Browse Source

Use promise/future to replace cond_var

ithewei 3 months ago
parent
commit
e4878eddf7
1 changed files with 49 additions and 40 deletions
  1. 49 40
      examples/protorpc/protorpc_client.cpp

+ 49 - 40
examples/protorpc/protorpc_client.cpp

@@ -9,8 +9,9 @@
 
 #include "TcpClient.h"
 
+#include <atomic>
 #include <mutex>
-#include <condition_variable>
+#include <future>
 
 using namespace hv;
 
@@ -44,21 +45,8 @@ enum ProtoRpcResult {
 
 class ProtoRpcContext {
 public:
-    protorpc::RequestPtr    req;
-    protorpc::ResponsePtr   res;
-private:
-    std::mutex              _mutex;
-    std::condition_variable _cond;
-
-public:
-    void wait(int timeout_ms) {
-        std::unique_lock<std::mutex> locker(_mutex);
-        _cond.wait_for(locker, std::chrono::milliseconds(timeout_ms));
-    }
-
-    void notify() {
-        _cond.notify_one();
-    }
+    protorpc::RequestPtr                req;
+    std::promise<protorpc::ResponsePtr> res;
 };
 typedef std::shared_ptr<ProtoRpcContext>    ContextPtr;
 
@@ -66,10 +54,9 @@ class ProtoRpcClient : public TcpClient {
 public:
     ProtoRpcClient() : TcpClient()
     {
-        connect_state = kInitialized;
-
-        setConnectTimeout(5000);
+        connect_status = kInitialized;
 
+        // reconnect setting
         reconn_setting_t reconn;
         reconn_setting_init(&reconn);
         reconn.min_delay = 1000;
@@ -91,10 +78,10 @@ public:
         onConnection = [this](const SocketChannelPtr& channel) {
             std::string peeraddr = channel->peeraddr();
             if (channel->isConnected()) {
-                connect_state = kConnected;
+                connect_status = kConnected;
                 printf("connected to %s! connfd=%d\n", peeraddr.c_str(), channel->fd());
             } else {
-                connect_state = kDisconnectd;
+                connect_status = kDisconnectd;
                 printf("disconnected to %s! connfd=%d\n", peeraddr.c_str(), channel->fd());
             }
         };
@@ -127,20 +114,33 @@ public:
             }
             auto ctx = iter->second;
             calls_mutex.unlock();
-            ctx->res = res;
-            ctx->notify();
+            ctx->res.set_value(res);
         };
     }
 
-    int connect(int port, const char* host = "127.0.0.1") {
-        createsocket(port, host);
-        connect_state = kConnecting;
+    // @retval >0 connfd, <0 error, =0 connecting
+    int connect(int port, const char* host = "127.0.0.1", bool wait_connect = true, int connect_timeout = 5000) {
+        int fd = createsocket(port, host);
+        if (fd < 0) {
+            return fd;
+        }
+        setConnectTimeout(connect_timeout);
+        connect_status = kConnecting;
         start();
+        if (wait_connect) {
+            while (connect_status == kConnecting) hv_msleep(1);
+            return connect_status == kConnected ? fd : -1;
+        }
         return 0;
     }
 
+    bool isConnected() {
+        return connect_status == kConnected;
+    }
+
     protorpc::ResponsePtr call(protorpc::RequestPtr& req, int timeout_ms = 10000) {
-        if (connect_state != kConnected) {
+        if (!isConnected()) {
+            printf("RPC not connected!\n");
             return NULL;
         }
         static std::atomic<uint64_t> s_id = ATOMIC_VAR_INIT(0);
@@ -165,17 +165,25 @@ public:
             channel->write(writebuf, packlen);
         }
         HV_STACK_FREE(writebuf);
-        // wait until response come or timeout
-        ctx->wait(timeout_ms);
-        auto res = ctx->res;
+        protorpc::ResponsePtr res;
+        if (timeout_ms > 0) {
+            // wait until response come or timeout
+            auto fut = ctx->res.get_future();
+            auto status = fut.wait_for(std::chrono::milliseconds(timeout_ms));
+            if (status == std::future_status::ready) {
+                res = fut.get();
+                if (res->has_error()) {
+                    printf("RPC error:\n%s\n", res->error().DebugString().c_str());
+                }
+            } else if (status == std::future_status::timeout) {
+                printf("RPC timeout!\n");
+            } else {
+                printf("RPC unexpected status: %d!\n", (int)status);
+            }
+        }
         calls_mutex.lock();
         calls.erase(req->id());
         calls_mutex.unlock();
-        if (res == NULL) {
-            printf("RPC timeout!\n");
-        } else if (res->has_error()) {
-            printf("RPC error:\n%s\n", res->error().DebugString().c_str());
-        }
         return res;
     }
 
@@ -217,12 +225,14 @@ public:
         return kRpcSuccess;
     }
 
-    enum {
+private:
+    enum ConnectStatus {
         kInitialized,
         kConnecting,
         kConnected,
         kDisconnectd,
-    } connect_state;
+    };
+    std::atomic<ConnectStatus> connect_status;
     std::map<uint64_t, protorpc::ContextPtr> calls;
     std::mutex calls_mutex;
 };
@@ -244,9 +254,8 @@ int main(int argc, char** argv) {
     const char* param2 = argv[5];
 
     protorpc::ProtoRpcClient cli;
-    cli.connect(port, host);
-    while (cli.connect_state == protorpc::ProtoRpcClient::kConnecting) hv_msleep(1);
-    if (cli.connect_state == protorpc::ProtoRpcClient::kDisconnectd) {
+    int ret = cli.connect(port, host, true);
+    if (ret < 0) {
         return -20;
     }