Merge pull request #31465 from duncanriach:deterministic_bias_add

PiperOrigin-RevId: 275197941
Change-Id: I06928391a845b2cc04d44d84dc203885c04da2c4
This commit is contained in:
TensorFlower Gardener 2019-10-17 00:50:57 -07:00
commit cc379e4e2e
10 changed files with 573 additions and 285 deletions

View File

@ -295,16 +295,34 @@ cuda_py_test(
],
)
cuda_py_test(
name = "cudnn_determinism_test",
size = "small",
srcs = ["cudnn_determinism_test.py"],
additional_deps = [
"//third_party/py/numpy",
py_library(
name = "cudnn_deterministic_base",
srcs = ["cudnn_deterministic_base.py"],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:client_testlib",
"//tensorflow/python:nn_ops",
"//third_party/py/numpy",
],
)
cuda_py_test(
name = "cudnn_deterministic_ops_test",
size = "small",
srcs = ["cudnn_deterministic_ops_test.py"],
additional_deps = [
":cudnn_deterministic_base",
],
xla_enable_strict_auto_jit = True,
)
cuda_py_test(
name = "cudnn_deterministic_test",
size = "small",
srcs = ["cudnn_deterministic_test.py"],
additional_deps = [
":cudnn_deterministic_base",
],
)
@ -1638,18 +1656,36 @@ cuda_py_test(
],
)
cuda_py_test(
name = "bias_op_test",
size = "medium",
srcs = ["bias_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
py_library(
name = "bias_op_base",
srcs = ["bias_op_base.py"],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradients",
"//tensorflow/python:nn_grad",
"//tensorflow/python:nn_ops",
"//third_party/py/numpy",
],
)
cuda_py_test(
name = "bias_op_deterministic_test",
size = "medium",
srcs = ["bias_op_deterministic_test.py"],
additional_deps = [
":bias_op_base",
],
xla_enable_strict_auto_jit = False,
)
cuda_py_test(
name = "bias_op_test",
size = "medium",
srcs = ["bias_op_test.py"],
additional_deps = [
":bias_op_base",
],
)

View File

@ -0,0 +1,264 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functional tests for BiasAdd."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import nn_ops
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
class BiasAddTestBase(test.TestCase):
def _npBias(self, inputs, bias):
assert len(bias.shape) == 1
assert inputs.shape[-1] == bias.shape[0]
return inputs + bias.reshape(([1] *
(len(inputs.shape) - 1)) + [bias.shape[0]])
def testNpBias(self):
self.assertAllClose(
np.array([[11, 22, 33], [41, 52, 63]]),
self._npBias(
np.array([[10, 20, 30], [40, 50, 60]]), np.array([1, 2, 3])))
def _testBias(self, np_inputs, np_bias, use_gpu=False):
np_val = self._npBias(np_inputs, np_bias)
with self.cached_session(use_gpu=use_gpu):
tf_val = nn_ops.bias_add(np_inputs, np_bias).eval()
self.assertAllCloseAccordingToType(np_val, tf_val)
def _AtLeast3d(self, np_value):
# fill the input value to at least 3-dimension
if np_value.ndim < 3:
return np.reshape(np_value, (1,) * (3 - np_value.ndim) + np_value.shape)
return np_value
def _NHWCToNCHW(self, np_value):
# fill the input value to at least 3-dimension
np_value = self._AtLeast3d(np_value)
# move the last dimension to second
np_dim = list(range(np_value.ndim))
np_dim_new = list(np_dim[0:1]) + list(np_dim[-1:]) + list(np_dim[1:-1])
return np.transpose(np_value, np_dim_new)
def _NCHWToNHWC(self, np_value):
assert len(np_value.shape) >= 3
np_dim = list(range(np_value.ndim))
# move the second dimension to the last
np_dim_new = list(np_dim[0:1]) + list(np_dim[2:]) + list(np_dim[1:2])
return np.transpose(np_value, np_dim_new)
def _testBiasNCHW(self, np_inputs, np_bias, use_gpu):
np_val = self._npBias(np_inputs, np_bias)
np_inputs = self._NHWCToNCHW(np_inputs)
with self.cached_session(use_gpu=use_gpu):
tf_val = nn_ops.bias_add(np_inputs, np_bias, data_format="NCHW").eval()
tf_val = self._NCHWToNHWC(tf_val)
self.assertAllCloseAccordingToType(self._AtLeast3d(np_val), tf_val)
def _testAll(self, np_inputs, np_bias):
self._testBias(np_inputs, np_bias, use_gpu=False)
self._testBiasNCHW(np_inputs, np_bias, use_gpu=False)
if np_inputs.dtype in [np.float16, np.float32, np.float64]:
self._testBias(np_inputs, np_bias, use_gpu=True)
self._testBiasNCHW(np_inputs, np_bias, use_gpu=True)
@test_util.run_deprecated_v1
def testInputDims(self):
with self.assertRaises(ValueError):
nn_ops.bias_add([1, 2], [1])
@test_util.run_deprecated_v1
def testBiasVec(self):
with self.assertRaises(ValueError):
nn_ops.bias_add(
array_ops.reshape([1, 2], shape=[1, 2]),
array_ops.reshape([1, 2], shape=[1, 2]))
@test_util.run_deprecated_v1
def testBiasInputsMatch(self):
with self.assertRaises(ValueError):
nn_ops.bias_add(
array_ops.reshape([1, 2], shape=[1, 2]),
array_ops.reshape([1], shape=[1]))
@test_util.run_deprecated_v1
def testIntTypes(self):
for t in [np.int8, np.int16, np.int32, np.int64]:
self._testAll(
np.array([[10, 20, 30], [40, 50, 60]]).astype(t),
np.array([1, 2, 3]).astype(t))
@test_util.run_deprecated_v1
def testFloatTypes(self):
for t in [np.float16, np.float32, np.float64]:
self._testAll(
np.random.rand(4, 3, 3).astype(t),
np.random.rand(3).astype(t))
@test_util.run_deprecated_v1
def test4DFloatTypes(self):
for t in [np.float16, np.float32, np.float64]:
self._testAll(
np.random.rand(4, 3, 2, 3).astype(t),
np.random.rand(3).astype(t))
self._testAll(
np.random.rand(2048, 4, 4, 4).astype(t),
np.random.rand(4).astype(t))
self._testAll(
np.random.rand(4, 4, 4, 2048).astype(t),
np.random.rand(2048).astype(t))
@test_util.run_deprecated_v1
def test5DFloatTypes(self):
for t in [np.float16, np.float32, np.float64]:
self._testAll(
np.random.rand(4, 3, 2, 3, 4).astype(t),
np.random.rand(4).astype(t))
def _testGradient(self, np_input, bias, dtype, data_format, use_gpu):
with self.cached_session(use_gpu=use_gpu):
if data_format == "NCHW":
np_input = self._NHWCToNCHW(np_input)
input_tensor = constant_op.constant(
np_input, shape=np_input.shape, dtype=dtype)
bias_tensor = constant_op.constant(bias, shape=bias.shape, dtype=dtype)
output_tensor = nn_ops.bias_add(
input_tensor, bias_tensor, data_format=data_format)
tensor_jacob_t, tensor_jacob_n = gradient_checker.compute_gradient(
input_tensor, np_input.shape, output_tensor, np_input.shape)
bias_jacob_t, bias_jacob_n = gradient_checker.compute_gradient(
bias_tensor, bias.shape, output_tensor, np_input.shape)
# Test gradient of BiasAddGrad
bias_add_grad = gradients_impl.gradients(
nn_ops.l2_loss(output_tensor), bias_tensor)[0]
grad_jacob_t, grad_jacob_n = gradient_checker.compute_gradient(
output_tensor, np_input.shape, bias_add_grad, bias.shape)
if dtype == np.float16:
# Compare fp16 analytical gradients to fp32 numerical gradients,
# since fp16 numerical gradients are too imprecise unless great
# care is taken with choosing the inputs and the delta. This is
# a weaker, but pragmatic, check (in particular, it does not test
# the op itself, only its gradient).
input_tensor = constant_op.constant(
np_input, shape=np_input.shape, dtype=np.float32)
bias_tensor = constant_op.constant(
bias, shape=bias.shape, dtype=np.float32)
output_tensor = nn_ops.bias_add(
input_tensor, bias_tensor, data_format=data_format)
_, tensor_jacob_n = gradient_checker.compute_gradient(
input_tensor, np_input.shape, output_tensor, np_input.shape)
_, bias_jacob_n = gradient_checker.compute_gradient(
bias_tensor, bias.shape, output_tensor, np_input.shape)
bias_add_grad = gradients_impl.gradients(
nn_ops.l2_loss(output_tensor), bias_tensor)[0]
_, grad_jacob_n = gradient_checker.compute_gradient(
output_tensor, np_input.shape, bias_add_grad, bias.shape)
threshold = 5e-3
if dtype == dtypes.float64:
threshold = 1e-10
self.assertAllClose(tensor_jacob_t, tensor_jacob_n, threshold, threshold)
self.assertAllClose(bias_jacob_t, bias_jacob_n, threshold, threshold)
self.assertAllClose(grad_jacob_t, grad_jacob_n, threshold, threshold)
@test_util.run_deprecated_v1
def testGradientTensor2D(self):
for (data_format, use_gpu) in ("NHWC", False), ("NHWC", True):
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
np_input = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
dtype=dtype.as_numpy_dtype).reshape(3, 2)
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
@test_util.run_deprecated_v1
def testGradientTensor3D(self):
for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True),
("NCHW", False), ("NCHW", True)]:
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
np_input = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
dtype=dtype.as_numpy_dtype).reshape(1, 3, 2)
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
@test_util.run_deprecated_v1
def testGradientTensor4D(self):
for (data_format, use_gpu) in [("NHWC", False)]:
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
np_input = np.arange(
1.0, 49.0,
dtype=dtype.as_numpy_dtype).reshape([2, 3, 4, 2]).astype(np.float32)
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
np_input = np.arange(
1.0, 513.0,
dtype=dtype.as_numpy_dtype).reshape([64, 2, 2,
2]).astype(np.float32)
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
np_input = np.arange(
1.0, 513.0,
dtype=dtype.as_numpy_dtype).reshape([2, 2, 2,
64]).astype(np.float32)
self._testGradient(np_input,
np.random.rand(64).astype(dtype.as_numpy_dtype),
dtype, data_format, use_gpu)
@test_util.run_deprecated_v1
def testGradientTensor5D(self):
for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True),
("NCHW", False), ("NCHW", True)]:
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
np_input = np.arange(
1.0, 49.0,
dtype=dtype.as_numpy_dtype).reshape([1, 2, 3, 4,
2]).astype(np.float32)
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
@test_util.run_deprecated_v1
def testEmpty(self):
np.random.seed(7)
for shape in (0, 0), (2, 0), (0, 2), (4, 3, 0), (4, 0, 3), (0, 4, 3):
self._testAll(np.random.randn(*shape), np.random.randn(shape[-1]))
@test_util.run_deprecated_v1
def testEmptyGradient(self):
for (data_format, use_gpu) in ("NHWC", False), ("NHWC", True):
for shape in (0, 0), (2, 0), (0, 2):
self._testGradient(
np.random.randn(*shape), np.random.randn(shape[-1]), dtypes.float64,
data_format, use_gpu)
for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True),
("NCHW", False), ("NCHW", True)]:
for shape in (4, 3, 0), (4, 0, 3), (0, 4, 3):
self._testGradient(
np.random.randn(*shape), np.random.randn(shape[-1]), dtypes.float64,
data_format, use_gpu)

View File

@ -0,0 +1,129 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functional tests for deterministic BiasAdd."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.kernel_tests import bias_op_base
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test
class BiasAddDeterministicTest(bias_op_base.BiasAddTestBase):
def _make_shape_tuple(self, batch_size, channel_count, data_rank, data_dim,
data_layout):
data_dims = data_rank * (data_dim,)
if data_layout == 'channels_first':
shape = (batch_size,) + (channel_count,) + data_dims
elif data_layout == 'channels_last':
shape = (batch_size,) + data_dims + (channel_count,)
else:
raise ValueError('Unknown data format')
return shape
def _data_format_from_data_layout(self, data_layout=None):
if data_layout == 'channels_first':
return 'NCHW'
elif data_layout == 'channels_last':
return 'NHWC'
else:
raise ValueError('Unknown data_layout')
def _random_data_op(self, shape, data_type):
return constant_op.constant(
2 * np.random.random_sample(shape) - 1, dtype=data_type)
def _random_ndarray(self, shape):
return 2 * np.random.random_sample(shape) - 1
def _assert_reproducible(self, operation, feed_dict={}):
with self.cached_session(force_gpu=True):
result_a = operation[0].eval(feed_dict=feed_dict)
result_b = operation[0].eval(feed_dict=feed_dict)
self.assertAllEqual(result_a, result_b)
def _testDeterministicGradientsCase(self, data_layout, data_rank, data_type):
seed = (
hash(data_layout) % 256 + hash(data_rank) % 256 + hash(data_type) % 256)
np.random.seed(seed)
batch_size = 10
channel_count = 8
data_dim = 14
in_shape = self._make_shape_tuple(batch_size, channel_count, data_rank,
data_dim, data_layout)
bias_shape = (channel_count,)
out_shape = in_shape
in_op = self._random_data_op(in_shape, data_type)
bias_op = self._random_data_op(bias_shape, data_type)
data_format = self._data_format_from_data_layout(data_layout)
bias_add_op = nn_ops.bias_add(in_op, bias_op, data_format=data_format)
upstream_gradients = array_ops.placeholder(
data_type, shape=out_shape, name='upstream_gradients')
gradient_injector_op = bias_add_op * upstream_gradients
# The gradient function behaves as if grad_ys is multiplied by the op
# gradient result, not passing the upstram gradients through the op's
# gradient generation graph. This is the reason for using the
# gradient_injector_op
grad_ys = None
bias_gradients_op = gradients_impl.gradients(
gradient_injector_op,
bias_op,
grad_ys=grad_ys,
colocate_gradients_with_ops=True)
for i in range(5):
feed_dict = {upstream_gradients: self._random_ndarray(out_shape)}
self._assert_reproducible(bias_gradients_op, feed_dict=feed_dict)
# TODO(duncanriach): add test coverage for deterministic gradients
# in eager mode
@test_util.run_deprecated_v1
@test_util.run_cuda_only
def testDeterministicGradients(self):
for data_layout in ('channels_first', 'channels_last'):
for data_rank in (1, 2, 3):
for data_type in (dtypes.float16, dtypes.float32, dtypes.float64):
self._testDeterministicGradientsCase(data_layout, data_rank,
data_type)
# TODO(duncanriach): Re-enable the following three tests for the error checks
# after deterministic functionality is implemented at the CUDA kernel level.
def testInputDims(self):
pass
def testBiasVec(self):
pass
def testBiasInputsMatch(self):
pass
if __name__ == '__main__':
# Note that the effect of setting the following environment variable to
# 'true' is not tested. Unless we can find a simpler pattern for testing these
# environment variables, it would require this file to be made into a base
# and then two more test files to be created.
os.environ['TF_DETERMINISTIC_OPS'] = '1'
test.main()

View File

@ -1,4 +1,4 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -18,259 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import nn_ops
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.kernel_tests import bias_op_base
from tensorflow.python.platform import test
class BiasAddTest(test.TestCase):
def _npBias(self, inputs, bias):
assert len(bias.shape) == 1
assert inputs.shape[-1] == bias.shape[0]
return inputs + bias.reshape(([1] * (len(inputs.shape) - 1)) +
[bias.shape[0]])
def testNpBias(self):
self.assertAllClose(
np.array([[11, 22, 33], [41, 52, 63]]),
self._npBias(
np.array([[10, 20, 30], [40, 50, 60]]), np.array([1, 2, 3])))
def _testBias(self, np_inputs, np_bias, use_gpu=False):
np_val = self._npBias(np_inputs, np_bias)
with self.cached_session(use_gpu=use_gpu):
tf_val = nn_ops.bias_add(np_inputs, np_bias).eval()
self.assertAllCloseAccordingToType(np_val, tf_val)
def _AtLeast3d(self, np_value):
# fill the input value to at least 3-dimension
if np_value.ndim < 3:
return np.reshape(np_value, (1,) * (3 - np_value.ndim) + np_value.shape)
return np_value
def _NHWCToNCHW(self, np_value):
# fill the input value to at least 3-dimension
np_value = self._AtLeast3d(np_value)
# move the last dimension to second
np_dim = list(range(np_value.ndim))
np_dim_new = list(np_dim[0:1]) + list(np_dim[-1:]) + list(np_dim[1:-1])
return np.transpose(np_value, np_dim_new)
def _NCHWToNHWC(self, np_value):
assert len(np_value.shape) >= 3
np_dim = list(range(np_value.ndim))
# move the second dimension to the last
np_dim_new = list(np_dim[0:1]) + list(np_dim[2:]) + list(np_dim[1:2])
return np.transpose(np_value, np_dim_new)
def _testBiasNCHW(self, np_inputs, np_bias, use_gpu):
np_val = self._npBias(np_inputs, np_bias)
np_inputs = self._NHWCToNCHW(np_inputs)
with self.cached_session(use_gpu=use_gpu):
tf_val = nn_ops.bias_add(np_inputs, np_bias, data_format="NCHW").eval()
tf_val = self._NCHWToNHWC(tf_val)
self.assertAllCloseAccordingToType(self._AtLeast3d(np_val), tf_val)
def _testAll(self, np_inputs, np_bias):
self._testBias(np_inputs, np_bias, use_gpu=False)
self._testBiasNCHW(np_inputs, np_bias, use_gpu=False)
if np_inputs.dtype in [np.float16, np.float32, np.float64]:
self._testBias(np_inputs, np_bias, use_gpu=True)
self._testBiasNCHW(np_inputs, np_bias, use_gpu=True)
@test_util.run_deprecated_v1
def testInputDims(self):
with self.assertRaises(ValueError):
nn_ops.bias_add([1, 2], [1])
@test_util.run_deprecated_v1
def testBiasVec(self):
with self.assertRaises(ValueError):
nn_ops.bias_add(
array_ops.reshape(
[1, 2], shape=[1, 2]),
array_ops.reshape(
[1, 2], shape=[1, 2]))
@test_util.run_deprecated_v1
def testBiasInputsMatch(self):
with self.assertRaises(ValueError):
nn_ops.bias_add(
array_ops.reshape(
[1, 2], shape=[1, 2]),
array_ops.reshape(
[1], shape=[1]))
@test_util.run_deprecated_v1
def testIntTypes(self):
for t in [np.int8, np.int16, np.int32, np.int64]:
self._testAll(
np.array([[10, 20, 30], [40, 50, 60]]).astype(t),
np.array([1, 2, 3]).astype(t))
@test_util.run_deprecated_v1
def testFloatTypes(self):
for t in [np.float16, np.float32, np.float64]:
self._testAll(
np.random.rand(4, 3, 3).astype(t), np.random.rand(3).astype(t))
@test_util.run_deprecated_v1
def test4DFloatTypes(self):
for t in [np.float16, np.float32, np.float64]:
self._testAll(
np.random.rand(4, 3, 2, 3).astype(t),
np.random.rand(3).astype(t))
self._testAll(
np.random.rand(2048, 4, 4, 4).astype(t),
np.random.rand(4).astype(t))
self._testAll(
np.random.rand(4, 4, 4, 2048).astype(t),
np.random.rand(2048).astype(t))
@test_util.run_deprecated_v1
def test5DFloatTypes(self):
for t in [np.float16, np.float32, np.float64]:
self._testAll(
np.random.rand(4, 3, 2, 3, 4).astype(t),
np.random.rand(4).astype(t))
def _testGradient(self, np_input, bias, dtype, data_format, use_gpu):
with self.cached_session(use_gpu=use_gpu):
if data_format == "NCHW":
np_input = self._NHWCToNCHW(np_input)
input_tensor = constant_op.constant(
np_input, shape=np_input.shape, dtype=dtype)
bias_tensor = constant_op.constant(bias, shape=bias.shape, dtype=dtype)
output_tensor = nn_ops.bias_add(
input_tensor, bias_tensor, data_format=data_format)
tensor_jacob_t, tensor_jacob_n = gradient_checker.compute_gradient(
input_tensor, np_input.shape, output_tensor, np_input.shape)
bias_jacob_t, bias_jacob_n = gradient_checker.compute_gradient(
bias_tensor, bias.shape, output_tensor, np_input.shape)
# Test gradient of BiasAddGrad
bias_add_grad = gradients_impl.gradients(
nn_ops.l2_loss(output_tensor), bias_tensor)[0]
grad_jacob_t, grad_jacob_n = gradient_checker.compute_gradient(
output_tensor, np_input.shape, bias_add_grad, bias.shape)
if dtype == np.float16:
# Compare fp16 theoretical gradients to fp32 numerical gradients,
# since fp16 numerical gradients are too imprecise unless great
# care is taken with choosing the inputs and the delta. This is
# a weaker check (in particular, it does not test the op itself,
# only its gradient), but it's much better than nothing.
input_tensor = constant_op.constant(
np_input, shape=np_input.shape, dtype=np.float32)
bias_tensor = constant_op.constant(
bias, shape=bias.shape, dtype=np.float32)
output_tensor = nn_ops.bias_add(
input_tensor, bias_tensor, data_format=data_format)
_, tensor_jacob_n = gradient_checker.compute_gradient(input_tensor,
np_input.shape,
output_tensor,
np_input.shape)
_, bias_jacob_n = gradient_checker.compute_gradient(bias_tensor,
bias.shape,
output_tensor,
np_input.shape)
bias_add_grad = gradients_impl.gradients(
nn_ops.l2_loss(output_tensor), bias_tensor)[0]
_, grad_jacob_n = gradient_checker.compute_gradient(output_tensor,
np_input.shape,
bias_add_grad,
bias.shape)
threshold = 5e-3
if dtype == dtypes.float64:
threshold = 1e-10
self.assertAllClose(tensor_jacob_t, tensor_jacob_n, threshold, threshold)
self.assertAllClose(bias_jacob_t, bias_jacob_n, threshold, threshold)
self.assertAllClose(grad_jacob_t, grad_jacob_n, threshold, threshold)
@test_util.run_deprecated_v1
def testGradientTensor2D(self):
for (data_format, use_gpu) in ("NHWC", False), ("NHWC", True):
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
np_input = np.array(
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
dtype=dtype.as_numpy_dtype).reshape(3, 2)
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
@test_util.run_deprecated_v1
def testGradientTensor3D(self):
for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True),
("NCHW", False), ("NCHW", True)]:
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
np_input = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
dtype=dtype.as_numpy_dtype).reshape(1, 3, 2)
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
@test_util.run_deprecated_v1
def testGradientTensor4D(self):
for (data_format, use_gpu) in [("NHWC", False)]:
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
np_input = np.arange(
1.0, 49.0, dtype=dtype.as_numpy_dtype).reshape(
[2, 3, 4, 2]).astype(np.float32)
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
np_input = np.arange(
1.0, 513.0, dtype=dtype.as_numpy_dtype).reshape(
[64, 2, 2, 2]).astype(np.float32)
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
np_input = np.arange(
1.0, 513.0, dtype=dtype.as_numpy_dtype).reshape(
[2, 2, 2, 64]).astype(np.float32)
self._testGradient(np_input,
np.random.rand(64).astype(dtype.as_numpy_dtype),
dtype, data_format, use_gpu)
@test_util.run_deprecated_v1
def testGradientTensor5D(self):
for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True),
("NCHW", False), ("NCHW", True)]:
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
np_input = np.arange(
1.0, 49.0, dtype=dtype.as_numpy_dtype).reshape(
[1, 2, 3, 4, 2]).astype(np.float32)
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
@test_util.run_deprecated_v1
def testEmpty(self):
np.random.seed(7)
for shape in (0, 0), (2, 0), (0, 2), (4, 3, 0), (4, 0, 3), (0, 4, 3):
self._testAll(np.random.randn(*shape), np.random.randn(shape[-1]))
@test_util.run_deprecated_v1
def testEmptyGradient(self):
for (data_format, use_gpu) in ("NHWC", False), ("NHWC", True):
for shape in (0, 0), (2, 0), (0, 2):
self._testGradient(
np.random.randn(*shape), np.random.randn(shape[-1]), dtypes.float64,
data_format, use_gpu)
for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True),
("NCHW", False), ("NCHW", True)]:
for shape in (4, 3, 0), (4, 0, 3), (0, 4, 3):
self._testGradient(
np.random.randn(*shape),
np.random.randn(shape[-1]), dtypes.float64, data_format, use_gpu)
BiasAddTest = bias_op_base.BiasAddTestBase
if __name__ == "__main__":
test.main()

View File

@ -12,14 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TF_CUDNN_DETERMINISTIC=true."""
"""Tests for deterministic cuDNN functionality."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
import numpy as np
from tensorflow.python.framework import constant_op
@ -28,12 +27,14 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test
# The TF_CUDNN_DETERMINISTIC flag disables autotuning of cuDNN algorithms and
# causes deterministic cuDNN algorithms to be selected when both deterministic
# and non-deterministic algorithms are available. These tests are intended to
# confirm that deterministic algorithms are chosen when
# TF_CUDNN_DETERMINISTIC=true. The configurations tested were confirmed to
# produce non-deterministic results without setting TF_CUDNN_DETERMINISTIC=true
# Setting either of the two environment variables TF_CUDNN_DETERMINISTIC or
# TF_DETERMINISTIC_OPS to "true" or "1" will disable autotuning of cuDNN
# algorithms and cause deterministic cuDNN algorithms to be selected when both
# deterministic and non-deterministic algorithms are available. These tests are
# intended to confirm that deterministic algorithms are chosen when either
# environment variable is set to "true" or "1". The tested configurations were
# first confirmed to produce non-deterministic results when the environment
# variables are not set.
_PADDING = 'SAME'
_STRIDES = [1, 1, 1, 1]
@ -96,8 +97,3 @@ class ConvolutionTest(test.TestCase):
# TODO(duncanriach): (1) add test to confirm that forward autotuning is
# disabled for cuDNN convolution; (2) add test for deterministic cuDNN
# max-pooling
if __name__ == '__main__':
os.environ['TF_CUDNN_DETERMINISTIC'] = 'true'
test.main()

View File

@ -0,0 +1,33 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TF_DETERMINISTIC_OPS=1."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from tensorflow.python.kernel_tests import cudnn_deterministic_base
from tensorflow.python.platform import test
ConvolutionTest = cudnn_deterministic_base.ConvolutionTest
if __name__ == '__main__':
# Note that the effect of setting the following environment variable to
# 'true' is not tested. Unless we can find a simpler pattern for testing these
# environment variables, it would require another test file to be added.
os.environ['TF_DETERMINISTIC_OPS'] = '1'
test.main()

View File

@ -0,0 +1,33 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TF_CUDNN_DETERMINISTIC=1."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from tensorflow.python.kernel_tests import cudnn_deterministic_base
from tensorflow.python.platform import test
ConvolutionTest = cudnn_deterministic_base.ConvolutionTest
if __name__ == '__main__':
# Note that the effect of setting the following environment variable to
# 'true' is not tested. Unless we can find a simpler pattern for testing these
# environment variables, it would require another test file to be added.
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
test.main()

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import collections
import numbers
import os
import numpy as np
@ -2682,6 +2683,19 @@ def conv_transpose(input, # pylint: disable=redefined-builtin
name=name)
def _tf_deterministic_ops():
if _tf_deterministic_ops.value is None:
tf_deterministic_ops = os.environ.get("TF_DETERMINISTIC_OPS")
if tf_deterministic_ops is not None:
tf_deterministic_ops = tf_deterministic_ops.lower()
_tf_deterministic_ops.value = (
tf_deterministic_ops == "true" or tf_deterministic_ops == "1")
return _tf_deterministic_ops.value
_tf_deterministic_ops.value = None
@tf_export("nn.bias_add")
def bias_add(value, bias, data_format=None, name=None):
"""Adds `bias` to `value`.
@ -2697,11 +2711,19 @@ def bias_add(value, bias, data_format=None, name=None):
bias: A 1-D `Tensor` with size matching the channel dimension of `value`.
Must be the same type as `value` unless `value` is a quantized type,
in which case a different quantized type may be used.
data_format: A string. 'N...C' and 'NC...' are supported.
data_format: A string. 'N...C' and 'NC...' are supported. If `None` (the
default) is specified then 'N..C' is assumed.
name: A name for the operation (optional).
Returns:
A `Tensor` with the same type as `value`.
Raises:
ValueError if data format is unrecognized, if `value` has less than two
dimensions when `data_format` is 'N..C'/`None` or `value` has less
then three dimensions when `data_format` is `NC..`, if `bias` does not
have exactly one dimension (is a vector), or if the size of `bias`
does not match the size of the channel dimension of `value`.
"""
with ops.name_scope(name, "BiasAdd", [value, bias]) as name:
if data_format is not None:
@ -2715,7 +2737,25 @@ def bias_add(value, bias, data_format=None, name=None):
if not context.executing_eagerly():
value = ops.convert_to_tensor(value, name="input")
bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias")
return gen_nn_ops.bias_add(value, bias, data_format=data_format, name=name)
# TODO(duncanriach): Implement deterministic functionality at CUDA kernel
# level.
if _tf_deterministic_ops():
# Note that this code does not implement the same error checks as the
# pre-existing C++ ops.
if data_format == "NCHW":
broadcast_shape_head = [1, array_ops.size(bias)]
broadcast_shape_tail = array_ops.ones(
array_ops.rank(value) - 2, dtype=dtypes.int32)
broadcast_shape = array_ops.concat(
[broadcast_shape_head, broadcast_shape_tail], 0)
return math_ops.add(
value, array_ops.reshape(bias, broadcast_shape), name=name)
else: # data_format == 'NHWC' or data_format == None
return math_ops.add(value, bias, name=name)
else:
return gen_nn_ops.bias_add(
value, bias, data_format=data_format, name=name)
def bias_add_v1(value, bias, name=None):

View File

@ -632,14 +632,18 @@ bool BatchnormSpatialPersistentEnabled() {
// A helper function to decide whether to enable deterministic functionality.
bool RequireDeterminism() {
static bool is_enabled = [] {
bool is_enabled = false;
static bool require_determinism = [] {
bool deterministic_ops = false;
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
/*default_val=*/false,
&deterministic_ops));
bool cudnn_deterministic = false;
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
/*default_val=*/false,
&is_enabled));
return is_enabled;
&cudnn_deterministic));
return deterministic_ops || cudnn_deterministic;
}();
return is_enabled;
return require_determinism;
}
std::tuple<int, int> GetCcMajorMinor(Stream* stream) {

View File

@ -89,6 +89,8 @@ COMMON_PIP_DEPS = [
"//tensorflow/python/keras/distribute:distribute_strategy_test_lib",
"//tensorflow/python/keras/distribute:multi_worker_testing_utils",
"//tensorflow/python/keras/mixed_precision/experimental:test_util",
"//tensorflow/python/kernel_tests:cudnn_deterministic_base",
"//tensorflow/python/kernel_tests:bias_op_base",
"//tensorflow/python/kernel_tests/random:util",
"//tensorflow/python/kernel_tests/signal:test_util",
"//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files",