TcpClient.h 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. #ifndef HV_TCP_CLIENT_HPP_
  2. #define HV_TCP_CLIENT_HPP_
  3. #include "hsocket.h"
  4. #include "hssl.h"
  5. #include "hlog.h"
  6. #include "EventLoopThread.h"
  7. #include "Channel.h"
  8. namespace hv {
  9. template<class TSocketChannel = SocketChannel>
  10. class TcpClientEventLoopTmpl {
  11. public:
  12. typedef std::shared_ptr<TSocketChannel> TSocketChannelPtr;
  13. TcpClientEventLoopTmpl(EventLoopPtr loop = NULL) {
  14. loop_ = loop ? loop : std::make_shared<EventLoop>();
  15. connect_timeout = HIO_DEFAULT_CONNECT_TIMEOUT;
  16. tls = false;
  17. tls_setting = NULL;
  18. reconn_setting = NULL;
  19. unpack_setting = NULL;
  20. }
  21. virtual ~TcpClientEventLoopTmpl() {
  22. HV_FREE(tls_setting);
  23. HV_FREE(reconn_setting);
  24. HV_FREE(unpack_setting);
  25. }
  26. const EventLoopPtr& loop() {
  27. return loop_;
  28. }
  29. // NOTE: By default, not bind local port. If necessary, you can call bind() after createsocket().
  30. // @retval >=0 connfd, <0 error
  31. int createsocket(int remote_port, const char* remote_host = "127.0.0.1") {
  32. memset(&remote_addr, 0, sizeof(remote_addr));
  33. int ret = sockaddr_set_ipport(&remote_addr, remote_host, remote_port);
  34. if (ret != 0) {
  35. return NABS(ret);
  36. }
  37. this->remote_host = remote_host;
  38. this->remote_port = remote_port;
  39. return createsocket(&remote_addr.sa);
  40. }
  41. int createsocket(struct sockaddr* remote_addr) {
  42. int connfd = ::socket(remote_addr->sa_family, SOCK_STREAM, 0);
  43. // SOCKADDR_PRINT(remote_addr);
  44. if (connfd < 0) {
  45. perror("socket");
  46. return -2;
  47. }
  48. hio_t* io = hio_get(loop_->loop(), connfd);
  49. assert(io != NULL);
  50. hio_set_peeraddr(io, remote_addr, SOCKADDR_LEN(remote_addr));
  51. channel.reset(new TSocketChannel(io));
  52. return connfd;
  53. }
  54. int bind(int local_port, const char* local_host = "0.0.0.0") {
  55. sockaddr_u local_addr;
  56. memset(&local_addr, 0, sizeof(local_addr));
  57. int ret = sockaddr_set_ipport(&local_addr, local_host, local_port);
  58. if (ret != 0) {
  59. return NABS(ret);
  60. }
  61. return bind(&local_addr.sa);
  62. }
  63. int bind(struct sockaddr* local_addr) {
  64. if (channel == NULL || channel->isClosed()) {
  65. return -1;
  66. }
  67. int ret = ::bind(channel->fd(), local_addr, SOCKADDR_LEN(local_addr));
  68. if (ret != 0) {
  69. perror("bind");
  70. }
  71. return ret;
  72. }
  73. // closesocket thread-safe
  74. void closesocket() {
  75. if (channel) {
  76. loop_->runInLoop([this](){
  77. if (channel) {
  78. setReconnect(NULL);
  79. channel->close();
  80. }
  81. });
  82. }
  83. }
  84. int startConnect() {
  85. if (channel == NULL || channel->isClosed()) {
  86. int connfd = createsocket(&remote_addr.sa);
  87. if (connfd < 0) {
  88. hloge("createsocket %s:%d return %d!\n", remote_host.c_str(), remote_port, connfd);
  89. return connfd;
  90. }
  91. }
  92. if (channel == NULL || channel->status >= SocketChannel::CONNECTING) {
  93. return -1;
  94. }
  95. if (connect_timeout) {
  96. channel->setConnectTimeout(connect_timeout);
  97. }
  98. if (tls) {
  99. channel->enableSSL();
  100. if (tls_setting) {
  101. int ret = channel->newSslCtx(tls_setting);
  102. if (ret != 0) {
  103. hloge("new SSL_CTX failed: %d", ret);
  104. closesocket();
  105. return ret;
  106. }
  107. }
  108. if (!is_ipaddr(remote_host.c_str())) {
  109. channel->setHostname(remote_host);
  110. }
  111. }
  112. channel->onconnect = [this]() {
  113. if (unpack_setting) {
  114. channel->setUnpack(unpack_setting);
  115. }
  116. channel->startRead();
  117. if (onConnection) {
  118. onConnection(channel);
  119. }
  120. if (reconn_setting) {
  121. reconn_setting_reset(reconn_setting);
  122. }
  123. };
  124. channel->onread = [this](Buffer* buf) {
  125. if (onMessage) {
  126. onMessage(channel, buf);
  127. }
  128. };
  129. channel->onwrite = [this](Buffer* buf) {
  130. if (onWriteComplete) {
  131. onWriteComplete(channel, buf);
  132. }
  133. };
  134. channel->onclose = [this]() {
  135. if (onConnection) {
  136. onConnection(channel);
  137. }
  138. // reconnect
  139. if (reconn_setting) {
  140. startReconnect();
  141. }
  142. };
  143. return channel->startConnect();
  144. }
  145. int startReconnect() {
  146. if (!reconn_setting) return -1;
  147. if (!reconn_setting_can_retry(reconn_setting)) return -2;
  148. uint32_t delay = reconn_setting_calc_delay(reconn_setting);
  149. hlogi("reconnect... cnt=%d, delay=%d", reconn_setting->cur_retry_cnt, reconn_setting->cur_delay);
  150. loop_->setTimeout(delay, [this](TimerID timerID){
  151. startConnect();
  152. });
  153. return 0;
  154. }
  155. // start thread-safe
  156. void start() {
  157. loop_->runInLoop(std::bind(&TcpClientEventLoopTmpl::startConnect, this));
  158. }
  159. bool isConnected() {
  160. if (channel == NULL) return false;
  161. return channel->isConnected();
  162. }
  163. // send thread-safe
  164. int send(const void* data, int size) {
  165. if (!isConnected()) return -1;
  166. return channel->write(data, size);
  167. }
  168. int send(Buffer* buf) {
  169. return send(buf->data(), buf->size());
  170. }
  171. int send(const std::string& str) {
  172. return send(str.data(), str.size());
  173. }
  174. int withTLS(hssl_ctx_opt_t* opt = NULL) {
  175. tls = true;
  176. if (opt) {
  177. if (tls_setting == NULL) {
  178. HV_ALLOC_SIZEOF(tls_setting);
  179. }
  180. opt->endpoint = HSSL_CLIENT;
  181. *tls_setting = *opt;
  182. }
  183. return 0;
  184. }
  185. void setConnectTimeout(int ms) {
  186. connect_timeout = ms;
  187. }
  188. void setReconnect(reconn_setting_t* setting) {
  189. if (setting == NULL) {
  190. HV_FREE(reconn_setting);
  191. return;
  192. }
  193. if (reconn_setting == NULL) {
  194. HV_ALLOC_SIZEOF(reconn_setting);
  195. }
  196. *reconn_setting = *setting;
  197. }
  198. bool isReconnect() {
  199. return reconn_setting && reconn_setting->cur_retry_cnt > 0;
  200. }
  201. void setUnpack(unpack_setting_t* setting) {
  202. if (setting == NULL) {
  203. HV_FREE(unpack_setting);
  204. return;
  205. }
  206. if (unpack_setting == NULL) {
  207. HV_ALLOC_SIZEOF(unpack_setting);
  208. }
  209. *unpack_setting = *setting;
  210. }
  211. public:
  212. TSocketChannelPtr channel;
  213. std::string remote_host;
  214. int remote_port;
  215. sockaddr_u remote_addr;
  216. int connect_timeout;
  217. bool tls;
  218. hssl_ctx_opt_t* tls_setting;
  219. reconn_setting_t* reconn_setting;
  220. unpack_setting_t* unpack_setting;
  221. // Callback
  222. std::function<void(const TSocketChannelPtr&)> onConnection;
  223. std::function<void(const TSocketChannelPtr&, Buffer*)> onMessage;
  224. // NOTE: Use Channel::isWriteComplete in onWriteComplete callback to determine whether all data has been written.
  225. std::function<void(const TSocketChannelPtr&, Buffer*)> onWriteComplete;
  226. private:
  227. EventLoopPtr loop_;
  228. };
  229. template<class TSocketChannel = SocketChannel>
  230. class TcpClientTmpl : private EventLoopThread, public TcpClientEventLoopTmpl<TSocketChannel> {
  231. public:
  232. TcpClientTmpl(EventLoopPtr loop = NULL)
  233. : EventLoopThread(loop)
  234. , TcpClientEventLoopTmpl<TSocketChannel>(EventLoopThread::loop())
  235. {}
  236. virtual ~TcpClientTmpl() {
  237. stop(true);
  238. }
  239. const EventLoopPtr& loop() {
  240. return EventLoopThread::loop();
  241. }
  242. // start thread-safe
  243. void start(bool wait_threads_started = true) {
  244. if (isRunning()) {
  245. TcpClientEventLoopTmpl<TSocketChannel>::start();
  246. } else {
  247. EventLoopThread::start(wait_threads_started, std::bind(&TcpClientTmpl::startConnect, this));
  248. }
  249. }
  250. // stop thread-safe
  251. void stop(bool wait_threads_stopped = true) {
  252. TcpClientEventLoopTmpl<TSocketChannel>::closesocket();
  253. EventLoopThread::stop(wait_threads_stopped);
  254. }
  255. };
  256. typedef TcpClientTmpl<SocketChannel> TcpClient;
  257. }
  258. #endif // HV_TCP_CLIENT_HPP_