Don't call overriden methods with @OnLifecycleEvent twice

bug:62658006
Test: ReflectiveGenericLifecycleObserver
Change-Id: I382a6afc565a8a13902e31d44044bdfc4bd22f00
diff --git a/lifecycle/common/src/main/java/android/arch/lifecycle/ReflectiveGenericLifecycleObserver.java b/lifecycle/common/src/main/java/android/arch/lifecycle/ReflectiveGenericLifecycleObserver.java
index 51cc94c..b0761aa 100644
--- a/lifecycle/common/src/main/java/android/arch/lifecycle/ReflectiveGenericLifecycleObserver.java
+++ b/lifecycle/common/src/main/java/android/arch/lifecycle/ReflectiveGenericLifecycleObserver.java
@@ -17,7 +17,6 @@
 package android.arch.lifecycle;
 
 import android.arch.lifecycle.Lifecycle.Event;
-import android.support.annotation.Nullable;
 
 import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.Method;
@@ -25,6 +24,7 @@
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Map.Entry;
 
 /**
  * An internal implementation of {@link GenericLifecycleObserver} that relies on reflection.
@@ -57,19 +57,8 @@
 
     @SuppressWarnings("ConstantConditions")
     private void invokeCallbacks(CallbackInfo info, LifecycleOwner source, Event event) {
-        invokeMethodsForEvent(info.mEventHandlers.get(event), source, event);
-        invokeMethodsForEvent(info.mEventHandlers.get(Event.ON_ANY), source, event);
-
-        // TODO prevent duplicate calls into the same method. Preferably while parsing
-        if (info.mSuper != null) {
-            invokeCallbacks(info.mSuper, source, event);
-        }
-        if (info.mInterfaces != null) {
-            final int size = info.mInterfaces.size();
-            for (int i = 0; i < size; i++) {
-                invokeCallbacks(info.mInterfaces.get(i), source, event);
-            }
-        }
+        invokeMethodsForEvent(info.mEventToHandlers.get(event), source, event);
+        invokeMethodsForEvent(info.mEventToHandlers.get(Event.ON_ANY), source, event);
     }
 
     private void invokeCallback(MethodReference reference, LifecycleOwner source, Event event) {
@@ -107,9 +96,40 @@
         return existing;
     }
 
+    private static void verifyAndPutHandler(Map<MethodReference, Event> handlers,
+            MethodReference newHandler, Event newEvent, Class klass) {
+        Event event = handlers.get(newHandler);
+        if (event != null && newEvent != event) {
+            Method method = newHandler.mMethod;
+            throw new IllegalArgumentException(
+                    "Method " + method.getName() + " in " + klass.getName()
+                            + " already declared with different @OnLifecycleEvent value: previous"
+                            + " value " + event + ", new value " + newEvent);
+        }
+        if (event == null) {
+            handlers.put(newHandler, newEvent);
+        }
+    }
+
     private static CallbackInfo createInfo(Class klass) {
+        Class superclass = klass.getSuperclass();
+        Map<MethodReference, Event> handlerToEvent = new HashMap<>();
+        if (superclass != null) {
+            CallbackInfo superInfo = getInfo(superclass);
+            if (superInfo != null) {
+                handlerToEvent.putAll(superInfo.mHandlerToEvent);
+            }
+        }
+
         Method[] methods = klass.getDeclaredMethods();
-        Map<Event, List<MethodReference>> eventHandlers = new HashMap<>();
+
+        Class[] interfaces = klass.getInterfaces();
+        for (Class intrfc : interfaces) {
+            for (Entry<MethodReference, Event> entry : getInfo(intrfc).mHandlerToEvent.entrySet()) {
+                verifyAndPutHandler(handlerToEvent, entry.getKey(), entry.getValue(), klass);
+            }
+        }
+
         for (Method method : methods) {
             OnLifecycleEvent annotation = method.getAnnotation(OnLifecycleEvent.class);
             if (annotation == null) {
@@ -135,57 +155,62 @@
                 throw new IllegalArgumentException("cannot have more than 2 params");
             }
             Event event = annotation.value();
-            List<MethodReference> methodReferences = eventHandlers.get(event);
-            if (methodReferences == null) {
-                methodReferences = new ArrayList<>();
-                eventHandlers.put(event, methodReferences);
-            }
-            methodReferences.add(new MethodReference(event, callType, method));
+            MethodReference methodReference = new MethodReference(callType, method);
+            verifyAndPutHandler(handlerToEvent, methodReference, event, klass);
         }
-        CallbackInfo info = new CallbackInfo(eventHandlers);
+        CallbackInfo info = new CallbackInfo(handlerToEvent);
         sInfoCache.put(klass, info);
-        Class superclass = klass.getSuperclass();
-        if (superclass != null) {
-            info.mSuper = getInfo(superclass);
-        }
-        Class[] interfaces = klass.getInterfaces();
-        for (Class intrfc : interfaces) {
-            CallbackInfo interfaceInfo = getInfo(intrfc);
-            if (!interfaceInfo.mEventHandlers.isEmpty()) {
-                if (info.mInterfaces == null) {
-                    info.mInterfaces = new ArrayList<>();
-                }
-                info.mInterfaces.add(interfaceInfo);
-            }
-        }
         return info;
     }
 
     @SuppressWarnings("WeakerAccess")
     static class CallbackInfo {
-        final Map<Event, List<MethodReference>> mEventHandlers;
-        @Nullable
-        List<CallbackInfo> mInterfaces;
-        @Nullable
-        CallbackInfo mSuper;
+        final Map<Event, List<MethodReference>> mEventToHandlers;
+        final Map<MethodReference, Event> mHandlerToEvent;
 
-        CallbackInfo(Map<Event, List<MethodReference>> eventHandlers) {
-            mEventHandlers = eventHandlers;
+        CallbackInfo(Map<MethodReference, Event> handlerToEvent) {
+            mHandlerToEvent = handlerToEvent;
+            mEventToHandlers = new HashMap<>();
+            for (Entry<MethodReference, Event> entry : handlerToEvent.entrySet()) {
+                Event event = entry.getValue();
+                List<MethodReference> methodReferences = mEventToHandlers.get(event);
+                if (methodReferences == null) {
+                    methodReferences = new ArrayList<>();
+                    mEventToHandlers.put(event, methodReferences);
+                }
+                methodReferences.add(entry.getKey());
+            }
         }
     }
 
     @SuppressWarnings("WeakerAccess")
     static class MethodReference {
-        final Event mEvent;
         final int mCallType;
         final Method mMethod;
 
-        MethodReference(Event event, int callType, Method method) {
-            mEvent = event;
+        MethodReference(int callType, Method method) {
             mCallType = callType;
             mMethod = method;
             mMethod.setAccessible(true);
         }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) {
+                return true;
+            }
+            if (o == null || getClass() != o.getClass()) {
+                return false;
+            }
+
+            MethodReference that = (MethodReference) o;
+            return mCallType == that.mCallType && mMethod.getName().equals(that.mMethod.getName());
+        }
+
+        @Override
+        public int hashCode() {
+            return 31 * mCallType + mMethod.getName().hashCode();
+        }
     }
 
     private static final int CALL_TYPE_NO_ARG = 0;
diff --git a/lifecycle/extensions/src/test/java/android/arch/lifecycle/ReflectiveGenericLifecycleObserverTest.java b/lifecycle/common/src/test/java/android/arch/lifecycle/ReflectiveGenericLifecycleObserverTest.java
similarity index 69%
rename from lifecycle/extensions/src/test/java/android/arch/lifecycle/ReflectiveGenericLifecycleObserverTest.java
rename to lifecycle/common/src/test/java/android/arch/lifecycle/ReflectiveGenericLifecycleObserverTest.java
index 07901ba..faa7e88 100644
--- a/lifecycle/extensions/src/test/java/android/arch/lifecycle/ReflectiveGenericLifecycleObserverTest.java
+++ b/lifecycle/common/src/test/java/android/arch/lifecycle/ReflectiveGenericLifecycleObserverTest.java
@@ -31,6 +31,7 @@
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.core.Is.is;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.reset;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
@@ -39,6 +40,7 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
+import org.mockito.Matchers;
 
 @RunWith(JUnit4.class)
 public class ReflectiveGenericLifecycleObserverTest {
@@ -198,17 +200,17 @@
     @Test
     public void testPrivateObserverMethods() {
         class ObserverWithPrivateMethod implements LifecycleObserver {
-            boolean called = false;
+            boolean mCalled = false;
             @OnLifecycleEvent(ON_START)
             private void started() {
-                called = true;
+                mCalled = true;
             }
         }
 
         ObserverWithPrivateMethod obj = mock(ObserverWithPrivateMethod.class);
         ReflectiveGenericLifecycleObserver observer = new ReflectiveGenericLifecycleObserver(obj);
         observer.onStateChanged(mOwner, ON_START);
-        assertThat(obj.called, is(true));
+        assertThat(obj.mCalled, is(true));
     }
 
     @Test(expected = IllegalArgumentException.class)
@@ -241,4 +243,119 @@
         new ReflectiveGenericLifecycleObserver(observer);
     }
 
+    class BaseClass1 implements LifecycleObserver {
+        @OnLifecycleEvent(ON_START)
+        void foo(LifecycleOwner owner) {
+        }
+    }
+
+    class DerivedClass1 extends BaseClass1 {
+        @OnLifecycleEvent(ON_STOP)
+        void foo(LifecycleOwner owner) {
+        }
+    }
+
+    @Test(expected = IllegalArgumentException.class)
+    public void testInvalidSuper1() {
+        new ReflectiveGenericLifecycleObserver(new DerivedClass1());
+    }
+
+    class BaseClass2 implements LifecycleObserver {
+        @OnLifecycleEvent(ON_START)
+        void foo(LifecycleOwner owner) {
+        }
+    }
+
+    class DerivedClass2 extends BaseClass1 {
+        @OnLifecycleEvent(ON_STOP)
+        void foo() {
+        }
+    }
+
+    @Test
+    public void testValidSuper1() {
+        DerivedClass2 obj = mock(DerivedClass2.class);
+        ReflectiveGenericLifecycleObserver observer = new ReflectiveGenericLifecycleObserver(obj);
+        observer.onStateChanged(mock(LifecycleOwner.class), ON_START);
+        verify(obj).foo(Matchers.<LifecycleOwner>any());
+        verify(obj, never()).foo();
+        reset(obj);
+        observer.onStateChanged(mock(LifecycleOwner.class), ON_STOP);
+        verify(obj).foo();
+        verify(obj, never()).foo(Matchers.<LifecycleOwner>any());
+    }
+
+    class BaseClass3 implements LifecycleObserver {
+        @OnLifecycleEvent(ON_START)
+        void foo(LifecycleOwner owner) {
+        }
+    }
+
+    interface Interface3 extends LifecycleObserver {
+        @OnLifecycleEvent(ON_STOP)
+        void foo(LifecycleOwner owner);
+    }
+
+    class DerivedClass3 extends BaseClass3 implements Interface3 {
+        @Override
+        public void foo(LifecycleOwner owner) {
+        }
+    }
+
+    @Test(expected = IllegalArgumentException.class)
+    public void testInvalidSuper2() {
+        new ReflectiveGenericLifecycleObserver(new DerivedClass3());
+    }
+
+    class BaseClass4 implements LifecycleObserver {
+        @OnLifecycleEvent(ON_START)
+        void foo(LifecycleOwner owner) {
+        }
+    }
+
+    interface Interface4 extends LifecycleObserver {
+        @OnLifecycleEvent(ON_START)
+        void foo(LifecycleOwner owner);
+    }
+
+    class DerivedClass4 extends BaseClass4 implements Interface4 {
+        @Override
+        @OnLifecycleEvent(ON_START)
+        public void foo(LifecycleOwner owner) {
+        }
+
+        @OnLifecycleEvent(ON_START)
+        public void foo() {
+        }
+    }
+
+    @Test
+    public void testValidSuper2() {
+        DerivedClass4 obj = mock(DerivedClass4.class);
+        ReflectiveGenericLifecycleObserver observer = new ReflectiveGenericLifecycleObserver(obj);
+        observer.onStateChanged(mock(LifecycleOwner.class), ON_START);
+        verify(obj).foo(Matchers.<LifecycleOwner>any());
+        verify(obj).foo();
+    }
+
+    interface InterfaceStart extends LifecycleObserver {
+        @OnLifecycleEvent(ON_START)
+        void foo(LifecycleOwner owner);
+    }
+
+    interface InterfaceStop extends LifecycleObserver {
+        @OnLifecycleEvent(ON_STOP)
+        void foo(LifecycleOwner owner);
+    }
+
+    class DerivedClass5 implements InterfaceStart, InterfaceStop {
+        @Override
+        public void foo(LifecycleOwner owner) {
+        }
+    }
+
+    @Test(expected = IllegalArgumentException.class)
+    public void testInvalidSuper3() {
+        new ReflectiveGenericLifecycleObserver(new DerivedClass5());
+    }
 }
diff --git a/room/runtime/src/test/java/android/arch/persistence/room/InvalidationTrackerTest.java b/room/runtime/src/test/java/android/arch/persistence/room/InvalidationTrackerTest.java
index 5bf4dcf..bd9ab4a 100644
--- a/room/runtime/src/test/java/android/arch/persistence/room/InvalidationTrackerTest.java
+++ b/room/runtime/src/test/java/android/arch/persistence/room/InvalidationTrackerTest.java
@@ -258,14 +258,21 @@
                 return index.addAndGet(2) < keyValuePairs.length;
             }
         });
-        Answer<Integer> answer = new Answer<Integer>() {
+        Answer<Integer> intAnswer = new Answer<Integer>() {
             @Override
             public Integer answer(InvocationOnMock invocation) throws Throwable {
                 return keyValuePairs[index.intValue() + (Integer) invocation.getArguments()[0]];
             }
         };
-        when(cursor.getInt(anyInt())).thenAnswer(answer);
-        when(cursor.getLong(anyInt())).thenAnswer(answer);
+        Answer<Long> longAnswer = new Answer<Long>() {
+            @Override
+            public Long answer(InvocationOnMock invocation) throws Throwable {
+                return (long) keyValuePairs[index.intValue()
+                        + (Integer) invocation.getArguments()[0]];
+            }
+        };
+        when(cursor.getInt(anyInt())).thenAnswer(intAnswer);
+        when(cursor.getLong(anyInt())).thenAnswer(longAnswer);
         return cursor;
     }