8025209: Intermittent test failure java/net/Socket/asyncClose/AsyncClose.java

Co-authored-by: Eric Wang <yiming.wang@oracle.com>
Reviewed-by: chegar
diff --git a/jdk/test/java/net/Socket/asyncClose/AsyncClose.java b/jdk/test/java/net/Socket/asyncClose/AsyncClose.java
index db9dddf..4de5d48 100644
--- a/jdk/test/java/net/Socket/asyncClose/AsyncClose.java
+++ b/jdk/test/java/net/Socket/asyncClose/AsyncClose.java
@@ -21,15 +21,17 @@
  * questions.
  */
 
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import static java.util.concurrent.CompletableFuture.*;
+
 /*
  * @test
  * @bug 4344135
  * @summary Check that {Socket,ServerSocket,DatagramSocket}.close will
  *          cause any thread blocked on the socket to throw a SocketException.
- * @run main/timeout=60 AsyncClose
  */
-import java.net.*;
-import java.io.*;
 
 public class AsyncClose {
 
@@ -37,34 +39,35 @@
 
         AsyncCloseTest tests[] = {
             new Socket_getInputStream_read(),
-            new Socket_getInputStream_read(5000),
+            new Socket_getInputStream_read(20000),
             new Socket_getOutputStream_write(),
             new DatagramSocket_receive(),
-            new DatagramSocket_receive(5000),
+            new DatagramSocket_receive(20000),
             new ServerSocket_accept(),
-            new ServerSocket_accept(5000),
+            new ServerSocket_accept(20000),
         };
 
         int failures = 0;
 
-        for (int i=0; i<tests.length; i++) {
-            AsyncCloseTest tst = tests[i];
+        List<CompletableFuture<AsyncCloseTest>> cfs = new ArrayList<>();
+        for (AsyncCloseTest test : tests)
+            cfs.add( supplyAsync(() -> test.go()));
+
+        for (CompletableFuture<AsyncCloseTest> cf : cfs) {
+            AsyncCloseTest test = cf.get();
 
             System.out.println("******************************");
-            System.out.println("Test: " + tst.description());
-
-            if (tst.go()) {
+            System.out.println("Test: " + test.description());
+            if (test.hasPassed()) {
                 System.out.println("Passed.");
             } else {
-                System.out.println("Failed: " + tst.failureReason());
+                System.out.println("Failed: " + test.failureReason());
                 failures++;
             }
             System.out.println("");
-
         }
 
-        if (failures > 0) {
+        if (failures > 0)
             throw new Exception(failures + " sub-tests failed - see log.");
-        }
     }
 }
diff --git a/jdk/test/java/net/Socket/asyncClose/AsyncCloseTest.java b/jdk/test/java/net/Socket/asyncClose/AsyncCloseTest.java
index ef6fc43..e7fd6fc 100644
--- a/jdk/test/java/net/Socket/asyncClose/AsyncCloseTest.java
+++ b/jdk/test/java/net/Socket/asyncClose/AsyncCloseTest.java
@@ -29,11 +29,22 @@
 
     public abstract String description();
 
-    public abstract boolean go() throws Exception;
+    public abstract AsyncCloseTest go();
 
+    public synchronized boolean hasPassed() {
+        return passed;
+    }
 
-    protected synchronized void failed(String reason) {
-        this.reason = reason;
+    protected synchronized AsyncCloseTest passed() {
+        if (reason == null)
+            passed = true;
+        return this;
+    }
+
+    protected synchronized AsyncCloseTest failed(String r) {
+        passed = false;
+        reason = r;
+        return this;
     }
 
     public synchronized String failureReason() {
@@ -48,7 +59,7 @@
         return closed;
     }
 
+    private boolean passed;
     private String reason;
     private boolean closed;
-
 }
diff --git a/jdk/test/java/net/Socket/asyncClose/DatagramSocket_receive.java b/jdk/test/java/net/Socket/asyncClose/DatagramSocket_receive.java
index b35ebdf..48a842e 100644
--- a/jdk/test/java/net/Socket/asyncClose/DatagramSocket_receive.java
+++ b/jdk/test/java/net/Socket/asyncClose/DatagramSocket_receive.java
@@ -26,20 +26,25 @@
  * throws a SocketException if the socket is asynchronously closed.
  */
 import java.net.*;
+import java.util.concurrent.CountDownLatch;
 
 public class DatagramSocket_receive extends AsyncCloseTest implements Runnable {
-    DatagramSocket s;
-    int timeout = 0;
+    private final DatagramSocket s;
+    private final int timeout;
+    private final CountDownLatch latch;
 
-    public DatagramSocket_receive() {
+    public DatagramSocket_receive() throws SocketException {
+        this(0);
     }
 
-    public DatagramSocket_receive(int timeout) {
+    public DatagramSocket_receive(int timeout) throws SocketException {
         this.timeout = timeout;
+        latch = new CountDownLatch(1);
+        s = new DatagramSocket();
     }
 
     public String description() {
-        String s = "DatagramSocket.receive";
+        String s = "DatagramSocket.receive(DatagramPacket)";
         if (timeout > 0) {
             s += " (timeout specified)";
         }
@@ -47,46 +52,45 @@
     }
 
     public void run() {
-        DatagramPacket p;
         try {
-
             byte b[] = new byte[1024];
-            p  = new DatagramPacket(b, b.length);
-
+            DatagramPacket p  = new DatagramPacket(b, b.length);
             if (timeout > 0) {
                 s.setSoTimeout(timeout);
             }
-        } catch (Exception e) {
-            failed(e.getMessage());
-            return;
-        }
-
-        try {
+            latch.countDown();
             s.receive(p);
+            failed("DatagramSocket.receive(DatagramPacket) returned unexpectly!!");
         } catch (SocketException se) {
-            closed();
+            if (latch.getCount() != 1) {
+                closed();
+            }
         } catch (Exception e) {
             failed(e.getMessage());
+        } finally {
+            if (latch.getCount() == 1) {
+                latch.countDown();
+            }
         }
     }
 
-    public boolean go() throws Exception {
-        s = new DatagramSocket();
+    public AsyncCloseTest go() {
+        try {
+            Thread thr = new Thread(this);
+            thr.start();
+            latch.await();
+            Thread.sleep(5000); //sleep, so receive(DatagramPacket) can block
+            s.close();
+            thr.join();
 
-        Thread thr = new Thread(this);
-        thr.start();
-
-        Thread.currentThread().sleep(1000);
-
-        s.close();
-
-        Thread.currentThread().sleep(1000);
-
-        if (isClosed()) {
-            return true;
-        } else {
-            failed("DatagramSocket.receive wasn't preempted");
-            return false;
+            if (isClosed()) {
+                return passed();
+            } else {
+                return failed("DatagramSocket.receive(DatagramPacket) wasn't preempted");
+            }
+        } catch (Exception x) {
+            failed(x.getMessage());
+            throw new RuntimeException(x);
         }
     }
 }
diff --git a/jdk/test/java/net/Socket/asyncClose/ServerSocket_accept.java b/jdk/test/java/net/Socket/asyncClose/ServerSocket_accept.java
index 07fbcc9..0461ba5 100644
--- a/jdk/test/java/net/Socket/asyncClose/ServerSocket_accept.java
+++ b/jdk/test/java/net/Socket/asyncClose/ServerSocket_accept.java
@@ -25,17 +25,23 @@
  * Tests that a thread blocked in ServerSocket.accept
  * throws a SocketException if the socket is asynchronously closed.
  */
+import java.io.IOException;
 import java.net.*;
+import java.util.concurrent.CountDownLatch;
 
 public class ServerSocket_accept extends AsyncCloseTest implements Runnable {
-    ServerSocket ss;
-    int timeout = 0;
+    private final ServerSocket ss;
+    private final int timeout;
+    private final CountDownLatch latch;
 
-    public ServerSocket_accept() {
+    public ServerSocket_accept() throws IOException {
+       this(0);
     }
 
-    public ServerSocket_accept(int timeout) {
+    public ServerSocket_accept(int timeout) throws IOException {
         this.timeout = timeout;
+        latch = new CountDownLatch(1);
+        ss = new ServerSocket(0);
     }
 
     public String description() {
@@ -48,7 +54,9 @@
 
     public void run() {
         try {
+            latch.countDown();
             Socket s = ss.accept();
+            failed("ServerSocket.accept() returned unexpectly!!");
         } catch (SocketException se) {
             closed();
         } catch (Exception e) {
@@ -56,23 +64,23 @@
         }
     }
 
-    public boolean go() throws Exception {
-        ss = new ServerSocket(0);
+    public AsyncCloseTest go(){
+        try {
+            Thread thr = new Thread(this);
+            thr.start();
+            latch.await();
+            Thread.sleep(5000); //sleep, so ServerSocket.accept() can block
+            ss.close();
+            thr.join();
 
-        Thread thr = new Thread(this);
-        thr.start();
-
-        Thread.currentThread().sleep(1000);
-
-        ss.close();
-
-        Thread.currentThread().sleep(1000);
-
-        if (isClosed()) {
-            return true;
-        } else {
-            failed("ServerSocket.accept() wasn't preempted");
-            return false;
+            if (isClosed()) {
+                return passed();
+            } else {
+                return failed("ServerSocket.accept() wasn't preempted");
+            }
+        } catch (Exception x) {
+            failed(x.getMessage());
+            throw new RuntimeException(x);
         }
     }
 }
diff --git a/jdk/test/java/net/Socket/asyncClose/Socket_getInputStream_read.java b/jdk/test/java/net/Socket/asyncClose/Socket_getInputStream_read.java
index a4d7602..25f8589 100644
--- a/jdk/test/java/net/Socket/asyncClose/Socket_getInputStream_read.java
+++ b/jdk/test/java/net/Socket/asyncClose/Socket_getInputStream_read.java
@@ -27,16 +27,21 @@
  */
 import java.net.*;
 import java.io.*;
+import java.util.concurrent.CountDownLatch;
 
 public class Socket_getInputStream_read extends AsyncCloseTest implements Runnable {
-    Socket s;
-    int timeout = 0;
+    private final Socket s;
+    private final int timeout;
+    private final CountDownLatch latch;
 
     public Socket_getInputStream_read() {
+        this(0);
     }
 
     public Socket_getInputStream_read(int timeout) {
         this.timeout = timeout;
+        latch = new CountDownLatch(1);
+        s = new Socket();
     }
 
     public String description() {
@@ -48,53 +53,48 @@
     }
 
     public void run() {
-        InputStream in;
-
         try {
-            in = s.getInputStream();
+            InputStream in = s.getInputStream();
             if (timeout > 0) {
                 s.setSoTimeout(timeout);
             }
-        } catch (Exception e) {
-            failed(e.getMessage());
-            return;
-        }
-
-        try {
+            latch.countDown();
             int n = in.read();
-            failed("getInptuStream().read() returned unexpectly!!");
+            failed("Socket.getInputStream().read() returned unexpectly!!");
         } catch (SocketException se) {
-            closed();
+            if (latch.getCount() != 1) {
+                closed();
+            }
         } catch (Exception e) {
             failed(e.getMessage());
+        } finally {
+            if (latch.getCount() == 1) {
+                latch.countDown();
+            }
         }
     }
 
-    public boolean go() throws Exception {
+    public AsyncCloseTest go() {
+        try {
+            ServerSocket ss = new ServerSocket(0);
+            InetAddress lh = InetAddress.getLocalHost();
+            s.connect( new InetSocketAddress(lh, ss.getLocalPort()) );
+            Socket s2 = ss.accept();
+            Thread thr = new Thread(this);
+            thr.start();
+            latch.await();
+            Thread.sleep(5000); //sleep, so Socket.getInputStream().read() can block
+            s.close();
+            thr.join();
 
-        ServerSocket ss = new ServerSocket(0);
-
-        InetAddress lh = InetAddress.getLocalHost();
-        s = new Socket();
-        s.connect( new InetSocketAddress(lh, ss.getLocalPort()) );
-
-        Socket s2 = ss.accept();
-
-        Thread thr = new Thread(this);
-        thr.start();
-
-        Thread.currentThread().sleep(1000);
-
-        s.close();
-
-        Thread.currentThread().sleep(1000);
-
-        if (isClosed()) {
-            return true;
-        } else {
-            failed("getInputStream().read() wasn't preempted");
-            return false;
+            if (isClosed()) {
+                return passed();
+            } else {
+                return failed("Socket.getInputStream().read() wasn't preempted");
+            }
+        } catch (Exception x) {
+            failed(x.getMessage());
+            throw new RuntimeException(x);
         }
-
     }
 }
diff --git a/jdk/test/java/net/Socket/asyncClose/Socket_getOutputStream_write.java b/jdk/test/java/net/Socket/asyncClose/Socket_getOutputStream_write.java
index 841861d..78cbf45 100644
--- a/jdk/test/java/net/Socket/asyncClose/Socket_getOutputStream_write.java
+++ b/jdk/test/java/net/Socket/asyncClose/Socket_getOutputStream_write.java
@@ -27,9 +27,16 @@
  */
 import java.net.*;
 import java.io.*;
+import java.util.concurrent.CountDownLatch;
 
 public class Socket_getOutputStream_write extends AsyncCloseTest implements Runnable {
-    Socket s;
+    private final Socket s;
+    private final CountDownLatch latch;
+
+    public Socket_getOutputStream_write() {
+        latch = new CountDownLatch(1);
+        s = new Socket();
+    }
 
     public String description() {
         return "Socket.getOutputStream().write()";
@@ -38,40 +45,45 @@
     public void run() {
         try {
             OutputStream out = s.getOutputStream();
+            byte b[] = new byte[8192];
+            latch.countDown();
             for (;;) {
-                byte b[] = new byte[8192];
                 out.write(b);
             }
         } catch (SocketException se) {
-            closed();
+            if (latch.getCount() != 1) {
+                closed();
+            }
         } catch (Exception e) {
             failed(e.getMessage());
+        } finally {
+            if (latch.getCount() == 1) {
+                latch.countDown();
+            }
         }
     }
 
-    public boolean go() throws Exception {
-        ServerSocket ss = new ServerSocket(0);
+    public AsyncCloseTest go() {
+        try {
+            ServerSocket ss = new ServerSocket(0);
+            InetAddress lh = InetAddress.getLocalHost();
+            s.connect( new InetSocketAddress(lh, ss.getLocalPort()) );
+            Socket s2 = ss.accept();
+            Thread thr = new Thread(this);
+            thr.start();
+            latch.await();
+            Thread.sleep(1000);
+            s.close();
+            thr.join();
 
-        InetAddress lh = InetAddress.getLocalHost();
-        s = new Socket();
-        s.connect( new InetSocketAddress(lh, ss.getLocalPort()) );
-
-        Socket s2 = ss.accept();
-
-        Thread thr = new Thread(this);
-        thr.start();
-
-        Thread.currentThread().sleep(2000);
-
-        s.close();
-
-        Thread.currentThread().sleep(2000);
-
-        if (isClosed()) {
-            return true;
-        } else {
-            failed("getOutputStream().write() wasn't preempted");
-            return false;
+            if (isClosed()) {
+                return passed();
+            } else {
+                return failed("Socket.getOutputStream().write() wasn't preempted");
+            }
+        } catch (Exception x) {
+            failed(x.getMessage());
+            throw new RuntimeException(x);
         }
     }
 }