dns.c 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. #include "dns.h"
  2. #include "hdef.h"
  3. #include "hsocket.h"
  4. #include "herr.h"
  5. void dns_free(dns_t* dns) {
  6. SAFE_FREE(dns->questions);
  7. SAFE_FREE(dns->answers);
  8. SAFE_FREE(dns->authorities);
  9. SAFE_FREE(dns->addtionals);
  10. }
  11. // www.example.com => 3www7example3com
  12. int dns_name_encode(const char* domain, char* buf) {
  13. const char* p = domain;
  14. char* plen = buf++;
  15. int buflen = 1;
  16. int len = 0;
  17. while (*p != '\0') {
  18. if (*p != '.') {
  19. ++len;
  20. *buf = *p;
  21. }
  22. else {
  23. *plen = len;
  24. //printf("len=%d\n", len);
  25. plen = buf;
  26. len = 0;
  27. }
  28. ++p;
  29. ++buf;
  30. ++buflen;
  31. }
  32. *plen = len;
  33. //printf("len=%d\n", len);
  34. *buf = '\0';
  35. if (len != 0) {
  36. ++buflen; // include last '\0'
  37. }
  38. return buflen;
  39. }
  40. // 3www7example3com => www.example.com
  41. int dns_name_decode(const char* buf, char* domain) {
  42. const char* p = buf;
  43. int len = *p++;
  44. //printf("len=%d\n", len);
  45. int buflen = 1;
  46. while (*p != '\0') {
  47. if (len-- == 0) {
  48. len = *p;
  49. //printf("len=%d\n", len);
  50. *domain = '.';
  51. }
  52. else {
  53. *domain = *p;
  54. }
  55. ++p;
  56. ++domain;
  57. ++buflen;
  58. }
  59. *domain = '\0';
  60. ++buflen; // include last '\0'
  61. return buflen;
  62. }
  63. int dns_rr_pack(dns_rr_t* rr, char* buf, int len) {
  64. char* p = buf;
  65. char encoded_name[256];
  66. int encoded_namelen = dns_name_encode(rr->name, encoded_name);
  67. int packetlen = encoded_namelen + 2 + 2 + (rr->data ? (4+2+rr->datalen) : 0);
  68. if (len < packetlen) {
  69. return -1;
  70. }
  71. memcpy(p, encoded_name, encoded_namelen);
  72. p += encoded_namelen;
  73. uint16_t* pushort = (uint16_t*)p;
  74. *pushort = htons(rr->rtype);
  75. p += 2;
  76. pushort = (uint16_t*)p;
  77. *pushort = htons(rr->rclass);
  78. p += 2;
  79. // ...
  80. if (rr->datalen && rr->data) {
  81. uint32_t* puint = (uint32_t*)p;
  82. *puint = htonl(rr->ttl);
  83. p += 4;
  84. pushort = (uint16_t*)p;
  85. *pushort = htons(rr->datalen);
  86. p += 2;
  87. memcpy(p, rr->data, rr->datalen);
  88. p += rr->datalen;
  89. }
  90. return packetlen;
  91. }
  92. int dns_rr_unpack(char* buf, int len, dns_rr_t* rr, int is_question) {
  93. char* p = buf;
  94. int off = 0;
  95. int namelen = 0;
  96. if (*(uint8_t*)p >= 192) {
  97. // name off, we ignore
  98. namelen = 2;
  99. //uint16_t nameoff = (*(uint8_t*)p - 192) * 256 + *(uint8_t*)(p+1);
  100. }
  101. else {
  102. namelen = dns_name_decode(buf, rr->name);
  103. }
  104. if (namelen < 0) return -1;
  105. p += namelen;
  106. off += namelen;
  107. if (len < off + 4) return -1;
  108. uint16_t* pushort = (uint16_t*)p;
  109. rr->rtype = ntohs(*pushort);
  110. p += 2;
  111. pushort = (uint16_t*)p;
  112. rr->rclass = ntohs(*pushort);
  113. p += 2;
  114. off += 4;
  115. if (!is_question) {
  116. if (len < off + 6) return -1;
  117. uint32_t* puint = (uint32_t*)p;
  118. rr->ttl = ntohl(*puint);
  119. p += 4;
  120. pushort = (uint16_t*)p;
  121. rr->datalen = ntohs(*pushort);
  122. p += 2;
  123. off += 6;
  124. if (len < off + rr->datalen) return -1;
  125. rr->data = p;
  126. p += rr->datalen;
  127. off += rr->datalen;
  128. }
  129. return off;
  130. }
  131. int dns_pack(dns_t* dns, char* buf, int len) {
  132. if (len < sizeof(dnshdr_t)) return -1;
  133. int off = 0;
  134. dnshdr_t* hdr = &dns->hdr;
  135. dnshdr_t htonhdr = dns->hdr;
  136. htonhdr.transaction_id = htons(hdr->transaction_id);
  137. htonhdr.nquestion = htons(hdr->nquestion);
  138. htonhdr.nanswer = htons(hdr->nanswer);
  139. htonhdr.nauthority = htons(hdr->nauthority);
  140. htonhdr.naddtional = htons(hdr->naddtional);
  141. memcpy(buf, &htonhdr, sizeof(dnshdr_t));
  142. off += sizeof(dnshdr_t);
  143. int i;
  144. for (i = 0; i < hdr->nquestion; ++i) {
  145. int packetlen = dns_rr_pack(dns->questions+i, buf+off, len-off);
  146. if (packetlen < 0) return -1;
  147. off += packetlen;
  148. }
  149. for (i = 0; i < hdr->nanswer; ++i) {
  150. int packetlen = dns_rr_pack(dns->answers+i, buf+off, len-off);
  151. if (packetlen < 0) return -1;
  152. off += packetlen;
  153. }
  154. for (i = 0; i < hdr->nauthority; ++i) {
  155. int packetlen = dns_rr_pack(dns->authorities+i, buf+off, len-off);
  156. if (packetlen < 0) return -1;
  157. off += packetlen;
  158. }
  159. for (i = 0; i < hdr->naddtional; ++i) {
  160. int packetlen = dns_rr_pack(dns->addtionals+i, buf+off, len-off);
  161. if (packetlen < 0) return -1;
  162. off += packetlen;
  163. }
  164. return off;
  165. }
  166. int dns_unpack(char* buf, int len, dns_t* dns) {
  167. memset(dns, 0, sizeof(dns_t));
  168. if (len < sizeof(dnshdr_t)) return -1;
  169. int off = 0;
  170. dnshdr_t* hdr = &dns->hdr;
  171. memcpy(hdr, buf, sizeof(dnshdr_t));
  172. off += sizeof(dnshdr_t);
  173. hdr->transaction_id = ntohs(hdr->transaction_id);
  174. hdr->nquestion = ntohs(hdr->nquestion);
  175. hdr->nanswer = ntohs(hdr->nanswer);
  176. hdr->nauthority = ntohs(hdr->nauthority);
  177. hdr->naddtional = ntohs(hdr->naddtional);
  178. int i;
  179. if (hdr->nquestion) {
  180. int bytes = hdr->nquestion * sizeof(dns_rr_t);
  181. dns->questions = (dns_rr_t*)malloc(bytes);
  182. memset(dns->questions, 0, bytes);
  183. for (i = 0; i < hdr->nquestion; ++i) {
  184. int packetlen = dns_rr_unpack(buf+off, len-off, dns->questions+i, 1);
  185. if (packetlen < 0) return -1;
  186. off += packetlen;
  187. }
  188. }
  189. if (hdr->nanswer) {
  190. int bytes = hdr->nanswer * sizeof(dns_rr_t);
  191. dns->answers = (dns_rr_t*)malloc(bytes);
  192. memset(dns->answers, 0, bytes);
  193. for (i = 0; i < hdr->nanswer; ++i) {
  194. int packetlen = dns_rr_unpack(buf+off, len-off, dns->answers+i, 0);
  195. if (packetlen < 0) return -1;
  196. off += packetlen;
  197. }
  198. }
  199. if (hdr->nauthority) {
  200. int bytes = hdr->nauthority * sizeof(dns_rr_t);
  201. dns->authorities = (dns_rr_t*)malloc(bytes);
  202. memset(dns->authorities, 0, bytes);
  203. for (i = 0; i < hdr->nauthority; ++i) {
  204. int packetlen = dns_rr_unpack(buf+off, len-off, dns->authorities+i, 0);
  205. if (packetlen < 0) return -1;
  206. off += packetlen;
  207. }
  208. }
  209. if (hdr->naddtional) {
  210. int bytes = hdr->naddtional * sizeof(dns_rr_t);
  211. dns->addtionals = (dns_rr_t*)malloc(bytes);
  212. memset(dns->addtionals, 0, bytes);
  213. for (i = 0; i < hdr->naddtional; ++i) {
  214. int packetlen = dns_rr_unpack(buf+off, len-off, dns->addtionals+i, 0);
  215. if (packetlen < 0) return -1;
  216. off += packetlen;
  217. }
  218. }
  219. return off;
  220. }
  221. // dns_pack -> sendto -> recvfrom -> dns_unpack
  222. int dns_query(dns_t* query, dns_t* response, const char* nameserver) {
  223. char buf[1024];
  224. int buflen = sizeof(buf);
  225. buflen = dns_pack(query, buf, buflen);
  226. if (buflen < 0) {
  227. return buflen;
  228. }
  229. int sockfd = socket(AF_INET, SOCK_DGRAM, 0);
  230. if (sockfd < 0) {
  231. perror("socket");
  232. return ERR_SOCKET;
  233. }
  234. so_sndtimeo(sockfd, 5000);
  235. so_rcvtimeo(sockfd, 5000);
  236. int ret = 0;
  237. int nsend, nrecv;
  238. int nparse;
  239. struct sockaddr_in addr;
  240. socklen_t addrlen = sizeof(addr);
  241. memset(&addr, 0, addrlen);
  242. addr.sin_family = AF_INET;
  243. addr.sin_addr.s_addr = inet_addr(nameserver);
  244. addr.sin_port = htons(DNS_PORT);
  245. nsend = sendto(sockfd, buf, buflen, 0, (struct sockaddr*)&addr, addrlen);
  246. if (nsend != buflen) {
  247. ret = ERR_SENDTO;
  248. goto error;
  249. }
  250. nrecv = recvfrom(sockfd, buf, sizeof(buf), 0, (struct sockaddr*)&addr, &addrlen);
  251. if (nrecv <= 0) {
  252. ret = ERR_RECVFROM;
  253. goto error;
  254. }
  255. nparse = dns_unpack(buf, nrecv, response);
  256. if (nparse != nrecv) {
  257. ret = -ERR_INVALID_PACKAGE;
  258. goto error;
  259. }
  260. error:
  261. if (sockfd != INVALID_SOCKET) {
  262. closesocket(sockfd);
  263. }
  264. return ret;
  265. }
  266. int nslookup(const char* domain, uint32_t* addrs, int naddr, const char* nameserver) {
  267. dns_t query;
  268. memset(&query, 0, sizeof(query));
  269. query.hdr.transaction_id = getpid();
  270. query.hdr.qr = DNS_QUERY;
  271. query.hdr.rd = 1;
  272. query.hdr.nquestion = 1;
  273. dns_rr_t question;
  274. memset(&question, 0, sizeof(question));
  275. strncpy(question.name, domain, sizeof(question.name));
  276. question.rtype = DNS_TYPE_A;
  277. question.rclass = DNS_CLASS_IN;
  278. query.questions = &question;
  279. dns_t resp;
  280. memset(&resp, 0, sizeof(resp));
  281. int ret = dns_query(&query, &resp, nameserver);
  282. if (ret != 0) {
  283. return ret;
  284. }
  285. dns_rr_t* rr = resp.answers;
  286. int addr_cnt = 0;
  287. if (resp.hdr.transaction_id != query.hdr.transaction_id ||
  288. resp.hdr.qr != DNS_RESPONSE ||
  289. resp.hdr.rcode != 0) {
  290. ret = -ERR_MISMATCH;
  291. goto end;
  292. }
  293. if (resp.hdr.nanswer == 0) {
  294. ret = 0;
  295. goto end;
  296. }
  297. for (int i = 0; i < resp.hdr.nanswer; ++i, ++rr) {
  298. if (rr->rtype == DNS_TYPE_A) {
  299. if (addr_cnt < naddr && rr->datalen == 4) {
  300. memcpy(addrs+addr_cnt, rr->data, 4);
  301. }
  302. ++addr_cnt;
  303. }
  304. /*
  305. else if (rr->rtype == DNS_TYPE_CNAME) {
  306. char name[256];
  307. dns_name_decode(rr->data, name);
  308. }
  309. */
  310. }
  311. ret = addr_cnt;
  312. end:
  313. dns_free(&resp);
  314. return ret;
  315. }