Bläddra i källkod

Add ssl_ctx member for #338

ithewei 2 år sedan
förälder
incheckning
aee97038fa

+ 27 - 12
evpp/TcpServer.h

@@ -19,12 +19,15 @@ public:
         acceptor_loop = loop ? loop : std::make_shared<EventLoop>();
         listenfd = -1;
         tls = false;
-        unpack_setting.mode = UNPACK_MODE_NONE;
+        tls_setting = NULL;
+        unpack_setting = NULL;
         max_connections = 0xFFFFFFFF;
         load_balance = LB_RoundRobin;
     }
 
     virtual ~TcpServerEventLoopTmpl() {
+        HV_FREE(tls_setting);
+        HV_FREE(unpack_setting);
     }
 
     EventLoopPtr loop(int idx = -1) {
@@ -80,6 +83,14 @@ public:
         hevent_set_userdata(listenio, this);
         if (tls) {
             hio_enable_ssl(listenio);
+            if (tls_setting) {
+                int ret = hio_new_ssl_ctx(listenio, tls_setting);
+                if (ret != 0) {
+                    hloge("new SSL_CTX failed: %d", ret);
+                    closesocket();
+                    return ret;
+                }
+            }
         }
         return 0;
     }
@@ -111,21 +122,24 @@ public:
     int withTLS(hssl_ctx_opt_t* opt = NULL) {
         tls = true;
         if (opt) {
-            opt->endpoint = HSSL_SERVER;
-            if (hssl_ctx_init(opt) == NULL) {
-                fprintf(stderr, "hssl_ctx_init failed!\n");
-                return -1;
+            if (tls_setting == NULL) {
+                HV_ALLOC_SIZEOF(tls_setting);
             }
+            opt->endpoint = HSSL_SERVER;
+            *tls_setting = *opt;
         }
         return 0;
     }
 
     void setUnpack(unpack_setting_t* setting) {
-        if (setting) {
-            unpack_setting = *setting;
-        } else {
-            unpack_setting.mode = UNPACK_MODE_NONE;
+        if (setting == NULL) {
+            HV_FREE(unpack_setting);
+            return;
+        }
+        if (unpack_setting == NULL) {
+            HV_ALLOC_SIZEOF(unpack_setting);
         }
+        *unpack_setting = *setting;
     }
 
     // channel
@@ -214,8 +228,8 @@ private:
             // so in this lambda function, no code should be added below.
         };
 
-        if (server->unpack_setting.mode != UNPACK_MODE_NONE) {
-            channel->setUnpack(&server->unpack_setting);
+        if (server->unpack_setting) {
+            channel->setUnpack(server->unpack_setting);
         }
         channel->startRead();
         if (server->onConnection) {
@@ -240,7 +254,8 @@ public:
     int                     port;
     int                     listenfd;
     bool                    tls;
-    unpack_setting_t        unpack_setting;
+    hssl_ctx_opt_t*         tls_setting;
+    unpack_setting_t*       unpack_setting;
     // Callback
     std::function<void(const TSocketChannelPtr&)>           onConnection;
     std::function<void(const TSocketChannelPtr&, Buffer*)>  onMessage;

+ 9 - 9
examples/http_server_test.cpp

@@ -5,7 +5,8 @@
  */
 
 #include "HttpServer.h"
-#include "hssl.h"
+
+using namespace hv;
 
 /*
  * #define TEST_HTTPS 1
@@ -76,31 +77,30 @@ int main(int argc, char** argv) {
         return ctx->send(resp.dump(2));
     });
 
-    http_server_t server;
+    HttpServer server;
     server.service = &router;
     server.port = port;
 #if TEST_HTTPS
     server.https_port = 8443;
-    hssl_ctx_init_param_t param;
+    hssl_ctx_opt_t param;
     memset(&param, 0, sizeof(param));
     param.crt_file = "cert/server.crt";
     param.key_file = "cert/server.key";
     param.endpoint = HSSL_SERVER;
-    if (hssl_ctx_init(&param) == NULL) {
-        fprintf(stderr, "hssl_ctx_init failed!\n");
+    if (server.newSslCtx(&param) != 0) {
+        fprintf(stderr, "new SSL_CTX failed!\n");
         return -20;
     }
 #endif
 
     // uncomment to test multi-processes
-    // server.worker_processes = 4;
+    // server.setProcessNum(4);
     // uncomment to test multi-threads
-    // server.worker_threads = 4;
+    // server.setThreadNum(4);
 
-    http_server_run(&server, 0);
+    server.start();
 
     // press Enter to stop
     while (getchar() != '\n');
-    http_server_stop(&server);
     return 0;
 }

+ 2 - 2
examples/httpd/httpd.cpp

@@ -188,13 +188,13 @@ int parse_confile(const char* confile) {
         std::string key_file = ini.GetValue("ssl_privatekey");
         std::string ca_file = ini.GetValue("ssl_ca_certificate");
         hlogi("SSL backend is %s", hssl_backend());
-        hssl_ctx_init_param_t param;
+        hssl_ctx_opt_t param;
         memset(&param, 0, sizeof(param));
         param.crt_file = crt_file.c_str();
         param.key_file = key_file.c_str();
         param.ca_file = ca_file.c_str();
         param.endpoint = HSSL_SERVER;
-        if (hssl_ctx_init(&param) == NULL) {
+        if (g_http_server.newSslCtx(&param) != 0) {
             hloge("SSL certificate verify failed!");
             exit(0);
         }

+ 10 - 9
examples/websocket_server_test.cpp

@@ -11,7 +11,8 @@
 #include "WebSocketServer.h"
 #include "EventLoop.h"
 #include "htime.h"
-#include "hssl.h"
+
+using namespace hv;
 
 /*
  * #define TEST_WSS 1
@@ -87,26 +88,26 @@ int main(int argc, char** argv) {
         // channel->deleteContextPtr();
     };
 
-    websocket_server_t server;
+    WebSocketServer server;
     server.port = port;
 #if TEST_WSS
     server.https_port = port + 1;
-    hssl_ctx_init_param_t param;
+    hssl_ctx_opt_t param;
     memset(&param, 0, sizeof(param));
     param.crt_file = "cert/server.crt";
     param.key_file = "cert/server.key";
     param.endpoint = HSSL_SERVER;
-    if (hssl_ctx_init(&param) == NULL) {
-        fprintf(stderr, "hssl_ctx_init failed!\n");
+    if (server.newSslCtx(&param) != 0) {
+        fprintf(stderr, "new SSL_CTX failed!\n");
         return -20;
     }
 #endif
-    server.service = &http;
-    server.ws = &ws;
-    websocket_server_run(&server, 0);
+    server.registerHttpService(&http);
+    server.registerWebSocketService(&ws);
+
+    server.start();
 
     // press Enter to stop
     while (getchar() != '\n');
-    websocket_server_stop(&server);
     return 0;
 }

+ 24 - 7
http/server/HttpServer.cpp

@@ -1,7 +1,6 @@
 #include "HttpServer.h"
 
 #include "hv.h"
-#include "hssl.h"
 #include "hmain.h"
 
 #include "httpdef.h"
@@ -276,6 +275,9 @@ static void loop_thread(void* userdata) {
         hio_t* listenio = haccept(hloop, server->listenfd[1], on_accept);
         hevent_set_userdata(listenio, server);
         hio_enable_ssl(listenio);
+        if (server->ssl_ctx) {
+            hio_set_ssl_ctx(listenio, server->ssl_ctx);
+        }
     }
 
     HttpServerPrivdata* privdata = (HttpServerPrivdata*)server->privdata;
@@ -336,17 +338,26 @@ int http_server_run(http_server_t* server, int wait) {
         hlogi("http server listening on %s:%d", server->host, server->port);
     }
     // https_port
-    if (server->https_port > 0 && hssl_ctx_instance() != NULL) {
+    if (server->https_port > 0 && HV_WITH_SSL) {
+        server->listenfd[1] = Listen(server->https_port, server->host);
+        if (server->listenfd[1] < 0) return server->listenfd[1];
+        hlogi("https server listening on %s:%d", server->host, server->https_port);
+    }
+    // SSL_CTX
+    if (server->listenfd[1] >= 0) {
+        if (server->ssl_ctx == NULL) {
+            server->ssl_ctx = hssl_ctx_instance();
+        }
+        if (server->ssl_ctx == NULL) {
+            hloge("new SSL_CTX failed!");
+            return ERR_NEW_SSL_CTX;
+        }
 #ifdef WITH_NGHTTP2
 #ifdef WITH_OPENSSL
         static unsigned char s_alpn_protos[] = "\x02h2\x08http/1.1\x08http/1.0\x08http/0.9";
-        hssl_ctx_t ssl_ctx = hssl_ctx_instance();
-        hssl_ctx_set_alpn_protos(ssl_ctx, s_alpn_protos, sizeof(s_alpn_protos) - 1);
+        hssl_ctx_set_alpn_protos(server->ssl_ctx, s_alpn_protos, sizeof(s_alpn_protos) - 1);
 #endif
 #endif
-        server->listenfd[1] = Listen(server->https_port, server->host);
-        if (server->listenfd[1] < 0) return server->listenfd[1];
-        hlogi("https server listening on %s:%d", server->host, server->https_port);
     }
 
     HttpServerPrivdata* privdata = new HttpServerPrivdata;
@@ -414,6 +425,12 @@ int http_server_stop(http_server_t* server) {
         hthread_join(thrd);
     }
 
+    if (server->alloced_ssl_ctx && server->ssl_ctx) {
+        hssl_ctx_free(server->ssl_ctx);
+        server->alloced_ssl_ctx = 0;
+        server->ssl_ctx = NULL;
+    }
+
     delete privdata;
     server->privdata = NULL;
     return 0;

+ 31 - 3
http/server/HttpServer.h

@@ -2,6 +2,7 @@
 #define HV_HTTP_SERVER_H_
 
 #include "hexport.h"
+#include "hssl.h"
 #include "HttpService.h"
 // #include "WebSocketServer.h"
 namespace hv {
@@ -26,6 +27,9 @@ typedef struct http_server_s {
     // hooks
     std::function<void()> onWorkerStart;
     std::function<void()> onWorkerStop;
+    // SSL/TLS
+    hssl_ctx_t  ssl_ctx;
+    unsigned    alloced_ssl_ctx: 1;
 
 #ifdef __cplusplus
     http_server_s() {
@@ -44,6 +48,9 @@ typedef struct http_server_s {
         listenfd[0] = listenfd[1] = -1;
         userdata = NULL;
         privdata = NULL;
+        // SSL/TLS
+        ssl_ctx = NULL;
+        alloced_ssl_ctx = 0;
     }
 #endif
 } http_server_t;
@@ -78,7 +85,11 @@ namespace hv {
 
 class HttpServer : public http_server_t {
 public:
-    HttpServer() : http_server_t() {}
+    HttpServer(HttpService* service = NULL)
+        : http_server_t()
+    {
+        this->service = service;
+    }
     ~HttpServer() { stop(); }
 
     void registerHttpService(HttpService* service) {
@@ -90,8 +101,12 @@ public:
     }
 
     void setPort(int port = 0, int ssl_port = 0) {
-        if (port != 0) this->port = port;
-        if (ssl_port != 0) this->https_port = ssl_port;
+        if (port >= 0) this->port = port;
+        if (ssl_port >= 0) this->https_port = ssl_port;
+    }
+    void setListenFD(int fd = -1, int ssl_fd = -1) {
+        if (fd >= 0) this->listenfd[0] = fd;
+        if (ssl_fd >= 0) this->listenfd[1] = ssl_fd;
     }
 
     void setProcessNum(int num) {
@@ -102,6 +117,19 @@ public:
         this->worker_threads = num;
     }
 
+    // SSL/TLS
+    int setSslCtx(hssl_ctx_t ssl_ctx) {
+        this->ssl_ctx = ssl_ctx;
+        return 0;
+    }
+    int newSslCtx(hssl_ctx_opt_t* opt) {
+        // NOTE: hssl_ctx_free in http_server_stop
+        hssl_ctx_t ssl_ctx = hssl_ctx_new(opt);
+        if (ssl_ctx == NULL) return -1;
+        this->alloced_ssl_ctx = 1;
+        return setSslCtx(ssl_ctx);
+    }
+
     int run(bool wait = true) {
         return http_server_run(this, wait);
     }

+ 7 - 0
http/server/WebSocketServer.h

@@ -28,6 +28,13 @@ struct WebSocketService {
 
 class WebSocketServer : public HttpServer {
 public:
+    WebSocketServer(WebSocketService* service = NULL)
+        : HttpServer()
+    {
+        this->ws = service;
+    }
+    ~WebSocketServer() { stop(); }
+
     void registerWebSocketService(WebSocketService* service) {
         this->ws = service;
     }