router_class.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import json
  2. import time
  3. from collections.abc import Callable, Coroutine
  4. from typing import Any
  5. from fastapi import Request, Response
  6. from fastapi.routing import APIRoute
  7. from user_agents import parse
  8. from app.api.v1.module_system.auth.schema import AuthSchema
  9. from app.api.v1.module_system.log.schema import OperationLogCreateSchema
  10. from app.api.v1.module_system.log.service import OperationLogService
  11. from app.config.setting import settings
  12. from app.core.database import async_db_session
  13. from app.utils.ip_local_util import IpLocalUtil
  14. """
  15. 在 FastAPI 中,route_class 参数用于自定义路由的行为。
  16. 通过设置 route_class,你可以定义一个自定义的路由类,从而在每个路由处理之前或之后执行特定的操作。
  17. 这对于日志记录、权限验证、性能监控等场景非常有用。
  18. """
  19. class OperationLogRoute(APIRoute):
  20. """操作日志路由装饰器"""
  21. def get_route_handler(
  22. self,
  23. ) -> Callable[[Request], Coroutine[Any, Any, Response]]:
  24. """
  25. 自定义路由处理程序,在每个路由处理之前或之后执行特定的操作。
  26. 参数:
  27. - request (Request): FastAPI请求对象。
  28. 返回:
  29. - Response: FastAPI响应对象。
  30. """
  31. original_route_handler = super().get_route_handler()
  32. async def custom_route_handler(request: Request) -> Response:
  33. """
  34. 自定义路由处理程序,在每个路由处理之前或之后执行特定的操作。
  35. 参数:
  36. - request (Request): FastAPI请求对象。
  37. 描述:
  38. - 该方法在每个路由处理之前被调用,用于记录操作日志。
  39. 返回:
  40. - Response: FastAPI响应对象。
  41. """
  42. start_time = time.time()
  43. # 请求前的处理
  44. response: Response = await original_route_handler(request)
  45. # 请求后的处理
  46. if not settings.OPERATION_LOG_RECORD:
  47. return response
  48. if request.method not in settings.OPERATION_RECORD_METHOD:
  49. return response
  50. route: APIRoute = request.scope.get("route", None)
  51. if route.name in settings.IGNORE_OPERATION_FUNCTION:
  52. return response
  53. user_agent = parse(request.headers.get("user-agent"))
  54. payload = b"{}"
  55. req_content_type = request.headers.get("Content-Type", "")
  56. if req_content_type and (
  57. req_content_type.startswith((
  58. "multipart/form-data",
  59. "application/x-www-form-urlencoded",
  60. ))
  61. ):
  62. form_data = await request.form()
  63. oper_param = "\n".join([f"{k}: {v}" for k, v in form_data.items()])
  64. payload = oper_param # 直接使用字符串格式的参数
  65. else:
  66. payload = await request.body()
  67. path_params = request.path_params
  68. oper_param = {}
  69. # 处理请求体数据
  70. if payload:
  71. try:
  72. oper_param["body"] = json.loads(payload.decode())
  73. except (json.JSONDecodeError, UnicodeDecodeError):
  74. oper_param["body"] = payload.decode("utf-8", errors="ignore")
  75. # 处理路径参数
  76. if path_params:
  77. oper_param["path_params"] = dict(path_params)
  78. payload = json.dumps(oper_param, ensure_ascii=False)
  79. # 日志表请求参数字段长度最大为2000,因此在此处判断长度
  80. if len(payload) > 2000:
  81. payload = "请求参数过长"
  82. response_data = (
  83. response.body
  84. if "application/json" in response.headers.get("Content-Type", "")
  85. else b"{}"
  86. )
  87. process_time = f"{(time.time() - start_time):.2f}s"
  88. # 获取当前用户ID,如果是登录接口则为空
  89. log_type = 1 # 1:登录日志 2:操作日志
  90. current_user_id = None
  91. # 优化:只在操作日志场景下获取current_user_id
  92. if "user_id" in request.scope:
  93. current_user_id = request.scope.get("user_id")
  94. log_type = 2
  95. request_ip = None
  96. x_forwarded_for = request.headers.get("X-Forwarded-For")
  97. if x_forwarded_for:
  98. # 取第一个 IP 地址,通常为客户端真实 IP
  99. request_ip = x_forwarded_for.split(",")[0].strip()
  100. else:
  101. # 若没有 X-Forwarded-For 头,则使用 request.client.host
  102. if request.client:
  103. request_ip = request.client.host
  104. login_location = await IpLocalUtil.resolve_location_for_log(request_ip)
  105. # 判断请求是否来自api文档
  106. referer = request.headers.get("referer")
  107. request_from_swagger = referer and referer.endswith("docs")
  108. request_from_redoc = referer and referer.endswith("redoc")
  109. if request_from_swagger or request_from_redoc:
  110. # 如果请求来自api文档,则不记录日志
  111. pass
  112. else:
  113. async with async_db_session() as session:
  114. async with session.begin():
  115. auth = AuthSchema(db=session)
  116. await OperationLogService.create_log_service(
  117. data=OperationLogCreateSchema(
  118. type=log_type,
  119. request_path=request.url.path,
  120. request_method=request.method,
  121. request_payload=payload,
  122. request_ip=request_ip,
  123. login_location=login_location,
  124. request_os=user_agent.os.family,
  125. request_browser=user_agent.browser.family,
  126. response_code=response.status_code,
  127. response_json=response_data.decode()
  128. if isinstance(response_data, (bytes, bytearray))
  129. else str(response_data),
  130. process_time=process_time,
  131. description=route.summary,
  132. created_id=current_user_id,
  133. updated_id=current_user_id,
  134. ),
  135. auth=auth,
  136. )
  137. return response
  138. return custom_route_handler