aiui.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. import _thread as thread
  2. import base64
  3. import datetime
  4. import hashlib
  5. import hmac
  6. import json
  7. import traceback
  8. from urllib.parse import urlparse
  9. import time
  10. from datetime import datetime
  11. from time import mktime
  12. from urllib.parse import urlencode
  13. from wsgiref.handlers import format_date_time
  14. import websocket
  15. # 修改应用应用配置和文件地址后直接执行即可
  16. # 请求地址
  17. url = "wss://aiui.xf-yun.com/v3/aiint/sos"
  18. # 应用配置
  19. appid = "f016fce3"
  20. api_key = "fcb58dc79de9b0568d2287bd8184e291"
  21. api_secret = "YTFiN2NkOGVjNTVjY2QyMTlmMTViOTBh"
  22. sn = "test-sn"
  23. # 场景
  24. scene = "test_box"
  25. vcn = "x5_lingxiaoyue_flow"
  26. # 请求类型用来设置文本请求还是音频请求,text/audio
  27. data_type = 'text'
  28. # 音频请求需要先设置audio_path
  29. # 当前音频格式默认pcm 16k 16bit,修改音频格式需要修改audioReq中的payload中音频相关参数
  30. # data_type = 'audio'
  31. # 音频请求上传的音频文件路径
  32. text_msg = ""
  33. audio_path = "weather.pcm"
  34. # 文本请求输入的文本
  35. question = "介绍下苏超?"
  36. question = "你好,今天天气怎么样,介绍下苏超"
  37. # 下面两个参数配合音频采样率设置,16k 16bit的音频: 每 40毫秒 发送 1280字节
  38. # 每帧音频数据大小,单位字节
  39. frame_size = 1280
  40. # 每帧音频发送间隔
  41. sleep_inetrval = 0.04
  42. class AIUIV3WsClient(object):
  43. # 初始化
  44. def __init__(self):
  45. self.handshake = self.assemble_auth_url(url)
  46. # 生成握手url
  47. def assemble_auth_url(self, base_url):
  48. host = urlparse(base_url).netloc
  49. path = urlparse(base_url).path
  50. # 生成RFC1123格式的时间戳
  51. now = datetime.now()
  52. date = format_date_time(mktime(now.timetuple()))
  53. # 拼接字符串
  54. signature_origin = "host: " + host + "\n"
  55. signature_origin += "date: " + date + "\n"
  56. signature_origin += "GET " + path + " HTTP/1.1"
  57. # 进行hmac-sha256进行加密
  58. print(signature_origin)
  59. signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
  60. digestmod=hashlib.sha256).digest()
  61. signature_sha_base64 = base64.b64encode(
  62. signature_sha).decode(encoding='utf-8')
  63. authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
  64. print('get authorization_origin:', authorization_origin)
  65. authorization = base64.b64encode(
  66. authorization_origin.encode('utf-8')).decode(encoding='utf-8')
  67. # 将请求的鉴权参数组合为字典
  68. v = {
  69. "host": host,
  70. "date": date,
  71. "authorization": authorization,
  72. }
  73. # 拼接鉴权参数,生成url
  74. url = base_url + '?' + urlencode(v)
  75. # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
  76. return url
  77. def on_open(self, ws):
  78. # 连接建立成功后开始发送数据
  79. print("### ws connect open")
  80. thread.start_new_thread(self.run, ())
  81. def run(self):
  82. if data_type == "text":
  83. self.text_req()
  84. if data_type == "audio":
  85. self.audio_req()
  86. def text_req(self):
  87. # 文本请求status固定为3,interact_mode固定为oneshot
  88. aiui_data = {
  89. "header": {
  90. "appid": appid,
  91. "sn": sn,
  92. "stmid": "text-1",
  93. "status": 3,
  94. "scene": scene,
  95. "msc.lat": 19.65309164062,
  96. "msc.lng": 109.259056086,
  97. "os_sys": "android",
  98. "interact_mode":"oneshot"
  99. },
  100. "parameter": {
  101. "nlp": {
  102. "nlp": {
  103. "compress": "raw",
  104. "format": "json",
  105. "encoding": "utf8"
  106. },
  107. "new_session": True,
  108. },
  109. # 合成参数
  110. "tts": {
  111. # 发音人
  112. "vcn": vcn,
  113. "tts": {
  114. "channels": 1,
  115. "bit_depth": 16,
  116. "sample_rate": 16000,
  117. "encoding": "raw"
  118. }
  119. }
  120. },
  121. "payload": {
  122. "text": {
  123. "compress": "raw",
  124. "format": "plain",
  125. "text": base64.b64encode(question.encode('utf-8')).decode('utf-8'),
  126. "encoding": "utf8",
  127. "status": 3
  128. }
  129. }
  130. }
  131. data = json.dumps(aiui_data)
  132. print('text request data:', data)
  133. self.ws.send(data)
  134. def audio_req(self):
  135. f = open(audio_path, 'rb')
  136. try:
  137. f.seek(0, 2)
  138. eof = f.tell()
  139. f.seek(0, 0)
  140. first = True
  141. status = 0
  142. while True:
  143. d = f.read(frame_size)
  144. if not d:
  145. break
  146. if f.tell() >= eof:
  147. # 尾帧
  148. status = 2
  149. elif not first:
  150. # 中间帧
  151. status = 1
  152. req = self.genAudioReq(d, status)
  153. first = False
  154. self.ws.send(req)
  155. # 发送间隔
  156. time.sleep(sleep_inetrval)
  157. finally:
  158. f.close()
  159. def genAudioReq(self, data, status):
  160. # 构造pcm音频请求参数
  161. aiui_data = {
  162. "header": {
  163. "appid": appid,
  164. "sn": sn,
  165. "stmid": "audio-1",
  166. "status": status,
  167. "scene": scene,
  168. "interact_mode": "continuous"
  169. },
  170. "parameter": {
  171. "nlp": {
  172. "nlp": {
  173. "compress": "raw",
  174. "format": "json",
  175. "encoding": "utf8"
  176. },
  177. "new_session": True
  178. },
  179. # 合成参数
  180. "tts": {
  181. # 发音人
  182. "vcn": vcn,
  183. "tts": {
  184. "channels": 1,
  185. "bit_depth": 16,
  186. "sample_rate": 16000,
  187. "encoding": "raw"
  188. }
  189. }
  190. },
  191. "payload": {
  192. "audio": {
  193. "encoding": "raw",
  194. "sample_rate": 16000,
  195. "channels": 1,
  196. "bit_depth": 16,
  197. "status": status,
  198. "audio": base64.b64encode(data).decode(),
  199. }
  200. }
  201. }
  202. return json.dumps(aiui_data)
  203. # 收到websocket消息的处理
  204. def on_message(self, ws, message):
  205. try:
  206. data = json.loads(message)
  207. # print('原始结果:', message)
  208. header = data['header']
  209. code = header['code']
  210. # 结果解析
  211. if code != 0:
  212. print('请求错误:', code, json.dumps(data, ensure_ascii=False))
  213. ws.close()
  214. sid = header.get('sid', "sid")
  215. payload = data.get('payload', {})
  216. parameter = data.get('parameter', {})
  217. if 'event' in payload:
  218. # 事件结果
  219. event_json = payload['event']
  220. event_text_bs64 = event_json['text']
  221. event_text = base64.b64decode(event_text_bs64).decode('utf-8')
  222. print("事件,", event_text)
  223. if 'iat' in payload:
  224. # 识别结果
  225. iat_json = payload['iat']
  226. iat_text_bs64 = iat_json['text']
  227. iat_text = base64.b64decode(iat_text_bs64).decode('utf-8')
  228. print("识别结果,seq:", iat_json['seq'], ",status:",
  229. iat_json['status'], ",", self.parse_iat_result(iat_text))
  230. if 'cbm_tidy' in payload:
  231. # 语义规整结果(历史改写),意图拆分
  232. cbm_tidy_json = payload['cbm_tidy']
  233. cbm_tidy_text_bs64 = cbm_tidy_json['text']
  234. cbm_tidy_text = base64.b64decode(
  235. cbm_tidy_text_bs64).decode('utf-8')
  236. cbm_tidy_json = json.loads(cbm_tidy_text)
  237. print("语义规整结果:")
  238. intents = cbm_tidy_json['intent']
  239. for intent in intents:
  240. print(" intent index:",
  241. intent['index'], ",意图语料:", intent['value'])
  242. if 'cbm_intent_domain' in payload:
  243. # 意图拆分后的落域结果
  244. cbm_intent_domain_json = payload['cbm_intent_domain']
  245. cbm_intent_domain_text_bs64 = cbm_intent_domain_json['text']
  246. cbm_intent_domain_text = base64.b64decode(
  247. cbm_intent_domain_text_bs64).decode('utf-8')
  248. index = self.get_intent_index(parameter, "cbm_intent_domain")
  249. print("intent index:", index, ",落域结果:", cbm_intent_domain_text)
  250. if 'cbm_semantic' in payload:
  251. # 技能结果
  252. cbm_semantic_json = payload['cbm_semantic']
  253. cbm_semantic_text_bs64 = cbm_semantic_json['text']
  254. cbm_semantic_text = base64.b64decode(
  255. cbm_semantic_text_bs64).decode('utf-8')
  256. cbm_semantic_json = json.loads(cbm_semantic_text)
  257. index = self.get_intent_index(parameter, "cbm_semantic")
  258. if cbm_semantic_json['rc'] != 0:
  259. print("intent index:", index, ",技能结果:说法:",
  260. cbm_semantic_json['text'], ",", cbm_semantic_text)
  261. else:
  262. print("intent index:", index, ",技能结果:说法:",
  263. cbm_semantic_json['text'], ",命中技能:", cbm_semantic_json['category'], ",回复:", cbm_semantic_json['answer']['text'])
  264. if 'nlp' in payload:
  265. # 语义结果,经过大模型润色的最终结果
  266. nlp_json = payload['nlp']
  267. nlp_text_bs64 = nlp_json['text']
  268. nlp_text = base64.b64decode(nlp_text_bs64).decode('utf-8')
  269. print("语义结果 seq:", nlp_json['seq'], ",status:",
  270. nlp_json['status'], ",nlp.text: ", nlp_text)
  271. if 'tts' in payload:
  272. # 将结果保存到文件,文件后缀名需要根据tts参数中的encoding来决定
  273. audioData = payload['tts']['audio']
  274. if audioData != None:
  275. audioBytes = base64.b64decode(audioData)
  276. print("tts结果: ", len(audioBytes), " 字节")
  277. with open(sid + "." + self.get_suffix(payload['tts']['encoding']), 'ab') as file:
  278. file.write(audioBytes)
  279. if 'status' in header and header['status'] == 2:
  280. # 接收最后一帧结果,关闭连接
  281. ws.close()
  282. except Exception as e:
  283. traceback.print_exc()
  284. pass
  285. def parse_iat_result(self, iat_res):
  286. iat_text = ""
  287. iat_res_json = json.loads(iat_res)
  288. for cw in iat_res_json['text']['ws']:
  289. for cw_item in cw["cw"]:
  290. iat_text += cw_item['w']
  291. return iat_text
  292. def get_intent_index(self, parameter, key):
  293. if key in parameter:
  294. return parameter[key]['loc']['intent']
  295. return "-"
  296. def get_suffix(self, encoding):
  297. if encoding == 'raw':
  298. return 'pcm'
  299. if encoding == 'lame':
  300. return 'mp3'
  301. return 'unknow'
  302. def on_error(self, ws, error):
  303. print("### connection error: ", str(error))
  304. ws.close()
  305. def on_close(self, ws, close_status_code, close_msg):
  306. print("### connection is closed ###, cloce code:", close_status_code)
  307. def start(self):
  308. self.ws = websocket.WebSocketApp(
  309. self.handshake,
  310. on_open=self.on_open,
  311. on_message=self.on_message,
  312. on_error=self.on_error,
  313. on_close=self.on_close,
  314. )
  315. self.ws.run_forever()
  316. if __name__ == "__main__":
  317. client = AIUIV3WsClient()
  318. client.start()