| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470 |
- import os
- import random
- import re
- from datetime import datetime
- from pathlib import Path
- from urllib.parse import urljoin
- import aiofiles
- from fastapi import UploadFile
- from app.config.setting import settings
- from app.core.exceptions import CustomException
- from app.core.logger import log
- DANGEROUS_EXTENSIONS = {
- ".py",
- ".pyc",
- ".pyo",
- ".php",
- ".php3",
- ".php4",
- ".php5",
- ".phtml",
- ".exe",
- ".bat",
- ".cmd",
- ".sh",
- ".bash",
- ".zsh",
- ".ps1",
- ".ps2",
- ".psm1",
- ".psd1",
- ".vbs",
- ".vbe",
- ".js",
- ".jse",
- ".wsf",
- ".wsh",
- ".msi",
- ".dll",
- ".so",
- ".dylib",
- ".jar",
- ".class",
- ".jsp",
- ".jspx",
- ".asp",
- ".aspx",
- ".asa",
- ".asax",
- ".cer",
- ".cdx",
- ".config",
- ".htaccess",
- ".htpasswd",
- ".sql",
- ".db",
- ".sqlite",
- ".sqlite3",
- }
- MIME_TYPE_MAPPING = {
- "image/jpeg": ".jpg",
- "image/png": ".png",
- "image/gif": ".gif",
- "image/webp": ".webp",
- "image/svg+xml": ".svg",
- "image/x-icon": ".ico",
- "image/bmp": ".bmp",
- "application/vnd.ms-excel": ".xls",
- "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
- "application/msword": ".doc",
- "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
- "application/pdf": ".pdf",
- "text/plain": ".txt",
- "text/csv": ".csv",
- }
- class UploadUtil:
- """
- 上传工具类
- """
- @staticmethod
- def generate_random_number() -> str:
- """
- 生成3位随机数字字符串。
- 返回:
- - str: 三位随机数字字符串。
- """
- return f"{random.randint(1, 999):03}"
- @staticmethod
- def check_file_exists(filepath: str) -> bool:
- """
- 检查文件是否存在。
- 参数:
- - filepath (str): 文件路径。
- 返回:
- - bool: 文件是否存在。
- """
- return Path(filepath).exists()
- @staticmethod
- def sanitize_filename(filename: str) -> str:
- """
- 清理文件名,移除危险字符和路径穿越。
- 参数:
- - filename (str): 原始文件名。
- 返回:
- - str: 安全的文件名。
- """
- if not filename:
- return ""
- filename = os.path.basename(filename)
- filename = re.sub(r'[<>:"/\\|?*\x00-\x1f]', "", filename)
- filename = re.sub(r"\.{2,}", ".", filename)
- filename = filename.strip(". ")
- if not filename:
- filename = f"file_{datetime.now().strftime('%Y%m%d%H%M%S')}"
- return filename
- @staticmethod
- def check_path_traversal(filename: str) -> bool:
- """
- 检查文件名是否包含路径穿越。
- 参数:
- - filename (str): 文件名。
- 返回:
- - bool: 是否安全(True 表示安全,False 表示存在路径穿越)。
- """
- dangerous_patterns = ["../", "..\\", "/", "\\", "\0"]
- for pattern in dangerous_patterns:
- if pattern in filename:
- return False
- return True
- @staticmethod
- def get_extension_from_filename(filename: str) -> str:
- """
- 从文件名获取扩展名。
- 参数:
- - filename (str): 文件名。
- 返回:
- - str: 扩展名(小写,包含点),如 ".jpg"。
- """
- if not filename or "." not in filename:
- return ""
- ext = filename.rsplit(".", 1)[-1].lower()
- return f".{ext}" if ext else ""
- @staticmethod
- def is_dangerous_extension(extension: str) -> bool:
- """
- 检查扩展名是否为危险类型。
- 参数:
- - extension (str): 文件扩展名。
- 返回:
- - bool: 是否为危险扩展名。
- """
- return extension.lower() in DANGEROUS_EXTENSIONS
- @staticmethod
- def detect_file_type(content: bytes) -> str | None:
- """
- 通过文件内容检测真实文件类型。
- 参数:
- - content (bytes): 文件内容(前几字节即可)。
- 返回:
- - str | None: 检测到的 MIME 类型,无法识别返回 None。
- """
- if content.startswith(b"\xff\xd8\xff"):
- return "image/jpeg"
- if content.startswith(b"\x89PNG\r\n\x1a\n"):
- return "image/png"
- if content.startswith(b"GIF87a") or content.startswith(b"GIF89a"):
- return "image/gif"
- if content.startswith(b"PK\x03\x04"):
- if b"[Content_Types].xml" in content[:1000]:
- return "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
- return "application/zip"
- if content.startswith(b"%PDF"):
- return "application/pdf"
- if content.startswith(b"\xd0\xcf\x11\xe0\xa1\xb1\x1a\xe1"):
- return "application/msword"
- return None
- @classmethod
- def validate_file_extension(cls, extension: str) -> bool:
- """
- 验证文件扩展名是否在允许列表中。
- 参数:
- - extension (str): 文件扩展名。
- 返回:
- - bool: 是否允许。
- 异常:
- - CustomException: 扩展名不允许时抛出。
- """
- ext_lower = extension.lower()
- if cls.is_dangerous_extension(ext_lower):
- raise CustomException(msg=f"不允许上传此类型的文件: {extension}")
- if ext_lower not in settings.ALLOWED_EXTENSIONS:
- raise CustomException(
- msg=f"文件类型不支持,允许的类型: {', '.join(settings.ALLOWED_EXTENSIONS)}"
- )
- return True
- @classmethod
- def validate_file_content_type(cls, content: bytes, claimed_extension: str) -> bool:
- """
- 验证文件内容类型与声明的扩展名是否匹配。
- 参数:
- - content (bytes): 文件内容。
- - claimed_extension (str): 声明的文件扩展名。
- 返回:
- - bool: 是否匹配。
- 异常:
- - CustomException: 类型不匹配时抛出。
- """
- detected_type = cls.detect_file_type(content)
- if detected_type:
- expected_ext = MIME_TYPE_MAPPING.get(detected_type, "")
- if expected_ext and expected_ext != claimed_extension.lower():
- log.warning(
- f"文件类型不匹配: 声明扩展名={claimed_extension}, 检测类型={detected_type}"
- )
- return True
- @staticmethod
- def check_file_size(file: UploadFile) -> bool:
- """
- 校验文件大小是否合法。
- 参数:
- - file (UploadFile): 上传的文件对象。
- 返回:
- - bool: 文件大小是否合法。
- 异常:
- - CustomException: 文件过大时抛出。
- """
- if file.size and file.size > settings.MAX_FILE_SIZE:
- raise CustomException(
- msg=f"文件大小超过限制,最大允许 {settings.MAX_FILE_SIZE // (1024 * 1024)}MB"
- )
- return True
- @classmethod
- def generate_safe_filename(cls, original_filename: str, extension: str) -> str:
- """
- 生成安全的文件名。
- 参数:
- - original_filename (str): 原始文件名。
- - extension (str): 文件扩展名。
- 返回:
- - str: 安全的文件名。
- """
- safe_name = cls.sanitize_filename(original_filename)
- if safe_name and "." in safe_name:
- name_part = safe_name.rsplit(".", 1)[0]
- else:
- name_part = safe_name or "file"
- name_part = re.sub(r"[^a-zA-Z0-9_\-\u4e00-\u9fa5]", "", name_part)
- if len(name_part) > 50:
- name_part = name_part[:50]
- timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
- random_suffix = cls.generate_random_number()
- return f"{name_part}_{timestamp}{settings.UPLOAD_MACHINE}{random_suffix}{extension}"
- @staticmethod
- def check_file_timestamp(filename: str) -> bool:
- """
- 校验文件时间戳是否合法。
- 参数:
- - filename (str): 文件名(包含时间戳片段)。
- 返回:
- - bool: 时间戳是否合法。
- """
- try:
- name_parts = filename.rsplit(".", 1)[0].split("_")
- timestamp = name_parts[-1].split(settings.UPLOAD_MACHINE)[0]
- datetime.strptime(timestamp, "%Y%m%d%H%M%S")
- return True
- except (ValueError, IndexError):
- return False
- @staticmethod
- def check_file_machine(filename: str) -> bool:
- """
- 校验文件机器码是否合法。
- 参数:
- - filename (str): 文件名。
- 返回:
- - bool: 机器码是否合法。
- """
- try:
- name_without_ext = filename.rsplit(".", 1)[0]
- return len(name_without_ext) >= 4 and name_without_ext[-4] == settings.UPLOAD_MACHINE
- except IndexError:
- return False
- @staticmethod
- def check_file_random_code(filename: str) -> bool:
- """
- 校验文件随机码是否合法。
- 参数:
- - filename (str): 文件名。
- 返回:
- - bool: 随机码是否合法(000–999)。
- """
- try:
- code = filename.rsplit(".", 1)[0][-3:]
- return code.isdigit() and 1 <= int(code) <= 999
- except IndexError:
- return False
- @staticmethod
- def generate_file(filepath: Path, chunk_size: int = 8192):
- """
- 根据文件生成二进制数据迭代器。
- 参数:
- - filepath (Path): 文件路径。
- - chunk_size (int): 分块大小,默认 8192 字节。
- 返回:
- - Iterator[bytes]: 文件二进制数据分块迭代器。
- """
- with filepath.open("rb") as f:
- while chunk := f.read(chunk_size):
- yield chunk
- @staticmethod
- def delete_file(filepath: Path) -> bool:
- """
- 删除文件。
- 参数:
- - filepath (Path): 文件路径。
- 返回:
- - bool: 删除是否成功。
- """
- try:
- filepath.unlink(missing_ok=True)
- return True
- except OSError:
- return False
- @classmethod
- async def upload_file(cls, file: UploadFile, base_url: str) -> tuple[str, Path, str]:
- """
- 安全文件上传。
- 参数:
- - file (UploadFile): 上传的文件对象。
- - base_url (str): 基础 URL。
- 返回:
- - tuple[str, Path, str]: (文件名, 文件路径, 文件 URL)。
- 异常:
- - CustomException: 当文件校验失败时抛出。
- """
- if not file or not file.filename:
- raise CustomException(msg="请选择要上传的文件")
- original_filename = file.filename
- if not cls.check_path_traversal(original_filename):
- log.error(f"检测到路径穿越攻击: {original_filename}")
- raise CustomException(msg="文件名包含非法字符")
- extension = cls.get_extension_from_filename(original_filename)
- if not extension:
- raise CustomException(msg="无法识别文件类型")
- cls.validate_file_extension(extension)
- cls.check_file_size(file)
- content = await file.read()
- await file.seek(0)
- cls.validate_file_content_type(content, extension)
- safe_filename = cls.generate_safe_filename(original_filename, extension)
- try:
- dir_path = settings.UPLOAD_FILE_PATH.joinpath(datetime.now().strftime("%Y/%m/%d"))
- dir_path.mkdir(parents=True, exist_ok=True)
- filepath = dir_path.joinpath(safe_filename)
- if not filepath.resolve().is_relative_to(settings.UPLOAD_FILE_PATH.resolve()):
- log.error(f"检测到路径穿越攻击,目标路径: {filepath}")
- raise CustomException(msg="非法的文件路径")
- file_url = urljoin(base_url, str(filepath))
- chunk_size = 8 * 1024 * 1024
- async with aiofiles.open(filepath, "wb") as f:
- while chunk := await file.read(chunk_size):
- await f.write(chunk)
- log.info(f"文件上传成功: {safe_filename}")
- return safe_filename, filepath, file_url
- except CustomException:
- raise
- except Exception as e:
- log.error(f"文件上传失败: {e}")
- raise CustomException(msg=f"文件上传失败: {e}")
- @staticmethod
- def get_file_tree(file_path: str) -> list[dict]:
- """
- 获取文件树结构。
- 参数:
- - file_path (str): 文件路径。
- 返回:
- - list[dict]: 文件树列表。
- """
- return [{"name": item.name, "is_dir": item.is_dir()} for item in Path(file_path).iterdir()]
- @classmethod
- async def download_file(cls, file_path: str) -> str:
- """
- 下载文件,生成新的文件名。
- 参数:
- - file_path (str): 文件路径。
- 返回:
- - str: 文件下载信息。
- """
- filename = cls.generate_file(Path(file_path))
- return str(filename)
|