| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298 |
- import json
- from collections.abc import AsyncGenerator
- from fastapi import Depends, Query, Request
- from redis.asyncio.client import Redis
- from sqlalchemy.ext.asyncio import AsyncSession
- from sqlalchemy.orm import selectinload
- from app.api.v1.module_system.auth.schema import AuthSchema
- from app.api.v1.module_system.user.crud import UserCRUD
- from app.api.v1.module_system.user.model import UserModel
- from app.common.enums import RedisInitKeyConfig
- from app.config.setting import settings
- from app.core.database import async_db_session
- from app.core.exceptions import CustomException
- from app.core.logger import log
- from app.core.redis_crud import RedisCURD
- from app.core.security import OAuth2Schema, decode_access_token
- async def db_getter() -> AsyncGenerator[AsyncSession, None]:
- """获取数据库会话连接
- 返回:
- - AsyncSession: 数据库会话连接
- """
- async with async_db_session() as session:
- async with session.begin():
- yield session
- async def redis_getter(request: Request) -> Redis:
- """获取Redis连接
- 参数:
- - request (Request): 请求对象
- 返回:
- - Redis: Redis连接
- """
- return request.app.state.redis
- async def get_current_user(
- request: Request,
- db: AsyncSession = Depends(db_getter),
- redis: Redis = Depends(redis_getter),
- token: str = Depends(OAuth2Schema),
- ) -> AuthSchema:
- """获取当前用户
- 参数:
- - request (Request): 请求对象
- - db (AsyncSession): 数据库会话
- - redis (Redis): Redis连接
- - token (str): 访问令牌
- 返回:
- - AuthSchema: 认证信息模型
- """
- if not token:
- raise CustomException(msg="认证已失效", code=10401, status_code=401)
- # 处理Bearer token
- if token.startswith("Bearer"):
- token = token.split(" ")[1]
- payload = decode_access_token(token)
- if not payload or not hasattr(payload, "is_refresh") or payload.is_refresh:
- raise CustomException(msg="非法凭证", code=10401, status_code=401)
- online_user_info = payload.sub
- # 从Redis中获取用户信息
- user_info = json.loads(online_user_info) # 确保是字典类型
- session_id = user_info.get("session_id")
- if not session_id:
- raise CustomException(msg="认证已失效", code=10401, status_code=401)
- # 检查用户是否在线
- online_ok = await RedisCURD(redis).exists(
- key=f"{RedisInitKeyConfig.ACCESS_TOKEN.key}:{session_id}"
- )
- if not online_ok:
- raise CustomException(msg="认证已失效", code=10401, status_code=401)
- # 如果启用了滑动过期,自动续期token
- if settings.TOKEN_SLIDING_EXPIRE:
- await RedisCURD(redis).expire(
- key=f"{RedisInitKeyConfig.ACCESS_TOKEN.key}:{session_id}",
- expire=settings.ACCESS_TOKEN_EXPIRE_MINUTES,
- )
- await RedisCURD(redis).expire(
- key=f"{RedisInitKeyConfig.REFRESH_TOKEN.key}:{session_id}",
- expire=settings.REFRESH_TOKEN_EXPIRE_MINUTES,
- )
- # 关闭数据权限过滤,避免当前用户查询被拦截
- auth = AuthSchema(db=db, check_data_scope=False)
- username = user_info.get("user_name")
- if not username:
- raise CustomException(msg="认证已失效", code=10401, status_code=401)
- # 获取用户信息,使用深层预加载确保RoleModel.creator被正确加载
- user = await UserCRUD(auth).get_by_username_crud(
- username=username,
- preload=[
- "dept",
- selectinload(UserModel.roles),
- "positions",
- "created_by",
- ],
- )
- if not user:
- raise CustomException(msg="用户不存在", code=10401, status_code=401)
- if user.status == "1":
- raise CustomException(msg="用户已被停用", code=10401, status_code=401)
- # 设置请求上下文
- request.scope["user_id"] = user.id
- request.scope["user_username"] = user.username
- # 过滤可用的角色和职位
- if hasattr(user, "roles"):
- user.roles = [role for role in user.roles if role and role.status]
- if hasattr(user, "positions"):
- user.positions = [pos for pos in user.positions if pos and pos.status]
- if hasattr(user, "tenant"):
- auth.tenant_id = user.tenant.id # type: ignore
- auth.user = user
- auth.tenant_id = user.tenant_id
- return auth
- async def get_current_user_ws(
- token: str = Query(..., description="认证token"),
- db: AsyncSession = Depends(db_getter),
- redis: Redis = Depends(redis_getter),
- ) -> AuthSchema:
- """获取当前用户(WebSocket专用,从查询参数获取token)
- 参数:
- - token (str): 认证token
- - db (AsyncSession): 数据库会话
- - redis (Redis): Redis连接
- 返回:
- - AuthSchema: 认证信息模型
- """
- return await _verify_token(token, db, redis)
- async def _verify_token(
- token: str,
- db: AsyncSession,
- redis: Redis,
- ) -> AuthSchema:
- """验证token并返回用户信息
- 参数:
- - token (str): 认证token
- - db (AsyncSession): 数据库会话
- - redis (Redis): Redis连接
- 返回:
- - AuthSchema: 认证信息模型
- """
- if not token:
- raise CustomException(msg="认证已失效", code=10401, status_code=401)
- # 处理Bearer token(如果通过查询参数传递时包含Bearer前缀)
- if token.startswith("Bearer"):
- token = token.split(" ")[1]
- payload = decode_access_token(token)
- if not payload or not hasattr(payload, "is_refresh") or payload.is_refresh:
- raise CustomException(msg="非法凭证", code=10401, status_code=401)
- online_user_info = payload.sub
- # 从Redis中获取用户信息
- user_info = json.loads(online_user_info) # 确保是字典类型
- session_id = user_info.get("session_id")
- if not session_id:
- raise CustomException(msg="认证已失效", code=10401, status_code=401)
- # 检查用户是否在线
- online_ok = await RedisCURD(redis).exists(
- key=f"{RedisInitKeyConfig.ACCESS_TOKEN.key}:{session_id}"
- )
- if not online_ok:
- raise CustomException(msg="认证已失效", code=10401, status_code=401)
- # 如果启用了滑动过期,自动续期token
- if settings.TOKEN_SLIDING_EXPIRE:
- await RedisCURD(redis).expire(
- key=f"{RedisInitKeyConfig.ACCESS_TOKEN.key}:{session_id}",
- expire=settings.ACCESS_TOKEN_EXPIRE_MINUTES,
- )
- await RedisCURD(redis).expire(
- key=f"{RedisInitKeyConfig.REFRESH_TOKEN.key}:{session_id}",
- expire=settings.REFRESH_TOKEN_EXPIRE_MINUTES,
- )
- # 关闭数据权限过滤,避免当前用户查询被拦截
- auth = AuthSchema(db=db, check_data_scope=False)
- username = user_info.get("user_name")
- if not username:
- raise CustomException(msg="认证已失效", code=10401, status_code=401)
- # 获取用户信息,使用深层预加载确保RoleModel.creator被正确加载
- user = await UserCRUD(auth).get_by_username_crud(
- username=username,
- preload=[
- "dept",
- selectinload(UserModel.roles),
- "positions",
- "created_by",
- ],
- )
- if not user:
- raise CustomException(msg="用户不存在", code=10401, status_code=401)
- if user.status == "1":
- raise CustomException(msg="用户已被停用", code=10401, status_code=401)
- # 设置请求上下文
- # request.scope["user_id"] = user.id
- # request.scope["user_username"] = user.username
- # 过滤可用的角色和职位
- if hasattr(user, "roles"):
- user.roles = [role for role in user.roles if role and role.status]
- if hasattr(user, "positions"):
- user.positions = [pos for pos in user.positions if pos and pos.status]
- auth.user = user
- return auth
- class AuthPermission:
- """权限验证类"""
- def __init__(
- self,
- permissions: list[str] | None = None,
- check_data_scope: bool = True,
- ) -> None:
- """
- 初始化权限验证
- 参数:
- - permissions (list[str] | None): 权限标识列表。
- - check_data_scope (bool): 是否启用严格模式校验。
- """
- self.permissions = permissions or []
- self.check_data_scope = check_data_scope
- async def __call__(self, auth: AuthSchema = Depends(get_current_user)) -> AuthSchema:
- """
- 调用权限验证
- 参数:
- - auth (AuthSchema): 认证信息对象。
- 返回:
- - AuthSchema: 认证信息对象。
- """
- auth.check_data_scope = self.check_data_scope
- # 超级管理员直接通过
- if auth.user and auth.user.is_superuser:
- return auth
- # 无需验证权限
- if not self.permissions:
- return auth
- # 超级管理员权限标识
- if "*" in self.permissions or "*:*:*" in self.permissions:
- return auth
- # 检查用户是否有角色
- if not auth.user or not auth.user.roles:
- raise CustomException(msg="无权限操作", code=10403, status_code=403)
- # 获取用户权限集合
- user_permissions = {
- menu.permission
- for role in auth.user.roles
- for menu in role.menus
- if role.status == "0" and menu.permission and menu.status == "0"
- }
- # 权限验证 - 满足任一权限即可
- if not any(perm in user_permissions for perm in self.permissions):
- log.error(f"用户缺少任何所需的权限: {self.permissions}")
- raise CustomException(msg="无权限操作", code=10403, status_code=403)
- return auth
|