Adds a number of loss functions and reduction utilities to calculate quantities such as L1 and L2 loss.
Change: 115671496
This commit is contained in:
parent
e0d5b85a16
commit
dff7bb7c73
tensorflow/contrib/layers
@ -17,6 +17,8 @@ py_library(
|
||||
"python/layers/layers.py",
|
||||
"python/layers/regularizers.py",
|
||||
"python/layers/summaries.py",
|
||||
"python/ops/__init__.py",
|
||||
"python/ops/loss_ops.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
@ -57,6 +59,18 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "loss_ops_test",
|
||||
srcs = glob(["python/ops/loss_ops_test.py"]),
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":layers_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "summaries_test",
|
||||
srcs = glob(["python/layers/summaries_test.py"]),
|
||||
|
@ -66,3 +66,4 @@ from __future__ import print_function
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.layers.python.framework.tensor_util import *
|
||||
from tensorflow.contrib.layers.python.layers import *
|
||||
from tensorflow.contrib.layers.python.ops import *
|
||||
|
22
tensorflow/contrib/layers/python/ops/__init__.py
Normal file
22
tensorflow/contrib/layers/python/ops/__init__.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright 2015 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A module containing TensorFlow ops whose API may change in the future."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.layers.python.ops.loss_ops import *
|
347
tensorflow/contrib/layers/python/ops/loss_ops.py
Normal file
347
tensorflow/contrib/layers/python/ops/loss_ops.py
Normal file
@ -0,0 +1,347 @@
|
||||
# Copyright 2015 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""## Loss operations for use in neural networks.
|
||||
|
||||
The loss ops measure error for use in neural networks. These losses
|
||||
can be used for measuring accuracy of a network in a regression task
|
||||
or for regularization purposes (e.g., weight decay).
|
||||
|
||||
These loss ops are, by design, minimal, enabling flexibility in how
|
||||
their output can be used.
|
||||
|
||||
@@reduce_batch_sum
|
||||
@@reduce_batch_mean
|
||||
|
||||
@@absolute_loss
|
||||
@@squared_loss
|
||||
|
||||
@@sum_squared_loss
|
||||
@@mean_absolute_loss
|
||||
@@mean_squared_loss
|
||||
@@root_mean_squared_loss
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
__all__ = ["reduce_batch_sum", "reduce_batch_mean", "absolute_loss",
|
||||
"squared_loss", "sum_squared_loss", "mean_absolute_loss",
|
||||
"mean_squared_loss", "root_mean_squared_loss"]
|
||||
|
||||
|
||||
def _reduce_batch(x, reduce_fn, name=None):
|
||||
"""Given a tensor `x`, calls reduce_fn to reduce it across dimensions.
|
||||
|
||||
Given a tensor with number of dimensions > 1, _reduce_batch will reduce the
|
||||
tensor across all dimensions except for dimension 0. As an example, given a
|
||||
tensor of shape [batch_size, d1, d2], this function will reduce across
|
||||
dimensions d1 and d2, returning a tensor of shape [batch_size].
|
||||
|
||||
Tensors of dimension 1 are returned as-is, while tensors of dimension 0
|
||||
raise a ValueError.
|
||||
|
||||
Args:
|
||||
x: A `Tensor` with dimension > 0.
|
||||
reduce_fn: A math_ops reduce function that takes arguments of
|
||||
`x`, `reduction_indices`, and `name`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `Tensor` with values reduced by reduce_fn across all dimensions > 0.
|
||||
|
||||
Raises:
|
||||
ValueError: If `x` has dimension 0.
|
||||
"""
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
with ops.op_scope([x], name, "reduce_batch"):
|
||||
ndims = x.get_shape().ndims
|
||||
if ndims == 0:
|
||||
raise ValueError("Cannot reduce a scalar into batches.")
|
||||
elif ndims == 1:
|
||||
return x # Don't include a useless reduction.
|
||||
elif ndims:
|
||||
reduction_indices = range(1, ndims)
|
||||
shape = [x.get_shape().dims[0]]
|
||||
else:
|
||||
reduction_indices = math_ops.range(1, array_ops.size(array_ops.shape(x)))
|
||||
shape = [None] # We don't know much about the shape, but it is rank 1.
|
||||
result = reduce_fn(x, reduction_indices=reduction_indices)
|
||||
|
||||
# Give a shape hint in case we have extra information.
|
||||
result.set_shape(shape)
|
||||
return result
|
||||
|
||||
|
||||
def reduce_batch_sum(x, name=None):
|
||||
"""Given a tensor `x`, sums across all dimensions except dimension 0.
|
||||
|
||||
Given a tensor with the number of dimensions > 1, reduce_batch_sum
|
||||
will sum across all dimensions except for dimension 0. This function
|
||||
is useful for summing the loss (error) across all examples in a
|
||||
batch when training. As an example, given a tensor of shape
|
||||
[batch_size, d1, d2], this function will sum across dimensions d1
|
||||
and d2, returning a tensor of shape [batch_size].
|
||||
|
||||
Tensors of dimension 1 are returned as-is, while tensors of dimension 0
|
||||
raise a ValueError.
|
||||
|
||||
Args:
|
||||
x: A `Tensor` with dimension > 0.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `Tensor` with values summed across all dimensions > 0.
|
||||
|
||||
Raises:
|
||||
ValueError: If `x` has dimension 0.
|
||||
|
||||
"""
|
||||
return _reduce_batch(x, math_ops.reduce_sum, name)
|
||||
|
||||
|
||||
def reduce_batch_mean(x, name=None):
|
||||
"""Given a tensor `x`, returns the mean across all dimensions except dim 0.
|
||||
|
||||
Given a tensor with the number of dimensions > 1, reduce_batch_mean
|
||||
will calculate the mean across all dimensions except for dimension
|
||||
0. This function is useful for calculating the mean loss (error)
|
||||
across all examples in a batch when training. As an example, given a
|
||||
tensor of shape [batch_size, d1, d2], this function will calculate
|
||||
the mean across dimensions d1 and d2, returning a tensor of shape
|
||||
[batch_size].
|
||||
|
||||
Tensors of dimension 1 are returned as-is.
|
||||
|
||||
Args:
|
||||
x: A `Tensor` with dimension > 0.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `Tensor` with values averaged across all dimensions > 0.
|
||||
|
||||
Raises:
|
||||
ValueError: If `x` has dimension 0.
|
||||
|
||||
"""
|
||||
return _reduce_batch(x, math_ops.reduce_mean, name)
|
||||
|
||||
|
||||
def absolute_loss(predicted, target, name=None):
|
||||
"""Computes and returns the per-example absolute loss.
|
||||
|
||||
Computes the per-example absolute value of the difference between
|
||||
the target and predicted tensors. The tensors must have the same
|
||||
shape.
|
||||
|
||||
Args:
|
||||
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
|
||||
of predicted values.
|
||||
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
|
||||
target values. The shape of the target tensor should match the
|
||||
`predicted` tensor.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `[batch_size, dim_1, ..., dim_n]` tensor of per-example absolute losses.
|
||||
|
||||
Raises:
|
||||
ValueError: If `predicted` and `target` shapes do not match.
|
||||
|
||||
"""
|
||||
with ops.op_scope([predicted, target], name, "absolute_loss") as scope:
|
||||
predicted = ops.convert_to_tensor(predicted, name="predicted")
|
||||
target = ops.convert_to_tensor(target, name="target")
|
||||
predicted.get_shape().assert_is_compatible_with(target.get_shape())
|
||||
return math_ops.abs(target - predicted, name=scope)
|
||||
|
||||
|
||||
def squared_loss(predicted, target, name=None):
|
||||
"""Computes and returns the per-example squared loss.
|
||||
|
||||
Computes the per-example squared difference between the target and
|
||||
predicted tensors. The tensors must have the same shape.
|
||||
|
||||
Args:
|
||||
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
|
||||
of predicted values.
|
||||
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
|
||||
target values. The shape of the target tensor should match the
|
||||
`predicted` tensor.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `[batch_size, dim_1, ..., dim_n]` tensor of per-example squared losses.
|
||||
|
||||
Raises:
|
||||
ValueError: If `predicted` and `target` shapes do not match.
|
||||
|
||||
"""
|
||||
with ops.op_scope([predicted, target], name, "squared_loss") as scope:
|
||||
predicted = ops.convert_to_tensor(predicted, name="predicted")
|
||||
target = ops.convert_to_tensor(target, name="target")
|
||||
predicted.get_shape().assert_is_compatible_with(target.get_shape())
|
||||
return math_ops.square(target - predicted, name=scope)
|
||||
|
||||
|
||||
def sum_squared_loss(predicted, target, name=None):
|
||||
# pylint: disable=line-too-long
|
||||
"""Calculates 1/2 the sum of the squared loss across batches.
|
||||
|
||||
Computes the squared difference between the target and predicted
|
||||
tensors, sums across all dimensions except dimension 0, and divides
|
||||
by 2:
|
||||
|
||||
losses = reduce_batch_sum(squared_loss(predicted, target)) / 2.0
|
||||
|
||||
where `losses` is a tensor with dimensions [batch_size].
|
||||
|
||||
The tensors must have the same shape.
|
||||
|
||||
This function is equivalent to typical formulations of L2 loss, and similar
|
||||
to TensorFlow's l2_loss function. It differs from the l2_loss function
|
||||
by allowing the caller to specify both the predicted and target tensors.
|
||||
|
||||
Args:
|
||||
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
|
||||
of predicted values.
|
||||
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
|
||||
target values. The shape of the target tensor should match the
|
||||
`predicted` tensor.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `[batch_size]` tensor of squared losses summed across all dimensions
|
||||
except dimension 0, divided by 2.
|
||||
|
||||
Raises:
|
||||
ValueError: If `predicted` and `target` shapes do not match.
|
||||
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
with ops.op_scope(
|
||||
[predicted, target],
|
||||
name,
|
||||
"sum_squared_loss") as scope:
|
||||
return math_ops.div(reduce_batch_sum(squared_loss(predicted, target)),
|
||||
2.0,
|
||||
name=scope)
|
||||
|
||||
|
||||
def mean_absolute_loss(predicted, target, name=None):
|
||||
"""Calculates the mean absolute loss across batches.
|
||||
|
||||
Computes the absolute difference between the target and predicted
|
||||
tensors, averaged across all dimensions except dimension 0:
|
||||
|
||||
losses = reduce_batch_mean(absolute_loss(predicted, target))
|
||||
|
||||
where `losses` is a tensor with dimensions [batch_size].
|
||||
|
||||
The tensors must have the same shape.
|
||||
|
||||
This loss function is a form of L1 loss.
|
||||
|
||||
Args:
|
||||
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
|
||||
of predicted values.
|
||||
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
|
||||
target values. The shape of the target tensor should match the
|
||||
`predicted` tensor.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `[batch_size]` tensor of absolute differences, averaged across all
|
||||
dimensions except dimension 0.
|
||||
|
||||
Raises:
|
||||
ValueError: If `predicted` and `target` shapes do not match.
|
||||
|
||||
"""
|
||||
with ops.op_scope([predicted, target], name, "mean_absolute_loss") as scope:
|
||||
return reduce_batch_mean(absolute_loss(predicted, target), name=scope)
|
||||
|
||||
|
||||
def mean_squared_loss(predicted, target, name=None):
|
||||
"""Calculates the mean squared loss across batches.
|
||||
|
||||
Computes the squared difference between the target and predicted
|
||||
tensors, and averages across all dimensions except dimension 0:
|
||||
|
||||
losses = reduce_batch_mean(squared_loss(predicted, target))
|
||||
|
||||
where `losses` is a tensor with dimensions [batch_size].
|
||||
|
||||
The tensors must have the same shape.
|
||||
|
||||
Args:
|
||||
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
|
||||
of predicted values.
|
||||
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
|
||||
target values. The shape of the target tensor should match the
|
||||
`predicted` tensor.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `[batch_size]` tensor of squared differences, averaged across
|
||||
all dimensions except dimension 0.
|
||||
|
||||
Raises:
|
||||
ValueError: If `predicted` and `target` shapes do not match.
|
||||
|
||||
"""
|
||||
with ops.op_scope([predicted, target], name, "mean_squared_loss") as scope:
|
||||
return reduce_batch_mean(squared_loss(predicted, target), name=scope)
|
||||
|
||||
|
||||
def root_mean_squared_loss(predicted, target, name=None):
|
||||
"""Calculates the root mean squared loss across batches.
|
||||
|
||||
Computes the root mean squared loss between the target and predicted
|
||||
tensors, which is the square root of the mean squared differences
|
||||
between the predicted and target tensors:
|
||||
|
||||
losses = sqrt(mean_squared_loss(predicted, target))
|
||||
|
||||
where `losses` is a tensor with dimensions [batch_size].
|
||||
|
||||
The tensors must have the same shape.
|
||||
|
||||
Args:
|
||||
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
|
||||
of predicted values.
|
||||
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
|
||||
target values. The shape of the target tensor should match the
|
||||
`predicted` tensor.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `[batch_size]` tensor of the root mean squared differences.
|
||||
|
||||
Raises:
|
||||
ValueError: If `predicted` and `target` shapes do not match.
|
||||
|
||||
"""
|
||||
with ops.op_scope([predicted, target],
|
||||
name,
|
||||
"root_mean_squared_loss") as scope:
|
||||
return math_ops.sqrt(mean_squared_loss(predicted, target),
|
||||
name=scope)
|
421
tensorflow/contrib/layers/python/ops/loss_ops_test.py
Normal file
421
tensorflow/contrib/layers/python/ops/loss_ops_test.py
Normal file
@ -0,0 +1,421 @@
|
||||
# Copyright 2015 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for contrib.layers.python.ops.loss_ops."""
|
||||
# pylint: disable=unused-import,g-bad-import-order
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class ReduceBatchSumTest(tf.test.TestCase):
|
||||
|
||||
def testDimensionNone(self):
|
||||
with self.test_session():
|
||||
input_array = np.array([
|
||||
[1.0, 2.0],
|
||||
[-1.0, -2.0]
|
||||
], dtype=np.float32)
|
||||
placeholder_vec = tf.placeholder(tf.float32, name="placeholder_vec")
|
||||
expected_result = np.array([3.0, -3.0])
|
||||
actual_result = tf.contrib.layers.reduce_batch_sum(placeholder_vec)
|
||||
self.assertEqual(actual_result.get_shape().as_list(), [None])
|
||||
self.assertAllClose(expected_result, actual_result.eval(feed_dict={
|
||||
placeholder_vec: input_array
|
||||
}))
|
||||
|
||||
def testDimension0(self):
|
||||
with self.test_session():
|
||||
input_vec = tf.constant(2.0)
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.layers.reduce_batch_sum(input_vec)
|
||||
|
||||
def testDimension1(self):
|
||||
with self.test_session():
|
||||
input_vec = tf.constant([1.0, 2.0])
|
||||
expected_result = np.array([1.0, 2.0])
|
||||
actual_result = tf.contrib.layers.reduce_batch_sum(input_vec)
|
||||
self.assertAllClose(expected_result, actual_result.eval())
|
||||
|
||||
def testDimension2(self):
|
||||
with self.test_session():
|
||||
input_vec = tf.constant([
|
||||
[1.0, 2.0],
|
||||
[-1.0, -2.0]
|
||||
])
|
||||
expected_result = np.array([3.0, -3.0])
|
||||
actual_result = tf.contrib.layers.reduce_batch_sum(input_vec)
|
||||
self.assertAllClose(expected_result, actual_result.eval())
|
||||
|
||||
def testReturnShape(self):
|
||||
with self.test_session():
|
||||
input_vec = tf.constant([
|
||||
[1.0, 2.0],
|
||||
[-1.0, -2.0]
|
||||
])
|
||||
expected_result = np.array([3.0, -3.0])
|
||||
actual_result = tf.contrib.layers.reduce_batch_sum(input_vec)
|
||||
self.assertShapeEqual(expected_result, actual_result)
|
||||
|
||||
def testDimensionN(self):
|
||||
with self.test_session():
|
||||
input_vec = tf.constant([
|
||||
[
|
||||
[1.0, 2.0],
|
||||
[3.0, 4.0]
|
||||
],
|
||||
[
|
||||
[5.0, 6.0],
|
||||
[7.0, 8.0]
|
||||
]
|
||||
])
|
||||
expected_result = np.array([10.0, 26.0])
|
||||
actual_result = tf.contrib.layers.reduce_batch_sum(input_vec)
|
||||
self.assertAllClose(expected_result, actual_result.eval())
|
||||
|
||||
|
||||
class ReduceBatchMeanTest(tf.test.TestCase):
|
||||
|
||||
def testDimensionNone(self):
|
||||
with self.test_session():
|
||||
input_array = np.array([
|
||||
[1.0, 2.0],
|
||||
[-1.0, -2.0]
|
||||
], dtype=np.float32)
|
||||
placeholder_vec = tf.placeholder(tf.float32, name="placeholder_vec")
|
||||
expected_result = np.array([1.5, -1.5])
|
||||
actual_result = tf.contrib.layers.reduce_batch_mean(placeholder_vec)
|
||||
self.assertEqual(actual_result.get_shape().as_list(), [None])
|
||||
self.assertAllClose(expected_result, actual_result.eval(feed_dict={
|
||||
placeholder_vec: input_array
|
||||
}))
|
||||
|
||||
def testDimension0(self):
|
||||
with self.test_session():
|
||||
input_vec = tf.constant(2.0)
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.layers.reduce_batch_mean(input_vec)
|
||||
|
||||
def testDimension1(self):
|
||||
with self.test_session():
|
||||
input_vec = tf.constant([1.0, 2.0])
|
||||
expected_result = np.array([1.0, 2.0])
|
||||
actual_result = tf.contrib.layers.reduce_batch_mean(input_vec)
|
||||
self.assertAllClose(expected_result, actual_result.eval())
|
||||
|
||||
def testDimension2(self):
|
||||
with self.test_session():
|
||||
input_vec = tf.constant([
|
||||
[1.0, 2.0],
|
||||
[-1.0, -2.0]
|
||||
])
|
||||
expected_result = np.array([1.5, -1.5])
|
||||
actual_result = tf.contrib.layers.reduce_batch_mean(input_vec)
|
||||
self.assertAllClose(expected_result, actual_result.eval())
|
||||
|
||||
def testReturnShape(self):
|
||||
with self.test_session():
|
||||
input_vec = tf.constant([
|
||||
[1.0, 2.0],
|
||||
[-1.0, -2.0]
|
||||
])
|
||||
expected_result = np.array([3.0, -3.0])
|
||||
actual_result = tf.contrib.layers.reduce_batch_mean(input_vec)
|
||||
self.assertShapeEqual(expected_result, actual_result)
|
||||
|
||||
def testDimensionN(self):
|
||||
with self.test_session():
|
||||
input_vec = tf.constant([
|
||||
[
|
||||
[1.0, 2.0],
|
||||
[3.0, 4.0]
|
||||
],
|
||||
[
|
||||
[5.0, 6.0],
|
||||
[7.0, 8.0]
|
||||
]
|
||||
])
|
||||
expected_result = np.array([2.5, 6.5])
|
||||
actual_result = tf.contrib.layers.reduce_batch_mean(input_vec)
|
||||
self.assertAllClose(expected_result, actual_result.eval())
|
||||
|
||||
|
||||
class AbsoluteLossTest(tf.test.TestCase):
|
||||
|
||||
def _getTestVectors(self):
|
||||
target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
|
||||
predicted = tf.constant([1.1, -0.2, 3.3, 1.6], shape=[2, 2],
|
||||
name="predicted")
|
||||
expected_loss = np.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2)
|
||||
return target, predicted, expected_loss
|
||||
|
||||
def testAbsoluteLoss(self):
|
||||
with self.test_session():
|
||||
target, predicted, expected_loss = self._getTestVectors()
|
||||
result = tf.contrib.layers.absolute_loss(predicted, target)
|
||||
self.assertAllClose(expected_loss, result.eval())
|
||||
|
||||
def testAbsoluteLossReturnShape(self):
|
||||
with self.test_session():
|
||||
target, predicted, expected_loss = self._getTestVectors()
|
||||
result = tf.contrib.layers.absolute_loss(predicted, target)
|
||||
self.assertShapeEqual(expected_loss, result)
|
||||
|
||||
def testInvalidShapesValueError(self):
|
||||
with self.test_session():
|
||||
target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
|
||||
incompatible_shape = tf.constant([0.0, 1.1], shape=[2],
|
||||
name="incompatible_shape")
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.layers.absolute_loss(incompatible_shape, target)
|
||||
|
||||
def testAbsoluteLossGradient(self):
|
||||
with self.test_session():
|
||||
target, predicted, _ = self._getTestVectors()
|
||||
result = tf.contrib.layers.absolute_loss(predicted, target)
|
||||
x_shape = [2, 2]
|
||||
err = tf.test.compute_gradient_error(target, x_shape, result, x_shape)
|
||||
err_tolerance = 1e-4
|
||||
self.assertLess(err, err_tolerance)
|
||||
|
||||
|
||||
class SquaredLossTest(tf.test.TestCase):
|
||||
|
||||
def _getTestVectors(self):
|
||||
target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
|
||||
predicted = tf.constant([1.1, -0.2, 3.3, 1.6], shape=[2, 2],
|
||||
name="predicted")
|
||||
expected_loss = np.array([0.01, 0.04, 0.09, 0.16]).reshape(2, 2)
|
||||
return target, predicted, expected_loss
|
||||
|
||||
def testSquaredLoss(self):
|
||||
with self.test_session():
|
||||
target, predicted, expected_loss = self._getTestVectors()
|
||||
result = tf.contrib.layers.squared_loss(predicted, target)
|
||||
self.assertAllClose(expected_loss, result.eval())
|
||||
|
||||
def testSquaredLossReturnShape(self):
|
||||
with self.test_session():
|
||||
target, predicted, expected_loss = self._getTestVectors()
|
||||
result = tf.contrib.layers.squared_loss(predicted, target)
|
||||
self.assertShapeEqual(expected_loss, result)
|
||||
|
||||
def testInvalidShapesValueError(self):
|
||||
with self.test_session():
|
||||
target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
|
||||
incompatible_shape = tf.constant([0.0, 1.1], shape=[2],
|
||||
name="incompatible_shape")
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.layers.squared_loss(incompatible_shape, target)
|
||||
|
||||
def testSquaredLossGradient(self):
|
||||
with self.test_session():
|
||||
target, predicted, _ = self._getTestVectors()
|
||||
result = tf.contrib.layers.squared_loss(predicted, target)
|
||||
x_shape = [2, 2]
|
||||
err = tf.test.compute_gradient_error(target, x_shape, result, x_shape)
|
||||
err_tolerance = 1e-3
|
||||
self.assertLess(err, err_tolerance)
|
||||
|
||||
|
||||
class SumSquaredLossTest(tf.test.TestCase):
|
||||
|
||||
def _getTestVectors(self):
|
||||
target = tf.constant([[0.0, 1.0],
|
||||
[3.0, 2.0]],
|
||||
shape=[2, 2],
|
||||
name="target")
|
||||
predicted = tf.constant([[3.0, -2.0],
|
||||
[1.0, 2.0]],
|
||||
shape=[2, 2],
|
||||
name="predicted")
|
||||
expected_loss = np.array([9.0, 2.0])
|
||||
return target, predicted, expected_loss
|
||||
|
||||
def testSumSquaredLoss(self):
|
||||
with self.test_session():
|
||||
target, predicted, expected_loss = self._getTestVectors()
|
||||
result = tf.contrib.layers.sum_squared_loss(predicted, target)
|
||||
self.assertAllClose(expected_loss, result.eval())
|
||||
|
||||
def testSumSquaredLossReturnShape(self):
|
||||
with self.test_session():
|
||||
target, predicted, expected_loss = self._getTestVectors()
|
||||
result = tf.contrib.layers.sum_squared_loss(predicted, target)
|
||||
self.assertShapeEqual(expected_loss, result)
|
||||
|
||||
def testInvalidShapesValueError(self):
|
||||
with self.test_session():
|
||||
target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
|
||||
incompatible_shape = tf.constant([0.0, 1.1], shape=[2],
|
||||
name="incompatible_shape")
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.layers.sum_squared_loss(incompatible_shape, target)
|
||||
|
||||
def testSumSquaredLossGradient(self):
|
||||
with self.test_session():
|
||||
target, predicted, _ = self._getTestVectors()
|
||||
result = tf.contrib.layers.sum_squared_loss(predicted, target)
|
||||
x_shape = [2, 2]
|
||||
result_shape = [2]
|
||||
err = tf.test.compute_gradient_error(target, x_shape,
|
||||
result, result_shape)
|
||||
err_tolerance = 1e-3
|
||||
self.assertLess(err, err_tolerance)
|
||||
|
||||
|
||||
class MeanAbsoluteLossTest(tf.test.TestCase):
|
||||
|
||||
def _getTestVectors(self):
|
||||
target = tf.constant([[0.0, 1.0, 2.0],
|
||||
[3.0, 2.0, 4.0]],
|
||||
shape=[2, 3],
|
||||
name="target")
|
||||
predicted = tf.constant([[3.0, -3.0, 0.0],
|
||||
[1.0, 2.0, 0.0]],
|
||||
shape=[2, 3],
|
||||
name="predicted")
|
||||
expected_loss = np.array([3.0, 2.0])
|
||||
return target, predicted, expected_loss
|
||||
|
||||
def testMeanAbsoluteLoss(self):
|
||||
with self.test_session():
|
||||
target, predicted, expected_loss = self._getTestVectors()
|
||||
result = tf.contrib.layers.mean_absolute_loss(predicted, target)
|
||||
self.assertAllClose(expected_loss, result.eval())
|
||||
|
||||
def testMeanAbsoluteLossReturnShape(self):
|
||||
with self.test_session():
|
||||
target, predicted, expected_loss = self._getTestVectors()
|
||||
result = tf.contrib.layers.mean_absolute_loss(predicted, target)
|
||||
self.assertShapeEqual(expected_loss, result)
|
||||
|
||||
def testInvalidShapesValueError(self):
|
||||
with self.test_session():
|
||||
target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
|
||||
incompatible_shape = tf.constant([0.0, 1.1], shape=[2],
|
||||
name="incompatible_shape")
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.layers.mean_absolute_loss(incompatible_shape, target)
|
||||
|
||||
def testMeanAbsoluteLossGradient(self):
|
||||
with self.test_session():
|
||||
target, predicted, _ = self._getTestVectors()
|
||||
result = tf.contrib.layers.mean_absolute_loss(predicted, target)
|
||||
x_shape = [2, 3]
|
||||
result_shape = [2]
|
||||
err = tf.test.compute_gradient_error(target, x_shape,
|
||||
result, result_shape)
|
||||
err_tolerance = 1e-3
|
||||
self.assertLess(err, err_tolerance)
|
||||
|
||||
|
||||
class MeanSquaredLossTest(tf.test.TestCase):
|
||||
|
||||
def _getTestVectors(self):
|
||||
target = tf.constant([[0.0, 1.0, 2.0],
|
||||
[3.0, 2.0, 4.0]],
|
||||
shape=[2, 3],
|
||||
name="target")
|
||||
predicted = tf.constant([[3.0, -3.0, 0.0],
|
||||
[1.0, 2.0, 0.0]],
|
||||
shape=[2, 3],
|
||||
name="predicted")
|
||||
expected_loss = np.array([9.666667, 6.666667])
|
||||
return target, predicted, expected_loss
|
||||
|
||||
def testMeanSquaredLoss(self):
|
||||
with self.test_session():
|
||||
target, predicted, expected_loss = self._getTestVectors()
|
||||
result = tf.contrib.layers.mean_squared_loss(predicted, target)
|
||||
self.assertAllClose(expected_loss, result.eval())
|
||||
|
||||
def testMeanSquaredLossReturnShape(self):
|
||||
with self.test_session():
|
||||
target, predicted, expected_loss = self._getTestVectors()
|
||||
result = tf.contrib.layers.mean_squared_loss(predicted, target)
|
||||
self.assertShapeEqual(expected_loss, result)
|
||||
|
||||
def testInvalidShapesValueError(self):
|
||||
with self.test_session():
|
||||
target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
|
||||
incompatible_shape = tf.constant([0.0, 1.1], shape=[2],
|
||||
name="incompatible_shape")
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.layers.mean_squared_loss(incompatible_shape, target)
|
||||
|
||||
def testMeanSquaredLossGradient(self):
|
||||
with self.test_session():
|
||||
target, predicted, _ = self._getTestVectors()
|
||||
result = tf.contrib.layers.mean_squared_loss(predicted, target)
|
||||
x_shape = [2, 3]
|
||||
result_shape = [2]
|
||||
err = tf.test.compute_gradient_error(target, x_shape,
|
||||
result, result_shape)
|
||||
err_tolerance = 1e-3
|
||||
self.assertLess(err, err_tolerance)
|
||||
|
||||
|
||||
class RootMeanSquaredLossTest(tf.test.TestCase):
|
||||
|
||||
def _getTestVectors(self):
|
||||
target = tf.constant([[0.0, 1.0, 2.0],
|
||||
[3.0, 2.0, 4.0]],
|
||||
shape=[2, 3],
|
||||
name="target")
|
||||
predicted = tf.constant([[3.0, -3.0, 0.0],
|
||||
[1.0, 2.0, 0.0]],
|
||||
shape=[2, 3],
|
||||
name="predicted")
|
||||
expected_loss = np.array([3.109126, 2.5819889])
|
||||
return target, predicted, expected_loss
|
||||
|
||||
def testRootMeanSquaredLoss(self):
|
||||
with self.test_session():
|
||||
target, predicted, expected_loss = self._getTestVectors()
|
||||
result = tf.contrib.layers.root_mean_squared_loss(predicted, target)
|
||||
self.assertAllClose(expected_loss, result.eval())
|
||||
|
||||
def testRootMeanSquaredLossReturnShape(self):
|
||||
with self.test_session():
|
||||
target, predicted, expected_loss = self._getTestVectors()
|
||||
result = tf.contrib.layers.root_mean_squared_loss(predicted, target)
|
||||
self.assertShapeEqual(expected_loss, result)
|
||||
|
||||
def testInvalidShapesValueError(self):
|
||||
with self.test_session():
|
||||
target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
|
||||
incompatible_shape = tf.constant([0.0, 1.1], shape=[2],
|
||||
name="incompatible_shape")
|
||||
with self.assertRaises(ValueError):
|
||||
tf.contrib.layers.root_mean_squared_loss(incompatible_shape, target)
|
||||
|
||||
def testRootMeanSquaredLossGradient(self):
|
||||
with self.test_session():
|
||||
target, predicted, _ = self._getTestVectors()
|
||||
result = tf.contrib.layers.root_mean_squared_loss(predicted, target)
|
||||
x_shape = [2, 3]
|
||||
result_shape = [2]
|
||||
err = tf.test.compute_gradient_error(target, x_shape,
|
||||
result, result_shape)
|
||||
err_tolerance = 1e-3
|
||||
self.assertLess(err, err_tolerance)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
Loading…
Reference in New Issue
Block a user