Prechádzať zdrojové kódy

Add SocketChannel::onconnect

hewei.it 4 rokov pred
rodič
commit
f4cfe4ada3
5 zmenil súbory, kde vykonal 167 pridanie a 125 odobranie
  1. 76 12
      evpp/Channel.h
  2. 13 22
      evpp/TcpClient.h
  3. 4 4
      evpp/TcpServer.h
  4. 49 52
      http/client/AsyncHttpClient.cpp
  5. 25 35
      http/client/AsyncHttpClient.h

+ 76 - 12
evpp/Channel.h

@@ -27,6 +27,7 @@ public:
             hio_setcb_write(io_, on_write);
             hio_setcb_close(io_, on_close);
         }
+        status = isOpened() ? OPENED : CLOSED;
     }
 
     virtual ~Channel() {
@@ -38,11 +39,27 @@ public:
     int id() { return id_; }
     int error() { return hio_error(io_); }
 
+    // context
+    void* context() {
+        return ctx_;
+    }
     void setContext(void* ctx) {
         ctx_ = ctx;
     }
-    void* context() {
-        return ctx_;
+    template<class T>
+    T* newContext() {
+        ctx_ = new T;
+    }
+    template<class T>
+    T* getContext() {
+        return (T*)ctx_;
+    }
+    template<class T>
+    void deleteContext() {
+        if (ctx_) {
+            delete (T*)ctx_;
+            ctx_ = NULL;
+        }
     }
 
     bool isOpened() {
@@ -83,6 +100,15 @@ public:
     int         fd_;
     uint32_t    id_;
     void*       ctx_;
+    enum Status {
+        // Channel::Status
+        OPENED,
+        CLOSED,
+        // SocketChannel::Status
+        CONNECTING,
+        CONNECTED,
+        DISCONNECTED,
+    } status;
     std::function<void(Buffer*)> onread;
     std::function<void(Buffer*)> onwrite;
     std::function<void()>        onclose;
@@ -106,27 +132,54 @@ private:
 
     static void on_close(hio_t* io) {
         Channel* channel = (Channel*)hio_context(io);
-        if (channel && channel->onclose) {
-            channel->onclose();
+        if (channel) {
+            channel->status = CLOSED;
+            if (channel->onclose) {
+                channel->onclose();
+            }
         }
     }
 };
 
 class SocketChannel : public Channel {
 public:
-    enum Status {
-        OPENED,
-        CONNECTING,
-        CONNECTED,
-        DISCONNECTED,
-        CLOSED,
-    } status;
+    // for TcpClient
+    std::function<void()>   onconnect;
 
     SocketChannel(hio_t* io) : Channel(io) {
-        status = isOpened() ? OPENED : CLOSED;
     }
     virtual ~SocketChannel() {}
 
+    int enableSSL() {
+        return hio_enable_ssl(io_);
+    }
+
+    void setConnectTimeout(int timeout_ms) {
+        hio_set_connect_timeout(io_, timeout_ms);
+    }
+
+    int startConnect(int port, const char* host = "127.0.0.1") {
+        sockaddr_u peeraddr;
+        memset(&peeraddr, 0, sizeof(peeraddr));
+        int ret = sockaddr_set_ipport(&peeraddr, host, port);
+        if (ret != 0) {
+            // hloge("unknown host %s", host);
+            return ret;
+        }
+        return startConnect(&peeraddr.sa);
+    }
+
+    int startConnect(struct sockaddr* peeraddr) {
+        hio_set_peeraddr(io_, peeraddr, SOCKADDR_LEN(peeraddr));
+        return startConnect();
+    }
+
+    int startConnect() {
+        status = CONNECTING;
+        hio_setcb_connect(io_, on_connect);
+        return hio_connect(io_);
+    }
+
     bool isConnected() {
         return isOpened() && status == CONNECTED;
     }
@@ -146,6 +199,17 @@ public:
     int send(const std::string& str) {
         return write(str);
     }
+
+private:
+    static void on_connect(hio_t* io) {
+        SocketChannel* channel = (SocketChannel*)hio_context(io);
+        if (channel) {
+            channel->status = CONNECTED;
+            if (channel->onconnect) {
+                channel->onconnect();
+            }
+        }
+    }
 };
 
 typedef std::shared_ptr<Channel>        ChannelPtr;

+ 13 - 22
evpp/TcpClient.h

@@ -81,6 +81,18 @@ public:
 
     int startConnect() {
         assert(channel != NULL);
+        if (tls) {
+            channel->enableSSL();
+        }
+        if (connect_timeout) {
+            channel->setConnectTimeout(connect_timeout);
+        }
+        channel->onconnect = [this]() {
+            channel->startRead();
+            if (onConnection) {
+                onConnection(channel);
+            }
+        };
         channel->onread = [this](Buffer* buf) {
             if (onMessage) {
                 onMessage(channel, buf);
@@ -92,7 +104,6 @@ public:
             }
         };
         channel->onclose = [this]() {
-            channel->status = SocketChannel::CLOSED;
             if (onConnection) {
                 onConnection(channel);
             }
@@ -102,16 +113,7 @@ public:
                 startReconnect();
             }
         };
-
-        hio_t* connio = channel->io();
-        hevent_set_userdata(connio, this);
-        if (tls) {
-            hio_enable_ssl(connio);
-        }
-        hio_set_connect_timeout(connio, connect_timeout);
-        hio_setcb_connect(connio, onConnect);
-        hio_connect(connio);
-        return 0;
+        return channel->startConnect();
     }
 
     int startReconnect() {
@@ -162,17 +164,6 @@ public:
         reconnect_info = *info;
     }
 
-private:
-    static void onConnect(hio_t* io) {
-        TcpClient* client = (TcpClient*)hevent_userdata(io);
-        SocketChannelPtr channel = client->channel;
-        channel->status = SocketChannel::CONNECTED;
-        channel->startRead();
-        if (client->onConnection) {
-            client->onConnection(channel);
-        }
-    }
-
 public:
     SocketChannelPtr        channel;
 

+ 4 - 4
evpp/TcpServer.h

@@ -74,7 +74,7 @@ public:
         return channel;
     }
 
-    void removeChannel(ChannelPtr channel) {
+    void removeChannel(const SocketChannelPtr& channel) {
         std::lock_guard<std::mutex> locker(mutex_);
         int fd = channel->fd();
         if (fd < channels.capacity()) {
@@ -94,17 +94,17 @@ private:
         channel->status = SocketChannel::CONNECTED;
         ++server->connection_num;
 
-        channel->onread = [server, channel](Buffer* buf) {
+        channel->onread = [server, &channel](Buffer* buf) {
             if (server->onMessage) {
                 server->onMessage(channel, buf);
             }
         };
-        channel->onwrite = [server, channel](Buffer* buf) {
+        channel->onwrite = [server, &channel](Buffer* buf) {
             if (server->onWriteComplete) {
                 server->onWriteComplete(channel, buf);
             }
         };
-        channel->onclose = [server, channel]() {
+        channel->onclose = [server, &channel]() {
             channel->status = SocketChannel::CLOSED;
             if (server->onConnection) {
                 server->onConnection(channel);

+ 49 - 52
http/client/AsyncHttpClient.cpp

@@ -2,6 +2,9 @@
 
 namespace hv {
 
+// createsocket => startConnect =>
+// onconnect => sendRequest => startRead =>
+// onread => HttpParser => resp_cb
 int AsyncHttpClient::doTask(const HttpClientTaskPtr& task) {
     const HttpRequestPtr& req = task->req;
     // queueInLoop timeout?
@@ -23,18 +26,13 @@ int AsyncHttpClient::doTask(const HttpClientTaskPtr& task) {
     }
 
     int connfd = -1;
-    hio_t* connio = NULL;
-    HttpClientContextPtr ctx = NULL;
-
     // first get from conn_pools
     char strAddr[SOCKADDR_STRLEN] = {0};
     SOCKADDR_STR(&peeraddr, strAddr);
     auto iter = conn_pools.find(strAddr);
     if (iter != conn_pools.end()) {
-        if (iter->second.get(connfd)) {
-            // hlogd("get from conn_pools");
-            ctx = getContext(connfd);
-        }
+        // hlogd("get from conn_pools");
+        iter->second.get(connfd);
     }
 
     if (connfd < 0) {
@@ -44,104 +42,102 @@ int AsyncHttpClient::doTask(const HttpClientTaskPtr& task) {
             perror("socket");
             return -30;
         }
-        connio = hio_get(loop_thread.hloop(), connfd);
+        hio_t* connio = hio_get(loop_thread.hloop(), connfd);
         assert(connio != NULL);
         hio_set_peeraddr(connio, &peeraddr.sa, sockaddr_len(&peeraddr));
+        addChannel(connio);
         // https
         if (req->https) {
             hio_enable_ssl(connio);
         }
     }
 
-    if (ctx == NULL) {
-        // new HttpClientContext
-        ctx.reset(new HttpClientContext);
-        ctx->channel.reset(new SocketChannel(connio));
-        addContext(ctx);
-    }
-
-    ctx->req = req;
-    ctx->cb = task->cb;
-    ctx->channel->onread = [this, ctx](Buffer* buf) {
+    const SocketChannelPtr& channel = getChannel(connfd);
+    assert(channel != NULL);
+    HttpClientContext* ctx = channel->getContext<HttpClientContext>();
+    ctx->task = task;
+    channel->onconnect = [this, &channel]() {
+        sendRequest(channel);
+    };
+    channel->onread = [this, &channel](Buffer* buf) {
+        HttpClientContext* ctx = channel->getContext<HttpClientContext>();
+        if (ctx->task == NULL) return;
         const char* data = (const char*)buf->data();
         int len = buf->size();
         int nparse = ctx->parser->FeedRecvData(data, len);
         if (nparse != len) {
             ctx->errorCallback();
-            ctx->channel->close();
+            channel->close();
             return;
         }
         if (ctx->parser->IsComplete()) {
-            std::string req_connection = ctx->req->GetHeader("Connection");
+            std::string req_connection = ctx->task->req->GetHeader("Connection");
             std::string resp_connection = ctx->resp->GetHeader("Connection");
             ctx->successCallback();
             if (stricmp(req_connection.c_str(), "keep-alive") == 0 &&
                 stricmp(resp_connection.c_str(), "keep-alive") == 0) {
-                // add into conn_pools to reuse
+                // NOTE: add into conn_pools to reuse
                 // hlogd("add into conn_pools");
-                conn_pools[ctx->channel->peeraddr()].add(ctx->channel->fd());
+                conn_pools[channel->peeraddr()].add(channel->fd());
             } else {
-                ctx->channel->close();
+                channel->close();
             }
         }
     };
-    ctx->channel->onclose = [this, ctx, task]() {
-        ctx->channel->status = SocketChannel::CLOSED;
-        removeContext(ctx);
-        if (task->retry_cnt-- > 0) {
+    channel->onclose = [this, &channel]() {
+        HttpClientContext* ctx = channel->getContext<HttpClientContext>();
+        // NOTE: remove from conn_pools
+        // hlogd("remove from conn_pools");
+        auto iter = conn_pools.find(channel->peeraddr());
+        if (iter != conn_pools.end()) {
+            iter->second.remove(channel->fd());
+        }
+        if (ctx->task && ctx->task->retry_cnt-- > 0) {
             // try again
-            send(task);
+            send(ctx->task);
         } else {
             ctx->errorCallback();
         }
+        removeChannel(channel);
     };
 
     // timer
     if (timeout_ms > 0) {
-        ctx->timerID = setTimeout(timeout_ms - elapsed_ms, [ctx](TimerID timerID){
-            hlogw("%s timeout!", ctx->req->url.c_str());
-            if (ctx->channel) {
-                ctx->channel->close();
+        ctx->timerID = setTimeout(timeout_ms - elapsed_ms, [&channel](TimerID timerID){
+            HttpClientContext* ctx = channel->getContext<HttpClientContext>();
+            assert(ctx->task != NULL);
+            hlogw("%s timeout!", ctx->task->req->url.c_str());
+            if (channel) {
+                channel->close();
             }
         });
     }
 
-    if (ctx->channel->isConnected()) {
+    if (channel->isConnected()) {
         // sendRequest
-        sendRequest(ctx);
+        sendRequest(channel);
     } else {
         // startConnect
-        hevent_set_userdata(connio, this);
-        hio_setcb_connect(connio, onconnect);
-        hio_connect(connio);
+        channel->startConnect();
     }
 
     return 0;
 }
 
-void AsyncHttpClient::onconnect(hio_t* io) {
-    AsyncHttpClient* client = (AsyncHttpClient*)hevent_userdata(io);
-    HttpClientContextPtr ctx = client->getContext(hio_fd(io));
-    assert(ctx != NULL && ctx->req != NULL && ctx->channel != NULL);
-
-    ctx->channel->status = SocketChannel::CONNECTED;
-    client->sendRequest(ctx);
-    ctx->channel->startRead();
-}
-
-int AsyncHttpClient::sendRequest(const HttpClientContextPtr ctx) {
-    assert(ctx != NULL && ctx->req != NULL && ctx->channel != NULL);
-    SocketChannelPtr channel = ctx->channel;
+// InitResponse => SubmitRequest => while(GetSendData) write => startRead
+int AsyncHttpClient::sendRequest(const SocketChannelPtr& channel) {
+    HttpClientContext* ctx = (HttpClientContext*)channel->context();
+    assert(ctx != NULL && ctx->task != NULL);
 
     if (ctx->parser == NULL) {
-        ctx->parser.reset(HttpParser::New(HTTP_CLIENT, (http_version)ctx->req->http_major));
+        ctx->parser.reset(HttpParser::New(HTTP_CLIENT, (http_version)ctx->task->req->http_major));
     }
     if (ctx->resp == NULL) {
         ctx->resp.reset(new HttpResponse);
     }
 
     ctx->parser->InitResponse(ctx->resp.get());
-    ctx->parser->SubmitRequest(ctx->req.get());
+    ctx->parser->SubmitRequest(ctx->task->req.get());
 
     char* data = NULL;
     size_t len = 0;
@@ -149,6 +145,7 @@ int AsyncHttpClient::sendRequest(const HttpClientContextPtr ctx) {
         Buffer buf(data, len);
         channel->write(&buf);
     }
+    channel->startRead();
 
     return 0;
 }

+ 25 - 35
http/client/AsyncHttpClient.h

@@ -59,13 +59,10 @@ struct HttpClientTask {
 typedef std::shared_ptr<HttpClientTask> HttpClientTaskPtr;
 
 struct HttpClientContext {
-    HttpRequestPtr          req;
-    HttpResponseCallback    cb;
+    HttpClientTaskPtr   task;
 
-    SocketChannelPtr    channel;
     HttpResponsePtr     resp;
     HttpParserPtr       parser;
-
     TimerID             timerID;
 
     HttpClientContext() {
@@ -77,11 +74,11 @@ struct HttpClientContext {
             killTimer(timerID);
             timerID = INVALID_TIMER_ID;
         }
-        if (cb) {
-            cb(resp);
-            // NOTE: ensure cb just call once
-            cb = NULL;
+        if (task && task->cb) {
+            task->cb(resp);
         }
+        // NOTE: task done
+        task = NULL;
     }
 
     void successCallback() {
@@ -94,7 +91,6 @@ struct HttpClientContext {
         callback();
     }
 };
-typedef std::shared_ptr<HttpClientContext>  HttpClientContextPtr;
 
 class AsyncHttpClient {
 public:
@@ -121,44 +117,38 @@ public:
     }
 
 protected:
-    void sendInLoop(const HttpClientTaskPtr& task) {
+    void sendInLoop(HttpClientTaskPtr task) {
         int err = doTask(task);
         if (err != 0 && task->cb) {
             task->cb(NULL);
         }
     }
-    // createsocket => startConnect =>
-    // onconnect => sendRequest => startRead =>
-    // onread => HttpParser => resp_cb
     int doTask(const HttpClientTaskPtr& task);
 
-    // InitResponse => SubmitRequest => while(GetSendData) write => startRead
-    static void onconnect(hio_t* io);
-    static int sendRequest(const HttpClientContextPtr ctx);
+    static int sendRequest(const SocketChannelPtr& channel);
 
-    HttpClientContextPtr getContext(int fd) {
-        return fd < client_ctxs.capacity() ? client_ctxs[fd] : NULL;
+    // channel
+    const SocketChannelPtr& getChannel(int fd) {
+        return channels[fd];
+        // return fd < channels.capacity() ? channels[fd] : NULL;
     }
 
-    void addContext(const HttpClientContextPtr& ctx) {
-        int fd = ctx->channel->fd();
-        if (fd >= client_ctxs.capacity()) {
-            client_ctxs.resize(2 * fd);
+    const SocketChannelPtr& addChannel(hio_t* io) {
+        SocketChannelPtr channel(new SocketChannel(io));
+        channel->newContext<HttpClientContext>();
+        int fd = channel->fd();
+        if (fd >= channels.capacity()) {
+            channels.resize(2 * fd);
         }
-        client_ctxs[fd] = ctx;
-        // NOTE: add into conn_pools after recv response completed
-        // conn_pools[ctx->channel->peeraddr()].add(fd);
+        channels[fd] = channel;
+        return channels[fd];
     }
 
-    void removeContext(const HttpClientContextPtr& ctx) {
-        int fd = ctx->channel->fd();
-        // NOTE: remove from conn_pools
-        auto iter = conn_pools.find(ctx->channel->peeraddr());
-        if (iter != conn_pools.end()) {
-            iter->second.remove(fd);
-        }
-        if (fd < client_ctxs.capacity()) {
-            client_ctxs[fd] = NULL;
+    void removeChannel(const SocketChannelPtr& channel) {
+        channel->deleteContext<HttpClientContext>();
+        int fd = channel->fd();
+        if (fd < channels.capacity()) {
+            channels[fd] = NULL;
         }
     }
 
@@ -166,7 +156,7 @@ private:
     EventLoopThread                         loop_thread;
     // NOTE: just one loop thread, no need mutex.
     // with fd as index
-    std::vector<HttpClientContextPtr>       client_ctxs;
+    std::vector<SocketChannelPtr>           channels;
     // peeraddr => ConnPool
     std::map<std::string, ConnPool<int>>    conn_pools;
 };