Explorar el Código

feat: HttpService::IsTrustProxy support wildcard match

ithewei hace 2 años
padre
commit
0e032b773c
Se han modificado 7 ficheros con 50 adiciones y 19 borrados
  1. 18 0
      base/hbase.c
  2. 1 0
      base/hbase.h
  3. 1 0
      docs/API.md
  4. 1 19
      http/server/HttpHandler.cpp
  5. 23 0
      http/server/HttpService.cpp
  6. 1 0
      http/server/HttpService.h
  7. 5 0
      unittest/hbase_test.c

+ 18 - 0
base/hbase.c

@@ -167,6 +167,24 @@ bool hv_strcontains(const char* str, const char* sub) {
     return strstr(str, sub) != NULL;
 }
 
+bool hv_wildcard_match(const char* str, const char* pattern) {
+    assert(str != NULL && pattern != NULL);
+    bool match = false;
+    while (*str && *pattern) {
+        if (*pattern == '*') {
+            match = hv_strendswith(str, pattern + 1);
+            break;
+        } else if (*str != *pattern) {
+            match = false;
+            break;
+        } else {
+            ++str;
+            ++pattern;
+        }
+    }
+    return match ? match : (*str == '\0' && *pattern == '\0');
+}
+
 char* hv_strnchr(const char* s, char c, size_t n) {
     assert(s != NULL);
     const char* p = s;

+ 1 - 0
base/hbase.h

@@ -63,6 +63,7 @@ HV_EXPORT char* hv_strreverse(char* str);
 HV_EXPORT bool hv_strstartswith(const char* str, const char* start);
 HV_EXPORT bool hv_strendswith(const char* str, const char* end);
 HV_EXPORT bool hv_strcontains(const char* str, const char* sub);
+HV_EXPORT bool hv_wildcard_match(const char* str, const char* pattern);
 
 // strncpy n = sizeof(dest_buf)-1
 // hv_strncpy n = sizeof(dest_buf)

+ 1 - 0
docs/API.md

@@ -107,6 +107,7 @@
 - hv_strstartswith
 - hv_strendswith
 - hv_strcontains
+- hv_wildcard_match
 - hv_strnchr
 - hv_strrchr_dot
 - hv_strrchr_dir

+ 1 - 19
http/server/HttpHandler.cpp

@@ -1014,25 +1014,7 @@ int HttpHandler::connectProxy(const std::string& strUrl) {
         }
     }
 
-    bool allow_proxy = true;
-    if (service && service->trustProxies.size() != 0) {
-        allow_proxy = false;
-        for (const auto& trust_proxy : service->trustProxies) {
-            if (trust_proxy == url.host) {
-                allow_proxy = true;
-                break;
-            }
-        }
-    }
-    if (service && service->noProxies.size() != 0) {
-        for (const auto& no_proxy : service->noProxies) {
-            if (no_proxy == url.host) {
-                allow_proxy = false;
-                break;
-            }
-        }
-    }
-    if (!allow_proxy) {
+    if (!service || !service->IsTrustProxy(url.host.c_str())) {
         hlogw("Forbidden to proxy %s", url.host.c_str());
         SetError(HTTP_STATUS_FORBIDDEN, HTTP_STATUS_FORBIDDEN);
         return 0;

+ 23 - 0
http/server/HttpService.cpp

@@ -179,6 +179,29 @@ void HttpService::AddNoProxy(const char* host) {
     noProxies.emplace_back(host);
 }
 
+bool HttpService::IsTrustProxy(const char* host) {
+    if (!host || *host == '\0') return false;
+    bool trust = true;
+    if (trustProxies.size() != 0) {
+        trust = false;
+        for (const auto& trust_proxy : trustProxies) {
+            if (hv_wildcard_match(host, trust_proxy.c_str())) {
+                trust = true;
+                break;
+            }
+        }
+    }
+    if (noProxies.size() != 0) {
+        for (const auto& no_proxy : noProxies) {
+            if (hv_wildcard_match(host, no_proxy.c_str())) {
+                trust = false;
+                break;
+            }
+        }
+    }
+    return trust;
+}
+
 void HttpService::AllowCORS() {
     Use(HttpMiddleware::CORS);
 }

+ 1 - 0
http/server/HttpService.h

@@ -194,6 +194,7 @@ struct HV_EXPORT HttpService {
     // proxy
     void AddTrustProxy(const char* host);
     void AddNoProxy(const char* host);
+    bool IsTrustProxy(const char* host);
     // forward proxy
     void EnableForwardProxy() { enable_forward_proxy = 1; }
     // reverse proxy

+ 5 - 0
unittest/hbase_test.c

@@ -8,6 +8,11 @@ int main(int argc, char* argv[]) {
     assert(hv_getboolean("1"));
     assert(hv_getboolean("yes"));
 
+    assert(hv_wildcard_match("www.example.com", "www.example.com"));
+    assert(hv_wildcard_match("www.example.com", "*.example.com"));
+    assert(hv_wildcard_match("www.example.com", "www.*.com"));
+    assert(hv_wildcard_match("www.example.com", "www.example.*"));
+
     assert(hv_parse_size("256") == 256);
     assert(hv_parse_size("1K") == 1024);
     assert(hv_parse_size("1G2M3K4B") ==