upload_util.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. import os
  2. import random
  3. import re
  4. from datetime import datetime
  5. from pathlib import Path
  6. from urllib.parse import urljoin
  7. import aiofiles
  8. from fastapi import UploadFile
  9. from app.config.setting import settings
  10. from app.core.exceptions import CustomException
  11. from app.core.logger import log
  12. DANGEROUS_EXTENSIONS = {
  13. ".py",
  14. ".pyc",
  15. ".pyo",
  16. ".php",
  17. ".php3",
  18. ".php4",
  19. ".php5",
  20. ".phtml",
  21. ".exe",
  22. ".bat",
  23. ".cmd",
  24. ".sh",
  25. ".bash",
  26. ".zsh",
  27. ".ps1",
  28. ".ps2",
  29. ".psm1",
  30. ".psd1",
  31. ".vbs",
  32. ".vbe",
  33. ".js",
  34. ".jse",
  35. ".wsf",
  36. ".wsh",
  37. ".msi",
  38. ".dll",
  39. ".so",
  40. ".dylib",
  41. ".jar",
  42. ".class",
  43. ".jsp",
  44. ".jspx",
  45. ".asp",
  46. ".aspx",
  47. ".asa",
  48. ".asax",
  49. ".cer",
  50. ".cdx",
  51. ".config",
  52. ".htaccess",
  53. ".htpasswd",
  54. ".sql",
  55. ".db",
  56. ".sqlite",
  57. ".sqlite3",
  58. }
  59. MIME_TYPE_MAPPING = {
  60. "image/jpeg": ".jpg",
  61. "image/png": ".png",
  62. "image/gif": ".gif",
  63. "image/webp": ".webp",
  64. "image/svg+xml": ".svg",
  65. "image/x-icon": ".ico",
  66. "image/bmp": ".bmp",
  67. "application/vnd.ms-excel": ".xls",
  68. "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
  69. "application/msword": ".doc",
  70. "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
  71. "application/pdf": ".pdf",
  72. "text/plain": ".txt",
  73. "text/csv": ".csv",
  74. }
  75. class UploadUtil:
  76. """
  77. 上传工具类
  78. """
  79. @staticmethod
  80. def generate_random_number() -> str:
  81. """
  82. 生成3位随机数字字符串。
  83. 返回:
  84. - str: 三位随机数字字符串。
  85. """
  86. return f"{random.randint(1, 999):03}"
  87. @staticmethod
  88. def check_file_exists(filepath: str) -> bool:
  89. """
  90. 检查文件是否存在。
  91. 参数:
  92. - filepath (str): 文件路径。
  93. 返回:
  94. - bool: 文件是否存在。
  95. """
  96. return Path(filepath).exists()
  97. @staticmethod
  98. def sanitize_filename(filename: str) -> str:
  99. """
  100. 清理文件名,移除危险字符和路径穿越。
  101. 参数:
  102. - filename (str): 原始文件名。
  103. 返回:
  104. - str: 安全的文件名。
  105. """
  106. if not filename:
  107. return ""
  108. filename = os.path.basename(filename)
  109. filename = re.sub(r'[<>:"/\\|?*\x00-\x1f]', "", filename)
  110. filename = re.sub(r"\.{2,}", ".", filename)
  111. filename = filename.strip(". ")
  112. if not filename:
  113. filename = f"file_{datetime.now().strftime('%Y%m%d%H%M%S')}"
  114. return filename
  115. @staticmethod
  116. def check_path_traversal(filename: str) -> bool:
  117. """
  118. 检查文件名是否包含路径穿越。
  119. 参数:
  120. - filename (str): 文件名。
  121. 返回:
  122. - bool: 是否安全(True 表示安全,False 表示存在路径穿越)。
  123. """
  124. dangerous_patterns = ["../", "..\\", "/", "\\", "\0"]
  125. for pattern in dangerous_patterns:
  126. if pattern in filename:
  127. return False
  128. return True
  129. @staticmethod
  130. def get_extension_from_filename(filename: str) -> str:
  131. """
  132. 从文件名获取扩展名。
  133. 参数:
  134. - filename (str): 文件名。
  135. 返回:
  136. - str: 扩展名(小写,包含点),如 ".jpg"。
  137. """
  138. if not filename or "." not in filename:
  139. return ""
  140. ext = filename.rsplit(".", 1)[-1].lower()
  141. return f".{ext}" if ext else ""
  142. @staticmethod
  143. def is_dangerous_extension(extension: str) -> bool:
  144. """
  145. 检查扩展名是否为危险类型。
  146. 参数:
  147. - extension (str): 文件扩展名。
  148. 返回:
  149. - bool: 是否为危险扩展名。
  150. """
  151. return extension.lower() in DANGEROUS_EXTENSIONS
  152. @staticmethod
  153. def detect_file_type(content: bytes) -> str | None:
  154. """
  155. 通过文件内容检测真实文件类型。
  156. 参数:
  157. - content (bytes): 文件内容(前几字节即可)。
  158. 返回:
  159. - str | None: 检测到的 MIME 类型,无法识别返回 None。
  160. """
  161. if content.startswith(b"\xff\xd8\xff"):
  162. return "image/jpeg"
  163. if content.startswith(b"\x89PNG\r\n\x1a\n"):
  164. return "image/png"
  165. if content.startswith(b"GIF87a") or content.startswith(b"GIF89a"):
  166. return "image/gif"
  167. if content.startswith(b"PK\x03\x04"):
  168. if b"[Content_Types].xml" in content[:1000]:
  169. return "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
  170. return "application/zip"
  171. if content.startswith(b"%PDF"):
  172. return "application/pdf"
  173. if content.startswith(b"\xd0\xcf\x11\xe0\xa1\xb1\x1a\xe1"):
  174. return "application/msword"
  175. return None
  176. @classmethod
  177. def validate_file_extension(cls, extension: str) -> bool:
  178. """
  179. 验证文件扩展名是否在允许列表中。
  180. 参数:
  181. - extension (str): 文件扩展名。
  182. 返回:
  183. - bool: 是否允许。
  184. 异常:
  185. - CustomException: 扩展名不允许时抛出。
  186. """
  187. ext_lower = extension.lower()
  188. if cls.is_dangerous_extension(ext_lower):
  189. raise CustomException(msg=f"不允许上传此类型的文件: {extension}")
  190. if ext_lower not in settings.ALLOWED_EXTENSIONS:
  191. raise CustomException(
  192. msg=f"文件类型不支持,允许的类型: {', '.join(settings.ALLOWED_EXTENSIONS)}"
  193. )
  194. return True
  195. @classmethod
  196. def validate_file_content_type(cls, content: bytes, claimed_extension: str) -> bool:
  197. """
  198. 验证文件内容类型与声明的扩展名是否匹配。
  199. 参数:
  200. - content (bytes): 文件内容。
  201. - claimed_extension (str): 声明的文件扩展名。
  202. 返回:
  203. - bool: 是否匹配。
  204. 异常:
  205. - CustomException: 类型不匹配时抛出。
  206. """
  207. detected_type = cls.detect_file_type(content)
  208. if detected_type:
  209. expected_ext = MIME_TYPE_MAPPING.get(detected_type, "")
  210. if expected_ext and expected_ext != claimed_extension.lower():
  211. log.warning(
  212. f"文件类型不匹配: 声明扩展名={claimed_extension}, 检测类型={detected_type}"
  213. )
  214. return True
  215. @staticmethod
  216. def check_file_size(file: UploadFile) -> bool:
  217. """
  218. 校验文件大小是否合法。
  219. 参数:
  220. - file (UploadFile): 上传的文件对象。
  221. 返回:
  222. - bool: 文件大小是否合法。
  223. 异常:
  224. - CustomException: 文件过大时抛出。
  225. """
  226. if file.size and file.size > settings.MAX_FILE_SIZE:
  227. raise CustomException(
  228. msg=f"文件大小超过限制,最大允许 {settings.MAX_FILE_SIZE // (1024 * 1024)}MB"
  229. )
  230. return True
  231. @classmethod
  232. def generate_safe_filename(cls, original_filename: str, extension: str) -> str:
  233. """
  234. 生成安全的文件名。
  235. 参数:
  236. - original_filename (str): 原始文件名。
  237. - extension (str): 文件扩展名。
  238. 返回:
  239. - str: 安全的文件名。
  240. """
  241. safe_name = cls.sanitize_filename(original_filename)
  242. if safe_name and "." in safe_name:
  243. name_part = safe_name.rsplit(".", 1)[0]
  244. else:
  245. name_part = safe_name or "file"
  246. name_part = re.sub(r"[^a-zA-Z0-9_\-\u4e00-\u9fa5]", "", name_part)
  247. if len(name_part) > 50:
  248. name_part = name_part[:50]
  249. timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
  250. random_suffix = cls.generate_random_number()
  251. return f"{name_part}_{timestamp}{settings.UPLOAD_MACHINE}{random_suffix}{extension}"
  252. @staticmethod
  253. def check_file_timestamp(filename: str) -> bool:
  254. """
  255. 校验文件时间戳是否合法。
  256. 参数:
  257. - filename (str): 文件名(包含时间戳片段)。
  258. 返回:
  259. - bool: 时间戳是否合法。
  260. """
  261. try:
  262. name_parts = filename.rsplit(".", 1)[0].split("_")
  263. timestamp = name_parts[-1].split(settings.UPLOAD_MACHINE)[0]
  264. datetime.strptime(timestamp, "%Y%m%d%H%M%S")
  265. return True
  266. except (ValueError, IndexError):
  267. return False
  268. @staticmethod
  269. def check_file_machine(filename: str) -> bool:
  270. """
  271. 校验文件机器码是否合法。
  272. 参数:
  273. - filename (str): 文件名。
  274. 返回:
  275. - bool: 机器码是否合法。
  276. """
  277. try:
  278. name_without_ext = filename.rsplit(".", 1)[0]
  279. return len(name_without_ext) >= 4 and name_without_ext[-4] == settings.UPLOAD_MACHINE
  280. except IndexError:
  281. return False
  282. @staticmethod
  283. def check_file_random_code(filename: str) -> bool:
  284. """
  285. 校验文件随机码是否合法。
  286. 参数:
  287. - filename (str): 文件名。
  288. 返回:
  289. - bool: 随机码是否合法(000–999)。
  290. """
  291. try:
  292. code = filename.rsplit(".", 1)[0][-3:]
  293. return code.isdigit() and 1 <= int(code) <= 999
  294. except IndexError:
  295. return False
  296. @staticmethod
  297. def generate_file(filepath: Path, chunk_size: int = 8192):
  298. """
  299. 根据文件生成二进制数据迭代器。
  300. 参数:
  301. - filepath (Path): 文件路径。
  302. - chunk_size (int): 分块大小,默认 8192 字节。
  303. 返回:
  304. - Iterator[bytes]: 文件二进制数据分块迭代器。
  305. """
  306. with filepath.open("rb") as f:
  307. while chunk := f.read(chunk_size):
  308. yield chunk
  309. @staticmethod
  310. def delete_file(filepath: Path) -> bool:
  311. """
  312. 删除文件。
  313. 参数:
  314. - filepath (Path): 文件路径。
  315. 返回:
  316. - bool: 删除是否成功。
  317. """
  318. try:
  319. filepath.unlink(missing_ok=True)
  320. return True
  321. except OSError:
  322. return False
  323. @classmethod
  324. async def upload_file(cls, file: UploadFile, base_url: str) -> tuple[str, Path, str]:
  325. """
  326. 安全文件上传。
  327. 参数:
  328. - file (UploadFile): 上传的文件对象。
  329. - base_url (str): 基础 URL。
  330. 返回:
  331. - tuple[str, Path, str]: (文件名, 文件路径, 文件 URL)。
  332. 异常:
  333. - CustomException: 当文件校验失败时抛出。
  334. """
  335. if not file or not file.filename:
  336. raise CustomException(msg="请选择要上传的文件")
  337. original_filename = file.filename
  338. if not cls.check_path_traversal(original_filename):
  339. log.error(f"检测到路径穿越攻击: {original_filename}")
  340. raise CustomException(msg="文件名包含非法字符")
  341. extension = cls.get_extension_from_filename(original_filename)
  342. if not extension:
  343. raise CustomException(msg="无法识别文件类型")
  344. cls.validate_file_extension(extension)
  345. cls.check_file_size(file)
  346. content = await file.read()
  347. await file.seek(0)
  348. cls.validate_file_content_type(content, extension)
  349. safe_filename = cls.generate_safe_filename(original_filename, extension)
  350. try:
  351. dir_path = settings.UPLOAD_FILE_PATH.joinpath(datetime.now().strftime("%Y/%m/%d"))
  352. dir_path.mkdir(parents=True, exist_ok=True)
  353. filepath = dir_path.joinpath(safe_filename)
  354. if not filepath.resolve().is_relative_to(settings.UPLOAD_FILE_PATH.resolve()):
  355. log.error(f"检测到路径穿越攻击,目标路径: {filepath}")
  356. raise CustomException(msg="非法的文件路径")
  357. file_url = urljoin(base_url, str(filepath))
  358. chunk_size = 8 * 1024 * 1024
  359. async with aiofiles.open(filepath, "wb") as f:
  360. while chunk := await file.read(chunk_size):
  361. await f.write(chunk)
  362. log.info(f"文件上传成功: {safe_filename}")
  363. return safe_filename, filepath, file_url
  364. except CustomException:
  365. raise
  366. except Exception as e:
  367. log.error(f"文件上传失败: {e}")
  368. raise CustomException(msg=f"文件上传失败: {e}")
  369. @staticmethod
  370. def get_file_tree(file_path: str) -> list[dict]:
  371. """
  372. 获取文件树结构。
  373. 参数:
  374. - file_path (str): 文件路径。
  375. 返回:
  376. - list[dict]: 文件树列表。
  377. """
  378. return [{"name": item.name, "is_dir": item.is_dir()} for item in Path(file_path).iterdir()]
  379. @classmethod
  380. async def download_file(cls, file_path: str) -> str:
  381. """
  382. 下载文件,生成新的文件名。
  383. 参数:
  384. - file_path (str): 文件路径。
  385. 返回:
  386. - str: 文件下载信息。
  387. """
  388. filename = cls.generate_file(Path(file_path))
  389. return str(filename)