diff --git a/crypto/blkcipher.c b/crypto/blkcipher.c index 8a66221d95bb..1c48c2e31cee 100644 --- a/crypto/blkcipher.c +++ b/crypto/blkcipher.c @@ -70,19 +70,18 @@ static inline u8 *blkcipher_get_spot(u8 *start, unsigned int len) return max(start, end_page); } -static inline unsigned int blkcipher_done_slow(struct blkcipher_walk *walk, - unsigned int bsize) +static inline void blkcipher_done_slow(struct blkcipher_walk *walk, + unsigned int bsize) { u8 *addr; addr = (u8 *)ALIGN((unsigned long)walk->buffer, walk->alignmask + 1); addr = blkcipher_get_spot(addr, bsize); scatterwalk_copychunks(addr, &walk->out, bsize, 1); - return bsize; } -static inline unsigned int blkcipher_done_fast(struct blkcipher_walk *walk, - unsigned int n) +static inline void blkcipher_done_fast(struct blkcipher_walk *walk, + unsigned int n) { if (walk->flags & BLKCIPHER_WALK_COPY) { blkcipher_map_dst(walk); @@ -96,49 +95,48 @@ static inline unsigned int blkcipher_done_fast(struct blkcipher_walk *walk, scatterwalk_advance(&walk->in, n); scatterwalk_advance(&walk->out, n); - - return n; } int blkcipher_walk_done(struct blkcipher_desc *desc, struct blkcipher_walk *walk, int err) { - unsigned int nbytes = 0; + unsigned int n; /* bytes processed */ + bool more; - if (likely(err >= 0)) { - unsigned int n = walk->nbytes - err; + if (unlikely(err < 0)) + goto finish; - if (likely(!(walk->flags & BLKCIPHER_WALK_SLOW))) - n = blkcipher_done_fast(walk, n); - else if (WARN_ON(err)) { + n = walk->nbytes - err; + walk->total -= n; + more = (walk->total != 0); + + if (likely(!(walk->flags & BLKCIPHER_WALK_SLOW))) { + blkcipher_done_fast(walk, n); + } else { + if (WARN_ON(err)) { + /* unexpected case; didn't process all bytes */ err = -EINVAL; - goto err; - } else - n = blkcipher_done_slow(walk, n); - - nbytes = walk->total - n; - err = 0; + goto finish; + } + blkcipher_done_slow(walk, n); } - scatterwalk_done(&walk->in, 0, nbytes); - scatterwalk_done(&walk->out, 1, nbytes); + scatterwalk_done(&walk->in, 0, more); + scatterwalk_done(&walk->out, 1, more); -err: - walk->total = nbytes; - walk->nbytes = nbytes; - - if (nbytes) { + if (more) { crypto_yield(desc->flags); return blkcipher_walk_next(desc, walk); } - + err = 0; +finish: + walk->nbytes = 0; if (walk->iv != desc->info) memcpy(desc->info, walk->iv, walk->ivsize); if (walk->buffer != walk->page) kfree(walk->buffer); if (walk->page) free_page((unsigned long)walk->page); - return err; } EXPORT_SYMBOL_GPL(blkcipher_walk_done);