1
0

wintls.c 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837
  1. #include "hssl.h"
  2. #ifdef WITH_WINTLS
  3. // #define PRINT_DEBUG
  4. // #define PRINT_ERROR
  5. #include "hdef.h"
  6. #include <schannel.h>
  7. #include <wincrypt.h>
  8. #include <windows.h>
  9. #include <wintrust.h>
  10. #define SECURITY_WIN32
  11. #include <security.h>
  12. #include <sspi.h>
  13. #define TLS_SOCKET_BUFFER_SIZE 17000
  14. #ifndef SP_PROT_SSL2_SERVER
  15. #define SP_PROT_SSL2_SERVER 0x00000004
  16. #endif
  17. #ifndef SP_PROT_SSL2_CLIENT
  18. #define SP_PROT_SSL2_CLIENT 0x00000008
  19. #endif
  20. #ifndef SP_PROT_SSL3_SERVER
  21. #define SP_PROT_SSL3_SERVER 0x00000010
  22. #endif
  23. #ifndef SP_PROT_SSL3_CLIENT
  24. #define SP_PROT_SSL3_CLIENT 0x00000020
  25. #endif
  26. #ifndef SP_PROT_TLS1_SERVER
  27. #define SP_PROT_TLS1_SERVER 0x00000040
  28. #endif
  29. #ifndef SP_PROT_TLS1_CLIENT
  30. #define SP_PROT_TLS1_CLIENT 0x00000080
  31. #endif
  32. #ifndef SP_PROT_TLS1_0_SERVER
  33. #define SP_PROT_TLS1_0_SERVER SP_PROT_TLS1_SERVER
  34. #endif
  35. #ifndef SP_PROT_TLS1_0_CLIENT
  36. #define SP_PROT_TLS1_0_CLIENT SP_PROT_TLS1_CLIENT
  37. #endif
  38. #ifndef SP_PROT_TLS1_1_SERVER
  39. #define SP_PROT_TLS1_1_SERVER 0x00000100
  40. #endif
  41. #ifndef SP_PROT_TLS1_1_CLIENT
  42. #define SP_PROT_TLS1_1_CLIENT 0x00000200
  43. #endif
  44. #ifndef SP_PROT_TLS1_2_SERVER
  45. #define SP_PROT_TLS1_2_SERVER 0x00000400
  46. #endif
  47. #ifndef SP_PROT_TLS1_2_CLIENT
  48. #define SP_PROT_TLS1_2_CLIENT 0x00000800
  49. #endif
  50. #ifndef SP_PROT_TLS1_3_SERVER
  51. #define SP_PROT_TLS1_3_SERVER 0x00001000
  52. #endif
  53. #ifndef SP_PROT_TLS1_3_CLIENT
  54. #define SP_PROT_TLS1_3_CLIENT 0x00002000
  55. #endif
  56. #ifndef SCH_USE_STRONG_CRYPTO
  57. #define SCH_USE_STRONG_CRYPTO 0x00400000
  58. #endif
  59. #ifndef SECBUFFER_ALERT
  60. #define SECBUFFER_ALERT 17
  61. #endif
  62. const char* hssl_backend()
  63. {
  64. return "schannel";
  65. }
  66. static PCCERT_CONTEXT getservercert(const char* path)
  67. {
  68. /*
  69. According to the information I searched from the internet, it is not possible to specify an x509 private key and certificate using the
  70. CertCreateCertificateContext interface. We must first export them as a pkcs#12 formatted file, and then import them into the Windows certificate store. This
  71. is because the Windows certificate store is an integrated system location that does not support the direct use of separate private key files and certificate
  72. files. The pkcs#12 format is a complex format that can store and protect keys and certificates. You can use the OpenSSL tool to combine the private key file
  73. and certificate file into a pkcs#12 formatted file, For example: OpenSSL pkcs12 -export -out cert.pfx -inkey private.key -in cert.cer Then, you can use the
  74. certutil tool or a graphical interface to import this file into the personal store of your local computer. After importing, you can use the
  75. CertFindCertificateInStore interface to create and manipulate certificate contexts.
  76. */
  77. return NULL;
  78. }
  79. hssl_ctx_t hssl_ctx_new(hssl_ctx_opt_t* opt)
  80. {
  81. SECURITY_STATUS SecStatus;
  82. TimeStamp Lifetime;
  83. CredHandle* hCred = NULL;
  84. SCHANNEL_CRED credData = { 0 };
  85. TCHAR unisp_name[] = UNISP_NAME;
  86. unsigned long credflag;
  87. if (opt && opt->endpoint == HSSL_SERVER) {
  88. PCCERT_CONTEXT serverCert = NULL; // server-side certificate
  89. #if 1 // create cert from store
  90. //-------------------------------------------------------
  91. // Get the server certificate.
  92. //-------------------------------------------------------
  93. // Open the My store(personal store).
  94. HCERTSTORE hMyCertStore = CertOpenStore(CERT_STORE_PROV_SYSTEM, X509_ASN_ENCODING, 0, CERT_SYSTEM_STORE_LOCAL_MACHINE, L"MY");
  95. if (hMyCertStore == NULL) {
  96. printe("Error opening MY store for server.\n");
  97. return NULL;
  98. }
  99. //-------------------------------------------------------
  100. // Search for a certificate match its subject string to opt->crt_file.
  101. serverCert = CertFindCertificateInStore(hMyCertStore, X509_ASN_ENCODING, 0, CERT_FIND_SUBJECT_STR_A, opt->crt_file, NULL);
  102. CertCloseStore(hMyCertStore, 0);
  103. if (serverCert == NULL) {
  104. printe("Error retrieving server certificate. %x\n", GetLastError());
  105. return NULL;
  106. }
  107. #else
  108. serverCert = getservercert(opt->ca_file);
  109. #endif
  110. credData.cCreds = 1; // 数量
  111. credData.paCred = &serverCert;
  112. // credData.dwCredFormat = SCH_CRED_FORMAT_CERT_HASH;
  113. credData.grbitEnabledProtocols = SP_PROT_TLS1_2_SERVER | SP_PROT_TLS1_3_SERVER;
  114. credflag = SECPKG_CRED_INBOUND;
  115. } else {
  116. credData.grbitEnabledProtocols = SP_PROT_TLS1_2_CLIENT | SP_PROT_TLS1_3_CLIENT;
  117. credflag = SECPKG_CRED_OUTBOUND;
  118. }
  119. #if 0 // just use the system defalut algs
  120. ALG_ID rgbSupportedAlgs[4];
  121. rgbSupportedAlgs[0] = CALG_DH_EPHEM;
  122. rgbSupportedAlgs[1] = CALG_RSA_KEYX;
  123. rgbSupportedAlgs[2] = CALG_AES_128;
  124. rgbSupportedAlgs[3] = CALG_SHA_256;
  125. credData.cSupportedAlgs = 4;
  126. credData.palgSupportedAlgs = rgbSupportedAlgs;
  127. #endif
  128. credData.dwVersion = SCHANNEL_CRED_VERSION;
  129. // credData.dwFlags = SCH_CRED_NO_DEFAULT_CREDS | SCH_CRED_NO_SERVERNAME_CHECK | SCH_USE_STRONG_CRYPTO | SCH_CRED_MANUAL_CRED_VALIDATION | SCH_CRED_IGNORE_NO_REVOCATION_CHECK | SCH_CRED_IGNORE_REVOCATION_OFFLINE;
  130. // credData.dwFlags = SCH_CRED_AUTO_CRED_VALIDATION | SCH_CRED_REVOCATION_CHECK_CHAIN | SCH_CRED_IGNORE_REVOCATION_OFFLINE;
  131. // credData.dwMinimumCipherStrength = -1;
  132. // credData.dwMaximumCipherStrength = -1;
  133. //-------------------------------------------------------
  134. hCred = (CredHandle*)malloc(sizeof(CredHandle));
  135. if (hCred == NULL) {
  136. return NULL;
  137. }
  138. SecStatus = AcquireCredentialsHandle(NULL, unisp_name, credflag, NULL, &credData, NULL, NULL, hCred, &Lifetime);
  139. if (SecStatus == SEC_E_OK) {
  140. #ifndef NDEBUG
  141. SecPkgCred_SupportedAlgs algs;
  142. if (QueryCredentialsAttributesA(hCred, SECPKG_ATTR_SUPPORTED_ALGS, &algs) == SEC_E_OK) {
  143. for (int i = 0; i < algs.cSupportedAlgs; i++) {
  144. printd("alg: 0x%08x\n", algs.palgSupportedAlgs[i]);
  145. }
  146. }
  147. #endif
  148. } else {
  149. printe("ERROR: AcquireCredentialsHandle: 0x%x\n", SecStatus);
  150. free(hCred);
  151. hCred = NULL;
  152. }
  153. return hCred;
  154. }
  155. void hssl_ctx_free(hssl_ctx_t ssl_ctx)
  156. {
  157. SECURITY_STATUS sec_status = FreeCredentialsHandle(ssl_ctx);
  158. if (sec_status != SEC_E_OK) {
  159. printe("free_cred_handle FreeCredentialsHandle %d\n", sec_status);
  160. }
  161. }
  162. static void init_sec_buffer(SecBuffer* secure_buffer, unsigned long type, unsigned long len, void* buffer)
  163. {
  164. secure_buffer->BufferType = type;
  165. secure_buffer->cbBuffer = len;
  166. secure_buffer->pvBuffer = buffer;
  167. }
  168. static void init_sec_buffer_desc(SecBufferDesc* secure_buffer_desc, unsigned long version, unsigned long num_buffers, SecBuffer* buffers)
  169. {
  170. secure_buffer_desc->ulVersion = version;
  171. secure_buffer_desc->cBuffers = num_buffers;
  172. secure_buffer_desc->pBuffers = buffers;
  173. }
  174. /* enum for the nonblocking SSL connection state machine */
  175. typedef enum {
  176. ssl_connect_1,
  177. ssl_connect_2,
  178. ssl_connect_2_reading,
  179. ssl_connect_2_writing,
  180. ssl_connect_3,
  181. ssl_connect_done
  182. } ssl_connect_state;
  183. struct wintls_s {
  184. hssl_ctx_t ssl_ctx; // CredHandle
  185. int fd;
  186. union {
  187. ssl_connect_state state2;
  188. ssl_connect_state connecting_state;
  189. };
  190. SecHandle sechandle;
  191. SecPkgContext_StreamSizes stream_sizes_;
  192. size_t buffer_to_decrypt_offset_;
  193. size_t dec_len_;
  194. char encrypted_buffer_[TLS_SOCKET_BUFFER_SIZE];
  195. char buffer_to_decrypt_[TLS_SOCKET_BUFFER_SIZE];
  196. char decrypted_buffer_[TLS_SOCKET_BUFFER_SIZE + TLS_SOCKET_BUFFER_SIZE];
  197. char* sni;
  198. };
  199. hssl_t hssl_new(hssl_ctx_t ssl_ctx, int fd)
  200. {
  201. struct wintls_s* ret = malloc(sizeof(*ret));
  202. if (ret) {
  203. memset(ret, 0, sizeof(*ret));
  204. ret->ssl_ctx = ssl_ctx;
  205. ret->fd = fd;
  206. ret->sechandle.dwLower = 0;
  207. ret->sechandle.dwUpper = 0;
  208. }
  209. return ret;
  210. }
  211. void hssl_free(hssl_t _ssl)
  212. {
  213. struct wintls_s* ssl = _ssl;
  214. SECURITY_STATUS sec_status = DeleteSecurityContext(&ssl->sechandle);
  215. if (sec_status != SEC_E_OK) {
  216. printe("hssl_free DeleteSecurityContext %d", sec_status);
  217. }
  218. if (ssl->sni) {
  219. free(ssl->sni);
  220. }
  221. free(ssl);
  222. }
  223. static void free_all_buffers(SecBufferDesc* secure_buffer_desc)
  224. {
  225. for (unsigned long i = 0; i < secure_buffer_desc->cBuffers; ++i) {
  226. void* buffer = secure_buffer_desc->pBuffers[i].pvBuffer;
  227. if (buffer != NULL) {
  228. FreeContextBuffer(buffer);
  229. }
  230. }
  231. }
  232. static int __sendwrapper(SOCKET fd, const char* buf, size_t len, int flags)
  233. {
  234. int left = len;
  235. int offset = 0;
  236. while (left > 0) {
  237. int bytes_sent = send(fd, buf + offset, left, flags);
  238. if (bytes_sent == 0 || (bytes_sent == SOCKET_ERROR && WSAGetLastError() != WSAEWOULDBLOCK && WSAGetLastError() != WSAEINTR)) {
  239. break;
  240. }
  241. if (bytes_sent > 0) {
  242. offset += bytes_sent;
  243. left -= bytes_sent;
  244. }
  245. }
  246. return offset;
  247. }
  248. static int __recvwrapper(SOCKET fd, char* buf, int len, int flags)
  249. {
  250. int ret = 0;
  251. do {
  252. ret = recv(fd, buf, len, flags);
  253. } while (ret == SOCKET_ERROR && WSAGetLastError() == WSAEINTR);
  254. return ret;
  255. }
  256. int hssl_accept(hssl_t ssl)
  257. {
  258. int ret = HSSL_ERROR;
  259. struct wintls_s* winssl = ssl;
  260. bool authn_completed = false;
  261. // Input buffer
  262. char buffer_in[TLS_SOCKET_BUFFER_SIZE];
  263. SecBuffer secure_buffer_in[2] = { 0 };
  264. init_sec_buffer(&secure_buffer_in[0], SECBUFFER_TOKEN, TLS_SOCKET_BUFFER_SIZE, buffer_in);
  265. init_sec_buffer(&secure_buffer_in[1], SECBUFFER_EMPTY, 0, NULL);
  266. SecBufferDesc secure_buffer_desc_in = { 0 };
  267. init_sec_buffer_desc(&secure_buffer_desc_in, SECBUFFER_VERSION, 2, secure_buffer_in);
  268. // Output buffer
  269. SecBuffer secure_buffer_out[3] = { 0 };
  270. init_sec_buffer(&secure_buffer_out[0], SECBUFFER_TOKEN, 0, NULL);
  271. init_sec_buffer(&secure_buffer_out[1], SECBUFFER_ALERT, 0, NULL);
  272. init_sec_buffer(&secure_buffer_out[2], SECBUFFER_EMPTY, 0, NULL);
  273. SecBufferDesc secure_buffer_desc_out = { 0 };
  274. init_sec_buffer_desc(&secure_buffer_desc_out, SECBUFFER_VERSION, 3, secure_buffer_out);
  275. unsigned long context_requirements = ASC_REQ_ALLOCATE_MEMORY | ASC_REQ_CONFIDENTIALITY;
  276. // We use ASC_REQ_ALLOCATE_MEMORY which means the buffers will be allocated for us, we need to make sure we free them.
  277. ULONG context_attributes = 0;
  278. TimeStamp life_time = { 0 };
  279. secure_buffer_in[0].cbBuffer = __recvwrapper(winssl->fd, (char*)secure_buffer_in[0].pvBuffer, TLS_SOCKET_BUFFER_SIZE, 0);
  280. // printd("%s recv %d %d\n", __func__, secure_buffer_in[0].cbBuffer, WSAGetLastError());
  281. if (secure_buffer_in[0].cbBuffer == SOCKET_ERROR && WSAGetLastError() == WSAEWOULDBLOCK) {
  282. ret = HSSL_WANT_READ;
  283. } else if (secure_buffer_in[0].cbBuffer > 0) {
  284. SECURITY_STATUS sec_status = AcceptSecurityContext(winssl->ssl_ctx, winssl->state2 == 0 ? NULL : &winssl->sechandle, &secure_buffer_desc_in,
  285. context_requirements, 0, &winssl->sechandle, &secure_buffer_desc_out, &context_attributes, &life_time);
  286. winssl->state2 = 1;
  287. // printd("establish_server_security_context AcceptSecurityContext %x\n", sec_status);
  288. if (secure_buffer_out[0].cbBuffer > 0) {
  289. int rc = __sendwrapper(winssl->fd, (const char*)secure_buffer_out[0].pvBuffer, secure_buffer_out[0].cbBuffer, 0);
  290. if (rc != secure_buffer_out[0].cbBuffer) {
  291. goto END;
  292. }
  293. }
  294. switch (sec_status) {
  295. case SEC_E_OK:
  296. ret = HSSL_OK;
  297. authn_completed = true;
  298. break;
  299. case SEC_I_CONTINUE_NEEDED:
  300. ret = HSSL_WANT_READ;
  301. break;
  302. case SEC_I_COMPLETE_AND_CONTINUE:
  303. case SEC_I_COMPLETE_NEEDED: {
  304. SECURITY_STATUS complete_sec_status = SEC_E_OK;
  305. complete_sec_status = CompleteAuthToken(&winssl->sechandle, &secure_buffer_desc_out);
  306. if (complete_sec_status != SEC_E_OK) {
  307. printe("establish_server_security_context CompleteAuthToken %x\n", complete_sec_status);
  308. goto END;
  309. }
  310. if (sec_status == SEC_I_COMPLETE_NEEDED) {
  311. authn_completed = true;
  312. ret = HSSL_OK;
  313. } else {
  314. ret = HSSL_WANT_READ;
  315. }
  316. break;
  317. }
  318. default:
  319. break;
  320. }
  321. }
  322. END:
  323. free_all_buffers(&secure_buffer_desc_out);
  324. if (authn_completed) {
  325. SECURITY_STATUS sec_status = QueryContextAttributes(&winssl->sechandle, SECPKG_ATTR_STREAM_SIZES, &winssl->stream_sizes_);
  326. if (sec_status != SEC_E_OK) {
  327. printe("get_stream_sizes QueryContextAttributes %d\n", sec_status);
  328. }
  329. }
  330. return ret;
  331. }
  332. static int schannel_connect_step1(struct wintls_s* ssl)
  333. {
  334. int ret = 0;
  335. ULONG context_attributes = 0;
  336. unsigned long context_requirements = ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT | ISC_REQ_CONFIDENTIALITY | ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_STREAM;
  337. TimeStamp life_time = { 0 };
  338. SecBuffer secure_buffer_out[1] = { 0 };
  339. init_sec_buffer(&secure_buffer_out[0], SECBUFFER_EMPTY, 0, NULL);
  340. SecBufferDesc secure_buffer_desc_out = { 0 };
  341. init_sec_buffer_desc(&secure_buffer_desc_out, SECBUFFER_VERSION, 1, secure_buffer_out);
  342. SECURITY_STATUS sec_status = InitializeSecurityContext(ssl->ssl_ctx, NULL, ssl->sni, context_requirements, 0, 0, NULL, 0, &ssl->sechandle,
  343. &secure_buffer_desc_out, &context_attributes, &life_time);
  344. if (sec_status != SEC_I_CONTINUE_NEEDED) {
  345. printe("1InitializeSecurityContext: %x\n", sec_status);
  346. }
  347. if (secure_buffer_out[0].cbBuffer > 0) {
  348. int rc = __sendwrapper(ssl->fd, (const char*)secure_buffer_out[0].pvBuffer, secure_buffer_out[0].cbBuffer, 0);
  349. if (rc != secure_buffer_out[0].cbBuffer) {
  350. // TODO: Handle the error
  351. printe("%s :send failed\n", __func__);
  352. ret = -1;
  353. } else {
  354. printd("%s :send len=%d\n", __func__, rc);
  355. ssl->connecting_state = ssl_connect_2;
  356. }
  357. }
  358. free_all_buffers(&secure_buffer_desc_out);
  359. return ret;
  360. }
  361. static int schannel_connect_step2(struct wintls_s* ssl)
  362. {
  363. int ret = HSSL_ERROR;
  364. ULONG context_attributes = 0;
  365. bool verify_server_cert = 0;
  366. unsigned long context_requirements = ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT | ISC_REQ_CONFIDENTIALITY | ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_STREAM;
  367. if (!verify_server_cert) {
  368. context_requirements |= ISC_REQ_MANUAL_CRED_VALIDATION;
  369. }
  370. TimeStamp life_time = { 0 };
  371. // Allocate a temporary buffer for input
  372. char* buffer_in = malloc(TLS_SOCKET_BUFFER_SIZE);
  373. if (buffer_in == NULL) {
  374. printe("schannel_connect_step2: Memory allocation failed\n");
  375. return HSSL_ERROR;
  376. }
  377. int offset = 0;
  378. bool skip_recv = false;
  379. bool authn_complete = false;
  380. while (!authn_complete) {
  381. int in_buffer_size = 0;
  382. if (!skip_recv) {
  383. int received = __recvwrapper(ssl->fd, buffer_in + offset, TLS_SOCKET_BUFFER_SIZE, 0);
  384. if (received == SOCKET_ERROR) {
  385. if (WSAGetLastError() == WSAEWOULDBLOCK) {
  386. ret = HSSL_WANT_READ;
  387. } else {
  388. printe("schannel_connect_step2: Receive failed\n");
  389. }
  390. break;
  391. } else if (received == 0) {
  392. printe("schannel_connect_step2: peer closed\n");
  393. break;
  394. }
  395. in_buffer_size = received + offset;
  396. } else {
  397. in_buffer_size = offset;
  398. }
  399. skip_recv = false;
  400. offset = 0;
  401. // Input buffer
  402. SecBuffer secure_buffer_in[4] = { 0 };
  403. init_sec_buffer(&secure_buffer_in[0], SECBUFFER_TOKEN, in_buffer_size, buffer_in);
  404. init_sec_buffer(&secure_buffer_in[1], SECBUFFER_EMPTY, 0, NULL);
  405. SecBufferDesc secure_buffer_desc_in = { 0 };
  406. init_sec_buffer_desc(&secure_buffer_desc_in, SECBUFFER_VERSION, 2, secure_buffer_in);
  407. // Output buffer
  408. SecBuffer secure_buffer_out[3] = { 0 };
  409. init_sec_buffer(&secure_buffer_out[0], SECBUFFER_TOKEN, 0, NULL);
  410. init_sec_buffer(&secure_buffer_out[1], SECBUFFER_ALERT, 0, NULL);
  411. init_sec_buffer(&secure_buffer_out[2], SECBUFFER_EMPTY, 0, NULL);
  412. SecBufferDesc secure_buffer_desc_out = { 0 };
  413. init_sec_buffer_desc(&secure_buffer_desc_out, SECBUFFER_VERSION, 3, secure_buffer_out);
  414. printd("h2:%d\n", in_buffer_size);
  415. SECURITY_STATUS sec_status = InitializeSecurityContext(ssl->ssl_ctx, &ssl->sechandle, ssl->sni, context_requirements, 0, 0, &secure_buffer_desc_in, 0,
  416. &ssl->sechandle, &secure_buffer_desc_out, &context_attributes, &life_time);
  417. printd("h2 0x%x inbuf[1] type=%d %d inbuf[0]=%d\n", sec_status, secure_buffer_in[1].BufferType, secure_buffer_in[1].cbBuffer, secure_buffer_in[0].cbBuffer);
  418. if (sec_status == SEC_E_OK || sec_status == SEC_I_CONTINUE_NEEDED) {
  419. // for (size_t i = 0; i < 3; i++) {
  420. // printd("obuf[%zu] type=%d %d\n", i, secure_buffer_out[i].BufferType, secure_buffer_out[i].cbBuffer);
  421. // }
  422. if (secure_buffer_out[0].cbBuffer > 0) {
  423. int rc = __sendwrapper(ssl->fd, (const char*)secure_buffer_out[0].pvBuffer, secure_buffer_out[0].cbBuffer, 0);
  424. if (rc != secure_buffer_out[0].cbBuffer) {
  425. printe("schannel_connect_step2: Send failed\n");
  426. // TODO: Handle the error
  427. break;
  428. }
  429. // printd("%s :send ok\n", __func__);
  430. }
  431. if (sec_status == SEC_I_CONTINUE_NEEDED) {
  432. if (secure_buffer_in[1].BufferType == SECBUFFER_EXTRA && secure_buffer_in[1].cbBuffer > 0) {
  433. offset = secure_buffer_in[0].cbBuffer - secure_buffer_in[1].cbBuffer;
  434. memmove(buffer_in, buffer_in + offset, secure_buffer_in[1].cbBuffer);
  435. offset = secure_buffer_in[1].cbBuffer;
  436. skip_recv = true;
  437. }
  438. } else if (sec_status == SEC_E_OK) {
  439. authn_complete = true;
  440. ret = HSSL_OK;
  441. ssl->connecting_state = ssl_connect_3;
  442. }
  443. } else if (sec_status == SEC_E_INCOMPLETE_MESSAGE) {
  444. offset = secure_buffer_in[0].cbBuffer;
  445. } else {
  446. printe("2InitializeSecurityContext: 0x%x\n", sec_status);
  447. break;
  448. }
  449. free_all_buffers(&secure_buffer_desc_out);
  450. }
  451. // END:
  452. free(buffer_in); // Free the temporary buffer
  453. return ret;
  454. }
  455. static void dumpconninfo(SecHandle* sechandle)
  456. {
  457. SECURITY_STATUS Status;
  458. SecPkgContext_ConnectionInfo ConnectionInfo;
  459. Status = QueryContextAttributes(sechandle,
  460. SECPKG_ATTR_CONNECTION_INFO,
  461. (PVOID)&ConnectionInfo);
  462. if (Status != SEC_E_OK) {
  463. printe("Error 0x%x querying connection info\n", Status);
  464. return;
  465. }
  466. printd("\n");
  467. switch (ConnectionInfo.dwProtocol) {
  468. case SP_PROT_TLS1_CLIENT:
  469. printd("Protocol: TLS1\n");
  470. break;
  471. case SP_PROT_SSL3_CLIENT:
  472. printd("Protocol: SSL3\n");
  473. break;
  474. case SP_PROT_SSL2_CLIENT:
  475. printd("Protocol: SSL2\n");
  476. break;
  477. case SP_PROT_PCT1_CLIENT:
  478. printd("Protocol: PCT\n");
  479. break;
  480. default:
  481. printd("Protocol: 0x%x\n", ConnectionInfo.dwProtocol);
  482. }
  483. switch (ConnectionInfo.aiCipher) {
  484. case CALG_RC4:
  485. printd("Cipher: RC4\n");
  486. break;
  487. case CALG_3DES:
  488. printd("Cipher: Triple DES\n");
  489. break;
  490. case CALG_RC2:
  491. printd("Cipher: RC2\n");
  492. break;
  493. case CALG_DES:
  494. case CALG_CYLINK_MEK:
  495. printd("Cipher: DES\n");
  496. break;
  497. case CALG_SKIPJACK:
  498. printd("Cipher: Skipjack\n");
  499. break;
  500. case CALG_AES_128:
  501. printd("Cipher: aes128\n");
  502. break;
  503. default:
  504. printd("Cipher: 0x%x\n", ConnectionInfo.aiCipher);
  505. }
  506. printd("Cipher strength: %d\n", ConnectionInfo.dwCipherStrength);
  507. switch (ConnectionInfo.aiHash) {
  508. case CALG_MD5:
  509. printd("Hash: MD5\n");
  510. break;
  511. case CALG_SHA:
  512. printd("Hash: SHA\n");
  513. break;
  514. default:
  515. printd("Hash: 0x%x\n", ConnectionInfo.aiHash);
  516. }
  517. printd("Hash strength: %d\n", ConnectionInfo.dwHashStrength);
  518. switch (ConnectionInfo.aiExch) {
  519. case CALG_RSA_KEYX:
  520. case CALG_RSA_SIGN:
  521. printd("Key exchange: RSA\n");
  522. break;
  523. case CALG_KEA_KEYX:
  524. printd("Key exchange: KEA\n");
  525. break;
  526. case CALG_DH_EPHEM:
  527. printd("Key exchange: DH Ephemeral\n");
  528. break;
  529. default:
  530. printd("Key exchange: 0x%x\n", ConnectionInfo.aiExch);
  531. }
  532. printd("Key exchange strength: %d\n", ConnectionInfo.dwExchStrength);
  533. }
  534. int hssl_connect(hssl_t _ssl)
  535. {
  536. int ret = 0;
  537. struct wintls_s* ssl = _ssl;
  538. if (ssl->connecting_state == ssl_connect_1) {
  539. ret = schannel_connect_step1(ssl);
  540. }
  541. if (!ret && ssl->connecting_state == ssl_connect_2) {
  542. ret = schannel_connect_step2(ssl);
  543. }
  544. // printd("%s %x\n", __func__, ret);
  545. if (!ret) {
  546. if (ssl->connecting_state == ssl_connect_3) {
  547. // ret = schannel_connect_step3(ssl);
  548. }
  549. SECURITY_STATUS sec_status = QueryContextAttributes(&ssl->sechandle, SECPKG_ATTR_STREAM_SIZES, &ssl->stream_sizes_);
  550. if (sec_status != SEC_E_OK) {
  551. printe("get_stream_sizes QueryContextAttributes %d\n", sec_status);
  552. } else {
  553. printd("stream_sizes bs:%d h:%d t:%d max:%d bfs:%d\n", ssl->stream_sizes_.cbBlockSize, ssl->stream_sizes_.cbHeader, ssl->stream_sizes_.cbTrailer, ssl->stream_sizes_.cbMaximumMessage, ssl->stream_sizes_.cBuffers);
  554. }
  555. dumpconninfo(&ssl->sechandle);
  556. }
  557. return ret;
  558. }
  559. static int decrypt_message(SecHandle security_context, unsigned long* extra, char* in_buf, int in_len, char* out_buf, int out_len)
  560. {
  561. printd("%s: inlen=%d\n", __func__, in_len);
  562. // Initialize the secure buffers
  563. SecBuffer secure_buffers[4] = { 0 };
  564. init_sec_buffer(&secure_buffers[0], SECBUFFER_DATA, in_len, in_buf);
  565. init_sec_buffer(&secure_buffers[1], SECBUFFER_EMPTY, 0, NULL);
  566. init_sec_buffer(&secure_buffers[2], SECBUFFER_EMPTY, 0, NULL);
  567. init_sec_buffer(&secure_buffers[3], SECBUFFER_EMPTY, 0, NULL);
  568. // Initialize the secure buffer descriptor
  569. SecBufferDesc secure_buffer_desc = { 0 };
  570. init_sec_buffer_desc(&secure_buffer_desc, SECBUFFER_VERSION, 4, secure_buffers);
  571. // Decrypt the message using the security context
  572. SECURITY_STATUS sec_status = DecryptMessage(&security_context, &secure_buffer_desc, 0, NULL);
  573. for (size_t i = 1; i < 4; i++) {
  574. printd("%d: %u %u\n", i, secure_buffers[i].BufferType, secure_buffers[i].cbBuffer);
  575. }
  576. if (sec_status == SEC_E_INCOMPLETE_MESSAGE) {
  577. printe("decrypt_message SEC_E_INCOMPLETE_MESSAGE\n");
  578. return -1;
  579. } else if (sec_status == SEC_E_DECRYPT_FAILURE) {
  580. printe("decrypt_message ignore SEC_E_DECRYPT_FAILURE\n");
  581. return 0;
  582. } else if (sec_status == SEC_E_UNSUPPORTED_FUNCTION) {
  583. printe("decrypt_message ignore SEC_E_UNSUPPORTED_FUNCTION\n");
  584. return 0;
  585. }
  586. if (sec_status != SEC_E_OK) {
  587. printe("decrypt_message DecryptMessage: 0x%x\n", sec_status);
  588. return -1;
  589. }
  590. if (secure_buffers[3].BufferType == SECBUFFER_EXTRA && secure_buffers[3].cbBuffer > 0) {
  591. *extra = secure_buffers[3].cbBuffer;
  592. }
  593. memcpy(out_buf, secure_buffers[1].pvBuffer, secure_buffers[1].cbBuffer);
  594. // printd("ob:%s\n", out_buf);
  595. return secure_buffers[1].cbBuffer;
  596. }
  597. int hssl_read(hssl_t _ssl, void* buf, int len)
  598. {
  599. struct wintls_s* ssl = _ssl;
  600. printd("%s: dec_len_= %zu\n", __func__, ssl->dec_len_);
  601. if (ssl->dec_len_ > 0) {
  602. if (buf == NULL) {
  603. return 0;
  604. }
  605. int decrypted = MIN(ssl->dec_len_, len);
  606. memcpy(buf, ssl->decrypted_buffer_, (size_t)decrypted);
  607. ssl->dec_len_ -= decrypted;
  608. if (ssl->dec_len_) {
  609. memmove(ssl->decrypted_buffer_, ssl->decrypted_buffer_ + decrypted, (size_t)ssl->dec_len_);
  610. } else {
  611. // hssl_read(_ssl, NULL, 0);
  612. }
  613. return decrypted;
  614. }
  615. // We might have leftovers, an incomplete message from a previous call.
  616. // Calculate the available buffer length for tcp recv.
  617. int recv_max_len = TLS_SOCKET_BUFFER_SIZE - ssl->buffer_to_decrypt_offset_;
  618. int bytes_received = __recvwrapper(ssl->fd, ssl->buffer_to_decrypt_ + ssl->buffer_to_decrypt_offset_, recv_max_len, 0);
  619. // printd("%s recv %d %d\n", __func__, bytes_received, WSAGetLastError());
  620. if (bytes_received == SOCKET_ERROR) {
  621. if (WSAGetLastError() == WSAEWOULDBLOCK) {
  622. bytes_received = 0;
  623. return 0;
  624. } else {
  625. return -1;
  626. }
  627. } else if (bytes_received == 0) {
  628. return 0;
  629. }
  630. int encrypted_buffer_len = ssl->buffer_to_decrypt_offset_ + bytes_received;
  631. ssl->buffer_to_decrypt_offset_ = 0;
  632. while (true) {
  633. // printd("%s:buffer_to_decrypt_offset_ = %d , encrypted_buffer_len= %d\n", __func__, ssl->buffer_to_decrypt_offset_, encrypted_buffer_len);
  634. if (ssl->buffer_to_decrypt_offset_ >= encrypted_buffer_len) {
  635. // Reached the encrypted buffer length, we decrypted everything so we can stop.
  636. break;
  637. }
  638. unsigned long extra = 0;
  639. int decrypted_len = decrypt_message(ssl->sechandle, &extra, ssl->buffer_to_decrypt_ + ssl->buffer_to_decrypt_offset_,
  640. encrypted_buffer_len - ssl->buffer_to_decrypt_offset_, ssl->decrypted_buffer_ + ssl->dec_len_,
  641. TLS_SOCKET_BUFFER_SIZE + TLS_SOCKET_BUFFER_SIZE - ssl->dec_len_);
  642. if (decrypted_len == -1) {
  643. // Incomplete message, we shuold keep it so it will be decrypted on the next call to recv().
  644. // Shift the remaining buffer to the beginning and break the loop.
  645. memmove(ssl->buffer_to_decrypt_, ssl->buffer_to_decrypt_ + ssl->buffer_to_decrypt_offset_, encrypted_buffer_len - ssl->buffer_to_decrypt_offset_);
  646. break;
  647. }
  648. ssl->dec_len_ += decrypted_len;
  649. ssl->buffer_to_decrypt_offset_ = encrypted_buffer_len - extra;
  650. }
  651. ssl->buffer_to_decrypt_offset_ = encrypted_buffer_len - ssl->buffer_to_decrypt_offset_;
  652. return hssl_read(_ssl, buf, len);
  653. }
  654. int hssl_write(hssl_t _ssl, const void* buf, int len)
  655. {
  656. struct wintls_s* ssl = _ssl;
  657. SecPkgContext_StreamSizes* stream_sizes = &ssl->stream_sizes_;
  658. if (len > (int)stream_sizes->cbMaximumMessage) {
  659. len = stream_sizes->cbMaximumMessage;
  660. }
  661. // Calculate the minimum output buffer length
  662. int min_out_len = stream_sizes->cbHeader + len + stream_sizes->cbTrailer;
  663. if (min_out_len > TLS_SOCKET_BUFFER_SIZE) {
  664. printe("encrypt_message: Output buffer is too small");
  665. return -1;
  666. }
  667. // Initialize the secure buffers
  668. SecBuffer secure_buffers[4] = { 0 };
  669. init_sec_buffer(&secure_buffers[0], SECBUFFER_STREAM_HEADER, stream_sizes->cbHeader, ssl->encrypted_buffer_);
  670. init_sec_buffer(&secure_buffers[1], SECBUFFER_DATA, len, ssl->encrypted_buffer_ + stream_sizes->cbHeader);
  671. init_sec_buffer(&secure_buffers[2], SECBUFFER_STREAM_TRAILER, stream_sizes->cbTrailer, ssl->encrypted_buffer_ + stream_sizes->cbHeader + len);
  672. init_sec_buffer(&secure_buffers[3], SECBUFFER_EMPTY, 0, NULL);
  673. // Initialize the secure buffer descriptor
  674. SecBufferDesc secure_buffer_desc = { 0 };
  675. init_sec_buffer_desc(&secure_buffer_desc, SECBUFFER_VERSION, 4, secure_buffers);
  676. // Copy the input buffer to the data buffer
  677. memcpy(secure_buffers[1].pvBuffer, buf, len);
  678. // Encrypt the message using the security context
  679. SECURITY_STATUS sec_status = EncryptMessage(&ssl->sechandle, 0, &secure_buffer_desc, 0);
  680. // Check the encryption status and the data buffer length
  681. if (sec_status != SEC_E_OK) {
  682. printe("encrypt_message EncryptMessage %d\n", sec_status);
  683. return -1;
  684. }
  685. if (secure_buffers[1].cbBuffer > (unsigned int)len) {
  686. printe("encrypt_message: Data buffer is too large\n");
  687. return -1;
  688. }
  689. // Adjust the minimum output buffer length
  690. min_out_len = secure_buffers[0].cbBuffer + secure_buffers[1].cbBuffer + secure_buffers[2].cbBuffer;
  691. printd("enc02: %d %d\n", secure_buffers[0].cbBuffer, secure_buffers[2].cbBuffer);
  692. // Send the encrypted message to the socket
  693. int offset = __sendwrapper(ssl->fd, ssl->encrypted_buffer_, min_out_len, 0);
  694. // Check the send result
  695. if (offset != min_out_len) {
  696. printe("hssl_write: Send failed\n");
  697. return -1;
  698. } else {
  699. printd("hssl_write: Send %d\n", min_out_len);
  700. }
  701. // Return the number of bytes sent excluding the header and trailer
  702. return offset - secure_buffers[0].cbBuffer - secure_buffers[2].cbBuffer;
  703. }
  704. int hssl_close(hssl_t _ssl)
  705. {
  706. return 0;
  707. }
  708. int hssl_set_sni_hostname(hssl_t _ssl, const char* hostname)
  709. {
  710. struct wintls_s* ssl = _ssl;
  711. ssl->sni = strdup(hostname);
  712. return 0;
  713. }
  714. #endif // WITH_WINTLS