crypto: poly1305 - Pass key as first two message blocks to each desc_ctx

The Poly1305 authenticator requires a unique key for each generated tag. This
implies that we can't set the key per tfm, as multiple users set individual
keys. Instead we pass a desc specific key as the first two blocks of the
message to authenticate in update().

Signed-off-by: Martin Willi <martin@strongswan.org>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
diff --git a/crypto/poly1305_generic.c b/crypto/poly1305_generic.c
index 9c1159b..387b5c8 100644
--- a/crypto/poly1305_generic.c
+++ b/crypto/poly1305_generic.c
@@ -21,20 +21,21 @@
 #define POLY1305_KEY_SIZE	32
 #define POLY1305_DIGEST_SIZE	16
 
-struct poly1305_ctx {
+struct poly1305_desc_ctx {
 	/* key */
 	u32 r[5];
 	/* finalize key */
 	u32 s[4];
-};
-
-struct poly1305_desc_ctx {
 	/* accumulator */
 	u32 h[5];
 	/* partial buffer */
 	u8 buf[POLY1305_BLOCK_SIZE];
 	/* bytes used in partial buffer */
 	unsigned int buflen;
+	/* r key has been set */
+	bool rset;
+	/* s key has been set */
+	bool sset;
 };
 
 static inline u64 mlt(u64 a, u64 b)
@@ -63,6 +64,8 @@
 
 	memset(dctx->h, 0, sizeof(dctx->h));
 	dctx->buflen = 0;
+	dctx->rset = false;
+	dctx->sset = false;
 
 	return 0;
 }
@@ -70,42 +73,60 @@
 static int poly1305_setkey(struct crypto_shash *tfm,
 			   const u8 *key, unsigned int keylen)
 {
-	struct poly1305_ctx *ctx = crypto_shash_ctx(tfm);
+	/* Poly1305 requires a unique key for each tag, which implies that
+	 * we can't set it on the tfm that gets accessed by multiple users
+	 * simultaneously. Instead we expect the key as the first 32 bytes in
+	 * the update() call. */
+	return -ENOTSUPP;
+}
 
-	if (keylen != POLY1305_KEY_SIZE) {
-		crypto_shash_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
-		return -EINVAL;
-	}
-
+static void poly1305_setrkey(struct poly1305_desc_ctx *dctx, const u8 *key)
+{
 	/* r &= 0xffffffc0ffffffc0ffffffc0fffffff */
-	ctx->r[0] = (le32_to_cpuvp(key +  0) >> 0) & 0x3ffffff;
-	ctx->r[1] = (le32_to_cpuvp(key +  3) >> 2) & 0x3ffff03;
-	ctx->r[2] = (le32_to_cpuvp(key +  6) >> 4) & 0x3ffc0ff;
-	ctx->r[3] = (le32_to_cpuvp(key +  9) >> 6) & 0x3f03fff;
-	ctx->r[4] = (le32_to_cpuvp(key + 12) >> 8) & 0x00fffff;
+	dctx->r[0] = (le32_to_cpuvp(key +  0) >> 0) & 0x3ffffff;
+	dctx->r[1] = (le32_to_cpuvp(key +  3) >> 2) & 0x3ffff03;
+	dctx->r[2] = (le32_to_cpuvp(key +  6) >> 4) & 0x3ffc0ff;
+	dctx->r[3] = (le32_to_cpuvp(key +  9) >> 6) & 0x3f03fff;
+	dctx->r[4] = (le32_to_cpuvp(key + 12) >> 8) & 0x00fffff;
+}
 
-	ctx->s[0] = le32_to_cpuvp(key + 16);
-	ctx->s[1] = le32_to_cpuvp(key + 20);
-	ctx->s[2] = le32_to_cpuvp(key + 24);
-	ctx->s[3] = le32_to_cpuvp(key + 28);
-
-	return 0;
+static void poly1305_setskey(struct poly1305_desc_ctx *dctx, const u8 *key)
+{
+	dctx->s[0] = le32_to_cpuvp(key +  0);
+	dctx->s[1] = le32_to_cpuvp(key +  4);
+	dctx->s[2] = le32_to_cpuvp(key +  8);
+	dctx->s[3] = le32_to_cpuvp(key + 12);
 }
 
 static unsigned int poly1305_blocks(struct poly1305_desc_ctx *dctx,
-				    struct poly1305_ctx *ctx, const u8 *src,
-				    unsigned int srclen, u32 hibit)
+				    const u8 *src, unsigned int srclen,
+				    u32 hibit)
 {
 	u32 r0, r1, r2, r3, r4;
 	u32 s1, s2, s3, s4;
 	u32 h0, h1, h2, h3, h4;
 	u64 d0, d1, d2, d3, d4;
 
-	r0 = ctx->r[0];
-	r1 = ctx->r[1];
-	r2 = ctx->r[2];
-	r3 = ctx->r[3];
-	r4 = ctx->r[4];
+	if (unlikely(!dctx->sset)) {
+		if (!dctx->rset && srclen >= POLY1305_BLOCK_SIZE) {
+			poly1305_setrkey(dctx, src);
+			src += POLY1305_BLOCK_SIZE;
+			srclen -= POLY1305_BLOCK_SIZE;
+			dctx->rset = true;
+		}
+		if (srclen >= POLY1305_BLOCK_SIZE) {
+			poly1305_setskey(dctx, src);
+			src += POLY1305_BLOCK_SIZE;
+			srclen -= POLY1305_BLOCK_SIZE;
+			dctx->sset = true;
+		}
+	}
+
+	r0 = dctx->r[0];
+	r1 = dctx->r[1];
+	r2 = dctx->r[2];
+	r3 = dctx->r[3];
+	r4 = dctx->r[4];
 
 	s1 = r1 * 5;
 	s2 = r2 * 5;
@@ -164,7 +185,6 @@
 			   const u8 *src, unsigned int srclen)
 {
 	struct poly1305_desc_ctx *dctx = shash_desc_ctx(desc);
-	struct poly1305_ctx *ctx = crypto_shash_ctx(desc->tfm);
 	unsigned int bytes;
 
 	if (unlikely(dctx->buflen)) {
@@ -175,14 +195,14 @@
 		dctx->buflen += bytes;
 
 		if (dctx->buflen == POLY1305_BLOCK_SIZE) {
-			poly1305_blocks(dctx, ctx, dctx->buf,
+			poly1305_blocks(dctx, dctx->buf,
 					POLY1305_BLOCK_SIZE, 1 << 24);
 			dctx->buflen = 0;
 		}
 	}
 
 	if (likely(srclen >= POLY1305_BLOCK_SIZE)) {
-		bytes = poly1305_blocks(dctx, ctx, src, srclen, 1 << 24);
+		bytes = poly1305_blocks(dctx, src, srclen, 1 << 24);
 		src += srclen - bytes;
 		srclen = bytes;
 	}
@@ -198,18 +218,20 @@
 static int poly1305_final(struct shash_desc *desc, u8 *dst)
 {
 	struct poly1305_desc_ctx *dctx = shash_desc_ctx(desc);
-	struct poly1305_ctx *ctx = crypto_shash_ctx(desc->tfm);
 	__le32 *mac = (__le32 *)dst;
 	u32 h0, h1, h2, h3, h4;
 	u32 g0, g1, g2, g3, g4;
 	u32 mask;
 	u64 f = 0;
 
+	if (unlikely(!dctx->sset))
+		return -ENOKEY;
+
 	if (unlikely(dctx->buflen)) {
 		dctx->buf[dctx->buflen++] = 1;
 		memset(dctx->buf + dctx->buflen, 0,
 		       POLY1305_BLOCK_SIZE - dctx->buflen);
-		poly1305_blocks(dctx, ctx, dctx->buf, POLY1305_BLOCK_SIZE, 0);
+		poly1305_blocks(dctx, dctx->buf, POLY1305_BLOCK_SIZE, 0);
 	}
 
 	/* fully carry h */
@@ -253,10 +275,10 @@
 	h3 = (h3 >> 18) | (h4 <<  8);
 
 	/* mac = (h + s) % (2^128) */
-	f = (f >> 32) + h0 + ctx->s[0]; mac[0] = cpu_to_le32(f);
-	f = (f >> 32) + h1 + ctx->s[1]; mac[1] = cpu_to_le32(f);
-	f = (f >> 32) + h2 + ctx->s[2]; mac[2] = cpu_to_le32(f);
-	f = (f >> 32) + h3 + ctx->s[3]; mac[3] = cpu_to_le32(f);
+	f = (f >> 32) + h0 + dctx->s[0]; mac[0] = cpu_to_le32(f);
+	f = (f >> 32) + h1 + dctx->s[1]; mac[1] = cpu_to_le32(f);
+	f = (f >> 32) + h2 + dctx->s[2]; mac[2] = cpu_to_le32(f);
+	f = (f >> 32) + h3 + dctx->s[3]; mac[3] = cpu_to_le32(f);
 
 	return 0;
 }
@@ -275,7 +297,6 @@
 		.cra_flags		= CRYPTO_ALG_TYPE_SHASH,
 		.cra_alignmask		= sizeof(u32) - 1,
 		.cra_blocksize		= POLY1305_BLOCK_SIZE,
-		.cra_ctxsize		= sizeof(struct poly1305_ctx),
 		.cra_module		= THIS_MODULE,
 	},
 };