1 // SPDX-License-Identifier: GPL-2.0 OR MIT
3 * Copyright (C) 2016-2017 INRIA and Microsoft Corporation.
4 * Copyright (C) 2018-2020 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
6 * This is a machine-generated formally verified implementation of Curve25519
7 * ECDH from: <https://github.com/mitls/hacl-star>. Though originally machine
8 * generated, it has been tweaked to be suitable for use in the kernel. It is
9 * optimized for 64-bit machines that can efficiently work with 128-bit
13 typedef __uint128_t u128
;
15 static __always_inline u64
u64_eq_mask(u64 a
, u64 b
)
18 u64 minus_x
= ~x
+ (u64
)1U;
19 u64 x_or_minus_x
= x
| minus_x
;
20 u64 xnx
= x_or_minus_x
>> (u32
)63U;
21 u64 c
= xnx
- (u64
)1U;
25 static __always_inline u64
u64_gte_mask(u64 a
, u64 b
)
31 u64 x_sub_y_xor_y
= x_sub_y
^ y
;
32 u64 q
= x_xor_y
| x_sub_y_xor_y
;
34 u64 x_xor_q_
= x_xor_q
>> (u32
)63U;
35 u64 c
= x_xor_q_
- (u64
)1U;
39 static __always_inline
void modulo_carry_top(u64
*b
)
43 u64 b4_
= b4
& 0x7ffffffffffffLLU
;
44 u64 b0_
= b0
+ 19 * (b4
>> 51);
49 static __always_inline
void fproduct_copy_from_wide_(u64
*output
, u128
*input
)
53 output
[0] = ((u64
)(xi
));
57 output
[1] = ((u64
)(xi
));
61 output
[2] = ((u64
)(xi
));
65 output
[3] = ((u64
)(xi
));
69 output
[4] = ((u64
)(xi
));
73 static __always_inline
void
74 fproduct_sum_scalar_multiplication_(u128
*output
, u64
*input
, u64 s
)
76 output
[0] += (u128
)input
[0] * s
;
77 output
[1] += (u128
)input
[1] * s
;
78 output
[2] += (u128
)input
[2] * s
;
79 output
[3] += (u128
)input
[3] * s
;
80 output
[4] += (u128
)input
[4] * s
;
83 static __always_inline
void fproduct_carry_wide_(u128
*tmp
)
88 u128 tctrp1
= tmp
[ctr
+ 1];
89 u64 r0
= ((u64
)(tctr
)) & 0x7ffffffffffffLLU
;
90 u128 c
= ((tctr
) >> (51));
91 tmp
[ctr
] = ((u128
)(r0
));
92 tmp
[ctr
+ 1] = ((tctrp1
) + (c
));
97 u128 tctrp1
= tmp
[ctr
+ 1];
98 u64 r0
= ((u64
)(tctr
)) & 0x7ffffffffffffLLU
;
99 u128 c
= ((tctr
) >> (51));
100 tmp
[ctr
] = ((u128
)(r0
));
101 tmp
[ctr
+ 1] = ((tctrp1
) + (c
));
106 u128 tctr
= tmp
[ctr
];
107 u128 tctrp1
= tmp
[ctr
+ 1];
108 u64 r0
= ((u64
)(tctr
)) & 0x7ffffffffffffLLU
;
109 u128 c
= ((tctr
) >> (51));
110 tmp
[ctr
] = ((u128
)(r0
));
111 tmp
[ctr
+ 1] = ((tctrp1
) + (c
));
115 u128 tctr
= tmp
[ctr
];
116 u128 tctrp1
= tmp
[ctr
+ 1];
117 u64 r0
= ((u64
)(tctr
)) & 0x7ffffffffffffLLU
;
118 u128 c
= ((tctr
) >> (51));
119 tmp
[ctr
] = ((u128
)(r0
));
120 tmp
[ctr
+ 1] = ((tctrp1
) + (c
));
124 static __always_inline
void fmul_shift_reduce(u64
*output
)
130 u64 z
= output
[ctr
- 1];
135 u64 z
= output
[ctr
- 1];
140 u64 z
= output
[ctr
- 1];
145 u64 z
= output
[ctr
- 1];
153 static __always_inline
void fmul_mul_shift_reduce_(u128
*output
, u64
*input
,
159 u64 input2i
= input21
[0];
160 fproduct_sum_scalar_multiplication_(output
, input
, input2i
);
161 fmul_shift_reduce(input
);
164 u64 input2i
= input21
[1];
165 fproduct_sum_scalar_multiplication_(output
, input
, input2i
);
166 fmul_shift_reduce(input
);
169 u64 input2i
= input21
[2];
170 fproduct_sum_scalar_multiplication_(output
, input
, input2i
);
171 fmul_shift_reduce(input
);
174 u64 input2i
= input21
[3];
175 fproduct_sum_scalar_multiplication_(output
, input
, input2i
);
176 fmul_shift_reduce(input
);
179 input2i
= input21
[i
];
180 fproduct_sum_scalar_multiplication_(output
, input
, input2i
);
183 static __always_inline
void fmul_fmul(u64
*output
, u64
*input
, u64
*input21
)
185 u64 tmp
[5] = { input
[0], input
[1], input
[2], input
[3], input
[4] };
196 fmul_mul_shift_reduce_(t
, tmp
, input21
);
197 fproduct_carry_wide_(t
);
200 b4_
= ((b4
) & (((u128
)(0x7ffffffffffffLLU
))));
201 b0_
= ((b0
) + (((u128
)(19) * (((u64
)(((b4
) >> (51))))))));
204 fproduct_copy_from_wide_(output
, t
);
207 i0_
= i0
& 0x7ffffffffffffLLU
;
208 i1_
= i1
+ (i0
>> 51);
214 static __always_inline
void fsquare_fsquare__(u128
*tmp
, u64
*output
)
223 u64 d2
= r2
* 2 * 19;
226 u128 s0
= ((((((u128
)(r0
) * (r0
))) + (((u128
)(d4
) * (r1
))))) +
227 (((u128
)(d2
) * (r3
))));
228 u128 s1
= ((((((u128
)(d0
) * (r1
))) + (((u128
)(d4
) * (r2
))))) +
229 (((u128
)(r3
* 19) * (r3
))));
230 u128 s2
= ((((((u128
)(d0
) * (r2
))) + (((u128
)(r1
) * (r1
))))) +
231 (((u128
)(d4
) * (r3
))));
232 u128 s3
= ((((((u128
)(d0
) * (r3
))) + (((u128
)(d1
) * (r2
))))) +
233 (((u128
)(r4
) * (d419
))));
234 u128 s4
= ((((((u128
)(d0
) * (r4
))) + (((u128
)(d1
) * (r3
))))) +
235 (((u128
)(r2
) * (r2
))));
243 static __always_inline
void fsquare_fsquare_(u128
*tmp
, u64
*output
)
253 fsquare_fsquare__(tmp
, output
);
254 fproduct_carry_wide_(tmp
);
257 b4_
= ((b4
) & (((u128
)(0x7ffffffffffffLLU
))));
258 b0_
= ((b0
) + (((u128
)(19) * (((u64
)(((b4
) >> (51))))))));
261 fproduct_copy_from_wide_(output
, tmp
);
264 i0_
= i0
& 0x7ffffffffffffLLU
;
265 i1_
= i1
+ (i0
>> 51);
270 static __always_inline
void fsquare_fsquare_times_(u64
*output
, u128
*tmp
,
274 fsquare_fsquare_(tmp
, output
);
275 for (i
= 1; i
< count1
; ++i
)
276 fsquare_fsquare_(tmp
, output
);
279 static __always_inline
void fsquare_fsquare_times(u64
*output
, u64
*input
,
283 memcpy(output
, input
, 5 * sizeof(*input
));
284 fsquare_fsquare_times_(output
, t
, count1
);
287 static __always_inline
void fsquare_fsquare_times_inplace(u64
*output
,
291 fsquare_fsquare_times_(output
, t
, count1
);
294 static __always_inline
void crecip_crecip(u64
*out
, u64
*z
)
307 fsquare_fsquare_times(a0
, z
, 1);
308 fsquare_fsquare_times(t00
, a0
, 2);
309 fmul_fmul(b0
, t00
, z
);
310 fmul_fmul(a0
, b0
, a0
);
311 fsquare_fsquare_times(t00
, a0
, 1);
312 fmul_fmul(b0
, t00
, b0
);
313 fsquare_fsquare_times(t00
, b0
, 5);
317 fmul_fmul(b1
, t01
, b1
);
318 fsquare_fsquare_times(t01
, b1
, 10);
319 fmul_fmul(c0
, t01
, b1
);
320 fsquare_fsquare_times(t01
, c0
, 20);
321 fmul_fmul(t01
, t01
, c0
);
322 fsquare_fsquare_times_inplace(t01
, 10);
323 fmul_fmul(b1
, t01
, b1
);
324 fsquare_fsquare_times(t01
, b1
, 50);
330 fsquare_fsquare_times(t0
, c
, 100);
331 fmul_fmul(t0
, t0
, c
);
332 fsquare_fsquare_times_inplace(t0
, 50);
333 fmul_fmul(t0
, t0
, b
);
334 fsquare_fsquare_times_inplace(t0
, 5);
335 fmul_fmul(out
, t0
, a
);
338 static __always_inline
void fsum(u64
*a
, u64
*b
)
347 static __always_inline
void fdifference(u64
*a
, u64
*b
)
355 memcpy(tmp
, b
, 5 * sizeof(*b
));
361 tmp
[0] = b0
+ 0x3fffffffffff68LLU
;
362 tmp
[1] = b1
+ 0x3ffffffffffff8LLU
;
363 tmp
[2] = b2
+ 0x3ffffffffffff8LLU
;
364 tmp
[3] = b3
+ 0x3ffffffffffff8LLU
;
365 tmp
[4] = b4
+ 0x3ffffffffffff8LLU
;
393 static __always_inline
void fscalar(u64
*output
, u64
*b
, u64 s
)
402 tmp
[0] = ((u128
)(xi
) * (s
));
406 tmp
[1] = ((u128
)(xi
) * (s
));
410 tmp
[2] = ((u128
)(xi
) * (s
));
414 tmp
[3] = ((u128
)(xi
) * (s
));
418 tmp
[4] = ((u128
)(xi
) * (s
));
420 fproduct_carry_wide_(tmp
);
423 b4_
= ((b4
) & (((u128
)(0x7ffffffffffffLLU
))));
424 b0_
= ((b0
) + (((u128
)(19) * (((u64
)(((b4
) >> (51))))))));
427 fproduct_copy_from_wide_(output
, tmp
);
430 static __always_inline
void fmul(u64
*output
, u64
*a
, u64
*b
)
432 fmul_fmul(output
, a
, b
);
435 static __always_inline
void crecip(u64
*output
, u64
*input
)
437 crecip_crecip(output
, input
);
440 static __always_inline
void point_swap_conditional_step(u64
*a
, u64
*b
,
446 u64 x
= swap1
& (ai
^ bi
);
453 static __always_inline
void point_swap_conditional5(u64
*a
, u64
*b
, u64 swap1
)
455 point_swap_conditional_step(a
, b
, swap1
, 5);
456 point_swap_conditional_step(a
, b
, swap1
, 4);
457 point_swap_conditional_step(a
, b
, swap1
, 3);
458 point_swap_conditional_step(a
, b
, swap1
, 2);
459 point_swap_conditional_step(a
, b
, swap1
, 1);
462 static __always_inline
void point_swap_conditional(u64
*a
, u64
*b
, u64 iswap
)
464 u64 swap1
= 0 - iswap
;
465 point_swap_conditional5(a
, b
, swap1
);
466 point_swap_conditional5(a
+ 5, b
+ 5, swap1
);
469 static __always_inline
void point_copy(u64
*output
, u64
*input
)
471 memcpy(output
, input
, 5 * sizeof(*input
));
472 memcpy(output
+ 5, input
+ 5, 5 * sizeof(*input
));
475 static __always_inline
void addanddouble_fmonty(u64
*pp
, u64
*ppq
, u64
*p
,
486 u64
*zprime
= pq
+ 5;
489 u64
*origxprime0
= buf
+ 5;
495 memcpy(origx
, x
, 5 * sizeof(*x
));
497 fdifference(z
, origx
);
498 memcpy(origxprime0
, xprime
, 5 * sizeof(*xprime
));
499 fsum(xprime
, zprime
);
500 fdifference(zprime
, origxprime0
);
501 fmul(xxprime0
, xprime
, z
);
502 fmul(zzprime0
, x
, zprime
);
503 origxprime
= buf
+ 5;
515 memcpy(origxprime
, xxprime
, 5 * sizeof(*xxprime
));
516 fsum(xxprime
, zzprime
);
517 fdifference(zzprime
, origxprime
);
518 fsquare_fsquare_times(x3
, xxprime
, 1);
519 fsquare_fsquare_times(zzzprime
, zzprime
, 1);
520 fmul(z3
, zzzprime
, qx
);
521 fsquare_fsquare_times(xx0
, x
, 1);
522 fsquare_fsquare_times(zz0
, z
, 1);
534 fscalar(zzz
, zz
, scalar
);
541 static __always_inline
void
542 ladder_smallloop_cmult_small_loop_step(u64
*nq
, u64
*nqpq
, u64
*nq2
, u64
*nqpq2
,
545 u64 bit0
= (u64
)(byt
>> 7);
547 point_swap_conditional(nq
, nqpq
, bit0
);
548 addanddouble_fmonty(nq2
, nqpq2
, nq
, nqpq
, q
);
549 bit
= (u64
)(byt
>> 7);
550 point_swap_conditional(nq2
, nqpq2
, bit
);
553 static __always_inline
void
554 ladder_smallloop_cmult_small_loop_double_step(u64
*nq
, u64
*nqpq
, u64
*nq2
,
555 u64
*nqpq2
, u64
*q
, u8 byt
)
558 ladder_smallloop_cmult_small_loop_step(nq
, nqpq
, nq2
, nqpq2
, q
, byt
);
560 ladder_smallloop_cmult_small_loop_step(nq2
, nqpq2
, nq
, nqpq
, q
, byt1
);
563 static __always_inline
void
564 ladder_smallloop_cmult_small_loop(u64
*nq
, u64
*nqpq
, u64
*nq2
, u64
*nqpq2
,
565 u64
*q
, u8 byt
, u32 i
)
568 ladder_smallloop_cmult_small_loop_double_step(nq
, nqpq
, nq2
,
574 static __always_inline
void ladder_bigloop_cmult_big_loop(u8
*n1
, u64
*nq
,
581 ladder_smallloop_cmult_small_loop(nq
, nqpq
, nq2
, nqpq2
, q
,
586 static void ladder_cmult(u64
*result
, u8
*n1
, u64
*q
)
588 u64 point_buf
[40] = { 0 };
590 u64
*nqpq
= point_buf
+ 10;
591 u64
*nq2
= point_buf
+ 20;
592 u64
*nqpq2
= point_buf
+ 30;
595 ladder_bigloop_cmult_big_loop(n1
, nq
, nqpq
, nq2
, nqpq2
, q
, 32);
596 point_copy(result
, nq
);
599 static __always_inline
void format_fexpand(u64
*output
, const u8
*input
)
601 const u8
*x00
= input
+ 6;
602 const u8
*x01
= input
+ 12;
603 const u8
*x02
= input
+ 19;
604 const u8
*x0
= input
+ 24;
605 u64 i0
, i1
, i2
, i3
, i4
, output0
, output1
, output2
, output3
, output4
;
606 i0
= get_unaligned_le64(input
);
607 i1
= get_unaligned_le64(x00
);
608 i2
= get_unaligned_le64(x01
);
609 i3
= get_unaligned_le64(x02
);
610 i4
= get_unaligned_le64(x0
);
611 output0
= i0
& 0x7ffffffffffffLLU
;
612 output1
= i1
>> 3 & 0x7ffffffffffffLLU
;
613 output2
= i2
>> 6 & 0x7ffffffffffffLLU
;
614 output3
= i3
>> 1 & 0x7ffffffffffffLLU
;
615 output4
= i4
>> 12 & 0x7ffffffffffffLLU
;
623 static __always_inline
void format_fcontract_first_carry_pass(u64
*input
)
630 u64 t1_
= t1
+ (t0
>> 51);
631 u64 t0_
= t0
& 0x7ffffffffffffLLU
;
632 u64 t2_
= t2
+ (t1_
>> 51);
633 u64 t1__
= t1_
& 0x7ffffffffffffLLU
;
634 u64 t3_
= t3
+ (t2_
>> 51);
635 u64 t2__
= t2_
& 0x7ffffffffffffLLU
;
636 u64 t4_
= t4
+ (t3_
>> 51);
637 u64 t3__
= t3_
& 0x7ffffffffffffLLU
;
645 static __always_inline
void format_fcontract_first_carry_full(u64
*input
)
647 format_fcontract_first_carry_pass(input
);
648 modulo_carry_top(input
);
651 static __always_inline
void format_fcontract_second_carry_pass(u64
*input
)
658 u64 t1_
= t1
+ (t0
>> 51);
659 u64 t0_
= t0
& 0x7ffffffffffffLLU
;
660 u64 t2_
= t2
+ (t1_
>> 51);
661 u64 t1__
= t1_
& 0x7ffffffffffffLLU
;
662 u64 t3_
= t3
+ (t2_
>> 51);
663 u64 t2__
= t2_
& 0x7ffffffffffffLLU
;
664 u64 t4_
= t4
+ (t3_
>> 51);
665 u64 t3__
= t3_
& 0x7ffffffffffffLLU
;
673 static __always_inline
void format_fcontract_second_carry_full(u64
*input
)
679 format_fcontract_second_carry_pass(input
);
680 modulo_carry_top(input
);
683 i0_
= i0
& 0x7ffffffffffffLLU
;
684 i1_
= i1
+ (i0
>> 51);
689 static __always_inline
void format_fcontract_trim(u64
*input
)
696 u64 mask0
= u64_gte_mask(a0
, 0x7ffffffffffedLLU
);
697 u64 mask1
= u64_eq_mask(a1
, 0x7ffffffffffffLLU
);
698 u64 mask2
= u64_eq_mask(a2
, 0x7ffffffffffffLLU
);
699 u64 mask3
= u64_eq_mask(a3
, 0x7ffffffffffffLLU
);
700 u64 mask4
= u64_eq_mask(a4
, 0x7ffffffffffffLLU
);
701 u64 mask
= (((mask0
& mask1
) & mask2
) & mask3
) & mask4
;
702 u64 a0_
= a0
- (0x7ffffffffffedLLU
& mask
);
703 u64 a1_
= a1
- (0x7ffffffffffffLLU
& mask
);
704 u64 a2_
= a2
- (0x7ffffffffffffLLU
& mask
);
705 u64 a3_
= a3
- (0x7ffffffffffffLLU
& mask
);
706 u64 a4_
= a4
- (0x7ffffffffffffLLU
& mask
);
714 static __always_inline
void format_fcontract_store(u8
*output
, u64
*input
)
721 u64 o0
= t1
<< 51 | t0
;
722 u64 o1
= t2
<< 38 | t1
>> 13;
723 u64 o2
= t3
<< 25 | t2
>> 26;
724 u64 o3
= t4
<< 12 | t3
>> 39;
727 u8
*b2
= output
+ 16;
728 u8
*b3
= output
+ 24;
729 put_unaligned_le64(o0
, b0
);
730 put_unaligned_le64(o1
, b1
);
731 put_unaligned_le64(o2
, b2
);
732 put_unaligned_le64(o3
, b3
);
735 static __always_inline
void format_fcontract(u8
*output
, u64
*input
)
737 format_fcontract_first_carry_full(input
);
738 format_fcontract_second_carry_full(input
);
739 format_fcontract_trim(input
);
740 format_fcontract_store(output
, input
);
743 static __always_inline
void format_scalar_of_point(u8
*scalar
, u64
*point
)
747 u64 buf
[10] __aligned(32) = { 0 };
752 format_fcontract(scalar
, sc
);
755 static void curve25519_generic(u8 mypublic
[CURVE25519_KEY_SIZE
],
756 const u8 secret
[CURVE25519_KEY_SIZE
],
757 const u8 basepoint
[CURVE25519_KEY_SIZE
])
759 u64 buf0
[10] __aligned(32) = { 0 };
763 format_fexpand(x0
, basepoint
);
767 u8 e
[32] __aligned(32) = { 0 };
769 memcpy(e
, secret
, 32);
770 curve25519_clamp_secret(e
);
777 ladder_cmult(nq
, scalar
, q
);
778 format_scalar_of_point(mypublic
, nq
);
779 memzero_explicit(buf
, sizeof(buf
));
781 memzero_explicit(e
, sizeof(e
));
783 memzero_explicit(buf0
, sizeof(buf0
));