[veridex] Reflection detection improvements.

- Handle invoke range instructions.
- Implement parameter substitution.

bug: 77513322
Test: m

Change-Id: I30678a73b5bb367e44edd43d7959fc428ff8ad12
diff --git a/tools/veridex/flow_analysis.cc b/tools/veridex/flow_analysis.cc
index 736abb7..e2833bf 100644
--- a/tools/veridex/flow_analysis.cc
+++ b/tools/veridex/flow_analysis.cc
@@ -243,43 +243,7 @@
     case Instruction::INVOKE_STATIC:
     case Instruction::INVOKE_SUPER:
     case Instruction::INVOKE_VIRTUAL: {
-      VeriMethod method = resolver_->GetMethod(instruction.VRegB_35c());
-      uint32_t args[5];
-      instruction.GetVarArgs(args);
-      if (method == VeriClass::forName_) {
-        RegisterValue value = GetRegister(args[0]);
-        last_result_ = RegisterValue(
-            value.GetSource(), value.GetDexFileReference(), VeriClass::class_);
-      } else if (IsGetField(method)) {
-        RegisterValue cls = GetRegister(args[0]);
-        RegisterValue name = GetRegister(args[1]);
-        field_uses_.push_back(std::make_pair(cls, name));
-        last_result_ = GetReturnType(instruction.VRegB_35c());
-      } else if (IsGetMethod(method)) {
-        RegisterValue cls = GetRegister(args[0]);
-        RegisterValue name = GetRegister(args[1]);
-        method_uses_.push_back(std::make_pair(cls, name));
-        last_result_ = GetReturnType(instruction.VRegB_35c());
-      } else if (method == VeriClass::getClass_) {
-        RegisterValue obj = GetRegister(args[0]);
-        const VeriClass* cls = obj.GetType();
-        if (cls != nullptr && cls->GetClassDef() != nullptr) {
-          const DexFile::ClassDef* def = cls->GetClassDef();
-          last_result_ = RegisterValue(
-              RegisterSource::kClass,
-              DexFileReference(&resolver_->GetDexFileOf(*cls), def->class_idx_.index_),
-              VeriClass::class_);
-        } else {
-          last_result_ = RegisterValue(
-              obj.GetSource(), obj.GetDexFileReference(), VeriClass::class_);
-        }
-      } else if (method == VeriClass::loadClass_) {
-        RegisterValue value = GetRegister(args[1]);
-        last_result_ = RegisterValue(
-            value.GetSource(), value.GetDexFileReference(), VeriClass::class_);
-      } else {
-        last_result_ = GetReturnType(instruction.VRegB_35c());
-      }
+      last_result_ = AnalyzeInvoke(instruction, /* is_range */ false);
       break;
     }
 
@@ -288,7 +252,7 @@
     case Instruction::INVOKE_STATIC_RANGE:
     case Instruction::INVOKE_SUPER_RANGE:
     case Instruction::INVOKE_VIRTUAL_RANGE: {
-      last_result_ = GetReturnType(instruction.VRegB_3rc());
+      last_result_ = AnalyzeInvoke(instruction, /* is_range */ true);
       break;
     }
 
@@ -520,6 +484,7 @@
     case Instruction::IPUT_BYTE:
     case Instruction::IPUT_CHAR:
     case Instruction::IPUT_SHORT: {
+      AnalyzeFieldSet(instruction);
       break;
     }
 
@@ -541,6 +506,7 @@
     case Instruction::SPUT_BYTE:
     case Instruction::SPUT_CHAR:
     case Instruction::SPUT_SHORT: {
+      AnalyzeFieldSet(instruction);
       break;
     }
 
@@ -613,7 +579,112 @@
 
 void VeriFlowAnalysis::Run() {
   FindBranches();
+  uint32_t number_of_registers = code_item_accessor_.RegistersSize();
+  uint32_t number_of_parameters = code_item_accessor_.InsSize();
+  std::vector<RegisterValue>& initial_values = *dex_registers_[0].get();
+  for (uint32_t i = 0; i < number_of_parameters; ++i) {
+    initial_values[number_of_registers - number_of_parameters + i] = RegisterValue(
+      RegisterSource::kParameter,
+      i,
+      DexFileReference(&resolver_->GetDexFile(), method_id_),
+      nullptr);
+  }
   AnalyzeCode();
 }
 
+static uint32_t GetParameterAt(const Instruction& instruction,
+                               bool is_range,
+                               uint32_t* args,
+                               uint32_t index) {
+  return is_range ? instruction.VRegC() + index : args[index];
+}
+
+RegisterValue FlowAnalysisCollector::AnalyzeInvoke(const Instruction& instruction, bool is_range) {
+  uint32_t id = is_range ? instruction.VRegB_3rc() : instruction.VRegB_35c();
+  VeriMethod method = resolver_->GetMethod(id);
+  uint32_t args[5];
+  if (!is_range) {
+    instruction.GetVarArgs(args);
+  }
+
+  if (method == VeriClass::forName_) {
+    // Class.forName. Fetch the first parameter.
+    RegisterValue value = GetRegister(GetParameterAt(instruction, is_range, args, 0));
+    return RegisterValue(
+        value.GetSource(), value.GetDexFileReference(), VeriClass::class_);
+  } else if (IsGetField(method)) {
+    // Class.getField or Class.getDeclaredField. Fetch the first parameter for the class, and the
+    // second parameter for the field name.
+    RegisterValue cls = GetRegister(GetParameterAt(instruction, is_range, args, 0));
+    RegisterValue name = GetRegister(GetParameterAt(instruction, is_range, args, 1));
+    uses_.push_back(ReflectAccessInfo(cls, name, /* is_method */ false));
+    return GetReturnType(id);
+  } else if (IsGetMethod(method)) {
+    // Class.getMethod or Class.getDeclaredMethod. Fetch the first parameter for the class, and the
+    // second parameter for the field name.
+    RegisterValue cls = GetRegister(GetParameterAt(instruction, is_range, args, 0));
+    RegisterValue name = GetRegister(GetParameterAt(instruction, is_range, args, 1));
+    uses_.push_back(ReflectAccessInfo(cls, name, /* is_method */ true));
+    return GetReturnType(id);
+  } else if (method == VeriClass::getClass_) {
+    // Get the type of the first parameter.
+    RegisterValue obj = GetRegister(GetParameterAt(instruction, is_range, args, 0));
+    const VeriClass* cls = obj.GetType();
+    if (cls != nullptr && cls->GetClassDef() != nullptr) {
+      const DexFile::ClassDef* def = cls->GetClassDef();
+      return RegisterValue(
+          RegisterSource::kClass,
+          DexFileReference(&resolver_->GetDexFileOf(*cls), def->class_idx_.index_),
+          VeriClass::class_);
+    } else {
+      return RegisterValue(
+          obj.GetSource(), obj.GetDexFileReference(), VeriClass::class_);
+    }
+  } else if (method == VeriClass::loadClass_) {
+    // ClassLoader.loadClass. Fetch the first parameter.
+    RegisterValue value = GetRegister(GetParameterAt(instruction, is_range, args, 1));
+    return RegisterValue(
+        value.GetSource(), value.GetDexFileReference(), VeriClass::class_);
+  } else {
+    // Return a RegisterValue referencing the method whose type is the return type
+    // of the method.
+    return GetReturnType(id);
+  }
+}
+
+void FlowAnalysisCollector::AnalyzeFieldSet(const Instruction& instruction ATTRIBUTE_UNUSED) {
+  // There are no fields that escape reflection uses.
+}
+
+RegisterValue FlowAnalysisSubstitutor::AnalyzeInvoke(const Instruction& instruction,
+                                                     bool is_range) {
+  uint32_t id = is_range ? instruction.VRegB_3rc() : instruction.VRegB_35c();
+  MethodReference method(&resolver_->GetDexFile(), id);
+  // TODO: doesn't work for multidex
+  // TODO: doesn't work for overriding (but maybe should be done at a higher level);
+  if (accesses_.find(method) == accesses_.end()) {
+    return GetReturnType(id);
+  }
+  uint32_t args[5];
+  if (!is_range) {
+    instruction.GetVarArgs(args);
+  }
+  for (const ReflectAccessInfo& info : accesses_.at(method)) {
+    if (info.cls.IsParameter() || info.name.IsParameter()) {
+      RegisterValue cls = info.cls.IsParameter()
+          ? GetRegister(GetParameterAt(instruction, is_range, args, info.cls.GetParameterIndex()))
+          : info.cls;
+      RegisterValue name = info.name.IsParameter()
+          ? GetRegister(GetParameterAt(instruction, is_range, args, info.name.GetParameterIndex()))
+          : info.name;
+      uses_.push_back(ReflectAccessInfo(cls, name, info.is_method));
+    }
+  }
+  return GetReturnType(id);
+}
+
+void FlowAnalysisSubstitutor::AnalyzeFieldSet(const Instruction& instruction ATTRIBUTE_UNUSED) {
+  // TODO: analyze field sets.
+}
+
 }  // namespace art
diff --git a/tools/veridex/flow_analysis.h b/tools/veridex/flow_analysis.h
index 80ae5fc..62c9916 100644
--- a/tools/veridex/flow_analysis.h
+++ b/tools/veridex/flow_analysis.h
@@ -21,13 +21,11 @@
 #include "dex/dex_file_reference.h"
 #include "dex/method_reference.h"
 #include "hidden_api.h"
+#include "resolver.h"
 #include "veridex.h"
 
 namespace art {
 
-class VeridexClass;
-class VeridexResolver;
-
 /**
  * The source where a dex register comes from.
  */
@@ -45,13 +43,29 @@
  */
 class RegisterValue {
  public:
-  RegisterValue() : source_(RegisterSource::kNone), reference_(nullptr, 0), type_(nullptr) {}
+  RegisterValue() : source_(RegisterSource::kNone),
+                    parameter_index_(0),
+                    reference_(nullptr, 0),
+                    type_(nullptr) {}
   RegisterValue(RegisterSource source, DexFileReference reference, const VeriClass* type)
-      : source_(source), reference_(reference), type_(type) {}
+      : source_(source), parameter_index_(0), reference_(reference), type_(type) {}
+
+  RegisterValue(RegisterSource source,
+                uint32_t parameter_index,
+                DexFileReference reference,
+                const VeriClass* type)
+      : source_(source), parameter_index_(parameter_index), reference_(reference), type_(type) {}
 
   RegisterSource GetSource() const { return source_; }
   DexFileReference GetDexFileReference() const { return reference_; }
   const VeriClass* GetType() const { return type_; }
+  uint32_t GetParameterIndex() const {
+    CHECK(IsParameter());
+    return parameter_index_;
+  }
+  bool IsParameter() const { return source_ == RegisterSource::kParameter; }
+  bool IsClass() const { return source_ == RegisterSource::kClass; }
+  bool IsString() const { return source_ == RegisterSource::kString; }
 
   std::string ToString() const {
     switch (source_) {
@@ -68,6 +82,8 @@
       }
       case RegisterSource::kClass:
         return reference_.dex_file->StringByTypeIdx(dex::TypeIndex(reference_.index));
+      case RegisterSource::kParameter:
+        return std::string("Parameter of ") + reference_.dex_file->PrettyMethod(reference_.index);
       default:
         return "<unknown>";
     }
@@ -75,6 +91,7 @@
 
  private:
   RegisterSource source_;
+  uint32_t parameter_index_;
   DexFileReference reference_;
   const VeriClass* type_;
 };
@@ -85,22 +102,18 @@
 
 class VeriFlowAnalysis {
  public:
-  VeriFlowAnalysis(VeridexResolver* resolver,
-                   const CodeItemDataAccessor& code_item_accessor)
+  VeriFlowAnalysis(VeridexResolver* resolver, const ClassDataItemIterator& it)
       : resolver_(resolver),
-        code_item_accessor_(code_item_accessor),
-        dex_registers_(code_item_accessor.InsnsSizeInCodeUnits()),
-        instruction_infos_(code_item_accessor.InsnsSizeInCodeUnits()) {}
+        method_id_(it.GetMemberIndex()),
+        code_item_accessor_(resolver->GetDexFile(), it.GetMethodCodeItem()),
+        dex_registers_(code_item_accessor_.InsnsSizeInCodeUnits()),
+        instruction_infos_(code_item_accessor_.InsnsSizeInCodeUnits()) {}
 
   void Run();
 
-  const std::vector<std::pair<RegisterValue, RegisterValue>>& GetFieldUses() const {
-    return field_uses_;
-  }
-
-  const std::vector<std::pair<RegisterValue, RegisterValue>>& GetMethodUses() const {
-    return method_uses_;
-  }
+  virtual RegisterValue AnalyzeInvoke(const Instruction& instruction, bool is_range) = 0;
+  virtual void AnalyzeFieldSet(const Instruction& instruction) = 0;
+  virtual ~VeriFlowAnalysis() {}
 
  private:
   // Find all branches in the code.
@@ -124,14 +137,19 @@
       uint32_t dex_register, RegisterSource kind, VeriClass* cls, uint32_t source_id);
   void UpdateRegister(uint32_t dex_register, const RegisterValue& value);
   void UpdateRegister(uint32_t dex_register, const VeriClass* cls);
-  const RegisterValue& GetRegister(uint32_t dex_register);
   void ProcessDexInstruction(const Instruction& inst);
   void SetVisited(uint32_t dex_pc);
-  RegisterValue GetReturnType(uint32_t method_index);
   RegisterValue GetFieldType(uint32_t field_index);
 
+ protected:
+  const RegisterValue& GetRegister(uint32_t dex_register);
+  RegisterValue GetReturnType(uint32_t method_index);
+
   VeridexResolver* resolver_;
-  const CodeItemDataAccessor& code_item_accessor_;
+
+ private:
+  const uint32_t method_id_;
+  CodeItemDataAccessor code_item_accessor_;
 
   // Vector of register values for all branch targets.
   std::vector<std::unique_ptr<std::vector<RegisterValue>>> dex_registers_;
@@ -144,12 +162,59 @@
 
   // The value of invoke instructions, to be fetched when visiting move-result.
   RegisterValue last_result_;
+};
 
-  // List of reflection field uses found.
-  std::vector<std::pair<RegisterValue, RegisterValue>> field_uses_;
+struct ReflectAccessInfo {
+  RegisterValue cls;
+  RegisterValue name;
+  bool is_method;
 
-  // List of reflection method uses found.
-  std::vector<std::pair<RegisterValue, RegisterValue>> method_uses_;
+  ReflectAccessInfo(RegisterValue c, RegisterValue n, bool m) : cls(c), name(n), is_method(m) {}
+
+  bool IsConcrete() const {
+    // We capture RegisterSource::kString for the class, for example in Class.forName.
+    return (cls.IsClass() || cls.IsString()) && name.IsString();
+  }
+};
+
+// Collects all reflection uses.
+class FlowAnalysisCollector : public VeriFlowAnalysis {
+ public:
+  FlowAnalysisCollector(VeridexResolver* resolver, const ClassDataItemIterator& it)
+      : VeriFlowAnalysis(resolver, it) {}
+
+  const std::vector<ReflectAccessInfo>& GetUses() const {
+    return uses_;
+  }
+
+  RegisterValue AnalyzeInvoke(const Instruction& instruction, bool is_range) OVERRIDE;
+  void AnalyzeFieldSet(const Instruction& instruction) OVERRIDE;
+
+ private:
+  // List of reflection uses found, concrete and abstract.
+  std::vector<ReflectAccessInfo> uses_;
+};
+
+// Substitutes reflection uses by new ones.
+class FlowAnalysisSubstitutor : public VeriFlowAnalysis {
+ public:
+  FlowAnalysisSubstitutor(VeridexResolver* resolver,
+                          const ClassDataItemIterator& it,
+                          const std::map<MethodReference, std::vector<ReflectAccessInfo>>& accesses)
+      : VeriFlowAnalysis(resolver, it), accesses_(accesses) {}
+
+  const std::vector<ReflectAccessInfo>& GetUses() const {
+    return uses_;
+  }
+
+  RegisterValue AnalyzeInvoke(const Instruction& instruction, bool is_range) OVERRIDE;
+  void AnalyzeFieldSet(const Instruction& instruction) OVERRIDE;
+
+ private:
+  // List of reflection uses found, concrete and abstract.
+  std::vector<ReflectAccessInfo> uses_;
+  // The abstract uses we are trying to subsititute.
+  const std::map<MethodReference, std::vector<ReflectAccessInfo>>& accesses_;
 };
 
 }  // namespace art
diff --git a/tools/veridex/precise_hidden_api_finder.cc b/tools/veridex/precise_hidden_api_finder.cc
index 4ae5769..89754c2 100644
--- a/tools/veridex/precise_hidden_api_finder.cc
+++ b/tools/veridex/precise_hidden_api_finder.cc
@@ -29,7 +29,9 @@
 
 namespace art {
 
-void PreciseHiddenApiFinder::Run(const std::vector<std::unique_ptr<VeridexResolver>>& resolvers) {
+void PreciseHiddenApiFinder::RunInternal(
+    const std::vector<std::unique_ptr<VeridexResolver>>& resolvers,
+    const std::function<void(VeridexResolver*, const ClassDataItemIterator&)>& action) {
   for (const std::unique_ptr<VeridexResolver>& resolver : resolvers) {
     const DexFile& dex_file = resolver->GetDexFile();
     size_t class_def_count = dex_file.NumClassDefs();
@@ -47,43 +49,67 @@
         if (code_item == nullptr) {
           continue;
         }
-        CodeItemDataAccessor code_item_accessor(dex_file, code_item);
-        VeriFlowAnalysis ana(resolver.get(), code_item_accessor);
-        ana.Run();
-        if (!ana.GetFieldUses().empty()) {
-          field_uses_[MethodReference(&dex_file, it.GetMemberIndex())] = ana.GetFieldUses();
-        }
-        if (!ana.GetMethodUses().empty()) {
-          method_uses_[MethodReference(&dex_file, it.GetMemberIndex())] = ana.GetMethodUses();
-        }
+        action(resolver.get(), it);
       }
     }
   }
 }
 
+void PreciseHiddenApiFinder::AddUsesAt(const std::vector<ReflectAccessInfo>& accesses,
+                                       MethodReference ref) {
+  for (const ReflectAccessInfo& info : accesses) {
+    if (info.IsConcrete()) {
+      concrete_uses_[ref].push_back(info);
+    } else {
+      abstract_uses_[ref].push_back(info);
+    }
+  }
+}
+
+void PreciseHiddenApiFinder::Run(const std::vector<std::unique_ptr<VeridexResolver>>& resolvers) {
+  // Collect reflection uses.
+  RunInternal(resolvers, [this] (VeridexResolver* resolver, const ClassDataItemIterator& it) {
+    FlowAnalysisCollector collector(resolver, it);
+    collector.Run();
+    AddUsesAt(collector.GetUses(), MethodReference(&resolver->GetDexFile(), it.GetMemberIndex()));
+  });
+
+  // For non-final reflection uses, do a limited fixed point calculation over the code to try
+  // substituting them with final reflection uses.
+  // We limit the number of times we iterate over the code as one run can be long.
+  static const int kMaximumIterations = 10;
+  uint32_t i = 0;
+  while (!abstract_uses_.empty() && (i++ < kMaximumIterations)) {
+    // Fetch and clear the worklist.
+    std::map<MethodReference, std::vector<ReflectAccessInfo>> current_uses
+        = std::move(abstract_uses_);
+    RunInternal(resolvers,
+                [this, current_uses] (VeridexResolver* resolver, const ClassDataItemIterator& it) {
+      FlowAnalysisSubstitutor substitutor(resolver, it, current_uses);
+      substitutor.Run();
+      AddUsesAt(substitutor.GetUses(),
+                MethodReference(&resolver->GetDexFile(), it.GetMemberIndex()));
+    });
+  }
+}
+
 void PreciseHiddenApiFinder::Dump(std::ostream& os, HiddenApiStats* stats) {
   static const char* kPrefix = "       ";
-  std::map<std::string, std::vector<MethodReference>> uses;
-  for (auto kinds : { field_uses_, method_uses_ }) {
-    for (auto it : kinds) {
-      MethodReference ref = it.first;
-      for (const std::pair<RegisterValue, RegisterValue>& info : it.second) {
-        if ((info.first.GetSource() == RegisterSource::kClass ||
-             info.first.GetSource() == RegisterSource::kString) &&
-            info.second.GetSource() == RegisterSource::kString) {
-          std::string cls(info.first.ToString());
-          std::string name(info.second.ToString());
-          std::string full_name = cls + "->" + name;
-          HiddenApiAccessFlags::ApiList api_list = hidden_api_.GetApiList(full_name);
-          if (api_list != HiddenApiAccessFlags::kWhitelist) {
-            uses[full_name].push_back(ref);
-          }
-        }
+  std::map<std::string, std::vector<MethodReference>> named_uses;
+  for (auto it : concrete_uses_) {
+    MethodReference ref = it.first;
+    for (const ReflectAccessInfo& info : it.second) {
+      std::string cls(info.cls.ToString());
+      std::string name(info.name.ToString());
+      std::string full_name = cls + "->" + name;
+      HiddenApiAccessFlags::ApiList api_list = hidden_api_.GetApiList(full_name);
+      if (api_list != HiddenApiAccessFlags::kWhitelist) {
+        named_uses[full_name].push_back(ref);
       }
     }
   }
 
-  for (auto it : uses) {
+  for (auto it : named_uses) {
     ++stats->reflection_count;
     const std::string& full_name = it.first;
     HiddenApiAccessFlags::ApiList api_list = hidden_api_.GetApiList(full_name);
diff --git a/tools/veridex/precise_hidden_api_finder.h b/tools/veridex/precise_hidden_api_finder.h
index 22744a6..1c4d0ae 100644
--- a/tools/veridex/precise_hidden_api_finder.h
+++ b/tools/veridex/precise_hidden_api_finder.h
@@ -45,9 +45,18 @@
   void Dump(std::ostream& os, HiddenApiStats* stats);
 
  private:
+  // Run over all methods of all dex files, and call `action` on each.
+  void RunInternal(
+      const std::vector<std::unique_ptr<VeridexResolver>>& resolvers,
+      const std::function<void(VeridexResolver*, const ClassDataItemIterator&)>& action);
+
+  // Add uses found in method `ref`.
+  void AddUsesAt(const std::vector<ReflectAccessInfo>& accesses, MethodReference ref);
+
   const HiddenApi& hidden_api_;
-  std::map<MethodReference, std::vector<std::pair<RegisterValue, RegisterValue>>> field_uses_;
-  std::map<MethodReference, std::vector<std::pair<RegisterValue, RegisterValue>>> method_uses_;
+
+  std::map<MethodReference, std::vector<ReflectAccessInfo>> concrete_uses_;
+  std::map<MethodReference, std::vector<ReflectAccessInfo>> abstract_uses_;
 };
 
 }  // namespace art