bpo-43751: Fix anext() bug where it erroneously returned None (GH-25238)

diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py
index 99464e3..77c15c0 100644
--- a/Lib/test/test_asyncgen.py
+++ b/Lib/test/test_asyncgen.py
@@ -372,11 +372,8 @@ def tearDown(self):
         self.loop = None
         asyncio.set_event_loop_policy(None)
 
-    def test_async_gen_anext(self):
-        async def gen():
-            yield 1
-            yield 2
-        g = gen()
+    def check_async_iterator_anext(self, ait_class):
+        g = ait_class()
         async def consume():
             results = []
             results.append(await anext(g))
@@ -388,6 +385,66 @@ async def consume():
         with self.assertRaises(StopAsyncIteration):
             self.loop.run_until_complete(consume())
 
+        async def test_2():
+            g1 = ait_class()
+            self.assertEqual(await anext(g1), 1)
+            self.assertEqual(await anext(g1), 2)
+            with self.assertRaises(StopAsyncIteration):
+                await anext(g1)
+            with self.assertRaises(StopAsyncIteration):
+                await anext(g1)
+
+            g2 = ait_class()
+            self.assertEqual(await anext(g2, "default"), 1)
+            self.assertEqual(await anext(g2, "default"), 2)
+            self.assertEqual(await anext(g2, "default"), "default")
+            self.assertEqual(await anext(g2, "default"), "default")
+
+            return "completed"
+
+        result = self.loop.run_until_complete(test_2())
+        self.assertEqual(result, "completed")
+
+    def test_async_generator_anext(self):
+        async def agen():
+            yield 1
+            yield 2
+        self.check_async_iterator_anext(agen)
+
+    def test_python_async_iterator_anext(self):
+        class MyAsyncIter:
+            """Asynchronously yield 1, then 2."""
+            def __init__(self):
+                self.yielded = 0
+            def __aiter__(self):
+                return self
+            async def __anext__(self):
+                if self.yielded >= 2:
+                    raise StopAsyncIteration()
+                else:
+                    self.yielded += 1
+                    return self.yielded
+        self.check_async_iterator_anext(MyAsyncIter)
+
+    def test_python_async_iterator_types_coroutine_anext(self):
+        import types
+        class MyAsyncIterWithTypesCoro:
+            """Asynchronously yield 1, then 2."""
+            def __init__(self):
+                self.yielded = 0
+            def __aiter__(self):
+                return self
+            @types.coroutine
+            def __anext__(self):
+                if False:
+                    yield "this is a generator-based coroutine"
+                if self.yielded >= 2:
+                    raise StopAsyncIteration()
+                else:
+                    self.yielded += 1
+                    return self.yielded
+        self.check_async_iterator_anext(MyAsyncIterWithTypesCoro)
+
     def test_async_gen_aiter(self):
         async def gen():
             yield 1
@@ -431,12 +488,85 @@ async def call_with_too_many_args():
             await anext(gen(), 1, 3)
         async def call_with_wrong_type_args():
             await anext(1, gen())
+        async def call_with_kwarg():
+            await anext(aiterator=gen())
         with self.assertRaises(TypeError):
             self.loop.run_until_complete(call_with_too_few_args())
         with self.assertRaises(TypeError):
             self.loop.run_until_complete(call_with_too_many_args())
         with self.assertRaises(TypeError):
             self.loop.run_until_complete(call_with_wrong_type_args())
+        with self.assertRaises(TypeError):
+            self.loop.run_until_complete(call_with_kwarg())
+
+    def test_anext_bad_await(self):
+        async def bad_awaitable():
+            class BadAwaitable:
+                def __await__(self):
+                    return 42
+            class MyAsyncIter:
+                def __aiter__(self):
+                    return self
+                def __anext__(self):
+                    return BadAwaitable()
+            regex = r"__await__.*iterator"
+            awaitable = anext(MyAsyncIter(), "default")
+            with self.assertRaisesRegex(TypeError, regex):
+                await awaitable
+            awaitable = anext(MyAsyncIter())
+            with self.assertRaisesRegex(TypeError, regex):
+                await awaitable
+            return "completed"
+        result = self.loop.run_until_complete(bad_awaitable())
+        self.assertEqual(result, "completed")
+
+    async def check_anext_returning_iterator(self, aiter_class):
+        awaitable = anext(aiter_class(), "default")
+        with self.assertRaises(TypeError):
+            await awaitable
+        awaitable = anext(aiter_class())
+        with self.assertRaises(TypeError):
+            await awaitable
+        return "completed"
+
+    def test_anext_return_iterator(self):
+        class WithIterAnext:
+            def __aiter__(self):
+                return self
+            def __anext__(self):
+                return iter("abc")
+        result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithIterAnext))
+        self.assertEqual(result, "completed")
+
+    def test_anext_return_generator(self):
+        class WithGenAnext:
+            def __aiter__(self):
+                return self
+            def __anext__(self):
+                yield
+        result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithGenAnext))
+        self.assertEqual(result, "completed")
+
+    def test_anext_await_raises(self):
+        class RaisingAwaitable:
+            def __await__(self):
+                raise ZeroDivisionError()
+                yield
+        class WithRaisingAwaitableAnext:
+            def __aiter__(self):
+                return self
+            def __anext__(self):
+                return RaisingAwaitable()
+        async def do_test():
+            awaitable = anext(WithRaisingAwaitableAnext())
+            with self.assertRaises(ZeroDivisionError):
+                await awaitable
+            awaitable = anext(WithRaisingAwaitableAnext(), "default")
+            with self.assertRaises(ZeroDivisionError):
+                await awaitable
+            return "completed"
+        result = self.loop.run_until_complete(do_test())
+        self.assertEqual(result, "completed")
 
     def test_aiter_bad_args(self):
         async def gen():