Przeglądaj źródła

Update TcpServer

ithewei 4 lat temu
rodzic
commit
1ff2d7ebfb
3 zmienionych plików z 31 dodań i 12 usunięć
  1. 1 0
      evpp/EventLoop.h
  2. 28 10
      evpp/TcpServer.h
  3. 2 2
      evpp/TcpServer_test.cpp

+ 1 - 0
evpp/EventLoop.h

@@ -218,6 +218,7 @@ typedef std::shared_ptr<EventLoop> EventLoopPtr;
 static inline EventLoop* tlsEventLoop() {
     return (EventLoop*)ThreadLocalStorage::get(ThreadLocalStorage::EVENT_LOOP);
 }
+#define currentThreadEventLoop tlsEventLoop()
 
 static inline TimerID setTimer(int timeout_ms, TimerCallback cb, int repeat = INFINITE) {
     EventLoop* loop = tlsEventLoop();

+ 28 - 10
evpp/TcpServer.h

@@ -24,7 +24,7 @@ public:
     }
 
     EventLoopPtr loop(int idx = -1) {
-        return loop_threads.loop(idx);
+        return worker_threads.loop(idx);
     }
 
     //@retval >=0 listenfd, <0 error
@@ -43,25 +43,26 @@ public:
         max_connections = num;
     }
     void setThreadNum(int num) {
-        loop_threads.setThreadNum(num);
+        worker_threads.setThreadNum(num);
     }
 
-    void startAccept(const EventLoopPtr& loop) {
+    int startAccept() {
         assert(listenfd >= 0);
-        hio_t* listenio = haccept(loop->loop(), listenfd, onAccept);
+        hio_t* listenio = haccept(acceptor_thread.hloop(), listenfd, onAccept);
         hevent_set_userdata(listenio, this);
         if (tls) {
             hio_enable_ssl(listenio);
         }
+        return 0;
     }
 
     void start(bool wait_threads_started = true) {
-        loop_threads.start(wait_threads_started, [this](const EventLoopPtr& loop){
-            startAccept(loop);
-        });
+        worker_threads.start(wait_threads_started);
+        acceptor_thread.start(wait_threads_started, std::bind(&TcpServer::startAccept, this));
     }
     void stop(bool wait_threads_stopped = true) {
-        loop_threads.stop(wait_threads_stopped);
+        acceptor_thread.stop(wait_threads_stopped);
+        worker_threads.stop(wait_threads_stopped);
     }
 
     int withTLS(const char* cert_file, const char* key_file) {
@@ -128,13 +129,19 @@ public:
     }
 
 private:
-    static void onAccept(hio_t* connio) {
+    static void newConnEvent(hio_t* connio) {
         TcpServer* server = (TcpServer*)hevent_userdata(connio);
         if (server->connectionNum() >= server->max_connections) {
             hlogw("over max_connections");
             hio_close(connio);
             return;
         }
+
+        // NOTE: attach to worker loop
+        EventLoop* worker_loop = currentThreadEventLoop;
+        assert(worker_loop != NULL);
+        hio_attach(worker_loop->loop(), connio);
+
         const SocketChannelPtr& channel = server->addChannel(connio);
         channel->status = SocketChannel::CONNECTED;
 
@@ -167,6 +174,15 @@ private:
         }
     }
 
+    static void onAccept(hio_t* connio) {
+        TcpServer* server = (TcpServer*)hevent_userdata(connio);
+        // NOTE: detach from acceptor loop
+        hio_detach(connio);
+        // Load Banlance: Round-Robin
+        EventLoopPtr worker_loop = server->worker_threads.nextLoop();
+        worker_loop->queueInLoop(std::bind(&TcpServer::newConnEvent, connio));
+    }
+
 public:
     int                     listenfd;
     bool                    tls;
@@ -183,7 +199,9 @@ private:
     // fd => SocketChannelPtr
     std::map<int, SocketChannelPtr> channels; // GUAREDE_BY(mutex_)
     std::mutex                      mutex_;
-    EventLoopThreadPool             loop_threads;
+
+    EventLoopThread                 acceptor_thread;
+    EventLoopThreadPool             worker_threads;
 };
 
 }

+ 2 - 2
evpp/TcpServer_test.cpp

@@ -25,9 +25,9 @@ int main(int argc, char* argv[]) {
     srv.onConnection = [](const SocketChannelPtr& channel) {
         std::string peeraddr = channel->peeraddr();
         if (channel->isConnected()) {
-            printf("%s connected! connfd=%d\n", peeraddr.c_str(), channel->fd());
+            printf("%s connected! connfd=%d tid=%ld\n", peeraddr.c_str(), channel->fd(), currentThreadEventLoop->tid());
         } else {
-            printf("%s disconnected! connfd=%d\n", peeraddr.c_str(), channel->fd());
+            printf("%s disconnected! connfd=%d tid=%ld\n", peeraddr.c_str(), channel->fd(), currentThreadEventLoop->tid());
         }
     };
     srv.onMessage = [](const SocketChannelPtr& channel, Buffer* buf) {