FastAPI Role Based Access Control With JWT

FastAPI Role Based Access Control With JWT

FastAPI is a modern, high-performance, web framework used to build APIs with Python 3.8+. It is one of the fastest Python frameworks available. In every framework, authentication and authorization are important sections of an API. In this article let’s implement Role-based access control with JWT in FastAPI.

Before you start you have to install these python modules.

  • FastAPI

  • Pydantic

  • uvicorn[standard

  • passlib[bcrypt]

  • python-jose[cryptography]

Setting Up The Environment

Let’s create two API endpoints in the main.py file.

from fastapi import FastAPI

app = FastAPI()

@app.get("/hello")
def hello_func():
  return "Hello World"

@app.get("/data")
def get_data():
  return {"data": "This is important data"}

Let’s create a User Model and Token Model in models.py

from pydantic import BaseModel 

class User (BaseModel):
    username: str None = None 
    email: str None = None 
    role: str None = None 
    disabled: bool| None = None 
    hashed_password: str | None = None 

class Token (BaseModel):
    access_token: str None = None 
    refresh_ token: str None = None

For this tutorial, I will create a Python dictionary containing dummy users in data.py. Also, I will create another list for store refresh tokens. You can use any database for this like PostgreSQL, MongoDB, etc.

fake_user_db = [
  {
     "username": "johndoe",
     "email": "john@emaik.com",
     "role": "admin",
     "hashed_password": "hdjsbdvdhxbzbsksjdbdbzjdhh45tbdbd7bdbd",
     "is_active": True
  },
  {
     "username": "alice",
     "email": "al8ce@emaik.com",
     "role": "user",
     "hashed_password": "hdjsbdvdhxbzbsksjdbdbzjdhh45tbdbd7bdbd",
     "is_active": True
  }
]

refresh_tokens = []

How This Works

To set up authentication for our API, we'll follow these steps: First, users log in with their username and password through a post request. Then, our backend checks if their details are correct and generates two types of tokens: an access token and a refresh token. The access token is short-lived, while the refresh token lasts longer. Once validated, the backend sends back these tokens to the user. To access secure parts of the API, users need to include the access token in their request header. If the access token expires, users can request a new one by sending their refresh token to the backend. This process ensures secure access to our API endpoints.

Role Based Access Control (RBAC)

FastAPI provides several ways to deal with security. Here we use the OAuth2 with password flow. (You can get more details from this link.) We do that using the OAuth2PasswordBearer class. Also, we use passlib CryptContext to hash and verify passwords.

Let’s create auth.py. First, create instances of the above classes.

#auth.py
from fastapi.security import OAuth2PasswordBearer 
from passlib.context import CryptContext

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

We pass the tokenUrl parameter to this class. This parameter contains the URL that the client uses to send the username and password in order to get a token. We haven’t created this endpoint yet. But we will create it later.

Now create a method to get the user details from db and another method to authenticate users. This method will check the password.

#auth.py
from db import User
from passlib.context import CryptContext

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

def get_user(db, username: str):
  if username in db:
    user = db[username]
    return User(**user)


def authenticate_user(fake_db, username: str, password: str):
    user = get_user(fake_db, username)
    if not user:
        return False
    if not pwd_context.verify(plain_password, hashed_password):
        return False
    return user

Now let’s handle the JWT. To do that create some variables and a method to create JWT token.

#auth.py
from jose import JWTError, jwt
from datetime import datetime, timedelta, timezone
from data import refresh_tokens


SECRET_KEY = "hdhfh5jdnb7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 20
REFRESH_TOKEN_EXPIRE_MINUTES = 120


def create_token(data: dict, expires_delta: timedelta | None = None):
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.now(timezone.utc) + expires_delta
    else:
        expire = datetime.now(timezone.utc) + timedelta(minutes=15)
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
    return encoded_jwt

We pass our data and token lifetime to this method and it returns the JWT token.

Once authentication is in place, we'll develop a method to retrieve details about the currently logged-in user. This method will take the token as input. It will decode the token to extract user data and then verify if the user exists in the database. If the user exists, the method will return the user's details. However, if the user does not exist, it will raise an exception to indicate the issue. This approach ensures that only valid users can access their information, maintaining security and integrity within the system.

#auth.py
from typing import Annotated
from jose import JWTError, jwt
from fastapi import Depends, HTTPException, status

SECRET_KEY = "hdhfh5jdnb7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"

async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
        if username is None:
            raise credentials_exception
    except JWTError:
        raise credentials_exception
    user = get_user(fake_users_db, username=username)
    if user is None:
        raise credentials_exception
    return user

async def get_current_active_user(
    current_user: Annotated[User, Depends(get_current_user)]
):
    if current_user.disabled:
        raise HTTPException(status_code=400, detail="Inactive user")
    return current_user

In addition, we'll implement another method to verify if the user is enabled or disabled. If the user is disabled, the method will raise an exception. The Depends() function in the provided code signifies a dependency relationship. For instance, the get_current_active_user() method relies on the get_current_user() method. When debugging, you'll notice that the get_current_user() method executes before get_current_active_user().

Now, we'll introduce the RoleChecker class to validate user roles. If the user's role grants sufficient permissions, the method will return True. Otherwise, it will raise an exception. This class helps ensure that users only access functionalities appropriate for their assigned roles, maintaining security and access control within the system.

#auth.py
from typing import Annotated
from fastapi import Depends, HTTPException, status


class RoleChecker:
  def __init__(self, allowed_roles):
    self.allowed_roles = allowed_roles

  def __call__(self, user: Annotated[User, Depends(get_current_active_user)]):
    if user.role in self.allowed_roles:
      return True
    raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, 
detail="You don't have enough permissions")

We have to create one more method to validate the refresh token. When the access token expires, we have to request our refresh token to get a new access token.

#auth.py
from typing import Annotated
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from data import refresh_tokens


SECRET_KEY = "hdhfh5jdnb7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"


oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")


async def validate_refresh_token(token: Annotated[str, Depends(oauth2_scheme)]):
    credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials")
    try:
        if token in refresh_tokens:
            payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
            username: str = payload.get("sub")
            role: str = payload.get("role")
            if username is None or role is None:
                raise credentials_exception
        else:
            raise credentials_exception

    except (JWTError, ValidationError):
        raise credentials_exception

    user = get_user(fake_users_db, username=username)

    if user is None:
        raise credentials_exception

    return user, token

The final auth.py file looks like this.

from fastapi.security import OAuth2PasswordBearer 
from passlib.context import CryptContext
from db import User
from jose import JWTError, jwt
from datetime import datetime, timedelta, timezone
from data import refresh_tokens
from typing import Annotated
from fastapi import Depends, HTTPException, status

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

SECRET_KEY = "hdhfh5jdnb7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"

def get_user(db, username: str):
  if username in db:
    user = db[username]
    return User(**user)


def authenticate_user(fake_db, username: str, password: str):
    user = get_user(fake_db, username)
    if not user:
        return False
    if not pwd_context.verify(plain_password, hashed_password):
        return False
    return user


def create_token(data: dict, expires_delta: timedelta | None = None):
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.now(timezone.utc) + expires_delta
    else:
        expire = datetime.now(timezone.utc) + timedelta(minutes=15)
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
    return encoded_jwt


async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
        if username is None:
            raise credentials_exception
    except JWTError:
        raise credentials_exception
    user = get_user(fake_users_db, username=username)
    if user is None:
        raise credentials_exception
    return user


async def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]):
    if current_user.disabled:
        raise HTTPException(status_code=400, detail="Inactive user")
    return current_user


async def validate_refresh_token(token: Annotated[str, Depends(oauth2_scheme)]):
    credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials")
    try:
        if token in refresh_tokens:
            payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
            username: str = payload.get("sub")
            role: str = payload.get("role")
            if username is None or role is None:
                raise credentials_exception
        else:
            raise credentials_exception

    except (JWTError, ValidationError):
        raise credentials_exception
    user = get_user(fake_users_db, username=username)
    if user is None:
        raise credentials_exception
    return user, token


class RoleChecker:
  def __init__(self, allowed_roles):
    self.allowed_roles = allowed_roles

  def __call__(self, user: Annotated[User, Depends(get_current_active_user)]):
    if user.role in self.allowed_roles:
      return True
    raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, 
detail="You don't have enough permissions")

Ok. We created the authentication and authorization parts. Now we can add these to our API endpoints. Before doing that we should create two endpoints. One is login and the other one is for refreshing tokens. Let’s go to main.py again.

from datetime import timedelta
from typing import Annotated

from fastapi import Depends, FastAPI, HTTPException
from fastapi.security import OAuth2PasswordRequestForm

from auth import create_token, authenticate_user, RoleChecker, get_current_active_user, validate_refresh_token
from data import fake_users_db, refresh_tokens
from models import User, Token

app = FastAPI()

ACCESS_TOKEN_EXPIRE_MINUTES = 20
REFRESH_TOKEN_EXPIRE_MINUTES = 120

@app.get("/hello")
def hello_func():
  return "Hello World"

@app.get("/data")
def get_data():
  return {"data": "This is important data"} 

@app.post("/token")
async def login_for_access_token(form_data: Annotated[OAuth2PasswordRequestForm, Depends()]) -> Token:
    user = authenticate_user(fake_users_db, form_data.username, form_data.password)
    if not user:
        raise HTTPException(status_code=400, detail="Incorrect username or password")

    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    refresh_token_expires = timedelta(minutes=REFRESH_TOKEN_EXPIRE_MINUTES)

    access_token = create_token(data={"sub": user.username, "role": user.role}, expires_delta=access_token_expires)
    refresh_token = create_token(data={"sub": user.username, "role": user.role}, expires_delta=refresh_token_expires)
    refresh_tokens.append(refresh_token)
    return Token(access_token=access_token, refresh_token=refresh_token)

@app.post("/refresh")
async def refresh_access_token(token_data: Annotated[tuple[User, str], Depends(validate_refresh_token)]):
    user, token = token_data
    access_token = create_token(data={"sub": user.username, "role": user.role}, expires_delta=access_token_expires)
    refresh_token = create_token(data={"sub": user.username, "role": user.role}, expires_delta=refresh_token_expires)

    refresh_tokens.remove(token)
    refresh_tokens.append(refresh_token)
    return Token(access_token=access_token, refresh_token=refresh_token)

Add RBAC To API

Now let’s add RBAC to our endpoints. For now, the “/data” endpoint is not protected. It can be accessed by anyone. You can check it using Swagger Docs or Postman. Now let’s add RBAC to this endpoint.

@app.get("/data")
def get_data(_: Annotated[bool, Depends(RoleChecker(allowed_roles=["admin"]))]):
  return {"data": "This is important data"}

After doing this, it can be only accessed after login as an admin user. Like that you can add this to any endpoint that you want to protect. Now you know how to add RBAC to FastAPI. This is only one method. There are some other methods to do this. You can find it on the Internet. Happy Coding !

Buy Me A Coffee

You can connect with me on hirushafernando.com