Enable more TFRT tests.
PiperOrigin-RevId: 312230367 Change-Id: Icc82c7ce424a1db2ca3cf2eabc1e5932fec7b6a7
This commit is contained in:
parent
3c6dadd17f
commit
f7d038cc3b
@ -3,7 +3,7 @@
|
|||||||
# ":platform" - Low-level and platform-specific Python code.
|
# ":platform" - Low-level and platform-specific Python code.
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "py_strict_library")
|
load("//tensorflow:tensorflow.bzl", "py_strict_library")
|
||||||
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "if_xla_available", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_test")
|
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "if_xla_available", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py")
|
||||||
|
|
||||||
# buildifier: disable=same-origin-load
|
# buildifier: disable=same-origin-load
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
||||||
@ -26,6 +26,9 @@ load("//tensorflow:tensorflow.bzl", "tf_external_workspace_visible")
|
|||||||
# buildifier: disable=same-origin-load
|
# buildifier: disable=same-origin-load
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_pybind_cc_library_wrapper")
|
load("//tensorflow:tensorflow.bzl", "tf_pybind_cc_library_wrapper")
|
||||||
|
|
||||||
|
# buildifier: disable=same-origin-load
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||||
|
|
||||||
# buildifier: disable=same-origin-load
|
# buildifier: disable=same-origin-load
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_py_build_info_genrule")
|
load("//tensorflow:tensorflow.bzl", "tf_py_build_info_genrule")
|
||||||
load("//tensorflow/core/platform:build_config.bzl", "pyx_library", "tf_additional_all_protos", "tf_additional_lib_deps", "tf_proto_library", "tf_proto_library_py", "tf_protos_grappler") # @unused
|
load("//tensorflow/core/platform:build_config.bzl", "pyx_library", "tf_additional_all_protos", "tf_additional_lib_deps", "tf_proto_library", "tf_proto_library_py", "tf_protos_grappler") # @unused
|
||||||
@ -2071,6 +2074,7 @@ tf_py_test(
|
|||||||
srcs = ["framework/constant_op_test.py"],
|
srcs = ["framework/constant_op_test.py"],
|
||||||
main = "framework/constant_op_test.py",
|
main = "framework/constant_op_test.py",
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
|
tfrt_enabled = True,
|
||||||
deps = [
|
deps = [
|
||||||
":constant_op",
|
":constant_op",
|
||||||
],
|
],
|
||||||
|
@ -108,15 +108,15 @@ class ResNet50Test(tf.test.TestCase):
|
|||||||
self._apply(defun=False)
|
self._apply(defun=False)
|
||||||
|
|
||||||
@test_util.disable_tfrt(
|
@test_util.disable_tfrt(
|
||||||
'TFE_ContextGetExecutorForThread not implemented for tfrt')
|
'TFE_ContextGetExecutorForThread not implemented b/156188669')
|
||||||
def test_apply_async(self):
|
def test_apply_async(self):
|
||||||
self._apply(defun=False, execution_mode=context.ASYNC)
|
self._apply(defun=False, execution_mode=context.ASYNC)
|
||||||
|
|
||||||
@test_util.disable_tfrt('Graph is not supported yet.')
|
@test_util.disable_tfrt('Graph is not supported yet. b/156187905')
|
||||||
def test_apply_with_defun(self):
|
def test_apply_with_defun(self):
|
||||||
self._apply(defun=True)
|
self._apply(defun=True)
|
||||||
|
|
||||||
@test_util.disable_tfrt('Graph is not supported yet.')
|
@test_util.disable_tfrt('Graph is not supported yet. b/156187905')
|
||||||
def test_apply_with_defun_async(self):
|
def test_apply_with_defun_async(self):
|
||||||
self._apply(defun=True, execution_mode=context.ASYNC)
|
self._apply(defun=True, execution_mode=context.ASYNC)
|
||||||
|
|
||||||
@ -217,7 +217,7 @@ class ResNet50Test(tf.test.TestCase):
|
|||||||
def test_train(self):
|
def test_train(self):
|
||||||
self._test_train()
|
self._test_train()
|
||||||
|
|
||||||
@test_util.disable_tfrt('b/155260334')
|
@test_util.disable_tfrt('TFE_ContextGetExecutorForThread missing b/156188669')
|
||||||
def test_train_async(self):
|
def test_train_async(self):
|
||||||
self._test_train(execution_mode=context.ASYNC)
|
self._test_train(execution_mode=context.ASYNC)
|
||||||
|
|
||||||
@ -329,7 +329,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
|
|||||||
defun=False,
|
defun=False,
|
||||||
execution_mode=context.ASYNC)
|
execution_mode=context.ASYNC)
|
||||||
|
|
||||||
@test_util.disable_tfrt('Graph is not supported yet.')
|
@test_util.disable_tfrt('Graph is not supported yet. b/156187905')
|
||||||
def benchmark_eager_apply_with_defun(self):
|
def benchmark_eager_apply_with_defun(self):
|
||||||
self._benchmark_eager_apply(
|
self._benchmark_eager_apply(
|
||||||
'eager_apply_with_defun',
|
'eager_apply_with_defun',
|
||||||
@ -389,7 +389,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
|
|||||||
defun=False,
|
defun=False,
|
||||||
execution_mode=context.ASYNC)
|
execution_mode=context.ASYNC)
|
||||||
|
|
||||||
@test_util.disable_tfrt('Graph is not supported yet.')
|
@test_util.disable_tfrt('Graph is not supported yet. b/156187905')
|
||||||
def benchmark_eager_train_with_defun(self):
|
def benchmark_eager_train_with_defun(self):
|
||||||
self._benchmark_eager_train(
|
self._benchmark_eager_train(
|
||||||
'eager_train_with_defun', MockIterator,
|
'eager_train_with_defun', MockIterator,
|
||||||
@ -408,7 +408,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
|
|||||||
resnet50_test_util.device_and_data_format(),
|
resnet50_test_util.device_and_data_format(),
|
||||||
defun=False)
|
defun=False)
|
||||||
|
|
||||||
@test_util.disable_tfrt('Graph is not supported yet.')
|
@test_util.disable_tfrt('Graph is not supported yet. b/156187905')
|
||||||
def benchmark_eager_train_datasets_with_defun(self):
|
def benchmark_eager_train_datasets_with_defun(self):
|
||||||
|
|
||||||
def make_iterator(tensors):
|
def make_iterator(tensors):
|
||||||
|
@ -618,7 +618,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
|||||||
self._benchmark_tfe_py_execute_matmul(
|
self._benchmark_tfe_py_execute_matmul(
|
||||||
m, transpose_b=False, num_iters=self._num_iters_2_by_2)
|
m, transpose_b=False, num_iters=self._num_iters_2_by_2)
|
||||||
|
|
||||||
@test_util.disable_tfrt("defun not supported")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
def benchmark_defun_matmul_2_by_2_GPU(self):
|
def benchmark_defun_matmul_2_by_2_GPU(self):
|
||||||
if not context.num_gpus():
|
if not context.num_gpus():
|
||||||
return
|
return
|
||||||
@ -639,7 +639,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
|||||||
num_iters=self._num_iters_2_by_2,
|
num_iters=self._num_iters_2_by_2,
|
||||||
execution_mode=context.ASYNC)
|
execution_mode=context.ASYNC)
|
||||||
|
|
||||||
@test_util.disable_tfrt("function not supported")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
def benchmark_nested_defun_matmul_2_by_2(self):
|
def benchmark_nested_defun_matmul_2_by_2(self):
|
||||||
m = self._m_2_by_2.cpu()
|
m = self._m_2_by_2.cpu()
|
||||||
self._benchmark_nested_defun_matmul(
|
self._benchmark_nested_defun_matmul(
|
||||||
@ -687,7 +687,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
|||||||
self._benchmark_tfe_py_execute_matmul(
|
self._benchmark_tfe_py_execute_matmul(
|
||||||
m, transpose_b=True, num_iters=self._num_iters_100_by_784)
|
m, transpose_b=True, num_iters=self._num_iters_100_by_784)
|
||||||
|
|
||||||
@test_util.disable_tfrt("function not supported")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
def benchmark_defun_matmul_100_by_784_CPU(self):
|
def benchmark_defun_matmul_100_by_784_CPU(self):
|
||||||
with context.device(CPU):
|
with context.device(CPU):
|
||||||
m = self._m_100_by_784.cpu()
|
m = self._m_100_by_784.cpu()
|
||||||
@ -815,35 +815,35 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
|||||||
func()
|
func()
|
||||||
self._run(func, 3000)
|
self._run(func, 3000)
|
||||||
|
|
||||||
@test_util.disable_tfrt("defun not supported")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
def benchmark_forwardprop_matmul_256_by_2096_CPU(self):
|
def benchmark_forwardprop_matmul_256_by_2096_CPU(self):
|
||||||
self._benchmark_forwardprop_matmul_CPU(shape=(256, 2096))
|
self._benchmark_forwardprop_matmul_CPU(shape=(256, 2096))
|
||||||
|
|
||||||
@test_util.disable_tfrt("defun not supported")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
def benchmark_forwardprop_in_defun_matmul_256_by_2096_CPU(self):
|
def benchmark_forwardprop_in_defun_matmul_256_by_2096_CPU(self):
|
||||||
self._benchmark_forwardprop_in_defun_matmul_CPU(shape=(256, 2096))
|
self._benchmark_forwardprop_in_defun_matmul_CPU(shape=(256, 2096))
|
||||||
|
|
||||||
@test_util.disable_tfrt("defun not supported")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
def benchmark_forwardprop_in_defun_of_defun_matmul_256_by_2096_CPU(self):
|
def benchmark_forwardprop_in_defun_of_defun_matmul_256_by_2096_CPU(self):
|
||||||
self._benchmark_forwardprop_in_defun_of_defun_matmul_CPU(shape=(256, 2096))
|
self._benchmark_forwardprop_in_defun_of_defun_matmul_CPU(shape=(256, 2096))
|
||||||
|
|
||||||
@test_util.disable_tfrt("defun not supported")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
def benchmark_forwardprop_of_defun_matmul_256_by_2096_CPU(self):
|
def benchmark_forwardprop_of_defun_matmul_256_by_2096_CPU(self):
|
||||||
self._benchmark_forwardprop_of_defun_matmul_CPU(shape=(256, 2096))
|
self._benchmark_forwardprop_of_defun_matmul_CPU(shape=(256, 2096))
|
||||||
|
|
||||||
@test_util.disable_tfrt("defun not supported")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
def benchmark_forwardprop_matmul_100_by_784_CPU(self):
|
def benchmark_forwardprop_matmul_100_by_784_CPU(self):
|
||||||
self._benchmark_forwardprop_matmul_CPU(shape=(100, 784))
|
self._benchmark_forwardprop_matmul_CPU(shape=(100, 784))
|
||||||
|
|
||||||
@test_util.disable_tfrt("defun not supported")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
def benchmark_forwardprop_in_defun_matmul_100_by_784_CPU(self):
|
def benchmark_forwardprop_in_defun_matmul_100_by_784_CPU(self):
|
||||||
self._benchmark_forwardprop_in_defun_matmul_CPU(shape=(100, 784))
|
self._benchmark_forwardprop_in_defun_matmul_CPU(shape=(100, 784))
|
||||||
|
|
||||||
@test_util.disable_tfrt("defun not supported")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
def benchmark_forwardprop_in_defun_of_defun_matmul_100_by_784_CPU(self):
|
def benchmark_forwardprop_in_defun_of_defun_matmul_100_by_784_CPU(self):
|
||||||
self._benchmark_forwardprop_in_defun_of_defun_matmul_CPU(shape=(100, 784))
|
self._benchmark_forwardprop_in_defun_of_defun_matmul_CPU(shape=(100, 784))
|
||||||
|
|
||||||
@test_util.disable_tfrt("defun not supported")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
def benchmark_forwardprop_of_defun_matmul_100_by_784_CPU(self):
|
def benchmark_forwardprop_of_defun_matmul_100_by_784_CPU(self):
|
||||||
self._benchmark_forwardprop_of_defun_matmul_CPU(shape=(100, 784))
|
self._benchmark_forwardprop_of_defun_matmul_CPU(shape=(100, 784))
|
||||||
|
|
||||||
@ -1097,7 +1097,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
|||||||
m = resource_variable_ops.ResourceVariable(self._m_2_by_2)
|
m = resource_variable_ops.ResourceVariable(self._m_2_by_2)
|
||||||
self._benchmark_transpose(m, num_iters=self._num_iters_2_by_2)
|
self._benchmark_transpose(m, num_iters=self._num_iters_2_by_2)
|
||||||
|
|
||||||
@test_util.disable_tfrt("defun not supported")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
def benchmark_defun_without_signature(self):
|
def benchmark_defun_without_signature(self):
|
||||||
|
|
||||||
def func(t1, t2, t3, t4, t5, t6, t7, t8):
|
def func(t1, t2, t3, t4, t5, t6, t7, t8):
|
||||||
@ -1109,7 +1109,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
|||||||
cache_computation = lambda: defined(t, t, t, t, t, t, t, t)
|
cache_computation = lambda: defined(t, t, t, t, t, t, t, t)
|
||||||
self._run(cache_computation, 30000)
|
self._run(cache_computation, 30000)
|
||||||
|
|
||||||
@test_util.disable_tfrt("defun not supported")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
def benchmark_defun_without_signature_and_with_kwargs(self):
|
def benchmark_defun_without_signature_and_with_kwargs(self):
|
||||||
|
|
||||||
def func(t1, t2, t3, t4, t5, t6, t7, t8):
|
def func(t1, t2, t3, t4, t5, t6, t7, t8):
|
||||||
@ -1122,7 +1122,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
|||||||
return defined(t1=t, t2=t, t3=t, t4=t, t5=t, t6=t, t7=t, t8=t)
|
return defined(t1=t, t2=t, t3=t, t4=t, t5=t, t6=t, t7=t, t8=t)
|
||||||
self._run(cache_computation, 30000)
|
self._run(cache_computation, 30000)
|
||||||
|
|
||||||
@test_util.disable_tfrt("defun not supported")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
def benchmark_defun_with_signature(self):
|
def benchmark_defun_with_signature(self):
|
||||||
|
|
||||||
def func(t1, t2, t3, t4, t5, t6, t7, t8):
|
def func(t1, t2, t3, t4, t5, t6, t7, t8):
|
||||||
@ -1135,7 +1135,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
|||||||
signature_computation = lambda: defined(t, t, t, t, t, t, t, t)
|
signature_computation = lambda: defined(t, t, t, t, t, t, t, t)
|
||||||
self._run(signature_computation, 30000)
|
self._run(signature_computation, 30000)
|
||||||
|
|
||||||
@test_util.disable_tfrt("defun not supported")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
def benchmark_defun_with_signature_and_kwargs(self):
|
def benchmark_defun_with_signature_and_kwargs(self):
|
||||||
|
|
||||||
def func(t1, t2, t3, t4, t5, t6, t7, t8):
|
def func(t1, t2, t3, t4, t5, t6, t7, t8):
|
||||||
@ -1305,11 +1305,11 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
|||||||
resources.append(resource_variable_ops.ResourceVariable(self._m_2))
|
resources.append(resource_variable_ops.ResourceVariable(self._m_2))
|
||||||
self._run(lambda: add_all(resources), num_iters)
|
self._run(lambda: add_all(resources), num_iters)
|
||||||
|
|
||||||
@test_util.disable_tfrt("funtion not supported")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
def benchmarkFunctionWithFiveResourceInputs(self):
|
def benchmarkFunctionWithFiveResourceInputs(self):
|
||||||
self._benchmarkFunctionWithResourceInputs(5, 1000)
|
self._benchmarkFunctionWithResourceInputs(5, 1000)
|
||||||
|
|
||||||
@test_util.disable_tfrt("funtion not supported")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
def benchmarkFunctionWithFiveHundredResourceInputs(self):
|
def benchmarkFunctionWithFiveHundredResourceInputs(self):
|
||||||
self._benchmarkFunctionWithResourceInputs(500, 100)
|
self._benchmarkFunctionWithResourceInputs(500, 100)
|
||||||
|
|
||||||
@ -1344,15 +1344,15 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
|||||||
with context.device(CPU):
|
with context.device(CPU):
|
||||||
self._run(benchmark_fn, 10)
|
self._run(benchmark_fn, 10)
|
||||||
|
|
||||||
@test_util.disable_tfrt("funtion not supported")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
def benchmarkTenThousandResourceReadsInCondInInnerFunc(self):
|
def benchmarkTenThousandResourceReadsInCondInInnerFunc(self):
|
||||||
self._benchmarkResourceReadsInCondInInnerFunc(10000)
|
self._benchmarkResourceReadsInCondInInnerFunc(10000)
|
||||||
|
|
||||||
@test_util.disable_tfrt("funtion not supported")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
def benchmarkHundredResourceReadsInCondInInnerFunc(self):
|
def benchmarkHundredResourceReadsInCondInInnerFunc(self):
|
||||||
self._benchmarkResourceReadsInCondInInnerFunc(100)
|
self._benchmarkResourceReadsInCondInInnerFunc(100)
|
||||||
|
|
||||||
@test_util.disable_tfrt("funtion not supported")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
def benchmarkTenResourceReadsInCondInInnerFunc(self):
|
def benchmarkTenResourceReadsInCondInInnerFunc(self):
|
||||||
self._benchmarkResourceReadsInCondInInnerFunc(10)
|
self._benchmarkResourceReadsInCondInInnerFunc(10)
|
||||||
|
|
||||||
|
@ -91,7 +91,7 @@ class ResourceTest(test_util.TensorFlowTestCase):
|
|||||||
resources.shared_resources()).eval()), 0)
|
resources.shared_resources()).eval()), 0)
|
||||||
|
|
||||||
|
|
||||||
@test_util.disable_tfrt("Graph is not supported yet.")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
class TensorAndShapeTest(test_util.TensorFlowTestCase):
|
class TensorAndShapeTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def testShape(self):
|
def testShape(self):
|
||||||
@ -311,7 +311,8 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
|
|||||||
del x
|
del x
|
||||||
self.assertIsNotNone(x_ref.deref())
|
self.assertIsNotNone(x_ref.deref())
|
||||||
|
|
||||||
@test_util.disable_tfrt("Graph mode is not supported yet.")
|
|
||||||
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class IndexedSlicesTest(test_util.TensorFlowTestCase):
|
class IndexedSlicesTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
@ -356,7 +357,7 @@ class IndexedSlicesTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllEqual(x.indices, [0, 2])
|
self.assertAllEqual(x.indices, [0, 2])
|
||||||
|
|
||||||
|
|
||||||
@test_util.disable_tfrt("Graph mode is not supported yet.")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class IndexedSlicesSpecTest(test_util.TensorFlowTestCase,
|
class IndexedSlicesSpecTest(test_util.TensorFlowTestCase,
|
||||||
parameterized.TestCase):
|
parameterized.TestCase):
|
||||||
@ -502,7 +503,7 @@ def _apply_op(g, *args, **kwargs):
|
|||||||
return op.outputs
|
return op.outputs
|
||||||
|
|
||||||
|
|
||||||
@test_util.disable_tfrt("Graph is not supported yet.")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
class OperationTest(test_util.TensorFlowTestCase):
|
class OperationTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
@ -1445,7 +1446,7 @@ class NameTest(test_util.TensorFlowTestCase):
|
|||||||
g.create_op("FloatOutput", [], [dtypes.float32]).name)
|
g.create_op("FloatOutput", [], [dtypes.float32]).name)
|
||||||
|
|
||||||
|
|
||||||
@test_util.disable_tfrt("Device API are not supported yet.")
|
@test_util.disable_tfrt("Device API are not supported yet. b/156188344")
|
||||||
class DeviceTest(test_util.TensorFlowTestCase):
|
class DeviceTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def testNoDevice(self):
|
def testNoDevice(self):
|
||||||
@ -2026,7 +2027,7 @@ class CollectionTest(test_util.TensorFlowTestCase):
|
|||||||
# Collections are ordered.
|
# Collections are ordered.
|
||||||
self.assertEqual([90, 100], ops.get_collection("key"))
|
self.assertEqual([90, 100], ops.get_collection("key"))
|
||||||
|
|
||||||
@test_util.disable_tfrt("Functions are not supported yet.")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
def test_defun(self):
|
def test_defun(self):
|
||||||
with context.eager_mode():
|
with context.eager_mode():
|
||||||
|
|
||||||
@ -2133,7 +2134,7 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase):
|
|||||||
# e should be dominated by c.
|
# e should be dominated by c.
|
||||||
self.assertEqual(e.op.control_inputs, [])
|
self.assertEqual(e.op.control_inputs, [])
|
||||||
|
|
||||||
@test_util.disable_tfrt("Graph is not supported yet.")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testEager(self):
|
def testEager(self):
|
||||||
def future():
|
def future():
|
||||||
@ -2454,7 +2455,7 @@ class OpScopeTest(test_util.TensorFlowTestCase):
|
|||||||
self._testGraphElements([a, variable, b])
|
self._testGraphElements([a, variable, b])
|
||||||
|
|
||||||
|
|
||||||
@test_util.disable_tfrt("Graphs are not supported yet.")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
class InitScopeTest(test_util.TensorFlowTestCase):
|
class InitScopeTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def testClearsControlDependencies(self):
|
def testClearsControlDependencies(self):
|
||||||
@ -2757,7 +2758,7 @@ class InitScopeTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertFalse(self.evaluate(f()))
|
self.assertFalse(self.evaluate(f()))
|
||||||
|
|
||||||
|
|
||||||
@test_util.disable_tfrt("Graphs are not supported yet.")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
class GraphTest(test_util.TensorFlowTestCase):
|
class GraphTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -3235,7 +3236,7 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
|
|||||||
b = variables.Variable([3.0], name="b")
|
b = variables.Variable([3.0], name="b")
|
||||||
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
|
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
|
||||||
|
|
||||||
@test_util.disable_tfrt("Functions are not supported yet.")
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
def testColocateWithVariableInFunction(self):
|
def testColocateWithVariableInFunction(self):
|
||||||
v = variables.Variable(1.)
|
v = variables.Variable(1.)
|
||||||
|
|
||||||
|
@ -864,6 +864,7 @@ cuda_py_test(
|
|||||||
srcs = ["resource_variable_ops_test.py"],
|
srcs = ["resource_variable_ops_test.py"],
|
||||||
# TODO(b/128347673): Re-enable.
|
# TODO(b/128347673): Re-enable.
|
||||||
tags = ["no_windows"],
|
tags = ["no_windows"],
|
||||||
|
tfrt_enabled = True,
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
|
@ -87,6 +87,7 @@ cuda_py_test(
|
|||||||
name = "random_ops_test",
|
name = "random_ops_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["random_ops_test.py"],
|
srcs = ["random_ops_test.py"],
|
||||||
|
tfrt_enabled = True,
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
@ -101,6 +102,7 @@ cuda_py_test(
|
|||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["stateless_random_ops_test.py"],
|
srcs = ["stateless_random_ops_test.py"],
|
||||||
shard_count = 2,
|
shard_count = 2,
|
||||||
|
tfrt_enabled = True,
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
|
@ -336,6 +336,8 @@ class RandomUniformTest(RandomOpTestCommon):
|
|||||||
self.assertLess(error.max(), 5 * std)
|
self.assertLess(error.max(), 5 * std)
|
||||||
|
|
||||||
# Check that minval = maxval is fine iff we're producing no numbers
|
# Check that minval = maxval is fine iff we're producing no numbers
|
||||||
|
@test_util.disable_tfrt(
|
||||||
|
"TFE_TensorHandleToNumpy not implemented yet. b/156191611")
|
||||||
def testUniformIntsDegenerate(self):
|
def testUniformIntsDegenerate(self):
|
||||||
for dt in dtypes.int32, dtypes.int64:
|
for dt in dtypes.int32, dtypes.int64:
|
||||||
def sample(n):
|
def sample(n):
|
||||||
|
@ -154,44 +154,54 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
**kwds),
|
**kwds),
|
||||||
functools.partial(random_ops.random_poisson, shape=(10,), **kwds))
|
functools.partial(random_ops.random_poisson, shape=(10,), **kwds))
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testMatchFloat(self):
|
def testMatchFloat(self):
|
||||||
self._test_match(self._float_cases())
|
self._test_match(self._float_cases())
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testMatchInt(self):
|
def testMatchInt(self):
|
||||||
self._test_match(self._int_cases())
|
self._test_match(self._int_cases())
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testMatchMultinomial(self):
|
def testMatchMultinomial(self):
|
||||||
self._test_match(self._multinomial_cases())
|
self._test_match(self._multinomial_cases())
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testMatchGamma(self):
|
def testMatchGamma(self):
|
||||||
self._test_match(self._gamma_cases())
|
self._test_match(self._gamma_cases())
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testMatchPoisson(self):
|
def testMatchPoisson(self):
|
||||||
self._test_match(self._poisson_cases())
|
self._test_match(self._poisson_cases())
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testDeterminismFloat(self):
|
def testDeterminismFloat(self):
|
||||||
self._test_determinism(
|
self._test_determinism(
|
||||||
self._float_cases(shape_dtypes=(dtypes.int32, dtypes.int64)))
|
self._float_cases(shape_dtypes=(dtypes.int32, dtypes.int64)))
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testDeterminismInt(self):
|
def testDeterminismInt(self):
|
||||||
self._test_determinism(
|
self._test_determinism(
|
||||||
self._int_cases(shape_dtypes=(dtypes.int32, dtypes.int64)))
|
self._int_cases(shape_dtypes=(dtypes.int32, dtypes.int64)))
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testDeterminismMultinomial(self):
|
def testDeterminismMultinomial(self):
|
||||||
self._test_determinism(self._multinomial_cases())
|
self._test_determinism(self._multinomial_cases())
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testDeterminismGamma(self):
|
def testDeterminismGamma(self):
|
||||||
self._test_determinism(self._gamma_cases())
|
self._test_determinism(self._gamma_cases())
|
||||||
|
|
||||||
|
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testDeterminismPoisson(self):
|
def testDeterminismPoisson(self):
|
||||||
self._test_determinism(self._poisson_cases())
|
self._test_determinism(self._poisson_cases())
|
||||||
|
@ -57,6 +57,8 @@ from tensorflow.python.training import training_util
|
|||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.disable_tfrt(
|
||||||
|
"Trying to assign variable with wrong dtype. b/156200342")
|
||||||
@test_util.with_control_flow_v2
|
@test_util.with_control_flow_v2
|
||||||
class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||||
parameterized.TestCase):
|
parameterized.TestCase):
|
||||||
@ -332,6 +334,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||||||
g = gradients_impl.gradients(c, [b], unconnected_gradients="zero")[0]
|
g = gradients_impl.gradients(c, [b], unconnected_gradients="zero")[0]
|
||||||
self.assertAllEqual(g.shape.as_list(), [1, 2])
|
self.assertAllEqual(g.shape.as_list(), [1, 2])
|
||||||
|
|
||||||
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testGradientCondInWhileLoop(self):
|
def testGradientCondInWhileLoop(self):
|
||||||
v = resource_variable_ops.ResourceVariable(initial_value=1.0)
|
v = resource_variable_ops.ResourceVariable(initial_value=1.0)
|
||||||
@ -965,6 +968,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||||||
assign = var.assign(np.zeros(shape=[2, 2]))
|
assign = var.assign(np.zeros(shape=[2, 2]))
|
||||||
self.evaluate(assign)
|
self.evaluate(assign)
|
||||||
|
|
||||||
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
@test_util.disable_xla("XLA doesn't allow changing shape at assignment, as "
|
@test_util.disable_xla("XLA doesn't allow changing shape at assignment, as "
|
||||||
"dictated by tf2xla/xla_resource.cc:SetTypeAndShape")
|
"dictated by tf2xla/xla_resource.cc:SetTypeAndShape")
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
@ -1327,6 +1331,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||||||
|
|
||||||
# TODO(ebrevdo): Add run_in_graph_and_eager_modes once we can create
|
# TODO(ebrevdo): Add run_in_graph_and_eager_modes once we can create
|
||||||
# EagerTensor constants with TensorProto inputs.
|
# EagerTensor constants with TensorProto inputs.
|
||||||
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testVariantInitializer(self):
|
def testVariantInitializer(self):
|
||||||
variant_shape_and_type_data = self.create_variant_shape_and_type_data()
|
variant_shape_and_type_data = self.create_variant_shape_and_type_data()
|
||||||
@ -1520,6 +1525,7 @@ class PerReplicaResourceHandleTest(test_util.TensorFlowTestCase):
|
|||||||
context.LogicalDeviceConfiguration(),
|
context.LogicalDeviceConfiguration(),
|
||||||
])
|
])
|
||||||
|
|
||||||
|
@test_util.disable_tfrt("Multiple device support. b/154956430")
|
||||||
def testAllowedDevices(self):
|
def testAllowedDevices(self):
|
||||||
device0 = "/job:localhost/replica:0/task:0/device:CPU:0"
|
device0 = "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||||
device1 = "/job:localhost/replica:0/task:0/device:CPU:1"
|
device1 = "/job:localhost/replica:0/task:0/device:CPU:1"
|
||||||
|
Loading…
Reference in New Issue
Block a user