ALSA: rawmidi: Make snd_rawmidi_transmit() race-free

A kernel WARNING in snd_rawmidi_transmit_ack() is triggered by
syzkaller fuzzer:
  WARNING: CPU: 1 PID: 20739 at sound/core/rawmidi.c:1136
Call Trace:
 [<     inline     >] __dump_stack lib/dump_stack.c:15
 [<ffffffff82999e2d>] dump_stack+0x6f/0xa2 lib/dump_stack.c:50
 [<ffffffff81352089>] warn_slowpath_common+0xd9/0x140 kernel/panic.c:482
 [<ffffffff813522b9>] warn_slowpath_null+0x29/0x30 kernel/panic.c:515
 [<ffffffff84f80bd5>] snd_rawmidi_transmit_ack+0x275/0x400 sound/core/rawmidi.c:1136
 [<ffffffff84fdb3c1>] snd_virmidi_output_trigger+0x4b1/0x5a0 sound/core/seq/seq_virmidi.c:163
 [<     inline     >] snd_rawmidi_output_trigger sound/core/rawmidi.c:150
 [<ffffffff84f87ed9>] snd_rawmidi_kernel_write1+0x549/0x780 sound/core/rawmidi.c:1223
 [<ffffffff84f89fd3>] snd_rawmidi_write+0x543/0xb30 sound/core/rawmidi.c:1273
 [<ffffffff817b0323>] __vfs_write+0x113/0x480 fs/read_write.c:528
 [<ffffffff817b1db7>] vfs_write+0x167/0x4a0 fs/read_write.c:577
 [<     inline     >] SYSC_write fs/read_write.c:624
 [<ffffffff817b50a1>] SyS_write+0x111/0x220 fs/read_write.c:616
 [<ffffffff86336c36>] entry_SYSCALL_64_fastpath+0x16/0x7a arch/x86/entry/entry_64.S:185

Also a similar warning is found but in another path:
Call Trace:
 [<     inline     >] __dump_stack lib/dump_stack.c:15
 [<ffffffff82be2c0d>] dump_stack+0x6f/0xa2 lib/dump_stack.c:50
 [<ffffffff81355139>] warn_slowpath_common+0xd9/0x140 kernel/panic.c:482
 [<ffffffff81355369>] warn_slowpath_null+0x29/0x30 kernel/panic.c:515
 [<ffffffff8527e69a>] rawmidi_transmit_ack+0x24a/0x3b0 sound/core/rawmidi.c:1133
 [<ffffffff8527e851>] snd_rawmidi_transmit_ack+0x51/0x80 sound/core/rawmidi.c:1163
 [<ffffffff852d9046>] snd_virmidi_output_trigger+0x2b6/0x570 sound/core/seq/seq_virmidi.c:185
 [<     inline     >] snd_rawmidi_output_trigger sound/core/rawmidi.c:150
 [<ffffffff85285a0b>] snd_rawmidi_kernel_write1+0x4bb/0x760 sound/core/rawmidi.c:1252
 [<ffffffff85287b73>] snd_rawmidi_write+0x543/0xb30 sound/core/rawmidi.c:1302
 [<ffffffff817ba5f3>] __vfs_write+0x113/0x480 fs/read_write.c:528
 [<ffffffff817bc087>] vfs_write+0x167/0x4a0 fs/read_write.c:577
 [<     inline     >] SYSC_write fs/read_write.c:624
 [<ffffffff817bf371>] SyS_write+0x111/0x220 fs/read_write.c:616
 [<ffffffff86660276>] entry_SYSCALL_64_fastpath+0x16/0x7a arch/x86/entry/entry_64.S:185

In the former case, the reason is that virmidi has an open code
calling snd_rawmidi_transmit_ack() with the value calculated outside
the spinlock.   We may use snd_rawmidi_transmit() in a loop just for
consuming the input data, but even there, there is a race between
snd_rawmidi_transmit_peek() and snd_rawmidi_tranmit_ack().

Similarly in the latter case, it calls snd_rawmidi_transmit_peek() and
snd_rawmidi_tranmit_ack() separately without protection, so they are
racy as well.

The patch tries to address these issues by the following ways:
- Introduce the unlocked versions of snd_rawmidi_transmit_peek() and
  snd_rawmidi_transmit_ack() to be called inside the explicit lock.
- Rewrite snd_rawmidi_transmit() to be race-free (the former case).
- Make the split calls (the latter case) protected in the rawmidi spin
  lock.

BugLink: http://lkml.kernel.org/r/CACT4Y+YPq1+cYLkadwjWa5XjzF1_Vki1eHnVn-Lm0hzhSpu5PA@mail.gmail.com
BugLink: http://lkml.kernel.org/r/CACT4Y+acG4iyphdOZx47Nyq_VHGbpJQK-6xNpiqUjaZYqsXOGw@mail.gmail.com
Reported-by: Dmitry Vyukov <dvyukov@google.com>
Tested-by: Dmitry Vyukov <dvyukov@google.com>
Cc: <stable@vger.kernel.org>
Signed-off-by: Takashi Iwai <tiwai@suse.de>
diff --git a/sound/core/rawmidi.c b/sound/core/rawmidi.c
index f75d165..26ca022 100644
--- a/sound/core/rawmidi.c
+++ b/sound/core/rawmidi.c
@@ -1055,23 +1055,16 @@
 EXPORT_SYMBOL(snd_rawmidi_transmit_empty);
 
 /**
- * snd_rawmidi_transmit_peek - copy data from the internal buffer
+ * __snd_rawmidi_transmit_peek - copy data from the internal buffer
  * @substream: the rawmidi substream
  * @buffer: the buffer pointer
  * @count: data size to transfer
  *
- * Copies data from the internal output buffer to the given buffer.
- *
- * Call this in the interrupt handler when the midi output is ready,
- * and call snd_rawmidi_transmit_ack() after the transmission is
- * finished.
- *
- * Return: The size of copied data, or a negative error code on failure.
+ * This is a variant of snd_rawmidi_transmit_peek() without spinlock.
  */
-int snd_rawmidi_transmit_peek(struct snd_rawmidi_substream *substream,
+int __snd_rawmidi_transmit_peek(struct snd_rawmidi_substream *substream,
 			      unsigned char *buffer, int count)
 {
-	unsigned long flags;
 	int result, count1;
 	struct snd_rawmidi_runtime *runtime = substream->runtime;
 
@@ -1081,7 +1074,6 @@
 		return -EINVAL;
 	}
 	result = 0;
-	spin_lock_irqsave(&runtime->lock, flags);
 	if (runtime->avail >= runtime->buffer_size) {
 		/* warning: lowlevel layer MUST trigger down the hardware */
 		goto __skip;
@@ -1106,12 +1098,68 @@
 		}
 	}
       __skip:
+	return result;
+}
+EXPORT_SYMBOL(__snd_rawmidi_transmit_peek);
+
+/**
+ * snd_rawmidi_transmit_peek - copy data from the internal buffer
+ * @substream: the rawmidi substream
+ * @buffer: the buffer pointer
+ * @count: data size to transfer
+ *
+ * Copies data from the internal output buffer to the given buffer.
+ *
+ * Call this in the interrupt handler when the midi output is ready,
+ * and call snd_rawmidi_transmit_ack() after the transmission is
+ * finished.
+ *
+ * Return: The size of copied data, or a negative error code on failure.
+ */
+int snd_rawmidi_transmit_peek(struct snd_rawmidi_substream *substream,
+			      unsigned char *buffer, int count)
+{
+	struct snd_rawmidi_runtime *runtime = substream->runtime;
+	int result;
+	unsigned long flags;
+
+	spin_lock_irqsave(&runtime->lock, flags);
+	result = __snd_rawmidi_transmit_peek(substream, buffer, count);
 	spin_unlock_irqrestore(&runtime->lock, flags);
 	return result;
 }
 EXPORT_SYMBOL(snd_rawmidi_transmit_peek);
 
 /**
+ * __snd_rawmidi_transmit_ack - acknowledge the transmission
+ * @substream: the rawmidi substream
+ * @count: the transferred count
+ *
+ * This is a variant of __snd_rawmidi_transmit_ack() without spinlock.
+ */
+int __snd_rawmidi_transmit_ack(struct snd_rawmidi_substream *substream, int count)
+{
+	struct snd_rawmidi_runtime *runtime = substream->runtime;
+
+	if (runtime->buffer == NULL) {
+		rmidi_dbg(substream->rmidi,
+			  "snd_rawmidi_transmit_ack: output is not active!!!\n");
+		return -EINVAL;
+	}
+	snd_BUG_ON(runtime->avail + count > runtime->buffer_size);
+	runtime->hw_ptr += count;
+	runtime->hw_ptr %= runtime->buffer_size;
+	runtime->avail += count;
+	substream->bytes += count;
+	if (count > 0) {
+		if (runtime->drain || snd_rawmidi_ready(substream))
+			wake_up(&runtime->sleep);
+	}
+	return count;
+}
+EXPORT_SYMBOL(__snd_rawmidi_transmit_ack);
+
+/**
  * snd_rawmidi_transmit_ack - acknowledge the transmission
  * @substream: the rawmidi substream
  * @count: the transferred count
@@ -1124,26 +1172,14 @@
  */
 int snd_rawmidi_transmit_ack(struct snd_rawmidi_substream *substream, int count)
 {
-	unsigned long flags;
 	struct snd_rawmidi_runtime *runtime = substream->runtime;
+	int result;
+	unsigned long flags;
 
-	if (runtime->buffer == NULL) {
-		rmidi_dbg(substream->rmidi,
-			  "snd_rawmidi_transmit_ack: output is not active!!!\n");
-		return -EINVAL;
-	}
 	spin_lock_irqsave(&runtime->lock, flags);
-	snd_BUG_ON(runtime->avail + count > runtime->buffer_size);
-	runtime->hw_ptr += count;
-	runtime->hw_ptr %= runtime->buffer_size;
-	runtime->avail += count;
-	substream->bytes += count;
-	if (count > 0) {
-		if (runtime->drain || snd_rawmidi_ready(substream))
-			wake_up(&runtime->sleep);
-	}
+	result = __snd_rawmidi_transmit_ack(substream, count);
 	spin_unlock_irqrestore(&runtime->lock, flags);
-	return count;
+	return result;
 }
 EXPORT_SYMBOL(snd_rawmidi_transmit_ack);
 
@@ -1160,12 +1196,22 @@
 int snd_rawmidi_transmit(struct snd_rawmidi_substream *substream,
 			 unsigned char *buffer, int count)
 {
+	struct snd_rawmidi_runtime *runtime = substream->runtime;
+	int result;
+	unsigned long flags;
+
+	spin_lock_irqsave(&runtime->lock, flags);
 	if (!substream->opened)
-		return -EBADFD;
-	count = snd_rawmidi_transmit_peek(substream, buffer, count);
-	if (count < 0)
-		return count;
-	return snd_rawmidi_transmit_ack(substream, count);
+		result = -EBADFD;
+	else {
+		count = __snd_rawmidi_transmit_peek(substream, buffer, count);
+		if (count <= 0)
+			result = count;
+		else
+			result = __snd_rawmidi_transmit_ack(substream, count);
+	}
+	spin_unlock_irqrestore(&runtime->lock, flags);
+	return result;
 }
 EXPORT_SYMBOL(snd_rawmidi_transmit);