Merge pull request #33803 from duncanriach:bias_add_test_eager_mode
PiperOrigin-RevId: 294705769 Change-Id: I96e27068e2a1d770e49439950c6043d353ca4962
This commit is contained in:
commit
3c88f5b302
tensorflow/python/kernel_tests
@ -20,17 +20,22 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import nn_ops
|
||||
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class BiasAddTestBase(test.TestCase):
|
||||
|
||||
def _npBias(self, inputs, bias):
|
||||
@ -48,7 +53,7 @@ class BiasAddTestBase(test.TestCase):
|
||||
def _testBias(self, np_inputs, np_bias, use_gpu=False):
|
||||
np_val = self._npBias(np_inputs, np_bias)
|
||||
with self.cached_session(use_gpu=use_gpu):
|
||||
tf_val = nn_ops.bias_add(np_inputs, np_bias).eval()
|
||||
tf_val = self.evaluate(nn_ops.bias_add(np_inputs, np_bias))
|
||||
self.assertAllCloseAccordingToType(np_val, tf_val)
|
||||
|
||||
def _AtLeast3d(self, np_value):
|
||||
@ -76,7 +81,8 @@ class BiasAddTestBase(test.TestCase):
|
||||
np_val = self._npBias(np_inputs, np_bias)
|
||||
np_inputs = self._NHWCToNCHW(np_inputs)
|
||||
with self.cached_session(use_gpu=use_gpu):
|
||||
tf_val = nn_ops.bias_add(np_inputs, np_bias, data_format="NCHW").eval()
|
||||
tf_val = self.evaluate(
|
||||
nn_ops.bias_add(np_inputs, np_bias, data_format="NCHW"))
|
||||
tf_val = self._NCHWToNHWC(tf_val)
|
||||
self.assertAllCloseAccordingToType(self._AtLeast3d(np_val), tf_val)
|
||||
|
||||
@ -87,40 +93,40 @@ class BiasAddTestBase(test.TestCase):
|
||||
self._testBias(np_inputs, np_bias, use_gpu=True)
|
||||
self._testBiasNCHW(np_inputs, np_bias, use_gpu=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def _expectedException(self):
|
||||
if context.executing_eagerly():
|
||||
return errors_impl.InvalidArgumentError
|
||||
else:
|
||||
return ValueError
|
||||
|
||||
def testInputDims(self):
|
||||
with self.assertRaises(ValueError):
|
||||
with self.assertRaises(self._expectedException()):
|
||||
nn_ops.bias_add([1, 2], [1])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBiasVec(self):
|
||||
with self.assertRaises(ValueError):
|
||||
with self.assertRaises(self._expectedException()):
|
||||
nn_ops.bias_add(
|
||||
array_ops.reshape([1, 2], shape=[1, 2]),
|
||||
array_ops.reshape([1, 2], shape=[1, 2]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBiasInputsMatch(self):
|
||||
with self.assertRaises(ValueError):
|
||||
with self.assertRaises(self._expectedException()):
|
||||
nn_ops.bias_add(
|
||||
array_ops.reshape([1, 2], shape=[1, 2]),
|
||||
array_ops.reshape([1], shape=[1]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testIntTypes(self):
|
||||
for t in [np.int8, np.int16, np.int32, np.int64]:
|
||||
self._testAll(
|
||||
np.array([[10, 20, 30], [40, 50, 60]]).astype(t),
|
||||
np.array([1, 2, 3]).astype(t))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFloatTypes(self):
|
||||
for t in [np.float16, np.float32, np.float64]:
|
||||
self._testAll(
|
||||
np.random.rand(4, 3, 3).astype(t),
|
||||
np.random.rand(3).astype(t))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test4DFloatTypes(self):
|
||||
for t in [np.float16, np.float32, np.float64]:
|
||||
self._testAll(
|
||||
@ -133,32 +139,78 @@ class BiasAddTestBase(test.TestCase):
|
||||
np.random.rand(4, 4, 4, 2048).astype(t),
|
||||
np.random.rand(2048).astype(t))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test5DFloatTypes(self):
|
||||
for t in [np.float16, np.float32, np.float64]:
|
||||
self._testAll(
|
||||
np.random.rand(4, 3, 2, 3, 4).astype(t),
|
||||
np.random.rand(4).astype(t))
|
||||
|
||||
def _random_tensor(self, shape, dtype):
|
||||
return constant_op.constant(2 * np.random.rand(*shape) - 1, dtype=dtype)
|
||||
|
||||
def _computeGradient(self, np_input, bias, dtype, data_format):
|
||||
input_shape = output_shape = np_input.shape
|
||||
bias_shape = bias.shape
|
||||
input_tensor = constant_op.constant(
|
||||
np_input, shape=input_shape, dtype=dtype)
|
||||
bias_tensor = constant_op.constant(bias, shape=bias_shape, dtype=dtype)
|
||||
|
||||
if context.executing_eagerly():
|
||||
|
||||
def bias_add(input_tensor, bias_tensor):
|
||||
return nn_ops.bias_add(
|
||||
input_tensor, bias_tensor, data_format=data_format)
|
||||
|
||||
# The following is a work-around for TF issue 33660. Instead of
|
||||
# calculating the analytical and numerical gradients for both
|
||||
# inputs in a single call to compute_gradient, compute_gradient
|
||||
# is called for each input separately.
|
||||
def bias_add_1(input_tensor):
|
||||
return bias_add(input_tensor, bias_tensor)
|
||||
|
||||
def bias_add_2(bias_tensor):
|
||||
return bias_add(input_tensor, bias_tensor)
|
||||
|
||||
input_jacob_a, input_jacob_n = gradient_checker_v2.compute_gradient(
|
||||
bias_add_1, [input_tensor])
|
||||
bias_jacob_a, bias_jacob_n = gradient_checker_v2.compute_gradient(
|
||||
bias_add_2, [bias_tensor])
|
||||
|
||||
# Test gradient of BiasAddGrad
|
||||
def bias_add_grad_function(upstream_gradients):
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(bias_tensor)
|
||||
bias_add_output = bias_add(input_tensor, bias_tensor)
|
||||
gradient_injector_output = bias_add_output * upstream_gradients
|
||||
return tape.gradient(gradient_injector_output, bias_tensor)
|
||||
|
||||
upstream_tensor = self._random_tensor(output_shape, dtype)
|
||||
grad_jacob_a, grad_jacob_n = gradient_checker_v2.compute_gradient(
|
||||
bias_add_grad_function, [upstream_tensor])
|
||||
else:
|
||||
output_tensor = nn_ops.bias_add(
|
||||
input_tensor, bias_tensor, data_format=data_format)
|
||||
jacobians = gradient_checker.compute_gradient([input_tensor, bias_tensor],
|
||||
[input_shape, bias_shape],
|
||||
output_tensor, output_shape)
|
||||
(input_jacob_a, input_jacob_n), (bias_jacob_a, bias_jacob_n) = jacobians
|
||||
# Test gradient of BiasAddGrad
|
||||
bias_add_grad = gradients_impl.gradients(
|
||||
nn_ops.l2_loss(output_tensor), bias_tensor)[0]
|
||||
grad_jacob_a, grad_jacob_n = gradient_checker.compute_gradient(
|
||||
output_tensor, output_shape, bias_add_grad, bias_shape)
|
||||
|
||||
return ((input_jacob_a, bias_jacob_a, grad_jacob_a),
|
||||
(input_jacob_n, bias_jacob_n, grad_jacob_n))
|
||||
|
||||
def _testGradient(self, np_input, bias, dtype, data_format, use_gpu):
|
||||
with self.cached_session(use_gpu=use_gpu):
|
||||
if data_format == "NCHW":
|
||||
np_input = self._NHWCToNCHW(np_input)
|
||||
input_tensor = constant_op.constant(
|
||||
np_input, shape=np_input.shape, dtype=dtype)
|
||||
bias_tensor = constant_op.constant(bias, shape=bias.shape, dtype=dtype)
|
||||
output_tensor = nn_ops.bias_add(
|
||||
input_tensor, bias_tensor, data_format=data_format)
|
||||
tensor_jacob_t, tensor_jacob_n = gradient_checker.compute_gradient(
|
||||
input_tensor, np_input.shape, output_tensor, np_input.shape)
|
||||
bias_jacob_t, bias_jacob_n = gradient_checker.compute_gradient(
|
||||
bias_tensor, bias.shape, output_tensor, np_input.shape)
|
||||
|
||||
# Test gradient of BiasAddGrad
|
||||
bias_add_grad = gradients_impl.gradients(
|
||||
nn_ops.l2_loss(output_tensor), bias_tensor)[0]
|
||||
grad_jacob_t, grad_jacob_n = gradient_checker.compute_gradient(
|
||||
output_tensor, np_input.shape, bias_add_grad, bias.shape)
|
||||
jacob_a, jacob_n = self._computeGradient(np_input, bias, dtype,
|
||||
data_format)
|
||||
input_jacob_a, bias_jacob_a, grad_jacob_a = jacob_a
|
||||
input_jacob_n, bias_jacob_n, grad_jacob_n = jacob_n
|
||||
|
||||
if dtype == np.float16:
|
||||
# Compare fp16 analytical gradients to fp32 numerical gradients,
|
||||
@ -166,30 +218,22 @@ class BiasAddTestBase(test.TestCase):
|
||||
# care is taken with choosing the inputs and the delta. This is
|
||||
# a weaker, but pragmatic, check (in particular, it does not test
|
||||
# the op itself, only its gradient).
|
||||
input_tensor = constant_op.constant(
|
||||
np_input, shape=np_input.shape, dtype=np.float32)
|
||||
bias_tensor = constant_op.constant(
|
||||
bias, shape=bias.shape, dtype=np.float32)
|
||||
output_tensor = nn_ops.bias_add(
|
||||
input_tensor, bias_tensor, data_format=data_format)
|
||||
_, tensor_jacob_n = gradient_checker.compute_gradient(
|
||||
input_tensor, np_input.shape, output_tensor, np_input.shape)
|
||||
_, bias_jacob_n = gradient_checker.compute_gradient(
|
||||
bias_tensor, bias.shape, output_tensor, np_input.shape)
|
||||
_, jacob_n = self._computeGradient(np_input, bias, np.float32,
|
||||
data_format)
|
||||
input_jacob_n, bias_jacob_n, grad_jacob_n = jacob_n
|
||||
|
||||
bias_add_grad = gradients_impl.gradients(
|
||||
nn_ops.l2_loss(output_tensor), bias_tensor)[0]
|
||||
_, grad_jacob_n = gradient_checker.compute_gradient(
|
||||
output_tensor, np_input.shape, bias_add_grad, bias.shape)
|
||||
|
||||
threshold = 5e-3
|
||||
if dtype == dtypes.float64:
|
||||
threshold = 1e-10
|
||||
self.assertAllClose(tensor_jacob_t, tensor_jacob_n, threshold, threshold)
|
||||
self.assertAllClose(bias_jacob_t, bias_jacob_n, threshold, threshold)
|
||||
self.assertAllClose(grad_jacob_t, grad_jacob_n, threshold, threshold)
|
||||
elif np_input.size >= 512:
|
||||
# The 5e-3 threshold seems to have been marginal in these cases, and
|
||||
# small changes in the test were pushing it over the limit.
|
||||
threshold = 5e-2
|
||||
else:
|
||||
threshold = 5e-3
|
||||
self.assertAllClose(input_jacob_a, input_jacob_n, threshold, threshold)
|
||||
self.assertAllClose(bias_jacob_a, bias_jacob_n, threshold, threshold)
|
||||
self.assertAllClose(grad_jacob_a, grad_jacob_n, threshold, threshold)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradientTensor2D(self):
|
||||
for (data_format, use_gpu) in ("NHWC", False), ("NHWC", True):
|
||||
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
|
||||
@ -198,7 +242,6 @@ class BiasAddTestBase(test.TestCase):
|
||||
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
|
||||
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradientTensor3D(self):
|
||||
for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True),
|
||||
("NCHW", False), ("NCHW", True)]:
|
||||
@ -208,7 +251,6 @@ class BiasAddTestBase(test.TestCase):
|
||||
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
|
||||
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradientTensor4D(self):
|
||||
for (data_format, use_gpu) in [("NHWC", False)]:
|
||||
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
|
||||
@ -230,7 +272,6 @@ class BiasAddTestBase(test.TestCase):
|
||||
np.random.rand(64).astype(dtype.as_numpy_dtype),
|
||||
dtype, data_format, use_gpu)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradientTensor5D(self):
|
||||
for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True),
|
||||
("NCHW", False), ("NCHW", True)]:
|
||||
@ -242,13 +283,11 @@ class BiasAddTestBase(test.TestCase):
|
||||
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
|
||||
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testEmpty(self):
|
||||
np.random.seed(7)
|
||||
for shape in (0, 0), (2, 0), (0, 2), (4, 3, 0), (4, 0, 3), (0, 4, 3):
|
||||
self._testAll(np.random.randn(*shape), np.random.randn(shape[-1]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testEmptyGradient(self):
|
||||
for (data_format, use_gpu) in ("NHWC", False), ("NHWC", True):
|
||||
for shape in (0, 0), (2, 0), (0, 2):
|
||||
|
@ -24,6 +24,8 @@ import numpy as np
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -37,8 +39,8 @@ from tensorflow.python.platform import test
|
||||
class BiasAddDeterministicTest(bias_op_base.BiasAddTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _make_shape_tuple(self, batch_size, channel_count, data_rank, data_dim,
|
||||
data_layout):
|
||||
def _makeShapeTuple(self, batch_size, channel_count, data_rank, data_dim,
|
||||
data_layout):
|
||||
data_dims = data_rank * (data_dim,)
|
||||
if data_layout == 'channels_first':
|
||||
shape = (batch_size,) + (channel_count,) + data_dims
|
||||
@ -48,7 +50,7 @@ class BiasAddDeterministicTest(bias_op_base.BiasAddTestBase,
|
||||
raise ValueError('Unknown data format')
|
||||
return shape
|
||||
|
||||
def _data_format_from_data_layout(self, data_layout=None):
|
||||
def _dataFormatFromDataLayout(self, data_layout=None):
|
||||
if data_layout == 'channels_first':
|
||||
return 'NCHW'
|
||||
elif data_layout == 'channels_last':
|
||||
@ -56,59 +58,82 @@ class BiasAddDeterministicTest(bias_op_base.BiasAddTestBase,
|
||||
else:
|
||||
raise ValueError('Unknown data_layout')
|
||||
|
||||
def _random_data_op(self, shape, data_type):
|
||||
return constant_op.constant(
|
||||
2 * np.random.random_sample(shape) - 1, dtype=data_type)
|
||||
|
||||
def _random_ndarray(self, shape):
|
||||
def _randomNDArray(self, shape):
|
||||
return 2 * np.random.random_sample(shape) - 1
|
||||
|
||||
def _assert_reproducible(self, operation, feed_dict={}):
|
||||
with self.cached_session(force_gpu=True):
|
||||
result_a = operation[0].eval(feed_dict=feed_dict)
|
||||
result_b = operation[0].eval(feed_dict=feed_dict)
|
||||
self.assertAllEqual(result_a, result_b)
|
||||
def _randomDataOp(self, shape, data_type):
|
||||
return constant_op.constant(self._randomNDArray(shape), dtype=data_type)
|
||||
|
||||
# TODO(duncanriach): add test coverage for deterministic gradients
|
||||
# in eager mode
|
||||
@parameterized.named_parameters(
|
||||
*test_util.generate_combinations_with_testcase_name(
|
||||
# With the selected layer configuration, at least in TensorFlow
|
||||
# version 2.0, when data_layout='channels_last', bias_add operates
|
||||
# deterministically by default. I don't know if this is true for
|
||||
# all layer configurations. These cases are still being tested here,
|
||||
# for completeness.
|
||||
data_layout=['channels_first', 'channels_last'],
|
||||
data_rank=[1, 2, 3],
|
||||
data_type=[dtypes.float16, dtypes.float32, dtypes.float64]))
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@test_util.run_cuda_only
|
||||
def testDeterministicGradients(self, data_layout, data_rank, data_type):
|
||||
seed = (
|
||||
hash(data_layout) % 256 + hash(data_rank) % 256 + hash(data_type) % 256)
|
||||
np.random.seed(seed)
|
||||
batch_size = 10
|
||||
channel_count = 8
|
||||
data_dim = 14
|
||||
in_shape = self._make_shape_tuple(batch_size, channel_count, data_rank,
|
||||
data_dim, data_layout)
|
||||
bias_shape = (channel_count,)
|
||||
out_shape = in_shape
|
||||
in_op = self._random_data_op(in_shape, data_type)
|
||||
bias_op = self._random_data_op(bias_shape, data_type)
|
||||
data_format = self._data_format_from_data_layout(data_layout)
|
||||
bias_add_op = nn_ops.bias_add(in_op, bias_op, data_format=data_format)
|
||||
upstream_gradients = array_ops.placeholder(
|
||||
data_type, shape=out_shape, name='upstream_gradients')
|
||||
gradient_injector_op = bias_add_op * upstream_gradients
|
||||
# The gradient function behaves as if grad_ys is multiplied by the op
|
||||
# gradient result, not passing the upstram gradients through the op's
|
||||
# gradient generation graph. This is the reason for using the
|
||||
# gradient_injector_op
|
||||
grad_ys = None
|
||||
bias_gradients_op = gradients_impl.gradients(
|
||||
gradient_injector_op,
|
||||
bias_op,
|
||||
grad_ys=grad_ys,
|
||||
colocate_gradients_with_ops=True)
|
||||
for i in range(5):
|
||||
feed_dict = {upstream_gradients: self._random_ndarray(out_shape)}
|
||||
self._assert_reproducible(bias_gradients_op, feed_dict=feed_dict)
|
||||
with self.session(force_gpu=True):
|
||||
# Using a cached_session with force_gpu=True does not work at the time
|
||||
# of writing (2019-12-10). Before the @parameterized.named_parameters
|
||||
# decorator was added, this non-cached session context was set outside
|
||||
# the iteration loops for the parameter combinations, and so was re-used.
|
||||
seed = (
|
||||
hash(data_layout) % 256 + hash(data_rank) % 256 +
|
||||
hash(data_type) % 256)
|
||||
np.random.seed(seed)
|
||||
batch_size = 10
|
||||
channel_count = 8
|
||||
data_dim = 14
|
||||
input_shape = self._makeShapeTuple(batch_size, channel_count, data_rank,
|
||||
data_dim, data_layout)
|
||||
bias_shape = (channel_count,)
|
||||
output_shape = input_shape
|
||||
input_val = self._randomDataOp(input_shape, data_type)
|
||||
bias_val = self._randomDataOp(bias_shape, data_type)
|
||||
data_format = self._dataFormatFromDataLayout(data_layout)
|
||||
repeat_count = 5
|
||||
if context.executing_eagerly():
|
||||
|
||||
def bias_gradients(local_seed):
|
||||
np.random.seed(local_seed)
|
||||
upstream_gradients = self._randomDataOp(output_shape, data_type)
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
tape.watch(bias_val)
|
||||
bias_add_output = nn_ops.bias_add(
|
||||
input_val, bias_val, data_format=data_format)
|
||||
gradient_injector_output = bias_add_output * upstream_gradients
|
||||
return tape.gradient(gradient_injector_output, bias_val)
|
||||
|
||||
for i in range(repeat_count):
|
||||
local_seed = seed + i # select different upstream gradients
|
||||
result_a = bias_gradients(local_seed)
|
||||
result_b = bias_gradients(local_seed)
|
||||
self.assertAllEqual(result_a, result_b)
|
||||
else: # graph mode
|
||||
upstream_gradients = array_ops.placeholder(
|
||||
data_type, shape=output_shape, name='upstream_gradients')
|
||||
bias_add_output = nn_ops.bias_add(
|
||||
input_val, bias_val, data_format=data_format)
|
||||
gradient_injector_output = bias_add_output * upstream_gradients
|
||||
# The gradient function behaves as if grad_ys is multiplied by the op
|
||||
# gradient result, not passing the upstram gradients through the op's
|
||||
# gradient generation graph. This is the reason for using the
|
||||
# gradient injector
|
||||
bias_gradients = gradients_impl.gradients(
|
||||
gradient_injector_output,
|
||||
bias_val,
|
||||
grad_ys=None,
|
||||
colocate_gradients_with_ops=True)[0]
|
||||
for i in range(repeat_count):
|
||||
feed_dict = {upstream_gradients: self._randomNDArray(output_shape)}
|
||||
result_a = bias_gradients.eval(feed_dict=feed_dict)
|
||||
result_b = bias_gradients.eval(feed_dict=feed_dict)
|
||||
self.assertAllEqual(result_a, result_b)
|
||||
|
||||
# TODO(duncanriach): Re-enable the following three tests for the error checks
|
||||
# after deterministic functionality is implemented at the CUDA kernel level.
|
||||
|
Loading…
Reference in New Issue
Block a user