| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136 |
- import time
- from typing import Optional, TypeVar, Generic, Type
- from fastapi import Request, Depends, HTTPException
- from sqlalchemy.ext.asyncio import AsyncSession
- from app.core.dependencies import db_getter
- from app.api.v1.module_system.auth.schema import AuthSchema
- from app.plugin.module_payment.apikey.service import TenantApiKeyService
- from app.plugin.module_payment.apikey.schema import ApiKeyPayload
- T = TypeVar("T")
- class TenantApiKeyAuth(Generic[T]):
- """
- 租户API Key认证
- """
- def __init__(self, data_type: Type[T], auto_error: bool = True):
- self.data_type: Type[T] = data_type
- self.auto_error: bool = auto_error
- async def __call__(
- self,
- request: Request,
- db: AsyncSession = Depends(db_getter)
- ) -> Optional[ApiKeyPayload[T]]:
- """
- 验证API Key
- """
- # 记录请求开始时间
- request.state.start_time = time.time()
- # 获取 Authorization 头
- authorization = request.headers.get("Authorization", None)
- if not authorization:
- if self.auto_error:
- raise HTTPException(status_code=401, detail="Authorization header required")
- return None
- # 获取 Signature 头
- signature = request.headers.get("Signature", None)
- if not signature:
- if self.auto_error:
- raise HTTPException(status_code=401, detail="Signature header required")
- return None
- # ========================= 验证 ApiKey =========================
- # 检查是否为ApiKey认证
- if not authorization.startswith("ApiKey "):
- if self.auto_error:
- raise HTTPException(status_code=401, detail="Invalid authorization format")
- return None
- # 提取API Key和签名
- api_key = authorization[7:]
- # 创建AuthSchema对象
- temp_auth = AuthSchema(user=None, db=db, tenant_id=1, check_data_scope=False)
- # 验证API Key
- api_key_obj = await TenantApiKeyService.validate_api_key(temp_auth, api_key)
- if not api_key_obj:
- # await TenantApiKeyService.log_api_call(
- # auth=temp_auth,
- # api_key_id=None,
- # tenant_id=temp_auth.tenant_id,
- # endpoint=str(request.url.path),
- # method=request.method,
- # request_ip=request.client.host if request.client else "unknown",
- # request_data=None,
- # response_code=401,
- # start_time=request.state.start_time,
- # )
- if self.auto_error:
- raise HTTPException(status_code=401, detail="Invalid API Key")
- return None
- # ========================= 验证 Signature =========================
- # 获取请求数据
- try:
- request_data = await request.json()
- except:
- request_data = {}
- if not TenantApiKeyService.verify_signature(api_key_obj.api_secret, request_data, signature):
- # await TenantApiKeyService.log_api_call(
- # auth=temp_auth,
- # api_key_id=api_key_obj.id,
- # tenant_id=api_key_obj.tenant_id,
- # endpoint=str(request.url.path),
- # method=request.method,
- # request_ip=request.client.host if request.client else "unknown",
- # request_data=request_data,
- # response_code=401,
- # start_time=request.state.start_time,
- # )
- if self.auto_error:
- raise HTTPException(status_code=401, detail="Invalid Signature")
- return None
- await TenantApiKeyService.log_api_call(
- auth=temp_auth,
- api_key_id=api_key_obj.id,
- tenant_id=api_key_obj.tenant_id,
- endpoint=str(request.url.path),
- method=request.method,
- request_ip=request.client.host if request.client else "unknown",
- request_data=None, # 避免记录敏感数据
- response_code=200,
- start_time=request.state.start_time,
- )
- # 更新最后使用时间
- await TenantApiKeyService.update_last_used(temp_auth, api_key_obj.id)
- # 将API Key对象存储到请求状态
- # request.state.api_key = api_key_obj
- # request.state.tenant_id = api_key_obj.tenant_id
- parsed_data = self.data_type(**request_data) if isinstance(request_data, dict) else request_data
- return ApiKeyPayload(
- api_key=api_key_obj.api_key,
- api_secret=api_key_obj.api_secret,
- tenant_id=api_key_obj.tenant_id,
- auth=AuthSchema(user=None, db=db, tenant_id=api_key_obj.tenant_id, check_data_scope=False),
- data=parsed_data,
- )
|