Losses should support scalar logits/labels.
Change: 144366312
This commit is contained in:
parent
3689c21345
commit
5ccc994e7e
@ -766,6 +766,13 @@ class MeanSquaredErrorTest(test.TestCase):
|
|||||||
losses.mean_squared_error(
|
losses.mean_squared_error(
|
||||||
self._predictions, self._predictions, weights=None)
|
self._predictions, self._predictions, weights=None)
|
||||||
|
|
||||||
|
def testScalar(self):
|
||||||
|
with self.test_session():
|
||||||
|
self.assertEqual(
|
||||||
|
0.0,
|
||||||
|
losses.mean_squared_error(predictions=constant_op.constant(0),
|
||||||
|
labels=constant_op.constant(0)).eval())
|
||||||
|
|
||||||
def testAllCorrectNoLossWeight(self):
|
def testAllCorrectNoLossWeight(self):
|
||||||
loss = losses.mean_squared_error(self._predictions, self._predictions)
|
loss = losses.mean_squared_error(self._predictions, self._predictions)
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
|
@ -119,6 +119,9 @@ def _num_present(losses, weights, per_batch=False):
|
|||||||
"""
|
"""
|
||||||
# If weights is a scalar, its easy to compute:
|
# If weights is a scalar, its easy to compute:
|
||||||
if weights.get_shape().ndims == 0:
|
if weights.get_shape().ndims == 0:
|
||||||
|
if losses.get_shape().ndims == 0:
|
||||||
|
batch_size = 1
|
||||||
|
else:
|
||||||
batch_size = array_ops.reshape(array_ops.slice(array_ops.shape(losses),
|
batch_size = array_ops.reshape(array_ops.slice(array_ops.shape(losses),
|
||||||
[0], [1]), [])
|
[0], [1]), [])
|
||||||
num_per_batch = math_ops.div(math_ops.to_float(array_ops.size(losses)),
|
num_per_batch = math_ops.div(math_ops.to_float(array_ops.size(losses)),
|
||||||
|
Loading…
Reference in New Issue
Block a user