diff --git a/crypto/ahash.c b/crypto/ahash.c index 00eab7281b0f..cdafd8d2abc3 100644 --- a/crypto/ahash.c +++ b/crypto/ahash.c @@ -169,6 +169,21 @@ static int ahash_setkey_unaligned(struct crypto_ahash *tfm, const u8 *key, return ret; } +static int ahash_nosetkey(struct crypto_ahash *tfm, const u8 *key, + unsigned int keylen) +{ + return -ENOSYS; +} + +static void ahash_set_needkey(struct crypto_ahash *tfm) +{ + const struct hash_alg_common *alg = crypto_hash_alg_common(tfm); + + if (tfm->setkey != ahash_nosetkey && + !(alg->base.cra_flags & CRYPTO_ALG_OPTIONAL_KEY)) + crypto_ahash_set_flags(tfm, CRYPTO_TFM_NEED_KEY); +} + int crypto_ahash_setkey(struct crypto_ahash *tfm, const u8 *key, unsigned int keylen) { @@ -180,20 +195,16 @@ int crypto_ahash_setkey(struct crypto_ahash *tfm, const u8 *key, else err = tfm->setkey(tfm, key, keylen); - if (err) + if (unlikely(err)) { + ahash_set_needkey(tfm); return err; + } crypto_ahash_clear_flags(tfm, CRYPTO_TFM_NEED_KEY); return 0; } EXPORT_SYMBOL_GPL(crypto_ahash_setkey); -static int ahash_nosetkey(struct crypto_ahash *tfm, const u8 *key, - unsigned int keylen) -{ - return -ENOSYS; -} - static inline unsigned int ahash_align_buffer_size(unsigned len, unsigned long mask) { @@ -462,8 +473,7 @@ static int crypto_ahash_init_tfm(struct crypto_tfm *tfm) if (alg->setkey) { hash->setkey = alg->setkey; - if (!(alg->halg.base.cra_flags & CRYPTO_ALG_OPTIONAL_KEY)) - crypto_ahash_set_flags(hash, CRYPTO_TFM_NEED_KEY); + ahash_set_needkey(hash); } if (alg->export) hash->export = alg->export; diff --git a/crypto/shash.c b/crypto/shash.c index b17cffd1124a..961cb691dc84 100644 --- a/crypto/shash.c +++ b/crypto/shash.c @@ -52,6 +52,13 @@ static int shash_setkey_unaligned(struct crypto_shash *tfm, const u8 *key, return err; } +static void shash_set_needkey(struct crypto_shash *tfm, struct shash_alg *alg) +{ + if (crypto_shash_alg_has_setkey(alg) && + !(alg->base.cra_flags & CRYPTO_ALG_OPTIONAL_KEY)) + crypto_shash_set_flags(tfm, CRYPTO_TFM_NEED_KEY); +} + int crypto_shash_setkey(struct crypto_shash *tfm, const u8 *key, unsigned int keylen) { @@ -64,8 +71,10 @@ int crypto_shash_setkey(struct crypto_shash *tfm, const u8 *key, else err = shash->setkey(tfm, key, keylen); - if (err) + if (unlikely(err)) { + shash_set_needkey(tfm, shash); return err; + } crypto_shash_clear_flags(tfm, CRYPTO_TFM_NEED_KEY); return 0; @@ -366,7 +375,8 @@ int crypto_init_shash_ops_async(struct crypto_tfm *tfm) crt->final = shash_async_final; crt->finup = shash_async_finup; crt->digest = shash_async_digest; - crt->setkey = shash_async_setkey; + if (crypto_shash_alg_has_setkey(alg)) + crt->setkey = shash_async_setkey; crypto_ahash_set_flags(crt, crypto_shash_get_flags(shash) & CRYPTO_TFM_NEED_KEY); @@ -533,9 +543,7 @@ static int crypto_shash_init_tfm(struct crypto_tfm *tfm) hash->descsize = alg->descsize; - if (crypto_shash_alg_has_setkey(alg) && - !(alg->base.cra_flags & CRYPTO_ALG_OPTIONAL_KEY)) - crypto_shash_set_flags(hash, CRYPTO_TFM_NEED_KEY); + shash_set_needkey(hash, alg); return 0; }