Переглянути джерело

精简websocket实现代码

cqm 3 роки тому
батько
коміт
3092204006
3 змінених файлів з 73 додано та 104 видалено
  1. 53 2
      http/server/HttpHandler.cpp
  2. 10 44
      http/server/HttpHandler.h
  3. 10 58
      http/server/HttpServer.cpp

+ 53 - 2
http/server/HttpHandler.cpp

@@ -1,4 +1,4 @@
-#include "HttpHandler.h"
+#include "HttpHandler.h"
 
 #include "hbase.h"
 #include "herr.h"
@@ -6,6 +6,57 @@
 #include "hasync.h" // import hv::async for http_async_handler
 #include "http_page.h"
 
+#include "htime.h"
+bool HttpHandler::SwitchWebSocket(hio_t* io, ws_session_type type) {
+    if(!io || !ws_service) return false;
+    protocol = WEBSOCKET;
+    ws_parser.reset(new WebSocketParser);
+    ws_channel.reset(new hv::WebSocketChannel(io, type));
+    ws_parser->onMessage = [this](int opcode, const std::string& msg){
+        switch(opcode) {
+        case WS_OPCODE_CLOSE:
+            ws_channel->close(true);
+            break;
+        case WS_OPCODE_PING:
+            // printf("recv ping\n");
+            // printf("send pong\n");
+            ws_channel->sendPong();
+            break;
+        case WS_OPCODE_PONG:
+            // printf("recv pong\n");
+            this->last_recv_pong_time = gethrtime_us();
+            break;
+        case WS_OPCODE_TEXT:
+        case WS_OPCODE_BINARY:
+            // onmessage
+            if (ws_service && ws_service->onmessage) {
+                ws_service->onmessage(ws_channel, msg);
+            }
+            break;
+        default:
+            break;
+        }
+    };
+    // NOTE: cancel keepalive timer, judge alive by heartbeat.
+    hio_set_keepalive_timeout(io, 0);
+    if (ws_service && ws_service->ping_interval > 0) {
+        int ping_interval = MAX(ws_service->ping_interval, 1000);
+        ws_channel->setHeartbeat(ping_interval, [this](){
+            if (last_recv_pong_time < last_send_ping_time) {
+                hlogw("[%s:%d] websocket no pong!", ip, port);
+                ws_channel->close(true);
+            } else {
+                // printf("send ping\n");
+                ws_channel->sendPing();
+                last_send_ping_time = gethrtime_us();
+            }
+        });
+    }
+    // onopen
+    WebSocketOnOpen();
+    return true;
+}
+
 int HttpHandler::customHttpHandler(const http_handler& handler) {
     return invokeHttpHandler(&handler);
 }
@@ -200,7 +251,7 @@ int HttpHandler::defaultErrorHandler() {
 int HttpHandler::FeedRecvData(const char* data, size_t len) {
     int nfeed = 0;
     if (protocol == HttpHandler::WEBSOCKET) {
-        nfeed = ws->parser->FeedRecvData(data, len);
+        nfeed = ws_parser->FeedRecvData(data, len);
         if (nfeed != len) {
             hloge("[%s:%d] websocket parse error!", ip, port);
         }

+ 10 - 44
http/server/HttpHandler.h

@@ -8,35 +8,6 @@
 #include "WebSocketServer.h"
 #include "WebSocketParser.h"
 
-class WebSocketHandler {
-public:
-    WebSocketChannelPtr         channel;
-    WebSocketParserPtr          parser;
-    uint64_t                    last_send_ping_time;
-    uint64_t                    last_recv_pong_time;
-
-    WebSocketHandler() {
-        last_send_ping_time = 0;
-        last_recv_pong_time = 0;
-    }
-
-    void Init(hio_t* io = NULL, ws_session_type type = WS_SERVER) {
-        parser.reset(new WebSocketParser);
-        if (io) {
-            channel.reset(new hv::WebSocketChannel(io, type));
-        }
-    }
-
-    void onopen() {
-        channel->status = hv::SocketChannel::CONNECTED;
-    }
-
-    void onclose() {
-        channel->status = hv::SocketChannel::DISCONNECTED;
-    }
-};
-typedef std::shared_ptr<WebSocketHandler> WebSocketHandlerPtr;
-
 class HttpHandler {
 public:
     enum ProtocolType {
@@ -76,7 +47,10 @@ public:
     std::string             body;
 
     // for websocket
-    WebSocketHandlerPtr         ws;
+    WebSocketChannelPtr         ws_channel;
+    WebSocketParserPtr          ws_parser;
+    uint64_t                    last_send_ping_time;
+    uint64_t                    last_recv_pong_time;
     WebSocketService*           ws_service;
 
     HttpHandler() {
@@ -143,26 +117,18 @@ public:
     int GetSendData(char** data, size_t* len);
 
     // websocket
-    WebSocketHandler* SwitchWebSocket() {
-        ws.reset(new WebSocketHandler);
-        protocol = WEBSOCKET;
-        return ws.get();
-    }
+    bool SwitchWebSocket(hio_t* io, ws_session_type type = WS_SERVER);
+
     void WebSocketOnOpen() {
-        ws->onopen();
+        ws_channel->status = hv::SocketChannel::CONNECTED;
         if (ws_service && ws_service->onopen) {
-            ws_service->onopen(ws->channel, req->url);
+            ws_service->onopen(ws_channel, req->url);
         }
     }
     void WebSocketOnClose() {
-        ws->onclose();
+        ws_channel->status = hv::SocketChannel::DISCONNECTED;
         if (ws_service && ws_service->onclose) {
-            ws_service->onclose(ws->channel);
-        }
-    }
-    void WebSocketOnMessage(const std::string& msg) {
-        if (ws_service && ws_service->onmessage) {
-            ws_service->onmessage(ws->channel, msg);
+            ws_service->onclose(ws_channel);
         }
     }
 

+ 10 - 58
http/server/HttpServer.cpp

@@ -36,45 +36,6 @@ struct HttpServerPrivdata {
     std::mutex                  mutex_;
 };
 
-static void websocket_heartbeat(hio_t* io) {
-    HttpHandler* handler = (HttpHandler*)hevent_userdata(io);
-    WebSocketHandler* ws = handler->ws.get();
-    if (ws->last_recv_pong_time < ws->last_send_ping_time) {
-        hlogw("[%s:%d] websocket no pong!", handler->ip, handler->port);
-        ws->channel->close(true);
-    } else {
-        // printf("send ping\n");
-        ws->channel->sendPing();
-        ws->last_send_ping_time = gethrtime_us();
-    }
-}
-
-static void websocket_onmessage(int opcode, const std::string& msg, hio_t* io) {
-    HttpHandler* handler = (HttpHandler*)hevent_userdata(io);
-    WebSocketHandler* ws = handler->ws.get();
-    switch(opcode) {
-    case WS_OPCODE_CLOSE:
-        ws->channel->close(true);
-        break;
-    case WS_OPCODE_PING:
-        // printf("recv ping\n");
-        // printf("send pong\n");
-        ws->channel->sendPong();
-        break;
-    case WS_OPCODE_PONG:
-        // printf("recv pong\n");
-        ws->last_recv_pong_time = gethrtime_us();
-        break;
-    case WS_OPCODE_TEXT:
-    case WS_OPCODE_BINARY:
-        // onmessage
-        handler->WebSocketOnMessage(msg);
-        break;
-    default:
-        break;
-    }
-}
-
 static void on_recv(hio_t* io, void* _buf, int readbytes) {
     // printf("on_recv fd=%d readbytes=%d\n", hio_fd(io), readbytes);
     const char* buf = (const char*)_buf;
@@ -146,7 +107,6 @@ static void on_recv(hio_t* io, void* _buf, int readbytes) {
 
     // Upgrade:
     bool upgrade = false;
-    HttpHandler::ProtocolType upgrade_protocol = HttpHandler::UNKNOWN;
     auto iter_upgrade = req->headers.find("upgrade");
     if (iter_upgrade != req->headers.end()) {
         upgrade = true;
@@ -154,6 +114,11 @@ static void on_recv(hio_t* io, void* _buf, int readbytes) {
         hlogi("[%s:%d] Upgrade: %s", handler->ip, handler->port, upgrade_proto);
         // websocket
         if (stricmp(upgrade_proto, "websocket") == 0) {
+            if (!handler->SwitchWebSocket(io)) {
+                hloge("[%s:%d] unsupported websocket", handler->ip, handler->port);
+                hio_close(io);
+                return;
+            }
             /*
             HTTP/1.1 101 Switching Protocols
             Connection: Upgrade
@@ -169,7 +134,10 @@ static void on_recv(hio_t* io, void* _buf, int readbytes) {
                 ws_encode_key(iter_key->second.c_str(), ws_accept);
                 resp->headers[SEC_WEBSOCKET_ACCEPT] = ws_accept;
             }
-            upgrade_protocol = HttpHandler::WEBSOCKET;
+            
+            // write upgrade resp
+            std::string header = resp->Dump(true, false);
+            hio_write(io, header.data(), header.length());
         }
         // h2/h2c
         else if (strnicmp(upgrade_proto, "h2", 2) == 0) {
@@ -214,22 +182,6 @@ static void on_recv(hio_t* io, void* _buf, int readbytes) {
         http_method_str(req->method), req->path.c_str(),
         resp->status_code, resp->status_message());
 
-    // switch protocol to websocket
-    if (upgrade && upgrade_protocol == HttpHandler::WEBSOCKET) {
-        WebSocketHandler* ws = handler->SwitchWebSocket();
-        ws->Init(io);
-        ws->parser->onMessage = std::bind(websocket_onmessage, std::placeholders::_1, std::placeholders::_2, io);
-        // NOTE: cancel keepalive timer, judge alive by heartbeat.
-        hio_set_keepalive_timeout(io, 0);
-        if (handler->ws_service && handler->ws_service->ping_interval > 0) {
-            int ping_interval = MAX(handler->ws_service->ping_interval, 1000);
-            hio_set_heartbeat(io, ping_interval, websocket_heartbeat);
-        }
-        // onopen
-        handler->WebSocketOnOpen();
-        return;
-    }
-
     if (status_code && !keepalive) {
         hio_close(io);
     }
@@ -311,7 +263,7 @@ static void loop_thread(void* userdata) {
             FileCache* filecache = default_filecache();
             filecache->RemoveExpiredFileCache();
         }, DEFAULT_FILE_EXPIRED_TIME * 1000);
-        // NOTE: add timer to update date every 1s
+        // NOTE: add timer to update s_date every 1s
         htimer_add(hloop, [](htimer_t* timer) {
             gmtime_fmt(hloop_now(hevent_loop(timer)), HttpMessage::s_date);
         }, 1000);