Add File
This commit is contained in:
234
src/landppt/auth/auth_service.py
Normal file
234
src/landppt/auth/auth_service.py
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
"""
|
||||||
|
Authentication service for LandPPT
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import secrets
|
||||||
|
import hashlib
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy import and_
|
||||||
|
|
||||||
|
from ..database.models import User, UserSession
|
||||||
|
from ..database.database import get_db
|
||||||
|
from ..core.config import app_config
|
||||||
|
|
||||||
|
|
||||||
|
class AuthService:
|
||||||
|
"""Authentication service"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.session_expire_minutes = app_config.access_token_expire_minutes
|
||||||
|
|
||||||
|
def _get_current_expire_minutes(self) -> int:
|
||||||
|
"""Get current session expire minutes from config (for real-time updates)"""
|
||||||
|
return app_config.access_token_expire_minutes
|
||||||
|
|
||||||
|
def create_user(self, db: Session, username: str, password: str, email: Optional[str] = None, is_admin: bool = False) -> User:
|
||||||
|
"""Create a new user"""
|
||||||
|
# Check if user already exists
|
||||||
|
existing_user = db.query(User).filter(User.username == username).first()
|
||||||
|
if existing_user:
|
||||||
|
raise ValueError("用户名已存在")
|
||||||
|
|
||||||
|
if email:
|
||||||
|
existing_email = db.query(User).filter(User.email == email).first()
|
||||||
|
if existing_email:
|
||||||
|
raise ValueError("邮箱已存在")
|
||||||
|
|
||||||
|
# Create new user
|
||||||
|
user = User(
|
||||||
|
username=username,
|
||||||
|
email=email,
|
||||||
|
is_admin=is_admin
|
||||||
|
)
|
||||||
|
user.set_password(password)
|
||||||
|
|
||||||
|
db.add(user)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(user)
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
def authenticate_user(self, db: Session, username: str, password: str) -> Optional[User]:
|
||||||
|
"""Authenticate user with username and password"""
|
||||||
|
user = db.query(User).filter(
|
||||||
|
and_(User.username == username, User.is_active == True)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if user and user.check_password(password):
|
||||||
|
# Update last login time
|
||||||
|
user.last_login = time.time()
|
||||||
|
db.commit()
|
||||||
|
return user
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def create_session(self, db: Session, user: User) -> str:
|
||||||
|
"""Create a new session for user"""
|
||||||
|
# Generate session ID
|
||||||
|
session_id = secrets.token_urlsafe(64)
|
||||||
|
|
||||||
|
# Get current expire minutes (for real-time config updates)
|
||||||
|
current_expire_minutes = self._get_current_expire_minutes()
|
||||||
|
|
||||||
|
# Calculate expiration time
|
||||||
|
# If session_expire_minutes is 0, set to a very far future date (never expire)
|
||||||
|
if current_expire_minutes == 0:
|
||||||
|
# Set expiration to year 2099 (effectively never expires)
|
||||||
|
expires_at = time.mktime(time.strptime("2099-12-31 23:59:59", "%Y-%m-%d %H:%M:%S"))
|
||||||
|
else:
|
||||||
|
expires_at = time.time() + (current_expire_minutes * 60)
|
||||||
|
|
||||||
|
# Create session record
|
||||||
|
session = UserSession(
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user.id,
|
||||||
|
expires_at=expires_at
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(session)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
def get_user_by_session(self, db: Session, session_id: str) -> Optional[User]:
|
||||||
|
"""Get user by session ID"""
|
||||||
|
session = db.query(UserSession).filter(
|
||||||
|
and_(
|
||||||
|
UserSession.session_id == session_id,
|
||||||
|
UserSession.is_active == True
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not session or session.is_expired():
|
||||||
|
if session:
|
||||||
|
# Mark session as inactive
|
||||||
|
session.is_active = False
|
||||||
|
db.commit()
|
||||||
|
return None
|
||||||
|
|
||||||
|
return session.user
|
||||||
|
|
||||||
|
def logout_user(self, db: Session, session_id: str) -> bool:
|
||||||
|
"""Logout user by deactivating session"""
|
||||||
|
session = db.query(UserSession).filter(
|
||||||
|
UserSession.session_id == session_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if session:
|
||||||
|
session.is_active = False
|
||||||
|
db.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def cleanup_expired_sessions(self, db: Session) -> int:
|
||||||
|
"""Clean up expired sessions"""
|
||||||
|
current_time = time.time()
|
||||||
|
# Don't clean up sessions that are set to never expire (year 2099 or later)
|
||||||
|
year_2099_timestamp = time.mktime(time.strptime("2099-01-01 00:00:00", "%Y-%m-%d %H:%M:%S"))
|
||||||
|
|
||||||
|
expired_sessions = db.query(UserSession).filter(
|
||||||
|
and_(
|
||||||
|
UserSession.expires_at < current_time,
|
||||||
|
UserSession.expires_at < year_2099_timestamp # Exclude never-expire sessions
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
|
||||||
|
count = len(expired_sessions)
|
||||||
|
for session in expired_sessions:
|
||||||
|
session.is_active = False
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
return count
|
||||||
|
|
||||||
|
def get_user_by_id(self, db: Session, user_id: int) -> Optional[User]:
|
||||||
|
"""Get user by ID"""
|
||||||
|
return db.query(User).filter(
|
||||||
|
and_(User.id == user_id, User.is_active == True)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
def get_user_by_username(self, db: Session, username: str) -> Optional[User]:
|
||||||
|
"""Get user by username"""
|
||||||
|
return db.query(User).filter(
|
||||||
|
and_(User.username == username, User.is_active == True)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
def update_user_password(self, db: Session, user: User, new_password: str) -> bool:
|
||||||
|
"""Update user password"""
|
||||||
|
try:
|
||||||
|
user.set_password(new_password)
|
||||||
|
db.commit()
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
db.rollback()
|
||||||
|
return False
|
||||||
|
|
||||||
|
def deactivate_user(self, db: Session, user: User) -> bool:
|
||||||
|
"""Deactivate user account"""
|
||||||
|
try:
|
||||||
|
user.is_active = False
|
||||||
|
# Deactivate all user sessions
|
||||||
|
sessions = db.query(UserSession).filter(UserSession.user_id == user.id).all()
|
||||||
|
for session in sessions:
|
||||||
|
session.is_active = False
|
||||||
|
db.commit()
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
db.rollback()
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_users(self, db: Session, skip: int = 0, limit: int = 100) -> list[User]:
|
||||||
|
"""List all users"""
|
||||||
|
return db.query(User).offset(skip).limit(limit).all()
|
||||||
|
|
||||||
|
def get_user_sessions(self, db: Session, user: User) -> list[UserSession]:
|
||||||
|
"""Get all active sessions for a user"""
|
||||||
|
return db.query(UserSession).filter(
|
||||||
|
and_(
|
||||||
|
UserSession.user_id == user.id,
|
||||||
|
UserSession.is_active == True
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
|
||||||
|
|
||||||
|
# Global auth service instance
|
||||||
|
auth_service = AuthService()
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_service() -> AuthService:
|
||||||
|
"""Get auth service instance"""
|
||||||
|
return auth_service
|
||||||
|
|
||||||
|
|
||||||
|
def init_default_admin(db: Session) -> None:
|
||||||
|
"""Initialize default admin user if no users exist"""
|
||||||
|
user_count = db.query(User).count()
|
||||||
|
|
||||||
|
if user_count == 0:
|
||||||
|
# Create default admin user
|
||||||
|
default_username = "admin"
|
||||||
|
default_password = "admin123"
|
||||||
|
|
||||||
|
try:
|
||||||
|
auth_service.create_user(
|
||||||
|
db=db,
|
||||||
|
username=default_username,
|
||||||
|
password=default_password,
|
||||||
|
is_admin=True
|
||||||
|
)
|
||||||
|
print(f"默认管理员账户已创建: {default_username} / {default_password}")
|
||||||
|
print("请及时修改默认密码!")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"创建默认管理员账户失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def hash_password(password: str) -> str:
|
||||||
|
"""Hash password using SHA256"""
|
||||||
|
return hashlib.sha256(password.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(password: str, hashed: str) -> bool:
|
||||||
|
"""Verify password against hash"""
|
||||||
|
return hash_password(password) == hashed
|
||||||
Reference in New Issue
Block a user