Skip to content

Commit

Permalink
Improve Dilithium (speed) verification stack usage
Browse files Browse the repository at this point in the history
Once upon a time, we wrote a paper on memory-efficient Dilithium [1]
which included a speed-optimized version of verification that still
included some memory optimizations that don't come at a performance
penalty.

Unfortunately with the update of the reference code to round 3 that
version did not get migrated leading to some complaints about
verification memory consumption.

I finally found some time to port these.
Verficication speed is essentially unchanged, but stack consumption is
much better.

[1] https://eprint.iacr.org/2020/1278.pdf
  • Loading branch information
mkannwischer committed Jul 4, 2024
1 parent 006a109 commit 4e4e563
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 40 deletions.
106 changes: 105 additions & 1 deletion crypto_sign/dilithium2/m4f/packing.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "packing.h"
#include "polyvec.h"
#include "poly.h"
#include <stddef.h>

/*************************************************
* Name: pack_pk
Expand Down Expand Up @@ -49,6 +50,21 @@ void unpack_pk(uint8_t rho[SEEDBYTES],
polyt1_unpack(&t1->vec[i], pk + i*POLYT1_PACKEDBYTES);
}

/*************************************************
* Name: unpack_pk_t1
*
* Description: Unpack public key pk = (rho, t1).
*
* Arguments: - const polyvec *t1: pointer to output vector t1
* - const size_t idx: unpack n'th element from t1
* - unsigned char pk[]: byte array containing bit-packed pk
**************************************************/
void unpack_pk_t1(poly *t1, size_t idx, const unsigned char pk[CRYPTO_PUBLICKEYBYTES]) {
pk += SEEDBYTES;
polyt1_unpack(t1, pk + idx * POLYT1_PACKEDBYTES);
}


/*************************************************
* Name: pack_sk
*
Expand Down Expand Up @@ -283,4 +299,92 @@ int unpack_sig(uint8_t c[CTILDEBYTES],
return 1;

return 0;
}
}

/*************************************************
* Name: unpack_sig_c
*
* Description: Unpack only c from signature sig = (z, h, c).
*
* Arguments: - poly *c: pointer to output challenge polynomial
* - const unsigned char sig[]: byte array containing
* bit-packed signature
*
* Returns 1 in case of malformed signature; otherwise 0.
**************************************************/
int unpack_sig_c(uint8_t c[CTILDEBYTES], const unsigned char sig[CRYPTO_BYTES]) {
for(size_t i = 0; i < CTILDEBYTES; ++i)
c[i] = sig[i];
sig += CTILDEBYTES;
return 0;
}

/*************************************************
* Name: unpack_sig_z
*
* Description: Unpack only z from signature sig = (z, h, c).
*
* Arguments: - polyvecl *z: pointer to output vector z
* - const unsigned char sig[]: byte array containing
* bit-packed signature
*
* Returns 1 in case of malformed signature; otherwise 0.
**************************************************/
int unpack_sig_z(polyvecl *z, const unsigned char sig[CRYPTO_BYTES]) {
sig += CTILDEBYTES;
for (size_t i = 0; i < L; ++i) {
polyz_unpack(&z->vec[i], sig + i * POLYZ_PACKEDBYTES);
}
return 0;
}

/*************************************************
* Name: unpack_sig_h
*
* Description: Unpack only h from signature sig = (z, h, c).
*
* Arguments: - polyveck *h: pointer to output hint vector h
* - const unsigned char sig[]: byte array containing
* bit-packed signature
*
* Returns 1 in case of malformed signature; otherwise 0.
**************************************************/
int unpack_sig_h(poly *h, size_t idx, const unsigned char sig[CRYPTO_BYTES]) {
sig += CTILDEBYTES;
sig += L * POLYZ_PACKEDBYTES;

/* Decode h */
size_t k = 0;
for (size_t i = 0; i < K; ++i) {
for (size_t j = 0; j < N; ++j) {
if (i == idx) {
h->coeffs[j] = 0;
}
}

if (sig[OMEGA + i] < k || sig[OMEGA + i] > OMEGA) {
return 1;
}

for (size_t j = k; j < sig[OMEGA + i]; ++j) {
/* Coefficients are ordered for strong unforgeability */
if (j > k && sig[j] <= sig[j - 1]) {
return 1;
}
if (i == idx) {
h->coeffs[sig[j]] = 1;
}
}

k = sig[OMEGA + i];
}

/* Extra indices are zero for strong unforgeability */
for (size_t j = k; j < OMEGA; ++j) {
if (sig[j]) {
return 1;
}
}
return 0;
}

13 changes: 13 additions & 0 deletions crypto_sign/dilithium2/m4f/packing.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define PACKING_H

#include <stdint.h>
#include <stddef.h>
#include "params.h"
#include "polyvec.h"
#include "smallpoly.h"
Expand All @@ -24,6 +25,9 @@ void pack_sig(uint8_t sig[CRYPTO_BYTES], const uint8_t c[CTILDEBYTES], const pol
#define unpack_pk DILITHIUM_NAMESPACE(unpack_pk)
void unpack_pk(uint8_t rho[SEEDBYTES], polyveck *t1, const uint8_t pk[CRYPTO_PUBLICKEYBYTES]);

#define unpack_pk_t1 DILITHIUM_NAMESPACE(unpack_pk_t1)
void unpack_pk_t1(poly *t1, size_t idx, const unsigned char pk[CRYPTO_PUBLICKEYBYTES]);

#define unpack_sk DILITHIUM_NAMESPACE(unpack_sk)
void unpack_sk(uint8_t rho[SEEDBYTES],
uint8_t tr[TRBYTES],
Expand All @@ -36,6 +40,15 @@ void unpack_sk(uint8_t rho[SEEDBYTES],
#define unpack_sig DILITHIUM_NAMESPACE(unpack_sig)
int unpack_sig(uint8_t c[CTILDEBYTES], polyvecl *z, polyveck *h, const uint8_t sig[CRYPTO_BYTES]);


#define unpack_sig_z DILITHIUM_NAMESPACE(unpack_sig_z)
int unpack_sig_z(polyvecl *z, const unsigned char sig[CRYPTO_BYTES]);
#define unpack_sig_h DILITHIUM_NAMESPACE(unpack_sig_h)
int unpack_sig_h(poly *h, size_t idx, const unsigned char sig[CRYPTO_BYTES]);
#define unpack_sig_c DILITHIUM_NAMESPACE(unpack_sig_c)
int unpack_sig_c(uint8_t c[CTILDEBYTES], const unsigned char sig[CRYPTO_BYTES]);


#define pack_sig_c DILITHIUM_NAMESPACE(pack_sig_c)
void pack_sig_c(uint8_t sig[CRYPTO_BYTES], const uint8_t c[CTILDEBYTES]);

Expand Down
12 changes: 12 additions & 0 deletions crypto_sign/dilithium2/m4f/poly.c
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ void poly_caddq(poly *a) {
asm_caddq(a->coeffs);
}

/*************************************************
* Name: poly_csubq
*
* Description: For all coefficients of input polynomial subtract Q if
* coefficient is bigger than Q; add Q if coefficient is negative.
*
* Arguments: - poly *a: pointer to input/output polynomial
**************************************************/
void poly_csubq(poly *a) {
asm_caddq(a->coeffs);
}

#if 0
/*************************************************
* Name: poly_freeze
Expand Down
2 changes: 2 additions & 0 deletions crypto_sign/dilithium2/m4f/poly.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ typedef struct {
void poly_reduce(poly *a);
#define poly_caddq DILITHIUM_NAMESPACE(poly_caddq)
void poly_caddq(poly *a);
#define poly_csubq DILITHIUM_NAMESPACE(poly_csubq)
void poly_csubq(poly *a);
#define poly_freeze DILITHIUM_NAMESPACE(poly_freeze)
void poly_freeze(poly *a);

Expand Down
117 changes: 78 additions & 39 deletions crypto_sign/dilithium2/m4f/sign.c
Original file line number Diff line number Diff line change
Expand Up @@ -225,44 +225,60 @@ int crypto_sign(uint8_t *sm,
*smlen += mlen;
return 0;
}
/*************************************************
* Name: expand_mat_elem
*
* Description: Implementation of ExpandA. Generates matrix A with uniformly
* random coefficients a_{i,j} by performing rejection
* sampling on the output stream of SHAKE128(rho|i|j).
*
* Arguments: - poly mat_elem: output matrix element
* - const unsigned char rho[]: byte array containing seed rho
* - k_idx: matrix row index
* - l_idx: matrix col index
**************************************************/
static void expand_mat_elem(poly *mat_elem, const unsigned char rho[SEEDBYTES], size_t k_idx, size_t l_idx)
{
poly_uniform(mat_elem, rho, (uint16_t)((k_idx << 8) + l_idx));
}

/*************************************************
* Name: crypto_sign_verify
*
* Description: Verifies signature.
*
* Arguments: - uint8_t *m: pointer to input signature
* - size_t siglen: length of signature
* - const uint8_t *m: pointer to message
* - size_t mlen: length of message
* - const uint8_t *pk: pointer to bit-packed public key
*
* Returns 0 if signature could be verified correctly and -1 otherwise
**************************************************/
* Name: crypto_sign_verify
*
* Description: Verifies signature.
*
* Arguments: - uint8_t *m: pointer to input signature
* - size_t siglen: length of signature
* - const uint8_t *m: pointer to message
* - size_t mlen: length of message
* - const uint8_t *pk: pointer to bit-packed public key
*
* Returns 0 if signature could be verified correctly and -1 otherwise
**************************************************/
int crypto_sign_verify(const uint8_t *sig,
size_t siglen,
const uint8_t *m,
size_t mlen,
const uint8_t *pk)
{
unsigned int i;
uint8_t buf[K*POLYW1_PACKEDBYTES];
uint8_t rho[SEEDBYTES];
const uint8_t *rho = pk;
uint8_t mu[CRHBYTES];
uint8_t c[CTILDEBYTES];
uint8_t c2[CTILDEBYTES];
poly cp;
polyvecl mat[K], z;
polyveck t1, w1, h;
polyvecl z;
shake256incctx state;

if(siglen != CRYPTO_BYTES)
poly tmp_elem, w1_elem;

if (siglen != CRYPTO_BYTES)
return -1;

unpack_pk(rho, &t1, pk);
if(unpack_sig(c, &z, &h, sig))
if (unpack_sig_z(&z, sig) != 0) {
return -1;
if(polyvecl_chknorm(&z, GAMMA1 - BETA))
}
if (polyvecl_chknorm(&z, GAMMA1 - BETA))
return -1;

/* Compute CRH(h(rho, t1), msg) */
Expand All @@ -273,35 +289,58 @@ int crypto_sign_verify(const uint8_t *sig,
shake256_inc_finalize(&state);
shake256_inc_squeeze(mu, CRHBYTES, &state);

// Hash [mu || w1'] to get c.
shake256_inc_init(&state);
shake256_inc_absorb(&state, mu, CRHBYTES);

/* Matrix-vector multiplication; compute Az - c2^dt1 */
if (unpack_sig_c(c, sig) != 0) {
return -1;
}
poly_challenge(&cp, c);
polyvec_matrix_expand(mat, rho);

poly_ntt(&cp);
polyvecl_ntt(&z);
polyvec_matrix_pointwise_montgomery(&w1, mat, &z);

poly_ntt(&cp);
polyveck_shiftl(&t1);
polyveck_ntt(&t1);
polyveck_pointwise_poly_montgomery(&t1, &cp, &t1);

polyveck_sub(&w1, &w1, &t1);
polyveck_reduce(&w1);
polyveck_invntt_tomont(&w1);
for (size_t k_idx = 0; k_idx < K; k_idx++)
{
// Sample the current element from A.
expand_mat_elem(&tmp_elem, rho, k_idx, 0);
poly_pointwise_montgomery(&w1_elem, &tmp_elem, &z.vec[0]);

for (size_t l_idx = 1; l_idx < L; l_idx++)
{
// Sample the element from A.
expand_mat_elem(&tmp_elem, rho, k_idx, l_idx);
poly_pointwise_acc_montgomery(&w1_elem, &tmp_elem, &z.vec[l_idx]);
}

// Subtract c*(t1_{k_idx} * 2^d)
unpack_pk_t1(&tmp_elem, k_idx, pk);
poly_shiftl(&tmp_elem);
poly_ntt(&tmp_elem);
poly_pointwise_montgomery(&tmp_elem, &cp, &tmp_elem);
poly_sub(&w1_elem, &w1_elem, &tmp_elem);
poly_reduce(&w1_elem);
poly_invntt_tomont(&w1_elem);

// Reconstruct w1
poly_csubq(&w1_elem);
if (unpack_sig_h(&tmp_elem, k_idx, sig) != 0) {
return -1;
}
poly_use_hint(&w1_elem, &w1_elem, &tmp_elem);
uint8_t w1_packed[POLYW1_PACKEDBYTES];
polyw1_pack(w1_packed, &w1_elem);
shake256_inc_absorb(&state, w1_packed, POLYW1_PACKEDBYTES);
}

/* Reconstruct w1 */
polyveck_caddq(&w1);
polyveck_use_hint(&w1, &w1, &h);
polyveck_pack_w1(buf, &w1);

/* Call random oracle and verify challenge */
shake256_inc_init(&state);
shake256_inc_absorb(&state, mu, CRHBYTES);
shake256_inc_absorb(&state, buf, K*POLYW1_PACKEDBYTES);
shake256_inc_finalize(&state);
shake256_inc_squeeze(c2, CTILDEBYTES, &state);
for(i = 0; i < CTILDEBYTES; ++i)
if(c[i] != c2[i])
for (i = 0; i < CTILDEBYTES; ++i)
if (c[i] != c2[i])
return -1;

return 0;
Expand Down
2 changes: 2 additions & 0 deletions crypto_sign/dilithium2/m4f/vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ void asm_reduce32(int32_t a[N]);
void small_asm_reduce32_central(int32_t a[N]);
#define asm_caddq DILITHIUM_NAMESPACE(asm_caddq)
void asm_caddq(int32_t a[N]);
#define asm_csubq DILITHIUM_NAMESPACE(asm_csubq)
void asm_csubq(int32_t a[N]);
#define asm_freeze DILITHIUM_NAMESPACE(asm_freeze)
void asm_freeze(int32_t a[N]);
#define asm_rej_uniform DILITHIUM_NAMESPACE(asm_rej_uniform)
Expand Down
Loading

0 comments on commit 4e4e563

Please sign in to comment.