Losses should support scalar logits/labels.
Change: 144366312
This commit is contained in:
parent
3689c21345
commit
5ccc994e7e
@ -766,6 +766,13 @@ class MeanSquaredErrorTest(test.TestCase):
|
||||
losses.mean_squared_error(
|
||||
self._predictions, self._predictions, weights=None)
|
||||
|
||||
def testScalar(self):
|
||||
with self.test_session():
|
||||
self.assertEqual(
|
||||
0.0,
|
||||
losses.mean_squared_error(predictions=constant_op.constant(0),
|
||||
labels=constant_op.constant(0)).eval())
|
||||
|
||||
def testAllCorrectNoLossWeight(self):
|
||||
loss = losses.mean_squared_error(self._predictions, self._predictions)
|
||||
with self.test_session():
|
||||
|
@ -119,8 +119,11 @@ def _num_present(losses, weights, per_batch=False):
|
||||
"""
|
||||
# If weights is a scalar, its easy to compute:
|
||||
if weights.get_shape().ndims == 0:
|
||||
batch_size = array_ops.reshape(array_ops.slice(array_ops.shape(losses),
|
||||
[0], [1]), [])
|
||||
if losses.get_shape().ndims == 0:
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = array_ops.reshape(array_ops.slice(array_ops.shape(losses),
|
||||
[0], [1]), [])
|
||||
num_per_batch = math_ops.div(math_ops.to_float(array_ops.size(losses)),
|
||||
math_ops.to_float(batch_size))
|
||||
num_per_batch = array_ops.where(math_ops.equal(weights, 0),
|
||||
|
Loading…
Reference in New Issue
Block a user