Преглед на файлове

ws mqtt send thread-safe

ithewei преди 3 години
родител
ревизия
46db906753
променени са 4 файла, в които са добавени 48 реда и са изтрити 19 реда
  1. 27 8
      http/WebSocketChannel.h
  2. 2 2
      http/client/WebSocketClient.cpp
  3. 4 4
      http/server/HttpServer.cpp
  4. 15 5
      mqtt/mqtt_client.c

+ 27 - 8
http/WebSocketChannel.h

@@ -30,49 +30,68 @@ public:
         if (len > fragment) {
             return send(buf, len, fragment, opcode);
         }
-        return send_(buf, len, opcode, fin);
+        std::lock_guard<std::mutex> locker(mutex_);
+        return sendFrame(buf, len, opcode, fin);
     }
 
     // websocket fragment
+    // lock ->
     // send(p, fragment, opcode, false) ->
     // send(p, fragment, WS_OPCODE_CONTINUE, false) ->
     // ... ->
     // send(p, remain, WS_OPCODE_CONTINUE, true)
+    // unlock
     int send(const char* buf, int len, int fragment, enum ws_opcode opcode = WS_OPCODE_BINARY) {
+        std::lock_guard<std::mutex> locker(mutex_);
         if (len <= fragment) {
-            return send_(buf, len, opcode, true);
+            return sendFrame(buf, len, opcode, true);
         }
 
         // first fragment
-        int nsend = send_(buf, fragment, opcode, false);
+        int nsend = sendFrame(buf, fragment, opcode, false);
         if (nsend < 0) return nsend;
 
         const char* p = buf + fragment;
         int remain = len - fragment;
         while (remain > fragment) {
-            nsend = send_(p, fragment, WS_OPCODE_CONTINUE, false);
+            nsend = sendFrame(p, fragment, WS_OPCODE_CONTINUE, false);
             if (nsend < 0) return nsend;
             p += fragment;
             remain -= fragment;
         }
 
         // last fragment
-        nsend = send_(p, remain, WS_OPCODE_CONTINUE, true);
+        nsend = sendFrame(p, remain, WS_OPCODE_CONTINUE, true);
         if (nsend < 0) return nsend;
 
         return len;
     }
 
+    int sendPing() {
+        std::lock_guard<std::mutex> locker(mutex_);
+        if (type == WS_CLIENT) {
+            return write(WS_CLIENT_PING_FRAME, WS_CLIENT_MIN_FRAME_SIZE);
+        }
+        return write(WS_SERVER_PING_FRAME, WS_SERVER_MIN_FRAME_SIZE);
+    }
+
+    int sendPong() {
+        std::lock_guard<std::mutex> locker(mutex_);
+        if (type == WS_CLIENT) {
+            return write(WS_CLIENT_PONG_FRAME, WS_CLIENT_MIN_FRAME_SIZE);
+        }
+        return write(WS_SERVER_PONG_FRAME, WS_SERVER_MIN_FRAME_SIZE);
+    }
+
 protected:
-    int send_(const char* buf, int len, enum ws_opcode opcode = WS_OPCODE_BINARY, bool fin = true) {
+    int sendFrame(const char* buf, int len, enum ws_opcode opcode = WS_OPCODE_BINARY, bool fin = true) {
         bool has_mask = false;
         char mask[4] = {0};
         if (type == WS_CLIENT) {
-            has_mask = true;
             *(int*)mask = rand();
+            has_mask = true;
         }
         int frame_size = ws_calc_frame_size(len, has_mask);
-        std::lock_guard<std::mutex> locker(mutex_);
         if (sendbuf_.len < frame_size) {
             sendbuf_.resize(ceil2e(frame_size));
         }

+ 2 - 2
http/client/WebSocketClient.cpp

@@ -130,7 +130,7 @@ int WebSocketClient::open(const char* _url, const http_headers& headers) {
                     {
                         // printf("recv ping\n");
                         // printf("send pong\n");
-                        channel->write(WS_CLIENT_PONG_FRAME, WS_CLIENT_MIN_FRAME_SIZE);
+                        channel->sendPong();
                         break;
                     }
                     case WS_OPCODE_PONG:
@@ -158,7 +158,7 @@ int WebSocketClient::open(const char* _url, const http_headers& headers) {
                             return;
                         }
                         // printf("send ping\n");
-                        channel->write(WS_CLIENT_PING_FRAME, WS_CLIENT_MIN_FRAME_SIZE);
+                        channel->sendPing();
                     });
                 }
                 if (onopen) onopen();

+ 4 - 4
http/server/HttpServer.cpp

@@ -41,10 +41,10 @@ static void websocket_heartbeat(hio_t* io) {
     WebSocketHandler* ws = handler->ws.get();
     if (ws->last_recv_pong_time < ws->last_send_ping_time) {
         hlogw("[%s:%d] websocket no pong!", handler->ip, handler->port);
-        hio_close(io);
+        ws->channel->close(true);
     } else {
         // printf("send ping\n");
-        hio_write(io, WS_SERVER_PING_FRAME, WS_SERVER_MIN_FRAME_SIZE);
+        ws->channel->sendPing();
         ws->last_send_ping_time = gethrtime_us();
     }
 }
@@ -54,12 +54,12 @@ static void websocket_onmessage(int opcode, const std::string& msg, hio_t* io) {
     WebSocketHandler* ws = handler->ws.get();
     switch(opcode) {
     case WS_OPCODE_CLOSE:
-        hio_close_async(io);
+        ws->channel->close(true);
         break;
     case WS_OPCODE_PING:
         // printf("recv ping\n");
         // printf("send pong\n");
-        hio_write(io, WS_SERVER_PONG_FRAME, WS_SERVER_MIN_FRAME_SIZE);
+        ws->channel->sendPong();
         break;
     case WS_OPCODE_PONG:
         // printf("recv pong\n");

+ 15 - 5
mqtt/mqtt_client.c

@@ -8,17 +8,27 @@ static unsigned short mqtt_next_mid() {
     return ++s_mid;
 }
 
+static int mqtt_client_send(mqtt_client_t* cli, const void* buf, int len) {
+    // thread-safe
+    hmutex_lock(&cli->mutex_);
+    int nwrite = hio_write(cli->io, buf, len);
+    hmutex_unlock(&cli->mutex_);
+    return nwrite;
+}
+
 static int mqtt_send_head(hio_t* io, int type, int length) {
+    mqtt_client_t* cli = (mqtt_client_t*)hevent_userdata(io);
     mqtt_head_t head;
     memset(&head, 0, sizeof(head));
     head.type = type;
     head.length = length;
     unsigned char headbuf[8] = { 0 };
     int headlen = mqtt_head_pack(&head, headbuf);
-    return hio_write(io, headbuf, headlen);
+    return mqtt_client_send(cli, headbuf, headlen);
 }
 
 static int mqtt_send_head_with_mid(hio_t* io, int type, unsigned short mid) {
+    mqtt_client_t* cli = (mqtt_client_t*)hevent_userdata(io);
     mqtt_head_t head;
     memset(&head, 0, sizeof(head));
     head.type = type;
@@ -31,7 +41,7 @@ static int mqtt_send_head_with_mid(hio_t* io, int type, unsigned short mid) {
     int headlen = mqtt_head_pack(&head, p);
     p += headlen;
     PUSH16(p, mid);
-    return hio_write(io, headbuf, headlen + 2);
+    return mqtt_client_send(cli, headbuf, headlen + 2);
 }
 
 static void mqtt_send_ping(hio_t* io) {
@@ -143,7 +153,7 @@ static int mqtt_client_login(mqtt_client_t* cli) {
         PUSH_N(p, cli->password, password_len);
     }
 
-    int nwrite = hio_write(cli->io, buf, p - buf);
+    int nwrite = mqtt_client_send(cli, buf, p - buf);
     HV_STACK_FREE(buf);
     return nwrite < 0 ? nwrite : 0;
 }
@@ -536,7 +546,7 @@ int mqtt_client_subscribe(mqtt_client_t* cli, const char* topic, int qos) {
     PUSH_N(p, topic, topic_len);
     PUSH8(p, qos & 3);
     // send head + mid + topic + qos
-    int nwrite = hio_write(cli->io, buf, p - buf);
+    int nwrite = mqtt_client_send(cli, buf, p - buf);
     HV_STACK_FREE(buf);
     return nwrite < 0 ? nwrite : mid;
 }
@@ -562,7 +572,7 @@ int mqtt_client_unsubscribe(mqtt_client_t* cli, const char* topic) {
     PUSH16(p, topic_len);
     PUSH_N(p, topic, topic_len);
     // send head + mid + topic
-    int nwrite = hio_write(cli->io, buf, p - buf);
+    int nwrite = mqtt_client_send(cli, buf, p - buf);
     HV_STACK_FREE(buf);
     return nwrite < 0 ? nwrite : mid;
 }