base_crud.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. import builtins
  2. from collections.abc import Sequence
  3. from typing import TYPE_CHECKING, Any, Generic, TypeVar
  4. from pydantic import BaseModel
  5. from sqlalchemy import Select, asc, delete, desc, false, func, select, update
  6. from sqlalchemy import inspect as sa_inspect
  7. from sqlalchemy.orm import selectinload
  8. from sqlalchemy.sql.elements import ColumnElement
  9. from app.api.v1.module_system.auth.schema import AuthSchema
  10. from app.core.base_model import MappedBase
  11. from app.core.exceptions import CustomException
  12. from app.core.permission import Permission
  13. if TYPE_CHECKING:
  14. from sqlalchemy.engine import Result
  15. ModelType = TypeVar("ModelType", bound=MappedBase)
  16. CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
  17. UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
  18. OutSchemaType = TypeVar("OutSchemaType", bound=BaseModel)
  19. class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
  20. """基础数据层"""
  21. def __init__(self, model: type[ModelType], auth: AuthSchema) -> None:
  22. """
  23. 初始化CRUDBase类
  24. 参数:
  25. - model (Type[ModelType]): 数据模型类。
  26. - auth (AuthSchema): 认证信息。
  27. 返回:
  28. - None
  29. """
  30. self.model = model
  31. self.auth = auth
  32. async def get(self, preload: list[str | Any] | None = None, **kwargs) -> ModelType | None:
  33. """
  34. 根据条件获取单个对象
  35. 参数:
  36. - preload (Optional[List[Union[str, Any]]]): 预加载关系,支持关系名字符串或SQLAlchemy loader option
  37. - **kwargs: 查询条件
  38. 返回:
  39. - Optional[ModelType]: 对象实例
  40. 异常:
  41. - CustomException: 查询失败时抛出异常
  42. """
  43. try:
  44. conditions = await self.__build_conditions(**kwargs)
  45. sql = select(self.model).where(*conditions)
  46. # 应用可配置的预加载选项
  47. for opt in self.__loader_options(preload):
  48. sql = sql.options(opt)
  49. sql = await self.__filter_permissions(sql)
  50. result: Result = await self.auth.db.execute(sql)
  51. obj = result.scalars().first()
  52. return obj
  53. except Exception as e:
  54. raise CustomException(msg=f"获取查询失败: {e!s}")
  55. async def list(
  56. self,
  57. search: dict | None = None,
  58. order_by: list[dict[str, str]] | None = None,
  59. preload: list[str | Any] | None = None,
  60. ) -> Sequence[ModelType]:
  61. """
  62. 根据条件获取对象列表
  63. 参数:
  64. - search (Optional[Dict]): 查询条件,格式为 {'id': value, 'name': value}
  65. - order_by (Optional[List[Dict[str, str]]]): 排序字段,格式为 [{'id': 'asc'}, {'name': 'desc'}]
  66. - preload (Optional[List[Union[str, Any]]]): 预加载关系,支持关系名字符串或SQLAlchemy loader option
  67. 返回:
  68. - Sequence[ModelType]: 对象列表
  69. 异常:
  70. - CustomException: 查询失败时抛出异常
  71. """
  72. try:
  73. conditions = await self.__build_conditions(**search) if search else []
  74. order = order_by or [{"id": "asc"}]
  75. sql = select(self.model).where(*conditions).order_by(*self.__order_by(order))
  76. # 应用可配置的预加载选项
  77. for opt in self.__loader_options(preload):
  78. sql = sql.options(opt)
  79. sql = await self.__filter_permissions(sql)
  80. result: Result = await self.auth.db.execute(sql)
  81. return result.scalars().all()
  82. except Exception as e:
  83. raise CustomException(msg=f"列表查询失败: {e!s}")
  84. async def tree_list(
  85. self,
  86. search: dict | None = None,
  87. order_by: builtins.list[dict[str, str]] | None = None,
  88. children_attr: str = "children",
  89. preload: builtins.list[str | Any] | None = None,
  90. ) -> Sequence[ModelType]:
  91. """
  92. 获取树形结构数据列表
  93. 参数:
  94. - search (Optional[Dict]): 查询条件
  95. - order_by (Optional[List[Dict[str, str]]]): 排序字段
  96. - children_attr (str): 子节点属性名
  97. - preload (Optional[List[Union[str, Any]]]): 额外预加载关系,若为None则默认包含children_attr
  98. 返回:
  99. - Sequence[ModelType]: 树形结构数据列表
  100. 异常:
  101. - CustomException: 查询失败时抛出异常
  102. """
  103. try:
  104. conditions = await self.__build_conditions(**search) if search else []
  105. order = order_by or [{"id": "asc"}]
  106. sql = select(self.model).where(*conditions).order_by(*self.__order_by(order))
  107. # 处理预加载选项
  108. final_preload = preload
  109. # 如果没有提供preload且children_attr存在,则添加到预加载选项中
  110. if preload is None and children_attr and hasattr(self.model, children_attr):
  111. # 获取模型默认预加载选项
  112. model_defaults = getattr(self.model, "__loader_options__", [])
  113. # 将children_attr添加到默认预加载选项中
  114. final_preload = [*list(model_defaults), children_attr]
  115. # 应用预加载选项
  116. for opt in self.__loader_options(final_preload):
  117. sql = sql.options(opt)
  118. sql = await self.__filter_permissions(sql)
  119. result: Result = await self.auth.db.execute(sql)
  120. return result.scalars().all()
  121. except Exception as e:
  122. raise CustomException(msg=f"树形列表查询失败: {e!s}")
  123. async def page(
  124. self,
  125. offset: int,
  126. limit: int,
  127. order_by: builtins.list[dict[str, str]],
  128. search: dict,
  129. out_schema: type[OutSchemaType],
  130. preload: builtins.list[str | Any] | None = None,
  131. ) -> dict:
  132. """
  133. 获取分页数据
  134. 参数:
  135. - offset (int): 偏移量
  136. - limit (int): 每页数量
  137. - order_by (List[Dict[str, str]]): 排序字段
  138. - search (Dict): 查询条件
  139. - out_schema (Type[OutSchemaType]): 输出数据模型
  140. - preload (Optional[List[Union[str, Any]]]): 预加载关系
  141. 返回:
  142. - Dict: 分页数据
  143. 异常:
  144. - CustomException: 查询失败时抛出异常
  145. """
  146. try:
  147. conditions = await self.__build_conditions(**search) if search else []
  148. order = order_by or [{"id": "asc"}]
  149. sql = select(self.model).where(*conditions).order_by(*self.__order_by(order))
  150. # 应用预加载选项
  151. for opt in self.__loader_options(preload):
  152. sql = sql.options(opt)
  153. sql = await self.__filter_permissions(sql)
  154. # 优化count查询:使用主键计数而非全表扫描
  155. mapper = sa_inspect(self.model)
  156. pk_cols = list(getattr(mapper, "primary_key", []))
  157. if pk_cols:
  158. # 使用主键的第一列进行计数(主键必定非NULL,性能更好)
  159. count_sql = select(func.count(pk_cols[0])).select_from(self.model)
  160. else:
  161. # 降级方案:使用count(*)
  162. count_sql = select(func.count()).select_from(self.model)
  163. if conditions:
  164. count_sql = count_sql.where(*conditions)
  165. count_sql = await self.__filter_permissions(count_sql)
  166. total_result = await self.auth.db.execute(count_sql)
  167. total = total_result.scalar() or 0
  168. result: Result = await self.auth.db.execute(sql.offset(offset).limit(limit))
  169. objs = result.scalars().all()
  170. return {
  171. "page_no": offset // limit + 1 if limit else 1,
  172. "page_size": limit or 10,
  173. "total": total,
  174. "has_next": offset + limit < total,
  175. "items": [out_schema.model_validate(obj).model_dump() for obj in objs],
  176. }
  177. except Exception as e:
  178. raise CustomException(msg=f"分页查询失败: {e!s}")
  179. async def create(self, data: CreateSchemaType | dict, skip_tenant_id: bool = False) -> ModelType:
  180. """
  181. 创建新对象
  182. 参数:
  183. - data (Union[CreateSchemaType, Dict]): 对象属性
  184. - skip_tenant_id (bool, optional): 是否跳过设置租户ID。默认值为 False。
  185. 返回:
  186. - ModelType: 新创建的对象实例
  187. 异常:
  188. - CustomException: 创建失败时抛出异常
  189. """
  190. try:
  191. obj_dict = data if isinstance(data, dict) else data.model_dump()
  192. obj = self.model(**obj_dict)
  193. # 设置字段值(只检查一次current_user)
  194. if self.auth.user:
  195. if hasattr(obj, "created_id"):
  196. setattr(obj, "created_id", self.auth.user.id)
  197. if hasattr(obj, "updated_id"):
  198. setattr(obj, "updated_id", self.auth.user.id)
  199. # 自动设置 tenant_id(如果模型有该字段)
  200. if not skip_tenant_id and hasattr(self.model, "tenant_id") and hasattr(obj, "tenant_id"):
  201. if hasattr(self.auth, "tenant_id") and self.auth.tenant_id:
  202. setattr(obj, "tenant_id", self.auth.tenant_id)
  203. elif self.auth.user and hasattr(self.auth.user, "tenant_id"):
  204. setattr(obj, "tenant_id", self.auth.user.tenant_id)
  205. self.auth.db.add(obj)
  206. await self.auth.db.flush()
  207. await self.auth.db.refresh(obj)
  208. return obj
  209. except Exception as e:
  210. raise CustomException(msg=f"创建失败: {e!s}")
  211. async def update(self, id: int, data: UpdateSchemaType | dict) -> ModelType:
  212. """
  213. 更新对象
  214. 参数:
  215. - id (int): 对象ID
  216. - data (Union[UpdateSchemaType, Dict]): 更新的属性及值
  217. 返回:
  218. - ModelType: 更新后的对象实例
  219. 异常:
  220. - CustomException: 更新失败时抛出异常
  221. """
  222. try:
  223. obj_dict = (
  224. data
  225. if isinstance(data, dict)
  226. else data.model_dump(exclude_unset=True, exclude={"id"})
  227. )
  228. # 获取对象时不自动预加载关系,避免循环依赖
  229. obj = await self.get(id=id, preload=[])
  230. if not obj:
  231. raise CustomException(msg="更新对象不存在")
  232. # 设置字段值(只检查一次current_user)
  233. if self.auth.user and hasattr(obj, "updated_id"):
  234. setattr(obj, "updated_id", self.auth.user.id)
  235. for key, value in obj_dict.items():
  236. if hasattr(obj, key):
  237. setattr(obj, key, value)
  238. await self.auth.db.flush()
  239. # 刷新对象时不自动预加载关系
  240. await self.auth.db.refresh(obj)
  241. # 权限二次确认:flush后再次验证对象仍在权限范围内
  242. # 防止并发修改导致的权限逃逸(如其他事务修改了created_id)
  243. # 验证时也不自动预加载关系
  244. verify_obj = await self.get(id=id, preload=[])
  245. if not verify_obj:
  246. # 对象已被删除或权限已失效
  247. raise CustomException(msg="更新失败,对象不存在或无权限访问")
  248. return obj
  249. except Exception as e:
  250. raise CustomException(msg=f"更新失败: {e!s}")
  251. async def delete(self, ids: builtins.list[int]) -> None:
  252. """
  253. 删除对象
  254. 参数:
  255. - ids (List[int]): 对象ID列表
  256. 返回:
  257. - None
  258. 异常:
  259. - CustomException: 删除失败时抛出异常
  260. """
  261. try:
  262. mapper = sa_inspect(self.model)
  263. pk_cols = list(getattr(mapper, "primary_key", []))
  264. if not pk_cols:
  265. raise CustomException(msg="模型缺少主键,无法删除")
  266. if len(pk_cols) > 1:
  267. raise CustomException(msg="暂不支持复合主键的批量删除")
  268. # 只删除有权限的数据
  269. sql = delete(self.model).where(pk_cols[0].in_(ids))
  270. await self.auth.db.execute(sql)
  271. await self.auth.db.flush()
  272. except Exception as e:
  273. raise CustomException(msg=f"删除失败: {e!s}")
  274. async def clear(self) -> None:
  275. """
  276. 清空对象表
  277. 返回:
  278. - None
  279. 异常:
  280. - CustomException: 清空失败时抛出异常
  281. """
  282. try:
  283. sql = delete(self.model)
  284. await self.auth.db.execute(sql)
  285. await self.auth.db.flush()
  286. except Exception as e:
  287. raise CustomException(msg=f"清空失败: {e!s}")
  288. async def set(self, ids: builtins.list[int], **kwargs) -> None:
  289. """
  290. 批量更新对象
  291. 参数:
  292. - ids (List[int]): 对象ID列表
  293. - **kwargs: 更新的属性及值
  294. 返回:
  295. - None
  296. 异常:
  297. - CustomException: 更新失败时抛出异常
  298. """
  299. try:
  300. mapper = sa_inspect(self.model)
  301. pk_cols = list(getattr(mapper, "primary_key", []))
  302. if not pk_cols:
  303. raise CustomException(msg="模型缺少主键,无法更新")
  304. if len(pk_cols) > 1:
  305. raise CustomException(msg="暂不支持复合主键的批量更新")
  306. # 只更新有权限的数据
  307. sql = update(self.model).where(pk_cols[0].in_(ids)).values(**kwargs)
  308. await self.auth.db.execute(sql)
  309. await self.auth.db.flush()
  310. except CustomException:
  311. raise
  312. except Exception as e:
  313. raise CustomException(msg=f"批量更新失败: {e!s}")
  314. async def __filter_permissions(self, sql: Select) -> Select:
  315. """
  316. 过滤数据权限(仅用于Select)。
  317. """
  318. filter = Permission(model=self.model, auth=self.auth)
  319. return await filter.filter_query(sql)
  320. async def __build_conditions(self, **kwargs) -> builtins.list[ColumnElement]:
  321. """
  322. 构建查询条件
  323. 参数:
  324. - **kwargs: 查询参数
  325. 返回:
  326. - List[ColumnElement]: SQL条件表达式列表
  327. 异常:
  328. - CustomException: 查询参数不存在时抛出异常
  329. """
  330. conditions = []
  331. for key, value in kwargs.items():
  332. if value is None or value == "":
  333. continue
  334. attr = getattr(self.model, key)
  335. if isinstance(value, tuple):
  336. seq, val = value
  337. if seq == "None":
  338. conditions.append(attr.is_(None))
  339. elif seq == "not None":
  340. conditions.append(attr.isnot(None))
  341. elif seq == "date" and val:
  342. conditions.append(func.date_format(attr, "%Y-%m-%d") == val)
  343. elif seq == "month" and val:
  344. conditions.append(func.date_format(attr, "%Y-%m") == val)
  345. elif seq == "like" and val:
  346. conditions.append(attr.like(f"%{val}%"))
  347. elif seq == "in":
  348. # 通用约定:("in", []) 应当返回空集(恒假),不能跳过条件导致查询退化成全量(仅权限过滤)
  349. if val is None:
  350. continue
  351. if isinstance(val, (list, tuple, set)) and len(val) == 0:
  352. conditions.append(false())
  353. else:
  354. conditions.append(attr.in_(val))
  355. elif seq == "between" and isinstance(val, (list, tuple)) and len(val) == 2:
  356. conditions.append(attr.between(val[0], val[1]))
  357. elif seq == "!=" or (seq == "ne" and val):
  358. conditions.append(attr != val)
  359. elif seq == ">" or (seq == "gt" and val):
  360. conditions.append(attr > val)
  361. elif seq == ">=" or (seq == "ge" and val):
  362. conditions.append(attr >= val)
  363. elif seq == "<" or (seq == "lt" and val):
  364. conditions.append(attr < val)
  365. elif seq == "<=" or (seq == "le" and val):
  366. conditions.append(attr <= val)
  367. elif seq == "==" or (seq == "eq" and val):
  368. conditions.append(attr == val)
  369. else:
  370. conditions.append(attr == value)
  371. return conditions
  372. def __order_by(self, order_by: builtins.list[dict[str, str]]) -> builtins.list[ColumnElement]:
  373. """
  374. 获取排序字段
  375. 参数:
  376. - order_by (List[Dict[str, str]]): 排序字段列表,格式为 [{'id': 'asc'}, {'name': 'desc'}]
  377. 返回:
  378. - List[ColumnElement]: 排序字段列表
  379. 异常:
  380. - CustomException: 排序字段不存在时抛出异常
  381. """
  382. columns = []
  383. for order in order_by:
  384. for field, direction in order.items():
  385. column = getattr(self.model, field)
  386. columns.append(desc(column) if direction.lower() == "desc" else asc(column))
  387. return columns
  388. def __loader_options(
  389. self, preload: builtins.list[str | Any] | None = None
  390. ) -> builtins.list[Any]:
  391. """
  392. 构建预加载选项
  393. 参数:
  394. - preload (Optional[List[Union[str, Any]]]): 预加载关系,支持关系名字符串或SQLAlchemy loader option
  395. 返回:
  396. - List[Any]: 预加载选项列表
  397. """
  398. options = []
  399. # 获取模型定义的默认加载选项
  400. model_loader_options = getattr(self.model, "__loader_options__", [])
  401. # 合并所有需要预加载的选项
  402. all_preloads = set(model_loader_options)
  403. if preload:
  404. for opt in preload:
  405. if isinstance(opt, str):
  406. all_preloads.add(opt)
  407. elif preload == []:
  408. # 如果明确指定空列表,则不使用任何预加载
  409. all_preloads = set()
  410. # 处理所有预加载选项
  411. for opt in all_preloads:
  412. if isinstance(opt, str):
  413. # 使用selectinload来避免在异步环境中的MissingGreenlet错误
  414. if hasattr(self.model, opt):
  415. options.append(selectinload(getattr(self.model, opt)))
  416. else:
  417. # 直接使用非字符串的加载选项
  418. options.append(opt)
  419. return options