middlewares.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import json
  2. import time
  3. from starlette.middleware.base import (
  4. BaseHTTPMiddleware,
  5. RequestResponseEndpoint,
  6. )
  7. from starlette.middleware.cors import CORSMiddleware
  8. from starlette.middleware.gzip import GZipMiddleware
  9. from starlette.requests import Request
  10. from starlette.responses import Response
  11. from starlette.types import ASGIApp
  12. from app.api.v1.module_system.params.service import ParamsService
  13. from app.common.response import ErrorResponse
  14. from app.config.setting import settings
  15. from app.core.exceptions import CustomException
  16. from app.core.logger import log
  17. from app.core.security import decode_access_token
  18. class CustomCORSMiddleware(CORSMiddleware):
  19. """CORS跨域中间件"""
  20. def __init__(self, app: ASGIApp) -> None:
  21. super().__init__(
  22. app,
  23. allow_origins=settings.ALLOW_ORIGINS,
  24. allow_methods=settings.ALLOW_METHODS,
  25. allow_headers=settings.ALLOW_HEADERS,
  26. allow_credentials=settings.ALLOW_CREDENTIALS,
  27. expose_headers=settings.CORS_EXPOSE_HEADERS,
  28. )
  29. class RequestLogMiddleware(BaseHTTPMiddleware):
  30. """
  31. 记录请求日志中间件: 提供一个基础的中间件类,允许你自定义请求和响应处理逻辑。
  32. """
  33. def __init__(self, app: ASGIApp) -> None:
  34. super().__init__(app)
  35. @staticmethod
  36. def _extract_session_id_from_request(request: Request) -> str | None:
  37. """
  38. 从请求中提取session_id(支持从Token或已设置的scope中获取)
  39. 参数:
  40. - request (Request): 请求对象
  41. 返回:
  42. - str | None: 会话ID,如果无法提取则返回None
  43. """
  44. # 1. 先检查 scope 中是否已经有 session_id(登录接口会设置)
  45. session_id = request.scope.get("session_id")
  46. if session_id:
  47. return session_id
  48. # 2. 尝试从 Authorization Header 中提取
  49. try:
  50. authorization = request.headers.get("Authorization")
  51. if not authorization:
  52. return None
  53. # 处理Bearer token
  54. token = authorization.replace("Bearer ", "").strip()
  55. # 解码token
  56. payload = decode_access_token(token)
  57. if not payload or not hasattr(payload, "sub"):
  58. return None
  59. # 从payload中提取session_id
  60. user_info = json.loads(payload.sub)
  61. session_id = user_info.get("session_id")
  62. # 同时设置到request.scope中,避免后续重复解析
  63. if session_id:
  64. request.scope["session_id"] = session_id
  65. return session_id
  66. except Exception:
  67. # 解析失败静默处理,返回None(可能是未认证请求)
  68. return None
  69. async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
  70. """
  71. 记录请求日志并透传响应。
  72. 参数:
  73. - request (Request): 当前请求。
  74. - call_next (RequestResponseEndpoint): 下一层 ASGI 可调用对象。
  75. 返回:
  76. - Response: 下游中间件/路由产生的响应。
  77. """
  78. start_time = time.time()
  79. # 尝试提取session_id
  80. session_id = self._extract_session_id_from_request(request)
  81. # 组装请求日志字段
  82. log_fields = (
  83. f"请求来源: {request.client.host if request.client else '未知'},"
  84. f"请求方法: {request.method},"
  85. f"请求路径: {request.url.path}"
  86. )
  87. log.info(log_fields)
  88. try:
  89. # 初始化响应变量
  90. response = None
  91. # 获取请求路径
  92. path = request.scope.get("path")
  93. # 尝试获取客户端真实IP
  94. request_ip = None
  95. request_ip = (
  96. x_forwarded_for.split(",")[0].strip()
  97. if (x_forwarded_for := request.headers.get("X-Forwarded-For"))
  98. else request.client.host
  99. if request.client
  100. else None
  101. )
  102. # 检查是否启用演示模式
  103. demo_enable = False
  104. ip_white_list = []
  105. white_api_list_path = []
  106. ip_black_list = []
  107. try:
  108. # 从应用实例获取Redis连接
  109. redis = request.app.state.redis
  110. if not redis:
  111. raise CustomException(msg="无法获取Redis连接")
  112. # 使用ParamsService获取系统配置
  113. system_config = await ParamsService.get_system_config_for_middleware(redis)
  114. # 提取配置值
  115. demo_enable = system_config["demo_enable"]
  116. ip_white_list = system_config["ip_white_list"]
  117. white_api_list_path = system_config["white_api_list_path"]
  118. ip_black_list = system_config["ip_black_list"]
  119. except Exception as e:
  120. log.error(f"获取系统配置失败: {e}")
  121. # 检查是否需要拦截请求
  122. should_block = False
  123. block_reason = ""
  124. # 1. 首先检查IP是否在黑名单中
  125. if request_ip and request_ip in ip_black_list:
  126. should_block = True
  127. block_reason = f"IP地址 {request_ip} 在黑名单中"
  128. # 2. 如果不在黑名单中,检查是否在演示模式下需要拦截
  129. elif demo_enable in ["true", "True"] and request.method != "GET":
  130. # 在演示模式下,非GET请求需要检查白名单
  131. is_ip_whitelisted = request_ip in ip_white_list
  132. is_path_whitelisted = path in white_api_list_path
  133. if not is_ip_whitelisted and not is_path_whitelisted:
  134. should_block = True
  135. block_reason = f"演示模式下拦截非GET请求,IP: {request_ip}, 路径: {path}"
  136. if should_block:
  137. # 增强安全审计:记录详细的拦截日志
  138. log.warning(
  139. " | ".join(
  140. [
  141. f"会话ID: {session_id or '未认证'}",
  142. f"请求被拦截: {block_reason}",
  143. f"请求来源: {request_ip}",
  144. f"请求方法: {request.method}",
  145. f"请求路径: {path}",
  146. f"用户代理: {request.headers.get('user-agent', '未知')}",
  147. f"演示模式: {demo_enable}",
  148. ]
  149. )
  150. )
  151. # 拦截请求
  152. return ErrorResponse(msg="演示环境,禁止操作")
  153. # 正常处理请求
  154. response = await call_next(request)
  155. # 计算处理时间并添加到响应头
  156. process_time = round(time.time() - start_time, 5)
  157. response.headers["X-Process-Time"] = str(process_time)
  158. # 构建响应日志信息
  159. content_length = response.headers.get("content-length", "0")
  160. response_info = f"响应状态: {response.status_code}, 响应内容长度: {content_length}, 处理时间: {round(process_time * 1000, 3)}ms"
  161. log.info(response_info)
  162. return response
  163. except CustomException as e:
  164. log.exception(f"中间件处理异常: {e!s}")
  165. return ErrorResponse(msg="系统异常,请联系管理员", data=str(e))
  166. class CustomGZipMiddleware(GZipMiddleware):
  167. """GZip压缩中间件"""
  168. def __init__(self, app: ASGIApp) -> None:
  169. super().__init__(
  170. app,
  171. minimum_size=settings.GZIP_MIN_SIZE,
  172. compresslevel=settings.GZIP_COMPRESS_LEVEL,
  173. )