import_util.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. import importlib
  2. import inspect
  3. import os
  4. from functools import lru_cache
  5. from pathlib import Path
  6. from typing import Any
  7. from sqlalchemy import inspect as sa_inspect
  8. from app.config.path_conf import BASE_DIR
  9. from app.core.exceptions import CustomException
  10. class ImportUtil:
  11. """
  12. 扫描工程中的 ORM 模型文件并做有效性校验的辅助类。
  13. """
  14. @classmethod
  15. def find_project_root(cls) -> Path:
  16. """
  17. 返回项目根目录(与配置中的 `BASE_DIR` 一致)。
  18. 返回:
  19. - Path: 项目根路径。
  20. """
  21. return BASE_DIR
  22. @classmethod
  23. def is_valid_model(cls, obj: Any, base_class: type) -> bool:
  24. """
  25. 判断是否为可映射的 SQLAlchemy 模型类(含表名与非空列)。
  26. 参数:
  27. - obj (Any): 待验证对象(一般为类)。
  28. - base_class (type): ORM 声明基类。
  29. 返回:
  30. - bool: 是否为有效模型类。
  31. """
  32. # 必须继承自base_class且不是base_class本身
  33. if not (inspect.isclass(obj) and issubclass(obj, base_class) and obj is not base_class):
  34. return False
  35. # 必须有表名定义(排除抽象基类)
  36. if not hasattr(obj, "__tablename__") or obj.__tablename__ is None:
  37. return False
  38. # 必须有至少一个列定义
  39. try:
  40. return len(sa_inspect(obj).columns) > 0
  41. except Exception:
  42. return False
  43. @classmethod
  44. @lru_cache(maxsize=256)
  45. def find_models(cls, base_class: type) -> list[Any]:
  46. """
  47. 遍历工程内 `model.py` / `models.py`,收集去重后的有效模型类。
  48. 参数:
  49. - base_class (type): SQLAlchemy 声明基类。
  50. 返回:
  51. - list[Any]: 模型类列表。
  52. 异常:
  53. - ImportError: 模块导入失败(非「无法从某名导入」类警告)。
  54. - CustomException: 处理模块时发生未预期错误。
  55. """
  56. models = []
  57. # 按类对象去重
  58. seen_models = set()
  59. # 按表名去重(防止同表名冲突)
  60. seen_tables = set()
  61. # 记录已经处理过的model.py文件路径
  62. processed_model_files = set()
  63. project_root = cls.find_project_root()
  64. # 排除目录扩展
  65. exclude_dirs = {
  66. "venv",
  67. ".env",
  68. ".git",
  69. "__pycache__",
  70. "migrations",
  71. "alembic",
  72. "tests",
  73. "test",
  74. "docs",
  75. "examples",
  76. "scripts",
  77. ".venv",
  78. "static",
  79. "templates",
  80. "sql",
  81. "env",
  82. }
  83. # 定义要搜索的模型目录模式
  84. model_dir_patterns = ["model.py", "models.py"]
  85. # 使用一个更高效的方法来查找所有model.py文件
  86. model_files = []
  87. for root, dirs, files in os.walk(project_root):
  88. # 过滤排除目录
  89. dirs[:] = [d for d in dirs if d not in exclude_dirs]
  90. for file in files:
  91. if file in model_dir_patterns:
  92. file_path = Path(root) / file
  93. # 构建相对于项目根的模块路径
  94. relative_path = file_path.relative_to(project_root)
  95. model_files.append((file_path, relative_path))
  96. # 按模块路径排序,确保先导入基础模块
  97. model_files.sort(key=lambda x: str(x[1]))
  98. for file_path, relative_path in model_files:
  99. # 确保文件路径没有被处理过
  100. if str(file_path) in processed_model_files:
  101. continue
  102. processed_model_files.add(str(file_path))
  103. # 构建模块名(将路径分隔符转换为点)
  104. module_parts = (*relative_path.parts[:-1], relative_path.stem)
  105. module_name = ".".join(module_parts)
  106. try:
  107. # 导入模块
  108. module = importlib.import_module(module_name)
  109. # 获取模块中的所有类
  110. for _name, obj in inspect.getmembers(module, inspect.isclass):
  111. # 验证模型有效性
  112. if not cls.is_valid_model(obj, base_class):
  113. continue
  114. # 检查类对象重复
  115. if obj in seen_models:
  116. continue
  117. # 检查表名重复
  118. table_name = obj.__tablename__
  119. if table_name in seen_tables:
  120. continue
  121. # 添加到已处理集合
  122. seen_models.add(obj)
  123. seen_tables.add(table_name)
  124. models.append(obj)
  125. except ImportError as e:
  126. if "cannot import name" not in str(e):
  127. raise ImportError(f"❗️ 警告: 无法导入模块 {module_name}: {e}")
  128. except Exception as e:
  129. raise CustomException(f"❌️ 处理模块 {module_name} 时出错: {e}")
  130. # 查找apscheduler_jobs表的模型(如果存在)
  131. cls._find_apscheduler_model(base_class, models, seen_models, seen_tables)
  132. return models
  133. @classmethod
  134. def _find_apscheduler_model(
  135. cls,
  136. base_class: type,
  137. models: list[Any],
  138. seen_models: set[Any],
  139. seen_tables: set[str],
  140. ) -> None:
  141. """
  142. 尝试从调度相关模块补充 `apscheduler_jobs` 表对应模型。
  143. 参数:
  144. - base_class (type): ORM 声明基类。
  145. - models (list[Any]): 已收集模型列表(就地追加)。
  146. - seen_models (set[Any]): 已见模型对象集合。
  147. - seen_tables (set[str]): 已见表名集合。
  148. 返回:
  149. - None
  150. 异常:
  151. - CustomException: 扫描过程出现未预期错误。
  152. """
  153. # 尝试从apscheduler相关模块导入
  154. try:
  155. # 检查是否有自定义的apscheduler模型
  156. for module_name in [
  157. "app.core.ap_scheduler",
  158. "app.module_task.scheduler_test",
  159. ]:
  160. try:
  161. module = importlib.import_module(module_name)
  162. for _name, obj in inspect.getmembers(module, inspect.isclass):
  163. if (
  164. cls.is_valid_model(obj, base_class)
  165. and hasattr(obj, "__tablename__")
  166. and obj.__tablename__ == "apscheduler_jobs"
  167. ) and (obj not in seen_models and "apscheduler_jobs" not in seen_tables):
  168. seen_models.add(obj)
  169. seen_tables.add("apscheduler_jobs")
  170. models.append(obj)
  171. print(
  172. f"✅️ 找到有效模型: {obj.__module__}.{obj.__name__} (表: apscheduler_jobs)"
  173. )
  174. except ImportError:
  175. pass
  176. except Exception as e:
  177. raise CustomException(f"❗️ 查找APScheduler模型时出错: {e}")