From 5ccc994e7e855f730c49e9b1bfeffc92b3928d13 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 12 Jan 2017 13:46:06 -0800
Subject: [PATCH] Losses should support scalar logits/labels. Change: 144366312

---
 tensorflow/python/kernel_tests/losses_test.py | 7 +++++++
 tensorflow/python/ops/losses/losses_impl.py   | 7 +++++--
 2 files changed, 12 insertions(+), 2 deletions(-)

diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py
index d784ec43b5a..125d353df38 100644
--- a/tensorflow/python/kernel_tests/losses_test.py
+++ b/tensorflow/python/kernel_tests/losses_test.py
@@ -766,6 +766,13 @@ class MeanSquaredErrorTest(test.TestCase):
         losses.mean_squared_error(
             self._predictions, self._predictions, weights=None)
 
+  def testScalar(self):
+    with self.test_session():
+      self.assertEqual(
+          0.0,
+          losses.mean_squared_error(predictions=constant_op.constant(0),
+                                    labels=constant_op.constant(0)).eval())
+
   def testAllCorrectNoLossWeight(self):
     loss = losses.mean_squared_error(self._predictions, self._predictions)
     with self.test_session():
diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py
index c23d046d70d..486e25afc71 100644
--- a/tensorflow/python/ops/losses/losses_impl.py
+++ b/tensorflow/python/ops/losses/losses_impl.py
@@ -119,8 +119,11 @@ def _num_present(losses, weights, per_batch=False):
   """
   # If weights is a scalar, its easy to compute:
   if weights.get_shape().ndims == 0:
-    batch_size = array_ops.reshape(array_ops.slice(array_ops.shape(losses),
-                                                   [0], [1]), [])
+    if losses.get_shape().ndims == 0:
+      batch_size = 1
+    else:
+      batch_size = array_ops.reshape(array_ops.slice(array_ops.shape(losses),
+                                                     [0], [1]), [])
     num_per_batch = math_ops.div(math_ops.to_float(array_ops.size(losses)),
                                  math_ops.to_float(batch_size))
     num_per_batch = array_ops.where(math_ops.equal(weights, 0),