فهرست منبع

Add hssl_ctx_new

ithewei 3 سال پیش
والد
کامیت
64d9d96d57
8فایلهای تغییر یافته به همراه114 افزوده شده و 7 حذف شده
  1. 19 0
      event/hevent.c
  2. 3 1
      event/hevent.h
  3. 7 2
      event/hloop.h
  4. 24 2
      event/nio.c
  5. 18 0
      evpp/Channel.h
  6. 5 0
      evpp/TcpClient.h
  7. 32 2
      http/client/http_client.cpp
  8. 6 0
      http/client/http_client.h

+ 19 - 0
event/hevent.c

@@ -134,6 +134,8 @@ void hio_ready(hio_t* io) {
     io->unpack_setting = NULL;
     // ssl
     io->ssl = NULL;
+    io->ssl_ctx = NULL;
+    io->alloced_ssl_ctx = 0;
     // context
     io->ctx = NULL;
     // private:
@@ -459,12 +461,29 @@ hssl_t hio_get_ssl(hio_t* io) {
     return io->ssl;
 }
 
+hssl_ctx_t hio_get_ssl_ctx(hio_t* io) {
+    return io->ssl_ctx;
+}
+
 int hio_set_ssl(hio_t* io, hssl_t ssl) {
     io->io_type = HIO_TYPE_SSL;
     io->ssl = ssl;
     return 0;
 }
 
+int hio_set_ssl_ctx(hio_t* io, hssl_ctx_t ssl_ctx) {
+    io->io_type = HIO_TYPE_SSL;
+    io->ssl_ctx = ssl_ctx;
+    return 0;
+}
+
+int hio_new_ssl_ctx(hio_t* io, hssl_ctx_opt_t* opt) {
+    hssl_ctx_t ssl_ctx = hssl_ctx_new(opt);
+    if (ssl_ctx == NULL) return HSSL_ERROR;
+    io->alloced_ssl_ctx = 1;
+    return hio_set_ssl_ctx(io, ssl_ctx);
+}
+
 void hio_set_readbuf(hio_t* io, void* buf, size_t len) {
     assert(io && buf && len != 0);
     hio_free_readbuf(io);

+ 3 - 1
event/hevent.h

@@ -112,6 +112,7 @@ struct hio_s {
     unsigned    sendto      :1;
     unsigned    close       :1;
     unsigned    alloced_readbuf :1; // for hio_alloc_readbuf
+    unsigned    alloced_ssl_ctx :1; // for hio_new_ssl_ctx
 // public:
     hio_type_e  io_type;
     uint32_t    id; // fd cannot be used as unique identifier, so we provide an id
@@ -161,7 +162,8 @@ struct hio_s {
     // unpack
     unpack_setting_t*   unpack_setting; // for hio_set_unpack
     // ssl
-    void*       ssl; // for hio_enable_ssl / hio_set_ssl
+    void*       ssl;        // for hio_set_ssl
+    void*       ssl_ctx;    // for hio_set_ssl_ctx
     // context
     void*       ctx; // for hio_context / hio_set_context
 // private:

+ 7 - 2
event/hloop.h

@@ -285,8 +285,13 @@ HV_EXPORT hclose_cb   hio_getcb_close(hio_t* io);
 // Enable SSL/TLS is so easy :)
 HV_EXPORT int  hio_enable_ssl(hio_t* io);
 HV_EXPORT bool hio_is_ssl(hio_t* io);
-HV_EXPORT hssl_t hio_get_ssl(hio_t* io);
-HV_EXPORT int  hio_set_ssl(hio_t* io, hssl_t ssl);
+HV_EXPORT int  hio_set_ssl    (hio_t* io, hssl_t ssl);
+HV_EXPORT int  hio_set_ssl_ctx(hio_t* io, hssl_ctx_t ssl_ctx);
+// hssl_ctx_new(opt) -> hio_set_ssl_ctx
+HV_EXPORT int  hio_new_ssl_ctx(hio_t* io, hssl_ctx_opt_t* opt);
+HV_EXPORT hssl_t     hio_get_ssl(hio_t* io);
+HV_EXPORT hssl_ctx_t hio_get_ssl_ctx(hio_t* io);
+
 // NOTE: One loop per thread, one readbuf per loop.
 // But you can pass in your own readbuf instead of the default readbuf to avoid memcopy.
 HV_EXPORT void hio_set_readbuf(hio_t* io, void* buf, size_t len);

+ 24 - 2
event/nio.c

@@ -138,7 +138,16 @@ static void nio_accept(hio_t* io) {
 
         if (io->io_type == HIO_TYPE_SSL) {
             if (connio->ssl == NULL) {
-                hssl_ctx_t ssl_ctx = hssl_ctx_instance();
+                // io->ssl_ctx > g_ssl_ctx > hssl_ctx_new
+                hssl_ctx_t ssl_ctx = NULL;
+                if (io->ssl_ctx) {
+                    ssl_ctx = io->ssl_ctx;
+                } else if (g_ssl_ctx) {
+                    ssl_ctx = g_ssl_ctx;
+                } else {
+                    io->ssl_ctx = ssl_ctx = hssl_ctx_new(NULL);
+                    io->alloced_ssl_ctx = 1;
+                }
                 if (ssl_ctx == NULL) {
                     io->error = HSSL_ERROR;
                     goto accept_error;
@@ -180,7 +189,16 @@ static void nio_connect(hio_t* io) {
 
         if (io->io_type == HIO_TYPE_SSL) {
             if (io->ssl == NULL) {
-                hssl_ctx_t ssl_ctx = hssl_ctx_instance();
+                // io->ssl_ctx > g_ssl_ctx > hssl_ctx_new
+                hssl_ctx_t ssl_ctx = NULL;
+                if (io->ssl_ctx) {
+                    ssl_ctx = io->ssl_ctx;
+                } else if (g_ssl_ctx) {
+                    ssl_ctx = g_ssl_ctx;
+                } else {
+                    io->ssl_ctx = ssl_ctx = hssl_ctx_new(NULL);
+                    io->alloced_ssl_ctx = 1;
+                }
                 if (ssl_ctx == NULL) {
                     goto connect_failed;
                 }
@@ -539,6 +557,10 @@ int hio_close (hio_t* io) {
         hssl_free(io->ssl);
         io->ssl = NULL;
     }
+    if (io->ssl_ctx && io->alloced_ssl_ctx) {
+        hssl_ctx_free(io->ssl_ctx);
+        io->ssl_ctx = NULL;
+    }
     if (io->io_type & HIO_TYPE_SOCKET) {
         closesocket(io->fd);
     }

+ 18 - 0
evpp/Channel.h

@@ -187,9 +187,27 @@ public:
     }
     virtual ~SocketChannel() {}
 
+    // SSL/TLS
     int enableSSL() {
+        if (io_ == NULL) return -1;
         return hio_enable_ssl(io_);
     }
+    bool isSSL() {
+        if (io_ == NULL) return false;
+        return hio_is_ssl(io_);
+    }
+    int setSSL(hssl_t ssl) {
+        if (io_ == NULL) return -1;
+        return hio_set_ssl(io_, ssl);
+    }
+    int setSslCtx(hssl_ctx_t ssl_ctx) {
+        if (io_ == NULL) return -1;
+        return hio_set_ssl_ctx(io_, ssl_ctx);
+    }
+    int newSslCtx(hssl_ctx_opt_t* opt) {
+        if (io_ == NULL) return -1;
+        return hio_new_ssl_ctx(io_, opt);
+    }
 
     void setConnectTimeout(int timeout_ms) {
         if (io_ == NULL) return;

+ 5 - 0
evpp/TcpClient.h

@@ -180,6 +180,7 @@ public:
         return send(str.data(), str.size());
     }
 
+    // deprecated: use withTLS(opt) after createsocket
     int withTLS(const char* cert_file = NULL, const char* key_file = NULL, bool verify_peer = false) {
         if (cert_file) {
             hssl_ctx_init_param_t param;
@@ -196,6 +197,10 @@ public:
         tls = true;
         return 0;
     }
+    int withTLS(hssl_ctx_opt_t* opt) {
+        if (!channel) return -1;
+        return channel->newSslCtx(opt);
+    }
 
     void setConnectTimeout(int ms) {
         connect_timeout = ms;

+ 32 - 2
http/client/http_client.cpp

@@ -39,6 +39,8 @@ struct http_client_s {
     // for sync
     int             fd;
     hssl_t          ssl;
+    hssl_ctx_t      ssl_ctx;
+    bool            alloced_ssl_ctx;
     HttpParserPtr   parser;
     // for async
     std::mutex                              mutex_;
@@ -54,10 +56,16 @@ struct http_client_s {
 #endif
         fd = -1;
         ssl = NULL;
+        ssl_ctx = NULL;
+        alloced_ssl_ctx = false;
     }
 
     ~http_client_s() {
         Close();
+        if (ssl_ctx && alloced_ssl_ctx) {
+            hssl_ctx_free(ssl_ctx);
+            ssl_ctx = NULL;
+        }
     }
 
     void Close() {
@@ -103,6 +111,19 @@ int http_client_set_timeout(http_client_t* cli, int timeout) {
     return 0;
 }
 
+int http_client_set_ssl_ctx(http_client_t* cli, hssl_ctx_t ssl_ctx) {
+    cli->ssl_ctx = ssl_ctx;
+    return 0;
+}
+
+int http_client_new_ssl_ctx(http_client_t* cli, hssl_ctx_opt_t* opt) {
+    opt->endpoint = HSSL_CLIENT;
+    hssl_ctx_t ssl_ctx = hssl_ctx_new(opt);
+    if (ssl_ctx == NULL) return HSSL_ERROR;
+    cli->alloced_ssl_ctx = true;
+    return http_client_set_ssl_ctx(cli, ssl_ctx);
+}
+
 int http_client_clear_headers(http_client_t* cli) {
     cli->headers.clear();
     return 0;
@@ -418,8 +439,17 @@ static int http_client_connect(http_client_t* cli, const char* host, int port, i
     }
     tcp_nodelay(connfd, 1);
 
-    if (https) {
-        hssl_ctx_t ssl_ctx = hssl_ctx_instance();
+    if (https && cli->ssl == NULL) {
+        // cli->ssl_ctx > g_ssl_ctx > hssl_ctx_new
+        hssl_ctx_t ssl_ctx = NULL;
+        if (cli->ssl_ctx) {
+            ssl_ctx = cli->ssl_ctx;
+        } else if (g_ssl_ctx) {
+            ssl_ctx = g_ssl_ctx;
+        } else {
+            cli->ssl_ctx = ssl_ctx = hssl_ctx_new(NULL);
+            cli->alloced_ssl_ctx = true;
+        }
         if (ssl_ctx == NULL) {
             closesocket(connfd);
             return HSSL_ERROR;

+ 6 - 0
http/client/http_client.h

@@ -2,6 +2,7 @@
 #define HV_HTTP_CLIENT_H_
 
 #include "hexport.h"
+#include "hssl.h"
 #include "HttpMessage.h"
 
 /*
@@ -36,6 +37,11 @@ HV_EXPORT const char* http_client_strerror(int errcode);
 
 HV_EXPORT int http_client_set_timeout(http_client_t* cli, int timeout);
 
+// SSL/TLS
+HV_EXPORT int http_client_set_ssl_ctx(http_client_t* cli, hssl_ctx_t ssl_ctx);
+// hssl_ctx_new(opt) -> http_client_set_ssl_ctx
+HV_EXPORT int http_client_new_ssl_ctx(http_client_t* cli, hssl_ctx_opt_t* opt);
+
 // common headers
 HV_EXPORT int http_client_clear_headers(http_client_t* cli);
 HV_EXPORT int http_client_set_header(http_client_t* cli, const char* key, const char* value);