snowflake.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Twitter 的雪花算法(Snowflake)实现
  4. #
  5. # 64位结构:
  6. #
  7. # ---------------------------------------------------------------------------------
  8. # | 符号位(1bits) | 时间戳(41bits) | 数据中心(5bits) | 机器标识(5bits) | 序列号(12bits) |
  9. # ---------------------------------------------------------------------------------
  10. import threading
  11. import time
  12. # 纪元时间戳
  13. DEFAULT_EPOCH: int = 1480166465631
  14. # 序列号需占用位数
  15. SEQUENCE_BITS: int = 12
  16. # 机器标识需占用位数
  17. MACHINE_BITS: int = 5
  18. # 数据中心需占用位数
  19. DATACENTER_BITS: int = 5
  20. # 时间戳、数据中心、机器标识所需左移位数
  21. MACHINE_SHIFT: int = SEQUENCE_BITS
  22. DATACENTER_SHIFT: int = SEQUENCE_BITS + MACHINE_BITS
  23. TIMESTAMP_SHIFT: int = SEQUENCE_BITS + MACHINE_BITS + DATACENTER_BITS
  24. # 数据中心、机器标识、序列号所支持的最大值
  25. DATACENTER_MAX: int = -1 ^ (-1 << DATACENTER_BITS)
  26. MACHINE_MAX: int = -1 ^ (-1 << MACHINE_BITS)
  27. SEQUENCE_MAX: int = -1 ^ (-1 << SEQUENCE_BITS)
  28. class SnowflakeId:
  29. def __init__(self, datacenter_id: int, machine_id: int):
  30. if datacenter_id > DATACENTER_MAX or datacenter_id < 0:
  31. raise ValueError(f"Datacenter id must be between 0 and {DATACENTER_MAX}")
  32. if machine_id > MACHINE_MAX or machine_id < 0:
  33. raise ValueError(f"Machine id must be between 0 and {MACHINE_MAX}")
  34. self.lock = threading.Lock()
  35. self.datacenter_id: int = datacenter_id
  36. self.machine_id: int = machine_id
  37. self.sequence: int = 0
  38. self.last_milli: int = -1
  39. self.epoch: int = DEFAULT_EPOCH
  40. def next_id(self) -> int:
  41. with self.lock:
  42. curr_milli: int = self.now_milli
  43. if curr_milli < self.last_milli:
  44. raise RuntimeError(f"Clock moved backwards. Refusing to generate id for {self.last_milli - curr_milli} milliseconds")
  45. if curr_milli == self.last_milli:
  46. self.sequence = (self.sequence + 1) & SEQUENCE_MAX
  47. if self.sequence == 0:
  48. curr_milli = self.next_milli
  49. else:
  50. self.sequence = 0
  51. self.last_milli = curr_milli
  52. # 1 位符号位(固定为0)
  53. # 41 位时间戳 + 10
  54. # 10 位机器ID(这里为machineId + datacenterId)
  55. # 12 位序列号
  56. return (
  57. (curr_milli - self.epoch) << TIMESTAMP_SHIFT
  58. | self.datacenter_id << DATACENTER_SHIFT
  59. | self.machine_id << MACHINE_SHIFT
  60. | self.sequence
  61. )
  62. @property
  63. def next_milli(self) -> int:
  64. curr_milli = self.now_milli
  65. while curr_milli <= self.last_milli:
  66. curr_milli = self.now_milli
  67. return curr_milli
  68. @property
  69. def now_milli(self) -> int:
  70. return int(time.time() * 1000)
  71. ID_GENERATOR = SnowflakeId(datacenter_id=1, machine_id=1)
  72. def get_snowflake_id() -> int:
  73. """
  74. 生成雪花算法ID
  75. 返回:
  76. - int: 雪花算法生成的唯一ID
  77. """
  78. return ID_GENERATOR.next_id()
  79. def get_snowflake_id_str(tenant_id: int) -> str:
  80. """
  81. 生成雪花算法ID字符串
  82. 返回:
  83. - str: 雪花算法生成的唯一ID字符串
  84. """
  85. # 保证id为19位,不足补0
  86. return f"{get_snowflake_id():019d}{tenant_id}"
  87. def extract_tenant_id_from_id_str(id_str: str) -> int:
  88. """
  89. 从雪花算法ID字符串中提取租户ID
  90. 参数:
  91. - id_str: 雪花算法生成的唯一ID字符串
  92. 返回:
  93. - int: 租户ID
  94. """
  95. # 从19位后开始提取,先判断是否足够长
  96. if len(id_str) <= 19:
  97. raise ValueError(f"ID字符串长度不足,必须为19位以上")
  98. return int(id_str[19:])