parent
a24cfc8d09
commit
093f2278aa
@ -39,27 +39,13 @@ VALID_FFT_RANKS = (1, 2, 3)
|
|||||||
|
|
||||||
class BaseFFTOpsTest(test.TestCase):
|
class BaseFFTOpsTest(test.TestCase):
|
||||||
|
|
||||||
# TODO(b/128833319,b/123860949): Default use_placeholder=False in the below
|
def _compare(self, x, rank, fft_length=None, use_placeholder=False,
|
||||||
# methods after Eigen kernels are more precise or once XLA testing can disable
|
rtol=1e-4, atol=1e-4):
|
||||||
# constant folding. Alternatively, use placeholders by default for all these
|
|
||||||
# tests.
|
|
||||||
def _compare(self,
|
|
||||||
x,
|
|
||||||
rank,
|
|
||||||
fft_length=None,
|
|
||||||
use_placeholder=True,
|
|
||||||
rtol=1e-4,
|
|
||||||
atol=1e-4):
|
|
||||||
self._compareForward(x, rank, fft_length, use_placeholder, rtol, atol)
|
self._compareForward(x, rank, fft_length, use_placeholder, rtol, atol)
|
||||||
self._compareBackward(x, rank, fft_length, use_placeholder, rtol, atol)
|
self._compareBackward(x, rank, fft_length, use_placeholder, rtol, atol)
|
||||||
|
|
||||||
def _compareForward(self,
|
def _compareForward(self, x, rank, fft_length=None, use_placeholder=False,
|
||||||
x,
|
rtol=1e-4, atol=1e-4):
|
||||||
rank,
|
|
||||||
fft_length=None,
|
|
||||||
use_placeholder=False,
|
|
||||||
rtol=1e-4,
|
|
||||||
atol=1e-4):
|
|
||||||
x_np = self._npFFT(x, rank, fft_length)
|
x_np = self._npFFT(x, rank, fft_length)
|
||||||
if use_placeholder:
|
if use_placeholder:
|
||||||
x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype))
|
x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype))
|
||||||
@ -69,13 +55,8 @@ class BaseFFTOpsTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllClose(x_np, x_tf, rtol=rtol, atol=atol)
|
self.assertAllClose(x_np, x_tf, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
def _compareBackward(self,
|
def _compareBackward(self, x, rank, fft_length=None, use_placeholder=False,
|
||||||
x,
|
rtol=1e-4, atol=1e-4):
|
||||||
rank,
|
|
||||||
fft_length=None,
|
|
||||||
use_placeholder=False,
|
|
||||||
rtol=1e-4,
|
|
||||||
atol=1e-4):
|
|
||||||
x_np = self._npIFFT(x, rank, fft_length)
|
x_np = self._npIFFT(x, rank, fft_length)
|
||||||
if use_placeholder:
|
if use_placeholder:
|
||||||
x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype))
|
x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype))
|
||||||
|
@ -1928,7 +1928,7 @@ def tf_py_test(
|
|||||||
# kernels compiled with XLA.
|
# kernels compiled with XLA.
|
||||||
if xla_enable_strict_auto_jit:
|
if xla_enable_strict_auto_jit:
|
||||||
xla_enabled = True
|
xla_enabled = True
|
||||||
xla_test_true_list += [clean_dep("//tensorflow/python:is_xla_test_true")]
|
xla_test_true_list += ["//tensorflow/python:is_xla_test_true"]
|
||||||
if xla_enabled:
|
if xla_enabled:
|
||||||
additional_deps = additional_deps + tf_additional_xla_deps_py()
|
additional_deps = additional_deps + tf_additional_xla_deps_py()
|
||||||
if grpc_enabled:
|
if grpc_enabled:
|
||||||
@ -1972,37 +1972,19 @@ def gpu_py_test(
|
|||||||
xla_enable_strict_auto_jit = False,
|
xla_enable_strict_auto_jit = False,
|
||||||
xla_enabled = False,
|
xla_enabled = False,
|
||||||
grpc_enabled = False):
|
grpc_enabled = False):
|
||||||
|
# TODO(b/122522101): Don't ignore xla_enable_strict_auto_jit and enable additional
|
||||||
|
# XLA tests once enough compute resources are available.
|
||||||
|
_ignored = [xla_enable_strict_auto_jit]
|
||||||
if main == None:
|
if main == None:
|
||||||
main = name + ".py"
|
main = name + ".py"
|
||||||
for config in ["cpu", "gpu"]:
|
for config in ["cpu", "gpu"]:
|
||||||
test_name = name
|
test_name = name
|
||||||
test_tags = tags
|
test_tags = tags
|
||||||
suffix = ""
|
|
||||||
if config == "gpu":
|
if config == "gpu":
|
||||||
suffix += "_gpu"
|
test_name += "_gpu"
|
||||||
test_tags = test_tags + tf_gpu_tests_tags()
|
test_tags = test_tags + tf_gpu_tests_tags()
|
||||||
|
|
||||||
# We don't care about testing XLA autojit on CPU for now.
|
|
||||||
if xla_enable_strict_auto_jit and config != "cpu":
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = name + "_xla" + suffix,
|
name = test_name,
|
||||||
size = size,
|
|
||||||
srcs = srcs,
|
|
||||||
additional_deps = additional_deps,
|
|
||||||
args = args,
|
|
||||||
data = data,
|
|
||||||
flaky = flaky,
|
|
||||||
grpc_enabled = grpc_enabled,
|
|
||||||
kernels = kernels,
|
|
||||||
main = main,
|
|
||||||
shard_count = shard_count,
|
|
||||||
tags = test_tags,
|
|
||||||
xla_enabled = xla_enabled,
|
|
||||||
xla_enable_strict_auto_jit = True,
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_py_test(
|
|
||||||
name = name + suffix,
|
|
||||||
size = size,
|
size = size,
|
||||||
srcs = srcs,
|
srcs = srcs,
|
||||||
additional_deps = additional_deps,
|
additional_deps = additional_deps,
|
||||||
@ -2113,22 +2095,10 @@ def gpu_py_tests(
|
|||||||
xla_enable_strict_auto_jit = False,
|
xla_enable_strict_auto_jit = False,
|
||||||
xla_enabled = False,
|
xla_enabled = False,
|
||||||
grpc_enabled = False):
|
grpc_enabled = False):
|
||||||
|
# TODO(b/122522101): Don't ignore xla_enable_strict_auto_jit and enable additional
|
||||||
|
# XLA tests once enough compute resources are available.
|
||||||
|
_ignored = [xla_enable_strict_auto_jit]
|
||||||
test_tags = tags + tf_gpu_tests_tags()
|
test_tags = tags + tf_gpu_tests_tags()
|
||||||
if xla_enable_strict_auto_jit:
|
|
||||||
py_tests(
|
|
||||||
name = name,
|
|
||||||
size = size,
|
|
||||||
srcs = srcs,
|
|
||||||
additional_deps = additional_deps,
|
|
||||||
data = data,
|
|
||||||
grpc_enabled = grpc_enabled,
|
|
||||||
kernels = kernels,
|
|
||||||
prefix = prefix + "gpu_xla",
|
|
||||||
shard_count = shard_count,
|
|
||||||
tags = test_tags,
|
|
||||||
xla_enabled = xla_enabled,
|
|
||||||
xla_enable_strict_auto_jit = True,
|
|
||||||
)
|
|
||||||
py_tests(
|
py_tests(
|
||||||
name = name,
|
name = name,
|
||||||
size = size,
|
size = size,
|
||||||
|
@ -67,7 +67,6 @@ COMMON_PIP_DEPS = [
|
|||||||
"//tensorflow/python/autograph/pyct/common_transformers:common_transformers",
|
"//tensorflow/python/autograph/pyct/common_transformers:common_transformers",
|
||||||
"//tensorflow/python:cond_v2",
|
"//tensorflow/python:cond_v2",
|
||||||
"//tensorflow/python:distributed_framework_test_lib",
|
"//tensorflow/python:distributed_framework_test_lib",
|
||||||
"//tensorflow/python:is_xla_test_true",
|
|
||||||
"//tensorflow/python:meta_graph_testdata",
|
"//tensorflow/python:meta_graph_testdata",
|
||||||
"//tensorflow/python:spectral_ops_test_util",
|
"//tensorflow/python:spectral_ops_test_util",
|
||||||
"//tensorflow/python:util_example_parser_configuration",
|
"//tensorflow/python:util_example_parser_configuration",
|
||||||
|
@ -154,9 +154,7 @@ def main():
|
|||||||
|
|
||||||
missing_dependencies = []
|
missing_dependencies = []
|
||||||
# File extensions and endings to ignore
|
# File extensions and endings to ignore
|
||||||
ignore_extensions = [
|
ignore_extensions = ["_test", "_test.py", "_test_gpu", "_test_gpu.py"]
|
||||||
"_test", "_test.py", "_test_gpu", "_test_gpu.py", "_test_xla_gpu"
|
|
||||||
]
|
|
||||||
|
|
||||||
ignored_files_count = 0
|
ignored_files_count = 0
|
||||||
blacklisted_dependencies_count = len(DEPENDENCY_BLACKLIST)
|
blacklisted_dependencies_count = len(DEPENDENCY_BLACKLIST)
|
||||||
|
Loading…
Reference in New Issue
Block a user