1 // SPDX-License-Identifier: LGPL-2.1+
3 * Copyright (C) 2015-2020 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4 * Copyright (C) 2008-2012 Pablo Neira Ayuso <pablo@netfilter.org>.
10 #include <linux/genetlink.h>
11 #include <linux/if_link.h>
12 #include <linux/netlink.h>
13 #include <linux/rtnetlink.h>
14 #include <netinet/in.h>
19 #include <sys/socket.h>
25 #include "wireguard.h"
27 /* wireguard.h netlink uapi: */
29 #define WG_GENL_NAME "wireguard"
30 #define WG_GENL_VERSION 1
39 WGDEVICE_F_REPLACE_PEERS
= 1U << 0
41 enum wgdevice_attribute
{
45 WGDEVICE_A_PRIVATE_KEY
,
46 WGDEVICE_A_PUBLIC_KEY
,
48 WGDEVICE_A_LISTEN_PORT
,
55 WGPEER_F_REMOVE_ME
= 1U << 0,
56 WGPEER_F_REPLACE_ALLOWEDIPS
= 1U << 1
58 enum wgpeer_attribute
{
61 WGPEER_A_PRESHARED_KEY
,
64 WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL
,
65 WGPEER_A_LAST_HANDSHAKE_TIME
,
69 WGPEER_A_PROTOCOL_VERSION
,
73 enum wgallowedip_attribute
{
77 WGALLOWEDIP_A_CIDR_MASK
,
81 /* libmnl mini library: */
83 #define MNL_SOCKET_AUTOPID 0
85 #define MNL_ALIGN(len) (((len)+MNL_ALIGNTO-1) & ~(MNL_ALIGNTO-1))
86 #define MNL_NLMSG_HDRLEN MNL_ALIGN(sizeof(struct nlmsghdr))
87 #define MNL_ATTR_HDRLEN MNL_ALIGN(sizeof(struct nlattr))
89 enum mnl_attr_data_type
{
99 MNL_TYPE_NESTED_COMPAT
,
105 #define mnl_attr_for_each(attr, nlh, offset) \
106 for ((attr) = mnl_nlmsg_get_payload_offset((nlh), (offset)); \
107 mnl_attr_ok((attr), (char *)mnl_nlmsg_get_payload_tail(nlh) - (char *)(attr)); \
108 (attr) = mnl_attr_next(attr))
110 #define mnl_attr_for_each_nested(attr, nest) \
111 for ((attr) = mnl_attr_get_payload(nest); \
112 mnl_attr_ok((attr), (char *)mnl_attr_get_payload(nest) + mnl_attr_get_payload_len(nest) - (char *)(attr)); \
113 (attr) = mnl_attr_next(attr))
115 #define mnl_attr_for_each_payload(payload, payload_size) \
116 for ((attr) = (payload); \
117 mnl_attr_ok((attr), (char *)(payload) + payload_size - (char *)(attr)); \
118 (attr) = mnl_attr_next(attr))
120 #define MNL_CB_ERROR -1
121 #define MNL_CB_STOP 0
124 typedef int (*mnl_attr_cb_t
)(const struct nlattr
*attr
, void *data
);
125 typedef int (*mnl_cb_t
)(const struct nlmsghdr
*nlh
, void *data
);
127 #ifndef MNL_ARRAY_SIZE
128 #define MNL_ARRAY_SIZE(a) (sizeof(a)/sizeof((a)[0]))
131 static size_t mnl_ideal_socket_buffer_size(void)
133 static size_t size
= 0;
137 size
= (size_t)sysconf(_SC_PAGESIZE
);
143 static size_t mnl_nlmsg_size(size_t len
)
145 return len
+ MNL_NLMSG_HDRLEN
;
148 static struct nlmsghdr
*mnl_nlmsg_put_header(void *buf
)
150 int len
= MNL_ALIGN(sizeof(struct nlmsghdr
));
151 struct nlmsghdr
*nlh
= buf
;
154 nlh
->nlmsg_len
= len
;
158 static void *mnl_nlmsg_put_extra_header(struct nlmsghdr
*nlh
, size_t size
)
160 char *ptr
= (char *)nlh
+ nlh
->nlmsg_len
;
161 size_t len
= MNL_ALIGN(size
);
162 nlh
->nlmsg_len
+= len
;
167 static void *mnl_nlmsg_get_payload(const struct nlmsghdr
*nlh
)
169 return (void *)nlh
+ MNL_NLMSG_HDRLEN
;
172 static void *mnl_nlmsg_get_payload_offset(const struct nlmsghdr
*nlh
, size_t offset
)
174 return (void *)nlh
+ MNL_NLMSG_HDRLEN
+ MNL_ALIGN(offset
);
177 static bool mnl_nlmsg_ok(const struct nlmsghdr
*nlh
, int len
)
179 return len
>= (int)sizeof(struct nlmsghdr
) &&
180 nlh
->nlmsg_len
>= sizeof(struct nlmsghdr
) &&
181 (int)nlh
->nlmsg_len
<= len
;
184 static struct nlmsghdr
*mnl_nlmsg_next(const struct nlmsghdr
*nlh
, int *len
)
186 *len
-= MNL_ALIGN(nlh
->nlmsg_len
);
187 return (struct nlmsghdr
*)((void *)nlh
+ MNL_ALIGN(nlh
->nlmsg_len
));
190 static void *mnl_nlmsg_get_payload_tail(const struct nlmsghdr
*nlh
)
192 return (void *)nlh
+ MNL_ALIGN(nlh
->nlmsg_len
);
195 static bool mnl_nlmsg_seq_ok(const struct nlmsghdr
*nlh
, unsigned int seq
)
197 return nlh
->nlmsg_seq
&& seq
? nlh
->nlmsg_seq
== seq
: true;
200 static bool mnl_nlmsg_portid_ok(const struct nlmsghdr
*nlh
, unsigned int portid
)
202 return nlh
->nlmsg_pid
&& portid
? nlh
->nlmsg_pid
== portid
: true;
205 static uint16_t mnl_attr_get_type(const struct nlattr
*attr
)
207 return attr
->nla_type
& NLA_TYPE_MASK
;
210 static uint16_t mnl_attr_get_payload_len(const struct nlattr
*attr
)
212 return attr
->nla_len
- MNL_ATTR_HDRLEN
;
215 static void *mnl_attr_get_payload(const struct nlattr
*attr
)
217 return (void *)attr
+ MNL_ATTR_HDRLEN
;
220 static bool mnl_attr_ok(const struct nlattr
*attr
, int len
)
222 return len
>= (int)sizeof(struct nlattr
) &&
223 attr
->nla_len
>= sizeof(struct nlattr
) &&
224 (int)attr
->nla_len
<= len
;
227 static struct nlattr
*mnl_attr_next(const struct nlattr
*attr
)
229 return (struct nlattr
*)((void *)attr
+ MNL_ALIGN(attr
->nla_len
));
232 static int mnl_attr_type_valid(const struct nlattr
*attr
, uint16_t max
)
234 if (mnl_attr_get_type(attr
) > max
) {
241 static int __mnl_attr_validate(const struct nlattr
*attr
,
242 enum mnl_attr_data_type type
, size_t exp_len
)
244 uint16_t attr_len
= mnl_attr_get_payload_len(attr
);
245 const char *attr_data
= mnl_attr_get_payload(attr
);
247 if (attr_len
< exp_len
) {
258 case MNL_TYPE_NUL_STRING
:
263 if (attr_data
[attr_len
-1] != '\0') {
268 case MNL_TYPE_STRING
:
274 case MNL_TYPE_NESTED
:
279 if (attr_len
< MNL_ATTR_HDRLEN
) {
288 if (exp_len
&& attr_len
> exp_len
) {
295 static const size_t mnl_attr_data_type_len
[MNL_TYPE_MAX
] = {
296 [MNL_TYPE_U8
] = sizeof(uint8_t),
297 [MNL_TYPE_U16
] = sizeof(uint16_t),
298 [MNL_TYPE_U32
] = sizeof(uint32_t),
299 [MNL_TYPE_U64
] = sizeof(uint64_t),
300 [MNL_TYPE_MSECS
] = sizeof(uint64_t),
303 static int mnl_attr_validate(const struct nlattr
*attr
, enum mnl_attr_data_type type
)
307 if (type
>= MNL_TYPE_MAX
) {
311 exp_len
= mnl_attr_data_type_len
[type
];
312 return __mnl_attr_validate(attr
, type
, exp_len
);
315 static int mnl_attr_parse(const struct nlmsghdr
*nlh
, unsigned int offset
,
316 mnl_attr_cb_t cb
, void *data
)
319 const struct nlattr
*attr
;
321 mnl_attr_for_each(attr
, nlh
, offset
)
322 if ((ret
= cb(attr
, data
)) <= MNL_CB_STOP
)
327 static int mnl_attr_parse_nested(const struct nlattr
*nested
, mnl_attr_cb_t cb
,
331 const struct nlattr
*attr
;
333 mnl_attr_for_each_nested(attr
, nested
)
334 if ((ret
= cb(attr
, data
)) <= MNL_CB_STOP
)
339 static uint8_t mnl_attr_get_u8(const struct nlattr
*attr
)
341 return *((uint8_t *)mnl_attr_get_payload(attr
));
344 static uint16_t mnl_attr_get_u16(const struct nlattr
*attr
)
346 return *((uint16_t *)mnl_attr_get_payload(attr
));
349 static uint32_t mnl_attr_get_u32(const struct nlattr
*attr
)
351 return *((uint32_t *)mnl_attr_get_payload(attr
));
354 static uint64_t mnl_attr_get_u64(const struct nlattr
*attr
)
357 memcpy(&tmp
, mnl_attr_get_payload(attr
), sizeof(tmp
));
361 static const char *mnl_attr_get_str(const struct nlattr
*attr
)
363 return mnl_attr_get_payload(attr
);
366 static void mnl_attr_put(struct nlmsghdr
*nlh
, uint16_t type
, size_t len
,
369 struct nlattr
*attr
= mnl_nlmsg_get_payload_tail(nlh
);
370 uint16_t payload_len
= MNL_ALIGN(sizeof(struct nlattr
)) + len
;
373 attr
->nla_type
= type
;
374 attr
->nla_len
= payload_len
;
375 memcpy(mnl_attr_get_payload(attr
), data
, len
);
376 nlh
->nlmsg_len
+= MNL_ALIGN(payload_len
);
377 pad
= MNL_ALIGN(len
) - len
;
379 memset(mnl_attr_get_payload(attr
) + len
, 0, pad
);
382 static void mnl_attr_put_u16(struct nlmsghdr
*nlh
, uint16_t type
, uint16_t data
)
384 mnl_attr_put(nlh
, type
, sizeof(uint16_t), &data
);
387 static void mnl_attr_put_u32(struct nlmsghdr
*nlh
, uint16_t type
, uint32_t data
)
389 mnl_attr_put(nlh
, type
, sizeof(uint32_t), &data
);
392 static void mnl_attr_put_strz(struct nlmsghdr
*nlh
, uint16_t type
, const char *data
)
394 mnl_attr_put(nlh
, type
, strlen(data
)+1, data
);
397 static struct nlattr
*mnl_attr_nest_start(struct nlmsghdr
*nlh
, uint16_t type
)
399 struct nlattr
*start
= mnl_nlmsg_get_payload_tail(nlh
);
401 start
->nla_type
= NLA_F_NESTED
| type
;
402 nlh
->nlmsg_len
+= MNL_ALIGN(sizeof(struct nlattr
));
406 static bool mnl_attr_put_check(struct nlmsghdr
*nlh
, size_t buflen
,
407 uint16_t type
, size_t len
, const void *data
)
409 if (nlh
->nlmsg_len
+ MNL_ATTR_HDRLEN
+ MNL_ALIGN(len
) > buflen
)
411 mnl_attr_put(nlh
, type
, len
, data
);
415 static bool mnl_attr_put_u8_check(struct nlmsghdr
*nlh
, size_t buflen
,
416 uint16_t type
, uint8_t data
)
418 return mnl_attr_put_check(nlh
, buflen
, type
, sizeof(uint8_t), &data
);
421 static bool mnl_attr_put_u16_check(struct nlmsghdr
*nlh
, size_t buflen
,
422 uint16_t type
, uint16_t data
)
424 return mnl_attr_put_check(nlh
, buflen
, type
, sizeof(uint16_t), &data
);
427 static bool mnl_attr_put_u32_check(struct nlmsghdr
*nlh
, size_t buflen
,
428 uint16_t type
, uint32_t data
)
430 return mnl_attr_put_check(nlh
, buflen
, type
, sizeof(uint32_t), &data
);
433 static struct nlattr
*mnl_attr_nest_start_check(struct nlmsghdr
*nlh
, size_t buflen
,
436 if (nlh
->nlmsg_len
+ MNL_ATTR_HDRLEN
> buflen
)
438 return mnl_attr_nest_start(nlh
, type
);
441 static void mnl_attr_nest_end(struct nlmsghdr
*nlh
, struct nlattr
*start
)
443 start
->nla_len
= mnl_nlmsg_get_payload_tail(nlh
) - (void *)start
;
446 static void mnl_attr_nest_cancel(struct nlmsghdr
*nlh
, struct nlattr
*start
)
448 nlh
->nlmsg_len
-= mnl_nlmsg_get_payload_tail(nlh
) - (void *)start
;
451 static int mnl_cb_noop(__attribute__((unused
)) const struct nlmsghdr
*nlh
, __attribute__((unused
)) void *data
)
456 static int mnl_cb_error(const struct nlmsghdr
*nlh
, __attribute__((unused
)) void *data
)
458 const struct nlmsgerr
*err
= mnl_nlmsg_get_payload(nlh
);
460 if (nlh
->nlmsg_len
< mnl_nlmsg_size(sizeof(struct nlmsgerr
))) {
470 return err
->error
== 0 ? MNL_CB_STOP
: MNL_CB_ERROR
;
473 static int mnl_cb_stop(__attribute__((unused
)) const struct nlmsghdr
*nlh
, __attribute__((unused
)) void *data
)
478 static const mnl_cb_t default_cb_array
[NLMSG_MIN_TYPE
] = {
479 [NLMSG_NOOP
] = mnl_cb_noop
,
480 [NLMSG_ERROR
] = mnl_cb_error
,
481 [NLMSG_DONE
] = mnl_cb_stop
,
482 [NLMSG_OVERRUN
] = mnl_cb_noop
,
485 static int __mnl_cb_run(const void *buf
, size_t numbytes
,
486 unsigned int seq
, unsigned int portid
,
487 mnl_cb_t cb_data
, void *data
,
488 const mnl_cb_t
*cb_ctl_array
,
489 unsigned int cb_ctl_array_len
)
491 int ret
= MNL_CB_OK
, len
= numbytes
;
492 const struct nlmsghdr
*nlh
= buf
;
494 while (mnl_nlmsg_ok(nlh
, len
)) {
496 if (!mnl_nlmsg_portid_ok(nlh
, portid
)) {
501 if (!mnl_nlmsg_seq_ok(nlh
, seq
)) {
506 if (nlh
->nlmsg_flags
& NLM_F_DUMP_INTR
) {
511 if (nlh
->nlmsg_type
>= NLMSG_MIN_TYPE
) {
513 ret
= cb_data(nlh
, data
);
514 if (ret
<= MNL_CB_STOP
)
517 } else if (nlh
->nlmsg_type
< cb_ctl_array_len
) {
518 if (cb_ctl_array
&& cb_ctl_array
[nlh
->nlmsg_type
]) {
519 ret
= cb_ctl_array
[nlh
->nlmsg_type
](nlh
, data
);
520 if (ret
<= MNL_CB_STOP
)
523 } else if (default_cb_array
[nlh
->nlmsg_type
]) {
524 ret
= default_cb_array
[nlh
->nlmsg_type
](nlh
, data
);
525 if (ret
<= MNL_CB_STOP
)
528 nlh
= mnl_nlmsg_next(nlh
, &len
);
534 static int mnl_cb_run2(const void *buf
, size_t numbytes
, unsigned int seq
,
535 unsigned int portid
, mnl_cb_t cb_data
, void *data
,
536 const mnl_cb_t
*cb_ctl_array
, unsigned int cb_ctl_array_len
)
538 return __mnl_cb_run(buf
, numbytes
, seq
, portid
, cb_data
, data
,
539 cb_ctl_array
, cb_ctl_array_len
);
542 static int mnl_cb_run(const void *buf
, size_t numbytes
, unsigned int seq
,
543 unsigned int portid
, mnl_cb_t cb_data
, void *data
)
545 return __mnl_cb_run(buf
, numbytes
, seq
, portid
, cb_data
, data
, NULL
, 0);
550 struct sockaddr_nl addr
;
553 static unsigned int mnl_socket_get_portid(const struct mnl_socket
*nl
)
555 return nl
->addr
.nl_pid
;
558 static struct mnl_socket
*__mnl_socket_open(int bus
, int flags
)
560 struct mnl_socket
*nl
;
562 nl
= calloc(1, sizeof(struct mnl_socket
));
566 nl
->fd
= socket(AF_NETLINK
, SOCK_RAW
| flags
, bus
);
575 static struct mnl_socket
*mnl_socket_open(int bus
)
577 return __mnl_socket_open(bus
, 0);
580 static int mnl_socket_bind(struct mnl_socket
*nl
, unsigned int groups
, pid_t pid
)
585 nl
->addr
.nl_family
= AF_NETLINK
;
586 nl
->addr
.nl_groups
= groups
;
587 nl
->addr
.nl_pid
= pid
;
589 ret
= bind(nl
->fd
, (struct sockaddr
*) &nl
->addr
, sizeof (nl
->addr
));
593 addr_len
= sizeof(nl
->addr
);
594 ret
= getsockname(nl
->fd
, (struct sockaddr
*) &nl
->addr
, &addr_len
);
598 if (addr_len
!= sizeof(nl
->addr
)) {
602 if (nl
->addr
.nl_family
!= AF_NETLINK
) {
609 static ssize_t
mnl_socket_sendto(const struct mnl_socket
*nl
, const void *buf
,
612 static const struct sockaddr_nl snl
= {
613 .nl_family
= AF_NETLINK
615 return sendto(nl
->fd
, buf
, len
, 0,
616 (struct sockaddr
*) &snl
, sizeof(snl
));
619 static ssize_t
mnl_socket_recvfrom(const struct mnl_socket
*nl
, void *buf
,
623 struct sockaddr_nl addr
;
628 struct msghdr msg
= {
630 .msg_namelen
= sizeof(struct sockaddr_nl
),
637 ret
= recvmsg(nl
->fd
, &msg
, 0);
641 if (msg
.msg_flags
& MSG_TRUNC
) {
645 if (msg
.msg_namelen
!= sizeof(struct sockaddr_nl
)) {
652 static int mnl_socket_close(struct mnl_socket
*nl
)
654 int ret
= close(nl
->fd
);
659 /* mnlg mini library: */
662 struct mnl_socket
*nl
;
670 static struct nlmsghdr
*__mnlg_msg_prepare(struct mnlg_socket
*nlg
, uint8_t cmd
,
671 uint16_t flags
, uint16_t id
,
674 struct nlmsghdr
*nlh
;
675 struct genlmsghdr
*genl
;
677 nlh
= mnl_nlmsg_put_header(nlg
->buf
);
678 nlh
->nlmsg_type
= id
;
679 nlh
->nlmsg_flags
= flags
;
680 nlg
->seq
= time(NULL
);
681 nlh
->nlmsg_seq
= nlg
->seq
;
683 genl
= mnl_nlmsg_put_extra_header(nlh
, sizeof(struct genlmsghdr
));
685 genl
->version
= version
;
690 static struct nlmsghdr
*mnlg_msg_prepare(struct mnlg_socket
*nlg
, uint8_t cmd
,
693 return __mnlg_msg_prepare(nlg
, cmd
, flags
, nlg
->id
, nlg
->version
);
696 static int mnlg_socket_send(struct mnlg_socket
*nlg
, const struct nlmsghdr
*nlh
)
698 return mnl_socket_sendto(nlg
->nl
, nlh
, nlh
->nlmsg_len
);
701 static int mnlg_cb_noop(const struct nlmsghdr
*nlh
, void *data
)
708 static int mnlg_cb_error(const struct nlmsghdr
*nlh
, void *data
)
710 const struct nlmsgerr
*err
= mnl_nlmsg_get_payload(nlh
);
713 if (nlh
->nlmsg_len
< mnl_nlmsg_size(sizeof(struct nlmsgerr
))) {
717 /* Netlink subsystems returns the errno value with different signess */
723 return err
->error
== 0 ? MNL_CB_STOP
: MNL_CB_ERROR
;
726 static int mnlg_cb_stop(const struct nlmsghdr
*nlh
, void *data
)
729 if (nlh
->nlmsg_flags
& NLM_F_MULTI
&& nlh
->nlmsg_len
== mnl_nlmsg_size(sizeof(int))) {
730 int error
= *(int *)mnl_nlmsg_get_payload(nlh
);
731 /* Netlink subsystems returns the errno value with different signess */
737 return error
== 0 ? MNL_CB_STOP
: MNL_CB_ERROR
;
742 static const mnl_cb_t mnlg_cb_array
[] = {
743 [NLMSG_NOOP
] = mnlg_cb_noop
,
744 [NLMSG_ERROR
] = mnlg_cb_error
,
745 [NLMSG_DONE
] = mnlg_cb_stop
,
746 [NLMSG_OVERRUN
] = mnlg_cb_noop
,
749 static int mnlg_socket_recv_run(struct mnlg_socket
*nlg
, mnl_cb_t data_cb
, void *data
)
754 err
= mnl_socket_recvfrom(nlg
->nl
, nlg
->buf
,
755 mnl_ideal_socket_buffer_size());
758 err
= mnl_cb_run2(nlg
->buf
, err
, nlg
->seq
, nlg
->portid
,
759 data_cb
, data
, mnlg_cb_array
, MNL_ARRAY_SIZE(mnlg_cb_array
));
765 static int get_family_id_attr_cb(const struct nlattr
*attr
, void *data
)
767 const struct nlattr
**tb
= data
;
768 int type
= mnl_attr_get_type(attr
);
770 if (mnl_attr_type_valid(attr
, CTRL_ATTR_MAX
) < 0)
773 if (type
== CTRL_ATTR_FAMILY_ID
&&
774 mnl_attr_validate(attr
, MNL_TYPE_U16
) < 0)
780 static int get_family_id_cb(const struct nlmsghdr
*nlh
, void *data
)
782 uint16_t *p_id
= data
;
783 struct nlattr
*tb
[CTRL_ATTR_MAX
+ 1] = { 0 };
785 mnl_attr_parse(nlh
, sizeof(struct genlmsghdr
), get_family_id_attr_cb
, tb
);
786 if (!tb
[CTRL_ATTR_FAMILY_ID
])
788 *p_id
= mnl_attr_get_u16(tb
[CTRL_ATTR_FAMILY_ID
]);
792 static struct mnlg_socket
*mnlg_socket_open(const char *family_name
, uint8_t version
)
794 struct mnlg_socket
*nlg
;
795 struct nlmsghdr
*nlh
;
798 nlg
= malloc(sizeof(*nlg
));
804 nlg
->buf
= malloc(mnl_ideal_socket_buffer_size());
808 nlg
->nl
= mnl_socket_open(NETLINK_GENERIC
);
811 goto err_mnl_socket_open
;
814 if (mnl_socket_bind(nlg
->nl
, 0, MNL_SOCKET_AUTOPID
) < 0) {
816 goto err_mnl_socket_bind
;
819 nlg
->portid
= mnl_socket_get_portid(nlg
->nl
);
821 nlh
= __mnlg_msg_prepare(nlg
, CTRL_CMD_GETFAMILY
,
822 NLM_F_REQUEST
| NLM_F_ACK
, GENL_ID_CTRL
, 1);
823 mnl_attr_put_strz(nlh
, CTRL_ATTR_FAMILY_NAME
, family_name
);
825 if (mnlg_socket_send(nlg
, nlh
) < 0) {
827 goto err_mnlg_socket_send
;
831 if (mnlg_socket_recv_run(nlg
, get_family_id_cb
, &nlg
->id
) < 0) {
832 errno
= errno
== ENOENT
? EPROTONOSUPPORT
: errno
;
833 err
= errno
? -errno
: -ENOSYS
;
834 goto err_mnlg_socket_recv_run
;
837 nlg
->version
= version
;
841 err_mnlg_socket_recv_run
:
842 err_mnlg_socket_send
:
844 mnl_socket_close(nlg
->nl
);
853 static void mnlg_socket_close(struct mnlg_socket
*nlg
)
855 mnl_socket_close(nlg
->nl
);
860 /* wireguard-specific parts: */
868 static int string_list_add(struct string_list
*list
, const char *str
)
870 size_t len
= strlen(str
) + 1;
875 if (len
>= list
->cap
- list
->len
) {
877 size_t new_cap
= list
->cap
* 2;
879 if (new_cap
< list
->len
+len
+ 1)
880 new_cap
= list
->len
+ len
+ 1;
881 new_buffer
= realloc(list
->buffer
, new_cap
);
884 list
->buffer
= new_buffer
;
887 memcpy(list
->buffer
+ list
->len
, str
, len
);
889 list
->buffer
[list
->len
] = '\0';
898 static int parse_linkinfo(const struct nlattr
*attr
, void *data
)
900 struct interface
*interface
= data
;
902 if (mnl_attr_get_type(attr
) == IFLA_INFO_KIND
&& !strcmp(WG_GENL_NAME
, mnl_attr_get_str(attr
)))
903 interface
->is_wireguard
= true;
907 static int parse_infomsg(const struct nlattr
*attr
, void *data
)
909 struct interface
*interface
= data
;
911 if (mnl_attr_get_type(attr
) == IFLA_LINKINFO
)
912 return mnl_attr_parse_nested(attr
, parse_linkinfo
, data
);
913 else if (mnl_attr_get_type(attr
) == IFLA_IFNAME
)
914 interface
->name
= mnl_attr_get_str(attr
);
918 static int read_devices_cb(const struct nlmsghdr
*nlh
, void *data
)
920 struct string_list
*list
= data
;
921 struct interface interface
= { 0 };
924 ret
= mnl_attr_parse(nlh
, sizeof(struct ifinfomsg
), parse_infomsg
, &interface
);
925 if (ret
!= MNL_CB_OK
)
927 if (interface
.name
&& interface
.is_wireguard
)
928 ret
= string_list_add(list
, interface
.name
);
931 if (nlh
->nlmsg_type
!= NLMSG_DONE
)
932 return MNL_CB_OK
+ 1;
936 static int fetch_device_names(struct string_list
*list
)
938 struct mnl_socket
*nl
= NULL
;
939 char *rtnl_buffer
= NULL
;
941 unsigned int portid
, seq
;
944 struct nlmsghdr
*nlh
;
945 struct ifinfomsg
*ifm
;
948 rtnl_buffer
= calloc(mnl_ideal_socket_buffer_size(), 1);
952 nl
= mnl_socket_open(NETLINK_ROUTE
);
958 if (mnl_socket_bind(nl
, 0, MNL_SOCKET_AUTOPID
) < 0) {
964 portid
= mnl_socket_get_portid(nl
);
965 nlh
= mnl_nlmsg_put_header(rtnl_buffer
);
966 nlh
->nlmsg_type
= RTM_GETLINK
;
967 nlh
->nlmsg_flags
= NLM_F_REQUEST
| NLM_F_ACK
| NLM_F_DUMP
;
968 nlh
->nlmsg_seq
= seq
;
969 ifm
= mnl_nlmsg_put_extra_header(nlh
, sizeof(*ifm
));
970 ifm
->ifi_family
= AF_UNSPEC
;
971 message_len
= nlh
->nlmsg_len
;
973 if (mnl_socket_sendto(nl
, rtnl_buffer
, message_len
) < 0) {
979 if ((len
= mnl_socket_recvfrom(nl
, rtnl_buffer
, mnl_ideal_socket_buffer_size())) < 0) {
983 if ((len
= mnl_cb_run(rtnl_buffer
, len
, seq
, portid
, read_devices_cb
, list
)) < 0) {
984 /* Netlink returns NLM_F_DUMP_INTR if the set of all tunnels changed
985 * during the dump. That's unfortunate, but is pretty common on busy
986 * systems that are adding and removing tunnels all the time. Rather
987 * than retrying, potentially indefinitely, we just work with the
988 * partial results. */
989 if (errno
!= EINTR
) {
994 if (len
== MNL_CB_OK
+ 1)
1001 mnl_socket_close(nl
);
1005 static int add_del_iface(const char *ifname
, bool add
)
1007 struct mnl_socket
*nl
= NULL
;
1011 struct nlmsghdr
*nlh
;
1012 struct ifinfomsg
*ifm
;
1013 struct nlattr
*nest
;
1015 rtnl_buffer
= calloc(mnl_ideal_socket_buffer_size(), 1);
1021 nl
= mnl_socket_open(NETLINK_ROUTE
);
1027 if (mnl_socket_bind(nl
, 0, MNL_SOCKET_AUTOPID
) < 0) {
1032 nlh
= mnl_nlmsg_put_header(rtnl_buffer
);
1033 nlh
->nlmsg_type
= add
? RTM_NEWLINK
: RTM_DELLINK
;
1034 nlh
->nlmsg_flags
= NLM_F_REQUEST
| NLM_F_ACK
| (add
? NLM_F_CREATE
| NLM_F_EXCL
: 0);
1035 nlh
->nlmsg_seq
= time(NULL
);
1036 ifm
= mnl_nlmsg_put_extra_header(nlh
, sizeof(*ifm
));
1037 ifm
->ifi_family
= AF_UNSPEC
;
1038 mnl_attr_put_strz(nlh
, IFLA_IFNAME
, ifname
);
1039 nest
= mnl_attr_nest_start(nlh
, IFLA_LINKINFO
);
1040 mnl_attr_put_strz(nlh
, IFLA_INFO_KIND
, WG_GENL_NAME
);
1041 mnl_attr_nest_end(nlh
, nest
);
1043 if (mnl_socket_sendto(nl
, rtnl_buffer
, nlh
->nlmsg_len
) < 0) {
1047 if ((len
= mnl_socket_recvfrom(nl
, rtnl_buffer
, mnl_ideal_socket_buffer_size())) < 0) {
1051 if (mnl_cb_run(rtnl_buffer
, len
, nlh
->nlmsg_seq
, mnl_socket_get_portid(nl
), NULL
, NULL
) < 0) {
1060 mnl_socket_close(nl
);
1064 int wg_set_device(wg_device
*dev
)
1067 wg_peer
*peer
= NULL
;
1068 wg_allowedip
*allowedip
= NULL
;
1069 struct nlattr
*peers_nest
, *peer_nest
, *allowedips_nest
, *allowedip_nest
;
1070 struct nlmsghdr
*nlh
;
1071 struct mnlg_socket
*nlg
;
1073 nlg
= mnlg_socket_open(WG_GENL_NAME
, WG_GENL_VERSION
);
1078 nlh
= mnlg_msg_prepare(nlg
, WG_CMD_SET_DEVICE
, NLM_F_REQUEST
| NLM_F_ACK
);
1079 mnl_attr_put_strz(nlh
, WGDEVICE_A_IFNAME
, dev
->name
);
1084 if (dev
->flags
& WGDEVICE_HAS_PRIVATE_KEY
)
1085 mnl_attr_put(nlh
, WGDEVICE_A_PRIVATE_KEY
, sizeof(dev
->private_key
), dev
->private_key
);
1086 if (dev
->flags
& WGDEVICE_HAS_LISTEN_PORT
)
1087 mnl_attr_put_u16(nlh
, WGDEVICE_A_LISTEN_PORT
, dev
->listen_port
);
1088 if (dev
->flags
& WGDEVICE_HAS_FWMARK
)
1089 mnl_attr_put_u32(nlh
, WGDEVICE_A_FWMARK
, dev
->fwmark
);
1090 if (dev
->flags
& WGDEVICE_REPLACE_PEERS
)
1091 flags
|= WGDEVICE_F_REPLACE_PEERS
;
1093 mnl_attr_put_u32(nlh
, WGDEVICE_A_FLAGS
, flags
);
1095 if (!dev
->first_peer
)
1097 peers_nest
= peer_nest
= allowedips_nest
= allowedip_nest
= NULL
;
1098 peers_nest
= mnl_attr_nest_start(nlh
, WGDEVICE_A_PEERS
);
1099 for (peer
= peer
? peer
: dev
->first_peer
; peer
; peer
= peer
->next_peer
) {
1102 peer_nest
= mnl_attr_nest_start_check(nlh
, mnl_ideal_socket_buffer_size(), 0);
1105 if (!mnl_attr_put_check(nlh
, mnl_ideal_socket_buffer_size(), WGPEER_A_PUBLIC_KEY
, sizeof(peer
->public_key
), peer
->public_key
))
1107 if (peer
->flags
& WGPEER_REMOVE_ME
)
1108 flags
|= WGPEER_F_REMOVE_ME
;
1110 if (peer
->flags
& WGPEER_REPLACE_ALLOWEDIPS
)
1111 flags
|= WGPEER_F_REPLACE_ALLOWEDIPS
;
1112 if (peer
->flags
& WGPEER_HAS_PRESHARED_KEY
) {
1113 if (!mnl_attr_put_check(nlh
, mnl_ideal_socket_buffer_size(), WGPEER_A_PRESHARED_KEY
, sizeof(peer
->preshared_key
), peer
->preshared_key
))
1116 if (peer
->endpoint
.addr
.sa_family
== AF_INET
) {
1117 if (!mnl_attr_put_check(nlh
, mnl_ideal_socket_buffer_size(), WGPEER_A_ENDPOINT
, sizeof(peer
->endpoint
.addr4
), &peer
->endpoint
.addr4
))
1119 } else if (peer
->endpoint
.addr
.sa_family
== AF_INET6
) {
1120 if (!mnl_attr_put_check(nlh
, mnl_ideal_socket_buffer_size(), WGPEER_A_ENDPOINT
, sizeof(peer
->endpoint
.addr6
), &peer
->endpoint
.addr6
))
1123 if (peer
->flags
& WGPEER_HAS_PERSISTENT_KEEPALIVE_INTERVAL
) {
1124 if (!mnl_attr_put_u16_check(nlh
, mnl_ideal_socket_buffer_size(), WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL
, peer
->persistent_keepalive_interval
))
1129 if (!mnl_attr_put_u32_check(nlh
, mnl_ideal_socket_buffer_size(), WGPEER_A_FLAGS
, flags
))
1132 if (peer
->first_allowedip
) {
1134 allowedip
= peer
->first_allowedip
;
1135 allowedips_nest
= mnl_attr_nest_start_check(nlh
, mnl_ideal_socket_buffer_size(), WGPEER_A_ALLOWEDIPS
);
1136 if (!allowedips_nest
)
1137 goto toobig_allowedips
;
1138 for (; allowedip
; allowedip
= allowedip
->next_allowedip
) {
1139 allowedip_nest
= mnl_attr_nest_start_check(nlh
, mnl_ideal_socket_buffer_size(), 0);
1140 if (!allowedip_nest
)
1141 goto toobig_allowedips
;
1142 if (!mnl_attr_put_u16_check(nlh
, mnl_ideal_socket_buffer_size(), WGALLOWEDIP_A_FAMILY
, allowedip
->family
))
1143 goto toobig_allowedips
;
1144 if (allowedip
->family
== AF_INET
) {
1145 if (!mnl_attr_put_check(nlh
, mnl_ideal_socket_buffer_size(), WGALLOWEDIP_A_IPADDR
, sizeof(allowedip
->ip4
), &allowedip
->ip4
))
1146 goto toobig_allowedips
;
1147 } else if (allowedip
->family
== AF_INET6
) {
1148 if (!mnl_attr_put_check(nlh
, mnl_ideal_socket_buffer_size(), WGALLOWEDIP_A_IPADDR
, sizeof(allowedip
->ip6
), &allowedip
->ip6
))
1149 goto toobig_allowedips
;
1151 if (!mnl_attr_put_u8_check(nlh
, mnl_ideal_socket_buffer_size(), WGALLOWEDIP_A_CIDR_MASK
, allowedip
->cidr
))
1152 goto toobig_allowedips
;
1153 mnl_attr_nest_end(nlh
, allowedip_nest
);
1154 allowedip_nest
= NULL
;
1156 mnl_attr_nest_end(nlh
, allowedips_nest
);
1157 allowedips_nest
= NULL
;
1160 mnl_attr_nest_end(nlh
, peer_nest
);
1163 mnl_attr_nest_end(nlh
, peers_nest
);
1168 mnl_attr_nest_cancel(nlh
, allowedip_nest
);
1169 if (allowedips_nest
)
1170 mnl_attr_nest_end(nlh
, allowedips_nest
);
1171 mnl_attr_nest_end(nlh
, peer_nest
);
1172 mnl_attr_nest_end(nlh
, peers_nest
);
1176 mnl_attr_nest_cancel(nlh
, peer_nest
);
1177 mnl_attr_nest_end(nlh
, peers_nest
);
1180 if (mnlg_socket_send(nlg
, nlh
) < 0) {
1185 if (mnlg_socket_recv_run(nlg
, NULL
, NULL
) < 0) {
1186 ret
= errno
? -errno
: -EINVAL
;
1193 mnlg_socket_close(nlg
);
1198 static int parse_allowedip(const struct nlattr
*attr
, void *data
)
1200 wg_allowedip
*allowedip
= data
;
1202 switch (mnl_attr_get_type(attr
)) {
1203 case WGALLOWEDIP_A_UNSPEC
:
1205 case WGALLOWEDIP_A_FAMILY
:
1206 if (!mnl_attr_validate(attr
, MNL_TYPE_U16
))
1207 allowedip
->family
= mnl_attr_get_u16(attr
);
1209 case WGALLOWEDIP_A_IPADDR
:
1210 if (mnl_attr_get_payload_len(attr
) == sizeof(allowedip
->ip4
))
1211 memcpy(&allowedip
->ip4
, mnl_attr_get_payload(attr
), sizeof(allowedip
->ip4
));
1212 else if (mnl_attr_get_payload_len(attr
) == sizeof(allowedip
->ip6
))
1213 memcpy(&allowedip
->ip6
, mnl_attr_get_payload(attr
), sizeof(allowedip
->ip6
));
1215 case WGALLOWEDIP_A_CIDR_MASK
:
1216 if (!mnl_attr_validate(attr
, MNL_TYPE_U8
))
1217 allowedip
->cidr
= mnl_attr_get_u8(attr
);
1224 static int parse_allowedips(const struct nlattr
*attr
, void *data
)
1226 wg_peer
*peer
= data
;
1227 wg_allowedip
*new_allowedip
= calloc(1, sizeof(wg_allowedip
));
1231 return MNL_CB_ERROR
;
1232 if (!peer
->first_allowedip
)
1233 peer
->first_allowedip
= peer
->last_allowedip
= new_allowedip
;
1235 peer
->last_allowedip
->next_allowedip
= new_allowedip
;
1236 peer
->last_allowedip
= new_allowedip
;
1238 ret
= mnl_attr_parse_nested(attr
, parse_allowedip
, new_allowedip
);
1241 if (!((new_allowedip
->family
== AF_INET
&& new_allowedip
->cidr
<= 32) || (new_allowedip
->family
== AF_INET6
&& new_allowedip
->cidr
<= 128))) {
1242 errno
= EAFNOSUPPORT
;
1243 return MNL_CB_ERROR
;
1248 bool wg_key_is_zero(const wg_key key
)
1250 volatile uint8_t acc
= 0;
1253 for (i
= 0; i
< sizeof(wg_key
); ++i
) {
1255 __asm__ ("" : "=r" (acc
) : "0" (acc
));
1257 return 1 & ((acc
- 1) >> 8);
1260 static int parse_peer(const struct nlattr
*attr
, void *data
)
1262 wg_peer
*peer
= data
;
1264 switch (mnl_attr_get_type(attr
)) {
1265 case WGPEER_A_UNSPEC
:
1267 case WGPEER_A_PUBLIC_KEY
:
1268 if (mnl_attr_get_payload_len(attr
) == sizeof(peer
->public_key
)) {
1269 memcpy(peer
->public_key
, mnl_attr_get_payload(attr
), sizeof(peer
->public_key
));
1270 peer
->flags
|= WGPEER_HAS_PUBLIC_KEY
;
1273 case WGPEER_A_PRESHARED_KEY
:
1274 if (mnl_attr_get_payload_len(attr
) == sizeof(peer
->preshared_key
)) {
1275 memcpy(peer
->preshared_key
, mnl_attr_get_payload(attr
), sizeof(peer
->preshared_key
));
1276 if (!wg_key_is_zero(peer
->preshared_key
))
1277 peer
->flags
|= WGPEER_HAS_PRESHARED_KEY
;
1280 case WGPEER_A_ENDPOINT
: {
1281 struct sockaddr
*addr
;
1283 if (mnl_attr_get_payload_len(attr
) < sizeof(*addr
))
1285 addr
= mnl_attr_get_payload(attr
);
1286 if (addr
->sa_family
== AF_INET
&& mnl_attr_get_payload_len(attr
) == sizeof(peer
->endpoint
.addr4
))
1287 memcpy(&peer
->endpoint
.addr4
, addr
, sizeof(peer
->endpoint
.addr4
));
1288 else if (addr
->sa_family
== AF_INET6
&& mnl_attr_get_payload_len(attr
) == sizeof(peer
->endpoint
.addr6
))
1289 memcpy(&peer
->endpoint
.addr6
, addr
, sizeof(peer
->endpoint
.addr6
));
1292 case WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL
:
1293 if (!mnl_attr_validate(attr
, MNL_TYPE_U16
))
1294 peer
->persistent_keepalive_interval
= mnl_attr_get_u16(attr
);
1296 case WGPEER_A_LAST_HANDSHAKE_TIME
:
1297 if (mnl_attr_get_payload_len(attr
) == sizeof(peer
->last_handshake_time
))
1298 memcpy(&peer
->last_handshake_time
, mnl_attr_get_payload(attr
), sizeof(peer
->last_handshake_time
));
1300 case WGPEER_A_RX_BYTES
:
1301 if (!mnl_attr_validate(attr
, MNL_TYPE_U64
))
1302 peer
->rx_bytes
= mnl_attr_get_u64(attr
);
1304 case WGPEER_A_TX_BYTES
:
1305 if (!mnl_attr_validate(attr
, MNL_TYPE_U64
))
1306 peer
->tx_bytes
= mnl_attr_get_u64(attr
);
1308 case WGPEER_A_ALLOWEDIPS
:
1309 return mnl_attr_parse_nested(attr
, parse_allowedips
, peer
);
1315 static int parse_peers(const struct nlattr
*attr
, void *data
)
1317 wg_device
*device
= data
;
1318 wg_peer
*new_peer
= calloc(1, sizeof(wg_peer
));
1322 return MNL_CB_ERROR
;
1323 if (!device
->first_peer
)
1324 device
->first_peer
= device
->last_peer
= new_peer
;
1326 device
->last_peer
->next_peer
= new_peer
;
1327 device
->last_peer
= new_peer
;
1329 ret
= mnl_attr_parse_nested(attr
, parse_peer
, new_peer
);
1332 if (!(new_peer
->flags
& WGPEER_HAS_PUBLIC_KEY
)) {
1334 return MNL_CB_ERROR
;
1339 static int parse_device(const struct nlattr
*attr
, void *data
)
1341 wg_device
*device
= data
;
1343 switch (mnl_attr_get_type(attr
)) {
1344 case WGDEVICE_A_UNSPEC
:
1346 case WGDEVICE_A_IFINDEX
:
1347 if (!mnl_attr_validate(attr
, MNL_TYPE_U32
))
1348 device
->ifindex
= mnl_attr_get_u32(attr
);
1350 case WGDEVICE_A_IFNAME
:
1351 if (!mnl_attr_validate(attr
, MNL_TYPE_STRING
)) {
1352 strncpy(device
->name
, mnl_attr_get_str(attr
), sizeof(device
->name
) - 1);
1353 device
->name
[sizeof(device
->name
) - 1] = '\0';
1356 case WGDEVICE_A_PRIVATE_KEY
:
1357 if (mnl_attr_get_payload_len(attr
) == sizeof(device
->private_key
)) {
1358 memcpy(device
->private_key
, mnl_attr_get_payload(attr
), sizeof(device
->private_key
));
1359 device
->flags
|= WGDEVICE_HAS_PRIVATE_KEY
;
1362 case WGDEVICE_A_PUBLIC_KEY
:
1363 if (mnl_attr_get_payload_len(attr
) == sizeof(device
->public_key
)) {
1364 memcpy(device
->public_key
, mnl_attr_get_payload(attr
), sizeof(device
->public_key
));
1365 device
->flags
|= WGDEVICE_HAS_PUBLIC_KEY
;
1368 case WGDEVICE_A_LISTEN_PORT
:
1369 if (!mnl_attr_validate(attr
, MNL_TYPE_U16
))
1370 device
->listen_port
= mnl_attr_get_u16(attr
);
1372 case WGDEVICE_A_FWMARK
:
1373 if (!mnl_attr_validate(attr
, MNL_TYPE_U32
))
1374 device
->fwmark
= mnl_attr_get_u32(attr
);
1376 case WGDEVICE_A_PEERS
:
1377 return mnl_attr_parse_nested(attr
, parse_peers
, device
);
1383 static int read_device_cb(const struct nlmsghdr
*nlh
, void *data
)
1385 return mnl_attr_parse(nlh
, sizeof(struct genlmsghdr
), parse_device
, data
);
1388 static void coalesce_peers(wg_device
*device
)
1390 wg_peer
*old_next_peer
, *peer
= device
->first_peer
;
1392 while (peer
&& peer
->next_peer
) {
1393 if (memcmp(peer
->public_key
, peer
->next_peer
->public_key
, sizeof(wg_key
))) {
1394 peer
= peer
->next_peer
;
1397 if (!peer
->first_allowedip
) {
1398 peer
->first_allowedip
= peer
->next_peer
->first_allowedip
;
1399 peer
->last_allowedip
= peer
->next_peer
->last_allowedip
;
1401 peer
->last_allowedip
->next_allowedip
= peer
->next_peer
->first_allowedip
;
1402 peer
->last_allowedip
= peer
->next_peer
->last_allowedip
;
1404 old_next_peer
= peer
->next_peer
;
1405 peer
->next_peer
= old_next_peer
->next_peer
;
1406 free(old_next_peer
);
1410 int wg_get_device(wg_device
**device
, const char *device_name
)
1413 struct nlmsghdr
*nlh
;
1414 struct mnlg_socket
*nlg
;
1417 *device
= calloc(1, sizeof(wg_device
));
1421 nlg
= mnlg_socket_open(WG_GENL_NAME
, WG_GENL_VERSION
);
1423 wg_free_device(*device
);
1428 nlh
= mnlg_msg_prepare(nlg
, WG_CMD_GET_DEVICE
, NLM_F_REQUEST
| NLM_F_ACK
| NLM_F_DUMP
);
1429 mnl_attr_put_strz(nlh
, WGDEVICE_A_IFNAME
, device_name
);
1430 if (mnlg_socket_send(nlg
, nlh
) < 0) {
1435 if (mnlg_socket_recv_run(nlg
, read_device_cb
, *device
) < 0) {
1436 ret
= errno
? -errno
: -EINVAL
;
1439 coalesce_peers(*device
);
1443 mnlg_socket_close(nlg
);
1445 wg_free_device(*device
);
1454 /* first\0second\0third\0forth\0last\0\0 */
1455 char *wg_list_device_names(void)
1457 struct string_list list
= { 0 };
1458 int ret
= fetch_device_names(&list
);
1465 return list
.buffer
?: strdup("\0");
1468 int wg_add_device(const char *device_name
)
1470 return add_del_iface(device_name
, true);
1473 int wg_del_device(const char *device_name
)
1475 return add_del_iface(device_name
, false);
1478 void wg_free_device(wg_device
*dev
)
1481 wg_allowedip
*allowedip
, *na
;
1485 for (peer
= dev
->first_peer
, np
= peer
? peer
->next_peer
: NULL
; peer
; peer
= np
, np
= peer
? peer
->next_peer
: NULL
) {
1486 for (allowedip
= peer
->first_allowedip
, na
= allowedip
? allowedip
->next_allowedip
: NULL
; allowedip
; allowedip
= na
, na
= allowedip
? allowedip
->next_allowedip
: NULL
)
1493 static void encode_base64(char dest
[static 4], const uint8_t src
[static 3])
1495 const uint8_t input
[] = { (src
[0] >> 2) & 63, ((src
[0] << 4) | (src
[1] >> 4)) & 63, ((src
[1] << 2) | (src
[2] >> 6)) & 63, src
[2] & 63 };
1498 for (i
= 0; i
< 4; ++i
)
1499 dest
[i
] = input
[i
] + 'A'
1500 + (((25 - input
[i
]) >> 8) & 6)
1501 - (((51 - input
[i
]) >> 8) & 75)
1502 - (((61 - input
[i
]) >> 8) & 15)
1503 + (((62 - input
[i
]) >> 8) & 3);
1507 void wg_key_to_base64(wg_key_b64_string base64
, const wg_key key
)
1511 for (i
= 0; i
< 32 / 3; ++i
)
1512 encode_base64(&base64
[i
* 4], &key
[i
* 3]);
1513 encode_base64(&base64
[i
* 4], (const uint8_t[]){ key
[i
* 3 + 0], key
[i
* 3 + 1], 0 });
1514 base64
[sizeof(wg_key_b64_string
) - 2] = '=';
1515 base64
[sizeof(wg_key_b64_string
) - 1] = '\0';
1518 static int decode_base64(const char src
[static 4])
1523 for (i
= 0; i
< 4; ++i
)
1525 + ((((('A' - 1) - src
[i
]) & (src
[i
] - ('Z' + 1))) >> 8) & (src
[i
] - 64))
1526 + ((((('a' - 1) - src
[i
]) & (src
[i
] - ('z' + 1))) >> 8) & (src
[i
] - 70))
1527 + ((((('0' - 1) - src
[i
]) & (src
[i
] - ('9' + 1))) >> 8) & (src
[i
] + 5))
1528 + ((((('+' - 1) - src
[i
]) & (src
[i
] - ('+' + 1))) >> 8) & 63)
1529 + ((((('/' - 1) - src
[i
]) & (src
[i
] - ('/' + 1))) >> 8) & 64)
1534 int wg_key_from_base64(wg_key key
, const wg_key_b64_string base64
)
1538 volatile uint8_t ret
= 0;
1540 if (strlen(base64
) != sizeof(wg_key_b64_string
) - 1 || base64
[sizeof(wg_key_b64_string
) - 2] != '=') {
1545 for (i
= 0; i
< 32 / 3; ++i
) {
1546 val
= decode_base64(&base64
[i
* 4]);
1547 ret
|= (uint32_t)val
>> 31;
1548 key
[i
* 3 + 0] = (val
>> 16) & 0xff;
1549 key
[i
* 3 + 1] = (val
>> 8) & 0xff;
1550 key
[i
* 3 + 2] = val
& 0xff;
1552 val
= decode_base64((const char[]){ base64
[i
* 4 + 0], base64
[i
* 4 + 1], base64
[i
* 4 + 2], 'A' });
1553 ret
|= ((uint32_t)val
>> 31) | (val
& 0xff);
1554 key
[i
* 3 + 0] = (val
>> 16) & 0xff;
1555 key
[i
* 3 + 1] = (val
>> 8) & 0xff;
1556 errno
= EINVAL
& ~((ret
- 1) >> 8);
1561 typedef int64_t fe
[16];
1563 static __attribute__((noinline
)) void memzero_explicit(void *s
, size_t count
)
1565 memset(s
, 0, count
);
1566 __asm__
__volatile__("": :"r"(s
) :"memory");
1569 static void carry(fe o
)
1573 for (i
= 0; i
< 16; ++i
) {
1574 o
[(i
+ 1) % 16] += (i
== 15 ? 38 : 1) * (o
[i
] >> 16);
1579 static void cswap(fe p
, fe q
, int b
)
1582 int64_t t
, c
= ~(b
- 1);
1584 for (i
= 0; i
< 16; ++i
) {
1585 t
= c
& (p
[i
] ^ q
[i
]);
1590 memzero_explicit(&t
, sizeof(t
));
1591 memzero_explicit(&c
, sizeof(c
));
1592 memzero_explicit(&b
, sizeof(b
));
1595 static void pack(uint8_t *o
, const fe n
)
1600 memcpy(t
, n
, sizeof(t
));
1604 for (j
= 0; j
< 2; ++j
) {
1605 m
[0] = t
[0] - 0xffed;
1606 for (i
= 1; i
< 15; ++i
) {
1607 m
[i
] = t
[i
] - 0xffff - ((m
[i
- 1] >> 16) & 1);
1610 m
[15] = t
[15] - 0x7fff - ((m
[14] >> 16) & 1);
1611 b
= (m
[15] >> 16) & 1;
1615 for (i
= 0; i
< 16; ++i
) {
1616 o
[2 * i
] = t
[i
] & 0xff;
1617 o
[2 * i
+ 1] = t
[i
] >> 8;
1620 memzero_explicit(m
, sizeof(m
));
1621 memzero_explicit(t
, sizeof(t
));
1622 memzero_explicit(&b
, sizeof(b
));
1625 static void add(fe o
, const fe a
, const fe b
)
1629 for (i
= 0; i
< 16; ++i
)
1633 static void subtract(fe o
, const fe a
, const fe b
)
1637 for (i
= 0; i
< 16; ++i
)
1641 static void multmod(fe o
, const fe a
, const fe b
)
1644 int64_t t
[31] = { 0 };
1646 for (i
= 0; i
< 16; ++i
) {
1647 for (j
= 0; j
< 16; ++j
)
1648 t
[i
+ j
] += a
[i
] * b
[j
];
1650 for (i
= 0; i
< 15; ++i
)
1651 t
[i
] += 38 * t
[i
+ 16];
1652 memcpy(o
, t
, sizeof(fe
));
1656 memzero_explicit(t
, sizeof(t
));
1659 static void invert(fe o
, const fe i
)
1664 memcpy(c
, i
, sizeof(c
));
1665 for (a
= 253; a
>= 0; --a
) {
1667 if (a
!= 2 && a
!= 4)
1670 memcpy(o
, c
, sizeof(fe
));
1672 memzero_explicit(c
, sizeof(c
));
1675 static void clamp_key(uint8_t *z
)
1677 z
[31] = (z
[31] & 127) | 64;
1681 void wg_generate_public_key(wg_key public_key
, const wg_key private_key
)
1685 fe a
= { 1 }, b
= { 9 }, c
= { 0 }, d
= { 1 }, e
, f
;
1687 memcpy(z
, private_key
, sizeof(z
));
1690 for (i
= 254; i
>= 0; --i
) {
1691 r
= (z
[i
>> 3] >> (i
& 7)) & 1;
1706 multmod(a
, c
, (const fe
){ 0xdb41, 1 });
1710 multmod(d
, b
, (const fe
){ 9 });
1717 pack(public_key
, a
);
1719 memzero_explicit(&r
, sizeof(r
));
1720 memzero_explicit(z
, sizeof(z
));
1721 memzero_explicit(a
, sizeof(a
));
1722 memzero_explicit(b
, sizeof(b
));
1723 memzero_explicit(c
, sizeof(c
));
1724 memzero_explicit(d
, sizeof(d
));
1725 memzero_explicit(e
, sizeof(e
));
1726 memzero_explicit(f
, sizeof(f
));
1729 void wg_generate_private_key(wg_key private_key
)
1731 wg_generate_preshared_key(private_key
);
1732 clamp_key(private_key
);
1735 void wg_generate_preshared_key(wg_key preshared_key
)
1740 #if defined(__OpenBSD__) || (defined(__APPLE__) && MAC_OS_X_VERSION_MIN_REQUIRED >= MAC_OS_X_VERSION_10_12) || (defined(__GLIBC__) && (__GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ >= 25)))
1741 if (!getentropy(preshared_key
, sizeof(wg_key
)))
1744 #if defined(__NR_getrandom) && defined(__linux__)
1745 if (syscall(__NR_getrandom
, preshared_key
, sizeof(wg_key
), 0) == sizeof(wg_key
))
1748 fd
= open("/dev/urandom", O_RDONLY
);
1750 for (i
= 0; i
< sizeof(wg_key
); i
+= ret
) {
1751 ret
= read(fd
, preshared_key
+ i
, sizeof(wg_key
) - i
);