Merge pull request #42503 from ahmedsabie:T1

PiperOrigin-RevId: 329267324
Change-Id: I8829d03b6d20531502ae9f638864c9a8d82470fa
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
index 6a992d6..1abb190 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
@@ -82,8 +82,8 @@
 //              offset - mean * scale * rsqrt(variance + epsilon)
 // CHECK:  %[[ADD2:.*]] = "tf.Add"(%[[MUL2]], %[[SUB]])
 
-// CHECK:  %[[BATCHNORM1_a:[^,]+]], {{.*}} = "tf.FusedBatchNorm"(%[[ADD2]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
-// CHECK:  "tf.FusedBatchNorm"(%[[BATCHNORM1_a]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
+// CHECK:  %[[BATCHNORM1_a:[^,]+]], {{.*}} = "tf.FusedBatchNormV3"(%[[ADD2]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
+// CHECK:  "tf.FusedBatchNormV3"(%[[BATCHNORM1_a]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
 }
 
 func @fusedBatchNormV3(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) {
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td
index f5b2527..326b6b2 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td
@@ -40,7 +40,7 @@
     (TF_MulOp $t, (TF_MulOp:$mul (TF_RsqrtOp (TF_AddOp $v, (TF_ConstOp $variance_epsilon))), $gamma)),
     (TF_SubOp $beta, (TF_MulOp $m, $mul)))>;
 
-// Converts tf.FusedBatchNorm & tf.FusedBatchNormV3 into a sequence of more primitive arithmetic
+// Converts tf.FusedBatchNormV3 into a sequence of more primitive arithmetic
 // operations. Specifically, performs the following calculation:
 //
 //   (x - mean) * scale / sqrt(variance + epsilon) + offset
@@ -50,28 +50,6 @@
 //   (x - mean) * scale / sqrt(variance + epsilon) + offset,
 // is then to compute
 //   (x * multiplier) + (offset - mean * multiplier).
-def : Pattern<
-    (TF_FusedBatchNormOp:$root
-        $x, $scale, $offset, $mean, $variance,
-        F32Attr:$epsilon, $exponential_avg_factor,
-        $data_format, FalseBoolAttr:$is_training),
-    [(TF_AddOp
-        (TF_MulOp
-            $x,
-            (TF_MulOp:$multiplier
-                $scale,
-                (TF_RsqrtOp
-                    (TF_AddOp $variance,
-                              (TF_ConstOp $epsilon))))),
-        (TF_SubOp $offset, (TF_MulOp $mean, $multiplier))),
-     // We already guaranteed that the last four results has no use so it does
-     // not matter what value we provide here for replacement.
-     /*batch_mean=*/(replaceWithValue $x),
-     /*batch_variance=*/(replaceWithValue $x),
-     /*reserve_space_1=*/(replaceWithValue $x),
-     /*reserve_space_2=*/(replaceWithValue $x)],
-    [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
-     (HasNoUseOf:$root__3), (HasNoUseOf:$root__4)]>;
 
 def : Pattern<
     (TF_FusedBatchNormV3Op:$root
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
index 31b281a..c2acad3 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
@@ -738,6 +738,31 @@
   }
 };
 
+struct ConvertFusedBatchNorm : public OpRewritePattern<TF::FusedBatchNormOp> {
+  explicit ConvertFusedBatchNorm(MLIRContext *context)
+      : OpRewritePattern<TF::FusedBatchNormOp>(context) {}
+
+  LogicalResult matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op,
+                                PatternRewriter &rewriter) const override {
+    auto new_result_types =
+        llvm::to_vector<6>(tf_fused_batch_norm_op.getResultTypes());
+    // reserve_space_3
+    new_result_types.push_back(
+        UnrankedTensorType::get(FloatType::getF32(rewriter.getContext())));
+
+    OperationState new_state(tf_fused_batch_norm_op.getLoc(),
+                             TF::FusedBatchNormV3Op::getOperationName(),
+                             tf_fused_batch_norm_op.getOperands(),
+                             new_result_types,
+                             tf_fused_batch_norm_op.getAttrs());
+    Operation *tf_fused_batch_norm_op_v3 = rewriter.createOperation(new_state);
+
+    rewriter.replaceOp(tf_fused_batch_norm_op,
+                       tf_fused_batch_norm_op_v3->getResults().drop_back());
+    return success();
+  }
+};
+
 #include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc"
 
 // Returns success if all the operations in the `op`'s regions including `op`
@@ -899,6 +924,8 @@
   // replaced with a single Conv op with dilation parameter.
   patterns.insert<ConvertTFDilatedConvOp<TF::Conv2DOp>,
                   ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(ctx);
+
+  patterns.insert<ConvertFusedBatchNorm>(ctx);
   TFL::populateWithGenerated(ctx, &patterns);
   // TODO(karimnosseir): Split to separate pass probably after
   // deciding on long term plan for this optimization.