ART: Clean up IndirectReferenceTable

Introduce constants and move some functions into the
IndirectReferenceTable class.

Slightly change IndirectRef encoding to be more obvious (and slighly
more optimized when decoding).

Bug: 32125344
Test: m test-art-host
Change-Id: I05819eccb733b611de582fb8d7151f1a110c305a
diff --git a/runtime/check_jni.cc b/runtime/check_jni.cc
index a1ce30b..5399dc5 100644
--- a/runtime/check_jni.cc
+++ b/runtime/check_jni.cc
@@ -277,7 +277,7 @@
     if (!Runtime::Current()->GetHeap()->IsValidObjectAddress(o.Ptr())) {
       Runtime::Current()->GetHeap()->DumpSpaces(LOG_STREAM(ERROR));
       AbortF("field operation on invalid %s: %p",
-             ToStr<IndirectRefKind>(GetIndirectRefKind(java_object)).c_str(),
+             GetIndirectRefKindString(IndirectReferenceTable::GetIndirectRefKind(java_object)),
              java_object);
       return false;
     }
@@ -632,17 +632,17 @@
   bool CheckReferenceKind(IndirectRefKind expected_kind, Thread* self, jobject obj) {
     IndirectRefKind found_kind;
     if (expected_kind == kLocal) {
-      found_kind = GetIndirectRefKind(obj);
+      found_kind = IndirectReferenceTable::GetIndirectRefKind(obj);
       if (found_kind == kHandleScopeOrInvalid && self->HandleScopeContains(obj)) {
         found_kind = kLocal;
       }
     } else {
-      found_kind = GetIndirectRefKind(obj);
+      found_kind = IndirectReferenceTable::GetIndirectRefKind(obj);
     }
     if (obj != nullptr && found_kind != expected_kind) {
       AbortF("expected reference of kind %s but found %s: %p",
-             ToStr<IndirectRefKind>(expected_kind).c_str(),
-             ToStr<IndirectRefKind>(GetIndirectRefKind(obj)).c_str(),
+             GetIndirectRefKindString(expected_kind),
+             GetIndirectRefKindString(IndirectReferenceTable::GetIndirectRefKind(obj)),
              obj);
       return false;
     }
@@ -773,7 +773,7 @@
       // Either java_object is invalid or is a cleared weak.
       IndirectRef ref = reinterpret_cast<IndirectRef>(java_object);
       bool okay;
-      if (GetIndirectRefKind(ref) != kWeakGlobal) {
+      if (IndirectReferenceTable::GetIndirectRefKind(ref) != kWeakGlobal) {
         okay = false;
       } else {
         obj = soa.Vm()->DecodeWeakGlobal(soa.Self(), ref);
@@ -781,8 +781,10 @@
       }
       if (!okay) {
         AbortF("%s is an invalid %s: %p (%p)",
-               what, ToStr<IndirectRefKind>(GetIndirectRefKind(java_object)).c_str(),
-               java_object, obj.Ptr());
+               what,
+               GetIndirectRefKindString(IndirectReferenceTable::GetIndirectRefKind(java_object)),
+               java_object,
+               obj.Ptr());
         return false;
       }
     }
@@ -790,8 +792,10 @@
     if (!Runtime::Current()->GetHeap()->IsValidObjectAddress(obj.Ptr())) {
       Runtime::Current()->GetHeap()->DumpSpaces(LOG_STREAM(ERROR));
       AbortF("%s is an invalid %s: %p (%p)",
-             what, ToStr<IndirectRefKind>(GetIndirectRefKind(java_object)).c_str(),
-             java_object, obj.Ptr());
+             what,
+             GetIndirectRefKindString(IndirectReferenceTable::GetIndirectRefKind(java_object)),
+             java_object,
+             obj.Ptr());
       return false;
     }
 
@@ -1116,8 +1120,9 @@
     if (UNLIKELY(!Runtime::Current()->GetHeap()->IsValidObjectAddress(a.Ptr()))) {
       Runtime::Current()->GetHeap()->DumpSpaces(LOG_STREAM(ERROR));
       AbortF("jarray is an invalid %s: %p (%p)",
-             ToStr<IndirectRefKind>(GetIndirectRefKind(java_array)).c_str(),
-             java_array, a.Ptr());
+             GetIndirectRefKindString(IndirectReferenceTable::GetIndirectRefKind(java_array)),
+             java_array,
+             a.Ptr());
       return false;
     } else if (!a->IsArrayInstance()) {
       AbortF("jarray argument has non-array type: %s", a->PrettyTypeOf().c_str());
diff --git a/runtime/indirect_reference_table.cc b/runtime/indirect_reference_table.cc
index 7389c73..b48e711 100644
--- a/runtime/indirect_reference_table.cc
+++ b/runtime/indirect_reference_table.cc
@@ -84,6 +84,34 @@
 IndirectReferenceTable::~IndirectReferenceTable() {
 }
 
+void IndirectReferenceTable::ConstexprChecks() {
+  // Use this for some assertions. They can't be put into the header as C++ wants the class
+  // to be complete.
+
+  // Check kind.
+  static_assert((EncodeIndirectRefKind(kLocal) & (~kKindMask)) == 0, "Kind encoding error");
+  static_assert((EncodeIndirectRefKind(kGlobal) & (~kKindMask)) == 0, "Kind encoding error");
+  static_assert((EncodeIndirectRefKind(kWeakGlobal) & (~kKindMask)) == 0, "Kind encoding error");
+  static_assert(DecodeIndirectRefKind(EncodeIndirectRefKind(kLocal)) == kLocal,
+                "Kind encoding error");
+  static_assert(DecodeIndirectRefKind(EncodeIndirectRefKind(kGlobal)) == kGlobal,
+                "Kind encoding error");
+  static_assert(DecodeIndirectRefKind(EncodeIndirectRefKind(kWeakGlobal)) == kWeakGlobal,
+                "Kind encoding error");
+
+  // Check serial.
+  static_assert(DecodeSerial(EncodeSerial(0u)) == 0u, "Serial encoding error");
+  static_assert(DecodeSerial(EncodeSerial(1u)) == 1u, "Serial encoding error");
+  static_assert(DecodeSerial(EncodeSerial(2u)) == 2u, "Serial encoding error");
+  static_assert(DecodeSerial(EncodeSerial(3u)) == 3u, "Serial encoding error");
+
+  // Table index.
+  static_assert(DecodeIndex(EncodeIndex(0u)) == 0u, "Index encoding error");
+  static_assert(DecodeIndex(EncodeIndex(1u)) == 1u, "Index encoding error");
+  static_assert(DecodeIndex(EncodeIndex(2u)) == 2u, "Index encoding error");
+  static_assert(DecodeIndex(EncodeIndex(3u)) == 3u, "Index encoding error");
+}
+
 bool IndirectReferenceTable::IsValid() const {
   return table_mem_map_.get() != nullptr;
 }
diff --git a/runtime/indirect_reference_table.h b/runtime/indirect_reference_table.h
index 363280a..c0355de 100644
--- a/runtime/indirect_reference_table.h
+++ b/runtime/indirect_reference_table.h
@@ -22,6 +22,7 @@
 #include <iosfwd>
 #include <string>
 
+#include "base/bit_utils.h"
 #include "base/logging.h"
 #include "base/mutex.h"
 #include "gc_root.h"
@@ -114,21 +115,15 @@
  * For convenience these match up with enum jobjectRefType from jni.h.
  */
 enum IndirectRefKind {
-  kHandleScopeOrInvalid = 0,  // <<stack indirect reference table or invalid reference>>
-  kLocal         = 1,  // <<local reference>>
-  kGlobal        = 2,  // <<global reference>>
-  kWeakGlobal    = 3   // <<weak global reference>>
+  kHandleScopeOrInvalid = 0,           // <<stack indirect reference table or invalid reference>>
+  kLocal                = 1,           // <<local reference>>
+  kGlobal               = 2,           // <<global reference>>
+  kWeakGlobal           = 3,           // <<weak global reference>>
+  kLastKind             = kWeakGlobal
 };
 std::ostream& operator<<(std::ostream& os, const IndirectRefKind& rhs);
 const char* GetIndirectRefKindString(const IndirectRefKind& kind);
 
-/*
- * Determine what kind of indirect reference this is.
- */
-static inline IndirectRefKind GetIndirectRefKind(IndirectRef iref) {
-  return static_cast<IndirectRefKind>(reinterpret_cast<uintptr_t>(iref) & 0x03);
-}
-
 /* use as initial value for "cookie", and when table has only one segment */
 static const uint32_t IRT_FIRST_SEGMENT = 0;
 
@@ -198,7 +193,8 @@
 // Try to choose kIRTPrevCount so that sizeof(IrtEntry) is a power of 2.
 // Contains multiple entries but only one active one, this helps us detect use after free errors
 // since the serial stored in the indirect ref wont match.
-static const size_t kIRTPrevCount = kIsDebugBuild ? 7 : 3;
+static constexpr size_t kIRTPrevCount = kIsDebugBuild ? 7 : 3;
+
 class IrtEntry {
  public:
   void Add(ObjPtr<mirror::Object> obj) REQUIRES_SHARED(Locks::mutator_lock_);
@@ -220,6 +216,7 @@
 };
 static_assert(sizeof(IrtEntry) == (1 + kIRTPrevCount) * sizeof(uint32_t),
               "Unexpected sizeof(IrtEntry)");
+static_assert(IsPowerOfTwo(sizeof(IrtEntry)), "Unexpected sizeof(IrtEntry)");
 
 class IrtIterator {
  public:
@@ -362,22 +359,59 @@
   // Release pages past the end of the table that may have previously held references.
   void Trim() REQUIRES_SHARED(Locks::mutator_lock_);
 
- private:
-  // Extract the table index from an indirect reference.
-  static uint32_t ExtractIndex(IndirectRef iref) {
-    uintptr_t uref = reinterpret_cast<uintptr_t>(iref);
-    return (uref >> 2) & 0xffff;
+  // Determine what kind of indirect reference this is. Opposite of EncodeIndirectRefKind.
+  ALWAYS_INLINE static inline IndirectRefKind GetIndirectRefKind(IndirectRef iref) {
+    return DecodeIndirectRefKind(reinterpret_cast<uintptr_t>(iref));
   }
 
-  /*
-   * The object pointer itself is subject to relocation in some GC
-   * implementations, so we shouldn't really be using it here.
-   */
-  IndirectRef ToIndirectRef(uint32_t tableIndex) const {
-    DCHECK_LT(tableIndex, 65536U);
-    uint32_t serialChunk = table_[tableIndex].GetSerial();
-    uintptr_t uref = (serialChunk << 20) | (tableIndex << 2) | kind_;
-    return reinterpret_cast<IndirectRef>(uref);
+ private:
+  static constexpr size_t kSerialBits = MinimumBitsToStore(kIRTPrevCount);
+  static constexpr uint32_t kShiftedSerialMask = (1u << kSerialBits) - 1;
+
+  static constexpr size_t kKindBits = MinimumBitsToStore(
+      static_cast<uint32_t>(IndirectRefKind::kLastKind));
+  static constexpr uint32_t kKindMask = (1u << kKindBits) - 1;
+
+  static constexpr uintptr_t EncodeIndex(uint32_t table_index) {
+    static_assert(sizeof(IndirectRef) == sizeof(uintptr_t), "Unexpected IndirectRef size");
+    DCHECK_LE(MinimumBitsToStore(table_index), BitSizeOf<uintptr_t>() - kSerialBits - kKindBits);
+    return (static_cast<uintptr_t>(table_index) << kKindBits << kSerialBits);
+  }
+  static constexpr uint32_t DecodeIndex(uintptr_t uref) {
+    return static_cast<uint32_t>((uref >> kKindBits) >> kSerialBits);
+  }
+
+  static constexpr uintptr_t EncodeIndirectRefKind(IndirectRefKind kind) {
+    return static_cast<uintptr_t>(kind);
+  }
+  static constexpr IndirectRefKind DecodeIndirectRefKind(uintptr_t uref) {
+    return static_cast<IndirectRefKind>(uref & kKindMask);
+  }
+
+  static constexpr uintptr_t EncodeSerial(uint32_t serial) {
+    DCHECK_LE(MinimumBitsToStore(serial), kSerialBits);
+    return serial << kKindBits;
+  }
+  static constexpr uint32_t DecodeSerial(uintptr_t uref) {
+    return static_cast<uint32_t>(uref >> kKindBits) & kShiftedSerialMask;
+  }
+
+  constexpr uintptr_t EncodeIndirectRef(uint32_t table_index, uint32_t serial) const {
+    DCHECK_LT(table_index, max_entries_);
+    return EncodeIndex(table_index) | EncodeSerial(serial) | EncodeIndirectRefKind(kind_);
+  }
+
+  static void ConstexprChecks();
+
+  // Extract the table index from an indirect reference.
+  ALWAYS_INLINE static uint32_t ExtractIndex(IndirectRef iref) {
+    return DecodeIndex(reinterpret_cast<uintptr_t>(iref));
+  }
+
+  IndirectRef ToIndirectRef(uint32_t table_index) const {
+    DCHECK_LT(table_index, max_entries_);
+    uint32_t serial = table_[table_index].GetSerial();
+    return reinterpret_cast<IndirectRef>(EncodeIndirectRef(table_index, serial));
   }
 
   // Abort if check_jni is not enabled. Otherwise, just log as an error.
diff --git a/runtime/java_vm_ext.cc b/runtime/java_vm_ext.cc
index 9b4327f..f1f9de8 100644
--- a/runtime/java_vm_ext.cc
+++ b/runtime/java_vm_ext.cc
@@ -680,7 +680,7 @@
   // This only applies in the case where MayAccessWeakGlobals goes from false to true. In the other
   // case, it may be racy, this is benign since DecodeWeakGlobalLocked does the correct behavior
   // if MayAccessWeakGlobals is false.
-  DCHECK_EQ(GetIndirectRefKind(ref), kWeakGlobal);
+  DCHECK_EQ(IndirectReferenceTable::GetIndirectRefKind(ref), kWeakGlobal);
   if (LIKELY(MayAccessWeakGlobalsUnlocked(self))) {
     return weak_globals_.SynchronizedGet(ref);
   }
@@ -699,7 +699,7 @@
 }
 
 ObjPtr<mirror::Object> JavaVMExt::DecodeWeakGlobalDuringShutdown(Thread* self, IndirectRef ref) {
-  DCHECK_EQ(GetIndirectRefKind(ref), kWeakGlobal);
+  DCHECK_EQ(IndirectReferenceTable::GetIndirectRefKind(ref), kWeakGlobal);
   DCHECK(Runtime::Current()->IsShuttingDown(self));
   if (self != nullptr) {
     return DecodeWeakGlobal(self, ref);
@@ -712,7 +712,7 @@
 }
 
 bool JavaVMExt::IsWeakGlobalCleared(Thread* self, IndirectRef ref) {
-  DCHECK_EQ(GetIndirectRefKind(ref), kWeakGlobal);
+  DCHECK_EQ(IndirectReferenceTable::GetIndirectRefKind(ref), kWeakGlobal);
   MutexLock mu(self, *Locks::jni_weak_globals_lock_);
   while (UNLIKELY(!MayAccessWeakGlobals(self))) {
     weak_globals_add_condition_.WaitHoldingLocks(self);
diff --git a/runtime/jni_internal.cc b/runtime/jni_internal.cc
index 3839e08..0217a67 100644
--- a/runtime/jni_internal.cc
+++ b/runtime/jni_internal.cc
@@ -2374,7 +2374,7 @@
 
     // Do we definitely know what kind of reference this is?
     IndirectRef ref = reinterpret_cast<IndirectRef>(java_object);
-    IndirectRefKind kind = GetIndirectRefKind(ref);
+    IndirectRefKind kind = IndirectReferenceTable::GetIndirectRefKind(ref);
     switch (kind) {
     case kLocal:
       return JNILocalRefType;
diff --git a/runtime/reflection.cc b/runtime/reflection.cc
index 661012c..f88309b 100644
--- a/runtime/reflection.cc
+++ b/runtime/reflection.cc
@@ -911,7 +911,7 @@
 // Will need to be fixed if there's cases where it's not.
 void UpdateReference(Thread* self, jobject obj, ObjPtr<mirror::Object> result) {
   IndirectRef ref = reinterpret_cast<IndirectRef>(obj);
-  IndirectRefKind kind = GetIndirectRefKind(ref);
+  IndirectRefKind kind = IndirectReferenceTable::GetIndirectRefKind(ref);
   if (kind == kLocal) {
     self->GetJniEnv()->locals.Update(obj, result);
   } else if (kind == kHandleScopeOrInvalid) {
diff --git a/runtime/thread.cc b/runtime/thread.cc
index e47ccc0..ace5e67 100644
--- a/runtime/thread.cc
+++ b/runtime/thread.cc
@@ -1860,7 +1860,7 @@
     return nullptr;
   }
   IndirectRef ref = reinterpret_cast<IndirectRef>(obj);
-  IndirectRefKind kind = GetIndirectRefKind(ref);
+  IndirectRefKind kind = IndirectReferenceTable::GetIndirectRefKind(ref);
   ObjPtr<mirror::Object> result;
   bool expect_null = false;
   // The "kinds" below are sorted by the frequency we expect to encounter them.
@@ -1902,7 +1902,7 @@
 bool Thread::IsJWeakCleared(jweak obj) const {
   CHECK(obj != nullptr);
   IndirectRef ref = reinterpret_cast<IndirectRef>(obj);
-  IndirectRefKind kind = GetIndirectRefKind(ref);
+  IndirectRefKind kind = IndirectReferenceTable::GetIndirectRefKind(ref);
   CHECK_EQ(kind, kWeakGlobal);
   return tlsPtr_.jni_env->vm->IsWeakGlobalCleared(const_cast<Thread*>(this), ref);
 }