base_crud.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526
  1. import builtins
  2. from collections.abc import Sequence
  3. from typing import TYPE_CHECKING, Any, Generic, TypeVar
  4. from pydantic import BaseModel
  5. from sqlalchemy import Select, asc, delete, desc, false, func, select, update
  6. from sqlalchemy import inspect as sa_inspect
  7. from sqlalchemy.orm import selectinload
  8. from sqlalchemy.sql.elements import ColumnElement
  9. from app.api.v1.module_system.auth.schema import AuthSchema
  10. from app.core.base_model import MappedBase
  11. from app.core.exceptions import CustomException
  12. from app.core.permission import Permission
  13. if TYPE_CHECKING:
  14. from sqlalchemy.engine import Result
  15. ModelType = TypeVar("ModelType", bound=MappedBase)
  16. CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
  17. UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
  18. OutSchemaType = TypeVar("OutSchemaType", bound=BaseModel)
  19. class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
  20. """基础数据层"""
  21. def __init__(self, model: type[ModelType], auth: AuthSchema) -> None:
  22. """
  23. 初始化CRUDBase类
  24. 参数:
  25. - model (Type[ModelType]): 数据模型类。
  26. - auth (AuthSchema): 认证信息。
  27. 返回:
  28. - None
  29. """
  30. self.model = model
  31. self.auth = auth
  32. async def get(self, preload: list[str | Any] | None = None, **kwargs) -> ModelType | None:
  33. """
  34. 根据条件获取单个对象
  35. 参数:
  36. - preload (Optional[List[Union[str, Any]]]): 预加载关系,支持关系名字符串或SQLAlchemy loader option
  37. - **kwargs: 查询条件
  38. 返回:
  39. - Optional[ModelType]: 对象实例
  40. 异常:
  41. - CustomException: 查询失败时抛出异常
  42. """
  43. try:
  44. conditions = await self.__build_conditions(**kwargs)
  45. sql = select(self.model).where(*conditions)
  46. # 应用可配置的预加载选项
  47. for opt in self.__loader_options(preload):
  48. sql = sql.options(opt)
  49. sql = await self.__filter_permissions(sql)
  50. result: Result = await self.auth.db.execute(sql)
  51. obj = result.scalars().first()
  52. return obj
  53. except Exception as e:
  54. raise CustomException(msg=f"获取查询失败: {e!s}")
  55. async def get_first(self, preload: list[str | Any] | None = None) -> ModelType | None:
  56. """
  57. 获取第一个匹配的对象(不指定条件,仅按权限过滤)
  58. 参数:
  59. - preload (Optional[List[Union[str, Any]]]): 预加载关系,支持关系名字符串或SQLAlchemy loader option
  60. 返回:
  61. - Optional[ModelType]: 对象实例
  62. 异常:
  63. - CustomException: 查询失败时抛出异常
  64. """
  65. try:
  66. sql = select(self.model)
  67. # 应用可配置的预加载选项
  68. for opt in self.__loader_options(preload):
  69. sql = sql.options(opt)
  70. sql = await self.__filter_permissions(sql)
  71. result: Result = await self.auth.db.execute(sql)
  72. obj = result.scalars().first()
  73. return obj
  74. except Exception as e:
  75. raise CustomException(msg=f"获取查询失败: {e!s}")
  76. async def list(
  77. self,
  78. search: dict | None = None,
  79. order_by: list[dict[str, str]] | None = None,
  80. preload: list[str | Any] | None = None,
  81. ) -> Sequence[ModelType]:
  82. """
  83. 根据条件获取对象列表
  84. 参数:
  85. - search (Optional[Dict]): 查询条件,格式为 {'id': value, 'name': value}
  86. - order_by (Optional[List[Dict[str, str]]]): 排序字段,格式为 [{'id': 'asc'}, {'name': 'desc'}]
  87. - preload (Optional[List[Union[str, Any]]]): 预加载关系,支持关系名字符串或SQLAlchemy loader option
  88. 返回:
  89. - Sequence[ModelType]: 对象列表
  90. 异常:
  91. - CustomException: 查询失败时抛出异常
  92. """
  93. try:
  94. conditions = await self.__build_conditions(**search) if search else []
  95. order = order_by or [{"id": "asc"}]
  96. sql = select(self.model).where(*conditions).order_by(*self.__order_by(order))
  97. # 应用可配置的预加载选项
  98. for opt in self.__loader_options(preload):
  99. sql = sql.options(opt)
  100. sql = await self.__filter_permissions(sql)
  101. result: Result = await self.auth.db.execute(sql)
  102. return result.scalars().all()
  103. except Exception as e:
  104. raise CustomException(msg=f"列表查询失败: {e!s}")
  105. async def tree_list(
  106. self,
  107. search: dict | None = None,
  108. order_by: builtins.list[dict[str, str]] | None = None,
  109. children_attr: str = "children",
  110. preload: builtins.list[str | Any] | None = None,
  111. ) -> Sequence[ModelType]:
  112. """
  113. 获取树形结构数据列表
  114. 参数:
  115. - search (Optional[Dict]): 查询条件
  116. - order_by (Optional[List[Dict[str, str]]]): 排序字段
  117. - children_attr (str): 子节点属性名
  118. - preload (Optional[List[Union[str, Any]]]): 额外预加载关系,若为None则默认包含children_attr
  119. 返回:
  120. - Sequence[ModelType]: 树形结构数据列表
  121. 异常:
  122. - CustomException: 查询失败时抛出异常
  123. """
  124. try:
  125. conditions = await self.__build_conditions(**search) if search else []
  126. order = order_by or [{"id": "asc"}]
  127. sql = select(self.model).where(*conditions).order_by(*self.__order_by(order))
  128. # 处理预加载选项
  129. final_preload = preload
  130. # 如果没有提供preload且children_attr存在,则添加到预加载选项中
  131. if preload is None and children_attr and hasattr(self.model, children_attr):
  132. # 获取模型默认预加载选项
  133. model_defaults = getattr(self.model, "__loader_options__", [])
  134. # 将children_attr添加到默认预加载选项中
  135. final_preload = [*list(model_defaults), children_attr]
  136. # 应用预加载选项
  137. for opt in self.__loader_options(final_preload):
  138. sql = sql.options(opt)
  139. sql = await self.__filter_permissions(sql)
  140. result: Result = await self.auth.db.execute(sql)
  141. return result.scalars().all()
  142. except Exception as e:
  143. raise CustomException(msg=f"树形列表查询失败: {e!s}")
  144. async def page(
  145. self,
  146. offset: int,
  147. limit: int,
  148. order_by: builtins.list[dict[str, str]],
  149. search: dict,
  150. out_schema: type[OutSchemaType],
  151. preload: builtins.list[str | Any] | None = None,
  152. ) -> dict:
  153. """
  154. 获取分页数据
  155. 参数:
  156. - offset (int): 偏移量
  157. - limit (int): 每页数量
  158. - order_by (List[Dict[str, str]]): 排序字段
  159. - search (Dict): 查询条件
  160. - out_schema (Type[OutSchemaType]): 输出数据模型
  161. - preload (Optional[List[Union[str, Any]]]): 预加载关系
  162. 返回:
  163. - Dict: 分页数据
  164. 异常:
  165. - CustomException: 查询失败时抛出异常
  166. """
  167. try:
  168. conditions = await self.__build_conditions(**search) if search else []
  169. order = order_by or [{"id": "asc"}]
  170. sql = select(self.model).where(*conditions).order_by(*self.__order_by(order))
  171. # 应用预加载选项
  172. for opt in self.__loader_options(preload):
  173. sql = sql.options(opt)
  174. sql = await self.__filter_permissions(sql)
  175. # 优化count查询:使用主键计数而非全表扫描
  176. mapper = sa_inspect(self.model)
  177. pk_cols = list(getattr(mapper, "primary_key", []))
  178. if pk_cols:
  179. # 使用主键的第一列进行计数(主键必定非NULL,性能更好)
  180. count_sql = select(func.count(pk_cols[0])).select_from(self.model)
  181. else:
  182. # 降级方案:使用count(*)
  183. count_sql = select(func.count()).select_from(self.model)
  184. if conditions:
  185. count_sql = count_sql.where(*conditions)
  186. count_sql = await self.__filter_permissions(count_sql)
  187. total_result = await self.auth.db.execute(count_sql)
  188. total = total_result.scalar() or 0
  189. result: Result = await self.auth.db.execute(sql.offset(offset).limit(limit))
  190. objs = result.scalars().all()
  191. return {
  192. "page_no": offset // limit + 1 if limit else 1,
  193. "page_size": limit or 10,
  194. "total": total,
  195. "has_next": offset + limit < total,
  196. "items": [out_schema.model_validate(obj).model_dump() for obj in objs],
  197. }
  198. except Exception as e:
  199. raise CustomException(msg=f"分页查询失败: {e!s}")
  200. async def create(self, data: CreateSchemaType | dict, skip_tenant_id: bool = False) -> ModelType:
  201. """
  202. 创建新对象
  203. 参数:
  204. - data (Union[CreateSchemaType, Dict]): 对象属性
  205. - skip_tenant_id (bool, optional): 是否跳过设置租户ID。默认值为 False。
  206. 返回:
  207. - ModelType: 新创建的对象实例
  208. 异常:
  209. - CustomException: 创建失败时抛出异常
  210. """
  211. try:
  212. obj_dict = data if isinstance(data, dict) else data.model_dump()
  213. obj = self.model(**obj_dict)
  214. # 设置字段值(只检查一次current_user)
  215. if self.auth.user:
  216. if hasattr(obj, "created_id"):
  217. setattr(obj, "created_id", self.auth.user.id)
  218. if hasattr(obj, "updated_id"):
  219. setattr(obj, "updated_id", self.auth.user.id)
  220. # 自动设置 tenant_id(如果模型有该字段)
  221. if not skip_tenant_id and hasattr(self.model, "tenant_id") and hasattr(obj, "tenant_id"):
  222. if hasattr(self.auth, "tenant_id") and self.auth.tenant_id:
  223. setattr(obj, "tenant_id", self.auth.tenant_id)
  224. elif self.auth.user and hasattr(self.auth.user, "tenant_id"):
  225. setattr(obj, "tenant_id", self.auth.user.tenant_id)
  226. self.auth.db.add(obj)
  227. await self.auth.db.flush()
  228. await self.auth.db.refresh(obj)
  229. return obj
  230. except Exception as e:
  231. raise CustomException(msg=f"创建失败: {e!s}")
  232. async def update(self, id: int, data: UpdateSchemaType | dict) -> ModelType:
  233. """
  234. 更新对象
  235. 参数:
  236. - id (int): 对象ID
  237. - data (Union[UpdateSchemaType, Dict]): 更新的属性及值
  238. 返回:
  239. - ModelType: 更新后的对象实例
  240. 异常:
  241. - CustomException: 更新失败时抛出异常
  242. """
  243. try:
  244. obj_dict = (
  245. data
  246. if isinstance(data, dict)
  247. else data.model_dump(exclude_unset=True, exclude={"id"})
  248. )
  249. # 获取对象时不自动预加载关系,避免循环依赖
  250. obj = await self.get(id=id, preload=[])
  251. if not obj:
  252. raise CustomException(msg="更新对象不存在")
  253. # 设置字段值(只检查一次current_user)
  254. if self.auth.user and hasattr(obj, "updated_id"):
  255. setattr(obj, "updated_id", self.auth.user.id)
  256. for key, value in obj_dict.items():
  257. if hasattr(obj, key):
  258. setattr(obj, key, value)
  259. await self.auth.db.flush()
  260. # 刷新对象时不自动预加载关系
  261. await self.auth.db.refresh(obj)
  262. # 权限二次确认:flush后再次验证对象仍在权限范围内
  263. # 防止并发修改导致的权限逃逸(如其他事务修改了created_id)
  264. # 验证时也不自动预加载关系
  265. verify_obj = await self.get(id=id, preload=[])
  266. if not verify_obj:
  267. # 对象已被删除或权限已失效
  268. raise CustomException(msg="更新失败,对象不存在或无权限访问")
  269. return obj
  270. except Exception as e:
  271. raise CustomException(msg=f"更新失败: {e!s}")
  272. async def delete(self, ids: builtins.list[int]) -> None:
  273. """
  274. 删除对象
  275. 参数:
  276. - ids (List[int]): 对象ID列表
  277. 返回:
  278. - None
  279. 异常:
  280. - CustomException: 删除失败时抛出异常
  281. """
  282. try:
  283. mapper = sa_inspect(self.model)
  284. pk_cols = list(getattr(mapper, "primary_key", []))
  285. if not pk_cols:
  286. raise CustomException(msg="模型缺少主键,无法删除")
  287. if len(pk_cols) > 1:
  288. raise CustomException(msg="暂不支持复合主键的批量删除")
  289. # 只删除有权限的数据
  290. sql = delete(self.model).where(pk_cols[0].in_(ids))
  291. await self.auth.db.execute(sql)
  292. await self.auth.db.flush()
  293. except Exception as e:
  294. raise CustomException(msg=f"删除失败: {e!s}")
  295. async def clear(self) -> None:
  296. """
  297. 清空对象表
  298. 返回:
  299. - None
  300. 异常:
  301. - CustomException: 清空失败时抛出异常
  302. """
  303. try:
  304. sql = delete(self.model)
  305. await self.auth.db.execute(sql)
  306. await self.auth.db.flush()
  307. except Exception as e:
  308. raise CustomException(msg=f"清空失败: {e!s}")
  309. async def set(self, ids: builtins.list[int], **kwargs) -> None:
  310. """
  311. 批量更新对象
  312. 参数:
  313. - ids (List[int]): 对象ID列表
  314. - **kwargs: 更新的属性及值
  315. 返回:
  316. - None
  317. 异常:
  318. - CustomException: 更新失败时抛出异常
  319. """
  320. try:
  321. mapper = sa_inspect(self.model)
  322. pk_cols = list(getattr(mapper, "primary_key", []))
  323. if not pk_cols:
  324. raise CustomException(msg="模型缺少主键,无法更新")
  325. if len(pk_cols) > 1:
  326. raise CustomException(msg="暂不支持复合主键的批量更新")
  327. # 只更新有权限的数据
  328. sql = update(self.model).where(pk_cols[0].in_(ids)).values(**kwargs)
  329. await self.auth.db.execute(sql)
  330. await self.auth.db.flush()
  331. except CustomException:
  332. raise
  333. except Exception as e:
  334. raise CustomException(msg=f"批量更新失败: {e!s}")
  335. async def __filter_permissions(self, sql: Select) -> Select:
  336. """
  337. 过滤数据权限(仅用于Select)。
  338. """
  339. filter = Permission(model=self.model, auth=self.auth)
  340. return await filter.filter_query(sql)
  341. async def __build_conditions(self, **kwargs) -> builtins.list[ColumnElement]:
  342. """
  343. 构建查询条件
  344. 参数:
  345. - **kwargs: 查询参数
  346. 返回:
  347. - List[ColumnElement]: SQL条件表达式列表
  348. 异常:
  349. - CustomException: 查询参数不存在时抛出异常
  350. """
  351. conditions = []
  352. for key, value in kwargs.items():
  353. if value is None or value == "":
  354. continue
  355. attr = getattr(self.model, key)
  356. if isinstance(value, tuple):
  357. seq, val = value
  358. if seq == "None":
  359. conditions.append(attr.is_(None))
  360. elif seq == "not None":
  361. conditions.append(attr.isnot(None))
  362. elif seq == "date" and val:
  363. conditions.append(func.date_format(attr, "%Y-%m-%d") == val)
  364. elif seq == "month" and val:
  365. conditions.append(func.date_format(attr, "%Y-%m") == val)
  366. elif seq == "like" and val:
  367. conditions.append(attr.like(f"%{val}%"))
  368. elif seq == "in":
  369. # 通用约定:("in", []) 应当返回空集(恒假),不能跳过条件导致查询退化成全量(仅权限过滤)
  370. if val is None:
  371. continue
  372. if isinstance(val, (list, tuple, set)) and len(val) == 0:
  373. conditions.append(false())
  374. else:
  375. conditions.append(attr.in_(val))
  376. elif seq == "between" and isinstance(val, (list, tuple)) and len(val) == 2:
  377. conditions.append(attr.between(val[0], val[1]))
  378. elif seq == "!=" or (seq == "ne" and val):
  379. conditions.append(attr != val)
  380. elif seq == ">" or (seq == "gt" and val):
  381. conditions.append(attr > val)
  382. elif seq == ">=" or (seq == "ge" and val):
  383. conditions.append(attr >= val)
  384. elif seq == "<" or (seq == "lt" and val):
  385. conditions.append(attr < val)
  386. elif seq == "<=" or (seq == "le" and val):
  387. conditions.append(attr <= val)
  388. elif seq == "==" or (seq == "eq" and val):
  389. conditions.append(attr == val)
  390. else:
  391. conditions.append(attr == value)
  392. return conditions
  393. def __order_by(self, order_by: builtins.list[dict[str, str]]) -> builtins.list[ColumnElement]:
  394. """
  395. 获取排序字段
  396. 参数:
  397. - order_by (List[Dict[str, str]]): 排序字段列表,格式为 [{'id': 'asc'}, {'name': 'desc'}]
  398. 返回:
  399. - List[ColumnElement]: 排序字段列表
  400. 异常:
  401. - CustomException: 排序字段不存在时抛出异常
  402. """
  403. columns = []
  404. for order in order_by:
  405. for field, direction in order.items():
  406. column = getattr(self.model, field)
  407. columns.append(desc(column) if direction.lower() == "desc" else asc(column))
  408. return columns
  409. def __loader_options(
  410. self, preload: builtins.list[str | Any] | None = None
  411. ) -> builtins.list[Any]:
  412. """
  413. 构建预加载选项
  414. 参数:
  415. - preload (Optional[List[Union[str, Any]]]): 预加载关系,支持关系名字符串或SQLAlchemy loader option
  416. 返回:
  417. - List[Any]: 预加载选项列表
  418. """
  419. options = []
  420. # 获取模型定义的默认加载选项
  421. model_loader_options = getattr(self.model, "__loader_options__", [])
  422. # 合并所有需要预加载的选项
  423. all_preloads = set(model_loader_options)
  424. if preload:
  425. for opt in preload:
  426. if isinstance(opt, str):
  427. all_preloads.add(opt)
  428. elif preload == []:
  429. # 如果明确指定空列表,则不使用任何预加载
  430. all_preloads = set()
  431. # 处理所有预加载选项
  432. for opt in all_preloads:
  433. if isinstance(opt, str):
  434. # 使用selectinload来避免在异步环境中的MissingGreenlet错误
  435. if hasattr(self.model, opt):
  436. options.append(selectinload(getattr(self.model, opt)))
  437. else:
  438. # 直接使用非字符串的加载选项
  439. options.append(opt)
  440. return options