Merge pull request #20007 from dhewg/prometheus-node-exporter-ucode
[feed/packages.git] / utils / rpcd-mod-wireguard / src / wireguard.c
1 // SPDX-License-Identifier: LGPL-2.1+
2 /*
3 * Copyright (C) 2015-2020 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4 * Copyright (C) 2008-2012 Pablo Neira Ayuso <pablo@netfilter.org>.
5 */
6
7 #define _GNU_SOURCE
8
9 #include <errno.h>
10 #include <linux/genetlink.h>
11 #include <linux/if_link.h>
12 #include <linux/netlink.h>
13 #include <linux/rtnetlink.h>
14 #include <netinet/in.h>
15 #include <stdbool.h>
16 #include <stdio.h>
17 #include <stdlib.h>
18 #include <string.h>
19 #include <sys/socket.h>
20 #include <time.h>
21 #include <unistd.h>
22 #include <fcntl.h>
23 #include <assert.h>
24
25 #include "wireguard.h"
26
27 /* wireguard.h netlink uapi: */
28
29 #define WG_GENL_NAME "wireguard"
30 #define WG_GENL_VERSION 1
31
32 enum wg_cmd {
33 WG_CMD_GET_DEVICE,
34 WG_CMD_SET_DEVICE,
35 __WG_CMD_MAX
36 };
37
38 enum wgdevice_flag {
39 WGDEVICE_F_REPLACE_PEERS = 1U << 0
40 };
41 enum wgdevice_attribute {
42 WGDEVICE_A_UNSPEC,
43 WGDEVICE_A_IFINDEX,
44 WGDEVICE_A_IFNAME,
45 WGDEVICE_A_PRIVATE_KEY,
46 WGDEVICE_A_PUBLIC_KEY,
47 WGDEVICE_A_FLAGS,
48 WGDEVICE_A_LISTEN_PORT,
49 WGDEVICE_A_FWMARK,
50 WGDEVICE_A_PEERS,
51 __WGDEVICE_A_LAST
52 };
53
54 enum wgpeer_flag {
55 WGPEER_F_REMOVE_ME = 1U << 0,
56 WGPEER_F_REPLACE_ALLOWEDIPS = 1U << 1
57 };
58 enum wgpeer_attribute {
59 WGPEER_A_UNSPEC,
60 WGPEER_A_PUBLIC_KEY,
61 WGPEER_A_PRESHARED_KEY,
62 WGPEER_A_FLAGS,
63 WGPEER_A_ENDPOINT,
64 WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL,
65 WGPEER_A_LAST_HANDSHAKE_TIME,
66 WGPEER_A_RX_BYTES,
67 WGPEER_A_TX_BYTES,
68 WGPEER_A_ALLOWEDIPS,
69 WGPEER_A_PROTOCOL_VERSION,
70 __WGPEER_A_LAST
71 };
72
73 enum wgallowedip_attribute {
74 WGALLOWEDIP_A_UNSPEC,
75 WGALLOWEDIP_A_FAMILY,
76 WGALLOWEDIP_A_IPADDR,
77 WGALLOWEDIP_A_CIDR_MASK,
78 __WGALLOWEDIP_A_LAST
79 };
80
81 /* libmnl mini library: */
82
83 #define MNL_SOCKET_AUTOPID 0
84 #define MNL_ALIGNTO 4
85 #define MNL_ALIGN(len) (((len)+MNL_ALIGNTO-1) & ~(MNL_ALIGNTO-1))
86 #define MNL_NLMSG_HDRLEN MNL_ALIGN(sizeof(struct nlmsghdr))
87 #define MNL_ATTR_HDRLEN MNL_ALIGN(sizeof(struct nlattr))
88
89 enum mnl_attr_data_type {
90 MNL_TYPE_UNSPEC,
91 MNL_TYPE_U8,
92 MNL_TYPE_U16,
93 MNL_TYPE_U32,
94 MNL_TYPE_U64,
95 MNL_TYPE_STRING,
96 MNL_TYPE_FLAG,
97 MNL_TYPE_MSECS,
98 MNL_TYPE_NESTED,
99 MNL_TYPE_NESTED_COMPAT,
100 MNL_TYPE_NUL_STRING,
101 MNL_TYPE_BINARY,
102 MNL_TYPE_MAX,
103 };
104
105 #define mnl_attr_for_each(attr, nlh, offset) \
106 for ((attr) = mnl_nlmsg_get_payload_offset((nlh), (offset)); \
107 mnl_attr_ok((attr), (char *)mnl_nlmsg_get_payload_tail(nlh) - (char *)(attr)); \
108 (attr) = mnl_attr_next(attr))
109
110 #define mnl_attr_for_each_nested(attr, nest) \
111 for ((attr) = mnl_attr_get_payload(nest); \
112 mnl_attr_ok((attr), (char *)mnl_attr_get_payload(nest) + mnl_attr_get_payload_len(nest) - (char *)(attr)); \
113 (attr) = mnl_attr_next(attr))
114
115 #define mnl_attr_for_each_payload(payload, payload_size) \
116 for ((attr) = (payload); \
117 mnl_attr_ok((attr), (char *)(payload) + payload_size - (char *)(attr)); \
118 (attr) = mnl_attr_next(attr))
119
120 #define MNL_CB_ERROR -1
121 #define MNL_CB_STOP 0
122 #define MNL_CB_OK 1
123
124 typedef int (*mnl_attr_cb_t)(const struct nlattr *attr, void *data);
125 typedef int (*mnl_cb_t)(const struct nlmsghdr *nlh, void *data);
126
127 #ifndef MNL_ARRAY_SIZE
128 #define MNL_ARRAY_SIZE(a) (sizeof(a)/sizeof((a)[0]))
129 #endif
130
131 static size_t mnl_ideal_socket_buffer_size(void)
132 {
133 static size_t size = 0;
134
135 if (size)
136 return size;
137 size = (size_t)sysconf(_SC_PAGESIZE);
138 if (size > 8192)
139 size = 8192;
140 return size;
141 }
142
143 static size_t mnl_nlmsg_size(size_t len)
144 {
145 return len + MNL_NLMSG_HDRLEN;
146 }
147
148 static struct nlmsghdr *mnl_nlmsg_put_header(void *buf)
149 {
150 int len = MNL_ALIGN(sizeof(struct nlmsghdr));
151 struct nlmsghdr *nlh = buf;
152
153 memset(buf, 0, len);
154 nlh->nlmsg_len = len;
155 return nlh;
156 }
157
158 static void *mnl_nlmsg_put_extra_header(struct nlmsghdr *nlh, size_t size)
159 {
160 char *ptr = (char *)nlh + nlh->nlmsg_len;
161 size_t len = MNL_ALIGN(size);
162 nlh->nlmsg_len += len;
163 memset(ptr, 0, len);
164 return ptr;
165 }
166
167 static void *mnl_nlmsg_get_payload(const struct nlmsghdr *nlh)
168 {
169 return (void *)nlh + MNL_NLMSG_HDRLEN;
170 }
171
172 static void *mnl_nlmsg_get_payload_offset(const struct nlmsghdr *nlh, size_t offset)
173 {
174 return (void *)nlh + MNL_NLMSG_HDRLEN + MNL_ALIGN(offset);
175 }
176
177 static bool mnl_nlmsg_ok(const struct nlmsghdr *nlh, int len)
178 {
179 return len >= (int)sizeof(struct nlmsghdr) &&
180 nlh->nlmsg_len >= sizeof(struct nlmsghdr) &&
181 (int)nlh->nlmsg_len <= len;
182 }
183
184 static struct nlmsghdr *mnl_nlmsg_next(const struct nlmsghdr *nlh, int *len)
185 {
186 *len -= MNL_ALIGN(nlh->nlmsg_len);
187 return (struct nlmsghdr *)((void *)nlh + MNL_ALIGN(nlh->nlmsg_len));
188 }
189
190 static void *mnl_nlmsg_get_payload_tail(const struct nlmsghdr *nlh)
191 {
192 return (void *)nlh + MNL_ALIGN(nlh->nlmsg_len);
193 }
194
195 static bool mnl_nlmsg_seq_ok(const struct nlmsghdr *nlh, unsigned int seq)
196 {
197 return nlh->nlmsg_seq && seq ? nlh->nlmsg_seq == seq : true;
198 }
199
200 static bool mnl_nlmsg_portid_ok(const struct nlmsghdr *nlh, unsigned int portid)
201 {
202 return nlh->nlmsg_pid && portid ? nlh->nlmsg_pid == portid : true;
203 }
204
205 static uint16_t mnl_attr_get_type(const struct nlattr *attr)
206 {
207 return attr->nla_type & NLA_TYPE_MASK;
208 }
209
210 static uint16_t mnl_attr_get_payload_len(const struct nlattr *attr)
211 {
212 return attr->nla_len - MNL_ATTR_HDRLEN;
213 }
214
215 static void *mnl_attr_get_payload(const struct nlattr *attr)
216 {
217 return (void *)attr + MNL_ATTR_HDRLEN;
218 }
219
220 static bool mnl_attr_ok(const struct nlattr *attr, int len)
221 {
222 return len >= (int)sizeof(struct nlattr) &&
223 attr->nla_len >= sizeof(struct nlattr) &&
224 (int)attr->nla_len <= len;
225 }
226
227 static struct nlattr *mnl_attr_next(const struct nlattr *attr)
228 {
229 return (struct nlattr *)((void *)attr + MNL_ALIGN(attr->nla_len));
230 }
231
232 static int mnl_attr_type_valid(const struct nlattr *attr, uint16_t max)
233 {
234 if (mnl_attr_get_type(attr) > max) {
235 errno = EOPNOTSUPP;
236 return -1;
237 }
238 return 1;
239 }
240
241 static int __mnl_attr_validate(const struct nlattr *attr,
242 enum mnl_attr_data_type type, size_t exp_len)
243 {
244 uint16_t attr_len = mnl_attr_get_payload_len(attr);
245 const char *attr_data = mnl_attr_get_payload(attr);
246
247 if (attr_len < exp_len) {
248 errno = ERANGE;
249 return -1;
250 }
251 switch(type) {
252 case MNL_TYPE_FLAG:
253 if (attr_len > 0) {
254 errno = ERANGE;
255 return -1;
256 }
257 break;
258 case MNL_TYPE_NUL_STRING:
259 if (attr_len == 0) {
260 errno = ERANGE;
261 return -1;
262 }
263 if (attr_data[attr_len-1] != '\0') {
264 errno = EINVAL;
265 return -1;
266 }
267 break;
268 case MNL_TYPE_STRING:
269 if (attr_len == 0) {
270 errno = ERANGE;
271 return -1;
272 }
273 break;
274 case MNL_TYPE_NESTED:
275
276 if (attr_len == 0)
277 break;
278
279 if (attr_len < MNL_ATTR_HDRLEN) {
280 errno = ERANGE;
281 return -1;
282 }
283 break;
284 default:
285
286 break;
287 }
288 if (exp_len && attr_len > exp_len) {
289 errno = ERANGE;
290 return -1;
291 }
292 return 0;
293 }
294
295 static const size_t mnl_attr_data_type_len[MNL_TYPE_MAX] = {
296 [MNL_TYPE_U8] = sizeof(uint8_t),
297 [MNL_TYPE_U16] = sizeof(uint16_t),
298 [MNL_TYPE_U32] = sizeof(uint32_t),
299 [MNL_TYPE_U64] = sizeof(uint64_t),
300 [MNL_TYPE_MSECS] = sizeof(uint64_t),
301 };
302
303 static int mnl_attr_validate(const struct nlattr *attr, enum mnl_attr_data_type type)
304 {
305 int exp_len;
306
307 if (type >= MNL_TYPE_MAX) {
308 errno = EINVAL;
309 return -1;
310 }
311 exp_len = mnl_attr_data_type_len[type];
312 return __mnl_attr_validate(attr, type, exp_len);
313 }
314
315 static int mnl_attr_parse(const struct nlmsghdr *nlh, unsigned int offset,
316 mnl_attr_cb_t cb, void *data)
317 {
318 int ret = MNL_CB_OK;
319 const struct nlattr *attr;
320
321 mnl_attr_for_each(attr, nlh, offset)
322 if ((ret = cb(attr, data)) <= MNL_CB_STOP)
323 return ret;
324 return ret;
325 }
326
327 static int mnl_attr_parse_nested(const struct nlattr *nested, mnl_attr_cb_t cb,
328 void *data)
329 {
330 int ret = MNL_CB_OK;
331 const struct nlattr *attr;
332
333 mnl_attr_for_each_nested(attr, nested)
334 if ((ret = cb(attr, data)) <= MNL_CB_STOP)
335 return ret;
336 return ret;
337 }
338
339 static uint8_t mnl_attr_get_u8(const struct nlattr *attr)
340 {
341 return *((uint8_t *)mnl_attr_get_payload(attr));
342 }
343
344 static uint16_t mnl_attr_get_u16(const struct nlattr *attr)
345 {
346 return *((uint16_t *)mnl_attr_get_payload(attr));
347 }
348
349 static uint32_t mnl_attr_get_u32(const struct nlattr *attr)
350 {
351 return *((uint32_t *)mnl_attr_get_payload(attr));
352 }
353
354 static uint64_t mnl_attr_get_u64(const struct nlattr *attr)
355 {
356 uint64_t tmp;
357 memcpy(&tmp, mnl_attr_get_payload(attr), sizeof(tmp));
358 return tmp;
359 }
360
361 static const char *mnl_attr_get_str(const struct nlattr *attr)
362 {
363 return mnl_attr_get_payload(attr);
364 }
365
366 static void mnl_attr_put(struct nlmsghdr *nlh, uint16_t type, size_t len,
367 const void *data)
368 {
369 struct nlattr *attr = mnl_nlmsg_get_payload_tail(nlh);
370 uint16_t payload_len = MNL_ALIGN(sizeof(struct nlattr)) + len;
371 int pad;
372
373 attr->nla_type = type;
374 attr->nla_len = payload_len;
375 memcpy(mnl_attr_get_payload(attr), data, len);
376 nlh->nlmsg_len += MNL_ALIGN(payload_len);
377 pad = MNL_ALIGN(len) - len;
378 if (pad > 0)
379 memset(mnl_attr_get_payload(attr) + len, 0, pad);
380 }
381
382 static void mnl_attr_put_u16(struct nlmsghdr *nlh, uint16_t type, uint16_t data)
383 {
384 mnl_attr_put(nlh, type, sizeof(uint16_t), &data);
385 }
386
387 static void mnl_attr_put_u32(struct nlmsghdr *nlh, uint16_t type, uint32_t data)
388 {
389 mnl_attr_put(nlh, type, sizeof(uint32_t), &data);
390 }
391
392 static void mnl_attr_put_strz(struct nlmsghdr *nlh, uint16_t type, const char *data)
393 {
394 mnl_attr_put(nlh, type, strlen(data)+1, data);
395 }
396
397 static struct nlattr *mnl_attr_nest_start(struct nlmsghdr *nlh, uint16_t type)
398 {
399 struct nlattr *start = mnl_nlmsg_get_payload_tail(nlh);
400
401 start->nla_type = NLA_F_NESTED | type;
402 nlh->nlmsg_len += MNL_ALIGN(sizeof(struct nlattr));
403 return start;
404 }
405
406 static bool mnl_attr_put_check(struct nlmsghdr *nlh, size_t buflen,
407 uint16_t type, size_t len, const void *data)
408 {
409 if (nlh->nlmsg_len + MNL_ATTR_HDRLEN + MNL_ALIGN(len) > buflen)
410 return false;
411 mnl_attr_put(nlh, type, len, data);
412 return true;
413 }
414
415 static bool mnl_attr_put_u8_check(struct nlmsghdr *nlh, size_t buflen,
416 uint16_t type, uint8_t data)
417 {
418 return mnl_attr_put_check(nlh, buflen, type, sizeof(uint8_t), &data);
419 }
420
421 static bool mnl_attr_put_u16_check(struct nlmsghdr *nlh, size_t buflen,
422 uint16_t type, uint16_t data)
423 {
424 return mnl_attr_put_check(nlh, buflen, type, sizeof(uint16_t), &data);
425 }
426
427 static bool mnl_attr_put_u32_check(struct nlmsghdr *nlh, size_t buflen,
428 uint16_t type, uint32_t data)
429 {
430 return mnl_attr_put_check(nlh, buflen, type, sizeof(uint32_t), &data);
431 }
432
433 static struct nlattr *mnl_attr_nest_start_check(struct nlmsghdr *nlh, size_t buflen,
434 uint16_t type)
435 {
436 if (nlh->nlmsg_len + MNL_ATTR_HDRLEN > buflen)
437 return NULL;
438 return mnl_attr_nest_start(nlh, type);
439 }
440
441 static void mnl_attr_nest_end(struct nlmsghdr *nlh, struct nlattr *start)
442 {
443 start->nla_len = mnl_nlmsg_get_payload_tail(nlh) - (void *)start;
444 }
445
446 static void mnl_attr_nest_cancel(struct nlmsghdr *nlh, struct nlattr *start)
447 {
448 nlh->nlmsg_len -= mnl_nlmsg_get_payload_tail(nlh) - (void *)start;
449 }
450
451 static int mnl_cb_noop(__attribute__((unused)) const struct nlmsghdr *nlh, __attribute__((unused)) void *data)
452 {
453 return MNL_CB_OK;
454 }
455
456 static int mnl_cb_error(const struct nlmsghdr *nlh, __attribute__((unused)) void *data)
457 {
458 const struct nlmsgerr *err = mnl_nlmsg_get_payload(nlh);
459
460 if (nlh->nlmsg_len < mnl_nlmsg_size(sizeof(struct nlmsgerr))) {
461 errno = EBADMSG;
462 return MNL_CB_ERROR;
463 }
464
465 if (err->error < 0)
466 errno = -err->error;
467 else
468 errno = err->error;
469
470 return err->error == 0 ? MNL_CB_STOP : MNL_CB_ERROR;
471 }
472
473 static int mnl_cb_stop(__attribute__((unused)) const struct nlmsghdr *nlh, __attribute__((unused)) void *data)
474 {
475 return MNL_CB_STOP;
476 }
477
478 static const mnl_cb_t default_cb_array[NLMSG_MIN_TYPE] = {
479 [NLMSG_NOOP] = mnl_cb_noop,
480 [NLMSG_ERROR] = mnl_cb_error,
481 [NLMSG_DONE] = mnl_cb_stop,
482 [NLMSG_OVERRUN] = mnl_cb_noop,
483 };
484
485 static int __mnl_cb_run(const void *buf, size_t numbytes,
486 unsigned int seq, unsigned int portid,
487 mnl_cb_t cb_data, void *data,
488 const mnl_cb_t *cb_ctl_array,
489 unsigned int cb_ctl_array_len)
490 {
491 int ret = MNL_CB_OK, len = numbytes;
492 const struct nlmsghdr *nlh = buf;
493
494 while (mnl_nlmsg_ok(nlh, len)) {
495
496 if (!mnl_nlmsg_portid_ok(nlh, portid)) {
497 errno = ESRCH;
498 return -1;
499 }
500
501 if (!mnl_nlmsg_seq_ok(nlh, seq)) {
502 errno = EPROTO;
503 return -1;
504 }
505
506 if (nlh->nlmsg_flags & NLM_F_DUMP_INTR) {
507 errno = EINTR;
508 return -1;
509 }
510
511 if (nlh->nlmsg_type >= NLMSG_MIN_TYPE) {
512 if (cb_data){
513 ret = cb_data(nlh, data);
514 if (ret <= MNL_CB_STOP)
515 goto out;
516 }
517 } else if (nlh->nlmsg_type < cb_ctl_array_len) {
518 if (cb_ctl_array && cb_ctl_array[nlh->nlmsg_type]) {
519 ret = cb_ctl_array[nlh->nlmsg_type](nlh, data);
520 if (ret <= MNL_CB_STOP)
521 goto out;
522 }
523 } else if (default_cb_array[nlh->nlmsg_type]) {
524 ret = default_cb_array[nlh->nlmsg_type](nlh, data);
525 if (ret <= MNL_CB_STOP)
526 goto out;
527 }
528 nlh = mnl_nlmsg_next(nlh, &len);
529 }
530 out:
531 return ret;
532 }
533
534 static int mnl_cb_run2(const void *buf, size_t numbytes, unsigned int seq,
535 unsigned int portid, mnl_cb_t cb_data, void *data,
536 const mnl_cb_t *cb_ctl_array, unsigned int cb_ctl_array_len)
537 {
538 return __mnl_cb_run(buf, numbytes, seq, portid, cb_data, data,
539 cb_ctl_array, cb_ctl_array_len);
540 }
541
542 static int mnl_cb_run(const void *buf, size_t numbytes, unsigned int seq,
543 unsigned int portid, mnl_cb_t cb_data, void *data)
544 {
545 return __mnl_cb_run(buf, numbytes, seq, portid, cb_data, data, NULL, 0);
546 }
547
548 struct mnl_socket {
549 int fd;
550 struct sockaddr_nl addr;
551 };
552
553 static unsigned int mnl_socket_get_portid(const struct mnl_socket *nl)
554 {
555 return nl->addr.nl_pid;
556 }
557
558 static struct mnl_socket *__mnl_socket_open(int bus, int flags)
559 {
560 struct mnl_socket *nl;
561
562 nl = calloc(1, sizeof(struct mnl_socket));
563 if (nl == NULL)
564 return NULL;
565
566 nl->fd = socket(AF_NETLINK, SOCK_RAW | flags, bus);
567 if (nl->fd == -1) {
568 free(nl);
569 return NULL;
570 }
571
572 return nl;
573 }
574
575 static struct mnl_socket *mnl_socket_open(int bus)
576 {
577 return __mnl_socket_open(bus, 0);
578 }
579
580 static int mnl_socket_bind(struct mnl_socket *nl, unsigned int groups, pid_t pid)
581 {
582 int ret;
583 socklen_t addr_len;
584
585 nl->addr.nl_family = AF_NETLINK;
586 nl->addr.nl_groups = groups;
587 nl->addr.nl_pid = pid;
588
589 ret = bind(nl->fd, (struct sockaddr *) &nl->addr, sizeof (nl->addr));
590 if (ret < 0)
591 return ret;
592
593 addr_len = sizeof(nl->addr);
594 ret = getsockname(nl->fd, (struct sockaddr *) &nl->addr, &addr_len);
595 if (ret < 0)
596 return ret;
597
598 if (addr_len != sizeof(nl->addr)) {
599 errno = EINVAL;
600 return -1;
601 }
602 if (nl->addr.nl_family != AF_NETLINK) {
603 errno = EINVAL;
604 return -1;
605 }
606 return 0;
607 }
608
609 static ssize_t mnl_socket_sendto(const struct mnl_socket *nl, const void *buf,
610 size_t len)
611 {
612 static const struct sockaddr_nl snl = {
613 .nl_family = AF_NETLINK
614 };
615 return sendto(nl->fd, buf, len, 0,
616 (struct sockaddr *) &snl, sizeof(snl));
617 }
618
619 static ssize_t mnl_socket_recvfrom(const struct mnl_socket *nl, void *buf,
620 size_t bufsiz)
621 {
622 ssize_t ret;
623 struct sockaddr_nl addr;
624 struct iovec iov = {
625 .iov_base = buf,
626 .iov_len = bufsiz,
627 };
628 struct msghdr msg = {
629 .msg_name = &addr,
630 .msg_namelen = sizeof(struct sockaddr_nl),
631 .msg_iov = &iov,
632 .msg_iovlen = 1,
633 .msg_control = NULL,
634 .msg_controllen = 0,
635 .msg_flags = 0,
636 };
637 ret = recvmsg(nl->fd, &msg, 0);
638 if (ret == -1)
639 return ret;
640
641 if (msg.msg_flags & MSG_TRUNC) {
642 errno = ENOSPC;
643 return -1;
644 }
645 if (msg.msg_namelen != sizeof(struct sockaddr_nl)) {
646 errno = EINVAL;
647 return -1;
648 }
649 return ret;
650 }
651
652 static int mnl_socket_close(struct mnl_socket *nl)
653 {
654 int ret = close(nl->fd);
655 free(nl);
656 return ret;
657 }
658
659 /* mnlg mini library: */
660
661 struct mnlg_socket {
662 struct mnl_socket *nl;
663 char *buf;
664 uint16_t id;
665 uint8_t version;
666 unsigned int seq;
667 unsigned int portid;
668 };
669
670 static struct nlmsghdr *__mnlg_msg_prepare(struct mnlg_socket *nlg, uint8_t cmd,
671 uint16_t flags, uint16_t id,
672 uint8_t version)
673 {
674 struct nlmsghdr *nlh;
675 struct genlmsghdr *genl;
676
677 nlh = mnl_nlmsg_put_header(nlg->buf);
678 nlh->nlmsg_type = id;
679 nlh->nlmsg_flags = flags;
680 nlg->seq = time(NULL);
681 nlh->nlmsg_seq = nlg->seq;
682
683 genl = mnl_nlmsg_put_extra_header(nlh, sizeof(struct genlmsghdr));
684 genl->cmd = cmd;
685 genl->version = version;
686
687 return nlh;
688 }
689
690 static struct nlmsghdr *mnlg_msg_prepare(struct mnlg_socket *nlg, uint8_t cmd,
691 uint16_t flags)
692 {
693 return __mnlg_msg_prepare(nlg, cmd, flags, nlg->id, nlg->version);
694 }
695
696 static int mnlg_socket_send(struct mnlg_socket *nlg, const struct nlmsghdr *nlh)
697 {
698 return mnl_socket_sendto(nlg->nl, nlh, nlh->nlmsg_len);
699 }
700
701 static int mnlg_cb_noop(const struct nlmsghdr *nlh, void *data)
702 {
703 (void)nlh;
704 (void)data;
705 return MNL_CB_OK;
706 }
707
708 static int mnlg_cb_error(const struct nlmsghdr *nlh, void *data)
709 {
710 const struct nlmsgerr *err = mnl_nlmsg_get_payload(nlh);
711 (void)data;
712
713 if (nlh->nlmsg_len < mnl_nlmsg_size(sizeof(struct nlmsgerr))) {
714 errno = EBADMSG;
715 return MNL_CB_ERROR;
716 }
717 /* Netlink subsystems returns the errno value with different signess */
718 if (err->error < 0)
719 errno = -err->error;
720 else
721 errno = err->error;
722
723 return err->error == 0 ? MNL_CB_STOP : MNL_CB_ERROR;
724 }
725
726 static int mnlg_cb_stop(const struct nlmsghdr *nlh, void *data)
727 {
728 (void)data;
729 if (nlh->nlmsg_flags & NLM_F_MULTI && nlh->nlmsg_len == mnl_nlmsg_size(sizeof(int))) {
730 int error = *(int *)mnl_nlmsg_get_payload(nlh);
731 /* Netlink subsystems returns the errno value with different signess */
732 if (error < 0)
733 errno = -error;
734 else
735 errno = error;
736
737 return error == 0 ? MNL_CB_STOP : MNL_CB_ERROR;
738 }
739 return MNL_CB_STOP;
740 }
741
742 static const mnl_cb_t mnlg_cb_array[] = {
743 [NLMSG_NOOP] = mnlg_cb_noop,
744 [NLMSG_ERROR] = mnlg_cb_error,
745 [NLMSG_DONE] = mnlg_cb_stop,
746 [NLMSG_OVERRUN] = mnlg_cb_noop,
747 };
748
749 static int mnlg_socket_recv_run(struct mnlg_socket *nlg, mnl_cb_t data_cb, void *data)
750 {
751 int err;
752
753 do {
754 err = mnl_socket_recvfrom(nlg->nl, nlg->buf,
755 mnl_ideal_socket_buffer_size());
756 if (err <= 0)
757 break;
758 err = mnl_cb_run2(nlg->buf, err, nlg->seq, nlg->portid,
759 data_cb, data, mnlg_cb_array, MNL_ARRAY_SIZE(mnlg_cb_array));
760 } while (err > 0);
761
762 return err;
763 }
764
765 static int get_family_id_attr_cb(const struct nlattr *attr, void *data)
766 {
767 const struct nlattr **tb = data;
768 int type = mnl_attr_get_type(attr);
769
770 if (mnl_attr_type_valid(attr, CTRL_ATTR_MAX) < 0)
771 return MNL_CB_ERROR;
772
773 if (type == CTRL_ATTR_FAMILY_ID &&
774 mnl_attr_validate(attr, MNL_TYPE_U16) < 0)
775 return MNL_CB_ERROR;
776 tb[type] = attr;
777 return MNL_CB_OK;
778 }
779
780 static int get_family_id_cb(const struct nlmsghdr *nlh, void *data)
781 {
782 uint16_t *p_id = data;
783 struct nlattr *tb[CTRL_ATTR_MAX + 1] = { 0 };
784
785 mnl_attr_parse(nlh, sizeof(struct genlmsghdr), get_family_id_attr_cb, tb);
786 if (!tb[CTRL_ATTR_FAMILY_ID])
787 return MNL_CB_ERROR;
788 *p_id = mnl_attr_get_u16(tb[CTRL_ATTR_FAMILY_ID]);
789 return MNL_CB_OK;
790 }
791
792 static struct mnlg_socket *mnlg_socket_open(const char *family_name, uint8_t version)
793 {
794 struct mnlg_socket *nlg;
795 struct nlmsghdr *nlh;
796 int err;
797
798 nlg = malloc(sizeof(*nlg));
799 if (!nlg)
800 return NULL;
801 nlg->id = 0;
802
803 err = -ENOMEM;
804 nlg->buf = malloc(mnl_ideal_socket_buffer_size());
805 if (!nlg->buf)
806 goto err_buf_alloc;
807
808 nlg->nl = mnl_socket_open(NETLINK_GENERIC);
809 if (!nlg->nl) {
810 err = -errno;
811 goto err_mnl_socket_open;
812 }
813
814 if (mnl_socket_bind(nlg->nl, 0, MNL_SOCKET_AUTOPID) < 0) {
815 err = -errno;
816 goto err_mnl_socket_bind;
817 }
818
819 nlg->portid = mnl_socket_get_portid(nlg->nl);
820
821 nlh = __mnlg_msg_prepare(nlg, CTRL_CMD_GETFAMILY,
822 NLM_F_REQUEST | NLM_F_ACK, GENL_ID_CTRL, 1);
823 mnl_attr_put_strz(nlh, CTRL_ATTR_FAMILY_NAME, family_name);
824
825 if (mnlg_socket_send(nlg, nlh) < 0) {
826 err = -errno;
827 goto err_mnlg_socket_send;
828 }
829
830 errno = 0;
831 if (mnlg_socket_recv_run(nlg, get_family_id_cb, &nlg->id) < 0) {
832 errno = errno == ENOENT ? EPROTONOSUPPORT : errno;
833 err = errno ? -errno : -ENOSYS;
834 goto err_mnlg_socket_recv_run;
835 }
836
837 nlg->version = version;
838 errno = 0;
839 return nlg;
840
841 err_mnlg_socket_recv_run:
842 err_mnlg_socket_send:
843 err_mnl_socket_bind:
844 mnl_socket_close(nlg->nl);
845 err_mnl_socket_open:
846 free(nlg->buf);
847 err_buf_alloc:
848 free(nlg);
849 errno = -err;
850 return NULL;
851 }
852
853 static void mnlg_socket_close(struct mnlg_socket *nlg)
854 {
855 mnl_socket_close(nlg->nl);
856 free(nlg->buf);
857 free(nlg);
858 }
859
860 /* wireguard-specific parts: */
861
862 struct string_list {
863 char *buffer;
864 size_t len;
865 size_t cap;
866 };
867
868 static int string_list_add(struct string_list *list, const char *str)
869 {
870 size_t len = strlen(str) + 1;
871
872 if (len == 1)
873 return 0;
874
875 if (len >= list->cap - list->len) {
876 char *new_buffer;
877 size_t new_cap = list->cap * 2;
878
879 if (new_cap < list->len +len + 1)
880 new_cap = list->len + len + 1;
881 new_buffer = realloc(list->buffer, new_cap);
882 if (!new_buffer)
883 return -errno;
884 list->buffer = new_buffer;
885 list->cap = new_cap;
886 }
887 memcpy(list->buffer + list->len, str, len);
888 list->len += len;
889 list->buffer[list->len] = '\0';
890 return 0;
891 }
892
893 struct interface {
894 const char *name;
895 bool is_wireguard;
896 };
897
898 static int parse_linkinfo(const struct nlattr *attr, void *data)
899 {
900 struct interface *interface = data;
901
902 if (mnl_attr_get_type(attr) == IFLA_INFO_KIND && !strcmp(WG_GENL_NAME, mnl_attr_get_str(attr)))
903 interface->is_wireguard = true;
904 return MNL_CB_OK;
905 }
906
907 static int parse_infomsg(const struct nlattr *attr, void *data)
908 {
909 struct interface *interface = data;
910
911 if (mnl_attr_get_type(attr) == IFLA_LINKINFO)
912 return mnl_attr_parse_nested(attr, parse_linkinfo, data);
913 else if (mnl_attr_get_type(attr) == IFLA_IFNAME)
914 interface->name = mnl_attr_get_str(attr);
915 return MNL_CB_OK;
916 }
917
918 static int read_devices_cb(const struct nlmsghdr *nlh, void *data)
919 {
920 struct string_list *list = data;
921 struct interface interface = { 0 };
922 int ret;
923
924 ret = mnl_attr_parse(nlh, sizeof(struct ifinfomsg), parse_infomsg, &interface);
925 if (ret != MNL_CB_OK)
926 return ret;
927 if (interface.name && interface.is_wireguard)
928 ret = string_list_add(list, interface.name);
929 if (ret < 0)
930 return ret;
931 if (nlh->nlmsg_type != NLMSG_DONE)
932 return MNL_CB_OK + 1;
933 return MNL_CB_OK;
934 }
935
936 static int fetch_device_names(struct string_list *list)
937 {
938 struct mnl_socket *nl = NULL;
939 char *rtnl_buffer = NULL;
940 size_t message_len;
941 unsigned int portid, seq;
942 ssize_t len;
943 int ret = 0;
944 struct nlmsghdr *nlh;
945 struct ifinfomsg *ifm;
946
947 ret = -ENOMEM;
948 rtnl_buffer = calloc(mnl_ideal_socket_buffer_size(), 1);
949 if (!rtnl_buffer)
950 goto cleanup;
951
952 nl = mnl_socket_open(NETLINK_ROUTE);
953 if (!nl) {
954 ret = -errno;
955 goto cleanup;
956 }
957
958 if (mnl_socket_bind(nl, 0, MNL_SOCKET_AUTOPID) < 0) {
959 ret = -errno;
960 goto cleanup;
961 }
962
963 seq = time(NULL);
964 portid = mnl_socket_get_portid(nl);
965 nlh = mnl_nlmsg_put_header(rtnl_buffer);
966 nlh->nlmsg_type = RTM_GETLINK;
967 nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP;
968 nlh->nlmsg_seq = seq;
969 ifm = mnl_nlmsg_put_extra_header(nlh, sizeof(*ifm));
970 ifm->ifi_family = AF_UNSPEC;
971 message_len = nlh->nlmsg_len;
972
973 if (mnl_socket_sendto(nl, rtnl_buffer, message_len) < 0) {
974 ret = -errno;
975 goto cleanup;
976 }
977
978 another:
979 if ((len = mnl_socket_recvfrom(nl, rtnl_buffer, mnl_ideal_socket_buffer_size())) < 0) {
980 ret = -errno;
981 goto cleanup;
982 }
983 if ((len = mnl_cb_run(rtnl_buffer, len, seq, portid, read_devices_cb, list)) < 0) {
984 /* Netlink returns NLM_F_DUMP_INTR if the set of all tunnels changed
985 * during the dump. That's unfortunate, but is pretty common on busy
986 * systems that are adding and removing tunnels all the time. Rather
987 * than retrying, potentially indefinitely, we just work with the
988 * partial results. */
989 if (errno != EINTR) {
990 ret = -errno;
991 goto cleanup;
992 }
993 }
994 if (len == MNL_CB_OK + 1)
995 goto another;
996 ret = 0;
997
998 cleanup:
999 free(rtnl_buffer);
1000 if (nl)
1001 mnl_socket_close(nl);
1002 return ret;
1003 }
1004
1005 static int add_del_iface(const char *ifname, bool add)
1006 {
1007 struct mnl_socket *nl = NULL;
1008 char *rtnl_buffer;
1009 ssize_t len;
1010 int ret;
1011 struct nlmsghdr *nlh;
1012 struct ifinfomsg *ifm;
1013 struct nlattr *nest;
1014
1015 rtnl_buffer = calloc(mnl_ideal_socket_buffer_size(), 1);
1016 if (!rtnl_buffer) {
1017 ret = -ENOMEM;
1018 goto cleanup;
1019 }
1020
1021 nl = mnl_socket_open(NETLINK_ROUTE);
1022 if (!nl) {
1023 ret = -errno;
1024 goto cleanup;
1025 }
1026
1027 if (mnl_socket_bind(nl, 0, MNL_SOCKET_AUTOPID) < 0) {
1028 ret = -errno;
1029 goto cleanup;
1030 }
1031
1032 nlh = mnl_nlmsg_put_header(rtnl_buffer);
1033 nlh->nlmsg_type = add ? RTM_NEWLINK : RTM_DELLINK;
1034 nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK | (add ? NLM_F_CREATE | NLM_F_EXCL : 0);
1035 nlh->nlmsg_seq = time(NULL);
1036 ifm = mnl_nlmsg_put_extra_header(nlh, sizeof(*ifm));
1037 ifm->ifi_family = AF_UNSPEC;
1038 mnl_attr_put_strz(nlh, IFLA_IFNAME, ifname);
1039 nest = mnl_attr_nest_start(nlh, IFLA_LINKINFO);
1040 mnl_attr_put_strz(nlh, IFLA_INFO_KIND, WG_GENL_NAME);
1041 mnl_attr_nest_end(nlh, nest);
1042
1043 if (mnl_socket_sendto(nl, rtnl_buffer, nlh->nlmsg_len) < 0) {
1044 ret = -errno;
1045 goto cleanup;
1046 }
1047 if ((len = mnl_socket_recvfrom(nl, rtnl_buffer, mnl_ideal_socket_buffer_size())) < 0) {
1048 ret = -errno;
1049 goto cleanup;
1050 }
1051 if (mnl_cb_run(rtnl_buffer, len, nlh->nlmsg_seq, mnl_socket_get_portid(nl), NULL, NULL) < 0) {
1052 ret = -errno;
1053 goto cleanup;
1054 }
1055 ret = 0;
1056
1057 cleanup:
1058 free(rtnl_buffer);
1059 if (nl)
1060 mnl_socket_close(nl);
1061 return ret;
1062 }
1063
1064 int wg_set_device(wg_device *dev)
1065 {
1066 int ret = 0;
1067 wg_peer *peer = NULL;
1068 wg_allowedip *allowedip = NULL;
1069 struct nlattr *peers_nest, *peer_nest, *allowedips_nest, *allowedip_nest;
1070 struct nlmsghdr *nlh;
1071 struct mnlg_socket *nlg;
1072
1073 nlg = mnlg_socket_open(WG_GENL_NAME, WG_GENL_VERSION);
1074 if (!nlg)
1075 return -errno;
1076
1077 again:
1078 nlh = mnlg_msg_prepare(nlg, WG_CMD_SET_DEVICE, NLM_F_REQUEST | NLM_F_ACK);
1079 mnl_attr_put_strz(nlh, WGDEVICE_A_IFNAME, dev->name);
1080
1081 if (!peer) {
1082 uint32_t flags = 0;
1083
1084 if (dev->flags & WGDEVICE_HAS_PRIVATE_KEY)
1085 mnl_attr_put(nlh, WGDEVICE_A_PRIVATE_KEY, sizeof(dev->private_key), dev->private_key);
1086 if (dev->flags & WGDEVICE_HAS_LISTEN_PORT)
1087 mnl_attr_put_u16(nlh, WGDEVICE_A_LISTEN_PORT, dev->listen_port);
1088 if (dev->flags & WGDEVICE_HAS_FWMARK)
1089 mnl_attr_put_u32(nlh, WGDEVICE_A_FWMARK, dev->fwmark);
1090 if (dev->flags & WGDEVICE_REPLACE_PEERS)
1091 flags |= WGDEVICE_F_REPLACE_PEERS;
1092 if (flags)
1093 mnl_attr_put_u32(nlh, WGDEVICE_A_FLAGS, flags);
1094 }
1095 if (!dev->first_peer)
1096 goto send;
1097 peers_nest = peer_nest = allowedips_nest = allowedip_nest = NULL;
1098 peers_nest = mnl_attr_nest_start(nlh, WGDEVICE_A_PEERS);
1099 for (peer = peer ? peer : dev->first_peer; peer; peer = peer->next_peer) {
1100 uint32_t flags = 0;
1101
1102 peer_nest = mnl_attr_nest_start_check(nlh, mnl_ideal_socket_buffer_size(), 0);
1103 if (!peer_nest)
1104 goto toobig_peers;
1105 if (!mnl_attr_put_check(nlh, mnl_ideal_socket_buffer_size(), WGPEER_A_PUBLIC_KEY, sizeof(peer->public_key), peer->public_key))
1106 goto toobig_peers;
1107 if (peer->flags & WGPEER_REMOVE_ME)
1108 flags |= WGPEER_F_REMOVE_ME;
1109 if (!allowedip) {
1110 if (peer->flags & WGPEER_REPLACE_ALLOWEDIPS)
1111 flags |= WGPEER_F_REPLACE_ALLOWEDIPS;
1112 if (peer->flags & WGPEER_HAS_PRESHARED_KEY) {
1113 if (!mnl_attr_put_check(nlh, mnl_ideal_socket_buffer_size(), WGPEER_A_PRESHARED_KEY, sizeof(peer->preshared_key), peer->preshared_key))
1114 goto toobig_peers;
1115 }
1116 if (peer->endpoint.addr.sa_family == AF_INET) {
1117 if (!mnl_attr_put_check(nlh, mnl_ideal_socket_buffer_size(), WGPEER_A_ENDPOINT, sizeof(peer->endpoint.addr4), &peer->endpoint.addr4))
1118 goto toobig_peers;
1119 } else if (peer->endpoint.addr.sa_family == AF_INET6) {
1120 if (!mnl_attr_put_check(nlh, mnl_ideal_socket_buffer_size(), WGPEER_A_ENDPOINT, sizeof(peer->endpoint.addr6), &peer->endpoint.addr6))
1121 goto toobig_peers;
1122 }
1123 if (peer->flags & WGPEER_HAS_PERSISTENT_KEEPALIVE_INTERVAL) {
1124 if (!mnl_attr_put_u16_check(nlh, mnl_ideal_socket_buffer_size(), WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL, peer->persistent_keepalive_interval))
1125 goto toobig_peers;
1126 }
1127 }
1128 if (flags) {
1129 if (!mnl_attr_put_u32_check(nlh, mnl_ideal_socket_buffer_size(), WGPEER_A_FLAGS, flags))
1130 goto toobig_peers;
1131 }
1132 if (peer->first_allowedip) {
1133 if (!allowedip)
1134 allowedip = peer->first_allowedip;
1135 allowedips_nest = mnl_attr_nest_start_check(nlh, mnl_ideal_socket_buffer_size(), WGPEER_A_ALLOWEDIPS);
1136 if (!allowedips_nest)
1137 goto toobig_allowedips;
1138 for (; allowedip; allowedip = allowedip->next_allowedip) {
1139 allowedip_nest = mnl_attr_nest_start_check(nlh, mnl_ideal_socket_buffer_size(), 0);
1140 if (!allowedip_nest)
1141 goto toobig_allowedips;
1142 if (!mnl_attr_put_u16_check(nlh, mnl_ideal_socket_buffer_size(), WGALLOWEDIP_A_FAMILY, allowedip->family))
1143 goto toobig_allowedips;
1144 if (allowedip->family == AF_INET) {
1145 if (!mnl_attr_put_check(nlh, mnl_ideal_socket_buffer_size(), WGALLOWEDIP_A_IPADDR, sizeof(allowedip->ip4), &allowedip->ip4))
1146 goto toobig_allowedips;
1147 } else if (allowedip->family == AF_INET6) {
1148 if (!mnl_attr_put_check(nlh, mnl_ideal_socket_buffer_size(), WGALLOWEDIP_A_IPADDR, sizeof(allowedip->ip6), &allowedip->ip6))
1149 goto toobig_allowedips;
1150 }
1151 if (!mnl_attr_put_u8_check(nlh, mnl_ideal_socket_buffer_size(), WGALLOWEDIP_A_CIDR_MASK, allowedip->cidr))
1152 goto toobig_allowedips;
1153 mnl_attr_nest_end(nlh, allowedip_nest);
1154 allowedip_nest = NULL;
1155 }
1156 mnl_attr_nest_end(nlh, allowedips_nest);
1157 allowedips_nest = NULL;
1158 }
1159
1160 mnl_attr_nest_end(nlh, peer_nest);
1161 peer_nest = NULL;
1162 }
1163 mnl_attr_nest_end(nlh, peers_nest);
1164 peers_nest = NULL;
1165 goto send;
1166 toobig_allowedips:
1167 if (allowedip_nest)
1168 mnl_attr_nest_cancel(nlh, allowedip_nest);
1169 if (allowedips_nest)
1170 mnl_attr_nest_end(nlh, allowedips_nest);
1171 mnl_attr_nest_end(nlh, peer_nest);
1172 mnl_attr_nest_end(nlh, peers_nest);
1173 goto send;
1174 toobig_peers:
1175 if (peer_nest)
1176 mnl_attr_nest_cancel(nlh, peer_nest);
1177 mnl_attr_nest_end(nlh, peers_nest);
1178 goto send;
1179 send:
1180 if (mnlg_socket_send(nlg, nlh) < 0) {
1181 ret = -errno;
1182 goto out;
1183 }
1184 errno = 0;
1185 if (mnlg_socket_recv_run(nlg, NULL, NULL) < 0) {
1186 ret = errno ? -errno : -EINVAL;
1187 goto out;
1188 }
1189 if (peer)
1190 goto again;
1191
1192 out:
1193 mnlg_socket_close(nlg);
1194 errno = -ret;
1195 return ret;
1196 }
1197
1198 static int parse_allowedip(const struct nlattr *attr, void *data)
1199 {
1200 wg_allowedip *allowedip = data;
1201
1202 switch (mnl_attr_get_type(attr)) {
1203 case WGALLOWEDIP_A_UNSPEC:
1204 break;
1205 case WGALLOWEDIP_A_FAMILY:
1206 if (!mnl_attr_validate(attr, MNL_TYPE_U16))
1207 allowedip->family = mnl_attr_get_u16(attr);
1208 break;
1209 case WGALLOWEDIP_A_IPADDR:
1210 if (mnl_attr_get_payload_len(attr) == sizeof(allowedip->ip4))
1211 memcpy(&allowedip->ip4, mnl_attr_get_payload(attr), sizeof(allowedip->ip4));
1212 else if (mnl_attr_get_payload_len(attr) == sizeof(allowedip->ip6))
1213 memcpy(&allowedip->ip6, mnl_attr_get_payload(attr), sizeof(allowedip->ip6));
1214 break;
1215 case WGALLOWEDIP_A_CIDR_MASK:
1216 if (!mnl_attr_validate(attr, MNL_TYPE_U8))
1217 allowedip->cidr = mnl_attr_get_u8(attr);
1218 break;
1219 }
1220
1221 return MNL_CB_OK;
1222 }
1223
1224 static int parse_allowedips(const struct nlattr *attr, void *data)
1225 {
1226 wg_peer *peer = data;
1227 wg_allowedip *new_allowedip = calloc(1, sizeof(wg_allowedip));
1228 int ret;
1229
1230 if (!new_allowedip)
1231 return MNL_CB_ERROR;
1232 if (!peer->first_allowedip)
1233 peer->first_allowedip = peer->last_allowedip = new_allowedip;
1234 else {
1235 peer->last_allowedip->next_allowedip = new_allowedip;
1236 peer->last_allowedip = new_allowedip;
1237 }
1238 ret = mnl_attr_parse_nested(attr, parse_allowedip, new_allowedip);
1239 if (!ret)
1240 return ret;
1241 if (!((new_allowedip->family == AF_INET && new_allowedip->cidr <= 32) || (new_allowedip->family == AF_INET6 && new_allowedip->cidr <= 128))) {
1242 errno = EAFNOSUPPORT;
1243 return MNL_CB_ERROR;
1244 }
1245 return MNL_CB_OK;
1246 }
1247
1248 bool wg_key_is_zero(const wg_key key)
1249 {
1250 volatile uint8_t acc = 0;
1251 unsigned int i;
1252
1253 for (i = 0; i < sizeof(wg_key); ++i) {
1254 acc |= key[i];
1255 __asm__ ("" : "=r" (acc) : "0" (acc));
1256 }
1257 return 1 & ((acc - 1) >> 8);
1258 }
1259
1260 static int parse_peer(const struct nlattr *attr, void *data)
1261 {
1262 wg_peer *peer = data;
1263
1264 switch (mnl_attr_get_type(attr)) {
1265 case WGPEER_A_UNSPEC:
1266 break;
1267 case WGPEER_A_PUBLIC_KEY:
1268 if (mnl_attr_get_payload_len(attr) == sizeof(peer->public_key)) {
1269 memcpy(peer->public_key, mnl_attr_get_payload(attr), sizeof(peer->public_key));
1270 peer->flags |= WGPEER_HAS_PUBLIC_KEY;
1271 }
1272 break;
1273 case WGPEER_A_PRESHARED_KEY:
1274 if (mnl_attr_get_payload_len(attr) == sizeof(peer->preshared_key)) {
1275 memcpy(peer->preshared_key, mnl_attr_get_payload(attr), sizeof(peer->preshared_key));
1276 if (!wg_key_is_zero(peer->preshared_key))
1277 peer->flags |= WGPEER_HAS_PRESHARED_KEY;
1278 }
1279 break;
1280 case WGPEER_A_ENDPOINT: {
1281 struct sockaddr *addr;
1282
1283 if (mnl_attr_get_payload_len(attr) < sizeof(*addr))
1284 break;
1285 addr = mnl_attr_get_payload(attr);
1286 if (addr->sa_family == AF_INET && mnl_attr_get_payload_len(attr) == sizeof(peer->endpoint.addr4))
1287 memcpy(&peer->endpoint.addr4, addr, sizeof(peer->endpoint.addr4));
1288 else if (addr->sa_family == AF_INET6 && mnl_attr_get_payload_len(attr) == sizeof(peer->endpoint.addr6))
1289 memcpy(&peer->endpoint.addr6, addr, sizeof(peer->endpoint.addr6));
1290 break;
1291 }
1292 case WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL:
1293 if (!mnl_attr_validate(attr, MNL_TYPE_U16))
1294 peer->persistent_keepalive_interval = mnl_attr_get_u16(attr);
1295 break;
1296 case WGPEER_A_LAST_HANDSHAKE_TIME:
1297 if (mnl_attr_get_payload_len(attr) == sizeof(peer->last_handshake_time))
1298 memcpy(&peer->last_handshake_time, mnl_attr_get_payload(attr), sizeof(peer->last_handshake_time));
1299 break;
1300 case WGPEER_A_RX_BYTES:
1301 if (!mnl_attr_validate(attr, MNL_TYPE_U64))
1302 peer->rx_bytes = mnl_attr_get_u64(attr);
1303 break;
1304 case WGPEER_A_TX_BYTES:
1305 if (!mnl_attr_validate(attr, MNL_TYPE_U64))
1306 peer->tx_bytes = mnl_attr_get_u64(attr);
1307 break;
1308 case WGPEER_A_ALLOWEDIPS:
1309 return mnl_attr_parse_nested(attr, parse_allowedips, peer);
1310 }
1311
1312 return MNL_CB_OK;
1313 }
1314
1315 static int parse_peers(const struct nlattr *attr, void *data)
1316 {
1317 wg_device *device = data;
1318 wg_peer *new_peer = calloc(1, sizeof(wg_peer));
1319 int ret;
1320
1321 if (!new_peer)
1322 return MNL_CB_ERROR;
1323 if (!device->first_peer)
1324 device->first_peer = device->last_peer = new_peer;
1325 else {
1326 device->last_peer->next_peer = new_peer;
1327 device->last_peer = new_peer;
1328 }
1329 ret = mnl_attr_parse_nested(attr, parse_peer, new_peer);
1330 if (!ret)
1331 return ret;
1332 if (!(new_peer->flags & WGPEER_HAS_PUBLIC_KEY)) {
1333 errno = ENXIO;
1334 return MNL_CB_ERROR;
1335 }
1336 return MNL_CB_OK;
1337 }
1338
1339 static int parse_device(const struct nlattr *attr, void *data)
1340 {
1341 wg_device *device = data;
1342
1343 switch (mnl_attr_get_type(attr)) {
1344 case WGDEVICE_A_UNSPEC:
1345 break;
1346 case WGDEVICE_A_IFINDEX:
1347 if (!mnl_attr_validate(attr, MNL_TYPE_U32))
1348 device->ifindex = mnl_attr_get_u32(attr);
1349 break;
1350 case WGDEVICE_A_IFNAME:
1351 if (!mnl_attr_validate(attr, MNL_TYPE_STRING)) {
1352 strncpy(device->name, mnl_attr_get_str(attr), sizeof(device->name) - 1);
1353 device->name[sizeof(device->name) - 1] = '\0';
1354 }
1355 break;
1356 case WGDEVICE_A_PRIVATE_KEY:
1357 if (mnl_attr_get_payload_len(attr) == sizeof(device->private_key)) {
1358 memcpy(device->private_key, mnl_attr_get_payload(attr), sizeof(device->private_key));
1359 device->flags |= WGDEVICE_HAS_PRIVATE_KEY;
1360 }
1361 break;
1362 case WGDEVICE_A_PUBLIC_KEY:
1363 if (mnl_attr_get_payload_len(attr) == sizeof(device->public_key)) {
1364 memcpy(device->public_key, mnl_attr_get_payload(attr), sizeof(device->public_key));
1365 device->flags |= WGDEVICE_HAS_PUBLIC_KEY;
1366 }
1367 break;
1368 case WGDEVICE_A_LISTEN_PORT:
1369 if (!mnl_attr_validate(attr, MNL_TYPE_U16))
1370 device->listen_port = mnl_attr_get_u16(attr);
1371 break;
1372 case WGDEVICE_A_FWMARK:
1373 if (!mnl_attr_validate(attr, MNL_TYPE_U32))
1374 device->fwmark = mnl_attr_get_u32(attr);
1375 break;
1376 case WGDEVICE_A_PEERS:
1377 return mnl_attr_parse_nested(attr, parse_peers, device);
1378 }
1379
1380 return MNL_CB_OK;
1381 }
1382
1383 static int read_device_cb(const struct nlmsghdr *nlh, void *data)
1384 {
1385 return mnl_attr_parse(nlh, sizeof(struct genlmsghdr), parse_device, data);
1386 }
1387
1388 static void coalesce_peers(wg_device *device)
1389 {
1390 wg_peer *old_next_peer, *peer = device->first_peer;
1391
1392 while (peer && peer->next_peer) {
1393 if (memcmp(peer->public_key, peer->next_peer->public_key, sizeof(wg_key))) {
1394 peer = peer->next_peer;
1395 continue;
1396 }
1397 if (!peer->first_allowedip) {
1398 peer->first_allowedip = peer->next_peer->first_allowedip;
1399 peer->last_allowedip = peer->next_peer->last_allowedip;
1400 } else {
1401 peer->last_allowedip->next_allowedip = peer->next_peer->first_allowedip;
1402 peer->last_allowedip = peer->next_peer->last_allowedip;
1403 }
1404 old_next_peer = peer->next_peer;
1405 peer->next_peer = old_next_peer->next_peer;
1406 free(old_next_peer);
1407 }
1408 }
1409
1410 int wg_get_device(wg_device **device, const char *device_name)
1411 {
1412 int ret = 0;
1413 struct nlmsghdr *nlh;
1414 struct mnlg_socket *nlg;
1415
1416 try_again:
1417 *device = calloc(1, sizeof(wg_device));
1418 if (!*device)
1419 return -errno;
1420
1421 nlg = mnlg_socket_open(WG_GENL_NAME, WG_GENL_VERSION);
1422 if (!nlg) {
1423 wg_free_device(*device);
1424 *device = NULL;
1425 return -errno;
1426 }
1427
1428 nlh = mnlg_msg_prepare(nlg, WG_CMD_GET_DEVICE, NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP);
1429 mnl_attr_put_strz(nlh, WGDEVICE_A_IFNAME, device_name);
1430 if (mnlg_socket_send(nlg, nlh) < 0) {
1431 ret = -errno;
1432 goto out;
1433 }
1434 errno = 0;
1435 if (mnlg_socket_recv_run(nlg, read_device_cb, *device) < 0) {
1436 ret = errno ? -errno : -EINVAL;
1437 goto out;
1438 }
1439 coalesce_peers(*device);
1440
1441 out:
1442 if (nlg)
1443 mnlg_socket_close(nlg);
1444 if (ret) {
1445 wg_free_device(*device);
1446 if (ret == -EINTR)
1447 goto try_again;
1448 *device = NULL;
1449 }
1450 errno = -ret;
1451 return ret;
1452 }
1453
1454 /* first\0second\0third\0forth\0last\0\0 */
1455 char *wg_list_device_names(void)
1456 {
1457 struct string_list list = { 0 };
1458 int ret = fetch_device_names(&list);
1459
1460 errno = -ret;
1461 if (errno) {
1462 free(list.buffer);
1463 return NULL;
1464 }
1465 return list.buffer ?: strdup("\0");
1466 }
1467
1468 int wg_add_device(const char *device_name)
1469 {
1470 return add_del_iface(device_name, true);
1471 }
1472
1473 int wg_del_device(const char *device_name)
1474 {
1475 return add_del_iface(device_name, false);
1476 }
1477
1478 void wg_free_device(wg_device *dev)
1479 {
1480 wg_peer *peer, *np;
1481 wg_allowedip *allowedip, *na;
1482
1483 if (!dev)
1484 return;
1485 for (peer = dev->first_peer, np = peer ? peer->next_peer : NULL; peer; peer = np, np = peer ? peer->next_peer : NULL) {
1486 for (allowedip = peer->first_allowedip, na = allowedip ? allowedip->next_allowedip : NULL; allowedip; allowedip = na, na = allowedip ? allowedip->next_allowedip : NULL)
1487 free(allowedip);
1488 free(peer);
1489 }
1490 free(dev);
1491 }
1492
1493 static void encode_base64(char dest[static 4], const uint8_t src[static 3])
1494 {
1495 const uint8_t input[] = { (src[0] >> 2) & 63, ((src[0] << 4) | (src[1] >> 4)) & 63, ((src[1] << 2) | (src[2] >> 6)) & 63, src[2] & 63 };
1496 unsigned int i;
1497
1498 for (i = 0; i < 4; ++i)
1499 dest[i] = input[i] + 'A'
1500 + (((25 - input[i]) >> 8) & 6)
1501 - (((51 - input[i]) >> 8) & 75)
1502 - (((61 - input[i]) >> 8) & 15)
1503 + (((62 - input[i]) >> 8) & 3);
1504
1505 }
1506
1507 void wg_key_to_base64(wg_key_b64_string base64, const wg_key key)
1508 {
1509 unsigned int i;
1510
1511 for (i = 0; i < 32 / 3; ++i)
1512 encode_base64(&base64[i * 4], &key[i * 3]);
1513 encode_base64(&base64[i * 4], (const uint8_t[]){ key[i * 3 + 0], key[i * 3 + 1], 0 });
1514 base64[sizeof(wg_key_b64_string) - 2] = '=';
1515 base64[sizeof(wg_key_b64_string) - 1] = '\0';
1516 }
1517
1518 static int decode_base64(const char src[static 4])
1519 {
1520 int val = 0;
1521 unsigned int i;
1522
1523 for (i = 0; i < 4; ++i)
1524 val |= (-1
1525 + ((((('A' - 1) - src[i]) & (src[i] - ('Z' + 1))) >> 8) & (src[i] - 64))
1526 + ((((('a' - 1) - src[i]) & (src[i] - ('z' + 1))) >> 8) & (src[i] - 70))
1527 + ((((('0' - 1) - src[i]) & (src[i] - ('9' + 1))) >> 8) & (src[i] + 5))
1528 + ((((('+' - 1) - src[i]) & (src[i] - ('+' + 1))) >> 8) & 63)
1529 + ((((('/' - 1) - src[i]) & (src[i] - ('/' + 1))) >> 8) & 64)
1530 ) << (18 - 6 * i);
1531 return val;
1532 }
1533
1534 int wg_key_from_base64(wg_key key, const wg_key_b64_string base64)
1535 {
1536 unsigned int i;
1537 int val;
1538 volatile uint8_t ret = 0;
1539
1540 if (strlen(base64) != sizeof(wg_key_b64_string) - 1 || base64[sizeof(wg_key_b64_string) - 2] != '=') {
1541 errno = EINVAL;
1542 goto out;
1543 }
1544
1545 for (i = 0; i < 32 / 3; ++i) {
1546 val = decode_base64(&base64[i * 4]);
1547 ret |= (uint32_t)val >> 31;
1548 key[i * 3 + 0] = (val >> 16) & 0xff;
1549 key[i * 3 + 1] = (val >> 8) & 0xff;
1550 key[i * 3 + 2] = val & 0xff;
1551 }
1552 val = decode_base64((const char[]){ base64[i * 4 + 0], base64[i * 4 + 1], base64[i * 4 + 2], 'A' });
1553 ret |= ((uint32_t)val >> 31) | (val & 0xff);
1554 key[i * 3 + 0] = (val >> 16) & 0xff;
1555 key[i * 3 + 1] = (val >> 8) & 0xff;
1556 errno = EINVAL & ~((ret - 1) >> 8);
1557 out:
1558 return -errno;
1559 }
1560
1561 typedef int64_t fe[16];
1562
1563 static __attribute__((noinline)) void memzero_explicit(void *s, size_t count)
1564 {
1565 memset(s, 0, count);
1566 __asm__ __volatile__("": :"r"(s) :"memory");
1567 }
1568
1569 static void carry(fe o)
1570 {
1571 int i;
1572
1573 for (i = 0; i < 16; ++i) {
1574 o[(i + 1) % 16] += (i == 15 ? 38 : 1) * (o[i] >> 16);
1575 o[i] &= 0xffff;
1576 }
1577 }
1578
1579 static void cswap(fe p, fe q, int b)
1580 {
1581 int i;
1582 int64_t t, c = ~(b - 1);
1583
1584 for (i = 0; i < 16; ++i) {
1585 t = c & (p[i] ^ q[i]);
1586 p[i] ^= t;
1587 q[i] ^= t;
1588 }
1589
1590 memzero_explicit(&t, sizeof(t));
1591 memzero_explicit(&c, sizeof(c));
1592 memzero_explicit(&b, sizeof(b));
1593 }
1594
1595 static void pack(uint8_t *o, const fe n)
1596 {
1597 int i, j, b;
1598 fe m, t;
1599
1600 memcpy(t, n, sizeof(t));
1601 carry(t);
1602 carry(t);
1603 carry(t);
1604 for (j = 0; j < 2; ++j) {
1605 m[0] = t[0] - 0xffed;
1606 for (i = 1; i < 15; ++i) {
1607 m[i] = t[i] - 0xffff - ((m[i - 1] >> 16) & 1);
1608 m[i - 1] &= 0xffff;
1609 }
1610 m[15] = t[15] - 0x7fff - ((m[14] >> 16) & 1);
1611 b = (m[15] >> 16) & 1;
1612 m[14] &= 0xffff;
1613 cswap(t, m, 1 - b);
1614 }
1615 for (i = 0; i < 16; ++i) {
1616 o[2 * i] = t[i] & 0xff;
1617 o[2 * i + 1] = t[i] >> 8;
1618 }
1619
1620 memzero_explicit(m, sizeof(m));
1621 memzero_explicit(t, sizeof(t));
1622 memzero_explicit(&b, sizeof(b));
1623 }
1624
1625 static void add(fe o, const fe a, const fe b)
1626 {
1627 int i;
1628
1629 for (i = 0; i < 16; ++i)
1630 o[i] = a[i] + b[i];
1631 }
1632
1633 static void subtract(fe o, const fe a, const fe b)
1634 {
1635 int i;
1636
1637 for (i = 0; i < 16; ++i)
1638 o[i] = a[i] - b[i];
1639 }
1640
1641 static void multmod(fe o, const fe a, const fe b)
1642 {
1643 int i, j;
1644 int64_t t[31] = { 0 };
1645
1646 for (i = 0; i < 16; ++i) {
1647 for (j = 0; j < 16; ++j)
1648 t[i + j] += a[i] * b[j];
1649 }
1650 for (i = 0; i < 15; ++i)
1651 t[i] += 38 * t[i + 16];
1652 memcpy(o, t, sizeof(fe));
1653 carry(o);
1654 carry(o);
1655
1656 memzero_explicit(t, sizeof(t));
1657 }
1658
1659 static void invert(fe o, const fe i)
1660 {
1661 fe c;
1662 int a;
1663
1664 memcpy(c, i, sizeof(c));
1665 for (a = 253; a >= 0; --a) {
1666 multmod(c, c, c);
1667 if (a != 2 && a != 4)
1668 multmod(c, c, i);
1669 }
1670 memcpy(o, c, sizeof(fe));
1671
1672 memzero_explicit(c, sizeof(c));
1673 }
1674
1675 static void clamp_key(uint8_t *z)
1676 {
1677 z[31] = (z[31] & 127) | 64;
1678 z[0] &= 248;
1679 }
1680
1681 void wg_generate_public_key(wg_key public_key, const wg_key private_key)
1682 {
1683 int i, r;
1684 uint8_t z[32];
1685 fe a = { 1 }, b = { 9 }, c = { 0 }, d = { 1 }, e, f;
1686
1687 memcpy(z, private_key, sizeof(z));
1688 clamp_key(z);
1689
1690 for (i = 254; i >= 0; --i) {
1691 r = (z[i >> 3] >> (i & 7)) & 1;
1692 cswap(a, b, r);
1693 cswap(c, d, r);
1694 add(e, a, c);
1695 subtract(a, a, c);
1696 add(c, b, d);
1697 subtract(b, b, d);
1698 multmod(d, e, e);
1699 multmod(f, a, a);
1700 multmod(a, c, a);
1701 multmod(c, b, e);
1702 add(e, a, c);
1703 subtract(a, a, c);
1704 multmod(b, a, a);
1705 subtract(c, d, f);
1706 multmod(a, c, (const fe){ 0xdb41, 1 });
1707 add(a, a, d);
1708 multmod(c, c, a);
1709 multmod(a, d, f);
1710 multmod(d, b, (const fe){ 9 });
1711 multmod(b, e, e);
1712 cswap(a, b, r);
1713 cswap(c, d, r);
1714 }
1715 invert(c, c);
1716 multmod(a, a, c);
1717 pack(public_key, a);
1718
1719 memzero_explicit(&r, sizeof(r));
1720 memzero_explicit(z, sizeof(z));
1721 memzero_explicit(a, sizeof(a));
1722 memzero_explicit(b, sizeof(b));
1723 memzero_explicit(c, sizeof(c));
1724 memzero_explicit(d, sizeof(d));
1725 memzero_explicit(e, sizeof(e));
1726 memzero_explicit(f, sizeof(f));
1727 }
1728
1729 void wg_generate_private_key(wg_key private_key)
1730 {
1731 wg_generate_preshared_key(private_key);
1732 clamp_key(private_key);
1733 }
1734
1735 void wg_generate_preshared_key(wg_key preshared_key)
1736 {
1737 ssize_t ret;
1738 size_t i;
1739 int fd;
1740 #if defined(__OpenBSD__) || (defined(__APPLE__) && MAC_OS_X_VERSION_MIN_REQUIRED >= MAC_OS_X_VERSION_10_12) || (defined(__GLIBC__) && (__GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ >= 25)))
1741 if (!getentropy(preshared_key, sizeof(wg_key)))
1742 return;
1743 #endif
1744 #if defined(__NR_getrandom) && defined(__linux__)
1745 if (syscall(__NR_getrandom, preshared_key, sizeof(wg_key), 0) == sizeof(wg_key))
1746 return;
1747 #endif
1748 fd = open("/dev/urandom", O_RDONLY);
1749 assert(fd >= 0);
1750 for (i = 0; i < sizeof(wg_key); i += ret) {
1751 ret = read(fd, preshared_key + i, sizeof(wg_key) - i);
1752 assert(ret > 0);
1753 }
1754 close(fd);
1755 }