crud.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. from collections.abc import Sequence
  2. from typing import Any, List, Optional
  3. from datetime import datetime
  4. from decimal import Decimal
  5. from sqlalchemy import func, select
  6. from sqlalchemy.engine import Result
  7. from app.api.v1.module_system.auth.schema import AuthSchema
  8. from app.core.base_crud import CRUDBase
  9. from app.core.exceptions import CustomException
  10. from .model import AlipayNotifyLogModel, PayBillModel, PayBillOrderModel, PayBillVoucherModel
  11. class AlipayNotifyLogCRUD(CRUDBase[AlipayNotifyLogModel, Any, Any]):
  12. """支付宝通知日志 CRUD 操作"""
  13. def __init__(self, auth: AuthSchema) -> None:
  14. self.auth = auth
  15. super().__init__(model=AlipayNotifyLogModel, auth=auth)
  16. async def get_by_notify_id(
  17. self, notify_id: str
  18. ) -> AlipayNotifyLogModel | None:
  19. return await self.get(notify_id=notify_id)
  20. async def update_by_notify_id(
  21. self, notify_id: str, data: dict
  22. ) -> AlipayNotifyLogModel | None:
  23. obj = await self.get(notify_id=notify_id, preload=[])
  24. if not obj:
  25. raise CustomException(msg="通知日志不存在")
  26. if self.auth.user and hasattr(obj, "updated_id"):
  27. setattr(obj, "updated_id", self.auth.user.id)
  28. for key, value in data.items():
  29. if hasattr(obj, key):
  30. setattr(obj, key, value)
  31. await self.auth.db.flush()
  32. await self.auth.db.refresh(obj)
  33. return obj
  34. class BillCRUD(CRUDBase[PayBillModel, Any, Any]):
  35. """账单 CRUD 操作"""
  36. def __init__(self, auth: AuthSchema) -> None:
  37. self.auth = auth
  38. super().__init__(model=PayBillModel, auth=auth)
  39. async def get_by_pay_no(
  40. self, pay_no: str
  41. ) -> PayBillModel | None:
  42. """根据账单号查询账单"""
  43. return await self.get(pay_no=pay_no)
  44. async def get_by_enterprise_id(
  45. self, enterprise_id: str
  46. ) -> Sequence[PayBillModel]:
  47. """根据企业ID查询账单列表"""
  48. return await self.list({"enterprise_id": enterprise_id})
  49. async def update_by_pay_no(
  50. self, pay_no: str, data: dict
  51. ) -> PayBillModel | None:
  52. """根据账单号更新账单"""
  53. obj = await self.get(pay_no=pay_no, preload=[])
  54. if not obj:
  55. raise CustomException(msg="账单不存在")
  56. if self.auth.user and hasattr(obj, "updated_id"):
  57. setattr(obj, "updated_id", self.auth.user.id)
  58. for key, value in data.items():
  59. if hasattr(obj, key):
  60. setattr(obj, key, value)
  61. await self.auth.db.flush()
  62. await self.auth.db.refresh(obj)
  63. return obj
  64. async def get_consume_amount(
  65. self,
  66. enterprise_id: Optional[str] = None,
  67. start_date: Optional[datetime] = None,
  68. end_date: Optional[datetime] = None,
  69. tenant_id: Optional[int] = None,
  70. ) -> Decimal:
  71. """统计消费金额:consume_type=CONSUME, status=PROCESSED,SUM consume_amount"""
  72. conditions = [
  73. PayBillModel.consume_type == "CONSUME",
  74. PayBillModel.status == "PROCESSED",
  75. ]
  76. if tenant_id:
  77. conditions.append(PayBillModel.tenant_id == tenant_id)
  78. if enterprise_id:
  79. conditions.append(PayBillModel.enterprise_id == enterprise_id)
  80. if start_date:
  81. conditions.append(PayBillModel.gmt_biz_create >= start_date)
  82. if end_date:
  83. conditions.append(PayBillModel.gmt_biz_create <= end_date)
  84. try:
  85. sql = select(func.sum(PayBillModel.consume_amount).label("total_amount")).where(
  86. *conditions
  87. )
  88. sql = await self.filter_permissions(sql)
  89. result: Result = await self.auth.db.execute(sql)
  90. return result.scalars().first() or Decimal(0)
  91. except Exception as e:
  92. raise CustomException(msg=f"列表查询失败: {e!s}")
  93. async def create_or_update(
  94. self, pay_no: str, data: dict
  95. ) -> PayBillModel:
  96. """创建或更新账单"""
  97. obj = await self.get(pay_no=pay_no, preload=[])
  98. if obj:
  99. if self.auth.user and hasattr(obj, "updated_id"):
  100. setattr(obj, "updated_id", self.auth.user.id)
  101. for key, value in data.items():
  102. if hasattr(obj, key):
  103. setattr(obj, key, value)
  104. await self.auth.db.flush()
  105. await self.auth.db.refresh(obj)
  106. return obj
  107. else:
  108. data["pay_no"] = pay_no
  109. if "tenant_id" not in data:
  110. data["tenant_id"] = self.auth.tenant_id
  111. return await self.create(data=data, skip_tenant_id=True)
  112. class BillOrderCRUD(CRUDBase[PayBillOrderModel, Any, Any]):
  113. """订单 CRUD 操作"""
  114. def __init__(self, auth: AuthSchema) -> None:
  115. self.auth = auth
  116. super().__init__(model=PayBillOrderModel, auth=auth)
  117. async def get_by_order_no(
  118. self, order_no: str
  119. ) -> PayBillOrderModel | None:
  120. """根据订单号查询订单"""
  121. return await self.get(order_no=order_no)
  122. async def get_by_pay_no(
  123. self, pay_no: str
  124. ) -> Sequence[PayBillOrderModel]:
  125. """根据账单号查询订单列表"""
  126. return await self.list({"pay_no": pay_no})
  127. async def update_by_order_no(
  128. self, order_no: str, data: dict
  129. ) -> PayBillOrderModel | None:
  130. """根据订单号更新订单"""
  131. obj = await self.get(order_no=order_no, preload=[])
  132. if not obj:
  133. raise CustomException(msg="订单不存在")
  134. if self.auth.user and hasattr(obj, "updated_id"):
  135. setattr(obj, "updated_id", self.auth.user.id)
  136. for key, value in data.items():
  137. if hasattr(obj, key):
  138. setattr(obj, key, value)
  139. await self.auth.db.flush()
  140. await self.auth.db.refresh(obj)
  141. return obj
  142. async def create_or_update(
  143. self, order_no: str, data: dict
  144. ) -> PayBillOrderModel:
  145. """创建或更新订单"""
  146. obj = await self.get(order_no=order_no, preload=[])
  147. if obj:
  148. if self.auth.user and hasattr(obj, "updated_id"):
  149. setattr(obj, "updated_id", self.auth.user.id)
  150. for key, value in data.items():
  151. if hasattr(obj, key):
  152. setattr(obj, key, value)
  153. await self.auth.db.flush()
  154. await self.auth.db.refresh(obj)
  155. return obj
  156. else:
  157. data["order_no"] = order_no
  158. if "tenant_id" not in data:
  159. data["tenant_id"] = self.auth.tenant_id
  160. return await self.create(data=data, skip_tenant_id=True)
  161. class BillVoucherCRUD(CRUDBase[PayBillVoucherModel, Any, Any]):
  162. """凭证 CRUD 操作"""
  163. def __init__(self, auth: AuthSchema) -> None:
  164. self.auth = auth
  165. super().__init__(model=PayBillVoucherModel, auth=auth)
  166. async def get_by_voucher_id(
  167. self, voucher_id: str
  168. ) -> PayBillVoucherModel | None:
  169. """根据凭证ID查询凭证"""
  170. return await self.get(voucher_id=voucher_id)
  171. async def get_by_pay_no(
  172. self, pay_no: str
  173. ) -> Sequence[PayBillVoucherModel]:
  174. """根据账单号查询凭证列表"""
  175. return await self.list({"pay_no": pay_no})
  176. async def update_by_voucher_id(
  177. self, voucher_id: str, data: dict
  178. ) -> PayBillVoucherModel | None:
  179. """根据凭证ID更新凭证"""
  180. obj = await self.get(voucher_id=voucher_id, preload=[])
  181. if not obj:
  182. raise CustomException(msg="凭证不存在")
  183. if self.auth.user and hasattr(obj, "updated_id"):
  184. setattr(obj, "updated_id", self.auth.user.id)
  185. for key, value in data.items():
  186. if hasattr(obj, key):
  187. setattr(obj, key, value)
  188. await self.auth.db.flush()
  189. await self.auth.db.refresh(obj)
  190. return obj
  191. async def create_or_update(
  192. self, voucher_id: str, data: dict
  193. ) -> PayBillVoucherModel:
  194. """创建或更新凭证"""
  195. obj = await self.get(voucher_id=voucher_id, preload=[])
  196. if obj:
  197. if self.auth.user and hasattr(obj, "updated_id"):
  198. setattr(obj, "updated_id", self.auth.user.id)
  199. for key, value in data.items():
  200. if hasattr(obj, key):
  201. setattr(obj, key, value)
  202. await self.auth.db.flush()
  203. await self.auth.db.refresh(obj)
  204. return obj
  205. else:
  206. data["voucher_id"] = voucher_id
  207. if "tenant_id" not in data:
  208. data["tenant_id"] = self.auth.tenant_id
  209. return await self.create(data=data, skip_tenant_id=True)