1
0
ithewei 6 жил өмнө
parent
commit
ac2b444248

+ 1 - 3
base/RAII.cpp

@@ -3,9 +3,7 @@
 #ifdef OS_WIN
 class WsaRAII {
 public:
-    WsaRAII() {
-        WSADATA wsadata;
-        WSAStartup(MAKEWORD(2,2), &wsadata);
+    WsaRAII() { WSADATA wsadata; WSAStartup(MAKEWORD(2,2), &wsadata);
     }
     ~WsaRAII() {
         WSACleanup();

+ 249 - 96
http/client/http_client.cpp

@@ -2,12 +2,134 @@
 
 #include "hstring.h"
 
-#ifdef WITH_CURL
+#define MAX_CONNECT_TIMEOUT 3000 // ms
 
-/***************************************************************
-HttpClient based libcurl
-***************************************************************/
+#ifdef WITH_CURL
 #include "curl/curl.h"
+#else
+#include "herr.h"
+#include "hsocket.h"
+#include "HttpParser.h"
+#include "ssl_ctx.h"
+#endif
+
+#ifdef WITH_OPENSSL
+#include "openssl/ssl.h"
+#endif
+
+struct http_session_s {
+    int          use_tls;
+    std::string  host;
+    int          port;
+    int          timeout;
+    http_headers headers;
+//private:
+#ifdef WITH_CURL
+    CURL* curl;
+#else
+    int fd;
+#endif
+#ifdef WITH_OPENSSL
+    SSL* ssl;
+#endif
+
+    http_session_s() {
+        use_tls = 0;
+        port = DEFAULT_HTTP_PORT;
+        timeout = DEFAULT_HTTP_TIMEOUT;
+#ifdef WITH_CURL
+        curl = NULL;
+#else
+        fd = -1;
+#endif
+#ifdef WITH_OPENSSL
+        ssl = NULL;
+#endif
+    }
+
+    ~http_session_s() {
+#ifdef WITH_OPENSSL
+        if (ssl) {
+            SSL_free(ssl);
+            ssl = NULL;
+        }
+#endif
+#ifdef WITH_CURL
+        if (curl) {
+            curl_easy_cleanup(curl);
+            curl = NULL;
+        }
+#else
+        if (fd > 0) {
+            closesocket(fd);
+            fd = -1;
+        }
+#endif
+    }
+};
+
+static int __http_session_send(http_session_t* hss, HttpRequest* req, HttpResponse* res);
+
+http_session_t* http_session_new(const char* host, int port) {
+    http_session_t* hss = new http_session_t;
+    hss->host = host;
+    hss->port = port;
+    hss->headers["Host"] = asprintf("%s:%d", host, port);
+    hss->headers["Connection"] = "keep-alive";
+    return hss;
+}
+
+int http_session_del(http_session_t* hss) {
+    if (hss == NULL) return 0;
+    delete hss;
+    return 0;
+}
+
+int http_session_set_timeout(http_session_t* hss, int timeout) {
+    hss->timeout = timeout;
+    return 0;
+}
+
+int http_session_clear_headers(http_session_t* hss) {
+    hss->headers.clear();
+    return 0;
+}
+
+int http_session_set_header(http_session_t* hss, const char* key, const char* value) {
+    hss->headers[key] = value;
+    return 0;
+}
+
+int http_session_del_header(http_session_t* hss, const char* key) {
+    auto iter = hss->headers.find(key);
+    if (iter != hss->headers.end()) {
+        hss->headers.erase(iter);
+    }
+    return 0;
+}
+
+const char* http_session_get_header(http_session_t* hss, const char* key) {
+    auto iter = hss->headers.find(key);
+    if (iter != hss->headers.end()) {
+        return iter->second.c_str();
+    }
+    return NULL;
+}
+
+int http_session_send(http_session_t* hss, HttpRequest* req, HttpResponse* res) {
+    for (auto& pair : hss->headers) {
+        req->headers[pair.first] = pair.second;
+    }
+    return __http_session_send(hss, req, res);
+}
+
+int http_client_send(HttpRequest* req, HttpResponse* res, int timeout) {
+    http_session_t hss;
+    hss.timeout = timeout;
+    return __http_session_send(&hss, req, res);
+}
+
+#ifdef WITH_CURL
 
 static size_t s_formget_cb(void *arg, const char *buf, size_t len) {
     return len;
@@ -49,26 +171,30 @@ static size_t s_body_cb(char *buf, size_t size, size_t cnt, void *userdata) {
     return size*cnt;
 }
 
-int http_client_send(HttpRequest* req, HttpResponse* res, int timeout) {
+int __http_session_send(http_session_t* hss, HttpRequest* req, HttpResponse* res) {
     if (req == NULL || res == NULL) {
         return -1;
     }
 
-    CURL* handle = curl_easy_init();
+    if (hss->curl == NULL) {
+        hss->curl = curl_easy_init();
+    }
+    CURL* curl = hss->curl;
+    int timeout = hss->timeout;
 
     // SSL
-    curl_easy_setopt(handle, CURLOPT_SSL_VERIFYPEER, 0);
-    curl_easy_setopt(handle, CURLOPT_SSL_VERIFYHOST, 0);
+    curl_easy_setopt(curl, CURLOPT_SSL_VERIFYPEER, 0);
+    curl_easy_setopt(curl, CURLOPT_SSL_VERIFYHOST, 0);
 
     // TCP_NODELAY
-    curl_easy_setopt(handle, CURLOPT_TCP_NODELAY, 1);
+    curl_easy_setopt(curl, CURLOPT_TCP_NODELAY, 1);
 
     // method
-    curl_easy_setopt(handle, CURLOPT_CUSTOMREQUEST, http_method_str(req->method));
+    curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, http_method_str(req->method));
 
     // url
     std::string url = req->dump_url();
-    curl_easy_setopt(handle, CURLOPT_URL, url.c_str());
+    curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
     //hlogd("%s %s HTTP/%d.%d", http_method_str(req->method), url.c_str(), req->http_major, req->http_minor);
 
     // header
@@ -80,7 +206,7 @@ int http_client_send(HttpRequest* req, HttpResponse* res, int timeout) {
         header += pair.second;
         headers = curl_slist_append(headers, header.c_str());
     }
-    curl_easy_setopt(handle, CURLOPT_HTTPHEADER, headers);
+    curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);
 
     // body
     struct curl_httppost* httppost = NULL;
@@ -104,28 +230,28 @@ int http_client_send(HttpRequest* req, HttpResponse* res, int timeout) {
                 }
             }
             if (httppost) {
-                curl_easy_setopt(handle, CURLOPT_HTTPPOST, httppost);
+                curl_easy_setopt(curl, CURLOPT_HTTPPOST, httppost);
                 curl_formget(httppost, NULL, s_formget_cb);
             }
         }
     }
     if (req->body.size() != 0) {
-        curl_easy_setopt(handle, CURLOPT_POSTFIELDS, req->body.c_str());
-        curl_easy_setopt(handle, CURLOPT_POSTFIELDSIZE, req->body.size());
+        curl_easy_setopt(curl, CURLOPT_POSTFIELDS, req->body.c_str());
+        curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, req->body.size());
     }
 
     if (timeout > 0) {
-        curl_easy_setopt(handle, CURLOPT_TIMEOUT, timeout);
+        curl_easy_setopt(curl, CURLOPT_TIMEOUT, timeout);
     }
 
-    curl_easy_setopt(handle, CURLOPT_WRITEFUNCTION, s_body_cb);
-    curl_easy_setopt(handle, CURLOPT_WRITEDATA, res);
+    curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, s_body_cb);
+    curl_easy_setopt(curl, CURLOPT_WRITEDATA, res);
 
-    curl_easy_setopt(handle, CURLOPT_HEADER, 0);
-    curl_easy_setopt(handle, CURLOPT_HEADERFUNCTION, s_header_cb);
-    curl_easy_setopt(handle, CURLOPT_HEADERDATA, res);
+    curl_easy_setopt(curl, CURLOPT_HEADER, 0);
+    curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, s_header_cb);
+    curl_easy_setopt(curl, CURLOPT_HEADERDATA, res);
 
-    int ret = curl_easy_perform(handle);
+    int ret = curl_easy_perform(curl);
     /*
     if (ret != 0) {
         hloge("curl error: %d: %s", ret, curl_easy_strerror((CURLcode)ret));
@@ -134,10 +260,10 @@ int http_client_send(HttpRequest* req, HttpResponse* res, int timeout) {
         hlogd("[Response]\n%s", res->body.c_str());
     }
     double total_time, name_time, conn_time, pre_time;
-    curl_easy_getinfo(handle, CURLINFO_TOTAL_TIME, &total_time);
-    curl_easy_getinfo(handle, CURLINFO_NAMELOOKUP_TIME, &name_time);
-    curl_easy_getinfo(handle, CURLINFO_CONNECT_TIME, &conn_time);
-    curl_easy_getinfo(handle, CURLINFO_PRETRANSFER_TIME, &pre_time);
+    curl_easy_getinfo(curl, CURLINFO_TOTAL_TIME, &total_time);
+    curl_easy_getinfo(curl, CURLINFO_NAMELOOKUP_TIME, &name_time);
+    curl_easy_getinfo(curl, CURLINFO_CONNECT_TIME, &conn_time);
+    curl_easy_getinfo(curl, CURLINFO_PRETRANSFER_TIME, &pre_time);
     hlogd("TIME_INFO: %lf,%lf,%lf,%lf", total_time, name_time, conn_time, pre_time);
     */
 
@@ -148,8 +274,6 @@ int http_client_send(HttpRequest* req, HttpResponse* res, int timeout) {
         curl_formfree(httppost);
     }
 
-    curl_easy_cleanup(handle);
-
     return ret;
 }
 
@@ -157,88 +281,128 @@ const char* http_client_strerror(int errcode) {
     return curl_easy_strerror((CURLcode)errcode);
 }
 #else
-#include "herr.h"
-#include "hsocket.h"
-#include "HttpParser.h"
-#include "ssl_ctx.h"
-#ifdef WITH_OPENSSL
-#include "openssl/ssl.h"
-#endif
-int http_client_send(HttpRequest* req, HttpResponse* res, int timeout) {
-    // connect -> send -> recv -> http_parser
-    int ssl_enable = 0;
-    if (strncmp(req->url.c_str(), "https", 5) == 0) {
-        ssl_enable = 1;
+static int __http_session_connect(http_session_t* hss) {
+    int blocktime = MAX_CONNECT_TIMEOUT;
+    if (hss->timeout > 0) {
+        blocktime = MIN(hss->timeout*1000, blocktime);
+    }
+    int connfd = ConnectTimeout(hss->host.c_str(), hss->port, blocktime);
+    if (connfd < 0) {
+        return socket_errno();
+    }
+    tcp_nodelay(connfd, 1);
+
+    if (hss->use_tls) {
 #ifdef WITH_OPENSSL
         if (g_ssl_ctx == NULL) {
             ssl_ctx_init(NULL, NULL, NULL);
         }
+        hss->ssl = SSL_new((SSL_CTX*)g_ssl_ctx);
+        SSL_set_fd(hss->ssl, connfd);
+        if (SSL_connect(hss->ssl) != 1) {
+            int err = SSL_get_error(hss->ssl, -1);
+            fprintf(stderr, "SSL handshark failed: %d\n", err);
+            SSL_free(hss->ssl);
+            hss->ssl = NULL;
+            closesocket(connfd);
+            return err;
+        }
 #else
         fprintf(stderr, "Please recompile WITH_OPENSSL\n");
+        closesocket(connfd);
         return ERR_INVALID_PROTOCOL;
 #endif
     }
-    time_t start_time = time(NULL);
-    time_t cur_time;
-    std::string http = req->dump(true, true);
-    auto Host = req->headers.find("Host");
-    if (Host == req->headers.end()) {
-        return ERR_INVALID_PARAM;
+    hss->fd = connfd;
+    return 0;
+}
+
+static int __http_session_close(http_session_t* hss) {
+#ifdef WITH_OPENSSL
+    if (hss->ssl) {
+        SSL_free(hss->ssl);
+        hss->ssl = NULL;
     }
-    StringList strlist = split(Host->second, ':');
-    std::string host;
-    int port = 80;
-    host = strlist[0];
-    if (strlist.size() == 2) {
-        port = atoi(strlist[1].c_str());
+#endif
+    if (hss->fd > 0) {
+        closesocket(hss->fd);
+        hss->fd = -1;
     }
-    int blocktime = 3000;
-    if (timeout > 0) {
-        blocktime = MIN(timeout*1000, blocktime);
+    return 0;
+}
+
+static int __http_session_send(http_session_t* hss, HttpRequest* req, HttpResponse* res) {
+    // connect -> send -> recv -> http_parser
+    int err = 0;
+    int timeout = hss->timeout;
+    SOCKET connfd = hss->fd;
+
+    // use_tls ?
+    int use_tls = hss->use_tls;
+    if (strncmp(req->url.c_str(), "https", 5) == 0) {
+        hss->use_tls = use_tls = 1;
     }
-    SOCKET connfd = ConnectTimeout(host.c_str(), port, blocktime);
-    if (connfd < 0) {
-        return socket_errno();
+
+    // parse host:port from Headers
+    std::string http = req->dump(true, true);
+    if (hss->host.size() == 0) {
+        auto Host = req->headers.find("Host");
+        if (Host == req->headers.end()) {
+            return ERR_INVALID_PARAM;
+        }
+        StringList strlist = split(Host->second, ':');
+        hss->host = strlist[0];
+        if (strlist.size() == 2) {
+            hss->port = atoi(strlist[1].c_str());
+        }
+        else {
+            hss->port = DEFAULT_HTTP_PORT;
+        }
     }
-#ifdef WITH_OPENSSL
-    SSL* ssl = NULL;
-    if (ssl_enable) {
-        ssl = SSL_new((SSL_CTX*)g_ssl_ctx);
-        SSL_set_fd(ssl, connfd);
-        if (SSL_connect(ssl) != 1) {
-            fprintf(stderr, "SSL handshark failed: %d\n", SSL_get_error(ssl, -1));
+
+    time_t start_time = time(NULL);
+    time_t cur_time;
+    int fail_cnt = 0;
+connect:
+    if (connfd <= 0) {
+        int ret = __http_session_connect(hss);
+        if (ret != 0) {
+            return ret;
         }
+        connfd = hss->fd;
     }
-#endif
-    tcp_nodelay(connfd, 1);
-    int err = 0;
+
     HttpParser parser;
     parser.parser_response_init(res);
     char recvbuf[1024] = {0};
+    int total_nsend, nsend, nrecv;
 send:
-    int total_nsend = 0;
-    int nsend = 0;
-    int nrecv = 0;
+    total_nsend = nsend = nrecv = 0;
     while (1) {
         if (timeout > 0) {
             cur_time = time(NULL);
             if (cur_time - start_time >= timeout) {
-                err = ERR_TASK_TIMEOUT;
-                goto ret;
+                return ERR_TASK_TIMEOUT;
             }
             so_sndtimeo(connfd, (timeout-(cur_time-start_time)) * 1000);
         }
 #ifdef WITH_OPENSSL
-        if (ssl_enable) {
-            nsend = SSL_write(ssl, http.c_str()+total_nsend, http.size()-total_nsend);
+        if (use_tls) {
+            nsend = SSL_write(hss->ssl, http.c_str()+total_nsend, http.size()-total_nsend);
         }
 #endif
-        if (!ssl_enable) {
+        if (!use_tls) {
             nsend = send(connfd, http.c_str()+total_nsend, http.size()-total_nsend, 0);
         }
         if (nsend <= 0) {
-            err = socket_errno();
-            goto ret;
+            if (++fail_cnt == 1) {
+                // maybe keep-alive timeout, try again
+                __http_session_close(hss);
+                goto connect;
+            }
+            else {
+                return socket_errno();
+            }
         }
         total_nsend += nsend;
         if (total_nsend == http.size()) {
@@ -250,27 +414,24 @@ recv:
         if (timeout > 0) {
             cur_time = time(NULL);
             if (cur_time - start_time >= timeout) {
-                err = ERR_TASK_TIMEOUT;
-                goto ret;
+                return ERR_TASK_TIMEOUT;
             }
             so_rcvtimeo(connfd, (timeout-(cur_time-start_time)) * 1000);
         }
 #ifdef WITH_OPENSSL
-        if (ssl_enable) {
-            nrecv = SSL_read(ssl, recvbuf, sizeof(recvbuf));
+        if (use_tls) {
+            nrecv = SSL_read(hss->ssl, recvbuf, sizeof(recvbuf));
         }
 #endif
-        if (!ssl_enable) {
+        if (!use_tls) {
             nrecv = recv(connfd, recvbuf, sizeof(recvbuf), 0);
         }
         if (nrecv <= 0) {
-            err = socket_errno();
-            goto ret;
+            return socket_errno();
         }
         int nparse = parser.execute(recvbuf, nrecv);
         if (nparse != nrecv || parser.get_errno() != HPE_OK) {
-            err = ERR_PARSE;
-            goto ret;
+            return ERR_PARSE;
         }
         if (parser.get_state() == HP_MESSAGE_COMPLETE) {
             err = 0;
@@ -279,19 +440,11 @@ recv:
         if (timeout > 0) {
             cur_time = time(NULL);
             if (cur_time - start_time >= timeout) {
-                err = ERR_TASK_TIMEOUT;
-                goto ret;
+                return ERR_TASK_TIMEOUT;
             }
             so_rcvtimeo(connfd, (timeout-(cur_time-start_time)) * 1000);
         }
     }
-ret:
-#ifdef WITH_OPENSSL
-    if (ssl) {
-        SSL_free(ssl);
-    }
-#endif
-    closesocket(connfd);
     return err;
 }
 

+ 14 - 0
http/client/http_client.h

@@ -26,7 +26,21 @@ int main(int argc, char* argv[]) {
 */
 
 #define DEFAULT_HTTP_TIMEOUT    10 // s
+#define DEFAULT_HTTP_PORT       80
 int http_client_send(HttpRequest* req, HttpResponse* res, int timeout = DEFAULT_HTTP_TIMEOUT);
 const char* http_client_strerror(int errcode);
 
+// http_session: Connection: keep-alive
+typedef struct http_session_s http_session_t;
+http_session_t* http_session_new(const char* host, int port = DEFAULT_HTTP_PORT);
+int http_session_del(http_session_t* hss);
+
+int http_session_set_timeout(http_session_t* hss, int timeout);
+int http_session_clear_headers(http_session_t* hss);
+int http_session_set_header(http_session_t* hss, const char* key, const char* value);
+int http_session_del_header(http_session_t* hss, const char* key);
+const char* http_session_get_header(http_session_t* hss, const char* key);
+
+int http_session_send(http_session_t* hss, HttpRequest* req, HttpResponse* res);
+
 #endif  // HTTP_CLIENT_H_

+ 1 - 1
winbuild/libhw/libhw.vcxproj

@@ -163,7 +163,6 @@
     <ClInclude Include="..\..\base\queue.h" />
     <ClInclude Include="..\..\base\ssl_ctx.h" />
     <ClInclude Include="..\..\event\hevent.h" />
-    <ClInclude Include="..\..\event\hio.h" />
     <ClInclude Include="..\..\event\hloop.h" />
     <ClInclude Include="..\..\event\iowatcher.h" />
     <ClInclude Include="..\..\event\nlog.h" />
@@ -208,6 +207,7 @@
     <ClCompile Include="..\..\base\ssl_ctx.c" />
     <ClCompile Include="..\..\event\epoll.c" />
     <ClCompile Include="..\..\event\evport.c" />
+    <ClCompile Include="..\..\event\hevent.c" />
     <ClCompile Include="..\..\event\hloop.c" />
     <ClCompile Include="..\..\event\iocp.c" />
     <ClCompile Include="..\..\event\kqueue.c" />