Ver código fonte

WebSocketClient::open(url, headers)

ithewei 4 anos atrás
pai
commit
e71b9a9b86

+ 3 - 1
examples/websocket_client_test.cpp

@@ -37,7 +37,9 @@ int main(int argc, char** argv) {
     reconn.delay_policy = 2;
     reconn.delay_policy = 2;
     ws.setReconnect(&reconn);
     ws.setReconnect(&reconn);
 
 
-    ws.open(url);
+    http_headers headers;
+    headers["Origin"] = "http://example.com/";
+    ws.open(url, headers);
 
 
     std::string str;
     std::string str;
     while (std::getline(std::cin, str)) {
     while (std::getline(std::cin, str)) {

+ 4 - 2
http/HttpMessage.cpp

@@ -9,6 +9,8 @@
 
 
 using namespace hv;
 using namespace hv;
 
 
+http_headers DefaultHeaders;
+http_body    NoBody;
 char HttpMessage::s_date[32] = {0};
 char HttpMessage::s_date[32] = {0};
 
 
 bool HttpCookie::parse(const std::string& str) {
 bool HttpCookie::parse(const std::string& str) {
@@ -424,11 +426,11 @@ int HttpMessage::ParseBody() {
     {
     {
         auto iter = headers.find("Content-Type");
         auto iter = headers.find("Content-Type");
         if (iter == headers.end()) {
         if (iter == headers.end()) {
-            return false;
+            return -1;
         }
         }
         const char* boundary = strstr(iter->second.c_str(), "boundary=");
         const char* boundary = strstr(iter->second.c_str(), "boundary=");
         if (boundary == NULL) {
         if (boundary == NULL) {
-            return false;
+            return -1;
         }
         }
         boundary += strlen("boundary=");
         boundary += strlen("boundary=");
         std::string strBoundary(boundary);
         std::string strBoundary(boundary);

+ 3 - 0
http/HttpMessage.h

@@ -82,6 +82,9 @@ typedef std::function<void(const http_headers& headers)>        http_head_cb;
 typedef std::function<void(const char* data, size_t size)>      http_body_cb;
 typedef std::function<void(const char* data, size_t size)>      http_body_cb;
 typedef std::function<void(const char* data, size_t size)>      http_chunked_cb;
 typedef std::function<void(const char* data, size_t size)>      http_chunked_cb;
 
 
+HV_EXPORT extern http_headers DefaultHeaders;
+HV_EXPORT extern http_body    NoBody;
+
 class HV_EXPORT HttpMessage {
 class HV_EXPORT HttpMessage {
 public:
 public:
     static char         s_date[32];
     static char         s_date[32];

+ 18 - 10
http/client/WebSocketClient.cpp

@@ -25,7 +25,7 @@ WebSocketClient::~WebSocketClient() {
  * TCP::onMessage => WebSocketParser => WS::onmessage =>
  * TCP::onMessage => WebSocketParser => WS::onmessage =>
  * TCP::onConnection => WS::onclose
  * TCP::onConnection => WS::onclose
  */
  */
-int WebSocketClient::open(const char* _url) {
+int WebSocketClient::open(const char* _url, const http_headers& headers) {
     close();
     close();
 
 
     // ParseUrl
     // ParseUrl
@@ -54,22 +54,30 @@ int WebSocketClient::open(const char* _url) {
         withTLS();
         withTLS();
     }
     }
 
 
+    for (auto& header : headers) {
+        http_req_->headers[header.first] = header.second;
+    }
+
     onConnection = [this](const WebSocketChannelPtr& channel) {
     onConnection = [this](const WebSocketChannelPtr& channel) {
         if (channel->isConnected()) {
         if (channel->isConnected()) {
             state = CONNECTED;
             state = CONNECTED;
             // websocket_handshake
             // websocket_handshake
             http_req_->headers["Connection"] = "Upgrade";
             http_req_->headers["Connection"] = "Upgrade";
             http_req_->headers["Upgrade"] = "websocket";
             http_req_->headers["Upgrade"] = "websocket";
-            // generate SEC_WEBSOCKET_KEY
-            unsigned char rand_key[16] = {0};
-            int *p = (int*)rand_key;
-            for (int i = 0; i < 4; ++i, ++p) {
-                *p = rand();
+            if (http_req_->GetHeader(SEC_WEBSOCKET_KEY).empty()) {
+                // generate SEC_WEBSOCKET_KEY
+                unsigned char rand_key[16] = {0};
+                int *p = (int*)rand_key;
+                for (int i = 0; i < 4; ++i, ++p) {
+                    *p = rand();
+                }
+                char ws_key[32] = {0};
+                hv_base64_encode(rand_key, 16, ws_key);
+                http_req_->headers[SEC_WEBSOCKET_KEY] = ws_key;
+            }
+            if (http_req_->GetHeader(SEC_WEBSOCKET_VERSION).empty()) {
+                http_req_->headers[SEC_WEBSOCKET_VERSION] = "13";
             }
             }
-            char ws_key[32] = {0};
-            hv_base64_encode(rand_key, 16, ws_key);
-            http_req_->headers[SEC_WEBSOCKET_KEY] = ws_key;
-            http_req_->headers[SEC_WEBSOCKET_VERSION] = "13";
             std::string http_msg = http_req_->Dump(true, true);
             std::string http_msg = http_req_->Dump(true, true);
             // printf("%s", http_msg.c_str());
             // printf("%s", http_msg.c_str());
             // NOTE: not use WebSocketChannel::send
             // NOTE: not use WebSocketChannel::send

+ 1 - 1
http/client/WebSocketClient.h

@@ -27,7 +27,7 @@ public:
 
 
     // url = ws://ip:port/path
     // url = ws://ip:port/path
     // url = wss://ip:port/path
     // url = wss://ip:port/path
-    int open(const char* url);
+    int open(const char* url, const http_headers& headers = DefaultHeaders);
     int close();
     int close();
     int send(const std::string& msg);
     int send(const std::string& msg);
     int send(const char* buf, int len, enum ws_opcode opcode = WS_OPCODE_BINARY);
     int send(const char* buf, int len, enum ws_opcode opcode = WS_OPCODE_BINARY);