dependencies.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. import json
  2. from collections.abc import AsyncGenerator
  3. from fastapi import Depends, Query, Request
  4. from redis.asyncio.client import Redis
  5. from sqlalchemy.ext.asyncio import AsyncSession
  6. from sqlalchemy.orm import selectinload
  7. from app.api.v1.module_system.auth.schema import AuthSchema
  8. from app.api.v1.module_system.user.crud import UserCRUD
  9. from app.api.v1.module_system.user.model import UserModel
  10. from app.common.enums import RedisInitKeyConfig
  11. from app.config.setting import settings
  12. from app.core.database import async_db_session
  13. from app.core.exceptions import CustomException
  14. from app.core.logger import log
  15. from app.core.redis_crud import RedisCURD
  16. from app.core.security import OAuth2Schema, decode_access_token
  17. async def db_getter() -> AsyncGenerator[AsyncSession, None]:
  18. """获取数据库会话连接
  19. 返回:
  20. - AsyncSession: 数据库会话连接
  21. """
  22. async with async_db_session() as session:
  23. async with session.begin():
  24. yield session
  25. async def redis_getter(request: Request) -> Redis:
  26. """获取Redis连接
  27. 参数:
  28. - request (Request): 请求对象
  29. 返回:
  30. - Redis: Redis连接
  31. """
  32. return request.app.state.redis
  33. async def get_current_user(
  34. request: Request,
  35. db: AsyncSession = Depends(db_getter),
  36. redis: Redis = Depends(redis_getter),
  37. token: str = Depends(OAuth2Schema),
  38. ) -> AuthSchema:
  39. """获取当前用户
  40. 参数:
  41. - request (Request): 请求对象
  42. - db (AsyncSession): 数据库会话
  43. - redis (Redis): Redis连接
  44. - token (str): 访问令牌
  45. 返回:
  46. - AuthSchema: 认证信息模型
  47. """
  48. if not token:
  49. raise CustomException(msg="认证已失效", code=10401, status_code=401)
  50. # 处理Bearer token
  51. if token.startswith("Bearer"):
  52. token = token.split(" ")[1]
  53. payload = decode_access_token(token)
  54. if not payload or not hasattr(payload, "is_refresh") or payload.is_refresh:
  55. raise CustomException(msg="非法凭证", code=10401, status_code=401)
  56. online_user_info = payload.sub
  57. # 从Redis中获取用户信息
  58. user_info = json.loads(online_user_info) # 确保是字典类型
  59. session_id = user_info.get("session_id")
  60. if not session_id:
  61. raise CustomException(msg="认证已失效", code=10401, status_code=401)
  62. # 检查用户是否在线
  63. online_ok = await RedisCURD(redis).exists(
  64. key=f"{RedisInitKeyConfig.ACCESS_TOKEN.key}:{session_id}"
  65. )
  66. if not online_ok:
  67. raise CustomException(msg="认证已失效", code=10401, status_code=401)
  68. # 如果启用了滑动过期,自动续期token
  69. if settings.TOKEN_SLIDING_EXPIRE:
  70. await RedisCURD(redis).expire(
  71. key=f"{RedisInitKeyConfig.ACCESS_TOKEN.key}:{session_id}",
  72. expire=settings.ACCESS_TOKEN_EXPIRE_MINUTES,
  73. )
  74. await RedisCURD(redis).expire(
  75. key=f"{RedisInitKeyConfig.REFRESH_TOKEN.key}:{session_id}",
  76. expire=settings.REFRESH_TOKEN_EXPIRE_MINUTES,
  77. )
  78. # 关闭数据权限过滤,避免当前用户查询被拦截
  79. auth = AuthSchema(db=db, check_data_scope=False)
  80. username = user_info.get("user_name")
  81. if not username:
  82. raise CustomException(msg="认证已失效", code=10401, status_code=401)
  83. # 获取用户信息,使用深层预加载确保RoleModel.creator被正确加载
  84. user = await UserCRUD(auth).get_by_username_crud(
  85. username=username,
  86. preload=[
  87. "dept",
  88. selectinload(UserModel.roles),
  89. "positions",
  90. "created_by",
  91. ],
  92. )
  93. if not user:
  94. raise CustomException(msg="用户不存在", code=10401, status_code=401)
  95. if user.status == "1":
  96. raise CustomException(msg="用户已被停用", code=10401, status_code=401)
  97. # 设置请求上下文
  98. request.scope["user_id"] = user.id
  99. request.scope["user_username"] = user.username
  100. # 过滤可用的角色和职位
  101. if hasattr(user, "roles"):
  102. user.roles = [role for role in user.roles if role and role.status]
  103. if hasattr(user, "positions"):
  104. user.positions = [pos for pos in user.positions if pos and pos.status]
  105. if hasattr(user, "tenant"):
  106. auth.tenant_id = user.tenant.id # type: ignore
  107. auth.user = user
  108. auth.tenant_id = user.tenant_id
  109. return auth
  110. async def get_current_user_ws(
  111. token: str = Query(..., description="认证token"),
  112. db: AsyncSession = Depends(db_getter),
  113. redis: Redis = Depends(redis_getter),
  114. ) -> AuthSchema:
  115. """获取当前用户(WebSocket专用,从查询参数获取token)
  116. 参数:
  117. - token (str): 认证token
  118. - db (AsyncSession): 数据库会话
  119. - redis (Redis): Redis连接
  120. 返回:
  121. - AuthSchema: 认证信息模型
  122. """
  123. return await _verify_token(token, db, redis)
  124. async def _verify_token(
  125. token: str,
  126. db: AsyncSession,
  127. redis: Redis,
  128. ) -> AuthSchema:
  129. """验证token并返回用户信息
  130. 参数:
  131. - token (str): 认证token
  132. - db (AsyncSession): 数据库会话
  133. - redis (Redis): Redis连接
  134. 返回:
  135. - AuthSchema: 认证信息模型
  136. """
  137. if not token:
  138. raise CustomException(msg="认证已失效", code=10401, status_code=401)
  139. # 处理Bearer token(如果通过查询参数传递时包含Bearer前缀)
  140. if token.startswith("Bearer"):
  141. token = token.split(" ")[1]
  142. payload = decode_access_token(token)
  143. if not payload or not hasattr(payload, "is_refresh") or payload.is_refresh:
  144. raise CustomException(msg="非法凭证", code=10401, status_code=401)
  145. online_user_info = payload.sub
  146. # 从Redis中获取用户信息
  147. user_info = json.loads(online_user_info) # 确保是字典类型
  148. session_id = user_info.get("session_id")
  149. if not session_id:
  150. raise CustomException(msg="认证已失效", code=10401, status_code=401)
  151. # 检查用户是否在线
  152. online_ok = await RedisCURD(redis).exists(
  153. key=f"{RedisInitKeyConfig.ACCESS_TOKEN.key}:{session_id}"
  154. )
  155. if not online_ok:
  156. raise CustomException(msg="认证已失效", code=10401, status_code=401)
  157. # 如果启用了滑动过期,自动续期token
  158. if settings.TOKEN_SLIDING_EXPIRE:
  159. await RedisCURD(redis).expire(
  160. key=f"{RedisInitKeyConfig.ACCESS_TOKEN.key}:{session_id}",
  161. expire=settings.ACCESS_TOKEN_EXPIRE_MINUTES,
  162. )
  163. await RedisCURD(redis).expire(
  164. key=f"{RedisInitKeyConfig.REFRESH_TOKEN.key}:{session_id}",
  165. expire=settings.REFRESH_TOKEN_EXPIRE_MINUTES,
  166. )
  167. # 关闭数据权限过滤,避免当前用户查询被拦截
  168. auth = AuthSchema(db=db, check_data_scope=False)
  169. username = user_info.get("user_name")
  170. if not username:
  171. raise CustomException(msg="认证已失效", code=10401, status_code=401)
  172. # 获取用户信息,使用深层预加载确保RoleModel.creator被正确加载
  173. user = await UserCRUD(auth).get_by_username_crud(
  174. username=username,
  175. preload=[
  176. "dept",
  177. selectinload(UserModel.roles),
  178. "positions",
  179. "created_by",
  180. ],
  181. )
  182. if not user:
  183. raise CustomException(msg="用户不存在", code=10401, status_code=401)
  184. if user.status == "1":
  185. raise CustomException(msg="用户已被停用", code=10401, status_code=401)
  186. # 设置请求上下文
  187. # request.scope["user_id"] = user.id
  188. # request.scope["user_username"] = user.username
  189. # 过滤可用的角色和职位
  190. if hasattr(user, "roles"):
  191. user.roles = [role for role in user.roles if role and role.status]
  192. if hasattr(user, "positions"):
  193. user.positions = [pos for pos in user.positions if pos and pos.status]
  194. auth.user = user
  195. return auth
  196. class AuthPermission:
  197. """权限验证类"""
  198. def __init__(
  199. self,
  200. permissions: list[str] | None = None,
  201. check_data_scope: bool = True,
  202. ) -> None:
  203. """
  204. 初始化权限验证
  205. 参数:
  206. - permissions (list[str] | None): 权限标识列表。
  207. - check_data_scope (bool): 是否启用严格模式校验。
  208. """
  209. self.permissions = permissions or []
  210. self.check_data_scope = check_data_scope
  211. async def __call__(self, auth: AuthSchema = Depends(get_current_user)) -> AuthSchema:
  212. """
  213. 调用权限验证
  214. 参数:
  215. - auth (AuthSchema): 认证信息对象。
  216. 返回:
  217. - AuthSchema: 认证信息对象。
  218. """
  219. auth.check_data_scope = self.check_data_scope
  220. # 超级管理员直接通过
  221. if auth.user and auth.user.is_superuser:
  222. return auth
  223. # 无需验证权限
  224. if not self.permissions:
  225. return auth
  226. # 超级管理员权限标识
  227. if "*" in self.permissions or "*:*:*" in self.permissions:
  228. return auth
  229. # 检查用户是否有角色
  230. if not auth.user or not auth.user.roles:
  231. raise CustomException(msg="无权限操作", code=10403, status_code=403)
  232. # 获取用户权限集合
  233. user_permissions = {
  234. menu.permission
  235. for role in auth.user.roles
  236. for menu in role.menus
  237. if role.status == "0" and menu.permission and menu.status == "0"
  238. }
  239. # 权限验证 - 满足任一权限即可
  240. if not any(perm in user_permissions for perm in self.permissions):
  241. log.error(f"用户缺少任何所需的权限: {self.permissions}")
  242. raise CustomException(msg="无权限操作", code=10403, status_code=403)
  243. return auth