Merge pull request #8084 from jtattermusch/throw_rpcexception_on_failure

Fix wrong exceptions being thrown on send failure.
diff --git a/src/csharp/Grpc.Core.Tests/Internal/AsyncCallServerTest.cs b/src/csharp/Grpc.Core.Tests/Internal/AsyncCallServerTest.cs
index c35aaf6..0979012 100644
--- a/src/csharp/Grpc.Core.Tests/Internal/AsyncCallServerTest.cs
+++ b/src/csharp/Grpc.Core.Tests/Internal/AsyncCallServerTest.cs
@@ -33,6 +33,7 @@
 
 using System;
 using System.Collections.Generic;
+using System.IO;
 using System.Runtime.InteropServices;
 using System.Threading.Tasks;
 
@@ -149,8 +150,7 @@
 
             var writeTask = responseStream.WriteAsync("request1");
             fakeCall.SendCompletionHandler(false);
-            // TODO(jtattermusch): should we throw a different exception type instead?
-            Assert.ThrowsAsync(typeof(InvalidOperationException), async () => await writeTask);
+            Assert.ThrowsAsync(typeof(IOException), async () => await writeTask);
 
             fakeCall.ReceivedCloseOnServerHandler(true, cancelled: true);
             AssertFinished(asyncCallServer, fakeCall, finishedTask);
diff --git a/src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs b/src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs
index 98e27a1..616bc06 100644
--- a/src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs
+++ b/src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs
@@ -180,21 +180,74 @@
         }
 
         [Test]
-        public void ClientStreaming_WriteCompletionFailure()
+        public void ClientStreaming_WriteFailureThrowsRpcException()
         {
             var resultTask = asyncCall.ClientStreamingCallAsync();
             var requestStream = new ClientRequestStream<string, string>(asyncCall);
 
             var writeTask = requestStream.WriteAsync("request1");
             fakeCall.SendCompletionHandler(false);
-            // TODO: maybe IOException or waiting for RPCException is more appropriate here.
-            Assert.ThrowsAsync(typeof(InvalidOperationException), async () => await writeTask);
+
+            // The write will wait for call to finish to receive the status code.
+            Assert.IsFalse(writeTask.IsCompleted);
 
             fakeCall.UnaryResponseClientHandler(true,
                 CreateClientSideStatus(StatusCode.Internal),
                 null,
                 new Metadata());
 
+            var ex = Assert.ThrowsAsync<RpcException>(async () => await writeTask);
+            Assert.AreEqual(StatusCode.Internal, ex.Status.StatusCode);
+
+            AssertUnaryResponseError(asyncCall, fakeCall, resultTask, StatusCode.Internal);
+        }
+
+        [Test]
+        public void ClientStreaming_WriteFailureThrowsRpcException2()
+        {
+            var resultTask = asyncCall.ClientStreamingCallAsync();
+            var requestStream = new ClientRequestStream<string, string>(asyncCall);
+
+            var writeTask = requestStream.WriteAsync("request1");
+
+            fakeCall.UnaryResponseClientHandler(true,
+                CreateClientSideStatus(StatusCode.Internal),
+                null,
+                new Metadata());
+
+            fakeCall.SendCompletionHandler(false);
+
+            var ex = Assert.ThrowsAsync<RpcException>(async () => await writeTask);
+            Assert.AreEqual(StatusCode.Internal, ex.Status.StatusCode);
+
+            AssertUnaryResponseError(asyncCall, fakeCall, resultTask, StatusCode.Internal);
+        }
+
+        [Test]
+        public void ClientStreaming_WriteFailureThrowsRpcException3()
+        {
+            var resultTask = asyncCall.ClientStreamingCallAsync();
+            var requestStream = new ClientRequestStream<string, string>(asyncCall);
+
+            var writeTask = requestStream.WriteAsync("request1");
+            fakeCall.SendCompletionHandler(false);
+
+            // Until the delayed write completion has been triggered,
+            // we still act as if there was an active write.
+            Assert.Throws(typeof(InvalidOperationException), () => requestStream.WriteAsync("request2"));
+
+            fakeCall.UnaryResponseClientHandler(true,
+                CreateClientSideStatus(StatusCode.Internal),
+                null,
+                new Metadata());
+
+            var ex = Assert.ThrowsAsync<RpcException>(async () => await writeTask);
+            Assert.AreEqual(StatusCode.Internal, ex.Status.StatusCode);
+
+            // Following attempts to write keep delivering the same status
+            var ex2 = Assert.ThrowsAsync<RpcException>(async () => await requestStream.WriteAsync("after call has finished"));
+            Assert.AreEqual(StatusCode.Internal, ex2.Status.StatusCode);
+
             AssertUnaryResponseError(asyncCall, fakeCall, resultTask, StatusCode.Internal);
         }
 
@@ -416,6 +469,49 @@
         }
 
         [Test]
+        public void DuplexStreaming_WriteFailureThrowsRpcException()
+        {
+            asyncCall.StartDuplexStreamingCall();
+            var requestStream = new ClientRequestStream<string, string>(asyncCall);
+            var responseStream = new ClientResponseStream<string, string>(asyncCall);
+
+            var writeTask = requestStream.WriteAsync("request1");
+            fakeCall.SendCompletionHandler(false);
+
+            // The write will wait for call to finish to receive the status code.
+            Assert.IsFalse(writeTask.IsCompleted);
+
+            var readTask = responseStream.MoveNext();
+            fakeCall.ReceivedMessageHandler(true, null);
+            fakeCall.ReceivedStatusOnClientHandler(true, CreateClientSideStatus(StatusCode.PermissionDenied));
+
+            var ex = Assert.ThrowsAsync<RpcException>(async () => await writeTask);
+            Assert.AreEqual(StatusCode.PermissionDenied, ex.Status.StatusCode);
+
+            AssertStreamingResponseError(asyncCall, fakeCall, readTask, StatusCode.PermissionDenied);
+        }
+
+        [Test]
+        public void DuplexStreaming_WriteFailureThrowsRpcException2()
+        {
+            asyncCall.StartDuplexStreamingCall();
+            var requestStream = new ClientRequestStream<string, string>(asyncCall);
+            var responseStream = new ClientResponseStream<string, string>(asyncCall);
+
+            var writeTask = requestStream.WriteAsync("request1");
+
+            var readTask = responseStream.MoveNext();
+            fakeCall.ReceivedMessageHandler(true, null);
+            fakeCall.ReceivedStatusOnClientHandler(true, CreateClientSideStatus(StatusCode.PermissionDenied));
+            fakeCall.SendCompletionHandler(false);
+
+            var ex = Assert.ThrowsAsync<RpcException>(async () => await writeTask);
+            Assert.AreEqual(StatusCode.PermissionDenied, ex.Status.StatusCode);
+
+            AssertStreamingResponseError(asyncCall, fakeCall, readTask, StatusCode.PermissionDenied);
+        }
+
+        [Test]
         public void DuplexStreaming_WriteAfterCancellationRequestThrowsTaskCanceledException()
         {
             asyncCall.StartDuplexStreamingCall();
diff --git a/src/csharp/Grpc.Core/Internal/AsyncCall.cs b/src/csharp/Grpc.Core/Internal/AsyncCall.cs
index f549c52..9abaf11 100644
--- a/src/csharp/Grpc.Core/Internal/AsyncCall.cs
+++ b/src/csharp/Grpc.Core/Internal/AsyncCall.cs
@@ -341,6 +341,11 @@
             get { return true; }
         }
 
+        protected override Exception GetRpcExceptionClientOnly()
+        {
+            return new RpcException(finishedStatus.Value.Status);
+        }
+
         protected override Task CheckSendAllowedOrEarlyResult()
         {
             var earlyResult = CheckSendPreconditionsClientSide();
@@ -452,6 +457,7 @@
 
             using (Profilers.ForCurrentThread().NewScope("AsyncCall.HandleUnaryResponse"))
             {
+                TaskCompletionSource<object> delayedStreamingWriteTcs = null;
                 TResponse msg = default(TResponse);
                 var deserializeException = TryDeserialize(receivedMessage, out msg);
 
@@ -465,13 +471,23 @@
                     }
                     finishedStatus = receivedStatus;
 
+                    if (isStreamingWriteCompletionDelayed)
+                    {
+                        delayedStreamingWriteTcs = streamingWriteTcs;
+                        streamingWriteTcs = null;
+                    }
+
                     ReleaseResourcesIfPossible();
                 }
 
                 responseHeadersTcs.SetResult(responseHeaders);
 
-                var status = receivedStatus.Status;
+                if (delayedStreamingWriteTcs != null)
+                {
+                    delayedStreamingWriteTcs.SetException(GetRpcExceptionClientOnly());
+                }
 
+                var status = receivedStatus.Status;
                 if (status.StatusCode != StatusCode.OK)
                 {
                     unaryResponseTcs.SetException(new RpcException(status));
@@ -490,16 +506,27 @@
             // NOTE: because this event is a result of batch containing GRPC_OP_RECV_STATUS_ON_CLIENT,
             // success will be always set to true.
 
+            TaskCompletionSource<object> delayedStreamingWriteTcs = null;
+
             lock (myLock)
             {
                 finished = true;
                 finishedStatus = receivedStatus;
+                if (isStreamingWriteCompletionDelayed)
+                {
+                    delayedStreamingWriteTcs = streamingWriteTcs;
+                    streamingWriteTcs = null;
+                }
 
                 ReleaseResourcesIfPossible();
             }
 
-            var status = receivedStatus.Status;
+            if (delayedStreamingWriteTcs != null)
+            {
+                delayedStreamingWriteTcs.SetException(GetRpcExceptionClientOnly());
+            }
 
+            var status = receivedStatus.Status;
             if (status.StatusCode != StatusCode.OK)
             {
                 streamingCallFinishedTcs.SetException(new RpcException(status));
diff --git a/src/csharp/Grpc.Core/Internal/AsyncCallBase.cs b/src/csharp/Grpc.Core/Internal/AsyncCallBase.cs
index eb9c3ea..9f9d260 100644
--- a/src/csharp/Grpc.Core/Internal/AsyncCallBase.cs
+++ b/src/csharp/Grpc.Core/Internal/AsyncCallBase.cs
@@ -69,6 +69,7 @@
         protected TaskCompletionSource<TRead> streamingReadTcs;  // Completion of a pending streaming read if not null.
         protected TaskCompletionSource<object> streamingWriteTcs;  // Completion of a pending streaming write or send close from client if not null.
         protected TaskCompletionSource<object> sendStatusFromServerTcs;
+        protected bool isStreamingWriteCompletionDelayed;  // Only used for the client side.
 
         protected bool readingDone;  // True if last read (i.e. read with null payload) was already received.
         protected bool halfcloseRequested;  // True if send close have been initiated.
@@ -200,6 +201,12 @@
             get;
         }
 
+        /// <summary>
+        /// Returns an exception to throw for a failed send operation.
+        /// It is only allowed to call this method for a call that has already finished.
+        /// </summary>
+        protected abstract Exception GetRpcExceptionClientOnly();
+
         private void ReleaseResources()
         {
             if (call != null)
@@ -252,18 +259,43 @@
         /// </summary>
         protected void HandleSendFinished(bool success)
         {
+            bool delayCompletion = false;
             TaskCompletionSource<object> origTcs = null;
             lock (myLock)
             {
-                origTcs = streamingWriteTcs;
-                streamingWriteTcs = null;
+                if (!success && !finished && IsClient) {
+                    // We should be setting this only once per call, following writes will be short circuited
+                    // because they cannot start until the entire call finishes.
+                    GrpcPreconditions.CheckState(!isStreamingWriteCompletionDelayed);
+
+                    // leave streamingWriteTcs set, it will be completed once call finished.
+                    isStreamingWriteCompletionDelayed = true;
+                    delayCompletion = true;
+                }
+                else
+                {
+                    origTcs = streamingWriteTcs;
+                    streamingWriteTcs = null;    
+                }
 
                 ReleaseResourcesIfPossible();
             }
 
             if (!success)
             {
-                origTcs.SetException(new InvalidOperationException("Send failed"));
+                if (!delayCompletion)
+                {
+                    if (IsClient)
+                    {
+                        GrpcPreconditions.CheckState(finished);  // implied by !success && !delayCompletion && IsClient
+                        origTcs.SetException(GetRpcExceptionClientOnly());
+                    }
+                    else
+                    {
+                        origTcs.SetException (new IOException("Error sending from server."));
+                    }
+                }
+                // if delayCompletion == true, postpone SetException until call finishes.
             }
             else
             {
@@ -283,7 +315,7 @@
 
             if (!success)
             {
-                sendStatusFromServerTcs.SetException(new InvalidOperationException("Error sending status from server."));
+                sendStatusFromServerTcs.SetException(new IOException("Error sending status from server."));
             }
             else
             {
diff --git a/src/csharp/Grpc.Core/Internal/AsyncCallServer.cs b/src/csharp/Grpc.Core/Internal/AsyncCallServer.cs
index 56c23ba..50fdfa9 100644
--- a/src/csharp/Grpc.Core/Internal/AsyncCallServer.cs
+++ b/src/csharp/Grpc.Core/Internal/AsyncCallServer.cs
@@ -33,6 +33,7 @@
 
 using System;
 using System.Diagnostics;
+using System.IO;
 using System.Runtime.CompilerServices;
 using System.Runtime.InteropServices;
 using System.Threading;
@@ -193,6 +194,11 @@
             get { return false; }
         }
 
+        protected override Exception GetRpcExceptionClientOnly()
+        {
+            throw new InvalidOperationException("Call be only called for client calls");
+        }
+
         protected override void OnAfterReleaseResources()
         {
             server.RemoveCallReference(this);