Losses should support scalar logits/labels.

Change: 144366312
This commit is contained in:
A. Unique TensorFlower 2017-01-12 13:46:06 -08:00 committed by TensorFlower Gardener
parent 3689c21345
commit 5ccc994e7e
2 changed files with 12 additions and 2 deletions

View File

@ -766,6 +766,13 @@ class MeanSquaredErrorTest(test.TestCase):
losses.mean_squared_error(
self._predictions, self._predictions, weights=None)
def testScalar(self):
with self.test_session():
self.assertEqual(
0.0,
losses.mean_squared_error(predictions=constant_op.constant(0),
labels=constant_op.constant(0)).eval())
def testAllCorrectNoLossWeight(self):
loss = losses.mean_squared_error(self._predictions, self._predictions)
with self.test_session():

View File

@ -119,8 +119,11 @@ def _num_present(losses, weights, per_batch=False):
"""
# If weights is a scalar, its easy to compute:
if weights.get_shape().ndims == 0:
batch_size = array_ops.reshape(array_ops.slice(array_ops.shape(losses),
[0], [1]), [])
if losses.get_shape().ndims == 0:
batch_size = 1
else:
batch_size = array_ops.reshape(array_ops.slice(array_ops.shape(losses),
[0], [1]), [])
num_per_batch = math_ops.div(math_ops.to_float(array_ops.size(losses)),
math_ops.to_float(batch_size))
num_per_batch = array_ops.where(math_ops.equal(weights, 0),