diff --git a/drivers/virtio/virtio_balloon.c b/drivers/virtio/virtio_balloon.c index cd778b1752b5..3db3d242c3ee 100644 --- a/drivers/virtio/virtio_balloon.c +++ b/drivers/virtio/virtio_balloon.c @@ -51,6 +51,7 @@ struct virtio_balloon u32 pfns[256]; /* Memory statistics */ + int need_stats_update; struct virtio_balloon_stat stats[VIRTIO_BALLOON_S_NR]; }; @@ -193,20 +194,30 @@ static void update_balloon_stats(struct virtio_balloon *vb) * the stats queue operates in reverse. The driver initializes the virtqueue * with a single buffer. From that point forward, all conversations consist of * a hypervisor request (a call to this function) which directs us to refill - * the virtqueue with a fresh stats buffer. + * the virtqueue with a fresh stats buffer. Since stats collection can sleep, + * we notify our kthread which does the actual work via stats_handle_request(). */ -static void stats_ack(struct virtqueue *vq) +static void stats_request(struct virtqueue *vq) { struct virtio_balloon *vb; unsigned int len; - struct scatterlist sg; vb = vq->vq_ops->get_buf(vq, &len); if (!vb) return; + vb->need_stats_update = 1; + wake_up(&vb->config_change); +} +static void stats_handle_request(struct virtio_balloon *vb) +{ + struct virtqueue *vq; + struct scatterlist sg; + + vb->need_stats_update = 0; update_balloon_stats(vb); + vq = vb->stats_vq; sg_init_one(&sg, vb->stats, sizeof(vb->stats)); if (vq->vq_ops->add_buf(vq, &sg, 1, 0, vb) < 0) BUG(); @@ -249,8 +260,11 @@ static int balloon(void *_vballoon) try_to_freeze(); wait_event_interruptible(vb->config_change, (diff = towards_target(vb)) != 0 + || vb->need_stats_update || kthread_should_stop() || freezing(current)); + if (vb->need_stats_update) + stats_handle_request(vb); if (diff > 0) fill_balloon(vb, diff); else if (diff < 0) @@ -264,7 +278,7 @@ static int virtballoon_probe(struct virtio_device *vdev) { struct virtio_balloon *vb; struct virtqueue *vqs[3]; - vq_callback_t *callbacks[] = { balloon_ack, balloon_ack, stats_ack }; + vq_callback_t *callbacks[] = { balloon_ack, balloon_ack, stats_request }; const char *names[] = { "inflate", "deflate", "stats" }; int err, nvqs;