Enable sparse xent op tests in eager mode.

PiperOrigin-RevId: 334272337
Change-Id: I7170277a86650a30354dddd035d6d332e16139d5
This commit is contained in:
Penporn Koanantakool 2020-09-28 17:32:43 -07:00 committed by TensorFlower Gardener
parent 421a7754e6
commit 816babe424

View File

@ -25,6 +25,8 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.eager import backprop as backprop_lib
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
@ -32,7 +34,7 @@ from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
@ -79,33 +81,34 @@ class SparseXentTest(test.TestCase):
self.assertAllClose([0.0, 0.0, 0.0], tf_loss)
self.assertAllClose([[0.0], [0.0], [0.0]], tf_backprop)
@test_util.run_deprecated_v1
@test_util.disable_xla("XLA cannot assert inside of a kernel.")
def testInvalidLabel(self):
@test_util.run_gpu_only()
def testInvalidLabelGPU(self):
features = [[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 2., 3., 4.],
[1., 2., 3., 4.]]
labels = [4, 3, 0, -1]
loss, backprop = self.evaluate(
gen_nn_ops.sparse_softmax_cross_entropy_with_logits(features, labels))
self.assertAllClose([[np.nan] * 4, [0.25, 0.25, 0.25, -0.75],
[-0.968, 0.087, 0.237, 0.6439], [np.nan] * 4],
backprop,
rtol=1e-3,
atol=1e-3)
self.assertAllClose([np.nan, 1.3862, 3.4420, np.nan],
loss,
rtol=1e-3,
atol=1e-3)
if test.is_built_with_gpu_support() and test.is_gpu_available():
with self.session(use_gpu=True) as sess:
loss, backprop = (
gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
features, labels))
tf_loss, tf_backprop = self.evaluate([loss, backprop])
self.assertAllClose(
[[np.nan] * 4, [0.25, 0.25, 0.25, -0.75],
[-0.968, 0.087, 0.237, 0.6439], [np.nan] * 4],
tf_backprop,
rtol=1e-3,
atol=1e-3)
self.assertAllClose(
[np.nan, 1.3862, 3.4420, np.nan], tf_loss, rtol=1e-3, atol=1e-3)
with self.session(use_gpu=False) as sess:
loss, backprop = (
@test_util.run_in_graph_and_eager_modes(use_gpu=False)
@test_util.disable_xla("XLA cannot assert inside of a kernel.")
def testInvalidLabelCPU(self):
features = [[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 2., 3., 4.],
[1., 2., 3., 4.]]
labels = [4, 3, 0, -1]
with self.assertRaisesRegex(
(errors_impl.InvalidArgumentError, errors_impl.UnknownError),
"Received a label value of"):
self.evaluate(
gen_nn_ops.sparse_softmax_cross_entropy_with_logits(features, labels))
with self.assertRaisesOpError("Received a label value of"):
self.evaluate([loss, backprop])
def testNpXent(self):
# We create 2 batches of logits for testing.
@ -153,9 +156,8 @@ class SparseXentTest(test.TestCase):
nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=constant_op.constant(0), logits=constant_op.constant(1.0))
@test_util.run_deprecated_v1
def testLabelsPlaceholderScalar(self):
with self.session(use_gpu=True):
with ops_lib.Graph().as_default(), self.session(use_gpu=True):
labels = array_ops.placeholder(np.int32)
y = nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=[[7.]])
@ -189,7 +191,7 @@ class SparseXentTest(test.TestCase):
def testEmpty(self):
self._testXent(np.zeros((0, 3)), np.zeros((0,), dtype=np.int32))
@test_util.run_deprecated_v1
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testGradient(self):
with self.session(use_gpu=True) as sess:
l = constant_op.constant([3, 0, 1], name="l")
@ -198,22 +200,28 @@ class SparseXentTest(test.TestCase):
shape=[3, 4],
dtype=dtypes.float64,
name="f")
x = nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=l, logits=f, name="xent")
err = gradient_checker.compute_gradient_error(f, [3, 4], x, [3])
# Check that no extra computation performed. When only first derivative is
# requested, second derivative must not be computed. So when there is no
# second derivative, there is no `BatchMatMul` op in the graph.
op_names = [
op.op_def.name for op in sess.graph.get_operations() if op.op_def
]
self.assertNotIn("BatchMatMul", op_names)
self.assertNotIn("BatchMatMulV2", op_names)
def xent(f):
# gradient_checker_v2.computee_gradient doesn't take int32/int64.
# labels must be of type int32/int64, so passing them separately here.
return nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=l, logits=f, name="xent")
self.assertLess(err, 5e-8)
theoretical, numerical = gradient_checker_v2.compute_gradient(xent, [f])
if not context.executing_eagerly():
# Check that no extra computation performed. When only first derivative
# is requested, second derivative must not be computed. So when there is
# no second derivative, there is no `BatchMatMul` op in the graph.
op_names = [
op.op_def.name for op in sess.graph.get_operations() if op.op_def
]
self.assertNotIn("BatchMatMul", op_names)
self.assertNotIn("BatchMatMulV2", op_names)
tol = 5e-8
self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
@test_util.run_deprecated_v1
def testSecondGradient(self):
with self.session() as sess:
l = constant_op.constant([3, 0, 1], name="l")
@ -222,51 +230,67 @@ class SparseXentTest(test.TestCase):
shape=[3, 4],
dtype=dtypes.float64,
name="f")
x = nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=l, logits=f, name="xent")
gradients = gradients_impl.gradients(x, [f])[0]
err = gradient_checker.compute_gradient_error(f, [3, 4], gradients,
[3, 4])
def xent_grad(f):
if not context.executing_eagerly():
return gradients_impl.gradients(
nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=l, logits=f, name="xent"), [f])[0]
with backprop_lib.GradientTape() as tape:
tape.watch(f)
return tape.gradient(
nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=l, logits=f, name="xent"), [f])[0]
# Check that second derivative is calculated.
# (it is equivalent to being `BatchMatMul` op in the graph because of
# implementation of xentropy grad)
op_names = [
op.op_def.name for op in sess.graph.get_operations() if op.op_def
]
self.assertIn("BatchMatMulV2", op_names)
theoretical, numerical = gradient_checker_v2.compute_gradient(
xent_grad, [f])
self.assertLess(err, 5e-8)
if not context.executing_eagerly():
# Check that second derivative is calculated.
# (it is equivalent to being `BatchMatMul` op in the graph because of
# implementation of xentropy grad)
op_names = [
op.op_def.name for op in sess.graph.get_operations() if op.op_def
]
self.assertIn("BatchMatMulV2", op_names)
tol = 5e-8
self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def _testHighDim(self, features, labels):
np_loss, np_backprop = self._npXent(np.array(features), np.array(labels))
# manually reshape loss
np_loss = np.reshape(np_loss, np.array(labels).shape)
with self.cached_session(use_gpu=True) as sess:
loss = nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=features)
backprop = loss.op.inputs[0].op.outputs[1]
tf_loss, tf_backprop = self.evaluate([loss, backprop])
tf_loss = nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=features)
if not context.executing_eagerly():
tf_backprop = tf_loss.op.inputs[0].op.outputs[1]
else:
with backprop_lib.GradientTape() as tape:
features = constant_op.constant(features)
tape.watch(features)
tf_backprop = tape.gradient(
nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=features), [features])[0]
tf_backprop = array_ops.reshape(tf_backprop, np_backprop.shape)
self.assertAllCloseAccordingToType(np_loss, tf_loss)
self.assertAllCloseAccordingToType(np_backprop, tf_backprop)
@test_util.run_deprecated_v1
def testHighDim(self):
features = [[[1., 1., 1., 1.]], [[1., 2., 3., 4.]]]
labels = [[3], [0]]
self._testHighDim(features, labels)
@test_util.run_deprecated_v1
def testHighDim2(self):
features = [[[1., 1., 1., 1.], [2., 2., 2., 2.]],
[[1., 2., 3., 4.], [5., 6., 7., 8.]]]
labels = [[3, 2], [0, 3]]
self._testHighDim(features, labels)
@test_util.run_deprecated_v1
def testScalarHandling(self):
with self.session(use_gpu=False) as sess:
with ops_lib.Graph().as_default(), self.session(use_gpu=False) as sess:
with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
".*labels must be 1-D.*"):
labels = array_ops.placeholder(dtypes.int32, shape=[None, 1])