Enable sparse xent op tests in eager mode.
PiperOrigin-RevId: 334272337 Change-Id: I7170277a86650a30354dddd035d6d332e16139d5
This commit is contained in:
parent
421a7754e6
commit
816babe424
@ -25,6 +25,8 @@ import numpy as np
|
|||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
|
from tensorflow.python.eager import backprop as backprop_lib
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
@ -32,7 +34,7 @@ from tensorflow.python.framework import ops as ops_lib
|
|||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gen_nn_ops
|
from tensorflow.python.ops import gen_nn_ops
|
||||||
from tensorflow.python.ops import gradient_checker
|
from tensorflow.python.ops import gradient_checker_v2
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
@ -79,33 +81,34 @@ class SparseXentTest(test.TestCase):
|
|||||||
self.assertAllClose([0.0, 0.0, 0.0], tf_loss)
|
self.assertAllClose([0.0, 0.0, 0.0], tf_loss)
|
||||||
self.assertAllClose([[0.0], [0.0], [0.0]], tf_backprop)
|
self.assertAllClose([[0.0], [0.0], [0.0]], tf_backprop)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_gpu_only()
|
||||||
@test_util.disable_xla("XLA cannot assert inside of a kernel.")
|
def testInvalidLabelGPU(self):
|
||||||
def testInvalidLabel(self):
|
|
||||||
features = [[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 2., 3., 4.],
|
features = [[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 2., 3., 4.],
|
||||||
[1., 2., 3., 4.]]
|
[1., 2., 3., 4.]]
|
||||||
labels = [4, 3, 0, -1]
|
labels = [4, 3, 0, -1]
|
||||||
|
loss, backprop = self.evaluate(
|
||||||
|
gen_nn_ops.sparse_softmax_cross_entropy_with_logits(features, labels))
|
||||||
|
self.assertAllClose([[np.nan] * 4, [0.25, 0.25, 0.25, -0.75],
|
||||||
|
[-0.968, 0.087, 0.237, 0.6439], [np.nan] * 4],
|
||||||
|
backprop,
|
||||||
|
rtol=1e-3,
|
||||||
|
atol=1e-3)
|
||||||
|
self.assertAllClose([np.nan, 1.3862, 3.4420, np.nan],
|
||||||
|
loss,
|
||||||
|
rtol=1e-3,
|
||||||
|
atol=1e-3)
|
||||||
|
|
||||||
if test.is_built_with_gpu_support() and test.is_gpu_available():
|
@test_util.run_in_graph_and_eager_modes(use_gpu=False)
|
||||||
with self.session(use_gpu=True) as sess:
|
@test_util.disable_xla("XLA cannot assert inside of a kernel.")
|
||||||
loss, backprop = (
|
def testInvalidLabelCPU(self):
|
||||||
gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
|
features = [[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 2., 3., 4.],
|
||||||
features, labels))
|
[1., 2., 3., 4.]]
|
||||||
tf_loss, tf_backprop = self.evaluate([loss, backprop])
|
labels = [4, 3, 0, -1]
|
||||||
self.assertAllClose(
|
with self.assertRaisesRegex(
|
||||||
[[np.nan] * 4, [0.25, 0.25, 0.25, -0.75],
|
(errors_impl.InvalidArgumentError, errors_impl.UnknownError),
|
||||||
[-0.968, 0.087, 0.237, 0.6439], [np.nan] * 4],
|
"Received a label value of"):
|
||||||
tf_backprop,
|
self.evaluate(
|
||||||
rtol=1e-3,
|
|
||||||
atol=1e-3)
|
|
||||||
self.assertAllClose(
|
|
||||||
[np.nan, 1.3862, 3.4420, np.nan], tf_loss, rtol=1e-3, atol=1e-3)
|
|
||||||
|
|
||||||
with self.session(use_gpu=False) as sess:
|
|
||||||
loss, backprop = (
|
|
||||||
gen_nn_ops.sparse_softmax_cross_entropy_with_logits(features, labels))
|
gen_nn_ops.sparse_softmax_cross_entropy_with_logits(features, labels))
|
||||||
with self.assertRaisesOpError("Received a label value of"):
|
|
||||||
self.evaluate([loss, backprop])
|
|
||||||
|
|
||||||
def testNpXent(self):
|
def testNpXent(self):
|
||||||
# We create 2 batches of logits for testing.
|
# We create 2 batches of logits for testing.
|
||||||
@ -153,9 +156,8 @@ class SparseXentTest(test.TestCase):
|
|||||||
nn_ops.sparse_softmax_cross_entropy_with_logits(
|
nn_ops.sparse_softmax_cross_entropy_with_logits(
|
||||||
labels=constant_op.constant(0), logits=constant_op.constant(1.0))
|
labels=constant_op.constant(0), logits=constant_op.constant(1.0))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testLabelsPlaceholderScalar(self):
|
def testLabelsPlaceholderScalar(self):
|
||||||
with self.session(use_gpu=True):
|
with ops_lib.Graph().as_default(), self.session(use_gpu=True):
|
||||||
labels = array_ops.placeholder(np.int32)
|
labels = array_ops.placeholder(np.int32)
|
||||||
y = nn_ops.sparse_softmax_cross_entropy_with_logits(
|
y = nn_ops.sparse_softmax_cross_entropy_with_logits(
|
||||||
labels=labels, logits=[[7.]])
|
labels=labels, logits=[[7.]])
|
||||||
@ -189,7 +191,7 @@ class SparseXentTest(test.TestCase):
|
|||||||
def testEmpty(self):
|
def testEmpty(self):
|
||||||
self._testXent(np.zeros((0, 3)), np.zeros((0,), dtype=np.int32))
|
self._testXent(np.zeros((0, 3)), np.zeros((0,), dtype=np.int32))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||||
def testGradient(self):
|
def testGradient(self):
|
||||||
with self.session(use_gpu=True) as sess:
|
with self.session(use_gpu=True) as sess:
|
||||||
l = constant_op.constant([3, 0, 1], name="l")
|
l = constant_op.constant([3, 0, 1], name="l")
|
||||||
@ -198,22 +200,28 @@ class SparseXentTest(test.TestCase):
|
|||||||
shape=[3, 4],
|
shape=[3, 4],
|
||||||
dtype=dtypes.float64,
|
dtype=dtypes.float64,
|
||||||
name="f")
|
name="f")
|
||||||
x = nn_ops.sparse_softmax_cross_entropy_with_logits(
|
|
||||||
labels=l, logits=f, name="xent")
|
|
||||||
err = gradient_checker.compute_gradient_error(f, [3, 4], x, [3])
|
|
||||||
|
|
||||||
# Check that no extra computation performed. When only first derivative is
|
def xent(f):
|
||||||
# requested, second derivative must not be computed. So when there is no
|
# gradient_checker_v2.computee_gradient doesn't take int32/int64.
|
||||||
# second derivative, there is no `BatchMatMul` op in the graph.
|
# labels must be of type int32/int64, so passing them separately here.
|
||||||
op_names = [
|
return nn_ops.sparse_softmax_cross_entropy_with_logits(
|
||||||
op.op_def.name for op in sess.graph.get_operations() if op.op_def
|
labels=l, logits=f, name="xent")
|
||||||
]
|
|
||||||
self.assertNotIn("BatchMatMul", op_names)
|
|
||||||
self.assertNotIn("BatchMatMulV2", op_names)
|
|
||||||
|
|
||||||
self.assertLess(err, 5e-8)
|
theoretical, numerical = gradient_checker_v2.compute_gradient(xent, [f])
|
||||||
|
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
# Check that no extra computation performed. When only first derivative
|
||||||
|
# is requested, second derivative must not be computed. So when there is
|
||||||
|
# no second derivative, there is no `BatchMatMul` op in the graph.
|
||||||
|
op_names = [
|
||||||
|
op.op_def.name for op in sess.graph.get_operations() if op.op_def
|
||||||
|
]
|
||||||
|
self.assertNotIn("BatchMatMul", op_names)
|
||||||
|
self.assertNotIn("BatchMatMulV2", op_names)
|
||||||
|
|
||||||
|
tol = 5e-8
|
||||||
|
self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testSecondGradient(self):
|
def testSecondGradient(self):
|
||||||
with self.session() as sess:
|
with self.session() as sess:
|
||||||
l = constant_op.constant([3, 0, 1], name="l")
|
l = constant_op.constant([3, 0, 1], name="l")
|
||||||
@ -222,51 +230,67 @@ class SparseXentTest(test.TestCase):
|
|||||||
shape=[3, 4],
|
shape=[3, 4],
|
||||||
dtype=dtypes.float64,
|
dtype=dtypes.float64,
|
||||||
name="f")
|
name="f")
|
||||||
x = nn_ops.sparse_softmax_cross_entropy_with_logits(
|
|
||||||
labels=l, logits=f, name="xent")
|
|
||||||
|
|
||||||
gradients = gradients_impl.gradients(x, [f])[0]
|
def xent_grad(f):
|
||||||
err = gradient_checker.compute_gradient_error(f, [3, 4], gradients,
|
if not context.executing_eagerly():
|
||||||
[3, 4])
|
return gradients_impl.gradients(
|
||||||
|
nn_ops.sparse_softmax_cross_entropy_with_logits(
|
||||||
|
labels=l, logits=f, name="xent"), [f])[0]
|
||||||
|
with backprop_lib.GradientTape() as tape:
|
||||||
|
tape.watch(f)
|
||||||
|
return tape.gradient(
|
||||||
|
nn_ops.sparse_softmax_cross_entropy_with_logits(
|
||||||
|
labels=l, logits=f, name="xent"), [f])[0]
|
||||||
|
|
||||||
# Check that second derivative is calculated.
|
theoretical, numerical = gradient_checker_v2.compute_gradient(
|
||||||
# (it is equivalent to being `BatchMatMul` op in the graph because of
|
xent_grad, [f])
|
||||||
# implementation of xentropy grad)
|
|
||||||
op_names = [
|
|
||||||
op.op_def.name for op in sess.graph.get_operations() if op.op_def
|
|
||||||
]
|
|
||||||
self.assertIn("BatchMatMulV2", op_names)
|
|
||||||
|
|
||||||
self.assertLess(err, 5e-8)
|
if not context.executing_eagerly():
|
||||||
|
# Check that second derivative is calculated.
|
||||||
|
# (it is equivalent to being `BatchMatMul` op in the graph because of
|
||||||
|
# implementation of xentropy grad)
|
||||||
|
op_names = [
|
||||||
|
op.op_def.name for op in sess.graph.get_operations() if op.op_def
|
||||||
|
]
|
||||||
|
self.assertIn("BatchMatMulV2", op_names)
|
||||||
|
|
||||||
|
tol = 5e-8
|
||||||
|
self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||||
def _testHighDim(self, features, labels):
|
def _testHighDim(self, features, labels):
|
||||||
np_loss, np_backprop = self._npXent(np.array(features), np.array(labels))
|
np_loss, np_backprop = self._npXent(np.array(features), np.array(labels))
|
||||||
# manually reshape loss
|
# manually reshape loss
|
||||||
np_loss = np.reshape(np_loss, np.array(labels).shape)
|
np_loss = np.reshape(np_loss, np.array(labels).shape)
|
||||||
with self.cached_session(use_gpu=True) as sess:
|
tf_loss = nn_ops.sparse_softmax_cross_entropy_with_logits(
|
||||||
loss = nn_ops.sparse_softmax_cross_entropy_with_logits(
|
labels=labels, logits=features)
|
||||||
labels=labels, logits=features)
|
if not context.executing_eagerly():
|
||||||
backprop = loss.op.inputs[0].op.outputs[1]
|
tf_backprop = tf_loss.op.inputs[0].op.outputs[1]
|
||||||
tf_loss, tf_backprop = self.evaluate([loss, backprop])
|
else:
|
||||||
|
with backprop_lib.GradientTape() as tape:
|
||||||
|
features = constant_op.constant(features)
|
||||||
|
tape.watch(features)
|
||||||
|
tf_backprop = tape.gradient(
|
||||||
|
nn_ops.sparse_softmax_cross_entropy_with_logits(
|
||||||
|
labels=labels, logits=features), [features])[0]
|
||||||
|
tf_backprop = array_ops.reshape(tf_backprop, np_backprop.shape)
|
||||||
|
|
||||||
self.assertAllCloseAccordingToType(np_loss, tf_loss)
|
self.assertAllCloseAccordingToType(np_loss, tf_loss)
|
||||||
self.assertAllCloseAccordingToType(np_backprop, tf_backprop)
|
self.assertAllCloseAccordingToType(np_backprop, tf_backprop)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testHighDim(self):
|
def testHighDim(self):
|
||||||
features = [[[1., 1., 1., 1.]], [[1., 2., 3., 4.]]]
|
features = [[[1., 1., 1., 1.]], [[1., 2., 3., 4.]]]
|
||||||
labels = [[3], [0]]
|
labels = [[3], [0]]
|
||||||
self._testHighDim(features, labels)
|
self._testHighDim(features, labels)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testHighDim2(self):
|
def testHighDim2(self):
|
||||||
features = [[[1., 1., 1., 1.], [2., 2., 2., 2.]],
|
features = [[[1., 1., 1., 1.], [2., 2., 2., 2.]],
|
||||||
[[1., 2., 3., 4.], [5., 6., 7., 8.]]]
|
[[1., 2., 3., 4.], [5., 6., 7., 8.]]]
|
||||||
labels = [[3, 2], [0, 3]]
|
labels = [[3, 2], [0, 3]]
|
||||||
self._testHighDim(features, labels)
|
self._testHighDim(features, labels)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testScalarHandling(self):
|
def testScalarHandling(self):
|
||||||
with self.session(use_gpu=False) as sess:
|
with ops_lib.Graph().as_default(), self.session(use_gpu=False) as sess:
|
||||||
with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
|
with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
|
||||||
".*labels must be 1-D.*"):
|
".*labels must be 1-D.*"):
|
||||||
labels = array_ops.placeholder(dtypes.int32, shape=[None, 1])
|
labels = array_ops.placeholder(dtypes.int32, shape=[None, 1])
|
||||||
|
Loading…
Reference in New Issue
Block a user