Checks the ndims of weights before indexing in the sparse_softmax_cross_entropy

PiperOrigin-RevId: 156256866
diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py
index 40fddd7..a23a805 100644
--- a/tensorflow/python/kernel_tests/losses_test.py
+++ b/tensorflow/python/kernel_tests/losses_test.py
@@ -319,6 +319,21 @@
                           feed_dict={weights: ((1.2,), (3.4,), (5.6,))})
       self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss_val, 3)
 
+  def testUnknownShapePlaceholderForLogitsLabelsButScalarWeights(self):
+    logits = array_ops.placeholder(dtypes.float32)
+    labels = array_ops.placeholder(dtypes.int32)
+    weights = 1.0
+    with self.test_session() as sess:
+      loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
+      loss_val = sess.run(loss,
+                          feed_dict={
+                              logits: [[10.0, 0.0, 0.0],
+                                       [0.0, 10.0, 0.0],
+                                       [0.0, 0.0, 10.0]],
+                              labels: [[2], [0], [1]],
+                          })
+      self.assertAlmostEqual((1.0 + 1.0 + 1.0) * 10.0 / 3.0, loss_val, 3)
+
   def testNonZeroLossWithPlaceholderForLogitsLabelsAndWeights(self):
     logits = array_ops.placeholder(dtypes.float32, shape=(None, 3))
     labels = array_ops.placeholder(dtypes.int32, shape=(None, 1))
diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py
index fc54553..97078d7 100644
--- a/tensorflow/python/ops/losses/losses_impl.py
+++ b/tensorflow/python/ops/losses/losses_impl.py
@@ -690,7 +690,7 @@
     # Use dynamic rank.
     rank_diff = array_ops.rank(weights) - array_ops.rank(labels)
     if (weights_rank is None) or (
-        weights_shape.dims[-1].is_compatible_with(1)):
+        weights_rank > 0 and weights_shape.dims[-1].is_compatible_with(1)):
       weights = control_flow_ops.cond(
           math_ops.equal(1, rank_diff),
           lambda: array_ops.squeeze(weights, [-1]),