1
0

gnutls.c 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. #include "hssl.h"
  2. #ifdef WITH_GNUTLS
  3. #include "gnutls/gnutls.h"
  4. const char* hssl_backend() {
  5. return "gnutls";
  6. }
  7. typedef gnutls_certificate_credentials_t gnutls_ctx_t;
  8. hssl_ctx_t hssl_ctx_init(hssl_ctx_init_param_t* param) {
  9. static int s_initialized = 0;
  10. if (s_initialized == 0) {
  11. gnutls_global_init();
  12. s_initialized = 1;
  13. }
  14. gnutls_ctx_t ctx;
  15. const char* crt_file = NULL;
  16. const char* key_file = NULL;
  17. const char* ca_file = NULL;
  18. const char* ca_path = NULL;
  19. int ret = gnutls_certificate_allocate_credentials(&ctx);
  20. if (ret != GNUTLS_E_SUCCESS) {
  21. return NULL;
  22. }
  23. if (param) {
  24. if (param->crt_file && *param->crt_file) {
  25. crt_file = param->crt_file;
  26. }
  27. if (param->key_file && *param->key_file) {
  28. key_file = param->key_file;
  29. }
  30. if (param->ca_file && *param->ca_file) {
  31. ca_file = param->ca_file;
  32. }
  33. if (param->ca_path && *param->ca_path) {
  34. ca_path = param->ca_path;
  35. }
  36. if (ca_file) {
  37. ret = gnutls_certificate_set_x509_trust_file(ctx, ca_file, GNUTLS_X509_FMT_PEM);
  38. if (ret < 0) {
  39. fprintf(stderr, "ssl ca_file failed!\n");
  40. goto error;
  41. }
  42. }
  43. if (ca_path) {
  44. ret = gnutls_certificate_set_x509_trust_dir(ctx, ca_path, GNUTLS_X509_FMT_PEM);
  45. if (ret < 0) {
  46. fprintf(stderr, "ssl ca_path failed!\n");
  47. goto error;
  48. }
  49. }
  50. if (crt_file && key_file) {
  51. ret = gnutls_certificate_set_x509_key_file(ctx, crt_file, key_file, GNUTLS_X509_FMT_PEM);
  52. if (ret != GNUTLS_E_SUCCESS) {
  53. fprintf(stderr, "ssl crt_file/key_file error!\n");
  54. goto error;
  55. }
  56. }
  57. if (param->verify_peer && !ca_file && !ca_path) {
  58. gnutls_certificate_set_x509_system_trust(ctx);
  59. }
  60. }
  61. g_ssl_ctx = ctx;
  62. return ctx;
  63. error:
  64. gnutls_certificate_free_credentials(ctx);
  65. return NULL;
  66. }
  67. void hssl_ctx_cleanup(hssl_ctx_t ssl_ctx) {
  68. if (!ssl_ctx) return;
  69. if (g_ssl_ctx == ssl_ctx) {
  70. g_ssl_ctx = NULL;
  71. }
  72. gnutls_ctx_t ctx = (gnutls_ctx_t)ssl_ctx;
  73. gnutls_certificate_free_credentials(ctx);
  74. }
  75. typedef struct gnutls_s {
  76. gnutls_session_t session;
  77. gnutls_ctx_t ctx;
  78. int fd;
  79. } gnutls_t;
  80. hssl_t hssl_new(hssl_ctx_t ssl_ctx, int fd) {
  81. gnutls_t* gnutls = (gnutls_t*)malloc(sizeof(gnutls_t));
  82. if (gnutls == NULL) return NULL;
  83. gnutls->session = NULL;
  84. gnutls->ctx = (gnutls_ctx_t)ssl_ctx;
  85. gnutls->fd = fd;
  86. return (hssl_t)gnutls;
  87. }
  88. static int hssl_init(hssl_t ssl, int endpoint) {
  89. if (ssl == NULL) return HSSL_ERROR;
  90. gnutls_t* gnutls = (gnutls_t*)ssl;
  91. if (gnutls->session == NULL) {
  92. gnutls_init(&gnutls->session, endpoint);
  93. gnutls_priority_set_direct(gnutls->session, "NORMAL", NULL);
  94. gnutls_credentials_set(gnutls->session, GNUTLS_CRD_CERTIFICATE, gnutls->ctx);
  95. gnutls_transport_set_ptr(gnutls->session, (gnutls_transport_ptr_t)(ptrdiff_t)gnutls->fd);
  96. }
  97. return HSSL_OK;
  98. }
  99. void hssl_free(hssl_t ssl) {
  100. if (ssl == NULL) return;
  101. gnutls_t* gnutls = (gnutls_t*)ssl;
  102. if (gnutls->session) {
  103. gnutls_deinit(gnutls->session);
  104. gnutls->session = NULL;
  105. }
  106. free(gnutls);
  107. }
  108. static int hssl_handshake(hssl_t ssl) {
  109. if (ssl == NULL) return HSSL_ERROR;
  110. gnutls_t* gnutls = (gnutls_t*)ssl;
  111. if (gnutls->session == NULL) return HSSL_ERROR;
  112. int ret = 0;
  113. while (1) {
  114. ret = gnutls_handshake(gnutls->session);
  115. if (ret == GNUTLS_E_SUCCESS) {
  116. return HSSL_OK;
  117. }
  118. else if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) {
  119. return gnutls_record_get_direction(gnutls->session) == 0 ? HSSL_WANT_READ : HSSL_WANT_WRITE;
  120. }
  121. else if (gnutls_error_is_fatal(ret)) {
  122. // fprintf(stderr, "gnutls_handshake failed: %s\n", gnutls_strerror(ret));
  123. return HSSL_ERROR;
  124. }
  125. }
  126. return HSSL_OK;
  127. }
  128. int hssl_accept(hssl_t ssl) {
  129. if (ssl == NULL) return HSSL_ERROR;
  130. gnutls_t* gnutls = (gnutls_t*)ssl;
  131. if (gnutls->session == NULL) {
  132. hssl_init(ssl, GNUTLS_SERVER);
  133. }
  134. return hssl_handshake(ssl);
  135. }
  136. int hssl_connect(hssl_t ssl) {
  137. if (ssl == NULL) return HSSL_ERROR;
  138. gnutls_t* gnutls = (gnutls_t*)ssl;
  139. if (gnutls->session == NULL) {
  140. hssl_init(ssl, GNUTLS_CLIENT);
  141. }
  142. return hssl_handshake(ssl);
  143. }
  144. int hssl_read(hssl_t ssl, void* buf, int len) {
  145. if (ssl == NULL) return HSSL_ERROR;
  146. gnutls_t* gnutls = (gnutls_t*)ssl;
  147. if (gnutls->session == NULL) return HSSL_ERROR;
  148. int ret = 0;
  149. while ((ret = gnutls_record_recv(gnutls->session, buf, len)) == GNUTLS_E_INTERRUPTED);
  150. return ret;
  151. }
  152. int hssl_write(hssl_t ssl, const void* buf, int len) {
  153. if (ssl == NULL) return HSSL_ERROR;
  154. gnutls_t* gnutls = (gnutls_t*)ssl;
  155. if (gnutls->session == NULL) return HSSL_ERROR;
  156. int ret = 0;
  157. while ((ret = gnutls_record_send(gnutls->session, buf, len)) == GNUTLS_E_INTERRUPTED);
  158. return ret;
  159. }
  160. int hssl_close(hssl_t ssl) {
  161. if (ssl == NULL) return HSSL_ERROR;
  162. gnutls_t* gnutls = (gnutls_t*)ssl;
  163. if (gnutls->session == NULL) return HSSL_ERROR;
  164. gnutls_bye(gnutls->session, GNUTLS_SHUT_RDWR);
  165. return HSSL_OK;
  166. }
  167. int hssl_set_sni_hostname(hssl_t ssl, const char* hostname) {
  168. if (ssl == NULL) return HSSL_ERROR;
  169. gnutls_t* gnutls = (gnutls_t*)ssl;
  170. if (gnutls->session == NULL) {
  171. hssl_init(ssl, GNUTLS_CLIENT);
  172. }
  173. gnutls_server_name_set(gnutls->session, GNUTLS_NAME_DNS, hostname, strlen(hostname));
  174. return 0;
  175. }
  176. #endif // WITH_GNUTLS