Ver Fonte

feat: add wintls via schannel (#412)

firear há 2 anos atrás
pai
commit
9068de30e1
1 ficheiros alterados com 738 adições e 20 exclusões
  1. 738 20
      ssl/wintls.c

+ 738 - 20
ssl/wintls.c

@@ -2,48 +2,766 @@
 
 #ifdef WITH_WINTLS
 
-const char* hssl_backend() {
-    return "nossl";
+// #define PRINT_DEBUG
+// #define PRINT_ERROR
+#include "hdef.h"
+#include <schannel.h>
+#include <wincrypt.h>
+#include <windows.h>
+#include <wintrust.h>
+
+#define SECURITY_WIN32
+#include <security.h>
+#include <sspi.h>
+
+#pragma comment(lib, "Secur32.lib")
+#pragma comment(lib, "crypt32.lib")
+
+#define TLS_SOCKET_BUFFER_SIZE 17000
+
+const char* hssl_backend()
+{
+    return "schannel";
 }
 
-hssl_ctx_t hssl_ctx_new(hssl_ctx_opt_t* opt) {
-    fprintf(stderr, "Please recompile WITH_SSL.\n");
+static PCCERT_CONTEXT getservercert(const char* path)
+{
+    /*
+    According to the information I searched from the internet, it is not possible to specify an x509 private key and certificate using the
+    CertCreateCertificateContext interface. We must first export them as a pkcs#12 formatted file, and then import them into the Windows certificate store. This
+    is because the Windows certificate store is an integrated system location that does not support the direct use of separate private key files and certificate
+    files. The pkcs#12 format is a complex format that can store and protect keys and certificates. You can use the OpenSSL tool to combine the private key file
+    and certificate file into a pkcs#12 formatted file, For example: OpenSSL pkcs12 -export -out cert.pfx -inkey private.key -in cert.cer Then, you can use the
+    certutil tool or a graphical interface to import this file into the personal store of your local computer. After importing, you can use the
+    CertFindCertificateInStore interface to create and manipulate certificate contexts.
+    */
     return NULL;
 }
 
-void hssl_ctx_free(hssl_ctx_t ssl_ctx) {
+hssl_ctx_t hssl_ctx_new(hssl_ctx_opt_t* opt)
+{
+    SECURITY_STATUS SecStatus;
+    TimeStamp Lifetime;
+    CredHandle* hCred = malloc(sizeof(CredHandle));
+    SCHANNEL_CRED credData = { 0 };
+    TCHAR unisp_name[] = UNISP_NAME;
+    unsigned long credflag;
+    SecPkgCred_SupportedAlgs algs;
+
+    if (opt && opt->endpoint == HSSL_SERVER) {
+        PCCERT_CONTEXT serverCert = NULL; // server-side certificate
+#if 1 // create ceart from store
+
+        //-------------------------------------------------------
+        // Get the server certificate.
+        //-------------------------------------------------------
+        // Open the My store(personal store).
+        HCERTSTORE hMyCertStore = CertOpenStore(CERT_STORE_PROV_SYSTEM, X509_ASN_ENCODING, 0, CERT_SYSTEM_STORE_LOCAL_MACHINE, L"MY");
+        if (hMyCertStore == NULL) {
+            printe("Error opening MY store for server.\n");
+            return NULL;
+        }
+        //-------------------------------------------------------
+        // Search for a certificate match its subject string to opt->crt_file.
+        serverCert = CertFindCertificateInStore(hMyCertStore, X509_ASN_ENCODING, 0, CERT_FIND_SUBJECT_STR_A, opt->crt_file, NULL);
+        CertCloseStore(hMyCertStore, 0);
+        if (serverCert == NULL) {
+            printe("Error retrieving server certificate. %x\n", GetLastError());
+            return NULL;
+        }
+#else
+        serverCert = getservercert(opt->ca_file);
+#endif
+        credData.cCreds = 1; // 数量
+        credData.paCred = &serverCert;
+        // credData.dwCredFormat = SCH_CRED_FORMAT_CERT_HASH;
+        credData.grbitEnabledProtocols = SP_PROT_TLS1_2_SERVER | SP_PROT_TLS1_3_SERVER;
+        credflag = SECPKG_CRED_INBOUND;
+    } else {
+        credData.grbitEnabledProtocols = SP_PROT_TLS1_2_CLIENT | SP_PROT_TLS1_3_CLIENT;
+        credflag = SECPKG_CRED_OUTBOUND;
+    }
+#if 0 // just use the system defalut algs
+        ALG_ID rgbSupportedAlgs[4];
+        rgbSupportedAlgs[0] = CALG_DH_EPHEM;
+        rgbSupportedAlgs[1] = CALG_RSA_KEYX;
+        rgbSupportedAlgs[2] = CALG_AES_128;
+        rgbSupportedAlgs[3] = CALG_SHA_256;
+        credData.cSupportedAlgs = 4;
+        credData.palgSupportedAlgs = rgbSupportedAlgs;
+#endif
+    credData.dwVersion = SCHANNEL_CRED_VERSION;
+    credData.dwFlags = SCH_CRED_NO_DEFAULT_CREDS | SCH_CRED_NO_SERVERNAME_CHECK | SCH_USE_STRONG_CRYPTO | SCH_CRED_MANUAL_CRED_VALIDATION | SCH_CRED_IGNORE_NO_REVOCATION_CHECK | SCH_CRED_IGNORE_REVOCATION_OFFLINE;
+
+    // credData.dwMinimumCipherStrength = -1;
+    // credData.dwMaximumCipherStrength = -1;
+
+    //-------------------------------------------------------
+    SecStatus = AcquireCredentialsHandle(NULL, unisp_name, credflag, NULL, &credData, NULL, NULL, hCred, &Lifetime);
+    if (SecStatus != SEC_E_OK) {
+        printe("ERROR: AcquireCredentialsHandle: 0x%x\n", SecStatus);
+        abort();
+    }
+    // Return the handle to the caller.
+    SecStatus = QueryCredentialsAttributesA(hCred, SECPKG_ATTR_SUPPORTED_ALGS, &algs);
+    if (SecStatus == SEC_E_OK) {
+        for (int i = 0; i < algs.cSupportedAlgs; i++) {
+            printd("alg: 0x%08x\n", algs.palgSupportedAlgs[i]);
+        }
+    }
+
+    return hCred;
 }
 
-hssl_t hssl_new(hssl_ctx_t ssl_ctx, int fd) {
-    return (void*)(intptr_t)fd;
+void hssl_ctx_free(hssl_ctx_t ssl_ctx)
+{
+    SECURITY_STATUS sec_status = FreeCredentialsHandle(ssl_ctx);
+    if (sec_status != SEC_E_OK) {
+        printe("free_cred_handle FreeCredentialsHandle %d\n", sec_status);
+    }
 }
 
-void hssl_free(hssl_t ssl) {
+static void init_sec_buffer(SecBuffer* secure_buffer, unsigned long type, unsigned long len, void* buffer)
+{
+    secure_buffer->BufferType = type;
+    secure_buffer->cbBuffer = len;
+    secure_buffer->pvBuffer = buffer;
 }
 
-int hssl_accept(hssl_t ssl) {
-    return 0;
+static void init_sec_buffer_desc(SecBufferDesc* secure_buffer_desc, unsigned long version, unsigned long num_buffers, SecBuffer* buffers)
+{
+    secure_buffer_desc->ulVersion = version;
+    secure_buffer_desc->cBuffers = num_buffers;
+    secure_buffer_desc->pBuffers = buffers;
 }
 
-int hssl_connect(hssl_t ssl) {
-    return 0;
+/* enum for the nonblocking SSL connection state machine */
+typedef enum {
+    ssl_connect_1,
+    ssl_connect_2,
+    ssl_connect_2_reading,
+    ssl_connect_2_writing,
+    ssl_connect_3,
+    ssl_connect_done
+} ssl_connect_state;
+
+struct wintls_s {
+    hssl_ctx_t ssl_ctx; // CredHandle
+    int fd;
+    union {
+        ssl_connect_state state2;
+        ssl_connect_state connecting_state;
+    };
+    SecHandle sechandle;
+    SecPkgContext_StreamSizes stream_sizes_;
+    size_t buffer_to_decrypt_offset_;
+    size_t dec_len_;
+    char encrypted_buffer_[TLS_SOCKET_BUFFER_SIZE];
+    char buffer_to_decrypt_[TLS_SOCKET_BUFFER_SIZE];
+    char decrypted_buffer_[TLS_SOCKET_BUFFER_SIZE + TLS_SOCKET_BUFFER_SIZE];
+    char* sni;
+};
+
+hssl_t hssl_new(hssl_ctx_t ssl_ctx, int fd)
+{
+    struct wintls_s* ret = malloc(sizeof(*ret));
+    memset(ret, 0, sizeof(*ret));
+    ret->ssl_ctx = ssl_ctx;
+    ret->fd = fd;
+    ret->sechandle.dwLower = 0;
+    ret->sechandle.dwUpper = 0;
+    return ret;
+}
+
+void hssl_free(hssl_t _ssl)
+{
+    struct wintls_s* ssl = _ssl;
+    SECURITY_STATUS sec_status = DeleteSecurityContext(&ssl->sechandle);
+    if (sec_status != SEC_E_OK) {
+        printe("hssl_free DeleteSecurityContext %d", sec_status);
+    }
+    if (ssl->sni) {
+        free(ssl->sni);
+    }
+    free(ssl);
+}
+
+static void free_all_buffers(SecBufferDesc* secure_buffer_desc)
+{
+    for (unsigned long i = 0; i < secure_buffer_desc->cBuffers; ++i) {
+        void* buffer = secure_buffer_desc->pBuffers[i].pvBuffer;
+        if (buffer != NULL) {
+            FreeContextBuffer(buffer);
+        }
+    }
+}
+
+static int __sendwrapper(SOCKET fd, const char* buf, size_t len, int flags)
+{
+    int left = len;
+    int offset = 0;
+    while (left > 0) {
+        int bytes_sent = send(fd, buf + offset, left, flags);
+        if (bytes_sent == 0 || (bytes_sent == SOCKET_ERROR && WSAGetLastError() != WSAEWOULDBLOCK && WSAGetLastError() != WSAEINTR)) {
+            break;
+        }
+        if (bytes_sent > 0) {
+            offset += bytes_sent;
+            left -= bytes_sent;
+        }
+    }
+    return offset;
+}
+
+static int __recvwrapper(SOCKET fd, char* buf, int len, int flags)
+{
+    int ret = 0;
+    do {
+        ret = recv(fd, buf, len, flags);
+    } while (ret == SOCKET_ERROR && WSAGetLastError() == WSAEINTR);
+    return ret;
+}
+
+int hssl_accept(hssl_t ssl)
+{
+    int ret = HSSL_ERROR;
+    struct wintls_s* winssl = ssl;
+    bool authn_completed = false;
+
+    // Input buffer
+    char buffer_in[TLS_SOCKET_BUFFER_SIZE];
+
+    SecBuffer secure_buffer_in[2] = { 0 };
+    init_sec_buffer(&secure_buffer_in[0], SECBUFFER_TOKEN, TLS_SOCKET_BUFFER_SIZE, buffer_in);
+    init_sec_buffer(&secure_buffer_in[1], SECBUFFER_EMPTY, 0, NULL);
+
+    SecBufferDesc secure_buffer_desc_in = { 0 };
+    init_sec_buffer_desc(&secure_buffer_desc_in, SECBUFFER_VERSION, 2, secure_buffer_in);
+
+    // Output buffer
+    SecBuffer secure_buffer_out[3] = { 0 };
+    init_sec_buffer(&secure_buffer_out[0], SECBUFFER_TOKEN, 0, NULL);
+    init_sec_buffer(&secure_buffer_out[1], SECBUFFER_ALERT, 0, NULL);
+    init_sec_buffer(&secure_buffer_out[2], SECBUFFER_EMPTY, 0, NULL);
+
+    SecBufferDesc secure_buffer_desc_out = { 0 };
+    init_sec_buffer_desc(&secure_buffer_desc_out, SECBUFFER_VERSION, 3, secure_buffer_out);
+
+    unsigned long context_requirements = ASC_REQ_ALLOCATE_MEMORY | ASC_REQ_CONFIDENTIALITY;
+
+    // We use ASC_REQ_ALLOCATE_MEMORY which means the buffers will be allocated for us, we need to make sure we free them.
+
+    ULONG context_attributes = 0;
+    TimeStamp life_time = { 0 };
+
+    secure_buffer_in[0].cbBuffer = __recvwrapper(winssl->fd, (char*)secure_buffer_in[0].pvBuffer, TLS_SOCKET_BUFFER_SIZE, 0);
+    // printd("%s recv %d %d\n", __func__, secure_buffer_in[0].cbBuffer, WSAGetLastError());
+    if (secure_buffer_in[0].cbBuffer == SOCKET_ERROR && WSAGetLastError() == WSAEWOULDBLOCK) {
+        ret = HSSL_WANT_READ;
+    } else if (secure_buffer_in[0].cbBuffer > 0) {
+        SECURITY_STATUS sec_status = AcceptSecurityContext(winssl->ssl_ctx, winssl->state2 == 0 ? NULL : &winssl->sechandle, &secure_buffer_desc_in,
+            context_requirements, 0, &winssl->sechandle, &secure_buffer_desc_out, &context_attributes, &life_time);
+
+        winssl->state2 = 1;
+        // printd("establish_server_security_context AcceptSecurityContext %x\n", sec_status);
+
+        if (secure_buffer_out[0].cbBuffer > 0) {
+            int rc = __sendwrapper(winssl->fd, (const char*)secure_buffer_out[0].pvBuffer, secure_buffer_out[0].cbBuffer, 0);
+            if (rc != secure_buffer_out[0].cbBuffer) {
+                goto END;
+            }
+        }
+
+        switch (sec_status) {
+        case SEC_E_OK:
+            ret = HSSL_OK;
+            authn_completed = true;
+            break;
+        case SEC_I_CONTINUE_NEEDED:
+            ret = HSSL_WANT_READ;
+            break;
+        case SEC_I_COMPLETE_AND_CONTINUE:
+        case SEC_I_COMPLETE_NEEDED: {
+            SECURITY_STATUS complete_sec_status = SEC_E_OK;
+            complete_sec_status = CompleteAuthToken(&winssl->sechandle, &secure_buffer_desc_out);
+            if (complete_sec_status != SEC_E_OK) {
+                printe("establish_server_security_context CompleteAuthToken %x\n", complete_sec_status);
+                goto END;
+            }
+
+            if (sec_status == SEC_I_COMPLETE_NEEDED) {
+                authn_completed = true;
+                ret = HSSL_OK;
+            } else {
+                ret = HSSL_WANT_READ;
+            }
+            break;
+        }
+        default:
+            break;
+        }
+    }
+END:
+    free_all_buffers(&secure_buffer_desc_out);
+
+    if (authn_completed) {
+        SECURITY_STATUS sec_status = QueryContextAttributes(&winssl->sechandle, SECPKG_ATTR_STREAM_SIZES, &winssl->stream_sizes_);
+        if (sec_status != SEC_E_OK) {
+            printe("get_stream_sizes QueryContextAttributes %d\n", sec_status);
+        }
+    }
+    return ret;
+}
+
+static int schannel_connect_step1(struct wintls_s* ssl)
+{
+    int ret = 0;
+    ULONG context_attributes = 0;
+    unsigned long context_requirements = ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT | ISC_REQ_CONFIDENTIALITY | ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_STREAM;
+
+    TimeStamp life_time = { 0 };
+
+    SecBuffer secure_buffer_out[1] = { 0 };
+    init_sec_buffer(&secure_buffer_out[0], SECBUFFER_EMPTY, 0, NULL);
+
+    SecBufferDesc secure_buffer_desc_out = { 0 };
+    init_sec_buffer_desc(&secure_buffer_desc_out, SECBUFFER_VERSION, 1, secure_buffer_out);
+
+    SECURITY_STATUS sec_status = InitializeSecurityContext(ssl->ssl_ctx, NULL, ssl->sni, context_requirements, 0, 0, NULL, 0, &ssl->sechandle,
+        &secure_buffer_desc_out, &context_attributes, &life_time);
+
+    if (sec_status != SEC_I_CONTINUE_NEEDED) {
+        printe("1InitializeSecurityContext: %x\n", sec_status);
+    }
+
+    if (secure_buffer_out[0].cbBuffer > 0) {
+        int rc = __sendwrapper(ssl->fd, (const char*)secure_buffer_out[0].pvBuffer, secure_buffer_out[0].cbBuffer, 0);
+        if (rc != secure_buffer_out[0].cbBuffer) {
+            // TODO: Handle the error
+            printe("%s :send failed\n", __func__);
+            ret = -1;
+        } else {
+            printd("%s :send len=%d\n", __func__, rc);
+            ssl->connecting_state = ssl_connect_2;
+        }
+    }
+    free_all_buffers(&secure_buffer_desc_out);
+    return ret;
+}
+
+static int schannel_connect_step2(struct wintls_s* ssl)
+{
+    int ret = HSSL_ERROR;
+    ULONG context_attributes = 0;
+    bool verify_server_cert = 0;
+
+    unsigned long context_requirements = ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT | ISC_REQ_CONFIDENTIALITY | ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_STREAM;
+    if (!verify_server_cert) {
+        context_requirements |= ISC_REQ_MANUAL_CRED_VALIDATION;
+    }
+
+    TimeStamp life_time = { 0 };
+
+    // Allocate a temporary buffer for input
+    char* buffer_in = malloc(TLS_SOCKET_BUFFER_SIZE);
+    if (buffer_in == NULL) {
+        printe("schannel_connect_step2: Memory allocation failed\n");
+        return HSSL_ERROR;
+    }
+
+    int offset = 0;
+    bool skip_recv = false;
+    bool authn_complete = false;
+    while (!authn_complete) {
+        int in_buffer_size = 0;
+
+        if (!skip_recv) {
+            int received = __recvwrapper(ssl->fd, buffer_in + offset, TLS_SOCKET_BUFFER_SIZE, 0);
+            if (received == SOCKET_ERROR) {
+                if (WSAGetLastError() == WSAEWOULDBLOCK) {
+                    ret = HSSL_WANT_READ;
+                } else {
+                    printe("schannel_connect_step2: Receive failed\n");
+                }
+                break;
+            } else if (received == 0) {
+                printe("schannel_connect_step2: peer closed\n");
+                break;
+            }
+            in_buffer_size = received + offset;
+        } else {
+            in_buffer_size = offset;
+        }
+
+        skip_recv = false;
+        offset = 0;
+
+        // Input buffer
+        SecBuffer secure_buffer_in[4] = { 0 };
+        init_sec_buffer(&secure_buffer_in[0], SECBUFFER_TOKEN, in_buffer_size, buffer_in);
+        init_sec_buffer(&secure_buffer_in[1], SECBUFFER_EMPTY, 0, NULL);
+
+        SecBufferDesc secure_buffer_desc_in = { 0 };
+        init_sec_buffer_desc(&secure_buffer_desc_in, SECBUFFER_VERSION, 2, secure_buffer_in);
+
+        // Output buffer
+        SecBuffer secure_buffer_out[3] = { 0 };
+        init_sec_buffer(&secure_buffer_out[0], SECBUFFER_TOKEN, 0, NULL);
+        init_sec_buffer(&secure_buffer_out[1], SECBUFFER_ALERT, 0, NULL);
+        init_sec_buffer(&secure_buffer_out[2], SECBUFFER_EMPTY, 0, NULL);
+
+        SecBufferDesc secure_buffer_desc_out = { 0 };
+        init_sec_buffer_desc(&secure_buffer_desc_out, SECBUFFER_VERSION, 3, secure_buffer_out);
+        printd("h2:%d\n", in_buffer_size);
+        SECURITY_STATUS sec_status = InitializeSecurityContext(ssl->ssl_ctx, &ssl->sechandle, ssl->sni, context_requirements, 0, 0, &secure_buffer_desc_in, 0,
+            &ssl->sechandle, &secure_buffer_desc_out, &context_attributes, &life_time);
+
+        printd("h2 0x%x inbuf[1] type=%d %d inbuf[0]=%d\n", sec_status, secure_buffer_in[1].BufferType, secure_buffer_in[1].cbBuffer, secure_buffer_in[0].cbBuffer);
+        if (sec_status == SEC_E_OK || sec_status == SEC_I_CONTINUE_NEEDED) {
+            // for (size_t i = 0; i < 3; i++) {
+            //     printd("obuf[%zu] type=%d %d\n", i, secure_buffer_out[i].BufferType, secure_buffer_out[i].cbBuffer);
+            // }
+            if (secure_buffer_out[0].cbBuffer > 0) {
+                int rc = __sendwrapper(ssl->fd, (const char*)secure_buffer_out[0].pvBuffer, secure_buffer_out[0].cbBuffer, 0);
+                if (rc != secure_buffer_out[0].cbBuffer) {
+                    printe("schannel_connect_step2: Send failed\n");
+                    // TODO: Handle the error
+                    break;
+                }
+                // printd("%s :send ok\n", __func__);
+            }
+
+            if (sec_status == SEC_I_CONTINUE_NEEDED) {
+                if (secure_buffer_in[1].BufferType == SECBUFFER_EXTRA && secure_buffer_in[1].cbBuffer > 0) {
+                    offset = secure_buffer_in[0].cbBuffer - secure_buffer_in[1].cbBuffer;
+                    memmove(buffer_in, buffer_in + offset, secure_buffer_in[1].cbBuffer);
+                    offset = secure_buffer_in[1].cbBuffer;
+                    skip_recv = true;
+                }
+            } else if (sec_status == SEC_E_OK) {
+                authn_complete = true;
+                ret = HSSL_OK;
+                ssl->connecting_state = ssl_connect_3;
+            }
+        } else if (sec_status == SEC_E_INCOMPLETE_MESSAGE) {
+            offset = secure_buffer_in[0].cbBuffer;
+        } else {
+            printe("2InitializeSecurityContext: 0x%x\n", sec_status);
+            break;
+        }
+
+        free_all_buffers(&secure_buffer_desc_out);
+    }
+    // END:
+    free(buffer_in); // Free the temporary buffer
+    return ret;
+}
+
+static void dumpconninfo(SecHandle* sechandle)
+{
+    SECURITY_STATUS Status;
+    SecPkgContext_ConnectionInfo ConnectionInfo;
+
+    Status = QueryContextAttributes(sechandle,
+        SECPKG_ATTR_CONNECTION_INFO,
+        (PVOID)&ConnectionInfo);
+    if (Status != SEC_E_OK) {
+        printe("Error 0x%x querying connection info\n", Status);
+        return;
+    }
+
+    printd("\n");
+
+    switch (ConnectionInfo.dwProtocol) {
+    case SP_PROT_TLS1_CLIENT:
+        printd("Protocol: TLS1\n");
+        break;
+
+    case SP_PROT_SSL3_CLIENT:
+        printd("Protocol: SSL3\n");
+        break;
+
+    case SP_PROT_PCT1_CLIENT:
+        printd("Protocol: PCT\n");
+        break;
+
+    case SP_PROT_SSL2_CLIENT:
+        printd("Protocol: SSL2\n");
+        break;
+
+    default:
+        printd("Protocol: 0x%x\n", ConnectionInfo.dwProtocol);
+    }
+
+    switch (ConnectionInfo.aiCipher) {
+    case CALG_RC4:
+        printd("Cipher: RC4\n");
+        break;
+
+    case CALG_3DES:
+        printd("Cipher: Triple DES\n");
+        break;
+
+    case CALG_RC2:
+        printd("Cipher: RC2\n");
+        break;
+
+    case CALG_DES:
+    case CALG_CYLINK_MEK:
+        printd("Cipher: DES\n");
+        break;
+
+    case CALG_SKIPJACK:
+        printd("Cipher: Skipjack\n");
+        break;
+
+    case CALG_AES_128:
+        printd("Cipher: aes128\n");
+        break;
+    default:
+        printd("Cipher: 0x%x\n", ConnectionInfo.aiCipher);
+    }
+
+    printd("Cipher strength: %d\n", ConnectionInfo.dwCipherStrength);
+
+    switch (ConnectionInfo.aiHash) {
+    case CALG_MD5:
+        printd("Hash: MD5\n");
+        break;
+
+    case CALG_SHA:
+        printd("Hash: SHA\n");
+        break;
+
+    default:
+        printd("Hash: 0x%x\n", ConnectionInfo.aiHash);
+    }
+
+    printd("Hash strength: %d\n", ConnectionInfo.dwHashStrength);
+
+    switch (ConnectionInfo.aiExch) {
+    case CALG_RSA_KEYX:
+    case CALG_RSA_SIGN:
+        printd("Key exchange: RSA\n");
+        break;
+
+    case CALG_KEA_KEYX:
+        printd("Key exchange: KEA\n");
+        break;
+
+    case CALG_DH_EPHEM:
+        printd("Key exchange: DH Ephemeral\n");
+        break;
+
+    default:
+        printd("Key exchange: 0x%x\n", ConnectionInfo.aiExch);
+    }
+
+    printd("Key exchange strength: %d\n", ConnectionInfo.dwExchStrength);
 }
 
-int hssl_read(hssl_t ssl, void* buf, int len) {
-    int fd = (intptr_t)ssl;
-    return read(fd, buf, len);
+int hssl_connect(hssl_t _ssl)
+{
+    int ret = 0;
+    struct wintls_s* ssl = _ssl;
+    if (ssl->connecting_state == ssl_connect_1) {
+        ret = schannel_connect_step1(ssl);
+    }
+    if (!ret && ssl->connecting_state == ssl_connect_2) {
+        ret = schannel_connect_step2(ssl);
+    }
+    // printd("%s %x\n", __func__, ret);
+    if (!ret) {
+        if (ssl->connecting_state == ssl_connect_3) {
+            // ret = schannel_connect_step3(ssl);
+        }
+        SECURITY_STATUS sec_status = QueryContextAttributes(&ssl->sechandle, SECPKG_ATTR_STREAM_SIZES, &ssl->stream_sizes_);
+        if (sec_status != SEC_E_OK) {
+            printe("get_stream_sizes QueryContextAttributes %d\n", sec_status);
+        } else {
+            printd("stream_sizes bs:%d h:%d t:%d max:%d bfs:%d\n", ssl->stream_sizes_.cbBlockSize, ssl->stream_sizes_.cbHeader, ssl->stream_sizes_.cbTrailer, ssl->stream_sizes_.cbMaximumMessage, ssl->stream_sizes_.cBuffers);
+        }
+        dumpconninfo(&ssl->sechandle);
+    }
+    return ret;
+}
+
+static int decrypt_message(SecHandle security_context, unsigned long* extra, char* in_buf, int in_len, char* out_buf, int out_len)
+{
+    printd("%s: inlen=%d\n", __func__, in_len);
+    // Initialize the secure buffers
+    SecBuffer secure_buffers[4] = { 0 };
+    init_sec_buffer(&secure_buffers[0], SECBUFFER_DATA, in_len, in_buf);
+    init_sec_buffer(&secure_buffers[1], SECBUFFER_EMPTY, 0, NULL);
+    init_sec_buffer(&secure_buffers[2], SECBUFFER_EMPTY, 0, NULL);
+    init_sec_buffer(&secure_buffers[3], SECBUFFER_EMPTY, 0, NULL);
+
+    // Initialize the secure buffer descriptor
+    SecBufferDesc secure_buffer_desc = { 0 };
+    init_sec_buffer_desc(&secure_buffer_desc, SECBUFFER_VERSION, 4, secure_buffers);
+
+    // Decrypt the message using the security context
+    SECURITY_STATUS sec_status = DecryptMessage(&security_context, &secure_buffer_desc, 0, NULL);
+
+    for (size_t i = 1; i < 4; i++) {
+        printd("%d: %u %u\n", i, secure_buffers[i].BufferType, secure_buffers[i].cbBuffer);
+    }
+    if (sec_status == SEC_E_INCOMPLETE_MESSAGE) {
+        printe("decrypt_message SEC_E_INCOMPLETE_MESSAGE\n");
+        return -1;
+    } else if (sec_status == SEC_E_DECRYPT_FAILURE) {
+        printe("decrypt_message ignore SEC_E_DECRYPT_FAILURE\n");
+        return 0;
+    } else if (sec_status == SEC_E_UNSUPPORTED_FUNCTION) {
+        printe("decrypt_message ignore SEC_E_UNSUPPORTED_FUNCTION\n");
+        return 0;
+    }
+
+    if (sec_status != SEC_E_OK) {
+        printe("decrypt_message DecryptMessage: 0x%x\n", sec_status);
+        return -1;
+    }
+    if (secure_buffers[3].BufferType == SECBUFFER_EXTRA && secure_buffers[3].cbBuffer > 0) {
+        *extra = secure_buffers[3].cbBuffer;
+    }
+    memcpy(out_buf, secure_buffers[1].pvBuffer, secure_buffers[1].cbBuffer);
+    // printd("ob:%s\n", out_buf);
+    return secure_buffers[1].cbBuffer;
 }
 
-int hssl_write(hssl_t ssl, const void* buf, int len) {
-    int fd = (intptr_t)ssl;
-    return write(fd, buf, len);
+int hssl_read(hssl_t _ssl, void* buf, int len)
+{
+    struct wintls_s* ssl = _ssl;
+    printd("%s: dec_len_= %zu\n", __func__, ssl->dec_len_);
+    if (ssl->dec_len_ > 0) {
+        if (buf == NULL) {
+            return 0;
+        }
+        int decrypted = MIN(ssl->dec_len_, len);
+        memcpy(buf, ssl->decrypted_buffer_, (size_t)decrypted);
+        ssl->dec_len_ -= decrypted;
+        if (ssl->dec_len_) {
+            memmove(ssl->decrypted_buffer_, ssl->decrypted_buffer_ + decrypted, (size_t)ssl->dec_len_);
+        } else {
+            // hssl_read(_ssl, NULL, 0);
+        }
+        return decrypted;
+    }
+
+    // We might have leftovers, an incomplete message from a previous call.
+    // Calculate the available buffer length for tcp recv.
+    int recv_max_len = TLS_SOCKET_BUFFER_SIZE - ssl->buffer_to_decrypt_offset_;
+    int bytes_received = __recvwrapper(ssl->fd, ssl->buffer_to_decrypt_ + ssl->buffer_to_decrypt_offset_, recv_max_len, 0);
+    // printd("%s recv %d %d\n", __func__, bytes_received, WSAGetLastError());
+    if (bytes_received == SOCKET_ERROR) {
+        if (WSAGetLastError() == WSAEWOULDBLOCK) {
+            bytes_received = 0;
+            return 0;
+        } else {
+            return -1;
+        }
+    } else if (bytes_received == 0) {
+        return 0;
+    }
+
+    int encrypted_buffer_len = ssl->buffer_to_decrypt_offset_ + bytes_received;
+    ssl->buffer_to_decrypt_offset_ = 0;
+    while (true) {
+        // printd("%s:buffer_to_decrypt_offset_ = %d , encrypted_buffer_len= %d\n", __func__, ssl->buffer_to_decrypt_offset_, encrypted_buffer_len);
+        if (ssl->buffer_to_decrypt_offset_ >= encrypted_buffer_len) {
+            // Reached the encrypted buffer length, we decrypted everything so we can stop.
+            break;
+        }
+        unsigned long extra = 0;
+        int decrypted_len = decrypt_message(ssl->sechandle, &extra, ssl->buffer_to_decrypt_ + ssl->buffer_to_decrypt_offset_,
+            encrypted_buffer_len - ssl->buffer_to_decrypt_offset_, ssl->decrypted_buffer_ + ssl->dec_len_,
+            TLS_SOCKET_BUFFER_SIZE + TLS_SOCKET_BUFFER_SIZE - ssl->dec_len_);
+
+        if (decrypted_len == -1) {
+            // Incomplete message, we shuold keep it so it will be decrypted on the next call to recv().
+            // Shift the remaining buffer to the beginning and break the loop.
+
+            memmove(ssl->buffer_to_decrypt_, ssl->buffer_to_decrypt_ + ssl->buffer_to_decrypt_offset_, encrypted_buffer_len - ssl->buffer_to_decrypt_offset_);
+
+            break;
+        }
+
+        ssl->dec_len_ += decrypted_len;
+        ssl->buffer_to_decrypt_offset_ = encrypted_buffer_len - extra;
+    }
+    ssl->buffer_to_decrypt_offset_ = encrypted_buffer_len - ssl->buffer_to_decrypt_offset_;
+    return hssl_read(_ssl, buf, len);
+}
+
+int hssl_write(hssl_t _ssl, const void* buf, int len)
+{
+    struct wintls_s* ssl = _ssl;
+    SecPkgContext_StreamSizes* stream_sizes = &ssl->stream_sizes_;
+    if (len > (int)stream_sizes->cbMaximumMessage) {
+        len = stream_sizes->cbMaximumMessage;
+    }
+
+    // Calculate the minimum output buffer length
+    int min_out_len = stream_sizes->cbHeader + len + stream_sizes->cbTrailer;
+    if (min_out_len > TLS_SOCKET_BUFFER_SIZE) {
+        printe("encrypt_message: Output buffer is too small");
+        return -1;
+    }
+
+    // Initialize the secure buffers
+    SecBuffer secure_buffers[4] = { 0 };
+    init_sec_buffer(&secure_buffers[0], SECBUFFER_STREAM_HEADER, stream_sizes->cbHeader, ssl->encrypted_buffer_);
+    init_sec_buffer(&secure_buffers[1], SECBUFFER_DATA, len, ssl->encrypted_buffer_ + stream_sizes->cbHeader);
+    init_sec_buffer(&secure_buffers[2], SECBUFFER_STREAM_TRAILER, stream_sizes->cbTrailer, ssl->encrypted_buffer_ + stream_sizes->cbHeader + len);
+    init_sec_buffer(&secure_buffers[3], SECBUFFER_EMPTY, 0, NULL);
+
+    // Initialize the secure buffer descriptor
+    SecBufferDesc secure_buffer_desc = { 0 };
+    init_sec_buffer_desc(&secure_buffer_desc, SECBUFFER_VERSION, 4, secure_buffers);
+
+    // Copy the input buffer to the data buffer
+    memcpy(secure_buffers[1].pvBuffer, buf, len);
+
+    // Encrypt the message using the security context
+    SECURITY_STATUS sec_status = EncryptMessage(&ssl->sechandle, 0, &secure_buffer_desc, 0);
+
+    // Check the encryption status and the data buffer length
+    if (sec_status != SEC_E_OK) {
+        printe("encrypt_message EncryptMessage %d\n", sec_status);
+        return -1;
+    }
+    if (secure_buffers[1].cbBuffer > (unsigned int)len) {
+        printe("encrypt_message: Data buffer is too large\n");
+        return -1;
+    }
+
+    // Adjust the minimum output buffer length
+    min_out_len = secure_buffers[0].cbBuffer + secure_buffers[1].cbBuffer + secure_buffers[2].cbBuffer;
+    printd("enc02: %d %d\n", secure_buffers[0].cbBuffer, secure_buffers[2].cbBuffer);
+
+    // Send the encrypted message to the socket
+    int offset = __sendwrapper(ssl->fd, ssl->encrypted_buffer_, min_out_len, 0);
+    // Check the send result
+    if (offset != min_out_len) {
+        printe("hssl_write: Send failed\n");
+        return -1;
+    } else {
+        printd("hssl_write: Send %d\n", min_out_len);
+    }
+
+    // Return the number of bytes sent excluding the header and trailer
+    return offset - secure_buffers[0].cbBuffer - secure_buffers[2].cbBuffer;
 }
 
-int hssl_close(hssl_t ssl) {
+int hssl_close(hssl_t _ssl)
+{
     return 0;
 }
 
-int hssl_set_sni_hostname(hssl_t ssl, const char* hostname) {
+int hssl_set_sni_hostname(hssl_t _ssl, const char* hostname)
+{
+    struct wintls_s* ssl = _ssl;
+    ssl->sni = strdup(hostname);
     return 0;
 }