Allow to use naive and equality memory management algorithms, when sizes of
intermediate tensors are given as BHWC.

PiperOrigin-RevId: 254232299
diff --git a/tensorflow/lite/delegates/gpu/common/BUILD b/tensorflow/lite/delegates/gpu/common/BUILD
index ecec453..fe5f5ed 100644
--- a/tensorflow/lite/delegates/gpu/common/BUILD
+++ b/tensorflow/lite/delegates/gpu/common/BUILD
@@ -35,6 +35,7 @@
     srcs = ["memory_management.cc"],
     hdrs = ["memory_management.h"],
     deps = [
+        ":shape",
         ":status",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/memory",
@@ -130,6 +131,7 @@
     srcs = ["shape.cc"],
     hdrs = ["shape.h"],
     deps = [
+        "@com_google_absl//absl/hash",
         "@com_google_absl//absl/strings",
     ],
 )
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management.cc b/tensorflow/lite/delegates/gpu/common/memory_management.cc
index 4dcd798..a5d5fc9 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management.cc
+++ b/tensorflow/lite/delegates/gpu/common/memory_management.cc
@@ -417,5 +417,20 @@
   return OkStatus();
 }
 
+Status AssignObjectsToTensors(
+    const std::vector<TensorUsageRecord<BHWC>>& usage_records,
+    const MemoryStrategy& strategy, ObjectsAssignment<BHWC>* assignment) {
+  switch (strategy) {
+    case MemoryStrategy::NAIVE:
+      return NaiveAssignment<BHWC>(usage_records, assignment);
+    case MemoryStrategy::EQUALITY:
+      return EqualityAssignment<BHWC>(usage_records, assignment);
+    default:
+      return InternalError(
+          "MemoryStrategy is not supported with current tensor size type.");
+  }
+  return OkStatus();
+}
+
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management.h b/tensorflow/lite/delegates/gpu/common/memory_management.h
index 38af893..d3fec0a 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management.h
+++ b/tensorflow/lite/delegates/gpu/common/memory_management.h
@@ -21,6 +21,7 @@
 #include <vector>
 
 #include "absl/memory/memory.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 
 namespace tflite {
@@ -28,7 +29,7 @@
 
 using TaskId = size_t;
 
-// Record, containing tensor size and IDs of the first and the last task,
+// Record, containing tensor size/shape and IDs of the first and the last task,
 // that use this tensor as input or output. For example: tensor #3 with size
 // tensor_size=65536 is first introduced in program #2 (first_task=2) and used
 // for the last time in program #7 (last_task=7).
@@ -82,6 +83,13 @@
     const std::vector<TensorUsageRecord<size_t>>& usage_records,
     const MemoryStrategy& strategy, ObjectsAssignment<size_t>* assignment);
 
+// Calculates the assignement of shared objects to given tensors, including
+// objects' sizes. Initial tensor sizes are given as BHWC. This function is
+// intended to use with GPU textures.
+Status AssignObjectsToTensors(
+    const std::vector<TensorUsageRecord<BHWC>>& usage_records,
+    const MemoryStrategy& strategy, ObjectsAssignment<BHWC>* assignment);
+
 }  // namespace gpu
 }  // namespace tflite
 
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management_test.cc b/tensorflow/lite/delegates/gpu/common/memory_management_test.cc
index 9e25ab9..34cc684 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management_test.cc
+++ b/tensorflow/lite/delegates/gpu/common/memory_management_test.cc
@@ -149,6 +149,39 @@
   EXPECT_THAT(assignment.object_sizes, ElementsAre(32, 64, 8, 8));
 }
 
+TEST(Model, BHWCRecords) {
+  std::vector<TensorUsageRecord<BHWC>> usage_records{
+      {/*size=*/BHWC(1, 1, 2, 8), /*first=*/0, /*last=*/1},
+      {/*size=*/BHWC(1, 1, 2, 8), /*first=*/1, /*last=*/2},
+      {/*size=*/BHWC(1, 1, 1, 16), /*first=*/2, /*last=*/4},
+      {/*size=*/BHWC(1, 1, 2, 8), /*first=*/3, /*last=*/5},
+      {/*size=*/BHWC(1, 1, 8, 2), /*first=*/4, /*last=*/5},
+      {/*size=*/BHWC(1, 1, 2, 8), /*first=*/5, /*last=*/7},
+      {/*size=*/BHWC(1, 16, 1, 1), /*first=*/6, /*last=*/8},
+      {/*size=*/BHWC(16, 1, 1, 1), /*first=*/7, /*last=*/8},
+      {/*size=*/BHWC(1, 1, 1, 16), /*first=*/8, /*last=*/9}};
+
+  ObjectsAssignment<BHWC> assignment;
+  ASSERT_TRUE(
+      AssignObjectsToTensors(usage_records, MemoryStrategy::NAIVE, &assignment)
+          .ok());
+  EXPECT_THAT(assignment.object_ids, ElementsAre(0, 1, 2, 3, 4, 5, 6, 7, 8));
+  EXPECT_THAT(
+      assignment.object_sizes,
+      ElementsAre(BHWC(1, 1, 2, 8), BHWC(1, 1, 2, 8), BHWC(1, 1, 1, 16),
+                  BHWC(1, 1, 2, 8), BHWC(1, 1, 8, 2), BHWC(1, 1, 2, 8),
+                  BHWC(1, 16, 1, 1), BHWC(16, 1, 1, 1), BHWC(1, 1, 1, 16)));
+
+  ASSERT_TRUE(AssignObjectsToTensors(usage_records, MemoryStrategy::EQUALITY,
+                                     &assignment)
+                  .ok());
+  EXPECT_THAT(assignment.object_ids, ElementsAre(0, 1, 2, 1, 3, 0, 4, 5, 2));
+  EXPECT_THAT(
+      assignment.object_sizes,
+      ElementsAre(BHWC(1, 1, 2, 8), BHWC(1, 1, 2, 8), BHWC(1, 1, 1, 16),
+                  BHWC(1, 1, 8, 2), BHWC(1, 16, 1, 1), BHWC(16, 1, 1, 1)));
+}
+
 }  // namespace
 }  // namespace gpu
 }  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/common/shape.h b/tensorflow/lite/delegates/gpu/common/shape.h
index f18e696..f1fb040 100644
--- a/tensorflow/lite/delegates/gpu/common/shape.h
+++ b/tensorflow/lite/delegates/gpu/common/shape.h
@@ -17,6 +17,7 @@
 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_
 
 #include <sys/types.h>
+
 #include <algorithm>
 #include <array>
 #include <functional>
@@ -25,6 +26,8 @@
 #include <utility>
 #include <vector>
 
+#include "absl/hash/hash.h"
+
 namespace tflite {
 namespace gpu {
 
@@ -531,6 +534,15 @@
       StrongShape::set(source.axis(i), source.get(i));
     }
   }
+
+  // AbslHash function for using in flat hash containers.
+  template <typename H>
+  friend H AbslHashValue(H hash_state, const StrongShape& strong_shape) {
+    for (size_t i = 0; i < strong_shape.size(); ++i) {
+      hash_state = H::combine(std::move(hash_state), strong_shape.get(i));
+    }
+    return hash_state;
+  }
 };
 
 template <Layout T>