kernel: 5.4: import wireguard backport
[openwrt/openwrt.git] / target / linux / generic / backport-5.4 / 080-wireguard-0106-wireguard-noise-separate-receive-counter-from-send-c.patch
1 From 044b98abbb08fabca5c2cff426023f1f52448efc Mon Sep 17 00:00:00 2001
2 From: "Jason A. Donenfeld" <Jason@zx2c4.com>
3 Date: Tue, 19 May 2020 22:49:30 -0600
4 Subject: [PATCH 106/124] wireguard: noise: separate receive counter from send
5 counter
6
7 commit a9e90d9931f3a474f04bab782ccd9d77904941e9 upstream.
8
9 In "wireguard: queueing: preserve flow hash across packet scrubbing", we
10 were required to slightly increase the size of the receive replay
11 counter to something still fairly small, but an increase nonetheless.
12 It turns out that we can recoup some of the additional memory overhead
13 by splitting up the prior union type into two distinct types. Before, we
14 used the same "noise_counter" union for both sending and receiving, with
15 sending just using a simple atomic64_t, while receiving used the full
16 replay counter checker. This meant that most of the memory being
17 allocated for the sending counter was being wasted. Since the old
18 "noise_counter" type increased in size in the prior commit, now is a
19 good time to split up that union type into a distinct "noise_replay_
20 counter" for receiving and a boring atomic64_t for sending, each using
21 neither more nor less memory than required.
22
23 Also, since sometimes the replay counter is accessed without
24 necessitating additional accesses to the bitmap, we can reduce cache
25 misses by hoisting the always-necessary lock above the bitmap in the
26 struct layout. We also change a "noise_replay_counter" stack allocation
27 to kmalloc in a -DDEBUG selftest so that KASAN doesn't trigger a stack
28 frame warning.
29
30 All and all, removing a bit of abstraction in this commit makes the code
31 simpler and smaller, in addition to the motivating memory usage
32 recuperation. For example, passing around raw "noise_symmetric_key"
33 structs is something that really only makes sense within noise.c, in the
34 one place where the sending and receiving keys can safely be thought of
35 as the same type of object; subsequent to that, it's important that we
36 uniformly access these through keypair->{sending,receiving}, where their
37 distinct roles are always made explicit. So this patch allows us to draw
38 that distinction clearly as well.
39
40 Fixes: e7096c131e51 ("net: WireGuard secure network tunnel")
41 Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
42 Signed-off-by: David S. Miller <davem@davemloft.net>
43 Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
44 ---
45 drivers/net/wireguard/noise.c | 16 +++------
46 drivers/net/wireguard/noise.h | 14 ++++----
47 drivers/net/wireguard/receive.c | 42 ++++++++++++------------
48 drivers/net/wireguard/selftest/counter.c | 17 +++++++---
49 drivers/net/wireguard/send.c | 12 +++----
50 5 files changed, 48 insertions(+), 53 deletions(-)
51
52 --- a/drivers/net/wireguard/noise.c
53 +++ b/drivers/net/wireguard/noise.c
54 @@ -104,6 +104,7 @@ static struct noise_keypair *keypair_cre
55
56 if (unlikely(!keypair))
57 return NULL;
58 + spin_lock_init(&keypair->receiving_counter.lock);
59 keypair->internal_id = atomic64_inc_return(&keypair_counter);
60 keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
61 keypair->entry.peer = peer;
62 @@ -358,25 +359,16 @@ out:
63 memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
64 }
65
66 -static void symmetric_key_init(struct noise_symmetric_key *key)
67 -{
68 - spin_lock_init(&key->counter.receive.lock);
69 - atomic64_set(&key->counter.counter, 0);
70 - memset(key->counter.receive.backtrack, 0,
71 - sizeof(key->counter.receive.backtrack));
72 - key->birthdate = ktime_get_coarse_boottime_ns();
73 - key->is_valid = true;
74 -}
75 -
76 static void derive_keys(struct noise_symmetric_key *first_dst,
77 struct noise_symmetric_key *second_dst,
78 const u8 chaining_key[NOISE_HASH_LEN])
79 {
80 + u64 birthdate = ktime_get_coarse_boottime_ns();
81 kdf(first_dst->key, second_dst->key, NULL, NULL,
82 NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
83 chaining_key);
84 - symmetric_key_init(first_dst);
85 - symmetric_key_init(second_dst);
86 + first_dst->birthdate = second_dst->birthdate = birthdate;
87 + first_dst->is_valid = second_dst->is_valid = true;
88 }
89
90 static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
91 --- a/drivers/net/wireguard/noise.h
92 +++ b/drivers/net/wireguard/noise.h
93 @@ -15,18 +15,14 @@
94 #include <linux/mutex.h>
95 #include <linux/kref.h>
96
97 -union noise_counter {
98 - struct {
99 - u64 counter;
100 - unsigned long backtrack[COUNTER_BITS_TOTAL / BITS_PER_LONG];
101 - spinlock_t lock;
102 - } receive;
103 - atomic64_t counter;
104 +struct noise_replay_counter {
105 + u64 counter;
106 + spinlock_t lock;
107 + unsigned long backtrack[COUNTER_BITS_TOTAL / BITS_PER_LONG];
108 };
109
110 struct noise_symmetric_key {
111 u8 key[NOISE_SYMMETRIC_KEY_LEN];
112 - union noise_counter counter;
113 u64 birthdate;
114 bool is_valid;
115 };
116 @@ -34,7 +30,9 @@ struct noise_symmetric_key {
117 struct noise_keypair {
118 struct index_hashtable_entry entry;
119 struct noise_symmetric_key sending;
120 + atomic64_t sending_counter;
121 struct noise_symmetric_key receiving;
122 + struct noise_replay_counter receiving_counter;
123 __le32 remote_index;
124 bool i_am_the_initiator;
125 struct kref refcount;
126 --- a/drivers/net/wireguard/receive.c
127 +++ b/drivers/net/wireguard/receive.c
128 @@ -245,20 +245,20 @@ static void keep_key_fresh(struct wg_pee
129 }
130 }
131
132 -static bool decrypt_packet(struct sk_buff *skb, struct noise_symmetric_key *key)
133 +static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair)
134 {
135 struct scatterlist sg[MAX_SKB_FRAGS + 8];
136 struct sk_buff *trailer;
137 unsigned int offset;
138 int num_frags;
139
140 - if (unlikely(!key))
141 + if (unlikely(!keypair))
142 return false;
143
144 - if (unlikely(!READ_ONCE(key->is_valid) ||
145 - wg_birthdate_has_expired(key->birthdate, REJECT_AFTER_TIME) ||
146 - key->counter.receive.counter >= REJECT_AFTER_MESSAGES)) {
147 - WRITE_ONCE(key->is_valid, false);
148 + if (unlikely(!READ_ONCE(keypair->receiving.is_valid) ||
149 + wg_birthdate_has_expired(keypair->receiving.birthdate, REJECT_AFTER_TIME) ||
150 + keypair->receiving_counter.counter >= REJECT_AFTER_MESSAGES)) {
151 + WRITE_ONCE(keypair->receiving.is_valid, false);
152 return false;
153 }
154
155 @@ -283,7 +283,7 @@ static bool decrypt_packet(struct sk_buf
156
157 if (!chacha20poly1305_decrypt_sg_inplace(sg, skb->len, NULL, 0,
158 PACKET_CB(skb)->nonce,
159 - key->key))
160 + keypair->receiving.key))
161 return false;
162
163 /* Another ugly situation of pushing and pulling the header so as to
164 @@ -298,41 +298,41 @@ static bool decrypt_packet(struct sk_buf
165 }
166
167 /* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */
168 -static bool counter_validate(union noise_counter *counter, u64 their_counter)
169 +static bool counter_validate(struct noise_replay_counter *counter, u64 their_counter)
170 {
171 unsigned long index, index_current, top, i;
172 bool ret = false;
173
174 - spin_lock_bh(&counter->receive.lock);
175 + spin_lock_bh(&counter->lock);
176
177 - if (unlikely(counter->receive.counter >= REJECT_AFTER_MESSAGES + 1 ||
178 + if (unlikely(counter->counter >= REJECT_AFTER_MESSAGES + 1 ||
179 their_counter >= REJECT_AFTER_MESSAGES))
180 goto out;
181
182 ++their_counter;
183
184 if (unlikely((COUNTER_WINDOW_SIZE + their_counter) <
185 - counter->receive.counter))
186 + counter->counter))
187 goto out;
188
189 index = their_counter >> ilog2(BITS_PER_LONG);
190
191 - if (likely(their_counter > counter->receive.counter)) {
192 - index_current = counter->receive.counter >> ilog2(BITS_PER_LONG);
193 + if (likely(their_counter > counter->counter)) {
194 + index_current = counter->counter >> ilog2(BITS_PER_LONG);
195 top = min_t(unsigned long, index - index_current,
196 COUNTER_BITS_TOTAL / BITS_PER_LONG);
197 for (i = 1; i <= top; ++i)
198 - counter->receive.backtrack[(i + index_current) &
199 + counter->backtrack[(i + index_current) &
200 ((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0;
201 - counter->receive.counter = their_counter;
202 + counter->counter = their_counter;
203 }
204
205 index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1;
206 ret = !test_and_set_bit(their_counter & (BITS_PER_LONG - 1),
207 - &counter->receive.backtrack[index]);
208 + &counter->backtrack[index]);
209
210 out:
211 - spin_unlock_bh(&counter->receive.lock);
212 + spin_unlock_bh(&counter->lock);
213 return ret;
214 }
215
216 @@ -472,12 +472,12 @@ int wg_packet_rx_poll(struct napi_struct
217 if (unlikely(state != PACKET_STATE_CRYPTED))
218 goto next;
219
220 - if (unlikely(!counter_validate(&keypair->receiving.counter,
221 + if (unlikely(!counter_validate(&keypair->receiving_counter,
222 PACKET_CB(skb)->nonce))) {
223 net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n",
224 peer->device->dev->name,
225 PACKET_CB(skb)->nonce,
226 - keypair->receiving.counter.receive.counter);
227 + keypair->receiving_counter.counter);
228 goto next;
229 }
230
231 @@ -511,8 +511,8 @@ void wg_packet_decrypt_worker(struct wor
232 struct sk_buff *skb;
233
234 while ((skb = ptr_ring_consume_bh(&queue->ring)) != NULL) {
235 - enum packet_state state = likely(decrypt_packet(skb,
236 - &PACKET_CB(skb)->keypair->receiving)) ?
237 + enum packet_state state =
238 + likely(decrypt_packet(skb, PACKET_CB(skb)->keypair)) ?
239 PACKET_STATE_CRYPTED : PACKET_STATE_DEAD;
240 wg_queue_enqueue_per_peer_napi(skb, state);
241 if (need_resched())
242 --- a/drivers/net/wireguard/selftest/counter.c
243 +++ b/drivers/net/wireguard/selftest/counter.c
244 @@ -6,18 +6,24 @@
245 #ifdef DEBUG
246 bool __init wg_packet_counter_selftest(void)
247 {
248 + struct noise_replay_counter *counter;
249 unsigned int test_num = 0, i;
250 - union noise_counter counter;
251 bool success = true;
252
253 -#define T_INIT do { \
254 - memset(&counter, 0, sizeof(union noise_counter)); \
255 - spin_lock_init(&counter.receive.lock); \
256 + counter = kmalloc(sizeof(*counter), GFP_KERNEL);
257 + if (unlikely(!counter)) {
258 + pr_err("nonce counter self-test malloc: FAIL\n");
259 + return false;
260 + }
261 +
262 +#define T_INIT do { \
263 + memset(counter, 0, sizeof(*counter)); \
264 + spin_lock_init(&counter->lock); \
265 } while (0)
266 #define T_LIM (COUNTER_WINDOW_SIZE + 1)
267 #define T(n, v) do { \
268 ++test_num; \
269 - if (counter_validate(&counter, n) != (v)) { \
270 + if (counter_validate(counter, n) != (v)) { \
271 pr_err("nonce counter self-test %u: FAIL\n", \
272 test_num); \
273 success = false; \
274 @@ -99,6 +105,7 @@ bool __init wg_packet_counter_selftest(v
275
276 if (success)
277 pr_info("nonce counter self-tests: pass\n");
278 + kfree(counter);
279 return success;
280 }
281 #endif
282 --- a/drivers/net/wireguard/send.c
283 +++ b/drivers/net/wireguard/send.c
284 @@ -129,7 +129,7 @@ static void keep_key_fresh(struct wg_pee
285 rcu_read_lock_bh();
286 keypair = rcu_dereference_bh(peer->keypairs.current_keypair);
287 send = keypair && READ_ONCE(keypair->sending.is_valid) &&
288 - (atomic64_read(&keypair->sending.counter.counter) > REKEY_AFTER_MESSAGES ||
289 + (atomic64_read(&keypair->sending_counter) > REKEY_AFTER_MESSAGES ||
290 (keypair->i_am_the_initiator &&
291 wg_birthdate_has_expired(keypair->sending.birthdate, REKEY_AFTER_TIME)));
292 rcu_read_unlock_bh();
293 @@ -349,7 +349,6 @@ void wg_packet_purge_staged_packets(stru
294
295 void wg_packet_send_staged_packets(struct wg_peer *peer)
296 {
297 - struct noise_symmetric_key *key;
298 struct noise_keypair *keypair;
299 struct sk_buff_head packets;
300 struct sk_buff *skb;
301 @@ -369,10 +368,9 @@ void wg_packet_send_staged_packets(struc
302 rcu_read_unlock_bh();
303 if (unlikely(!keypair))
304 goto out_nokey;
305 - key = &keypair->sending;
306 - if (unlikely(!READ_ONCE(key->is_valid)))
307 + if (unlikely(!READ_ONCE(keypair->sending.is_valid)))
308 goto out_nokey;
309 - if (unlikely(wg_birthdate_has_expired(key->birthdate,
310 + if (unlikely(wg_birthdate_has_expired(keypair->sending.birthdate,
311 REJECT_AFTER_TIME)))
312 goto out_invalid;
313
314 @@ -387,7 +385,7 @@ void wg_packet_send_staged_packets(struc
315 */
316 PACKET_CB(skb)->ds = ip_tunnel_ecn_encap(0, ip_hdr(skb), skb);
317 PACKET_CB(skb)->nonce =
318 - atomic64_inc_return(&key->counter.counter) - 1;
319 + atomic64_inc_return(&keypair->sending_counter) - 1;
320 if (unlikely(PACKET_CB(skb)->nonce >= REJECT_AFTER_MESSAGES))
321 goto out_invalid;
322 }
323 @@ -399,7 +397,7 @@ void wg_packet_send_staged_packets(struc
324 return;
325
326 out_invalid:
327 - WRITE_ONCE(key->is_valid, false);
328 + WRITE_ONCE(keypair->sending.is_valid, false);
329 out_nokey:
330 wg_noise_keypair_put(keypair, false);
331