| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837 |
- #include "hssl.h"
- #ifdef WITH_WINTLS
- // #define PRINT_DEBUG
- // #define PRINT_ERROR
- #include "hdef.h"
- #include <schannel.h>
- #include <wincrypt.h>
- #include <windows.h>
- #include <wintrust.h>
- #define SECURITY_WIN32
- #include <security.h>
- #include <sspi.h>
- #define TLS_SOCKET_BUFFER_SIZE 17000
- #ifndef SP_PROT_SSL2_SERVER
- #define SP_PROT_SSL2_SERVER 0x00000004
- #endif
- #ifndef SP_PROT_SSL2_CLIENT
- #define SP_PROT_SSL2_CLIENT 0x00000008
- #endif
- #ifndef SP_PROT_SSL3_SERVER
- #define SP_PROT_SSL3_SERVER 0x00000010
- #endif
- #ifndef SP_PROT_SSL3_CLIENT
- #define SP_PROT_SSL3_CLIENT 0x00000020
- #endif
- #ifndef SP_PROT_TLS1_SERVER
- #define SP_PROT_TLS1_SERVER 0x00000040
- #endif
- #ifndef SP_PROT_TLS1_CLIENT
- #define SP_PROT_TLS1_CLIENT 0x00000080
- #endif
- #ifndef SP_PROT_TLS1_0_SERVER
- #define SP_PROT_TLS1_0_SERVER SP_PROT_TLS1_SERVER
- #endif
- #ifndef SP_PROT_TLS1_0_CLIENT
- #define SP_PROT_TLS1_0_CLIENT SP_PROT_TLS1_CLIENT
- #endif
- #ifndef SP_PROT_TLS1_1_SERVER
- #define SP_PROT_TLS1_1_SERVER 0x00000100
- #endif
- #ifndef SP_PROT_TLS1_1_CLIENT
- #define SP_PROT_TLS1_1_CLIENT 0x00000200
- #endif
- #ifndef SP_PROT_TLS1_2_SERVER
- #define SP_PROT_TLS1_2_SERVER 0x00000400
- #endif
- #ifndef SP_PROT_TLS1_2_CLIENT
- #define SP_PROT_TLS1_2_CLIENT 0x00000800
- #endif
- #ifndef SP_PROT_TLS1_3_SERVER
- #define SP_PROT_TLS1_3_SERVER 0x00001000
- #endif
- #ifndef SP_PROT_TLS1_3_CLIENT
- #define SP_PROT_TLS1_3_CLIENT 0x00002000
- #endif
- #ifndef SCH_USE_STRONG_CRYPTO
- #define SCH_USE_STRONG_CRYPTO 0x00400000
- #endif
- #ifndef SECBUFFER_ALERT
- #define SECBUFFER_ALERT 17
- #endif
- const char* hssl_backend()
- {
- return "schannel";
- }
- static PCCERT_CONTEXT getservercert(const char* path)
- {
- /*
- According to the information I searched from the internet, it is not possible to specify an x509 private key and certificate using the
- CertCreateCertificateContext interface. We must first export them as a pkcs#12 formatted file, and then import them into the Windows certificate store. This
- is because the Windows certificate store is an integrated system location that does not support the direct use of separate private key files and certificate
- files. The pkcs#12 format is a complex format that can store and protect keys and certificates. You can use the OpenSSL tool to combine the private key file
- and certificate file into a pkcs#12 formatted file, For example: OpenSSL pkcs12 -export -out cert.pfx -inkey private.key -in cert.cer Then, you can use the
- certutil tool or a graphical interface to import this file into the personal store of your local computer. After importing, you can use the
- CertFindCertificateInStore interface to create and manipulate certificate contexts.
- */
- return NULL;
- }
- hssl_ctx_t hssl_ctx_new(hssl_ctx_opt_t* opt)
- {
- SECURITY_STATUS SecStatus;
- TimeStamp Lifetime;
- CredHandle* hCred = NULL;
- SCHANNEL_CRED credData = { 0 };
- TCHAR unisp_name[] = UNISP_NAME;
- unsigned long credflag;
- if (opt && opt->endpoint == HSSL_SERVER) {
- PCCERT_CONTEXT serverCert = NULL; // server-side certificate
- #if 1 // create cert from store
- //-------------------------------------------------------
- // Get the server certificate.
- //-------------------------------------------------------
- // Open the My store(personal store).
- HCERTSTORE hMyCertStore = CertOpenStore(CERT_STORE_PROV_SYSTEM, X509_ASN_ENCODING, 0, CERT_SYSTEM_STORE_LOCAL_MACHINE, L"MY");
- if (hMyCertStore == NULL) {
- printe("Error opening MY store for server.\n");
- return NULL;
- }
- //-------------------------------------------------------
- // Search for a certificate match its subject string to opt->crt_file.
- serverCert = CertFindCertificateInStore(hMyCertStore, X509_ASN_ENCODING, 0, CERT_FIND_SUBJECT_STR_A, opt->crt_file, NULL);
- CertCloseStore(hMyCertStore, 0);
- if (serverCert == NULL) {
- printe("Error retrieving server certificate. %x\n", GetLastError());
- return NULL;
- }
- #else
- serverCert = getservercert(opt->ca_file);
- #endif
- credData.cCreds = 1; // 数量
- credData.paCred = &serverCert;
- // credData.dwCredFormat = SCH_CRED_FORMAT_CERT_HASH;
- credData.grbitEnabledProtocols = SP_PROT_TLS1_2_SERVER | SP_PROT_TLS1_3_SERVER;
- credflag = SECPKG_CRED_INBOUND;
- } else {
- credData.grbitEnabledProtocols = SP_PROT_TLS1_2_CLIENT | SP_PROT_TLS1_3_CLIENT;
- credflag = SECPKG_CRED_OUTBOUND;
- }
- #if 0 // just use the system defalut algs
- ALG_ID rgbSupportedAlgs[4];
- rgbSupportedAlgs[0] = CALG_DH_EPHEM;
- rgbSupportedAlgs[1] = CALG_RSA_KEYX;
- rgbSupportedAlgs[2] = CALG_AES_128;
- rgbSupportedAlgs[3] = CALG_SHA_256;
- credData.cSupportedAlgs = 4;
- credData.palgSupportedAlgs = rgbSupportedAlgs;
- #endif
- credData.dwVersion = SCHANNEL_CRED_VERSION;
- // credData.dwFlags = SCH_CRED_NO_DEFAULT_CREDS | SCH_CRED_NO_SERVERNAME_CHECK | SCH_USE_STRONG_CRYPTO | SCH_CRED_MANUAL_CRED_VALIDATION | SCH_CRED_IGNORE_NO_REVOCATION_CHECK | SCH_CRED_IGNORE_REVOCATION_OFFLINE;
- // credData.dwFlags = SCH_CRED_AUTO_CRED_VALIDATION | SCH_CRED_REVOCATION_CHECK_CHAIN | SCH_CRED_IGNORE_REVOCATION_OFFLINE;
- // credData.dwMinimumCipherStrength = -1;
- // credData.dwMaximumCipherStrength = -1;
- //-------------------------------------------------------
- hCred = (CredHandle*)malloc(sizeof(CredHandle));
- if (hCred == NULL) {
- return NULL;
- }
- SecStatus = AcquireCredentialsHandle(NULL, unisp_name, credflag, NULL, &credData, NULL, NULL, hCred, &Lifetime);
- if (SecStatus == SEC_E_OK) {
- #ifndef NDEBUG
- SecPkgCred_SupportedAlgs algs;
- if (QueryCredentialsAttributesA(hCred, SECPKG_ATTR_SUPPORTED_ALGS, &algs) == SEC_E_OK) {
- for (int i = 0; i < algs.cSupportedAlgs; i++) {
- printd("alg: 0x%08x\n", algs.palgSupportedAlgs[i]);
- }
- }
- #endif
- } else {
- printe("ERROR: AcquireCredentialsHandle: 0x%x\n", SecStatus);
- free(hCred);
- hCred = NULL;
- }
- return hCred;
- }
- void hssl_ctx_free(hssl_ctx_t ssl_ctx)
- {
- SECURITY_STATUS sec_status = FreeCredentialsHandle(ssl_ctx);
- if (sec_status != SEC_E_OK) {
- printe("free_cred_handle FreeCredentialsHandle %d\n", sec_status);
- }
- }
- static void init_sec_buffer(SecBuffer* secure_buffer, unsigned long type, unsigned long len, void* buffer)
- {
- secure_buffer->BufferType = type;
- secure_buffer->cbBuffer = len;
- secure_buffer->pvBuffer = buffer;
- }
- static void init_sec_buffer_desc(SecBufferDesc* secure_buffer_desc, unsigned long version, unsigned long num_buffers, SecBuffer* buffers)
- {
- secure_buffer_desc->ulVersion = version;
- secure_buffer_desc->cBuffers = num_buffers;
- secure_buffer_desc->pBuffers = buffers;
- }
- /* enum for the nonblocking SSL connection state machine */
- typedef enum {
- ssl_connect_1,
- ssl_connect_2,
- ssl_connect_2_reading,
- ssl_connect_2_writing,
- ssl_connect_3,
- ssl_connect_done
- } ssl_connect_state;
- struct wintls_s {
- hssl_ctx_t ssl_ctx; // CredHandle
- int fd;
- union {
- ssl_connect_state state2;
- ssl_connect_state connecting_state;
- };
- SecHandle sechandle;
- SecPkgContext_StreamSizes stream_sizes_;
- size_t buffer_to_decrypt_offset_;
- size_t dec_len_;
- char encrypted_buffer_[TLS_SOCKET_BUFFER_SIZE];
- char buffer_to_decrypt_[TLS_SOCKET_BUFFER_SIZE];
- char decrypted_buffer_[TLS_SOCKET_BUFFER_SIZE + TLS_SOCKET_BUFFER_SIZE];
- char* sni;
- };
- hssl_t hssl_new(hssl_ctx_t ssl_ctx, int fd)
- {
- struct wintls_s* ret = malloc(sizeof(*ret));
- if (ret) {
- memset(ret, 0, sizeof(*ret));
- ret->ssl_ctx = ssl_ctx;
- ret->fd = fd;
- ret->sechandle.dwLower = 0;
- ret->sechandle.dwUpper = 0;
- }
- return ret;
- }
- void hssl_free(hssl_t _ssl)
- {
- struct wintls_s* ssl = _ssl;
- SECURITY_STATUS sec_status = DeleteSecurityContext(&ssl->sechandle);
- if (sec_status != SEC_E_OK) {
- printe("hssl_free DeleteSecurityContext %d", sec_status);
- }
- if (ssl->sni) {
- free(ssl->sni);
- }
- free(ssl);
- }
- static void free_all_buffers(SecBufferDesc* secure_buffer_desc)
- {
- for (unsigned long i = 0; i < secure_buffer_desc->cBuffers; ++i) {
- void* buffer = secure_buffer_desc->pBuffers[i].pvBuffer;
- if (buffer != NULL) {
- FreeContextBuffer(buffer);
- }
- }
- }
- static int __sendwrapper(SOCKET fd, const char* buf, size_t len, int flags)
- {
- int left = len;
- int offset = 0;
- while (left > 0) {
- int bytes_sent = send(fd, buf + offset, left, flags);
- if (bytes_sent == 0 || (bytes_sent == SOCKET_ERROR && WSAGetLastError() != WSAEWOULDBLOCK && WSAGetLastError() != WSAEINTR)) {
- break;
- }
- if (bytes_sent > 0) {
- offset += bytes_sent;
- left -= bytes_sent;
- }
- }
- return offset;
- }
- static int __recvwrapper(SOCKET fd, char* buf, int len, int flags)
- {
- int ret = 0;
- do {
- ret = recv(fd, buf, len, flags);
- } while (ret == SOCKET_ERROR && WSAGetLastError() == WSAEINTR);
- return ret;
- }
- int hssl_accept(hssl_t ssl)
- {
- int ret = HSSL_ERROR;
- struct wintls_s* winssl = ssl;
- bool authn_completed = false;
- // Input buffer
- char buffer_in[TLS_SOCKET_BUFFER_SIZE];
- SecBuffer secure_buffer_in[2] = { 0 };
- init_sec_buffer(&secure_buffer_in[0], SECBUFFER_TOKEN, TLS_SOCKET_BUFFER_SIZE, buffer_in);
- init_sec_buffer(&secure_buffer_in[1], SECBUFFER_EMPTY, 0, NULL);
- SecBufferDesc secure_buffer_desc_in = { 0 };
- init_sec_buffer_desc(&secure_buffer_desc_in, SECBUFFER_VERSION, 2, secure_buffer_in);
- // Output buffer
- SecBuffer secure_buffer_out[3] = { 0 };
- init_sec_buffer(&secure_buffer_out[0], SECBUFFER_TOKEN, 0, NULL);
- init_sec_buffer(&secure_buffer_out[1], SECBUFFER_ALERT, 0, NULL);
- init_sec_buffer(&secure_buffer_out[2], SECBUFFER_EMPTY, 0, NULL);
- SecBufferDesc secure_buffer_desc_out = { 0 };
- init_sec_buffer_desc(&secure_buffer_desc_out, SECBUFFER_VERSION, 3, secure_buffer_out);
- unsigned long context_requirements = ASC_REQ_ALLOCATE_MEMORY | ASC_REQ_CONFIDENTIALITY;
- // We use ASC_REQ_ALLOCATE_MEMORY which means the buffers will be allocated for us, we need to make sure we free them.
- ULONG context_attributes = 0;
- TimeStamp life_time = { 0 };
- secure_buffer_in[0].cbBuffer = __recvwrapper(winssl->fd, (char*)secure_buffer_in[0].pvBuffer, TLS_SOCKET_BUFFER_SIZE, 0);
- // printd("%s recv %d %d\n", __func__, secure_buffer_in[0].cbBuffer, WSAGetLastError());
- if (secure_buffer_in[0].cbBuffer == SOCKET_ERROR && WSAGetLastError() == WSAEWOULDBLOCK) {
- ret = HSSL_WANT_READ;
- } else if (secure_buffer_in[0].cbBuffer > 0) {
- SECURITY_STATUS sec_status = AcceptSecurityContext(winssl->ssl_ctx, winssl->state2 == 0 ? NULL : &winssl->sechandle, &secure_buffer_desc_in,
- context_requirements, 0, &winssl->sechandle, &secure_buffer_desc_out, &context_attributes, &life_time);
- winssl->state2 = 1;
- // printd("establish_server_security_context AcceptSecurityContext %x\n", sec_status);
- if (secure_buffer_out[0].cbBuffer > 0) {
- int rc = __sendwrapper(winssl->fd, (const char*)secure_buffer_out[0].pvBuffer, secure_buffer_out[0].cbBuffer, 0);
- if (rc != secure_buffer_out[0].cbBuffer) {
- goto END;
- }
- }
- switch (sec_status) {
- case SEC_E_OK:
- ret = HSSL_OK;
- authn_completed = true;
- break;
- case SEC_I_CONTINUE_NEEDED:
- ret = HSSL_WANT_READ;
- break;
- case SEC_I_COMPLETE_AND_CONTINUE:
- case SEC_I_COMPLETE_NEEDED: {
- SECURITY_STATUS complete_sec_status = SEC_E_OK;
- complete_sec_status = CompleteAuthToken(&winssl->sechandle, &secure_buffer_desc_out);
- if (complete_sec_status != SEC_E_OK) {
- printe("establish_server_security_context CompleteAuthToken %x\n", complete_sec_status);
- goto END;
- }
- if (sec_status == SEC_I_COMPLETE_NEEDED) {
- authn_completed = true;
- ret = HSSL_OK;
- } else {
- ret = HSSL_WANT_READ;
- }
- break;
- }
- default:
- break;
- }
- }
- END:
- free_all_buffers(&secure_buffer_desc_out);
- if (authn_completed) {
- SECURITY_STATUS sec_status = QueryContextAttributes(&winssl->sechandle, SECPKG_ATTR_STREAM_SIZES, &winssl->stream_sizes_);
- if (sec_status != SEC_E_OK) {
- printe("get_stream_sizes QueryContextAttributes %d\n", sec_status);
- }
- }
- return ret;
- }
- static int schannel_connect_step1(struct wintls_s* ssl)
- {
- int ret = 0;
- ULONG context_attributes = 0;
- unsigned long context_requirements = ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT | ISC_REQ_CONFIDENTIALITY | ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_STREAM;
- TimeStamp life_time = { 0 };
- SecBuffer secure_buffer_out[1] = { 0 };
- init_sec_buffer(&secure_buffer_out[0], SECBUFFER_EMPTY, 0, NULL);
- SecBufferDesc secure_buffer_desc_out = { 0 };
- init_sec_buffer_desc(&secure_buffer_desc_out, SECBUFFER_VERSION, 1, secure_buffer_out);
- SECURITY_STATUS sec_status = InitializeSecurityContext(ssl->ssl_ctx, NULL, ssl->sni, context_requirements, 0, 0, NULL, 0, &ssl->sechandle,
- &secure_buffer_desc_out, &context_attributes, &life_time);
- if (sec_status != SEC_I_CONTINUE_NEEDED) {
- printe("1InitializeSecurityContext: %x\n", sec_status);
- }
- if (secure_buffer_out[0].cbBuffer > 0) {
- int rc = __sendwrapper(ssl->fd, (const char*)secure_buffer_out[0].pvBuffer, secure_buffer_out[0].cbBuffer, 0);
- if (rc != secure_buffer_out[0].cbBuffer) {
- // TODO: Handle the error
- printe("%s :send failed\n", __func__);
- ret = -1;
- } else {
- printd("%s :send len=%d\n", __func__, rc);
- ssl->connecting_state = ssl_connect_2;
- }
- }
- free_all_buffers(&secure_buffer_desc_out);
- return ret;
- }
- static int schannel_connect_step2(struct wintls_s* ssl)
- {
- int ret = HSSL_ERROR;
- ULONG context_attributes = 0;
- bool verify_server_cert = 0;
- unsigned long context_requirements = ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT | ISC_REQ_CONFIDENTIALITY | ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_STREAM;
- if (!verify_server_cert) {
- context_requirements |= ISC_REQ_MANUAL_CRED_VALIDATION;
- }
- TimeStamp life_time = { 0 };
- // Allocate a temporary buffer for input
- char* buffer_in = malloc(TLS_SOCKET_BUFFER_SIZE);
- if (buffer_in == NULL) {
- printe("schannel_connect_step2: Memory allocation failed\n");
- return HSSL_ERROR;
- }
- int offset = 0;
- bool skip_recv = false;
- bool authn_complete = false;
- while (!authn_complete) {
- int in_buffer_size = 0;
- if (!skip_recv) {
- int received = __recvwrapper(ssl->fd, buffer_in + offset, TLS_SOCKET_BUFFER_SIZE, 0);
- if (received == SOCKET_ERROR) {
- if (WSAGetLastError() == WSAEWOULDBLOCK) {
- ret = HSSL_WANT_READ;
- } else {
- printe("schannel_connect_step2: Receive failed\n");
- }
- break;
- } else if (received == 0) {
- printe("schannel_connect_step2: peer closed\n");
- break;
- }
- in_buffer_size = received + offset;
- } else {
- in_buffer_size = offset;
- }
- skip_recv = false;
- offset = 0;
- // Input buffer
- SecBuffer secure_buffer_in[4] = { 0 };
- init_sec_buffer(&secure_buffer_in[0], SECBUFFER_TOKEN, in_buffer_size, buffer_in);
- init_sec_buffer(&secure_buffer_in[1], SECBUFFER_EMPTY, 0, NULL);
- SecBufferDesc secure_buffer_desc_in = { 0 };
- init_sec_buffer_desc(&secure_buffer_desc_in, SECBUFFER_VERSION, 2, secure_buffer_in);
- // Output buffer
- SecBuffer secure_buffer_out[3] = { 0 };
- init_sec_buffer(&secure_buffer_out[0], SECBUFFER_TOKEN, 0, NULL);
- init_sec_buffer(&secure_buffer_out[1], SECBUFFER_ALERT, 0, NULL);
- init_sec_buffer(&secure_buffer_out[2], SECBUFFER_EMPTY, 0, NULL);
- SecBufferDesc secure_buffer_desc_out = { 0 };
- init_sec_buffer_desc(&secure_buffer_desc_out, SECBUFFER_VERSION, 3, secure_buffer_out);
- printd("h2:%d\n", in_buffer_size);
- SECURITY_STATUS sec_status = InitializeSecurityContext(ssl->ssl_ctx, &ssl->sechandle, ssl->sni, context_requirements, 0, 0, &secure_buffer_desc_in, 0,
- &ssl->sechandle, &secure_buffer_desc_out, &context_attributes, &life_time);
- printd("h2 0x%x inbuf[1] type=%d %d inbuf[0]=%d\n", sec_status, secure_buffer_in[1].BufferType, secure_buffer_in[1].cbBuffer, secure_buffer_in[0].cbBuffer);
- if (sec_status == SEC_E_OK || sec_status == SEC_I_CONTINUE_NEEDED) {
- // for (size_t i = 0; i < 3; i++) {
- // printd("obuf[%zu] type=%d %d\n", i, secure_buffer_out[i].BufferType, secure_buffer_out[i].cbBuffer);
- // }
- if (secure_buffer_out[0].cbBuffer > 0) {
- int rc = __sendwrapper(ssl->fd, (const char*)secure_buffer_out[0].pvBuffer, secure_buffer_out[0].cbBuffer, 0);
- if (rc != secure_buffer_out[0].cbBuffer) {
- printe("schannel_connect_step2: Send failed\n");
- // TODO: Handle the error
- break;
- }
- // printd("%s :send ok\n", __func__);
- }
- if (sec_status == SEC_I_CONTINUE_NEEDED) {
- if (secure_buffer_in[1].BufferType == SECBUFFER_EXTRA && secure_buffer_in[1].cbBuffer > 0) {
- offset = secure_buffer_in[0].cbBuffer - secure_buffer_in[1].cbBuffer;
- memmove(buffer_in, buffer_in + offset, secure_buffer_in[1].cbBuffer);
- offset = secure_buffer_in[1].cbBuffer;
- skip_recv = true;
- }
- } else if (sec_status == SEC_E_OK) {
- authn_complete = true;
- ret = HSSL_OK;
- ssl->connecting_state = ssl_connect_3;
- }
- } else if (sec_status == SEC_E_INCOMPLETE_MESSAGE) {
- offset = secure_buffer_in[0].cbBuffer;
- } else {
- printe("2InitializeSecurityContext: 0x%x\n", sec_status);
- break;
- }
- free_all_buffers(&secure_buffer_desc_out);
- }
- // END:
- free(buffer_in); // Free the temporary buffer
- return ret;
- }
- static void dumpconninfo(SecHandle* sechandle)
- {
- SECURITY_STATUS Status;
- SecPkgContext_ConnectionInfo ConnectionInfo;
- Status = QueryContextAttributes(sechandle,
- SECPKG_ATTR_CONNECTION_INFO,
- (PVOID)&ConnectionInfo);
- if (Status != SEC_E_OK) {
- printe("Error 0x%x querying connection info\n", Status);
- return;
- }
- printd("\n");
- switch (ConnectionInfo.dwProtocol) {
- case SP_PROT_TLS1_CLIENT:
- printd("Protocol: TLS1\n");
- break;
- case SP_PROT_SSL3_CLIENT:
- printd("Protocol: SSL3\n");
- break;
- case SP_PROT_SSL2_CLIENT:
- printd("Protocol: SSL2\n");
- break;
- case SP_PROT_PCT1_CLIENT:
- printd("Protocol: PCT\n");
- break;
- default:
- printd("Protocol: 0x%x\n", ConnectionInfo.dwProtocol);
- }
- switch (ConnectionInfo.aiCipher) {
- case CALG_RC4:
- printd("Cipher: RC4\n");
- break;
- case CALG_3DES:
- printd("Cipher: Triple DES\n");
- break;
- case CALG_RC2:
- printd("Cipher: RC2\n");
- break;
- case CALG_DES:
- case CALG_CYLINK_MEK:
- printd("Cipher: DES\n");
- break;
- case CALG_SKIPJACK:
- printd("Cipher: Skipjack\n");
- break;
- case CALG_AES_128:
- printd("Cipher: aes128\n");
- break;
- default:
- printd("Cipher: 0x%x\n", ConnectionInfo.aiCipher);
- }
- printd("Cipher strength: %d\n", ConnectionInfo.dwCipherStrength);
- switch (ConnectionInfo.aiHash) {
- case CALG_MD5:
- printd("Hash: MD5\n");
- break;
- case CALG_SHA:
- printd("Hash: SHA\n");
- break;
- default:
- printd("Hash: 0x%x\n", ConnectionInfo.aiHash);
- }
- printd("Hash strength: %d\n", ConnectionInfo.dwHashStrength);
- switch (ConnectionInfo.aiExch) {
- case CALG_RSA_KEYX:
- case CALG_RSA_SIGN:
- printd("Key exchange: RSA\n");
- break;
- case CALG_KEA_KEYX:
- printd("Key exchange: KEA\n");
- break;
- case CALG_DH_EPHEM:
- printd("Key exchange: DH Ephemeral\n");
- break;
- default:
- printd("Key exchange: 0x%x\n", ConnectionInfo.aiExch);
- }
- printd("Key exchange strength: %d\n", ConnectionInfo.dwExchStrength);
- }
- int hssl_connect(hssl_t _ssl)
- {
- int ret = 0;
- struct wintls_s* ssl = _ssl;
- if (ssl->connecting_state == ssl_connect_1) {
- ret = schannel_connect_step1(ssl);
- }
- if (!ret && ssl->connecting_state == ssl_connect_2) {
- ret = schannel_connect_step2(ssl);
- }
- // printd("%s %x\n", __func__, ret);
- if (!ret) {
- if (ssl->connecting_state == ssl_connect_3) {
- // ret = schannel_connect_step3(ssl);
- }
- SECURITY_STATUS sec_status = QueryContextAttributes(&ssl->sechandle, SECPKG_ATTR_STREAM_SIZES, &ssl->stream_sizes_);
- if (sec_status != SEC_E_OK) {
- printe("get_stream_sizes QueryContextAttributes %d\n", sec_status);
- } else {
- printd("stream_sizes bs:%d h:%d t:%d max:%d bfs:%d\n", ssl->stream_sizes_.cbBlockSize, ssl->stream_sizes_.cbHeader, ssl->stream_sizes_.cbTrailer, ssl->stream_sizes_.cbMaximumMessage, ssl->stream_sizes_.cBuffers);
- }
- dumpconninfo(&ssl->sechandle);
- }
- return ret;
- }
- static int decrypt_message(SecHandle security_context, unsigned long* extra, char* in_buf, int in_len, char* out_buf, int out_len)
- {
- printd("%s: inlen=%d\n", __func__, in_len);
- // Initialize the secure buffers
- SecBuffer secure_buffers[4] = { 0 };
- init_sec_buffer(&secure_buffers[0], SECBUFFER_DATA, in_len, in_buf);
- init_sec_buffer(&secure_buffers[1], SECBUFFER_EMPTY, 0, NULL);
- init_sec_buffer(&secure_buffers[2], SECBUFFER_EMPTY, 0, NULL);
- init_sec_buffer(&secure_buffers[3], SECBUFFER_EMPTY, 0, NULL);
- // Initialize the secure buffer descriptor
- SecBufferDesc secure_buffer_desc = { 0 };
- init_sec_buffer_desc(&secure_buffer_desc, SECBUFFER_VERSION, 4, secure_buffers);
- // Decrypt the message using the security context
- SECURITY_STATUS sec_status = DecryptMessage(&security_context, &secure_buffer_desc, 0, NULL);
- for (size_t i = 1; i < 4; i++) {
- printd("%d: %u %u\n", i, secure_buffers[i].BufferType, secure_buffers[i].cbBuffer);
- }
- if (sec_status == SEC_E_INCOMPLETE_MESSAGE) {
- printe("decrypt_message SEC_E_INCOMPLETE_MESSAGE\n");
- return -1;
- } else if (sec_status == SEC_E_DECRYPT_FAILURE) {
- printe("decrypt_message ignore SEC_E_DECRYPT_FAILURE\n");
- return 0;
- } else if (sec_status == SEC_E_UNSUPPORTED_FUNCTION) {
- printe("decrypt_message ignore SEC_E_UNSUPPORTED_FUNCTION\n");
- return 0;
- }
- if (sec_status != SEC_E_OK) {
- printe("decrypt_message DecryptMessage: 0x%x\n", sec_status);
- return -1;
- }
- if (secure_buffers[3].BufferType == SECBUFFER_EXTRA && secure_buffers[3].cbBuffer > 0) {
- *extra = secure_buffers[3].cbBuffer;
- }
- memcpy(out_buf, secure_buffers[1].pvBuffer, secure_buffers[1].cbBuffer);
- // printd("ob:%s\n", out_buf);
- return secure_buffers[1].cbBuffer;
- }
- int hssl_read(hssl_t _ssl, void* buf, int len)
- {
- struct wintls_s* ssl = _ssl;
- printd("%s: dec_len_= %zu\n", __func__, ssl->dec_len_);
- if (ssl->dec_len_ > 0) {
- if (buf == NULL) {
- return 0;
- }
- int decrypted = MIN(ssl->dec_len_, len);
- memcpy(buf, ssl->decrypted_buffer_, (size_t)decrypted);
- ssl->dec_len_ -= decrypted;
- if (ssl->dec_len_) {
- memmove(ssl->decrypted_buffer_, ssl->decrypted_buffer_ + decrypted, (size_t)ssl->dec_len_);
- } else {
- // hssl_read(_ssl, NULL, 0);
- }
- return decrypted;
- }
- // We might have leftovers, an incomplete message from a previous call.
- // Calculate the available buffer length for tcp recv.
- int recv_max_len = TLS_SOCKET_BUFFER_SIZE - ssl->buffer_to_decrypt_offset_;
- int bytes_received = __recvwrapper(ssl->fd, ssl->buffer_to_decrypt_ + ssl->buffer_to_decrypt_offset_, recv_max_len, 0);
- // printd("%s recv %d %d\n", __func__, bytes_received, WSAGetLastError());
- if (bytes_received == SOCKET_ERROR) {
- if (WSAGetLastError() == WSAEWOULDBLOCK) {
- bytes_received = 0;
- return 0;
- } else {
- return -1;
- }
- } else if (bytes_received == 0) {
- return 0;
- }
- int encrypted_buffer_len = ssl->buffer_to_decrypt_offset_ + bytes_received;
- ssl->buffer_to_decrypt_offset_ = 0;
- while (true) {
- // printd("%s:buffer_to_decrypt_offset_ = %d , encrypted_buffer_len= %d\n", __func__, ssl->buffer_to_decrypt_offset_, encrypted_buffer_len);
- if (ssl->buffer_to_decrypt_offset_ >= encrypted_buffer_len) {
- // Reached the encrypted buffer length, we decrypted everything so we can stop.
- break;
- }
- unsigned long extra = 0;
- int decrypted_len = decrypt_message(ssl->sechandle, &extra, ssl->buffer_to_decrypt_ + ssl->buffer_to_decrypt_offset_,
- encrypted_buffer_len - ssl->buffer_to_decrypt_offset_, ssl->decrypted_buffer_ + ssl->dec_len_,
- TLS_SOCKET_BUFFER_SIZE + TLS_SOCKET_BUFFER_SIZE - ssl->dec_len_);
- if (decrypted_len == -1) {
- // Incomplete message, we shuold keep it so it will be decrypted on the next call to recv().
- // Shift the remaining buffer to the beginning and break the loop.
- memmove(ssl->buffer_to_decrypt_, ssl->buffer_to_decrypt_ + ssl->buffer_to_decrypt_offset_, encrypted_buffer_len - ssl->buffer_to_decrypt_offset_);
- break;
- }
- ssl->dec_len_ += decrypted_len;
- ssl->buffer_to_decrypt_offset_ = encrypted_buffer_len - extra;
- }
- ssl->buffer_to_decrypt_offset_ = encrypted_buffer_len - ssl->buffer_to_decrypt_offset_;
- return hssl_read(_ssl, buf, len);
- }
- int hssl_write(hssl_t _ssl, const void* buf, int len)
- {
- struct wintls_s* ssl = _ssl;
- SecPkgContext_StreamSizes* stream_sizes = &ssl->stream_sizes_;
- if (len > (int)stream_sizes->cbMaximumMessage) {
- len = stream_sizes->cbMaximumMessage;
- }
- // Calculate the minimum output buffer length
- int min_out_len = stream_sizes->cbHeader + len + stream_sizes->cbTrailer;
- if (min_out_len > TLS_SOCKET_BUFFER_SIZE) {
- printe("encrypt_message: Output buffer is too small");
- return -1;
- }
- // Initialize the secure buffers
- SecBuffer secure_buffers[4] = { 0 };
- init_sec_buffer(&secure_buffers[0], SECBUFFER_STREAM_HEADER, stream_sizes->cbHeader, ssl->encrypted_buffer_);
- init_sec_buffer(&secure_buffers[1], SECBUFFER_DATA, len, ssl->encrypted_buffer_ + stream_sizes->cbHeader);
- init_sec_buffer(&secure_buffers[2], SECBUFFER_STREAM_TRAILER, stream_sizes->cbTrailer, ssl->encrypted_buffer_ + stream_sizes->cbHeader + len);
- init_sec_buffer(&secure_buffers[3], SECBUFFER_EMPTY, 0, NULL);
- // Initialize the secure buffer descriptor
- SecBufferDesc secure_buffer_desc = { 0 };
- init_sec_buffer_desc(&secure_buffer_desc, SECBUFFER_VERSION, 4, secure_buffers);
- // Copy the input buffer to the data buffer
- memcpy(secure_buffers[1].pvBuffer, buf, len);
- // Encrypt the message using the security context
- SECURITY_STATUS sec_status = EncryptMessage(&ssl->sechandle, 0, &secure_buffer_desc, 0);
- // Check the encryption status and the data buffer length
- if (sec_status != SEC_E_OK) {
- printe("encrypt_message EncryptMessage %d\n", sec_status);
- return -1;
- }
- if (secure_buffers[1].cbBuffer > (unsigned int)len) {
- printe("encrypt_message: Data buffer is too large\n");
- return -1;
- }
- // Adjust the minimum output buffer length
- min_out_len = secure_buffers[0].cbBuffer + secure_buffers[1].cbBuffer + secure_buffers[2].cbBuffer;
- printd("enc02: %d %d\n", secure_buffers[0].cbBuffer, secure_buffers[2].cbBuffer);
- // Send the encrypted message to the socket
- int offset = __sendwrapper(ssl->fd, ssl->encrypted_buffer_, min_out_len, 0);
- // Check the send result
- if (offset != min_out_len) {
- printe("hssl_write: Send failed\n");
- return -1;
- } else {
- printd("hssl_write: Send %d\n", min_out_len);
- }
- // Return the number of bytes sent excluding the header and trailer
- return offset - secure_buffers[0].cbBuffer - secure_buffers[2].cbBuffer;
- }
- int hssl_close(hssl_t _ssl)
- {
- return 0;
- }
- int hssl_set_sni_hostname(hssl_t _ssl, const char* hostname)
- {
- struct wintls_s* ssl = _ssl;
- ssl->sni = strdup(hostname);
- return 0;
- }
- #endif // WITH_WINTLS
|