Add File
This commit is contained in:
219
src/landppt/auth/middleware.py
Normal file
219
src/landppt/auth/middleware.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
Authentication middleware for LandPPT
|
||||
"""
|
||||
|
||||
from typing import Optional, Callable
|
||||
from fastapi import Request, Response, HTTPException, Depends
|
||||
from fastapi.responses import RedirectResponse
|
||||
from sqlalchemy.orm import Session
|
||||
import logging
|
||||
|
||||
from .auth_service import get_auth_service, AuthService
|
||||
from ..database.database import get_db
|
||||
from ..database.models import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthMiddleware:
|
||||
"""Authentication middleware"""
|
||||
|
||||
def __init__(self):
|
||||
self.auth_service = get_auth_service()
|
||||
# 不需要认证的路径
|
||||
self.public_paths = {
|
||||
"/",
|
||||
"/auth/login",
|
||||
"/auth/logout",
|
||||
"/api/auth/login",
|
||||
"/api/auth/logout",
|
||||
"/api/auth/check",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/static",
|
||||
"/favicon.ico"
|
||||
}
|
||||
# 不需要认证的路径前缀
|
||||
self.public_prefixes = [
|
||||
"/static/",
|
||||
"/temp/", # 添加temp目录用于图片缓存访问
|
||||
"/api/image/view/", # 图床图片访问无需认证
|
||||
"/api/image/thumbnail/", # 图片缩略图访问无需认证
|
||||
"/share/", # 公开分享链接无需认证
|
||||
"/api/share/", # 分享API无需认证
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json"
|
||||
]
|
||||
|
||||
def is_public_path(self, path: str) -> bool:
|
||||
"""Check if path is public (doesn't require authentication)"""
|
||||
# Check exact matches
|
||||
if path in self.public_paths:
|
||||
return True
|
||||
|
||||
# Check prefixes
|
||||
for prefix in self.public_prefixes:
|
||||
if path.startswith(prefix):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def __call__(self, request: Request, call_next: Callable):
|
||||
"""Middleware function"""
|
||||
path = request.url.path
|
||||
|
||||
# Skip authentication for public paths
|
||||
if self.is_public_path(path):
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
# Get session from cookie
|
||||
session_id = request.cookies.get("session_id")
|
||||
|
||||
if not session_id:
|
||||
# No session, redirect to login
|
||||
if path.startswith("/api/"):
|
||||
# API endpoints return 401
|
||||
return Response(
|
||||
content='{"detail": "Authentication required"}',
|
||||
status_code=401,
|
||||
media_type="application/json"
|
||||
)
|
||||
else:
|
||||
# Web endpoints redirect to login
|
||||
return RedirectResponse(url="/auth/login", status_code=302)
|
||||
|
||||
# Validate session
|
||||
try:
|
||||
# Get database session
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
try:
|
||||
user = self.auth_service.get_user_by_session(db, session_id)
|
||||
|
||||
if not user:
|
||||
# Invalid session, redirect to login
|
||||
if path.startswith("/api/"):
|
||||
return Response(
|
||||
content='{"detail": "Invalid session"}',
|
||||
status_code=401,
|
||||
media_type="application/json"
|
||||
)
|
||||
else:
|
||||
response = RedirectResponse(url="/auth/login", status_code=302)
|
||||
response.delete_cookie("session_id")
|
||||
return response
|
||||
|
||||
# Add user to request state
|
||||
request.state.user = user
|
||||
|
||||
# Continue with request
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication middleware error: {e}")
|
||||
if path.startswith("/api/"):
|
||||
return Response(
|
||||
content='{"detail": "Authentication error"}',
|
||||
status_code=500,
|
||||
media_type="application/json"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(url="/auth/login", status_code=302)
|
||||
|
||||
|
||||
def get_current_user(request: Request) -> Optional[User]:
|
||||
"""Get current authenticated user from request"""
|
||||
return getattr(request.state, 'user', None)
|
||||
|
||||
|
||||
def require_auth(request: Request) -> User:
|
||||
"""Dependency to require authentication"""
|
||||
user = get_current_user(request)
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
return user
|
||||
|
||||
|
||||
def require_admin(request: Request) -> User:
|
||||
"""Dependency to require admin privileges"""
|
||||
user = require_auth(request)
|
||||
if not user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Admin privileges required")
|
||||
return user
|
||||
|
||||
|
||||
def get_current_user_optional(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db)
|
||||
) -> Optional[User]:
|
||||
"""
|
||||
Get current user if authenticated, None otherwise.
|
||||
For use with FastAPI dependency injection.
|
||||
"""
|
||||
session_id = request.cookies.get("session_id")
|
||||
if not session_id:
|
||||
return None
|
||||
|
||||
auth_service = get_auth_service()
|
||||
return auth_service.get_user_by_session(db, session_id)
|
||||
|
||||
|
||||
def get_current_user_required(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db)
|
||||
) -> User:
|
||||
"""
|
||||
Get current user, raise exception if not authenticated.
|
||||
For use with FastAPI dependency injection.
|
||||
"""
|
||||
user = get_current_user_optional(request, db)
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
return user
|
||||
|
||||
|
||||
def get_current_admin_user(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db)
|
||||
) -> User:
|
||||
"""
|
||||
Get current admin user, raise exception if not admin.
|
||||
For use with FastAPI dependency injection.
|
||||
"""
|
||||
user = get_current_user_required(request, db)
|
||||
if not user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Admin privileges required")
|
||||
return user
|
||||
|
||||
|
||||
def create_auth_middleware() -> AuthMiddleware:
|
||||
"""Create authentication middleware instance"""
|
||||
return AuthMiddleware()
|
||||
|
||||
|
||||
# Utility functions for templates
|
||||
def is_authenticated(request: Request) -> bool:
|
||||
"""Check if user is authenticated"""
|
||||
return get_current_user(request) is not None
|
||||
|
||||
|
||||
def is_admin(request: Request) -> bool:
|
||||
"""Check if user is admin"""
|
||||
user = get_current_user(request)
|
||||
return user is not None and user.is_admin
|
||||
|
||||
|
||||
def get_user_info(request: Request) -> Optional[dict]:
|
||||
"""Get user info for templates"""
|
||||
user = get_current_user(request)
|
||||
if user:
|
||||
return user.to_dict()
|
||||
return None
|
||||
Reference in New Issue
Block a user