Merge pull request #31465 from duncanriach:deterministic_bias_add
PiperOrigin-RevId: 275197941 Change-Id: I06928391a845b2cc04d44d84dc203885c04da2c4
This commit is contained in:
commit
cc379e4e2e
@ -295,16 +295,34 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "cudnn_determinism_test",
|
||||
size = "small",
|
||||
srcs = ["cudnn_determinism_test.py"],
|
||||
additional_deps = [
|
||||
"//third_party/py/numpy",
|
||||
py_library(
|
||||
name = "cudnn_deterministic_base",
|
||||
srcs = ["cudnn_deterministic_base.py"],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "cudnn_deterministic_ops_test",
|
||||
size = "small",
|
||||
srcs = ["cudnn_deterministic_ops_test.py"],
|
||||
additional_deps = [
|
||||
":cudnn_deterministic_base",
|
||||
],
|
||||
xla_enable_strict_auto_jit = True,
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "cudnn_deterministic_test",
|
||||
size = "small",
|
||||
srcs = ["cudnn_deterministic_test.py"],
|
||||
additional_deps = [
|
||||
":cudnn_deterministic_base",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1638,18 +1656,36 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "bias_op_test",
|
||||
size = "medium",
|
||||
srcs = ["bias_op_test.py"],
|
||||
additional_deps = [
|
||||
"//third_party/py/numpy",
|
||||
py_library(
|
||||
name = "bias_op_base",
|
||||
srcs = ["bias_op_base.py"],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:nn_grad",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "bias_op_deterministic_test",
|
||||
size = "medium",
|
||||
srcs = ["bias_op_deterministic_test.py"],
|
||||
additional_deps = [
|
||||
":bias_op_base",
|
||||
],
|
||||
xla_enable_strict_auto_jit = False,
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "bias_op_test",
|
||||
size = "medium",
|
||||
srcs = ["bias_op_test.py"],
|
||||
additional_deps = [
|
||||
":bias_op_base",
|
||||
],
|
||||
)
|
||||
|
||||
|
264
tensorflow/python/kernel_tests/bias_op_base.py
Normal file
264
tensorflow/python/kernel_tests/bias_op_base.py
Normal file
@ -0,0 +1,264 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Functional tests for BiasAdd."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import nn_ops
|
||||
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class BiasAddTestBase(test.TestCase):
|
||||
|
||||
def _npBias(self, inputs, bias):
|
||||
assert len(bias.shape) == 1
|
||||
assert inputs.shape[-1] == bias.shape[0]
|
||||
return inputs + bias.reshape(([1] *
|
||||
(len(inputs.shape) - 1)) + [bias.shape[0]])
|
||||
|
||||
def testNpBias(self):
|
||||
self.assertAllClose(
|
||||
np.array([[11, 22, 33], [41, 52, 63]]),
|
||||
self._npBias(
|
||||
np.array([[10, 20, 30], [40, 50, 60]]), np.array([1, 2, 3])))
|
||||
|
||||
def _testBias(self, np_inputs, np_bias, use_gpu=False):
|
||||
np_val = self._npBias(np_inputs, np_bias)
|
||||
with self.cached_session(use_gpu=use_gpu):
|
||||
tf_val = nn_ops.bias_add(np_inputs, np_bias).eval()
|
||||
self.assertAllCloseAccordingToType(np_val, tf_val)
|
||||
|
||||
def _AtLeast3d(self, np_value):
|
||||
# fill the input value to at least 3-dimension
|
||||
if np_value.ndim < 3:
|
||||
return np.reshape(np_value, (1,) * (3 - np_value.ndim) + np_value.shape)
|
||||
return np_value
|
||||
|
||||
def _NHWCToNCHW(self, np_value):
|
||||
# fill the input value to at least 3-dimension
|
||||
np_value = self._AtLeast3d(np_value)
|
||||
# move the last dimension to second
|
||||
np_dim = list(range(np_value.ndim))
|
||||
np_dim_new = list(np_dim[0:1]) + list(np_dim[-1:]) + list(np_dim[1:-1])
|
||||
return np.transpose(np_value, np_dim_new)
|
||||
|
||||
def _NCHWToNHWC(self, np_value):
|
||||
assert len(np_value.shape) >= 3
|
||||
np_dim = list(range(np_value.ndim))
|
||||
# move the second dimension to the last
|
||||
np_dim_new = list(np_dim[0:1]) + list(np_dim[2:]) + list(np_dim[1:2])
|
||||
return np.transpose(np_value, np_dim_new)
|
||||
|
||||
def _testBiasNCHW(self, np_inputs, np_bias, use_gpu):
|
||||
np_val = self._npBias(np_inputs, np_bias)
|
||||
np_inputs = self._NHWCToNCHW(np_inputs)
|
||||
with self.cached_session(use_gpu=use_gpu):
|
||||
tf_val = nn_ops.bias_add(np_inputs, np_bias, data_format="NCHW").eval()
|
||||
tf_val = self._NCHWToNHWC(tf_val)
|
||||
self.assertAllCloseAccordingToType(self._AtLeast3d(np_val), tf_val)
|
||||
|
||||
def _testAll(self, np_inputs, np_bias):
|
||||
self._testBias(np_inputs, np_bias, use_gpu=False)
|
||||
self._testBiasNCHW(np_inputs, np_bias, use_gpu=False)
|
||||
if np_inputs.dtype in [np.float16, np.float32, np.float64]:
|
||||
self._testBias(np_inputs, np_bias, use_gpu=True)
|
||||
self._testBiasNCHW(np_inputs, np_bias, use_gpu=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testInputDims(self):
|
||||
with self.assertRaises(ValueError):
|
||||
nn_ops.bias_add([1, 2], [1])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBiasVec(self):
|
||||
with self.assertRaises(ValueError):
|
||||
nn_ops.bias_add(
|
||||
array_ops.reshape([1, 2], shape=[1, 2]),
|
||||
array_ops.reshape([1, 2], shape=[1, 2]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBiasInputsMatch(self):
|
||||
with self.assertRaises(ValueError):
|
||||
nn_ops.bias_add(
|
||||
array_ops.reshape([1, 2], shape=[1, 2]),
|
||||
array_ops.reshape([1], shape=[1]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testIntTypes(self):
|
||||
for t in [np.int8, np.int16, np.int32, np.int64]:
|
||||
self._testAll(
|
||||
np.array([[10, 20, 30], [40, 50, 60]]).astype(t),
|
||||
np.array([1, 2, 3]).astype(t))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFloatTypes(self):
|
||||
for t in [np.float16, np.float32, np.float64]:
|
||||
self._testAll(
|
||||
np.random.rand(4, 3, 3).astype(t),
|
||||
np.random.rand(3).astype(t))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test4DFloatTypes(self):
|
||||
for t in [np.float16, np.float32, np.float64]:
|
||||
self._testAll(
|
||||
np.random.rand(4, 3, 2, 3).astype(t),
|
||||
np.random.rand(3).astype(t))
|
||||
self._testAll(
|
||||
np.random.rand(2048, 4, 4, 4).astype(t),
|
||||
np.random.rand(4).astype(t))
|
||||
self._testAll(
|
||||
np.random.rand(4, 4, 4, 2048).astype(t),
|
||||
np.random.rand(2048).astype(t))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test5DFloatTypes(self):
|
||||
for t in [np.float16, np.float32, np.float64]:
|
||||
self._testAll(
|
||||
np.random.rand(4, 3, 2, 3, 4).astype(t),
|
||||
np.random.rand(4).astype(t))
|
||||
|
||||
def _testGradient(self, np_input, bias, dtype, data_format, use_gpu):
|
||||
with self.cached_session(use_gpu=use_gpu):
|
||||
if data_format == "NCHW":
|
||||
np_input = self._NHWCToNCHW(np_input)
|
||||
input_tensor = constant_op.constant(
|
||||
np_input, shape=np_input.shape, dtype=dtype)
|
||||
bias_tensor = constant_op.constant(bias, shape=bias.shape, dtype=dtype)
|
||||
output_tensor = nn_ops.bias_add(
|
||||
input_tensor, bias_tensor, data_format=data_format)
|
||||
tensor_jacob_t, tensor_jacob_n = gradient_checker.compute_gradient(
|
||||
input_tensor, np_input.shape, output_tensor, np_input.shape)
|
||||
bias_jacob_t, bias_jacob_n = gradient_checker.compute_gradient(
|
||||
bias_tensor, bias.shape, output_tensor, np_input.shape)
|
||||
|
||||
# Test gradient of BiasAddGrad
|
||||
bias_add_grad = gradients_impl.gradients(
|
||||
nn_ops.l2_loss(output_tensor), bias_tensor)[0]
|
||||
grad_jacob_t, grad_jacob_n = gradient_checker.compute_gradient(
|
||||
output_tensor, np_input.shape, bias_add_grad, bias.shape)
|
||||
|
||||
if dtype == np.float16:
|
||||
# Compare fp16 analytical gradients to fp32 numerical gradients,
|
||||
# since fp16 numerical gradients are too imprecise unless great
|
||||
# care is taken with choosing the inputs and the delta. This is
|
||||
# a weaker, but pragmatic, check (in particular, it does not test
|
||||
# the op itself, only its gradient).
|
||||
input_tensor = constant_op.constant(
|
||||
np_input, shape=np_input.shape, dtype=np.float32)
|
||||
bias_tensor = constant_op.constant(
|
||||
bias, shape=bias.shape, dtype=np.float32)
|
||||
output_tensor = nn_ops.bias_add(
|
||||
input_tensor, bias_tensor, data_format=data_format)
|
||||
_, tensor_jacob_n = gradient_checker.compute_gradient(
|
||||
input_tensor, np_input.shape, output_tensor, np_input.shape)
|
||||
_, bias_jacob_n = gradient_checker.compute_gradient(
|
||||
bias_tensor, bias.shape, output_tensor, np_input.shape)
|
||||
|
||||
bias_add_grad = gradients_impl.gradients(
|
||||
nn_ops.l2_loss(output_tensor), bias_tensor)[0]
|
||||
_, grad_jacob_n = gradient_checker.compute_gradient(
|
||||
output_tensor, np_input.shape, bias_add_grad, bias.shape)
|
||||
|
||||
threshold = 5e-3
|
||||
if dtype == dtypes.float64:
|
||||
threshold = 1e-10
|
||||
self.assertAllClose(tensor_jacob_t, tensor_jacob_n, threshold, threshold)
|
||||
self.assertAllClose(bias_jacob_t, bias_jacob_n, threshold, threshold)
|
||||
self.assertAllClose(grad_jacob_t, grad_jacob_n, threshold, threshold)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradientTensor2D(self):
|
||||
for (data_format, use_gpu) in ("NHWC", False), ("NHWC", True):
|
||||
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
|
||||
np_input = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
|
||||
dtype=dtype.as_numpy_dtype).reshape(3, 2)
|
||||
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
|
||||
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradientTensor3D(self):
|
||||
for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True),
|
||||
("NCHW", False), ("NCHW", True)]:
|
||||
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
|
||||
np_input = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
|
||||
dtype=dtype.as_numpy_dtype).reshape(1, 3, 2)
|
||||
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
|
||||
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradientTensor4D(self):
|
||||
for (data_format, use_gpu) in [("NHWC", False)]:
|
||||
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
|
||||
np_input = np.arange(
|
||||
1.0, 49.0,
|
||||
dtype=dtype.as_numpy_dtype).reshape([2, 3, 4, 2]).astype(np.float32)
|
||||
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
|
||||
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
|
||||
np_input = np.arange(
|
||||
1.0, 513.0,
|
||||
dtype=dtype.as_numpy_dtype).reshape([64, 2, 2,
|
||||
2]).astype(np.float32)
|
||||
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
|
||||
np_input = np.arange(
|
||||
1.0, 513.0,
|
||||
dtype=dtype.as_numpy_dtype).reshape([2, 2, 2,
|
||||
64]).astype(np.float32)
|
||||
self._testGradient(np_input,
|
||||
np.random.rand(64).astype(dtype.as_numpy_dtype),
|
||||
dtype, data_format, use_gpu)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradientTensor5D(self):
|
||||
for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True),
|
||||
("NCHW", False), ("NCHW", True)]:
|
||||
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
|
||||
np_input = np.arange(
|
||||
1.0, 49.0,
|
||||
dtype=dtype.as_numpy_dtype).reshape([1, 2, 3, 4,
|
||||
2]).astype(np.float32)
|
||||
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
|
||||
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testEmpty(self):
|
||||
np.random.seed(7)
|
||||
for shape in (0, 0), (2, 0), (0, 2), (4, 3, 0), (4, 0, 3), (0, 4, 3):
|
||||
self._testAll(np.random.randn(*shape), np.random.randn(shape[-1]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testEmptyGradient(self):
|
||||
for (data_format, use_gpu) in ("NHWC", False), ("NHWC", True):
|
||||
for shape in (0, 0), (2, 0), (0, 2):
|
||||
self._testGradient(
|
||||
np.random.randn(*shape), np.random.randn(shape[-1]), dtypes.float64,
|
||||
data_format, use_gpu)
|
||||
|
||||
for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True),
|
||||
("NCHW", False), ("NCHW", True)]:
|
||||
for shape in (4, 3, 0), (4, 0, 3), (0, 4, 3):
|
||||
self._testGradient(
|
||||
np.random.randn(*shape), np.random.randn(shape[-1]), dtypes.float64,
|
||||
data_format, use_gpu)
|
129
tensorflow/python/kernel_tests/bias_op_deterministic_test.py
Normal file
129
tensorflow/python/kernel_tests/bias_op_deterministic_test.py
Normal file
@ -0,0 +1,129 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Functional tests for deterministic BiasAdd."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.kernel_tests import bias_op_base
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class BiasAddDeterministicTest(bias_op_base.BiasAddTestBase):
|
||||
|
||||
def _make_shape_tuple(self, batch_size, channel_count, data_rank, data_dim,
|
||||
data_layout):
|
||||
data_dims = data_rank * (data_dim,)
|
||||
if data_layout == 'channels_first':
|
||||
shape = (batch_size,) + (channel_count,) + data_dims
|
||||
elif data_layout == 'channels_last':
|
||||
shape = (batch_size,) + data_dims + (channel_count,)
|
||||
else:
|
||||
raise ValueError('Unknown data format')
|
||||
return shape
|
||||
|
||||
def _data_format_from_data_layout(self, data_layout=None):
|
||||
if data_layout == 'channels_first':
|
||||
return 'NCHW'
|
||||
elif data_layout == 'channels_last':
|
||||
return 'NHWC'
|
||||
else:
|
||||
raise ValueError('Unknown data_layout')
|
||||
|
||||
def _random_data_op(self, shape, data_type):
|
||||
return constant_op.constant(
|
||||
2 * np.random.random_sample(shape) - 1, dtype=data_type)
|
||||
|
||||
def _random_ndarray(self, shape):
|
||||
return 2 * np.random.random_sample(shape) - 1
|
||||
|
||||
def _assert_reproducible(self, operation, feed_dict={}):
|
||||
with self.cached_session(force_gpu=True):
|
||||
result_a = operation[0].eval(feed_dict=feed_dict)
|
||||
result_b = operation[0].eval(feed_dict=feed_dict)
|
||||
self.assertAllEqual(result_a, result_b)
|
||||
|
||||
def _testDeterministicGradientsCase(self, data_layout, data_rank, data_type):
|
||||
seed = (
|
||||
hash(data_layout) % 256 + hash(data_rank) % 256 + hash(data_type) % 256)
|
||||
np.random.seed(seed)
|
||||
batch_size = 10
|
||||
channel_count = 8
|
||||
data_dim = 14
|
||||
in_shape = self._make_shape_tuple(batch_size, channel_count, data_rank,
|
||||
data_dim, data_layout)
|
||||
bias_shape = (channel_count,)
|
||||
out_shape = in_shape
|
||||
in_op = self._random_data_op(in_shape, data_type)
|
||||
bias_op = self._random_data_op(bias_shape, data_type)
|
||||
data_format = self._data_format_from_data_layout(data_layout)
|
||||
bias_add_op = nn_ops.bias_add(in_op, bias_op, data_format=data_format)
|
||||
upstream_gradients = array_ops.placeholder(
|
||||
data_type, shape=out_shape, name='upstream_gradients')
|
||||
gradient_injector_op = bias_add_op * upstream_gradients
|
||||
# The gradient function behaves as if grad_ys is multiplied by the op
|
||||
# gradient result, not passing the upstram gradients through the op's
|
||||
# gradient generation graph. This is the reason for using the
|
||||
# gradient_injector_op
|
||||
grad_ys = None
|
||||
bias_gradients_op = gradients_impl.gradients(
|
||||
gradient_injector_op,
|
||||
bias_op,
|
||||
grad_ys=grad_ys,
|
||||
colocate_gradients_with_ops=True)
|
||||
for i in range(5):
|
||||
feed_dict = {upstream_gradients: self._random_ndarray(out_shape)}
|
||||
self._assert_reproducible(bias_gradients_op, feed_dict=feed_dict)
|
||||
|
||||
# TODO(duncanriach): add test coverage for deterministic gradients
|
||||
# in eager mode
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.run_cuda_only
|
||||
def testDeterministicGradients(self):
|
||||
for data_layout in ('channels_first', 'channels_last'):
|
||||
for data_rank in (1, 2, 3):
|
||||
for data_type in (dtypes.float16, dtypes.float32, dtypes.float64):
|
||||
self._testDeterministicGradientsCase(data_layout, data_rank,
|
||||
data_type)
|
||||
|
||||
# TODO(duncanriach): Re-enable the following three tests for the error checks
|
||||
# after deterministic functionality is implemented at the CUDA kernel level.
|
||||
def testInputDims(self):
|
||||
pass
|
||||
|
||||
def testBiasVec(self):
|
||||
pass
|
||||
|
||||
def testBiasInputsMatch(self):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Note that the effect of setting the following environment variable to
|
||||
# 'true' is not tested. Unless we can find a simpler pattern for testing these
|
||||
# environment variables, it would require this file to be made into a base
|
||||
# and then two more test files to be created.
|
||||
os.environ['TF_DETERMINISTIC_OPS'] = '1'
|
||||
test.main()
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -18,259 +18,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import nn_ops
|
||||
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.kernel_tests import bias_op_base
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class BiasAddTest(test.TestCase):
|
||||
|
||||
def _npBias(self, inputs, bias):
|
||||
assert len(bias.shape) == 1
|
||||
assert inputs.shape[-1] == bias.shape[0]
|
||||
return inputs + bias.reshape(([1] * (len(inputs.shape) - 1)) +
|
||||
[bias.shape[0]])
|
||||
|
||||
def testNpBias(self):
|
||||
self.assertAllClose(
|
||||
np.array([[11, 22, 33], [41, 52, 63]]),
|
||||
self._npBias(
|
||||
np.array([[10, 20, 30], [40, 50, 60]]), np.array([1, 2, 3])))
|
||||
|
||||
def _testBias(self, np_inputs, np_bias, use_gpu=False):
|
||||
np_val = self._npBias(np_inputs, np_bias)
|
||||
with self.cached_session(use_gpu=use_gpu):
|
||||
tf_val = nn_ops.bias_add(np_inputs, np_bias).eval()
|
||||
self.assertAllCloseAccordingToType(np_val, tf_val)
|
||||
|
||||
def _AtLeast3d(self, np_value):
|
||||
# fill the input value to at least 3-dimension
|
||||
if np_value.ndim < 3:
|
||||
return np.reshape(np_value, (1,) * (3 - np_value.ndim) + np_value.shape)
|
||||
return np_value
|
||||
|
||||
def _NHWCToNCHW(self, np_value):
|
||||
# fill the input value to at least 3-dimension
|
||||
np_value = self._AtLeast3d(np_value)
|
||||
# move the last dimension to second
|
||||
np_dim = list(range(np_value.ndim))
|
||||
np_dim_new = list(np_dim[0:1]) + list(np_dim[-1:]) + list(np_dim[1:-1])
|
||||
return np.transpose(np_value, np_dim_new)
|
||||
|
||||
def _NCHWToNHWC(self, np_value):
|
||||
assert len(np_value.shape) >= 3
|
||||
np_dim = list(range(np_value.ndim))
|
||||
# move the second dimension to the last
|
||||
np_dim_new = list(np_dim[0:1]) + list(np_dim[2:]) + list(np_dim[1:2])
|
||||
return np.transpose(np_value, np_dim_new)
|
||||
|
||||
def _testBiasNCHW(self, np_inputs, np_bias, use_gpu):
|
||||
np_val = self._npBias(np_inputs, np_bias)
|
||||
np_inputs = self._NHWCToNCHW(np_inputs)
|
||||
with self.cached_session(use_gpu=use_gpu):
|
||||
tf_val = nn_ops.bias_add(np_inputs, np_bias, data_format="NCHW").eval()
|
||||
tf_val = self._NCHWToNHWC(tf_val)
|
||||
self.assertAllCloseAccordingToType(self._AtLeast3d(np_val), tf_val)
|
||||
|
||||
def _testAll(self, np_inputs, np_bias):
|
||||
self._testBias(np_inputs, np_bias, use_gpu=False)
|
||||
self._testBiasNCHW(np_inputs, np_bias, use_gpu=False)
|
||||
if np_inputs.dtype in [np.float16, np.float32, np.float64]:
|
||||
self._testBias(np_inputs, np_bias, use_gpu=True)
|
||||
self._testBiasNCHW(np_inputs, np_bias, use_gpu=True)
|
||||
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testInputDims(self):
|
||||
with self.assertRaises(ValueError):
|
||||
nn_ops.bias_add([1, 2], [1])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBiasVec(self):
|
||||
with self.assertRaises(ValueError):
|
||||
nn_ops.bias_add(
|
||||
array_ops.reshape(
|
||||
[1, 2], shape=[1, 2]),
|
||||
array_ops.reshape(
|
||||
[1, 2], shape=[1, 2]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBiasInputsMatch(self):
|
||||
with self.assertRaises(ValueError):
|
||||
nn_ops.bias_add(
|
||||
array_ops.reshape(
|
||||
[1, 2], shape=[1, 2]),
|
||||
array_ops.reshape(
|
||||
[1], shape=[1]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testIntTypes(self):
|
||||
for t in [np.int8, np.int16, np.int32, np.int64]:
|
||||
self._testAll(
|
||||
np.array([[10, 20, 30], [40, 50, 60]]).astype(t),
|
||||
np.array([1, 2, 3]).astype(t))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFloatTypes(self):
|
||||
for t in [np.float16, np.float32, np.float64]:
|
||||
self._testAll(
|
||||
np.random.rand(4, 3, 3).astype(t), np.random.rand(3).astype(t))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test4DFloatTypes(self):
|
||||
for t in [np.float16, np.float32, np.float64]:
|
||||
self._testAll(
|
||||
np.random.rand(4, 3, 2, 3).astype(t),
|
||||
np.random.rand(3).astype(t))
|
||||
self._testAll(
|
||||
np.random.rand(2048, 4, 4, 4).astype(t),
|
||||
np.random.rand(4).astype(t))
|
||||
self._testAll(
|
||||
np.random.rand(4, 4, 4, 2048).astype(t),
|
||||
np.random.rand(2048).astype(t))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test5DFloatTypes(self):
|
||||
for t in [np.float16, np.float32, np.float64]:
|
||||
self._testAll(
|
||||
np.random.rand(4, 3, 2, 3, 4).astype(t),
|
||||
np.random.rand(4).astype(t))
|
||||
|
||||
def _testGradient(self, np_input, bias, dtype, data_format, use_gpu):
|
||||
with self.cached_session(use_gpu=use_gpu):
|
||||
if data_format == "NCHW":
|
||||
np_input = self._NHWCToNCHW(np_input)
|
||||
input_tensor = constant_op.constant(
|
||||
np_input, shape=np_input.shape, dtype=dtype)
|
||||
bias_tensor = constant_op.constant(bias, shape=bias.shape, dtype=dtype)
|
||||
output_tensor = nn_ops.bias_add(
|
||||
input_tensor, bias_tensor, data_format=data_format)
|
||||
tensor_jacob_t, tensor_jacob_n = gradient_checker.compute_gradient(
|
||||
input_tensor, np_input.shape, output_tensor, np_input.shape)
|
||||
bias_jacob_t, bias_jacob_n = gradient_checker.compute_gradient(
|
||||
bias_tensor, bias.shape, output_tensor, np_input.shape)
|
||||
|
||||
# Test gradient of BiasAddGrad
|
||||
bias_add_grad = gradients_impl.gradients(
|
||||
nn_ops.l2_loss(output_tensor), bias_tensor)[0]
|
||||
grad_jacob_t, grad_jacob_n = gradient_checker.compute_gradient(
|
||||
output_tensor, np_input.shape, bias_add_grad, bias.shape)
|
||||
|
||||
if dtype == np.float16:
|
||||
# Compare fp16 theoretical gradients to fp32 numerical gradients,
|
||||
# since fp16 numerical gradients are too imprecise unless great
|
||||
# care is taken with choosing the inputs and the delta. This is
|
||||
# a weaker check (in particular, it does not test the op itself,
|
||||
# only its gradient), but it's much better than nothing.
|
||||
input_tensor = constant_op.constant(
|
||||
np_input, shape=np_input.shape, dtype=np.float32)
|
||||
bias_tensor = constant_op.constant(
|
||||
bias, shape=bias.shape, dtype=np.float32)
|
||||
output_tensor = nn_ops.bias_add(
|
||||
input_tensor, bias_tensor, data_format=data_format)
|
||||
_, tensor_jacob_n = gradient_checker.compute_gradient(input_tensor,
|
||||
np_input.shape,
|
||||
output_tensor,
|
||||
np_input.shape)
|
||||
_, bias_jacob_n = gradient_checker.compute_gradient(bias_tensor,
|
||||
bias.shape,
|
||||
output_tensor,
|
||||
np_input.shape)
|
||||
|
||||
bias_add_grad = gradients_impl.gradients(
|
||||
nn_ops.l2_loss(output_tensor), bias_tensor)[0]
|
||||
_, grad_jacob_n = gradient_checker.compute_gradient(output_tensor,
|
||||
np_input.shape,
|
||||
bias_add_grad,
|
||||
bias.shape)
|
||||
|
||||
threshold = 5e-3
|
||||
if dtype == dtypes.float64:
|
||||
threshold = 1e-10
|
||||
self.assertAllClose(tensor_jacob_t, tensor_jacob_n, threshold, threshold)
|
||||
self.assertAllClose(bias_jacob_t, bias_jacob_n, threshold, threshold)
|
||||
self.assertAllClose(grad_jacob_t, grad_jacob_n, threshold, threshold)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradientTensor2D(self):
|
||||
for (data_format, use_gpu) in ("NHWC", False), ("NHWC", True):
|
||||
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
|
||||
np_input = np.array(
|
||||
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
|
||||
dtype=dtype.as_numpy_dtype).reshape(3, 2)
|
||||
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
|
||||
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradientTensor3D(self):
|
||||
for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True),
|
||||
("NCHW", False), ("NCHW", True)]:
|
||||
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
|
||||
np_input = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
|
||||
dtype=dtype.as_numpy_dtype).reshape(1, 3, 2)
|
||||
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
|
||||
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradientTensor4D(self):
|
||||
for (data_format, use_gpu) in [("NHWC", False)]:
|
||||
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
|
||||
np_input = np.arange(
|
||||
1.0, 49.0, dtype=dtype.as_numpy_dtype).reshape(
|
||||
[2, 3, 4, 2]).astype(np.float32)
|
||||
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
|
||||
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
|
||||
np_input = np.arange(
|
||||
1.0, 513.0, dtype=dtype.as_numpy_dtype).reshape(
|
||||
[64, 2, 2, 2]).astype(np.float32)
|
||||
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
|
||||
np_input = np.arange(
|
||||
1.0, 513.0, dtype=dtype.as_numpy_dtype).reshape(
|
||||
[2, 2, 2, 64]).astype(np.float32)
|
||||
self._testGradient(np_input,
|
||||
np.random.rand(64).astype(dtype.as_numpy_dtype),
|
||||
dtype, data_format, use_gpu)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradientTensor5D(self):
|
||||
for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True),
|
||||
("NCHW", False), ("NCHW", True)]:
|
||||
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
|
||||
np_input = np.arange(
|
||||
1.0, 49.0, dtype=dtype.as_numpy_dtype).reshape(
|
||||
[1, 2, 3, 4, 2]).astype(np.float32)
|
||||
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
|
||||
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testEmpty(self):
|
||||
np.random.seed(7)
|
||||
for shape in (0, 0), (2, 0), (0, 2), (4, 3, 0), (4, 0, 3), (0, 4, 3):
|
||||
self._testAll(np.random.randn(*shape), np.random.randn(shape[-1]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testEmptyGradient(self):
|
||||
for (data_format, use_gpu) in ("NHWC", False), ("NHWC", True):
|
||||
for shape in (0, 0), (2, 0), (0, 2):
|
||||
self._testGradient(
|
||||
np.random.randn(*shape), np.random.randn(shape[-1]), dtypes.float64,
|
||||
data_format, use_gpu)
|
||||
|
||||
for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True),
|
||||
("NCHW", False), ("NCHW", True)]:
|
||||
for shape in (4, 3, 0), (4, 0, 3), (0, 4, 3):
|
||||
self._testGradient(
|
||||
np.random.randn(*shape),
|
||||
np.random.randn(shape[-1]), dtypes.float64, data_format, use_gpu)
|
||||
|
||||
BiasAddTest = bias_op_base.BiasAddTestBase
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -12,14 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for TF_CUDNN_DETERMINISTIC=true."""
|
||||
"""Tests for deterministic cuDNN functionality."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -28,12 +27,14 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
# The TF_CUDNN_DETERMINISTIC flag disables autotuning of cuDNN algorithms and
|
||||
# causes deterministic cuDNN algorithms to be selected when both deterministic
|
||||
# and non-deterministic algorithms are available. These tests are intended to
|
||||
# confirm that deterministic algorithms are chosen when
|
||||
# TF_CUDNN_DETERMINISTIC=true. The configurations tested were confirmed to
|
||||
# produce non-deterministic results without setting TF_CUDNN_DETERMINISTIC=true
|
||||
# Setting either of the two environment variables TF_CUDNN_DETERMINISTIC or
|
||||
# TF_DETERMINISTIC_OPS to "true" or "1" will disable autotuning of cuDNN
|
||||
# algorithms and cause deterministic cuDNN algorithms to be selected when both
|
||||
# deterministic and non-deterministic algorithms are available. These tests are
|
||||
# intended to confirm that deterministic algorithms are chosen when either
|
||||
# environment variable is set to "true" or "1". The tested configurations were
|
||||
# first confirmed to produce non-deterministic results when the environment
|
||||
# variables are not set.
|
||||
|
||||
_PADDING = 'SAME'
|
||||
_STRIDES = [1, 1, 1, 1]
|
||||
@ -96,8 +97,3 @@ class ConvolutionTest(test.TestCase):
|
||||
# TODO(duncanriach): (1) add test to confirm that forward autotuning is
|
||||
# disabled for cuDNN convolution; (2) add test for deterministic cuDNN
|
||||
# max-pooling
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
os.environ['TF_CUDNN_DETERMINISTIC'] = 'true'
|
||||
test.main()
|
@ -0,0 +1,33 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for TF_DETERMINISTIC_OPS=1."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.python.kernel_tests import cudnn_deterministic_base
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
ConvolutionTest = cudnn_deterministic_base.ConvolutionTest
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Note that the effect of setting the following environment variable to
|
||||
# 'true' is not tested. Unless we can find a simpler pattern for testing these
|
||||
# environment variables, it would require another test file to be added.
|
||||
os.environ['TF_DETERMINISTIC_OPS'] = '1'
|
||||
test.main()
|
33
tensorflow/python/kernel_tests/cudnn_deterministic_test.py
Normal file
33
tensorflow/python/kernel_tests/cudnn_deterministic_test.py
Normal file
@ -0,0 +1,33 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for TF_CUDNN_DETERMINISTIC=1."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.python.kernel_tests import cudnn_deterministic_base
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
ConvolutionTest = cudnn_deterministic_base.ConvolutionTest
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Note that the effect of setting the following environment variable to
|
||||
# 'true' is not tested. Unless we can find a simpler pattern for testing these
|
||||
# environment variables, it would require another test file to be added.
|
||||
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
|
||||
test.main()
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import numbers
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -2682,6 +2683,19 @@ def conv_transpose(input, # pylint: disable=redefined-builtin
|
||||
name=name)
|
||||
|
||||
|
||||
def _tf_deterministic_ops():
|
||||
if _tf_deterministic_ops.value is None:
|
||||
tf_deterministic_ops = os.environ.get("TF_DETERMINISTIC_OPS")
|
||||
if tf_deterministic_ops is not None:
|
||||
tf_deterministic_ops = tf_deterministic_ops.lower()
|
||||
_tf_deterministic_ops.value = (
|
||||
tf_deterministic_ops == "true" or tf_deterministic_ops == "1")
|
||||
return _tf_deterministic_ops.value
|
||||
|
||||
|
||||
_tf_deterministic_ops.value = None
|
||||
|
||||
|
||||
@tf_export("nn.bias_add")
|
||||
def bias_add(value, bias, data_format=None, name=None):
|
||||
"""Adds `bias` to `value`.
|
||||
@ -2697,11 +2711,19 @@ def bias_add(value, bias, data_format=None, name=None):
|
||||
bias: A 1-D `Tensor` with size matching the channel dimension of `value`.
|
||||
Must be the same type as `value` unless `value` is a quantized type,
|
||||
in which case a different quantized type may be used.
|
||||
data_format: A string. 'N...C' and 'NC...' are supported.
|
||||
data_format: A string. 'N...C' and 'NC...' are supported. If `None` (the
|
||||
default) is specified then 'N..C' is assumed.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `Tensor` with the same type as `value`.
|
||||
|
||||
Raises:
|
||||
ValueError if data format is unrecognized, if `value` has less than two
|
||||
dimensions when `data_format` is 'N..C'/`None` or `value` has less
|
||||
then three dimensions when `data_format` is `NC..`, if `bias` does not
|
||||
have exactly one dimension (is a vector), or if the size of `bias`
|
||||
does not match the size of the channel dimension of `value`.
|
||||
"""
|
||||
with ops.name_scope(name, "BiasAdd", [value, bias]) as name:
|
||||
if data_format is not None:
|
||||
@ -2715,7 +2737,25 @@ def bias_add(value, bias, data_format=None, name=None):
|
||||
if not context.executing_eagerly():
|
||||
value = ops.convert_to_tensor(value, name="input")
|
||||
bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias")
|
||||
return gen_nn_ops.bias_add(value, bias, data_format=data_format, name=name)
|
||||
|
||||
# TODO(duncanriach): Implement deterministic functionality at CUDA kernel
|
||||
# level.
|
||||
if _tf_deterministic_ops():
|
||||
# Note that this code does not implement the same error checks as the
|
||||
# pre-existing C++ ops.
|
||||
if data_format == "NCHW":
|
||||
broadcast_shape_head = [1, array_ops.size(bias)]
|
||||
broadcast_shape_tail = array_ops.ones(
|
||||
array_ops.rank(value) - 2, dtype=dtypes.int32)
|
||||
broadcast_shape = array_ops.concat(
|
||||
[broadcast_shape_head, broadcast_shape_tail], 0)
|
||||
return math_ops.add(
|
||||
value, array_ops.reshape(bias, broadcast_shape), name=name)
|
||||
else: # data_format == 'NHWC' or data_format == None
|
||||
return math_ops.add(value, bias, name=name)
|
||||
else:
|
||||
return gen_nn_ops.bias_add(
|
||||
value, bias, data_format=data_format, name=name)
|
||||
|
||||
|
||||
def bias_add_v1(value, bias, name=None):
|
||||
|
@ -632,14 +632,18 @@ bool BatchnormSpatialPersistentEnabled() {
|
||||
|
||||
// A helper function to decide whether to enable deterministic functionality.
|
||||
bool RequireDeterminism() {
|
||||
static bool is_enabled = [] {
|
||||
bool is_enabled = false;
|
||||
static bool require_determinism = [] {
|
||||
bool deterministic_ops = false;
|
||||
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
|
||||
/*default_val=*/false,
|
||||
&deterministic_ops));
|
||||
bool cudnn_deterministic = false;
|
||||
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
|
||||
/*default_val=*/false,
|
||||
&is_enabled));
|
||||
return is_enabled;
|
||||
&cudnn_deterministic));
|
||||
return deterministic_ops || cudnn_deterministic;
|
||||
}();
|
||||
return is_enabled;
|
||||
return require_determinism;
|
||||
}
|
||||
|
||||
std::tuple<int, int> GetCcMajorMinor(Stream* stream) {
|
||||
|
@ -89,6 +89,8 @@ COMMON_PIP_DEPS = [
|
||||
"//tensorflow/python/keras/distribute:distribute_strategy_test_lib",
|
||||
"//tensorflow/python/keras/distribute:multi_worker_testing_utils",
|
||||
"//tensorflow/python/keras/mixed_precision/experimental:test_util",
|
||||
"//tensorflow/python/kernel_tests:cudnn_deterministic_base",
|
||||
"//tensorflow/python/kernel_tests:bias_op_base",
|
||||
"//tensorflow/python/kernel_tests/random:util",
|
||||
"//tensorflow/python/kernel_tests/signal:test_util",
|
||||
"//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files",
|
||||
|
Loading…
Reference in New Issue
Block a user