aboutsummaryrefslogtreecommitdiff
path: root/crypto/rsa/rsa_ossl.c
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/rsa/rsa_ossl.c')
-rw-r--r--crypto/rsa/rsa_ossl.c269
1 files changed, 240 insertions, 29 deletions
diff --git a/crypto/rsa/rsa_ossl.c b/crypto/rsa/rsa_ossl.c
index 0fc642e777fd..0c0c73c65c67 100644
--- a/crypto/rsa/rsa_ossl.c
+++ b/crypto/rsa/rsa_ossl.c
@@ -1,5 +1,5 @@
/*
- * Copyright 1995-2023 The OpenSSL Project Authors. All Rights Reserved.
+ * Copyright 1995-2024 The OpenSSL Project Authors. All Rights Reserved.
*
* Licensed under the Apache License 2.0 (the "License"). You may not use
* this file except in compliance with the License. You can obtain a copy
@@ -17,6 +17,9 @@
#include "crypto/bn.h"
#include "rsa_local.h"
#include "internal/constant_time.h"
+#include <openssl/evp.h>
+#include <openssl/sha.h>
+#include <openssl/hmac.h>
static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
unsigned char *to, RSA *rsa, int padding);
@@ -30,6 +33,27 @@ static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *i, RSA *rsa,
BN_CTX *ctx);
static int rsa_ossl_init(RSA *rsa);
static int rsa_ossl_finish(RSA *rsa);
+#ifdef S390X_MOD_EXP
+static int rsa_ossl_s390x_mod_exp(BIGNUM *r0, const BIGNUM *i, RSA *rsa,
+ BN_CTX *ctx);
+static RSA_METHOD rsa_pkcs1_ossl_meth = {
+ "OpenSSL PKCS#1 RSA",
+ rsa_ossl_public_encrypt,
+ rsa_ossl_public_decrypt, /* signature verification */
+ rsa_ossl_private_encrypt, /* signing */
+ rsa_ossl_private_decrypt,
+ rsa_ossl_s390x_mod_exp,
+ s390x_mod_exp,
+ rsa_ossl_init,
+ rsa_ossl_finish,
+ RSA_FLAG_FIPS_METHOD, /* flags */
+ NULL,
+ 0, /* rsa_sign */
+ 0, /* rsa_verify */
+ NULL, /* rsa_keygen */
+ NULL /* rsa_multi_prime_keygen */
+};
+#else
static RSA_METHOD rsa_pkcs1_ossl_meth = {
"OpenSSL PKCS#1 RSA",
rsa_ossl_public_encrypt,
@@ -48,6 +72,7 @@ static RSA_METHOD rsa_pkcs1_ossl_meth = {
NULL, /* rsa_keygen */
NULL /* rsa_multi_prime_keygen */
};
+#endif
static const RSA_METHOD *default_RSA_meth = &rsa_pkcs1_ossl_meth;
@@ -104,10 +129,8 @@ static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
ret = BN_CTX_get(ctx);
num = BN_num_bytes(rsa->n);
buf = OPENSSL_malloc(num);
- if (ret == NULL || buf == NULL) {
- ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
+ if (ret == NULL || buf == NULL)
goto err;
- }
switch (padding) {
case RSA_PKCS1_PADDING:
@@ -132,10 +155,35 @@ static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
if (BN_bin2bn(buf, num, f) == NULL)
goto err;
- if (BN_ucmp(f, rsa->n) >= 0) {
- /* usually the padding functions would catch this */
- ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
- goto err;
+#ifdef FIPS_MODULE
+ /*
+ * See SP800-56Br2, section 7.1.1.1
+ * RSAEP: 1 < f < (n – 1).
+ * (where f is the plaintext).
+ */
+ if (padding == RSA_NO_PADDING) {
+ BIGNUM *nminus1 = BN_CTX_get(ctx);
+
+ if (BN_ucmp(f, BN_value_one()) <= 0) {
+ ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_SMALL);
+ goto err;
+ }
+ if (nminus1 == NULL
+ || BN_copy(nminus1, rsa->n) == NULL
+ || !BN_sub_word(nminus1, 1))
+ goto err;
+ if (BN_ucmp(f, nminus1) >= 0) {
+ ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
+ goto err;
+ }
+ } else
+#endif
+ {
+ if (BN_ucmp(f, rsa->n) >= 0) {
+ /* usually the padding functions would catch this */
+ ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
+ goto err;
+ }
}
if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
@@ -163,11 +211,21 @@ static BN_BLINDING *rsa_get_blinding(RSA *rsa, int *local, BN_CTX *ctx)
{
BN_BLINDING *ret;
- if (!CRYPTO_THREAD_write_lock(rsa->lock))
+ if (!CRYPTO_THREAD_read_lock(rsa->lock))
return NULL;
if (rsa->blinding == NULL) {
- rsa->blinding = RSA_setup_blinding(rsa, ctx);
+ /*
+ * This dance with upgrading the lock from read to write will be
+ * slower in cases of a single use RSA object, but should be
+ * significantly better in multi-thread cases (e.g. servers). It's
+ * probably worth it.
+ */
+ CRYPTO_THREAD_unlock(rsa->lock);
+ if (!CRYPTO_THREAD_write_lock(rsa->lock))
+ return NULL;
+ if (rsa->blinding == NULL)
+ rsa->blinding = RSA_setup_blinding(rsa, ctx);
}
ret = rsa->blinding;
@@ -189,7 +247,11 @@ static BN_BLINDING *rsa_get_blinding(RSA *rsa, int *local, BN_CTX *ctx)
*local = 0;
if (rsa->mt_blinding == NULL) {
- rsa->mt_blinding = RSA_setup_blinding(rsa, ctx);
+ CRYPTO_THREAD_unlock(rsa->lock);
+ if (!CRYPTO_THREAD_write_lock(rsa->lock))
+ return NULL;
+ if (rsa->mt_blinding == NULL)
+ rsa->mt_blinding = RSA_setup_blinding(rsa, ctx);
}
ret = rsa->mt_blinding;
}
@@ -262,10 +324,8 @@ static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
ret = BN_CTX_get(ctx);
num = BN_num_bytes(rsa->n);
buf = OPENSSL_malloc(num);
- if (ret == NULL || buf == NULL) {
- ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
+ if (ret == NULL || buf == NULL)
goto err;
- }
switch (padding) {
case RSA_PKCS1_PADDING:
@@ -308,7 +368,7 @@ static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
if (blinding != NULL) {
if (!local_blinding && ((unblind = BN_CTX_get(ctx)) == NULL)) {
- ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
+ ERR_raise(ERR_LIB_RSA, ERR_R_BN_LIB);
goto err;
}
if (!rsa_blinding_convert(blinding, f, unblind, ctx))
@@ -325,7 +385,7 @@ static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
} else {
BIGNUM *d = BN_new();
if (d == NULL) {
- ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
+ ERR_raise(ERR_LIB_RSA, ERR_R_BN_LIB);
goto err;
}
if (rsa->d == NULL) {
@@ -371,12 +431,98 @@ static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
return r;
}
+static int derive_kdk(int flen, const unsigned char *from, RSA *rsa,
+ unsigned char *buf, int num, unsigned char *kdk)
+{
+ int ret = 0;
+ HMAC_CTX *hmac = NULL;
+ EVP_MD *md = NULL;
+ unsigned int md_len = SHA256_DIGEST_LENGTH;
+ unsigned char d_hash[SHA256_DIGEST_LENGTH] = {0};
+ /*
+ * because we use d as a handle to rsa->d we need to keep it local and
+ * free before any further use of rsa->d
+ */
+ BIGNUM *d = BN_new();
+
+ if (d == NULL) {
+ ERR_raise(ERR_LIB_RSA, ERR_R_CRYPTO_LIB);
+ goto err;
+ }
+ if (rsa->d == NULL) {
+ ERR_raise(ERR_LIB_RSA, RSA_R_MISSING_PRIVATE_KEY);
+ BN_free(d);
+ goto err;
+ }
+ BN_with_flags(d, rsa->d, BN_FLG_CONSTTIME);
+ if (BN_bn2binpad(d, buf, num) < 0) {
+ ERR_raise(ERR_LIB_RSA, ERR_R_INTERNAL_ERROR);
+ BN_free(d);
+ goto err;
+ }
+ BN_free(d);
+
+ /*
+ * we use hardcoded hash so that migrating between versions that use
+ * different hash doesn't provide a Bleichenbacher oracle:
+ * if the attacker can see that different versions return different
+ * messages for the same ciphertext, they'll know that the message is
+ * synthetically generated, which means that the padding check failed
+ */
+ md = EVP_MD_fetch(rsa->libctx, "sha256", NULL);
+ if (md == NULL) {
+ ERR_raise(ERR_LIB_RSA, ERR_R_FETCH_FAILED);
+ goto err;
+ }
+
+ if (EVP_Digest(buf, num, d_hash, NULL, md, NULL) <= 0) {
+ ERR_raise(ERR_LIB_RSA, ERR_R_INTERNAL_ERROR);
+ goto err;
+ }
+
+ hmac = HMAC_CTX_new();
+ if (hmac == NULL) {
+ ERR_raise(ERR_LIB_RSA, ERR_R_CRYPTO_LIB);
+ goto err;
+ }
+
+ if (HMAC_Init_ex(hmac, d_hash, sizeof(d_hash), md, NULL) <= 0) {
+ ERR_raise(ERR_LIB_RSA, ERR_R_INTERNAL_ERROR);
+ goto err;
+ }
+
+ if (flen < num) {
+ memset(buf, 0, num - flen);
+ if (HMAC_Update(hmac, buf, num - flen) <= 0) {
+ ERR_raise(ERR_LIB_RSA, ERR_R_INTERNAL_ERROR);
+ goto err;
+ }
+ }
+ if (HMAC_Update(hmac, from, flen) <= 0) {
+ ERR_raise(ERR_LIB_RSA, ERR_R_INTERNAL_ERROR);
+ goto err;
+ }
+
+ md_len = SHA256_DIGEST_LENGTH;
+ if (HMAC_Final(hmac, kdk, &md_len) <= 0) {
+ ERR_raise(ERR_LIB_RSA, ERR_R_INTERNAL_ERROR);
+ goto err;
+ }
+ ret = 1;
+
+ err:
+ HMAC_CTX_free(hmac);
+ EVP_MD_free(md);
+ return ret;
+}
+
static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
unsigned char *to, RSA *rsa, int padding)
{
BIGNUM *f, *ret;
int j, num = 0, r = -1;
unsigned char *buf = NULL;
+ unsigned char kdk[SHA256_DIGEST_LENGTH] = {0};
BN_CTX *ctx = NULL;
int local_blinding = 0;
/*
@@ -387,17 +533,25 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
BIGNUM *unblind = NULL;
BN_BLINDING *blinding = NULL;
+ /*
+ * we need the value of the private exponent to perform implicit rejection
+ */
+ if ((rsa->flags & RSA_FLAG_EXT_PKEY) && (padding == RSA_PKCS1_PADDING))
+ padding = RSA_PKCS1_NO_IMPLICIT_REJECT_PADDING;
+
if ((ctx = BN_CTX_new_ex(rsa->libctx)) == NULL)
goto err;
BN_CTX_start(ctx);
f = BN_CTX_get(ctx);
ret = BN_CTX_get(ctx);
+ if (ret == NULL) {
+ ERR_raise(ERR_LIB_RSA, ERR_R_BN_LIB);
+ goto err;
+ }
num = BN_num_bytes(rsa->n);
buf = OPENSSL_malloc(num);
- if (ret == NULL || buf == NULL) {
- ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
+ if (buf == NULL)
goto err;
- }
/*
* This check was for equality but PGP does evil things and chops off the
@@ -408,15 +562,44 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
goto err;
}
+ if (flen < 1) {
+ ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_SMALL);
+ goto err;
+ }
+
/* make data into a big number */
if (BN_bin2bn(from, (int)flen, f) == NULL)
goto err;
- if (BN_ucmp(f, rsa->n) >= 0) {
- ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
- goto err;
- }
+#ifdef FIPS_MODULE
+ /*
+ * See SP800-56Br2, section 7.1.2.1
+ * RSADP: 1 < f < (n – 1)
+ * (where f is the ciphertext).
+ */
+ if (padding == RSA_NO_PADDING) {
+ BIGNUM *nminus1 = BN_CTX_get(ctx);
+ if (BN_ucmp(f, BN_value_one()) <= 0) {
+ ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_SMALL);
+ goto err;
+ }
+ if (nminus1 == NULL
+ || BN_copy(nminus1, rsa->n) == NULL
+ || !BN_sub_word(nminus1, 1))
+ goto err;
+ if (BN_ucmp(f, nminus1) >= 0) {
+ ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
+ goto err;
+ }
+ } else
+#endif
+ {
+ if (BN_ucmp(f, rsa->n) >= 0) {
+ ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
+ goto err;
+ }
+ }
if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
rsa->n, ctx))
@@ -432,7 +615,7 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
if (blinding != NULL) {
if (!local_blinding && ((unblind = BN_CTX_get(ctx)) == NULL)) {
- ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
+ ERR_raise(ERR_LIB_RSA, ERR_R_BN_LIB);
goto err;
}
if (!rsa_blinding_convert(blinding, f, unblind, ctx))
@@ -450,7 +633,7 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
} else {
BIGNUM *d = BN_new();
if (d == NULL) {
- ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
+ ERR_raise(ERR_LIB_RSA, ERR_R_BN_LIB);
goto err;
}
if (rsa->d == NULL) {
@@ -472,14 +655,26 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
if (!rsa_blinding_invert(blinding, ret, unblind, ctx))
goto err;
+ /*
+ * derive the Key Derivation Key from private exponent and public
+ * ciphertext
+ */
+ if (padding == RSA_PKCS1_PADDING) {
+ if (derive_kdk(flen, from, rsa, buf, num, kdk) == 0)
+ goto err;
+ }
+
j = BN_bn2binpad(ret, buf, num);
if (j < 0)
goto err;
switch (padding) {
- case RSA_PKCS1_PADDING:
+ case RSA_PKCS1_NO_IMPLICIT_REJECT_PADDING:
r = RSA_padding_check_PKCS1_type_2(to, num, buf, j, num);
break;
+ case RSA_PKCS1_PADDING:
+ r = ossl_rsa_padding_check_PKCS1_type_2(rsa->libctx, to, num, buf, j, num, kdk);
+ break;
case RSA_PKCS1_OAEP_PADDING:
r = RSA_padding_check_PKCS1_OAEP(to, num, buf, j, num, NULL, 0);
break;
@@ -539,12 +734,14 @@ static int rsa_ossl_public_decrypt(int flen, const unsigned char *from,
BN_CTX_start(ctx);
f = BN_CTX_get(ctx);
ret = BN_CTX_get(ctx);
+ if (ret == NULL) {
+ ERR_raise(ERR_LIB_RSA, ERR_R_BN_LIB);
+ goto err;
+ }
num = BN_num_bytes(rsa->n);
buf = OPENSSL_malloc(num);
- if (ret == NULL || buf == NULL) {
- ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
+ if (buf == NULL)
goto err;
- }
/*
* This check was for equality but PGP does evil things and chops off the
@@ -572,6 +769,7 @@ static int rsa_ossl_public_decrypt(int flen, const unsigned char *from,
rsa->_method_mod_n))
goto err;
+ /* For X9.31: Assuming e is odd it does a 12 mod 16 test */
if ((padding == RSA_X931_PADDING) && ((bn_get_words(ret)[0] & 0xf) != 12))
if (!BN_sub(ret, rsa->n, ret))
goto err;
@@ -998,3 +1196,16 @@ static int rsa_ossl_finish(RSA *rsa)
BN_MONT_CTX_free(rsa->_method_mod_q);
return 1;
}
+
+#ifdef S390X_MOD_EXP
+static int rsa_ossl_s390x_mod_exp(BIGNUM *r0, const BIGNUM *i, RSA *rsa,
+ BN_CTX *ctx)
+{
+ if (rsa->version != RSA_ASN1_VERSION_MULTI) {
+ if (s390x_crt(r0, i, rsa->p, rsa->q, rsa->dmp1, rsa->dmq1, rsa->iqmp) == 1)
+ return 1;
+ }
+ return rsa_ossl_mod_exp(r0, i, rsa, ctx);
+}
+
+#endif