tts_client.py 20 KB


  1. """
  2. TTS语音合成客户端模块 - 优化版
  3. 优化内容:
  4. 1. AudioStreamManager - 减少锁开销,使用bytearray缓冲,自动获取bytes_per_sample
  5. 2. ByteFIFO - 改进read方法,减少重复操作
  6. 3. WebSocket连接管理 - 使用Event替代轮询,改进音频缓存合并
  7. 4. 缓存播放 - 使用Event替代sleep,改进对齐计算
  8. 5. 异常处理 - 精确捕获异常类型,优化日志输出
  9. 6. 全局控制逻辑 - 使用Event替代硬等待,改进状态管理
  10. """
  11. import websocket
  12. import sounddevice as sd
  13. from time import mktime
  14. from wsgiref.handlers import format_date_time
  15. from urllib.parse import urlparse, urlencode
  16. import threading
  17. import time
  18. import json
  19. import hmac
  20. import hashlib
  21. import datetime
  22. import base64
  23. import queue
  24. import platform
  25. system = platform.system().lower()
  26. if system == "linux":
  27. sd.default.device = 'pulse'
  28. elif system == "windows":
  29. sd.default.device = None
  30. elif system == "darwin":
  31. sd.default.device = None
  32. import io
  33. from collections import deque
  34. from config.config.settings import config
  35. from utils.pc2_requests import _send_led_color_task
  36. from utils.tts_cache import get_cached_audio, save_audio_cache
  37. from utils.logger import logger
  38. # ==============================
  39. # 全局音频配置
  40. # ==============================
  41. UNIFIED_AUDIO_CONFIG = {
  42. 'samplerate': 16000,
  43. 'channels': 1,
  44. 'dtype': 'int16', # 16-bit PCM
  45. 'blocksize': 512, # 非 Windows
  46. 'latency': 'low',
  47. 'device': None,
  48. }
  49. WINDOWS_AUDIO_CONFIG = {
  50. 'samplerate': 16000,
  51. 'channels': 1,
  52. 'dtype': 'int16',
  53. 'blocksize': 1024, # Windows 稳定性更好
  54. 'latency': 'low',
  55. 'device': None,
  56. }
  57. # ==============================
  58. # 优化版字节 FIFO 缓冲(线程安全)
  59. # ==============================
  60. class ByteFIFO:
  61. def __init__(self):
  62. self._buf = deque()
  63. self._size = 0
  64. self._lock = threading.Lock()
  65. def write(self, data: bytes):
  66. if not data:
  67. return
  68. with self._lock:
  69. self._buf.append(data)
  70. self._size += len(data)
  71. def read(self, nbytes: int) -> bytes:
  72. """优化版read方法,减少重复操作"""
  73. if nbytes <= 0:
  74. return b''
  75. out = bytearray()
  76. with self._lock:
  77. while nbytes > 0 and self._buf:
  78. chunk = self._buf[0]
  79. take = min(len(chunk), nbytes)
  80. out.extend(chunk[:take])
  81. if take < len(chunk):
  82. self._buf[0] = chunk[take:]
  83. else:
  84. self._buf.popleft()
  85. self._size -= take
  86. nbytes -= take
  87. return bytes(out)
  88. def clear(self):
  89. with self._lock:
  90. self._buf.clear()
  91. self._size = 0
  92. def __len__(self):
  93. with self._lock:
  94. return self._size
  95. # ==============================
  96. # 优化版音频流管理器(单例 RawOutputStream + 回调)
  97. # ==============================
  98. class AudioStreamManager:
  99. def __init__(self):
  100. self._lock = threading.Lock()
  101. self._stream = None
  102. self._fifo = ByteFIFO()
  103. self._is_windows = platform.system().lower() == 'windows'
  104. self._config = WINDOWS_AUDIO_CONFIG if self._is_windows else UNIFIED_AUDIO_CONFIG
  105. # 自动获取每帧字节数,更通用
  106. try:
  107. # 创建临时流来获取dtype信息
  108. temp_stream = sd.RawOutputStream(
  109. samplerate=self._config['samplerate'],
  110. channels=self._config['channels'],
  111. dtype=self._config['dtype'],
  112. blocksize=1, # 最小块大小
  113. device=self._config['device']
  114. )
  115. self._bytes_per_sample = temp_stream.dtype.itemsize
  116. temp_stream.close()
  117. except Exception:
  118. # 回退到默认值
  119. self._bytes_per_sample = 2 # int16
  120. self._frame_bytes = self._config['channels'] * self._bytes_per_sample
  121. def get_audio_config(self):
  122. return self._config
  123. def _callback(self, outdata, frames, time_info, status):
  124. """优化版回调,使用memoryview直接填充,避免bytes拼接"""
  125. required = frames * self._frame_bytes
  126. data = self._fifo.read(required)
  127. if len(data) < required:
  128. # 使用memoryview直接填充零,避免bytes拼接
  129. outdata[:len(data)] = data
  130. outdata[len(data):] = b'\x00' * (required - len(data))
  131. else:
  132. outdata[:] = data
  133. def init_stream(self):
  134. with self._lock:
  135. if self._stream and self._stream.active:
  136. return self._stream
  137. cfg = self._config
  138. self._fifo.clear()
  139. self._stream = sd.RawOutputStream(
  140. samplerate=cfg['samplerate'],
  141. channels=cfg['channels'],
  142. dtype=cfg['dtype'],
  143. blocksize=cfg['blocksize'],
  144. latency=cfg['latency'],
  145. device=cfg['device'],
  146. callback=self._callback,
  147. )
  148. self._stream.start()
  149. logger.info(f"[音频] RawOutputStream 初始化成功: blocksize={cfg['blocksize']}, samplerate={cfg['samplerate']}")
  150. return self._stream
  151. def play_bytes(self, chunk: bytes):
  152. """喂入原始 PCM 字节(16kHz, 16bit, mono)"""
  153. if not chunk:
  154. return
  155. self._fifo.write(chunk)
  156. def clear_buffer(self):
  157. """清空音频缓冲区"""
  158. with self._lock:
  159. self._fifo.clear()
  160. logger.info("[音频] 缓冲区已清空")
  161. def stop_stream(self):
  162. with self._lock:
  163. if self._stream:
  164. try:
  165. self._stream.stop()
  166. self._stream.close()
  167. logger.info("[音频] 音频流已安全关闭")
  168. except Exception as e:
  169. logger.info(f"[音频] 关闭失败: {e}")
  170. finally:
  171. self._stream = None
  172. self._fifo.clear()
  173. def expected_chunk_size(self):
  174. """建议的投喂对齐字节数(一个 block 对应的字节量)"""
  175. return self._config['blocksize'] * self._frame_bytes
  176. _audio_manager = AudioStreamManager()
  177. # ==============================
  178. # 优化版 TTS 客户端
  179. # ==============================
  180. class AIUITTSClient:
  181. def __init__(self, text: str, use_cache: bool = True):
  182. if not text or not text.strip():
  183. raise ValueError("文本内容不能为空")
  184. self.text = text.strip()
  185. self.use_cache = use_cache
  186. self.call_time = time.time()
  187. self.start_play_time = None
  188. from config.config.settings import config
  189. self.handshake = self.assemble_auth_url(
  190. config.XUNFEI_STREAMING_TTS_URL)
  191. self.ws = None
  192. self._play_thread = None
  193. self._cache_play_thread = None # 缓存播放线程引用
  194. # 使用BytesIO替代列表,减少内存分配
  195. self._audio_buffer = io.BytesIO()
  196. self._interrupted = threading.Event()
  197. self._connection_established = threading.Event() # 使用Event替代bool
  198. self._audio_done_event = threading.Event() # 音频完成事件
  199. self._max_retries = 3 # 最大重试次数
  200. self._retry_count = 0 # 当前重试次数
  201. def is_connection_ready(self, ws):
  202. """检查WebSocket连接是否就绪"""
  203. try:
  204. return (ws and ws.sock and
  205. hasattr(ws.sock, 'connected') and
  206. ws.sock.connected and
  207. not self._interrupted.is_set())
  208. except Exception:
  209. return False
  210. def assemble_auth_url(self, base_url):
  211. from config.config.settings import config
  212. host = urlparse(base_url).netloc
  213. path = urlparse(base_url).path
  214. now = datetime.datetime.now()
  215. date = format_date_time(mktime(now.timetuple()))
  216. signature_origin = f"host: {host}\ndate: {date}\nGET {path} HTTP/1.1"
  217. signature_sha = hmac.new(config.XUNFEI_API_SECRET.encode(),
  218. signature_origin.encode(), digestmod=hashlib.sha256).digest()
  219. signature_base64 = base64.b64encode(signature_sha).decode()
  220. authorization_origin = (
  221. f'api_key="{config.XUNFEI_API_KEY}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_base64}"'
  222. )
  223. authorization = base64.b64encode(
  224. authorization_origin.encode()).decode()
  225. v = {"host": host, "date": date, "authorization": authorization}
  226. return base_url + "?" + urlencode(v)
  227. def send_text(self):
  228. """发送文本数据,优化连接检查"""
  229. if not self._is_ws_connected():
  230. raise Exception("WebSocket连接不可用")
  231. from config.config.settings import config
  232. aiui_data = {
  233. "header": {
  234. "sn": "yd-00:00:00:00:00:01",
  235. "appid": config.XUNFEI_APP_ID,
  236. "stmid": "text-1",
  237. "status": 3,
  238. "scene": 'IFLYTEK.hts'
  239. },
  240. "parameter": {
  241. "tts": {
  242. "vcn": config.TTS_VOICE,
  243. "tts": {
  244. "channels": 1,
  245. "sample_rate": 16000,
  246. "bit_depth": 16,
  247. "encoding": "raw"
  248. }
  249. }
  250. },
  251. "payload": {
  252. "text": {
  253. "compress": "raw",
  254. "format": "plain",
  255. "text": base64.b64encode(self.text.encode()).decode(),
  256. "encoding": "utf8",
  257. "status": 3
  258. }
  259. }
  260. }
  261. self.ws.send(json.dumps(aiui_data, ensure_ascii=False))
  262. def _is_ws_connected(self):
  263. """辅助函数:检查WebSocket连接状态"""
  264. try:
  265. return (self.ws and hasattr(self.ws, 'sock') and self.ws.sock and
  266. hasattr(self.ws.sock, 'connected') and self.ws.sock.connected)
  267. except Exception:
  268. return False
  269. def interrupt(self):
  270. """中断播放"""
  271. self._interrupted.set()
  272. logger.info("[中断] 开始中断播放")
  273. # 完全停止音频流,立即停止播放
  274. _audio_manager.stop_stream()
  275. logger.info("[TTS] 音频流已完全停止")
  276. # 关闭WebSocket连接
  277. if self.ws:
  278. try:
  279. # 标记连接为不可用
  280. self._connection_established.clear()
  281. # 关闭连接
  282. self.ws.close()
  283. logger.info("[中断] WebSocket连接已关闭")
  284. except Exception as e:
  285. logger.info(f"[中断] 关闭WebSocket失败: {e}")
  286. finally:
  287. # 确保WebSocket对象被清空
  288. self.ws = None
  289. logger.info("[中断] 播放已中断")
  290. # ---- WebSocket 回调 ----
  291. def on_open(self, ws):
  292. logger.info("[TTS] WebSocket连接已建立")
  293. self._connection_established.set()
  294. self._retry_count = 0 # 重置重试计数
  295. # 连接建立后立即发送数据
  296. try:
  297. # 检查连接是否仍然有效
  298. if self._is_ws_connected():
  299. self.send_text()
  300. logger.info("[TTS] 数据发送成功")
  301. else:
  302. logger.warning("[TTS] 连接已无效,无法发送数据")
  303. self._connection_established.clear()
  304. except Exception as e:
  305. logger.warning(f"[TTS] 发送数据失败: {e}")
  306. # 发送失败时,标记连接为不可用
  307. self._connection_established.clear()
  308. def on_message(self, ws, message):
  309. try:
  310. data = json.loads(message)
  311. if data.get("header", {}).get("code") != 0:
  312. logger.info(f"TTS错误: {data}")
  313. ws.close()
  314. return
  315. if self._interrupted.is_set():
  316. ws.close()
  317. return
  318. tts_audio_b64 = data.get("payload", {}).get("tts", {}).get("audio")
  319. if tts_audio_b64:
  320. try:
  321. # 16kHz/16bit/mono 裸PCM
  322. pcm = base64.b64decode(tts_audio_b64)
  323. self._audio_buffer.write(pcm) # 写入BytesIO
  324. _audio_manager.play_bytes(pcm) # 交给回调流
  325. except Exception as e:
  326. logger.info(f"[错误] 音频解码失败: {e}")
  327. # 检查是否完成
  328. if data.get("header", {}).get("status") == 2:
  329. logger.info("[TTS] 收到结束状态,准备关闭连接")
  330. # 收尾:缓存
  331. if self.use_cache and self._audio_buffer.tell() > 0:
  332. try:
  333. audio_data = self._audio_buffer.getvalue()
  334. save_audio_cache(self.text, audio_data)
  335. logger.info(f"[TTS] 缓存保存成功,总大小: {len(audio_data)} bytes")
  336. except Exception as e:
  337. logger.info(f"[TTS] 缓存保存失败: {e}")
  338. # 等待音频播放完成后再关闭连接
  339. self._wait_for_audio_completion()
  340. # 确保音频流停止
  341. if _audio_manager._stream and _audio_manager._stream.active:
  342. try:
  343. _audio_manager._stream.stop()
  344. logger.info("[TTS] 音频流已停止")
  345. except Exception as e:
  346. logger.info(f"[TTS] 停止音频流失败: {e}")
  347. # 关闭WebSocket连接
  348. ws.close()
  349. logger.info("[TTS] WebSocket连接已关闭")
  350. except json.JSONDecodeError as e:
  351. logger.info(f"[TTS] JSON解析错误: {e}")
  352. except Exception as e:
  353. logger.info(f"[TTS] 消息处理异常: {e}")
  354. import traceback
  355. traceback.print_exc()
  356. def _wait_for_audio_completion(self):
  357. """优化版等待音频播放完成,使用Event替代轮询"""
  358. try:
  359. if self._audio_buffer.tell() == 0:
  360. logger.info("[TTS] 无音频数据,无需等待")
  361. return
  362. # 计算总音频时长
  363. total_bytes = self._audio_buffer.tell()
  364. cfg = _audio_manager.get_audio_config()
  365. bytes_per_sec = cfg['samplerate'] * cfg['channels'] * 2
  366. audio_duration = total_bytes / max(1, bytes_per_sec)
  367. logger.info(f"[TTS] 音频总时长: {audio_duration:.2f}秒,等待播放完成...")
  368. # 使用Event等待,更高效
  369. self._audio_done_event.wait(timeout=audio_duration + 0.5)
  370. logger.info("[TTS] 音频播放完成")
  371. # 播放完成后,完全停止音频流,避免继续播放静音
  372. _audio_manager.stop_stream()
  373. logger.info("[TTS] 音频流已完全停止")
  374. except Exception as e:
  375. logger.info(f"[TTS] 等待音频完成时出错: {e}")
  376. def on_error(self, ws, error):
  377. # 改进错误处理:区分不同类型的错误
  378. error_str = str(error) if error else "未知错误"
  379. # 忽略一些常见的非致命错误
  380. if "already closed" in error_str.lower() or "connection is closed" in error_str.lower():
  381. logger.info(f"[TTS] 连接已关闭,忽略错误: {error_str}")
  382. return
  383. # 记录其他错误但不立即关闭连接
  384. logger.info(f"[TTS] 连接错误: {error_str}")
  385. # 只有在严重错误时才关闭连接
  386. if "timeout" in error_str.lower() or "connection refused" in error_str.lower():
  387. logger.info("[TTS] 严重连接错误,关闭连接")
  388. ws.close()
  389. def on_close(self, ws, code, reason):
  390. self._connection_established.clear()
  391. close_info = f"代码:{code}" if code else ""
  392. if reason:
  393. close_info += f", 原因:{reason}"
  394. logger.info(f"[TTS] 连接关闭 {close_info}")
  395. # 连接关闭时,完全停止音频流
  396. _audio_manager.stop_stream()
  397. logger.info("[TTS] 连接关闭时音频流已完全停止")
  398. # ---- 优化版缓存播放 ----
  399. def play_cached_audio(self, audio_data: bytes):
  400. logger.info(f"[缓存] 开始播放,大小={len(audio_data)} bytes")
  401. _audio_manager.init_stream()
  402. # 尽量按块对齐喂入
  403. align = _audio_manager.expected_chunk_size()
  404. if align <= 0:
  405. align = 2048
  406. for i in range(0, len(audio_data), align):
  407. # 检查是否被打断
  408. if self._interrupted.is_set():
  409. logger.info("[缓存] 播放被中断")
  410. return
  411. chunk = audio_data[i:i+align]
  412. _audio_manager.play_bytes(chunk)
  413. # 使用Event等待,更响应中断
  414. if self._interrupted.wait(0.01):
  415. logger.info("[缓存] 播放被中断")
  416. return
  417. # 播放完成后的等待,也要检查打断状态
  418. if not self._interrupted.is_set():
  419. cfg = _audio_manager.get_audio_config()
  420. bytes_per_sec = cfg['samplerate'] * cfg['channels'] * 2
  421. remaining_time = len(audio_data) / max(1, bytes_per_sec) + 0.05
  422. # 使用Event等待,更高效
  423. self._interrupted.wait(timeout=remaining_time)
  424. def start(self):
  425. # 先播缓存
  426. if self.use_cache:
  427. cached_audio = get_cached_audio(self.text)
  428. if cached_audio:
  429. logger.info("[缓存] 命中,使用缓存播放")
  430. self._cache_play_thread = threading.Thread(
  431. target=self.play_cached_audio, args=(cached_audio,), daemon=True)
  432. self._cache_play_thread.start()
  433. return # 确保缓存播放后直接返回,不执行API请求
  434. # 初始化音频流并请求合成
  435. _audio_manager.init_stream()
  436. logger.info(f"[API] 请求TTS: {self.text[:30]}...")
  437. # 确保每次都创建新的WebSocket连接
  438. self._create_new_connection()
  439. def _create_new_connection(self):
  440. """创建新的WebSocket连接"""
  441. logger.info("[TTS] 开始创建新的WebSocket连接")
  442. # 重置连接状态
  443. self._connection_established.clear()
  444. self._interrupted.clear()
  445. logger.info("[TTS] 连接状态已重置")
  446. # 创建新的WebSocket连接
  447. self.ws = websocket.WebSocketApp(
  448. self.handshake,
  449. on_open=self.on_open,
  450. on_message=self.on_message,
  451. on_error=self.on_error,
  452. on_close=self.on_close,
  453. )
  454. logger.info("[TTS] WebSocket对象已创建,准备启动连接线程")
  455. # 启动WebSocket线程
  456. self._play_thread = threading.Thread(
  457. target=lambda: self.ws.run_forever(), daemon=True)
  458. self._play_thread.start()
  459. logger.info("[TTS] WebSocket连接线程已启动")
  460. # ==============================
  461. # 优化版全局控制
  462. # ==============================
  463. _current_tts_client = None
  464. _lock = threading.Lock()
  465. _playback_done_event = threading.Event() # 播放完成事件
  466. def play_text_async(text: str, use_cache: bool = True):
  467. # 安全替换昵称
  468. machinename = getattr(config, 'machinename', None) or '小勇'
  469. text = text.replace('小勇', machinename)
  470. global _current_tts_client
  471. if not text or not isinstance(text, str) or text.strip() == "":
  472. logger.info("[错误] 文本内容不能为空")
  473. return
  474. with _lock:
  475. if _current_tts_client:
  476. logger.info("[停止] 打断旧播放")
  477. _current_tts_client.interrupt()
  478. _current_tts_client = None
  479. # 使用Event等待,更高效
  480. _playback_done_event.wait(timeout=0.3)
  481. # 创建新的TTS客户端
  482. tts_client = AIUITTSClient(text, use_cache)
  483. _current_tts_client = tts_client
  484. # 启动新的TTS客户端
  485. tts_client.start()
  486. def stop_playback():
  487. global _current_tts_client
  488. with _lock:
  489. if _current_tts_client:
  490. _current_tts_client.interrupt()
  491. _current_tts_client = None
  492. _audio_manager.stop_stream()
  493. def is_playing() -> bool:
  494. with _lock:
  495. return _current_tts_client is not None and not _current_tts_client._interrupted.is_set()