Adds a number of loss functions and reduction utilities to calculate quantities such as L1 and L2 loss.

Change: 115671496
This commit is contained in:
A. Unique TensorFlower 2016-02-26 08:23:44 -08:00 committed by TensorFlower Gardener
parent e0d5b85a16
commit dff7bb7c73
5 changed files with 805 additions and 0 deletions

View File

@ -17,6 +17,8 @@ py_library(
"python/layers/layers.py",
"python/layers/regularizers.py",
"python/layers/summaries.py",
"python/ops/__init__.py",
"python/ops/loss_ops.py",
],
srcs_version = "PY2AND3",
)
@ -57,6 +59,18 @@ py_test(
],
)
py_test(
name = "loss_ops_test",
srcs = glob(["python/ops/loss_ops_test.py"]),
srcs_version = "PY2AND3",
deps = [
":layers_py",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
py_test(
name = "summaries_test",
srcs = glob(["python/layers/summaries_test.py"]),

View File

@ -66,3 +66,4 @@ from __future__ import print_function
# pylint: disable=unused-import,wildcard-import
from tensorflow.contrib.layers.python.framework.tensor_util import *
from tensorflow.contrib.layers.python.layers import *
from tensorflow.contrib.layers.python.ops import *

View File

@ -0,0 +1,22 @@
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A module containing TensorFlow ops whose API may change in the future."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import
from tensorflow.contrib.layers.python.ops.loss_ops import *

View File

@ -0,0 +1,347 @@
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""## Loss operations for use in neural networks.
The loss ops measure error for use in neural networks. These losses
can be used for measuring accuracy of a network in a regression task
or for regularization purposes (e.g., weight decay).
These loss ops are, by design, minimal, enabling flexibility in how
their output can be used.
@@reduce_batch_sum
@@reduce_batch_mean
@@absolute_loss
@@squared_loss
@@sum_squared_loss
@@mean_absolute_loss
@@mean_squared_loss
@@root_mean_squared_loss
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
__all__ = ["reduce_batch_sum", "reduce_batch_mean", "absolute_loss",
"squared_loss", "sum_squared_loss", "mean_absolute_loss",
"mean_squared_loss", "root_mean_squared_loss"]
def _reduce_batch(x, reduce_fn, name=None):
"""Given a tensor `x`, calls reduce_fn to reduce it across dimensions.
Given a tensor with number of dimensions > 1, _reduce_batch will reduce the
tensor across all dimensions except for dimension 0. As an example, given a
tensor of shape [batch_size, d1, d2], this function will reduce across
dimensions d1 and d2, returning a tensor of shape [batch_size].
Tensors of dimension 1 are returned as-is, while tensors of dimension 0
raise a ValueError.
Args:
x: A `Tensor` with dimension > 0.
reduce_fn: A math_ops reduce function that takes arguments of
`x`, `reduction_indices`, and `name`.
name: A name for the operation (optional).
Returns:
A `Tensor` with values reduced by reduce_fn across all dimensions > 0.
Raises:
ValueError: If `x` has dimension 0.
"""
x = ops.convert_to_tensor(x, name="x")
with ops.op_scope([x], name, "reduce_batch"):
ndims = x.get_shape().ndims
if ndims == 0:
raise ValueError("Cannot reduce a scalar into batches.")
elif ndims == 1:
return x # Don't include a useless reduction.
elif ndims:
reduction_indices = range(1, ndims)
shape = [x.get_shape().dims[0]]
else:
reduction_indices = math_ops.range(1, array_ops.size(array_ops.shape(x)))
shape = [None] # We don't know much about the shape, but it is rank 1.
result = reduce_fn(x, reduction_indices=reduction_indices)
# Give a shape hint in case we have extra information.
result.set_shape(shape)
return result
def reduce_batch_sum(x, name=None):
"""Given a tensor `x`, sums across all dimensions except dimension 0.
Given a tensor with the number of dimensions > 1, reduce_batch_sum
will sum across all dimensions except for dimension 0. This function
is useful for summing the loss (error) across all examples in a
batch when training. As an example, given a tensor of shape
[batch_size, d1, d2], this function will sum across dimensions d1
and d2, returning a tensor of shape [batch_size].
Tensors of dimension 1 are returned as-is, while tensors of dimension 0
raise a ValueError.
Args:
x: A `Tensor` with dimension > 0.
name: A name for the operation (optional).
Returns:
A `Tensor` with values summed across all dimensions > 0.
Raises:
ValueError: If `x` has dimension 0.
"""
return _reduce_batch(x, math_ops.reduce_sum, name)
def reduce_batch_mean(x, name=None):
"""Given a tensor `x`, returns the mean across all dimensions except dim 0.
Given a tensor with the number of dimensions > 1, reduce_batch_mean
will calculate the mean across all dimensions except for dimension
0. This function is useful for calculating the mean loss (error)
across all examples in a batch when training. As an example, given a
tensor of shape [batch_size, d1, d2], this function will calculate
the mean across dimensions d1 and d2, returning a tensor of shape
[batch_size].
Tensors of dimension 1 are returned as-is.
Args:
x: A `Tensor` with dimension > 0.
name: A name for the operation (optional).
Returns:
A `Tensor` with values averaged across all dimensions > 0.
Raises:
ValueError: If `x` has dimension 0.
"""
return _reduce_batch(x, math_ops.reduce_mean, name)
def absolute_loss(predicted, target, name=None):
"""Computes and returns the per-example absolute loss.
Computes the per-example absolute value of the difference between
the target and predicted tensors. The tensors must have the same
shape.
Args:
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
of predicted values.
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
target values. The shape of the target tensor should match the
`predicted` tensor.
name: A name for the operation (optional).
Returns:
A `[batch_size, dim_1, ..., dim_n]` tensor of per-example absolute losses.
Raises:
ValueError: If `predicted` and `target` shapes do not match.
"""
with ops.op_scope([predicted, target], name, "absolute_loss") as scope:
predicted = ops.convert_to_tensor(predicted, name="predicted")
target = ops.convert_to_tensor(target, name="target")
predicted.get_shape().assert_is_compatible_with(target.get_shape())
return math_ops.abs(target - predicted, name=scope)
def squared_loss(predicted, target, name=None):
"""Computes and returns the per-example squared loss.
Computes the per-example squared difference between the target and
predicted tensors. The tensors must have the same shape.
Args:
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
of predicted values.
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
target values. The shape of the target tensor should match the
`predicted` tensor.
name: A name for the operation (optional).
Returns:
A `[batch_size, dim_1, ..., dim_n]` tensor of per-example squared losses.
Raises:
ValueError: If `predicted` and `target` shapes do not match.
"""
with ops.op_scope([predicted, target], name, "squared_loss") as scope:
predicted = ops.convert_to_tensor(predicted, name="predicted")
target = ops.convert_to_tensor(target, name="target")
predicted.get_shape().assert_is_compatible_with(target.get_shape())
return math_ops.square(target - predicted, name=scope)
def sum_squared_loss(predicted, target, name=None):
# pylint: disable=line-too-long
"""Calculates 1/2 the sum of the squared loss across batches.
Computes the squared difference between the target and predicted
tensors, sums across all dimensions except dimension 0, and divides
by 2:
losses = reduce_batch_sum(squared_loss(predicted, target)) / 2.0
where `losses` is a tensor with dimensions [batch_size].
The tensors must have the same shape.
This function is equivalent to typical formulations of L2 loss, and similar
to TensorFlow's l2_loss function. It differs from the l2_loss function
by allowing the caller to specify both the predicted and target tensors.
Args:
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
of predicted values.
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
target values. The shape of the target tensor should match the
`predicted` tensor.
name: A name for the operation (optional).
Returns:
A `[batch_size]` tensor of squared losses summed across all dimensions
except dimension 0, divided by 2.
Raises:
ValueError: If `predicted` and `target` shapes do not match.
"""
# pylint: enable=line-too-long
with ops.op_scope(
[predicted, target],
name,
"sum_squared_loss") as scope:
return math_ops.div(reduce_batch_sum(squared_loss(predicted, target)),
2.0,
name=scope)
def mean_absolute_loss(predicted, target, name=None):
"""Calculates the mean absolute loss across batches.
Computes the absolute difference between the target and predicted
tensors, averaged across all dimensions except dimension 0:
losses = reduce_batch_mean(absolute_loss(predicted, target))
where `losses` is a tensor with dimensions [batch_size].
The tensors must have the same shape.
This loss function is a form of L1 loss.
Args:
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
of predicted values.
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
target values. The shape of the target tensor should match the
`predicted` tensor.
name: A name for the operation (optional).
Returns:
A `[batch_size]` tensor of absolute differences, averaged across all
dimensions except dimension 0.
Raises:
ValueError: If `predicted` and `target` shapes do not match.
"""
with ops.op_scope([predicted, target], name, "mean_absolute_loss") as scope:
return reduce_batch_mean(absolute_loss(predicted, target), name=scope)
def mean_squared_loss(predicted, target, name=None):
"""Calculates the mean squared loss across batches.
Computes the squared difference between the target and predicted
tensors, and averages across all dimensions except dimension 0:
losses = reduce_batch_mean(squared_loss(predicted, target))
where `losses` is a tensor with dimensions [batch_size].
The tensors must have the same shape.
Args:
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
of predicted values.
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
target values. The shape of the target tensor should match the
`predicted` tensor.
name: A name for the operation (optional).
Returns:
A `[batch_size]` tensor of squared differences, averaged across
all dimensions except dimension 0.
Raises:
ValueError: If `predicted` and `target` shapes do not match.
"""
with ops.op_scope([predicted, target], name, "mean_squared_loss") as scope:
return reduce_batch_mean(squared_loss(predicted, target), name=scope)
def root_mean_squared_loss(predicted, target, name=None):
"""Calculates the root mean squared loss across batches.
Computes the root mean squared loss between the target and predicted
tensors, which is the square root of the mean squared differences
between the predicted and target tensors:
losses = sqrt(mean_squared_loss(predicted, target))
where `losses` is a tensor with dimensions [batch_size].
The tensors must have the same shape.
Args:
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
of predicted values.
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
target values. The shape of the target tensor should match the
`predicted` tensor.
name: A name for the operation (optional).
Returns:
A `[batch_size]` tensor of the root mean squared differences.
Raises:
ValueError: If `predicted` and `target` shapes do not match.
"""
with ops.op_scope([predicted, target],
name,
"root_mean_squared_loss") as scope:
return math_ops.sqrt(mean_squared_loss(predicted, target),
name=scope)

View File

@ -0,0 +1,421 @@
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for contrib.layers.python.ops.loss_ops."""
# pylint: disable=unused-import,g-bad-import-order
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
class ReduceBatchSumTest(tf.test.TestCase):
def testDimensionNone(self):
with self.test_session():
input_array = np.array([
[1.0, 2.0],
[-1.0, -2.0]
], dtype=np.float32)
placeholder_vec = tf.placeholder(tf.float32, name="placeholder_vec")
expected_result = np.array([3.0, -3.0])
actual_result = tf.contrib.layers.reduce_batch_sum(placeholder_vec)
self.assertEqual(actual_result.get_shape().as_list(), [None])
self.assertAllClose(expected_result, actual_result.eval(feed_dict={
placeholder_vec: input_array
}))
def testDimension0(self):
with self.test_session():
input_vec = tf.constant(2.0)
with self.assertRaises(ValueError):
tf.contrib.layers.reduce_batch_sum(input_vec)
def testDimension1(self):
with self.test_session():
input_vec = tf.constant([1.0, 2.0])
expected_result = np.array([1.0, 2.0])
actual_result = tf.contrib.layers.reduce_batch_sum(input_vec)
self.assertAllClose(expected_result, actual_result.eval())
def testDimension2(self):
with self.test_session():
input_vec = tf.constant([
[1.0, 2.0],
[-1.0, -2.0]
])
expected_result = np.array([3.0, -3.0])
actual_result = tf.contrib.layers.reduce_batch_sum(input_vec)
self.assertAllClose(expected_result, actual_result.eval())
def testReturnShape(self):
with self.test_session():
input_vec = tf.constant([
[1.0, 2.0],
[-1.0, -2.0]
])
expected_result = np.array([3.0, -3.0])
actual_result = tf.contrib.layers.reduce_batch_sum(input_vec)
self.assertShapeEqual(expected_result, actual_result)
def testDimensionN(self):
with self.test_session():
input_vec = tf.constant([
[
[1.0, 2.0],
[3.0, 4.0]
],
[
[5.0, 6.0],
[7.0, 8.0]
]
])
expected_result = np.array([10.0, 26.0])
actual_result = tf.contrib.layers.reduce_batch_sum(input_vec)
self.assertAllClose(expected_result, actual_result.eval())
class ReduceBatchMeanTest(tf.test.TestCase):
def testDimensionNone(self):
with self.test_session():
input_array = np.array([
[1.0, 2.0],
[-1.0, -2.0]
], dtype=np.float32)
placeholder_vec = tf.placeholder(tf.float32, name="placeholder_vec")
expected_result = np.array([1.5, -1.5])
actual_result = tf.contrib.layers.reduce_batch_mean(placeholder_vec)
self.assertEqual(actual_result.get_shape().as_list(), [None])
self.assertAllClose(expected_result, actual_result.eval(feed_dict={
placeholder_vec: input_array
}))
def testDimension0(self):
with self.test_session():
input_vec = tf.constant(2.0)
with self.assertRaises(ValueError):
tf.contrib.layers.reduce_batch_mean(input_vec)
def testDimension1(self):
with self.test_session():
input_vec = tf.constant([1.0, 2.0])
expected_result = np.array([1.0, 2.0])
actual_result = tf.contrib.layers.reduce_batch_mean(input_vec)
self.assertAllClose(expected_result, actual_result.eval())
def testDimension2(self):
with self.test_session():
input_vec = tf.constant([
[1.0, 2.0],
[-1.0, -2.0]
])
expected_result = np.array([1.5, -1.5])
actual_result = tf.contrib.layers.reduce_batch_mean(input_vec)
self.assertAllClose(expected_result, actual_result.eval())
def testReturnShape(self):
with self.test_session():
input_vec = tf.constant([
[1.0, 2.0],
[-1.0, -2.0]
])
expected_result = np.array([3.0, -3.0])
actual_result = tf.contrib.layers.reduce_batch_mean(input_vec)
self.assertShapeEqual(expected_result, actual_result)
def testDimensionN(self):
with self.test_session():
input_vec = tf.constant([
[
[1.0, 2.0],
[3.0, 4.0]
],
[
[5.0, 6.0],
[7.0, 8.0]
]
])
expected_result = np.array([2.5, 6.5])
actual_result = tf.contrib.layers.reduce_batch_mean(input_vec)
self.assertAllClose(expected_result, actual_result.eval())
class AbsoluteLossTest(tf.test.TestCase):
def _getTestVectors(self):
target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
predicted = tf.constant([1.1, -0.2, 3.3, 1.6], shape=[2, 2],
name="predicted")
expected_loss = np.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2)
return target, predicted, expected_loss
def testAbsoluteLoss(self):
with self.test_session():
target, predicted, expected_loss = self._getTestVectors()
result = tf.contrib.layers.absolute_loss(predicted, target)
self.assertAllClose(expected_loss, result.eval())
def testAbsoluteLossReturnShape(self):
with self.test_session():
target, predicted, expected_loss = self._getTestVectors()
result = tf.contrib.layers.absolute_loss(predicted, target)
self.assertShapeEqual(expected_loss, result)
def testInvalidShapesValueError(self):
with self.test_session():
target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
incompatible_shape = tf.constant([0.0, 1.1], shape=[2],
name="incompatible_shape")
with self.assertRaises(ValueError):
tf.contrib.layers.absolute_loss(incompatible_shape, target)
def testAbsoluteLossGradient(self):
with self.test_session():
target, predicted, _ = self._getTestVectors()
result = tf.contrib.layers.absolute_loss(predicted, target)
x_shape = [2, 2]
err = tf.test.compute_gradient_error(target, x_shape, result, x_shape)
err_tolerance = 1e-4
self.assertLess(err, err_tolerance)
class SquaredLossTest(tf.test.TestCase):
def _getTestVectors(self):
target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
predicted = tf.constant([1.1, -0.2, 3.3, 1.6], shape=[2, 2],
name="predicted")
expected_loss = np.array([0.01, 0.04, 0.09, 0.16]).reshape(2, 2)
return target, predicted, expected_loss
def testSquaredLoss(self):
with self.test_session():
target, predicted, expected_loss = self._getTestVectors()
result = tf.contrib.layers.squared_loss(predicted, target)
self.assertAllClose(expected_loss, result.eval())
def testSquaredLossReturnShape(self):
with self.test_session():
target, predicted, expected_loss = self._getTestVectors()
result = tf.contrib.layers.squared_loss(predicted, target)
self.assertShapeEqual(expected_loss, result)
def testInvalidShapesValueError(self):
with self.test_session():
target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
incompatible_shape = tf.constant([0.0, 1.1], shape=[2],
name="incompatible_shape")
with self.assertRaises(ValueError):
tf.contrib.layers.squared_loss(incompatible_shape, target)
def testSquaredLossGradient(self):
with self.test_session():
target, predicted, _ = self._getTestVectors()
result = tf.contrib.layers.squared_loss(predicted, target)
x_shape = [2, 2]
err = tf.test.compute_gradient_error(target, x_shape, result, x_shape)
err_tolerance = 1e-3
self.assertLess(err, err_tolerance)
class SumSquaredLossTest(tf.test.TestCase):
def _getTestVectors(self):
target = tf.constant([[0.0, 1.0],
[3.0, 2.0]],
shape=[2, 2],
name="target")
predicted = tf.constant([[3.0, -2.0],
[1.0, 2.0]],
shape=[2, 2],
name="predicted")
expected_loss = np.array([9.0, 2.0])
return target, predicted, expected_loss
def testSumSquaredLoss(self):
with self.test_session():
target, predicted, expected_loss = self._getTestVectors()
result = tf.contrib.layers.sum_squared_loss(predicted, target)
self.assertAllClose(expected_loss, result.eval())
def testSumSquaredLossReturnShape(self):
with self.test_session():
target, predicted, expected_loss = self._getTestVectors()
result = tf.contrib.layers.sum_squared_loss(predicted, target)
self.assertShapeEqual(expected_loss, result)
def testInvalidShapesValueError(self):
with self.test_session():
target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
incompatible_shape = tf.constant([0.0, 1.1], shape=[2],
name="incompatible_shape")
with self.assertRaises(ValueError):
tf.contrib.layers.sum_squared_loss(incompatible_shape, target)
def testSumSquaredLossGradient(self):
with self.test_session():
target, predicted, _ = self._getTestVectors()
result = tf.contrib.layers.sum_squared_loss(predicted, target)
x_shape = [2, 2]
result_shape = [2]
err = tf.test.compute_gradient_error(target, x_shape,
result, result_shape)
err_tolerance = 1e-3
self.assertLess(err, err_tolerance)
class MeanAbsoluteLossTest(tf.test.TestCase):
def _getTestVectors(self):
target = tf.constant([[0.0, 1.0, 2.0],
[3.0, 2.0, 4.0]],
shape=[2, 3],
name="target")
predicted = tf.constant([[3.0, -3.0, 0.0],
[1.0, 2.0, 0.0]],
shape=[2, 3],
name="predicted")
expected_loss = np.array([3.0, 2.0])
return target, predicted, expected_loss
def testMeanAbsoluteLoss(self):
with self.test_session():
target, predicted, expected_loss = self._getTestVectors()
result = tf.contrib.layers.mean_absolute_loss(predicted, target)
self.assertAllClose(expected_loss, result.eval())
def testMeanAbsoluteLossReturnShape(self):
with self.test_session():
target, predicted, expected_loss = self._getTestVectors()
result = tf.contrib.layers.mean_absolute_loss(predicted, target)
self.assertShapeEqual(expected_loss, result)
def testInvalidShapesValueError(self):
with self.test_session():
target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
incompatible_shape = tf.constant([0.0, 1.1], shape=[2],
name="incompatible_shape")
with self.assertRaises(ValueError):
tf.contrib.layers.mean_absolute_loss(incompatible_shape, target)
def testMeanAbsoluteLossGradient(self):
with self.test_session():
target, predicted, _ = self._getTestVectors()
result = tf.contrib.layers.mean_absolute_loss(predicted, target)
x_shape = [2, 3]
result_shape = [2]
err = tf.test.compute_gradient_error(target, x_shape,
result, result_shape)
err_tolerance = 1e-3
self.assertLess(err, err_tolerance)
class MeanSquaredLossTest(tf.test.TestCase):
def _getTestVectors(self):
target = tf.constant([[0.0, 1.0, 2.0],
[3.0, 2.0, 4.0]],
shape=[2, 3],
name="target")
predicted = tf.constant([[3.0, -3.0, 0.0],
[1.0, 2.0, 0.0]],
shape=[2, 3],
name="predicted")
expected_loss = np.array([9.666667, 6.666667])
return target, predicted, expected_loss
def testMeanSquaredLoss(self):
with self.test_session():
target, predicted, expected_loss = self._getTestVectors()
result = tf.contrib.layers.mean_squared_loss(predicted, target)
self.assertAllClose(expected_loss, result.eval())
def testMeanSquaredLossReturnShape(self):
with self.test_session():
target, predicted, expected_loss = self._getTestVectors()
result = tf.contrib.layers.mean_squared_loss(predicted, target)
self.assertShapeEqual(expected_loss, result)
def testInvalidShapesValueError(self):
with self.test_session():
target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
incompatible_shape = tf.constant([0.0, 1.1], shape=[2],
name="incompatible_shape")
with self.assertRaises(ValueError):
tf.contrib.layers.mean_squared_loss(incompatible_shape, target)
def testMeanSquaredLossGradient(self):
with self.test_session():
target, predicted, _ = self._getTestVectors()
result = tf.contrib.layers.mean_squared_loss(predicted, target)
x_shape = [2, 3]
result_shape = [2]
err = tf.test.compute_gradient_error(target, x_shape,
result, result_shape)
err_tolerance = 1e-3
self.assertLess(err, err_tolerance)
class RootMeanSquaredLossTest(tf.test.TestCase):
def _getTestVectors(self):
target = tf.constant([[0.0, 1.0, 2.0],
[3.0, 2.0, 4.0]],
shape=[2, 3],
name="target")
predicted = tf.constant([[3.0, -3.0, 0.0],
[1.0, 2.0, 0.0]],
shape=[2, 3],
name="predicted")
expected_loss = np.array([3.109126, 2.5819889])
return target, predicted, expected_loss
def testRootMeanSquaredLoss(self):
with self.test_session():
target, predicted, expected_loss = self._getTestVectors()
result = tf.contrib.layers.root_mean_squared_loss(predicted, target)
self.assertAllClose(expected_loss, result.eval())
def testRootMeanSquaredLossReturnShape(self):
with self.test_session():
target, predicted, expected_loss = self._getTestVectors()
result = tf.contrib.layers.root_mean_squared_loss(predicted, target)
self.assertShapeEqual(expected_loss, result)
def testInvalidShapesValueError(self):
with self.test_session():
target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
incompatible_shape = tf.constant([0.0, 1.1], shape=[2],
name="incompatible_shape")
with self.assertRaises(ValueError):
tf.contrib.layers.root_mean_squared_loss(incompatible_shape, target)
def testRootMeanSquaredLossGradient(self):
with self.test_session():
target, predicted, _ = self._getTestVectors()
result = tf.contrib.layers.root_mean_squared_loss(predicted, target)
x_shape = [2, 3]
result_shape = [2]
err = tf.test.compute_gradient_error(target, x_shape,
result, result_shape)
err_tolerance = 1e-3
self.assertLess(err, err_tolerance)
if __name__ == "__main__":
tf.test.main()