common_util.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. import importlib
  2. import re
  3. import uuid
  4. from collections.abc import Generator, Sequence
  5. from pathlib import Path
  6. from typing import Any, Literal
  7. from sqlalchemy.engine.row import Row
  8. from sqlalchemy.orm import DeclarativeBase
  9. from sqlalchemy.orm.collections import InstrumentedList
  10. from sqlalchemy.sql.elements import Null
  11. from sqlalchemy.sql.expression import null
  12. from app.config.setting import settings
  13. from app.core.exceptions import CustomException
  14. from app.core.logger import log
  15. def import_module(module: str, desc: str) -> Any:
  16. """
  17. 动态导入模块
  18. 参数:
  19. - module (str): 模块名称。
  20. - desc (str): 模块描述。
  21. 返回:
  22. - Any: 模块对象。
  23. """
  24. try:
  25. module_path, module_class = module.rsplit(".", 1)
  26. module = importlib.import_module(module_path) # pyright: ignore[reportAssignmentType]
  27. return getattr(module, module_class)
  28. except ModuleNotFoundError:
  29. log.error(f"❗️ 导入{desc}失败,未找到模块:{module}")
  30. raise
  31. except AttributeError:
  32. log.error(f"❗ ️导入{desc}失败,未找到模块方法:{module}")
  33. raise
  34. async def import_modules_async(modules: list, desc: str, **kwargs) -> None:
  35. """
  36. 异步导入模块列表
  37. 参数:
  38. - modules (list[str]): 模块列表。
  39. - desc (str): 模块描述。
  40. - kwargs: 额外参数。
  41. 返回:
  42. - None
  43. """
  44. for module in modules:
  45. if not module:
  46. continue
  47. try:
  48. module_path = module[0 : module.rindex(".")]
  49. module_name = module[module.rindex(".") + 1 :]
  50. module_obj = importlib.import_module(module_path)
  51. await getattr(module_obj, module_name)(**kwargs)
  52. except ModuleNotFoundError:
  53. log.error(f"❌️ 导入{desc}失败,未找到模块:{module}")
  54. raise
  55. except AttributeError:
  56. log.error(f"❌️ 导入{desc}失败,未找到模块方法:{module}")
  57. raise
  58. def get_random_character() -> str:
  59. """
  60. 生成随机字符串
  61. 返回:
  62. - str: 随机字符串。
  63. """
  64. return uuid.uuid4().hex
  65. def generate_random_code(length: int = 6) -> str:
  66. """
  67. 生成指定长度的随机数字验证码
  68. 参数:
  69. - length (int): 验证码长度,默认6位
  70. 返回:
  71. - str: 随机数字验证码字符串
  72. """
  73. import random
  74. return ''.join(random.choices('0123456789', k=length))
  75. def uuid4_str() -> str:
  76. """
  77. 数据库引擎 UUID 类型兼容:返回无连字符的 UUID 字符串。
  78. 返回:
  79. - str: UUID 字符串。
  80. """
  81. return str(uuid.uuid4())
  82. def get_parent_id_map(model_list: Sequence[DeclarativeBase]) -> dict[int, int]:
  83. """
  84. 获取父级 ID 映射字典
  85. 参数:
  86. - model_list (Sequence[DeclarativeBase]): 模型列表。
  87. 返回:
  88. - Dict[int, int]: {id: parent_id} 映射字典。
  89. """
  90. return {item.id: item.parent_id for item in model_list} # pyright: ignore[reportAttributeAccessIssue]
  91. def get_parent_recursion(
  92. id: int, id_map: dict[int, int], ids: list[int] | None = None
  93. ) -> list[int]:
  94. """
  95. 递归获取所有父级 ID
  96. 参数:
  97. - id (int): 当前 ID。
  98. - id_map (dict[int, int]): ID 映射字典。
  99. - ids (list[int] | None): 已收集的 ID 列表。
  100. 返回:
  101. - list[int]: 所有父级 ID 列表。
  102. """
  103. ids = ids or []
  104. if id in ids:
  105. raise CustomException(msg="递归获取父级ID失败,不可以自引用")
  106. ids.append(id)
  107. parent_id = id_map.get(id)
  108. if parent_id:
  109. get_parent_recursion(parent_id, id_map, ids)
  110. return ids
  111. def get_child_id_map(
  112. model_list: Sequence[DeclarativeBase],
  113. ) -> dict[int, list[int]]:
  114. """
  115. 获取子级 ID 映射字典
  116. 参数:
  117. - model_list (Sequence[DeclarativeBase]): 模型列表。
  118. 返回:
  119. - Dict[int, List[int]]: {id: [child_ids]} 映射字典。
  120. """
  121. data_map = {}
  122. for model in model_list:
  123. data_map.setdefault(model.id, []) # pyright: ignore[reportAttributeAccessIssue]
  124. if model.parent_id: # pyright: ignore[reportAttributeAccessIssue]
  125. data_map.setdefault(model.parent_id, []).append(model.id) # pyright: ignore[reportAttributeAccessIssue]
  126. return data_map
  127. def get_child_recursion(
  128. id: int, id_map: dict[int, list[int]], ids: list[int] | None = None
  129. ) -> list[int]:
  130. """
  131. 递归获取所有子级 ID
  132. 参数:
  133. - id (int): 当前 ID。
  134. - id_map (dict[int, list[int]]): ID 映射字典。
  135. - ids (list[int] | None): 已收集的 ID 列表。
  136. 返回:
  137. - list[int]: 所有子级 ID 列表。
  138. """
  139. ids = ids or []
  140. ids.append(id)
  141. for child in id_map.get(id, []):
  142. get_child_recursion(child, id_map, ids)
  143. return ids
  144. def traversal_to_tree(nodes: list[dict[str, Any]]) -> list[dict[str, Any]]:
  145. """
  146. 通过遍历算法构造树形结构
  147. 参数:
  148. - nodes (list[dict[str, Any]]): 树节点列表。
  149. 返回:
  150. - list[dict[str, Any]]: 构造后的树形结构列表。
  151. """
  152. tree: list[dict[str, Any]] = []
  153. node_dict = {node["id"]: node for node in nodes}
  154. for node in nodes:
  155. # 确保每个节点都有children字段,即使没有子节点也设置为null
  156. if "children" not in node:
  157. node["children"] = None
  158. parent_id = node["parent_id"]
  159. if parent_id is None:
  160. tree.append(node)
  161. else:
  162. parent_node = node_dict.get(parent_id)
  163. if parent_node is not None:
  164. if "children" not in parent_node or parent_node["children"] is None:
  165. parent_node["children"] = []
  166. if node not in parent_node["children"]:
  167. parent_node["children"].append(node)
  168. else:
  169. if node not in tree:
  170. tree.append(node)
  171. # 确保所有节点都有children字段
  172. for node in tree:
  173. if "children" not in node:
  174. node["children"] = None
  175. return tree
  176. def recursive_to_tree(
  177. nodes: list[dict[str, Any]], *, parent_id: int | None = None
  178. ) -> list[dict[str, Any]]:
  179. """
  180. 通过递归算法构造树形结构(性能影响较大)
  181. 参数:
  182. - nodes (list[dict[str, Any]]): 树节点列表。
  183. - parent_id (int | None): 父节点 ID,默认为 None 表示根节点。
  184. 返回:
  185. - list[dict[str, Any]]: 构造后的树形结构列表。
  186. """
  187. tree: list[dict[str, Any]] = []
  188. for node in nodes:
  189. if node["parent_id"] == parent_id:
  190. child_nodes = recursive_to_tree(nodes, parent_id=node["id"])
  191. if child_nodes:
  192. node["children"] = child_nodes
  193. tree.append(node)
  194. return tree
  195. def bytes2human(n: int, format_str: str = "%(value).1f%(symbol)s") -> str:
  196. """
  197. 字节数转人类可读格式
  198. Used by various scripts. See:
  199. http://goo.gl/zeJZl
  200. >>> bytes2human(10000)
  201. '9.8K'
  202. >>> bytes2human(100001221)
  203. '95.4M'
  204. 参数:
  205. - n (int): 字节数。
  206. - format_str (str): 格式化字符串,默认 '%(value).1f%(symbol)s'。
  207. 返回:
  208. - str: 可读的字节字符串,如 '1.5MB'。
  209. """
  210. symbols = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
  211. prefix = {s: 1 << (i + 1) * 10 for i, s in enumerate(symbols[1:])}
  212. for symbol in reversed(symbols[1:]):
  213. if n >= prefix[symbol]:
  214. value = float(n) / prefix[symbol]
  215. return format_str % locals()
  216. return format_str % {"symbol": symbols[0], "value": n}
  217. def bytes2file_response(bytes_info: bytes) -> Generator[bytes, Any, None]:
  218. """
  219. 将字节内容封装为单块流式生成器,供文件下载响应使用。
  220. 参数:
  221. - bytes_info (bytes): 文件二进制内容。
  222. 返回:
  223. - Generator[bytes, Any, None]: 仅 yield 一次的字节生成器。
  224. """
  225. yield bytes_info
  226. def get_filepath_from_url(url: str) -> Path:
  227. """
  228. 工具方法:根据请求参数获取文件路径
  229. 参数:
  230. - url (str): 请求参数中的 url 参数。
  231. 返回:
  232. - Path: 文件路径。
  233. """
  234. file_info = url.split("?")[1].split("&")
  235. task_id = file_info[0].split("=")[1]
  236. file_name = file_info[1].split("=")[1]
  237. task_path = file_info[2].split("=")[1]
  238. filepath = settings.STATIC_ROOT.joinpath(task_path, task_id, file_name)
  239. return filepath
  240. class SqlalchemyUtil:
  241. """
  242. sqlalchemy工具类
  243. """
  244. @classmethod
  245. def base_to_dict(
  246. cls,
  247. obj: DeclarativeBase | dict[str, Any],
  248. transform_case: Literal["no_case", "snake_to_camel", "camel_to_snake"] = "no_case",
  249. ):
  250. """
  251. 将 SQLAlchemy 模型或字典转为普通 dict,并可做键名大小写转换。
  252. 参数:
  253. - obj (DeclarativeBase | dict[str, Any]): 模型实例或字典。
  254. - transform_case (Literal[...]): no_case / snake_to_camel / camel_to_snake。
  255. 返回:
  256. - dict[str, Any]: 扁平字典结果。
  257. """
  258. if isinstance(obj, DeclarativeBase):
  259. base_dict = obj.__dict__.copy()
  260. base_dict.pop("_sa_instance_state", None)
  261. for name, value in base_dict.items():
  262. if isinstance(value, InstrumentedList):
  263. base_dict[name] = cls.serialize_result(value, "snake_to_camel")
  264. elif isinstance(obj, dict):
  265. base_dict = obj.copy()
  266. if transform_case == "snake_to_camel":
  267. return {CamelCaseUtil.snake_to_camel(k): v for k, v in base_dict.items()}
  268. if transform_case == "camel_to_snake":
  269. return {SnakeCaseUtil.camel_to_snake(k): v for k, v in base_dict.items()}
  270. return base_dict
  271. @classmethod
  272. def serialize_result(
  273. cls,
  274. result: Any,
  275. transform_case: Literal["no_case", "snake_to_camel", "camel_to_snake"] = "no_case",
  276. ):
  277. """
  278. 将 SQLAlchemy 查询结果(模型、列表、Row 等)递归序列化为可 JSON 化结构。
  279. 参数:
  280. - result (Any): ORM 对象、列表、Row 等。
  281. - transform_case (Literal[...]): 键名转换策略。
  282. 返回:
  283. - Any: 序列化后的 Python 内置类型或嵌套结构。
  284. """
  285. if isinstance(result, (DeclarativeBase, dict)):
  286. return cls.base_to_dict(result, transform_case)
  287. if isinstance(result, list):
  288. return [cls.serialize_result(row, transform_case) for row in result]
  289. if isinstance(result, Row):
  290. if all(isinstance(row, DeclarativeBase) for row in result):
  291. return [cls.base_to_dict(row, transform_case) for row in result]
  292. if any(isinstance(row, DeclarativeBase) for row in result):
  293. return [cls.serialize_result(row, transform_case) for row in result]
  294. result_dict = result._asdict()
  295. if transform_case == "snake_to_camel":
  296. return {CamelCaseUtil.snake_to_camel(k): v for k, v in result_dict.items()}
  297. if transform_case == "camel_to_snake":
  298. return {SnakeCaseUtil.camel_to_snake(k): v for k, v in result_dict.items()}
  299. return result_dict
  300. return result
  301. @classmethod
  302. def get_server_default_null(
  303. cls, dialect_name: str, need_explicit_null: bool = True
  304. ) -> Null | None:
  305. """
  306. 按方言返回列默认值中的 NULL 表达(PostgreSQL 可显式 DEFAULT NULL)。
  307. 参数:
  308. - dialect_name (str): 数据库方言名。
  309. - need_explicit_null (bool): 是否生成显式 NULL 默认值。
  310. 返回:
  311. - Null | None: SQLAlchemy null() 或 None。
  312. """
  313. if need_explicit_null and dialect_name == "postgres":
  314. return null()
  315. return None
  316. class CamelCaseUtil:
  317. """
  318. 下划线形式(snake_case)转小驼峰形式(camelCase)工具方法
  319. """
  320. @classmethod
  321. def snake_to_camel(cls, snake_str: str):
  322. """
  323. 下划线形式 (snake_case) 转为小驼峰形式 (camelCase)。
  324. 参数:
  325. - snake_str (str): 下划线分隔字符串。
  326. 返回:
  327. - str: 合并首字母大写后的驼峰字符串。
  328. """
  329. # 分割字符串
  330. words = snake_str.split("_")
  331. # 小驼峰命名,第一个词首字母小写,其余词首字母大写
  332. # return words[0] + ''.join(word.capitalize() for word in words[1:])
  333. # 大驼峰命名,所有词首字母大写
  334. return "".join(word.capitalize() for word in words)
  335. @classmethod
  336. def transform_result(cls, result: Any):
  337. """
  338. 将查询结果递归序列化并将键名转为小驼峰。
  339. 参数:
  340. - result (Any): ORM 查询结果或嵌套结构。
  341. 返回:
  342. - Any: 小驼峰键名的序列化结果。
  343. """
  344. return SqlalchemyUtil.serialize_result(result=result, transform_case="snake_to_camel")
  345. class SnakeCaseUtil:
  346. """
  347. 小驼峰形式(camelCase)转下划线形式(snake_case)工具方法
  348. """
  349. @classmethod
  350. def camel_to_snake(cls, camel_str: str):
  351. """
  352. 小驼峰形式 (camelCase) 转为下划线形式 (snake_case)。
  353. 参数:
  354. - camel_str (str): 驼峰字符串。
  355. 返回:
  356. - str: 下划线分隔且全小写。
  357. """
  358. # 在大写字母前添加一个下划线,然后将整个字符串转为小写
  359. words = re.sub(r"(.)([A-Z][a-z]+)", r"\1_\2", camel_str)
  360. return re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", words).lower()
  361. @classmethod
  362. def transform_result(cls, result: Any):
  363. """
  364. 将查询结果递归序列化并将键名转为下划线形式。
  365. 参数:
  366. - result (Any): ORM 查询结果或嵌套结构。
  367. 返回:
  368. - Any: 下划线键名的序列化结果。
  369. """
  370. return SqlalchemyUtil.serialize_result(result=result, transform_case="camel_to_snake")