From 5ccc994e7e855f730c49e9b1bfeffc92b3928d13 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Thu, 12 Jan 2017 13:46:06 -0800 Subject: [PATCH] Losses should support scalar logits/labels. Change: 144366312 --- tensorflow/python/kernel_tests/losses_test.py | 7 +++++++ tensorflow/python/ops/losses/losses_impl.py | 7 +++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py index d784ec43b5a..125d353df38 100644 --- a/tensorflow/python/kernel_tests/losses_test.py +++ b/tensorflow/python/kernel_tests/losses_test.py @@ -766,6 +766,13 @@ class MeanSquaredErrorTest(test.TestCase): losses.mean_squared_error( self._predictions, self._predictions, weights=None) + def testScalar(self): + with self.test_session(): + self.assertEqual( + 0.0, + losses.mean_squared_error(predictions=constant_op.constant(0), + labels=constant_op.constant(0)).eval()) + def testAllCorrectNoLossWeight(self): loss = losses.mean_squared_error(self._predictions, self._predictions) with self.test_session(): diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py index c23d046d70d..486e25afc71 100644 --- a/tensorflow/python/ops/losses/losses_impl.py +++ b/tensorflow/python/ops/losses/losses_impl.py @@ -119,8 +119,11 @@ def _num_present(losses, weights, per_batch=False): """ # If weights is a scalar, its easy to compute: if weights.get_shape().ndims == 0: - batch_size = array_ops.reshape(array_ops.slice(array_ops.shape(losses), - [0], [1]), []) + if losses.get_shape().ndims == 0: + batch_size = 1 + else: + batch_size = array_ops.reshape(array_ops.slice(array_ops.shape(losses), + [0], [1]), []) num_per_batch = math_ops.div(math_ops.to_float(array_ops.size(losses)), math_ops.to_float(batch_size)) num_per_batch = array_ops.where(math_ops.equal(weights, 0),