Selaa lähdekoodia

Add Channel::setContextPtr/getContextPtr

ithewei 3 vuotta sitten
vanhempi
commit
9a0aecc836
2 muutettua tiedostoa jossa 37 lisäystä ja 8 poistoa
  1. 30 4
      evpp/Channel.h
  2. 7 4
      examples/websocket_server_test.cpp

+ 30 - 4
evpp/Channel.h

@@ -41,10 +41,12 @@ public:
     }
 
     virtual ~Channel() {
-        close();
-        // NOTE: Detach after destructor to avoid triggering onclose
-        if (io_ && id_ == hio_id(io_)) {
-            hio_set_context(io_, NULL);
+        if (isOpened()) {
+            close();
+            // NOTE: Detach after destructor to avoid triggering onclose
+            if (io_ && id_ == hio_id(io_)) {
+                hio_set_context(io_, NULL);
+            }
         }
     }
 
@@ -77,6 +79,29 @@ public:
         }
     }
 
+    // contextPtr
+    std::shared_ptr<void> contextPtr() {
+        return contextPtr_;
+    }
+    void setContextPtr(const std::shared_ptr<void>& ctx) {
+        contextPtr_ = ctx;
+    }
+    void setContextPtr(std::shared_ptr<void>&& ctx) {
+        contextPtr_ = std::move(ctx);
+    }
+    template<class T>
+    std::shared_ptr<T> newContextPtr() {
+        contextPtr_ = std::make_shared<T>();
+        return std::static_pointer_cast<T>(contextPtr_);
+    }
+    template<class T>
+    std::shared_ptr<T> getContextPtr() {
+        return std::static_pointer_cast<T>(contextPtr_);
+    }
+    void deleteContextPtr() {
+        contextPtr_.reset();
+    }
+
     bool isOpened() {
         if (io_ == NULL || status >= DISCONNECTED) return false;
         return id_ == hio_id(io_) && hio_is_opened(io_);
@@ -169,6 +194,7 @@ public:
     // NOTE: Use Channel::isWriteComplete in onwrite callback to determine whether all data has been written.
     std::function<void(Buffer*)> onwrite;
     std::function<void()>        onclose;
+    std::shared_ptr<void>        contextPtr_;
 
 private:
     static void on_read(hio_t* io, void* data, int readbytes) {

+ 7 - 4
examples/websocket_server_test.cpp

@@ -31,9 +31,11 @@ using namespace hv;
 class MyContext {
 public:
     MyContext() {
+        printf("MyContext::MyContext()\n");
         timerID = INVALID_TIMER_ID;
     }
     ~MyContext() {
+        printf("MyContext::~MyContext()\n");
     }
 
     int handleMessage(const std::string& msg, enum ws_opcode opcode) {
@@ -60,7 +62,7 @@ int main(int argc, char** argv) {
     WebSocketService ws;
     ws.onopen = [](const WebSocketChannelPtr& channel, const HttpRequestPtr& req) {
         printf("onopen: GET %s\n", req->Path().c_str());
-        MyContext* ctx = channel->newContext<MyContext>();
+        auto ctx = channel->newContextPtr<MyContext>();
         // send(time) every 1s
         ctx->timerID = setInterval(1000, [channel](TimerID id) {
             if (channel->isConnected() && channel->isWriteComplete()) {
@@ -72,16 +74,17 @@ int main(int argc, char** argv) {
         });
     };
     ws.onmessage = [](const WebSocketChannelPtr& channel, const std::string& msg) {
-        MyContext* ctx = channel->getContext<MyContext>();
+        auto ctx = channel->getContextPtr<MyContext>();
         ctx->handleMessage(msg, channel->opcode);
     };
     ws.onclose = [](const WebSocketChannelPtr& channel) {
         printf("onclose\n");
-        MyContext* ctx = channel->getContext<MyContext>();
+        auto ctx = channel->getContextPtr<MyContext>();
         if (ctx->timerID != INVALID_TIMER_ID) {
             killTimer(ctx->timerID);
+            ctx->timerID = INVALID_TIMER_ID;
         }
-        channel->deleteContext<MyContext>();
+        // channel->deleteContextPtr();
     };
 
     websocket_server_t server;