Merge pull request #42503 from ahmedsabie:T1
PiperOrigin-RevId: 329267324
Change-Id: I8829d03b6d20531502ae9f638864c9a8d82470fa
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
index 6a992d6..1abb190 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
@@ -82,8 +82,8 @@
// offset - mean * scale * rsqrt(variance + epsilon)
// CHECK: %[[ADD2:.*]] = "tf.Add"(%[[MUL2]], %[[SUB]])
-// CHECK: %[[BATCHNORM1_a:[^,]+]], {{.*}} = "tf.FusedBatchNorm"(%[[ADD2]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
-// CHECK: "tf.FusedBatchNorm"(%[[BATCHNORM1_a]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
+// CHECK: %[[BATCHNORM1_a:[^,]+]], {{.*}} = "tf.FusedBatchNormV3"(%[[ADD2]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
+// CHECK: "tf.FusedBatchNormV3"(%[[BATCHNORM1_a]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
}
func @fusedBatchNormV3(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) {
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td
index f5b2527..326b6b2 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td
@@ -40,7 +40,7 @@
(TF_MulOp $t, (TF_MulOp:$mul (TF_RsqrtOp (TF_AddOp $v, (TF_ConstOp $variance_epsilon))), $gamma)),
(TF_SubOp $beta, (TF_MulOp $m, $mul)))>;
-// Converts tf.FusedBatchNorm & tf.FusedBatchNormV3 into a sequence of more primitive arithmetic
+// Converts tf.FusedBatchNormV3 into a sequence of more primitive arithmetic
// operations. Specifically, performs the following calculation:
//
// (x - mean) * scale / sqrt(variance + epsilon) + offset
@@ -50,28 +50,6 @@
// (x - mean) * scale / sqrt(variance + epsilon) + offset,
// is then to compute
// (x * multiplier) + (offset - mean * multiplier).
-def : Pattern<
- (TF_FusedBatchNormOp:$root
- $x, $scale, $offset, $mean, $variance,
- F32Attr:$epsilon, $exponential_avg_factor,
- $data_format, FalseBoolAttr:$is_training),
- [(TF_AddOp
- (TF_MulOp
- $x,
- (TF_MulOp:$multiplier
- $scale,
- (TF_RsqrtOp
- (TF_AddOp $variance,
- (TF_ConstOp $epsilon))))),
- (TF_SubOp $offset, (TF_MulOp $mean, $multiplier))),
- // We already guaranteed that the last four results has no use so it does
- // not matter what value we provide here for replacement.
- /*batch_mean=*/(replaceWithValue $x),
- /*batch_variance=*/(replaceWithValue $x),
- /*reserve_space_1=*/(replaceWithValue $x),
- /*reserve_space_2=*/(replaceWithValue $x)],
- [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
- (HasNoUseOf:$root__3), (HasNoUseOf:$root__4)]>;
def : Pattern<
(TF_FusedBatchNormV3Op:$root
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
index 31b281a..c2acad3 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
@@ -738,6 +738,31 @@
}
};
+struct ConvertFusedBatchNorm : public OpRewritePattern<TF::FusedBatchNormOp> {
+ explicit ConvertFusedBatchNorm(MLIRContext *context)
+ : OpRewritePattern<TF::FusedBatchNormOp>(context) {}
+
+ LogicalResult matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op,
+ PatternRewriter &rewriter) const override {
+ auto new_result_types =
+ llvm::to_vector<6>(tf_fused_batch_norm_op.getResultTypes());
+ // reserve_space_3
+ new_result_types.push_back(
+ UnrankedTensorType::get(FloatType::getF32(rewriter.getContext())));
+
+ OperationState new_state(tf_fused_batch_norm_op.getLoc(),
+ TF::FusedBatchNormV3Op::getOperationName(),
+ tf_fused_batch_norm_op.getOperands(),
+ new_result_types,
+ tf_fused_batch_norm_op.getAttrs());
+ Operation *tf_fused_batch_norm_op_v3 = rewriter.createOperation(new_state);
+
+ rewriter.replaceOp(tf_fused_batch_norm_op,
+ tf_fused_batch_norm_op_v3->getResults().drop_back());
+ return success();
+ }
+};
+
#include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc"
// Returns success if all the operations in the `op`'s regions including `op`
@@ -899,6 +924,8 @@
// replaced with a single Conv op with dilation parameter.
patterns.insert<ConvertTFDilatedConvOp<TF::Conv2DOp>,
ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(ctx);
+
+ patterns.insert<ConvertFusedBatchNorm>(ctx);
TFL::populateWithGenerated(ctx, &patterns);
// TODO(karimnosseir): Split to separate pass probably after
// deciding on long term plan for this optimization.