kernel: 5.4: import wireguard backport
[openwrt/openwrt.git] / target / linux / generic / backport-5.4 / 080-wireguard-0105-wireguard-noise-separate-receive-counter-from-send-c.patch
1 From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
2 From: "Jason A. Donenfeld" <Jason@zx2c4.com>
3 Date: Tue, 19 May 2020 22:49:30 -0600
4 Subject: [PATCH] wireguard: noise: separate receive counter from send counter
5
6 commit a9e90d9931f3a474f04bab782ccd9d77904941e9 upstream.
7
8 In "wireguard: queueing: preserve flow hash across packet scrubbing", we
9 were required to slightly increase the size of the receive replay
10 counter to something still fairly small, but an increase nonetheless.
11 It turns out that we can recoup some of the additional memory overhead
12 by splitting up the prior union type into two distinct types. Before, we
13 used the same "noise_counter" union for both sending and receiving, with
14 sending just using a simple atomic64_t, while receiving used the full
15 replay counter checker. This meant that most of the memory being
16 allocated for the sending counter was being wasted. Since the old
17 "noise_counter" type increased in size in the prior commit, now is a
18 good time to split up that union type into a distinct "noise_replay_
19 counter" for receiving and a boring atomic64_t for sending, each using
20 neither more nor less memory than required.
21
22 Also, since sometimes the replay counter is accessed without
23 necessitating additional accesses to the bitmap, we can reduce cache
24 misses by hoisting the always-necessary lock above the bitmap in the
25 struct layout. We also change a "noise_replay_counter" stack allocation
26 to kmalloc in a -DDEBUG selftest so that KASAN doesn't trigger a stack
27 frame warning.
28
29 All and all, removing a bit of abstraction in this commit makes the code
30 simpler and smaller, in addition to the motivating memory usage
31 recuperation. For example, passing around raw "noise_symmetric_key"
32 structs is something that really only makes sense within noise.c, in the
33 one place where the sending and receiving keys can safely be thought of
34 as the same type of object; subsequent to that, it's important that we
35 uniformly access these through keypair->{sending,receiving}, where their
36 distinct roles are always made explicit. So this patch allows us to draw
37 that distinction clearly as well.
38
39 Fixes: e7096c131e51 ("net: WireGuard secure network tunnel")
40 Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
41 Signed-off-by: David S. Miller <davem@davemloft.net>
42 Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
43 ---
44 drivers/net/wireguard/noise.c | 16 +++------
45 drivers/net/wireguard/noise.h | 14 ++++----
46 drivers/net/wireguard/receive.c | 42 ++++++++++++------------
47 drivers/net/wireguard/selftest/counter.c | 17 +++++++---
48 drivers/net/wireguard/send.c | 12 +++----
49 5 files changed, 48 insertions(+), 53 deletions(-)
50
51 --- a/drivers/net/wireguard/noise.c
52 +++ b/drivers/net/wireguard/noise.c
53 @@ -104,6 +104,7 @@ static struct noise_keypair *keypair_cre
54
55 if (unlikely(!keypair))
56 return NULL;
57 + spin_lock_init(&keypair->receiving_counter.lock);
58 keypair->internal_id = atomic64_inc_return(&keypair_counter);
59 keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
60 keypair->entry.peer = peer;
61 @@ -358,25 +359,16 @@ out:
62 memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
63 }
64
65 -static void symmetric_key_init(struct noise_symmetric_key *key)
66 -{
67 - spin_lock_init(&key->counter.receive.lock);
68 - atomic64_set(&key->counter.counter, 0);
69 - memset(key->counter.receive.backtrack, 0,
70 - sizeof(key->counter.receive.backtrack));
71 - key->birthdate = ktime_get_coarse_boottime_ns();
72 - key->is_valid = true;
73 -}
74 -
75 static void derive_keys(struct noise_symmetric_key *first_dst,
76 struct noise_symmetric_key *second_dst,
77 const u8 chaining_key[NOISE_HASH_LEN])
78 {
79 + u64 birthdate = ktime_get_coarse_boottime_ns();
80 kdf(first_dst->key, second_dst->key, NULL, NULL,
81 NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
82 chaining_key);
83 - symmetric_key_init(first_dst);
84 - symmetric_key_init(second_dst);
85 + first_dst->birthdate = second_dst->birthdate = birthdate;
86 + first_dst->is_valid = second_dst->is_valid = true;
87 }
88
89 static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
90 --- a/drivers/net/wireguard/noise.h
91 +++ b/drivers/net/wireguard/noise.h
92 @@ -15,18 +15,14 @@
93 #include <linux/mutex.h>
94 #include <linux/kref.h>
95
96 -union noise_counter {
97 - struct {
98 - u64 counter;
99 - unsigned long backtrack[COUNTER_BITS_TOTAL / BITS_PER_LONG];
100 - spinlock_t lock;
101 - } receive;
102 - atomic64_t counter;
103 +struct noise_replay_counter {
104 + u64 counter;
105 + spinlock_t lock;
106 + unsigned long backtrack[COUNTER_BITS_TOTAL / BITS_PER_LONG];
107 };
108
109 struct noise_symmetric_key {
110 u8 key[NOISE_SYMMETRIC_KEY_LEN];
111 - union noise_counter counter;
112 u64 birthdate;
113 bool is_valid;
114 };
115 @@ -34,7 +30,9 @@ struct noise_symmetric_key {
116 struct noise_keypair {
117 struct index_hashtable_entry entry;
118 struct noise_symmetric_key sending;
119 + atomic64_t sending_counter;
120 struct noise_symmetric_key receiving;
121 + struct noise_replay_counter receiving_counter;
122 __le32 remote_index;
123 bool i_am_the_initiator;
124 struct kref refcount;
125 --- a/drivers/net/wireguard/receive.c
126 +++ b/drivers/net/wireguard/receive.c
127 @@ -245,20 +245,20 @@ static void keep_key_fresh(struct wg_pee
128 }
129 }
130
131 -static bool decrypt_packet(struct sk_buff *skb, struct noise_symmetric_key *key)
132 +static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair)
133 {
134 struct scatterlist sg[MAX_SKB_FRAGS + 8];
135 struct sk_buff *trailer;
136 unsigned int offset;
137 int num_frags;
138
139 - if (unlikely(!key))
140 + if (unlikely(!keypair))
141 return false;
142
143 - if (unlikely(!READ_ONCE(key->is_valid) ||
144 - wg_birthdate_has_expired(key->birthdate, REJECT_AFTER_TIME) ||
145 - key->counter.receive.counter >= REJECT_AFTER_MESSAGES)) {
146 - WRITE_ONCE(key->is_valid, false);
147 + if (unlikely(!READ_ONCE(keypair->receiving.is_valid) ||
148 + wg_birthdate_has_expired(keypair->receiving.birthdate, REJECT_AFTER_TIME) ||
149 + keypair->receiving_counter.counter >= REJECT_AFTER_MESSAGES)) {
150 + WRITE_ONCE(keypair->receiving.is_valid, false);
151 return false;
152 }
153
154 @@ -283,7 +283,7 @@ static bool decrypt_packet(struct sk_buf
155
156 if (!chacha20poly1305_decrypt_sg_inplace(sg, skb->len, NULL, 0,
157 PACKET_CB(skb)->nonce,
158 - key->key))
159 + keypair->receiving.key))
160 return false;
161
162 /* Another ugly situation of pushing and pulling the header so as to
163 @@ -298,41 +298,41 @@ static bool decrypt_packet(struct sk_buf
164 }
165
166 /* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */
167 -static bool counter_validate(union noise_counter *counter, u64 their_counter)
168 +static bool counter_validate(struct noise_replay_counter *counter, u64 their_counter)
169 {
170 unsigned long index, index_current, top, i;
171 bool ret = false;
172
173 - spin_lock_bh(&counter->receive.lock);
174 + spin_lock_bh(&counter->lock);
175
176 - if (unlikely(counter->receive.counter >= REJECT_AFTER_MESSAGES + 1 ||
177 + if (unlikely(counter->counter >= REJECT_AFTER_MESSAGES + 1 ||
178 their_counter >= REJECT_AFTER_MESSAGES))
179 goto out;
180
181 ++their_counter;
182
183 if (unlikely((COUNTER_WINDOW_SIZE + their_counter) <
184 - counter->receive.counter))
185 + counter->counter))
186 goto out;
187
188 index = their_counter >> ilog2(BITS_PER_LONG);
189
190 - if (likely(their_counter > counter->receive.counter)) {
191 - index_current = counter->receive.counter >> ilog2(BITS_PER_LONG);
192 + if (likely(their_counter > counter->counter)) {
193 + index_current = counter->counter >> ilog2(BITS_PER_LONG);
194 top = min_t(unsigned long, index - index_current,
195 COUNTER_BITS_TOTAL / BITS_PER_LONG);
196 for (i = 1; i <= top; ++i)
197 - counter->receive.backtrack[(i + index_current) &
198 + counter->backtrack[(i + index_current) &
199 ((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0;
200 - counter->receive.counter = their_counter;
201 + counter->counter = their_counter;
202 }
203
204 index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1;
205 ret = !test_and_set_bit(their_counter & (BITS_PER_LONG - 1),
206 - &counter->receive.backtrack[index]);
207 + &counter->backtrack[index]);
208
209 out:
210 - spin_unlock_bh(&counter->receive.lock);
211 + spin_unlock_bh(&counter->lock);
212 return ret;
213 }
214
215 @@ -472,12 +472,12 @@ int wg_packet_rx_poll(struct napi_struct
216 if (unlikely(state != PACKET_STATE_CRYPTED))
217 goto next;
218
219 - if (unlikely(!counter_validate(&keypair->receiving.counter,
220 + if (unlikely(!counter_validate(&keypair->receiving_counter,
221 PACKET_CB(skb)->nonce))) {
222 net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n",
223 peer->device->dev->name,
224 PACKET_CB(skb)->nonce,
225 - keypair->receiving.counter.receive.counter);
226 + keypair->receiving_counter.counter);
227 goto next;
228 }
229
230 @@ -511,8 +511,8 @@ void wg_packet_decrypt_worker(struct wor
231 struct sk_buff *skb;
232
233 while ((skb = ptr_ring_consume_bh(&queue->ring)) != NULL) {
234 - enum packet_state state = likely(decrypt_packet(skb,
235 - &PACKET_CB(skb)->keypair->receiving)) ?
236 + enum packet_state state =
237 + likely(decrypt_packet(skb, PACKET_CB(skb)->keypair)) ?
238 PACKET_STATE_CRYPTED : PACKET_STATE_DEAD;
239 wg_queue_enqueue_per_peer_napi(skb, state);
240 if (need_resched())
241 --- a/drivers/net/wireguard/selftest/counter.c
242 +++ b/drivers/net/wireguard/selftest/counter.c
243 @@ -6,18 +6,24 @@
244 #ifdef DEBUG
245 bool __init wg_packet_counter_selftest(void)
246 {
247 + struct noise_replay_counter *counter;
248 unsigned int test_num = 0, i;
249 - union noise_counter counter;
250 bool success = true;
251
252 -#define T_INIT do { \
253 - memset(&counter, 0, sizeof(union noise_counter)); \
254 - spin_lock_init(&counter.receive.lock); \
255 + counter = kmalloc(sizeof(*counter), GFP_KERNEL);
256 + if (unlikely(!counter)) {
257 + pr_err("nonce counter self-test malloc: FAIL\n");
258 + return false;
259 + }
260 +
261 +#define T_INIT do { \
262 + memset(counter, 0, sizeof(*counter)); \
263 + spin_lock_init(&counter->lock); \
264 } while (0)
265 #define T_LIM (COUNTER_WINDOW_SIZE + 1)
266 #define T(n, v) do { \
267 ++test_num; \
268 - if (counter_validate(&counter, n) != (v)) { \
269 + if (counter_validate(counter, n) != (v)) { \
270 pr_err("nonce counter self-test %u: FAIL\n", \
271 test_num); \
272 success = false; \
273 @@ -99,6 +105,7 @@ bool __init wg_packet_counter_selftest(v
274
275 if (success)
276 pr_info("nonce counter self-tests: pass\n");
277 + kfree(counter);
278 return success;
279 }
280 #endif
281 --- a/drivers/net/wireguard/send.c
282 +++ b/drivers/net/wireguard/send.c
283 @@ -129,7 +129,7 @@ static void keep_key_fresh(struct wg_pee
284 rcu_read_lock_bh();
285 keypair = rcu_dereference_bh(peer->keypairs.current_keypair);
286 send = keypair && READ_ONCE(keypair->sending.is_valid) &&
287 - (atomic64_read(&keypair->sending.counter.counter) > REKEY_AFTER_MESSAGES ||
288 + (atomic64_read(&keypair->sending_counter) > REKEY_AFTER_MESSAGES ||
289 (keypair->i_am_the_initiator &&
290 wg_birthdate_has_expired(keypair->sending.birthdate, REKEY_AFTER_TIME)));
291 rcu_read_unlock_bh();
292 @@ -349,7 +349,6 @@ void wg_packet_purge_staged_packets(stru
293
294 void wg_packet_send_staged_packets(struct wg_peer *peer)
295 {
296 - struct noise_symmetric_key *key;
297 struct noise_keypair *keypair;
298 struct sk_buff_head packets;
299 struct sk_buff *skb;
300 @@ -369,10 +368,9 @@ void wg_packet_send_staged_packets(struc
301 rcu_read_unlock_bh();
302 if (unlikely(!keypair))
303 goto out_nokey;
304 - key = &keypair->sending;
305 - if (unlikely(!READ_ONCE(key->is_valid)))
306 + if (unlikely(!READ_ONCE(keypair->sending.is_valid)))
307 goto out_nokey;
308 - if (unlikely(wg_birthdate_has_expired(key->birthdate,
309 + if (unlikely(wg_birthdate_has_expired(keypair->sending.birthdate,
310 REJECT_AFTER_TIME)))
311 goto out_invalid;
312
313 @@ -387,7 +385,7 @@ void wg_packet_send_staged_packets(struc
314 */
315 PACKET_CB(skb)->ds = ip_tunnel_ecn_encap(0, ip_hdr(skb), skb);
316 PACKET_CB(skb)->nonce =
317 - atomic64_inc_return(&key->counter.counter) - 1;
318 + atomic64_inc_return(&keypair->sending_counter) - 1;
319 if (unlikely(PACKET_CB(skb)->nonce >= REJECT_AFTER_MESSAGES))
320 goto out_invalid;
321 }
322 @@ -399,7 +397,7 @@ void wg_packet_send_staged_packets(struc
323 return;
324
325 out_invalid:
326 - WRITE_ONCE(key->is_valid, false);
327 + WRITE_ONCE(keypair->sending.is_valid, false);
328 out_nokey:
329 wg_noise_keypair_put(keypair, false);
330