Checks the ndims of weights before indexing in the sparse_softmax_cross_entropy
PiperOrigin-RevId: 156256866
diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py
index 40fddd7..a23a805 100644
--- a/tensorflow/python/kernel_tests/losses_test.py
+++ b/tensorflow/python/kernel_tests/losses_test.py
@@ -319,6 +319,21 @@
feed_dict={weights: ((1.2,), (3.4,), (5.6,))})
self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss_val, 3)
+ def testUnknownShapePlaceholderForLogitsLabelsButScalarWeights(self):
+ logits = array_ops.placeholder(dtypes.float32)
+ labels = array_ops.placeholder(dtypes.int32)
+ weights = 1.0
+ with self.test_session() as sess:
+ loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
+ loss_val = sess.run(loss,
+ feed_dict={
+ logits: [[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]],
+ labels: [[2], [0], [1]],
+ })
+ self.assertAlmostEqual((1.0 + 1.0 + 1.0) * 10.0 / 3.0, loss_val, 3)
+
def testNonZeroLossWithPlaceholderForLogitsLabelsAndWeights(self):
logits = array_ops.placeholder(dtypes.float32, shape=(None, 3))
labels = array_ops.placeholder(dtypes.int32, shape=(None, 1))
diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py
index fc54553..97078d7 100644
--- a/tensorflow/python/ops/losses/losses_impl.py
+++ b/tensorflow/python/ops/losses/losses_impl.py
@@ -690,7 +690,7 @@
# Use dynamic rank.
rank_diff = array_ops.rank(weights) - array_ops.rank(labels)
if (weights_rank is None) or (
- weights_shape.dims[-1].is_compatible_with(1)):
+ weights_rank > 0 and weights_shape.dims[-1].is_compatible_with(1)):
weights = control_flow_ops.cond(
math_ops.equal(1, rank_diff),
lambda: array_ops.squeeze(weights, [-1]),