import base64 import json from typing import Optional from fastapi import Request from fastapi.responses import JSONResponse import jwt from sqlmodel import Session from starlette.middleware.base import BaseHTTPMiddleware from apps.system.models.system_model import AssistantModel from common.core.db import engine from apps.system.crud.assistant import get_assistant_info, get_assistant_user from apps.system.crud.user import get_user_by_account, get_user_info from apps.system.schemas.system_schema import AssistantHeader, UserInfoDTO from common.core import security from common.core.config import settings from common.core.schemas import TokenPayload from common.utils.locale import I18n from common.utils.utils import SQLBotLogUtil from common.utils.whitelist import whiteUtils from fastapi.security.utils import get_authorization_scheme_param from common.core.deps import get_i18n class TokenMiddleware(BaseHTTPMiddleware): def __init__(self, app): super().__init__(app) async def dispatch(self, request, call_next): if self.is_options(request) or whiteUtils.is_whitelisted(request.url.path): return await call_next(request) assistantTokenKey = settings.ASSISTANT_TOKEN_KEY assistantToken = request.headers.get(assistantTokenKey) trans = await get_i18n(request) #if assistantToken and assistantToken.lower().startswith("assistant "): if assistantToken: validator: tuple[any] = await self.validateAssistant(assistantToken, trans) if validator[0]: request.state.current_user = validator[1] request.state.assistant = validator[2] return await call_next(request) message = trans('i18n_permission.authenticate_invalid', msg = validator[1]) return JSONResponse(message, status_code=401, headers={"Access-Control-Allow-Origin": "*"}) #validate pass tokenkey = settings.TOKEN_KEY token = request.headers.get(tokenkey) validate_pass, data = await self.validateToken(token, trans) if validate_pass: request.state.current_user = data return await call_next(request) message = trans('i18n_permission.authenticate_invalid', msg = data) return JSONResponse(message, status_code=401, headers={"Access-Control-Allow-Origin": "*"}) def is_options(self, request: Request): return request.method == "OPTIONS" async def validateToken(self, token: Optional[str], trans: I18n): if not token: return False, f"Miss Token[{settings.TOKEN_KEY}]!" schema, param = get_authorization_scheme_param(token) if schema.lower() != "bearer": return False, f"Token schema error!" try: payload = jwt.decode( param, settings.SECRET_KEY, algorithms=[security.ALGORITHM] ) token_data = TokenPayload(**payload) with Session(engine) as session: session_user = await get_user_info(session = session, user_id = token_data.id) if not session_user: message = trans('i18n_not_exist', msg = trans('i18n_user.account')) raise Exception(message) session_user = UserInfoDTO.model_validate(session_user) if session_user.status != 1: message = trans('i18n_login.user_disable', msg = trans('i18n_concat_admin')) raise Exception(message) if not session_user.oid or session_user.oid == 0: message = trans('i18n_login.no_associated_ws', msg = trans('i18n_concat_admin')) raise Exception(message) return True, session_user except Exception as e: msg = str(e) SQLBotLogUtil.exception(f"Token validation error: {msg}") if 'expired' in msg: return False, jwt.ExpiredSignatureError(trans('i18n_permission.token_expired')) return False, e async def validateAssistant(self, assistantToken: Optional[str], trans: I18n) -> tuple[any]: if not assistantToken: return False, f"Miss Token[{settings.TOKEN_KEY}]!" schema, param = get_authorization_scheme_param(assistantToken) try: if schema.lower() == 'embedded': return await self.validateEmbedded(param, trans) if schema.lower() != "assistant": return False, f"Token schema error!" payload = jwt.decode( param, settings.SECRET_KEY, algorithms=[security.ALGORITHM] ) token_data = TokenPayload(**payload) if not payload['assistant_id']: return False, f"Miss assistant payload error!" with Session(engine) as session: """ session_user = await get_user_info(session = session, user_id = token_data.id) session_user = UserInfoDTO.model_validate(session_user) """ session_user = get_assistant_user(id = token_data.id) assistant_info = await get_assistant_info(session=session, assistant_id=payload['assistant_id']) assistant_info = AssistantModel.model_validate(assistant_info) assistant_info = AssistantHeader.model_validate(assistant_info.model_dump(exclude_unset=True)) if assistant_info and assistant_info.type == 0: if payload['oid']: session_user.oid = int(payload['oid']) else: assistant_oid = 1 configuration = assistant_info.configuration config_obj = json.loads(configuration) if configuration else {} assistant_oid = config_obj.get('oid', 1) session_user.oid = int(assistant_oid) return True, session_user, assistant_info except Exception as e: SQLBotLogUtil.exception(f"Assistant validation error: {str(e)}") # Return False and the exception message return False, e async def validateEmbedded(self, param: str, trans: I18n) -> tuple[any]: try: """ payload = jwt.decode( param, settings.SECRET_KEY, algorithms=[security.ALGORITHM] ) """ payload: dict = jwt.decode( param, options={"verify_signature": False, "verify_exp": False}, algorithms=[security.ALGORITHM] ) app_key = payload.get('appId', '') embeddedId = payload.get('embeddedId', None) if not embeddedId: embeddedId = xor_decrypt(app_key) if not payload['account']: return False, f"Miss account payload error!" account = payload['account'] with Session(engine) as session: """ session_user = await get_user_info(session = session, user_id = token_data.id) session_user = UserInfoDTO.model_validate(session_user) """ session_user = get_user_by_account(session = session, account=account) if not session_user: message = trans('i18n_not_exist', msg = trans('i18n_user.account')) raise Exception(message) session_user = await get_user_info(session = session, user_id = session_user.id) session_user = UserInfoDTO.model_validate(session_user) if session_user.status != 1: message = trans('i18n_login.user_disable', msg = trans('i18n_concat_admin')) raise Exception(message) if not session_user.oid or session_user.oid == 0: message = trans('i18n_login.no_associated_ws', msg = trans('i18n_concat_admin')) raise Exception(message) assistant_info = await get_assistant_info(session=session, assistant_id=embeddedId) assistant_info = AssistantModel.model_validate(assistant_info) assistant_info = AssistantHeader.model_validate(assistant_info.model_dump(exclude_unset=True)) return True, session_user, assistant_info except Exception as e: SQLBotLogUtil.exception(f"Embedded validation error: {str(e)}") # Return False and the exception message return False, e def xor_decrypt(encrypted_str: str, key: int = 0xABCD1234) -> int: encrypted_bytes = base64.urlsafe_b64decode(encrypted_str) hex_str = encrypted_bytes.hex() encrypted_num = int(hex_str, 16) return encrypted_num ^ key