Fixes possible out-of-bounds access by strided slice.
PiperOrigin-RevId: 215269882
diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc
index f0575de..3e8a4c5 100644
--- a/tensorflow/core/kernels/strided_slice_op.cc
+++ b/tensorflow/core/kernels/strided_slice_op.cc
@@ -149,7 +149,7 @@
// NDIM and T
if (is_simple_slice && std::is_same<Device, CPUDevice>::value &&
input_dims == 2 && processing_shape.dims() == 2 &&
- final_shape.dims() == 2) {
+ final_shape.dims() == 2 && new_axis_mask == 0) {
MemCpyFunctor<T> functor;
if (functor.Copy(input, begin, end, result)) {
return;
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index c5547b1..dcc5947 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -615,6 +615,14 @@
_ = checker[:, 0]
_ = checker[:, :, 0]
+ def testBothNewAxisAndShrink(self):
+ with self.test_session(use_gpu=True):
+ ones = array_ops.placeholder(shape=[2, 2], dtype=dtypes.int16)
+ self.assertAllEqual(
+ ones[array_ops.newaxis, :, 0].eval(
+ feed_dict={ones: [[1, 1], [1, 1]]}),
+ [[1, 1]])
+
def testTensorIndexing(self):
with self.test_session(use_gpu=True):
raw = [[[[[1, 2, 4, 5], [5, 6, 7, 8], [9, 10, 11, 12]]],