kernel: 5.4: import wireguard backport
[openwrt/openwrt.git] / target / linux / generic / backport-5.4 / 080-wireguard-0107-wireguard-device-avoid-circular-netns-references.patch
1 From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
2 From: "Jason A. Donenfeld" <Jason@zx2c4.com>
3 Date: Tue, 23 Jun 2020 03:59:45 -0600
4 Subject: [PATCH] wireguard: device: avoid circular netns references
5
6 commit 900575aa33a3eaaef802b31de187a85c4a4b4bd0 upstream.
7
8 Before, we took a reference to the creating netns if the new netns was
9 different. This caused issues with circular references, with two
10 wireguard interfaces swapping namespaces. The solution is to rather not
11 take any extra references at all, but instead simply invalidate the
12 creating netns pointer when that netns is deleted.
13
14 In order to prevent this from happening again, this commit improves the
15 rough object leak tracking by allowing it to account for created and
16 destroyed interfaces, aside from just peers and keys. That then makes it
17 possible to check for the object leak when having two interfaces take a
18 reference to each others' namespaces.
19
20 Fixes: e7096c131e51 ("net: WireGuard secure network tunnel")
21 Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
22 Signed-off-by: David S. Miller <davem@davemloft.net>
23 Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
24 ---
25 drivers/net/wireguard/device.c | 58 ++++++++++------------
26 drivers/net/wireguard/device.h | 3 +-
27 drivers/net/wireguard/netlink.c | 14 ++++--
28 drivers/net/wireguard/socket.c | 25 +++++++---
29 tools/testing/selftests/wireguard/netns.sh | 13 ++++-
30 5 files changed, 67 insertions(+), 46 deletions(-)
31
32 --- a/drivers/net/wireguard/device.c
33 +++ b/drivers/net/wireguard/device.c
34 @@ -45,17 +45,18 @@ static int wg_open(struct net_device *de
35 if (dev_v6)
36 dev_v6->cnf.addr_gen_mode = IN6_ADDR_GEN_MODE_NONE;
37
38 + mutex_lock(&wg->device_update_lock);
39 ret = wg_socket_init(wg, wg->incoming_port);
40 if (ret < 0)
41 - return ret;
42 - mutex_lock(&wg->device_update_lock);
43 + goto out;
44 list_for_each_entry(peer, &wg->peer_list, peer_list) {
45 wg_packet_send_staged_packets(peer);
46 if (peer->persistent_keepalive_interval)
47 wg_packet_send_keepalive(peer);
48 }
49 +out:
50 mutex_unlock(&wg->device_update_lock);
51 - return 0;
52 + return ret;
53 }
54
55 #ifdef CONFIG_PM_SLEEP
56 @@ -225,6 +226,7 @@ static void wg_destruct(struct net_devic
57 list_del(&wg->device_list);
58 rtnl_unlock();
59 mutex_lock(&wg->device_update_lock);
60 + rcu_assign_pointer(wg->creating_net, NULL);
61 wg->incoming_port = 0;
62 wg_socket_reinit(wg, NULL, NULL);
63 /* The final references are cleared in the below calls to destroy_workqueue. */
64 @@ -240,13 +242,11 @@ static void wg_destruct(struct net_devic
65 skb_queue_purge(&wg->incoming_handshakes);
66 free_percpu(dev->tstats);
67 free_percpu(wg->incoming_handshakes_worker);
68 - if (wg->have_creating_net_ref)
69 - put_net(wg->creating_net);
70 kvfree(wg->index_hashtable);
71 kvfree(wg->peer_hashtable);
72 mutex_unlock(&wg->device_update_lock);
73
74 - pr_debug("%s: Interface deleted\n", dev->name);
75 + pr_debug("%s: Interface destroyed\n", dev->name);
76 free_netdev(dev);
77 }
78
79 @@ -292,7 +292,7 @@ static int wg_newlink(struct net *src_ne
80 struct wg_device *wg = netdev_priv(dev);
81 int ret = -ENOMEM;
82
83 - wg->creating_net = src_net;
84 + rcu_assign_pointer(wg->creating_net, src_net);
85 init_rwsem(&wg->static_identity.lock);
86 mutex_init(&wg->socket_update_lock);
87 mutex_init(&wg->device_update_lock);
88 @@ -393,30 +393,26 @@ static struct rtnl_link_ops link_ops __r
89 .newlink = wg_newlink,
90 };
91
92 -static int wg_netdevice_notification(struct notifier_block *nb,
93 - unsigned long action, void *data)
94 +static void wg_netns_pre_exit(struct net *net)
95 {
96 - struct net_device *dev = ((struct netdev_notifier_info *)data)->dev;
97 - struct wg_device *wg = netdev_priv(dev);
98 -
99 - ASSERT_RTNL();
100 -
101 - if (action != NETDEV_REGISTER || dev->netdev_ops != &netdev_ops)
102 - return 0;
103 + struct wg_device *wg;
104
105 - if (dev_net(dev) == wg->creating_net && wg->have_creating_net_ref) {
106 - put_net(wg->creating_net);
107 - wg->have_creating_net_ref = false;
108 - } else if (dev_net(dev) != wg->creating_net &&
109 - !wg->have_creating_net_ref) {
110 - wg->have_creating_net_ref = true;
111 - get_net(wg->creating_net);
112 + rtnl_lock();
113 + list_for_each_entry(wg, &device_list, device_list) {
114 + if (rcu_access_pointer(wg->creating_net) == net) {
115 + pr_debug("%s: Creating namespace exiting\n", wg->dev->name);
116 + netif_carrier_off(wg->dev);
117 + mutex_lock(&wg->device_update_lock);
118 + rcu_assign_pointer(wg->creating_net, NULL);
119 + wg_socket_reinit(wg, NULL, NULL);
120 + mutex_unlock(&wg->device_update_lock);
121 + }
122 }
123 - return 0;
124 + rtnl_unlock();
125 }
126
127 -static struct notifier_block netdevice_notifier = {
128 - .notifier_call = wg_netdevice_notification
129 +static struct pernet_operations pernet_ops = {
130 + .pre_exit = wg_netns_pre_exit
131 };
132
133 int __init wg_device_init(void)
134 @@ -429,18 +425,18 @@ int __init wg_device_init(void)
135 return ret;
136 #endif
137
138 - ret = register_netdevice_notifier(&netdevice_notifier);
139 + ret = register_pernet_device(&pernet_ops);
140 if (ret)
141 goto error_pm;
142
143 ret = rtnl_link_register(&link_ops);
144 if (ret)
145 - goto error_netdevice;
146 + goto error_pernet;
147
148 return 0;
149
150 -error_netdevice:
151 - unregister_netdevice_notifier(&netdevice_notifier);
152 +error_pernet:
153 + unregister_pernet_device(&pernet_ops);
154 error_pm:
155 #ifdef CONFIG_PM_SLEEP
156 unregister_pm_notifier(&pm_notifier);
157 @@ -451,7 +447,7 @@ error_pm:
158 void wg_device_uninit(void)
159 {
160 rtnl_link_unregister(&link_ops);
161 - unregister_netdevice_notifier(&netdevice_notifier);
162 + unregister_pernet_device(&pernet_ops);
163 #ifdef CONFIG_PM_SLEEP
164 unregister_pm_notifier(&pm_notifier);
165 #endif
166 --- a/drivers/net/wireguard/device.h
167 +++ b/drivers/net/wireguard/device.h
168 @@ -40,7 +40,7 @@ struct wg_device {
169 struct net_device *dev;
170 struct crypt_queue encrypt_queue, decrypt_queue;
171 struct sock __rcu *sock4, *sock6;
172 - struct net *creating_net;
173 + struct net __rcu *creating_net;
174 struct noise_static_identity static_identity;
175 struct workqueue_struct *handshake_receive_wq, *handshake_send_wq;
176 struct workqueue_struct *packet_crypt_wq;
177 @@ -56,7 +56,6 @@ struct wg_device {
178 unsigned int num_peers, device_update_gen;
179 u32 fwmark;
180 u16 incoming_port;
181 - bool have_creating_net_ref;
182 };
183
184 int wg_device_init(void);
185 --- a/drivers/net/wireguard/netlink.c
186 +++ b/drivers/net/wireguard/netlink.c
187 @@ -517,11 +517,15 @@ static int wg_set_device(struct sk_buff
188 if (flags & ~__WGDEVICE_F_ALL)
189 goto out;
190
191 - ret = -EPERM;
192 - if ((info->attrs[WGDEVICE_A_LISTEN_PORT] ||
193 - info->attrs[WGDEVICE_A_FWMARK]) &&
194 - !ns_capable(wg->creating_net->user_ns, CAP_NET_ADMIN))
195 - goto out;
196 + if (info->attrs[WGDEVICE_A_LISTEN_PORT] || info->attrs[WGDEVICE_A_FWMARK]) {
197 + struct net *net;
198 + rcu_read_lock();
199 + net = rcu_dereference(wg->creating_net);
200 + ret = !net || !ns_capable(net->user_ns, CAP_NET_ADMIN) ? -EPERM : 0;
201 + rcu_read_unlock();
202 + if (ret)
203 + goto out;
204 + }
205
206 ++wg->device_update_gen;
207
208 --- a/drivers/net/wireguard/socket.c
209 +++ b/drivers/net/wireguard/socket.c
210 @@ -347,6 +347,7 @@ static void set_sock_opts(struct socket
211
212 int wg_socket_init(struct wg_device *wg, u16 port)
213 {
214 + struct net *net;
215 int ret;
216 struct udp_tunnel_sock_cfg cfg = {
217 .sk_user_data = wg,
218 @@ -371,37 +372,47 @@ int wg_socket_init(struct wg_device *wg,
219 };
220 #endif
221
222 + rcu_read_lock();
223 + net = rcu_dereference(wg->creating_net);
224 + net = net ? maybe_get_net(net) : NULL;
225 + rcu_read_unlock();
226 + if (unlikely(!net))
227 + return -ENONET;
228 +
229 #if IS_ENABLED(CONFIG_IPV6)
230 retry:
231 #endif
232
233 - ret = udp_sock_create(wg->creating_net, &port4, &new4);
234 + ret = udp_sock_create(net, &port4, &new4);
235 if (ret < 0) {
236 pr_err("%s: Could not create IPv4 socket\n", wg->dev->name);
237 - return ret;
238 + goto out;
239 }
240 set_sock_opts(new4);
241 - setup_udp_tunnel_sock(wg->creating_net, new4, &cfg);
242 + setup_udp_tunnel_sock(net, new4, &cfg);
243
244 #if IS_ENABLED(CONFIG_IPV6)
245 if (ipv6_mod_enabled()) {
246 port6.local_udp_port = inet_sk(new4->sk)->inet_sport;
247 - ret = udp_sock_create(wg->creating_net, &port6, &new6);
248 + ret = udp_sock_create(net, &port6, &new6);
249 if (ret < 0) {
250 udp_tunnel_sock_release(new4);
251 if (ret == -EADDRINUSE && !port && retries++ < 100)
252 goto retry;
253 pr_err("%s: Could not create IPv6 socket\n",
254 wg->dev->name);
255 - return ret;
256 + goto out;
257 }
258 set_sock_opts(new6);
259 - setup_udp_tunnel_sock(wg->creating_net, new6, &cfg);
260 + setup_udp_tunnel_sock(net, new6, &cfg);
261 }
262 #endif
263
264 wg_socket_reinit(wg, new4->sk, new6 ? new6->sk : NULL);
265 - return 0;
266 + ret = 0;
267 +out:
268 + put_net(net);
269 + return ret;
270 }
271
272 void wg_socket_reinit(struct wg_device *wg, struct sock *new4,
273 --- a/tools/testing/selftests/wireguard/netns.sh
274 +++ b/tools/testing/selftests/wireguard/netns.sh
275 @@ -587,9 +587,20 @@ ip0 link set wg0 up
276 kill $ncat_pid
277 ip0 link del wg0
278
279 +# Ensure there aren't circular reference loops
280 +ip1 link add wg1 type wireguard
281 +ip2 link add wg2 type wireguard
282 +ip1 link set wg1 netns $netns2
283 +ip2 link set wg2 netns $netns1
284 +pp ip netns delete $netns1
285 +pp ip netns delete $netns2
286 +pp ip netns add $netns1
287 +pp ip netns add $netns2
288 +
289 +sleep 2 # Wait for cleanup and grace periods
290 declare -A objects
291 while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do
292 - [[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ [0-9]+)\ .*(created|destroyed).* ]] || continue
293 + [[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ ?[0-9]*)\ .*(created|destroyed).* ]] || continue
294 objects["${BASH_REMATCH[1]}"]+="${BASH_REMATCH[2]}"
295 done < /dev/kmsg
296 alldeleted=1