This commit is contained in:
2025-11-07 09:05:10 +08:00
parent bc6bbd617b
commit 971abe1293

View File

@@ -0,0 +1,219 @@
"""
Authentication middleware for LandPPT
"""
from typing import Optional, Callable
from fastapi import Request, Response, HTTPException, Depends
from fastapi.responses import RedirectResponse
from sqlalchemy.orm import Session
import logging
from .auth_service import get_auth_service, AuthService
from ..database.database import get_db
from ..database.models import User
logger = logging.getLogger(__name__)
class AuthMiddleware:
"""Authentication middleware"""
def __init__(self):
self.auth_service = get_auth_service()
# 不需要认证的路径
self.public_paths = {
"/",
"/auth/login",
"/auth/logout",
"/api/auth/login",
"/api/auth/logout",
"/api/auth/check",
"/docs",
"/redoc",
"/openapi.json",
"/static",
"/favicon.ico"
}
# 不需要认证的路径前缀
self.public_prefixes = [
"/static/",
"/temp/", # 添加temp目录用于图片缓存访问
"/api/image/view/", # 图床图片访问无需认证
"/api/image/thumbnail/", # 图片缩略图访问无需认证
"/share/", # 公开分享链接无需认证
"/api/share/", # 分享API无需认证
"/docs",
"/redoc",
"/openapi.json"
]
def is_public_path(self, path: str) -> bool:
"""Check if path is public (doesn't require authentication)"""
# Check exact matches
if path in self.public_paths:
return True
# Check prefixes
for prefix in self.public_prefixes:
if path.startswith(prefix):
return True
return False
async def __call__(self, request: Request, call_next: Callable):
"""Middleware function"""
path = request.url.path
# Skip authentication for public paths
if self.is_public_path(path):
response = await call_next(request)
return response
# Get session from cookie
session_id = request.cookies.get("session_id")
if not session_id:
# No session, redirect to login
if path.startswith("/api/"):
# API endpoints return 401
return Response(
content='{"detail": "Authentication required"}',
status_code=401,
media_type="application/json"
)
else:
# Web endpoints redirect to login
return RedirectResponse(url="/auth/login", status_code=302)
# Validate session
try:
# Get database session
db_gen = get_db()
db = next(db_gen)
try:
user = self.auth_service.get_user_by_session(db, session_id)
if not user:
# Invalid session, redirect to login
if path.startswith("/api/"):
return Response(
content='{"detail": "Invalid session"}',
status_code=401,
media_type="application/json"
)
else:
response = RedirectResponse(url="/auth/login", status_code=302)
response.delete_cookie("session_id")
return response
# Add user to request state
request.state.user = user
# Continue with request
response = await call_next(request)
return response
finally:
db.close()
except Exception as e:
logger.error(f"Authentication middleware error: {e}")
if path.startswith("/api/"):
return Response(
content='{"detail": "Authentication error"}',
status_code=500,
media_type="application/json"
)
else:
return RedirectResponse(url="/auth/login", status_code=302)
def get_current_user(request: Request) -> Optional[User]:
"""Get current authenticated user from request"""
return getattr(request.state, 'user', None)
def require_auth(request: Request) -> User:
"""Dependency to require authentication"""
user = get_current_user(request)
if not user:
raise HTTPException(status_code=401, detail="Authentication required")
return user
def require_admin(request: Request) -> User:
"""Dependency to require admin privileges"""
user = require_auth(request)
if not user.is_admin:
raise HTTPException(status_code=403, detail="Admin privileges required")
return user
def get_current_user_optional(
request: Request,
db: Session = Depends(get_db)
) -> Optional[User]:
"""
Get current user if authenticated, None otherwise.
For use with FastAPI dependency injection.
"""
session_id = request.cookies.get("session_id")
if not session_id:
return None
auth_service = get_auth_service()
return auth_service.get_user_by_session(db, session_id)
def get_current_user_required(
request: Request,
db: Session = Depends(get_db)
) -> User:
"""
Get current user, raise exception if not authenticated.
For use with FastAPI dependency injection.
"""
user = get_current_user_optional(request, db)
if not user:
raise HTTPException(status_code=401, detail="Authentication required")
return user
def get_current_admin_user(
request: Request,
db: Session = Depends(get_db)
) -> User:
"""
Get current admin user, raise exception if not admin.
For use with FastAPI dependency injection.
"""
user = get_current_user_required(request, db)
if not user.is_admin:
raise HTTPException(status_code=403, detail="Admin privileges required")
return user
def create_auth_middleware() -> AuthMiddleware:
"""Create authentication middleware instance"""
return AuthMiddleware()
# Utility functions for templates
def is_authenticated(request: Request) -> bool:
"""Check if user is authenticated"""
return get_current_user(request) is not None
def is_admin(request: Request) -> bool:
"""Check if user is admin"""
user = get_current_user(request)
return user is not None and user.is_admin
def get_user_info(request: Request) -> Optional[dict]:
"""Get user info for templates"""
user = get_current_user(request)
if user:
return user.to_dict()
return None