Enable more TFRT tests.

PiperOrigin-RevId: 312230367
Change-Id: Icc82c7ce424a1db2ca3cf2eabc1e5932fec7b6a7
This commit is contained in:
Kibeom Kim 2020-05-19 00:20:53 -07:00 committed by TensorFlower Gardener
parent 3c6dadd17f
commit f7d038cc3b
9 changed files with 64 additions and 38 deletions

View File

@ -3,7 +3,7 @@
# ":platform" - Low-level and platform-specific Python code. # ":platform" - Low-level and platform-specific Python code.
load("//tensorflow:tensorflow.bzl", "py_strict_library") load("//tensorflow:tensorflow.bzl", "py_strict_library")
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "if_xla_available", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_test") load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "if_xla_available", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py")
# buildifier: disable=same-origin-load # buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension") load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
@ -26,6 +26,9 @@ load("//tensorflow:tensorflow.bzl", "tf_external_workspace_visible")
# buildifier: disable=same-origin-load # buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_pybind_cc_library_wrapper") load("//tensorflow:tensorflow.bzl", "tf_pybind_cc_library_wrapper")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_py_test")
# buildifier: disable=same-origin-load # buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_py_build_info_genrule") load("//tensorflow:tensorflow.bzl", "tf_py_build_info_genrule")
load("//tensorflow/core/platform:build_config.bzl", "pyx_library", "tf_additional_all_protos", "tf_additional_lib_deps", "tf_proto_library", "tf_proto_library_py", "tf_protos_grappler") # @unused load("//tensorflow/core/platform:build_config.bzl", "pyx_library", "tf_additional_all_protos", "tf_additional_lib_deps", "tf_proto_library", "tf_proto_library_py", "tf_protos_grappler") # @unused
@ -2071,6 +2074,7 @@ tf_py_test(
srcs = ["framework/constant_op_test.py"], srcs = ["framework/constant_op_test.py"],
main = "framework/constant_op_test.py", main = "framework/constant_op_test.py",
python_version = "PY3", python_version = "PY3",
tfrt_enabled = True,
deps = [ deps = [
":constant_op", ":constant_op",
], ],

View File

@ -108,15 +108,15 @@ class ResNet50Test(tf.test.TestCase):
self._apply(defun=False) self._apply(defun=False)
@test_util.disable_tfrt( @test_util.disable_tfrt(
'TFE_ContextGetExecutorForThread not implemented for tfrt') 'TFE_ContextGetExecutorForThread not implemented b/156188669')
def test_apply_async(self): def test_apply_async(self):
self._apply(defun=False, execution_mode=context.ASYNC) self._apply(defun=False, execution_mode=context.ASYNC)
@test_util.disable_tfrt('Graph is not supported yet.') @test_util.disable_tfrt('Graph is not supported yet. b/156187905')
def test_apply_with_defun(self): def test_apply_with_defun(self):
self._apply(defun=True) self._apply(defun=True)
@test_util.disable_tfrt('Graph is not supported yet.') @test_util.disable_tfrt('Graph is not supported yet. b/156187905')
def test_apply_with_defun_async(self): def test_apply_with_defun_async(self):
self._apply(defun=True, execution_mode=context.ASYNC) self._apply(defun=True, execution_mode=context.ASYNC)
@ -217,7 +217,7 @@ class ResNet50Test(tf.test.TestCase):
def test_train(self): def test_train(self):
self._test_train() self._test_train()
@test_util.disable_tfrt('b/155260334') @test_util.disable_tfrt('TFE_ContextGetExecutorForThread missing b/156188669')
def test_train_async(self): def test_train_async(self):
self._test_train(execution_mode=context.ASYNC) self._test_train(execution_mode=context.ASYNC)
@ -329,7 +329,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
defun=False, defun=False,
execution_mode=context.ASYNC) execution_mode=context.ASYNC)
@test_util.disable_tfrt('Graph is not supported yet.') @test_util.disable_tfrt('Graph is not supported yet. b/156187905')
def benchmark_eager_apply_with_defun(self): def benchmark_eager_apply_with_defun(self):
self._benchmark_eager_apply( self._benchmark_eager_apply(
'eager_apply_with_defun', 'eager_apply_with_defun',
@ -389,7 +389,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
defun=False, defun=False,
execution_mode=context.ASYNC) execution_mode=context.ASYNC)
@test_util.disable_tfrt('Graph is not supported yet.') @test_util.disable_tfrt('Graph is not supported yet. b/156187905')
def benchmark_eager_train_with_defun(self): def benchmark_eager_train_with_defun(self):
self._benchmark_eager_train( self._benchmark_eager_train(
'eager_train_with_defun', MockIterator, 'eager_train_with_defun', MockIterator,
@ -408,7 +408,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
resnet50_test_util.device_and_data_format(), resnet50_test_util.device_and_data_format(),
defun=False) defun=False)
@test_util.disable_tfrt('Graph is not supported yet.') @test_util.disable_tfrt('Graph is not supported yet. b/156187905')
def benchmark_eager_train_datasets_with_defun(self): def benchmark_eager_train_datasets_with_defun(self):
def make_iterator(tensors): def make_iterator(tensors):

View File

@ -618,7 +618,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
self._benchmark_tfe_py_execute_matmul( self._benchmark_tfe_py_execute_matmul(
m, transpose_b=False, num_iters=self._num_iters_2_by_2) m, transpose_b=False, num_iters=self._num_iters_2_by_2)
@test_util.disable_tfrt("defun not supported") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def benchmark_defun_matmul_2_by_2_GPU(self): def benchmark_defun_matmul_2_by_2_GPU(self):
if not context.num_gpus(): if not context.num_gpus():
return return
@ -639,7 +639,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
num_iters=self._num_iters_2_by_2, num_iters=self._num_iters_2_by_2,
execution_mode=context.ASYNC) execution_mode=context.ASYNC)
@test_util.disable_tfrt("function not supported") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def benchmark_nested_defun_matmul_2_by_2(self): def benchmark_nested_defun_matmul_2_by_2(self):
m = self._m_2_by_2.cpu() m = self._m_2_by_2.cpu()
self._benchmark_nested_defun_matmul( self._benchmark_nested_defun_matmul(
@ -687,7 +687,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
self._benchmark_tfe_py_execute_matmul( self._benchmark_tfe_py_execute_matmul(
m, transpose_b=True, num_iters=self._num_iters_100_by_784) m, transpose_b=True, num_iters=self._num_iters_100_by_784)
@test_util.disable_tfrt("function not supported") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def benchmark_defun_matmul_100_by_784_CPU(self): def benchmark_defun_matmul_100_by_784_CPU(self):
with context.device(CPU): with context.device(CPU):
m = self._m_100_by_784.cpu() m = self._m_100_by_784.cpu()
@ -815,35 +815,35 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
func() func()
self._run(func, 3000) self._run(func, 3000)
@test_util.disable_tfrt("defun not supported") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def benchmark_forwardprop_matmul_256_by_2096_CPU(self): def benchmark_forwardprop_matmul_256_by_2096_CPU(self):
self._benchmark_forwardprop_matmul_CPU(shape=(256, 2096)) self._benchmark_forwardprop_matmul_CPU(shape=(256, 2096))
@test_util.disable_tfrt("defun not supported") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def benchmark_forwardprop_in_defun_matmul_256_by_2096_CPU(self): def benchmark_forwardprop_in_defun_matmul_256_by_2096_CPU(self):
self._benchmark_forwardprop_in_defun_matmul_CPU(shape=(256, 2096)) self._benchmark_forwardprop_in_defun_matmul_CPU(shape=(256, 2096))
@test_util.disable_tfrt("defun not supported") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def benchmark_forwardprop_in_defun_of_defun_matmul_256_by_2096_CPU(self): def benchmark_forwardprop_in_defun_of_defun_matmul_256_by_2096_CPU(self):
self._benchmark_forwardprop_in_defun_of_defun_matmul_CPU(shape=(256, 2096)) self._benchmark_forwardprop_in_defun_of_defun_matmul_CPU(shape=(256, 2096))
@test_util.disable_tfrt("defun not supported") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def benchmark_forwardprop_of_defun_matmul_256_by_2096_CPU(self): def benchmark_forwardprop_of_defun_matmul_256_by_2096_CPU(self):
self._benchmark_forwardprop_of_defun_matmul_CPU(shape=(256, 2096)) self._benchmark_forwardprop_of_defun_matmul_CPU(shape=(256, 2096))
@test_util.disable_tfrt("defun not supported") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def benchmark_forwardprop_matmul_100_by_784_CPU(self): def benchmark_forwardprop_matmul_100_by_784_CPU(self):
self._benchmark_forwardprop_matmul_CPU(shape=(100, 784)) self._benchmark_forwardprop_matmul_CPU(shape=(100, 784))
@test_util.disable_tfrt("defun not supported") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def benchmark_forwardprop_in_defun_matmul_100_by_784_CPU(self): def benchmark_forwardprop_in_defun_matmul_100_by_784_CPU(self):
self._benchmark_forwardprop_in_defun_matmul_CPU(shape=(100, 784)) self._benchmark_forwardprop_in_defun_matmul_CPU(shape=(100, 784))
@test_util.disable_tfrt("defun not supported") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def benchmark_forwardprop_in_defun_of_defun_matmul_100_by_784_CPU(self): def benchmark_forwardprop_in_defun_of_defun_matmul_100_by_784_CPU(self):
self._benchmark_forwardprop_in_defun_of_defun_matmul_CPU(shape=(100, 784)) self._benchmark_forwardprop_in_defun_of_defun_matmul_CPU(shape=(100, 784))
@test_util.disable_tfrt("defun not supported") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def benchmark_forwardprop_of_defun_matmul_100_by_784_CPU(self): def benchmark_forwardprop_of_defun_matmul_100_by_784_CPU(self):
self._benchmark_forwardprop_of_defun_matmul_CPU(shape=(100, 784)) self._benchmark_forwardprop_of_defun_matmul_CPU(shape=(100, 784))
@ -1097,7 +1097,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
m = resource_variable_ops.ResourceVariable(self._m_2_by_2) m = resource_variable_ops.ResourceVariable(self._m_2_by_2)
self._benchmark_transpose(m, num_iters=self._num_iters_2_by_2) self._benchmark_transpose(m, num_iters=self._num_iters_2_by_2)
@test_util.disable_tfrt("defun not supported") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def benchmark_defun_without_signature(self): def benchmark_defun_without_signature(self):
def func(t1, t2, t3, t4, t5, t6, t7, t8): def func(t1, t2, t3, t4, t5, t6, t7, t8):
@ -1109,7 +1109,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
cache_computation = lambda: defined(t, t, t, t, t, t, t, t) cache_computation = lambda: defined(t, t, t, t, t, t, t, t)
self._run(cache_computation, 30000) self._run(cache_computation, 30000)
@test_util.disable_tfrt("defun not supported") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def benchmark_defun_without_signature_and_with_kwargs(self): def benchmark_defun_without_signature_and_with_kwargs(self):
def func(t1, t2, t3, t4, t5, t6, t7, t8): def func(t1, t2, t3, t4, t5, t6, t7, t8):
@ -1122,7 +1122,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
return defined(t1=t, t2=t, t3=t, t4=t, t5=t, t6=t, t7=t, t8=t) return defined(t1=t, t2=t, t3=t, t4=t, t5=t, t6=t, t7=t, t8=t)
self._run(cache_computation, 30000) self._run(cache_computation, 30000)
@test_util.disable_tfrt("defun not supported") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def benchmark_defun_with_signature(self): def benchmark_defun_with_signature(self):
def func(t1, t2, t3, t4, t5, t6, t7, t8): def func(t1, t2, t3, t4, t5, t6, t7, t8):
@ -1135,7 +1135,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
signature_computation = lambda: defined(t, t, t, t, t, t, t, t) signature_computation = lambda: defined(t, t, t, t, t, t, t, t)
self._run(signature_computation, 30000) self._run(signature_computation, 30000)
@test_util.disable_tfrt("defun not supported") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def benchmark_defun_with_signature_and_kwargs(self): def benchmark_defun_with_signature_and_kwargs(self):
def func(t1, t2, t3, t4, t5, t6, t7, t8): def func(t1, t2, t3, t4, t5, t6, t7, t8):
@ -1305,11 +1305,11 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
resources.append(resource_variable_ops.ResourceVariable(self._m_2)) resources.append(resource_variable_ops.ResourceVariable(self._m_2))
self._run(lambda: add_all(resources), num_iters) self._run(lambda: add_all(resources), num_iters)
@test_util.disable_tfrt("funtion not supported") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def benchmarkFunctionWithFiveResourceInputs(self): def benchmarkFunctionWithFiveResourceInputs(self):
self._benchmarkFunctionWithResourceInputs(5, 1000) self._benchmarkFunctionWithResourceInputs(5, 1000)
@test_util.disable_tfrt("funtion not supported") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def benchmarkFunctionWithFiveHundredResourceInputs(self): def benchmarkFunctionWithFiveHundredResourceInputs(self):
self._benchmarkFunctionWithResourceInputs(500, 100) self._benchmarkFunctionWithResourceInputs(500, 100)
@ -1344,15 +1344,15 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
with context.device(CPU): with context.device(CPU):
self._run(benchmark_fn, 10) self._run(benchmark_fn, 10)
@test_util.disable_tfrt("funtion not supported") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def benchmarkTenThousandResourceReadsInCondInInnerFunc(self): def benchmarkTenThousandResourceReadsInCondInInnerFunc(self):
self._benchmarkResourceReadsInCondInInnerFunc(10000) self._benchmarkResourceReadsInCondInInnerFunc(10000)
@test_util.disable_tfrt("funtion not supported") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def benchmarkHundredResourceReadsInCondInInnerFunc(self): def benchmarkHundredResourceReadsInCondInInnerFunc(self):
self._benchmarkResourceReadsInCondInInnerFunc(100) self._benchmarkResourceReadsInCondInInnerFunc(100)
@test_util.disable_tfrt("funtion not supported") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def benchmarkTenResourceReadsInCondInInnerFunc(self): def benchmarkTenResourceReadsInCondInInnerFunc(self):
self._benchmarkResourceReadsInCondInInnerFunc(10) self._benchmarkResourceReadsInCondInInnerFunc(10)

View File

@ -91,7 +91,7 @@ class ResourceTest(test_util.TensorFlowTestCase):
resources.shared_resources()).eval()), 0) resources.shared_resources()).eval()), 0)
@test_util.disable_tfrt("Graph is not supported yet.") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
class TensorAndShapeTest(test_util.TensorFlowTestCase): class TensorAndShapeTest(test_util.TensorFlowTestCase):
def testShape(self): def testShape(self):
@ -311,7 +311,8 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
del x del x
self.assertIsNotNone(x_ref.deref()) self.assertIsNotNone(x_ref.deref())
@test_util.disable_tfrt("Graph mode is not supported yet.")
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
@test_util.run_all_in_graph_and_eager_modes @test_util.run_all_in_graph_and_eager_modes
class IndexedSlicesTest(test_util.TensorFlowTestCase): class IndexedSlicesTest(test_util.TensorFlowTestCase):
@ -356,7 +357,7 @@ class IndexedSlicesTest(test_util.TensorFlowTestCase):
self.assertAllEqual(x.indices, [0, 2]) self.assertAllEqual(x.indices, [0, 2])
@test_util.disable_tfrt("Graph mode is not supported yet.") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
@test_util.run_all_in_graph_and_eager_modes @test_util.run_all_in_graph_and_eager_modes
class IndexedSlicesSpecTest(test_util.TensorFlowTestCase, class IndexedSlicesSpecTest(test_util.TensorFlowTestCase,
parameterized.TestCase): parameterized.TestCase):
@ -502,7 +503,7 @@ def _apply_op(g, *args, **kwargs):
return op.outputs return op.outputs
@test_util.disable_tfrt("Graph is not supported yet.") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
class OperationTest(test_util.TensorFlowTestCase): class OperationTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
@ -1445,7 +1446,7 @@ class NameTest(test_util.TensorFlowTestCase):
g.create_op("FloatOutput", [], [dtypes.float32]).name) g.create_op("FloatOutput", [], [dtypes.float32]).name)
@test_util.disable_tfrt("Device API are not supported yet.") @test_util.disable_tfrt("Device API are not supported yet. b/156188344")
class DeviceTest(test_util.TensorFlowTestCase): class DeviceTest(test_util.TensorFlowTestCase):
def testNoDevice(self): def testNoDevice(self):
@ -2026,7 +2027,7 @@ class CollectionTest(test_util.TensorFlowTestCase):
# Collections are ordered. # Collections are ordered.
self.assertEqual([90, 100], ops.get_collection("key")) self.assertEqual([90, 100], ops.get_collection("key"))
@test_util.disable_tfrt("Functions are not supported yet.") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def test_defun(self): def test_defun(self):
with context.eager_mode(): with context.eager_mode():
@ -2133,7 +2134,7 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase):
# e should be dominated by c. # e should be dominated by c.
self.assertEqual(e.op.control_inputs, []) self.assertEqual(e.op.control_inputs, [])
@test_util.disable_tfrt("Graph is not supported yet.") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testEager(self): def testEager(self):
def future(): def future():
@ -2454,7 +2455,7 @@ class OpScopeTest(test_util.TensorFlowTestCase):
self._testGraphElements([a, variable, b]) self._testGraphElements([a, variable, b])
@test_util.disable_tfrt("Graphs are not supported yet.") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
class InitScopeTest(test_util.TensorFlowTestCase): class InitScopeTest(test_util.TensorFlowTestCase):
def testClearsControlDependencies(self): def testClearsControlDependencies(self):
@ -2757,7 +2758,7 @@ class InitScopeTest(test_util.TensorFlowTestCase):
self.assertFalse(self.evaluate(f())) self.assertFalse(self.evaluate(f()))
@test_util.disable_tfrt("Graphs are not supported yet.") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
class GraphTest(test_util.TensorFlowTestCase): class GraphTest(test_util.TensorFlowTestCase):
def setUp(self): def setUp(self):
@ -3235,7 +3236,7 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
b = variables.Variable([3.0], name="b") b = variables.Variable([3.0], name="b")
self.assertEqual([b"loc:@a"], b.op.colocation_groups()) self.assertEqual([b"loc:@a"], b.op.colocation_groups())
@test_util.disable_tfrt("Functions are not supported yet.") @test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def testColocateWithVariableInFunction(self): def testColocateWithVariableInFunction(self):
v = variables.Variable(1.) v = variables.Variable(1.)

View File

@ -864,6 +864,7 @@ cuda_py_test(
srcs = ["resource_variable_ops_test.py"], srcs = ["resource_variable_ops_test.py"],
# TODO(b/128347673): Re-enable. # TODO(b/128347673): Re-enable.
tags = ["no_windows"], tags = ["no_windows"],
tfrt_enabled = True,
deps = [ deps = [
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",

View File

@ -87,6 +87,7 @@ cuda_py_test(
name = "random_ops_test", name = "random_ops_test",
size = "medium", size = "medium",
srcs = ["random_ops_test.py"], srcs = ["random_ops_test.py"],
tfrt_enabled = True,
deps = [ deps = [
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
@ -101,6 +102,7 @@ cuda_py_test(
size = "medium", size = "medium",
srcs = ["stateless_random_ops_test.py"], srcs = ["stateless_random_ops_test.py"],
shard_count = 2, shard_count = 2,
tfrt_enabled = True,
deps = [ deps = [
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",

View File

@ -336,6 +336,8 @@ class RandomUniformTest(RandomOpTestCommon):
self.assertLess(error.max(), 5 * std) self.assertLess(error.max(), 5 * std)
# Check that minval = maxval is fine iff we're producing no numbers # Check that minval = maxval is fine iff we're producing no numbers
@test_util.disable_tfrt(
"TFE_TensorHandleToNumpy not implemented yet. b/156191611")
def testUniformIntsDegenerate(self): def testUniformIntsDegenerate(self):
for dt in dtypes.int32, dtypes.int64: for dt in dtypes.int32, dtypes.int64:
def sample(n): def sample(n):

View File

@ -154,44 +154,54 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
**kwds), **kwds),
functools.partial(random_ops.random_poisson, shape=(10,), **kwds)) functools.partial(random_ops.random_poisson, shape=(10,), **kwds))
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testMatchFloat(self): def testMatchFloat(self):
self._test_match(self._float_cases()) self._test_match(self._float_cases())
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testMatchInt(self): def testMatchInt(self):
self._test_match(self._int_cases()) self._test_match(self._int_cases())
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testMatchMultinomial(self): def testMatchMultinomial(self):
self._test_match(self._multinomial_cases()) self._test_match(self._multinomial_cases())
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testMatchGamma(self): def testMatchGamma(self):
self._test_match(self._gamma_cases()) self._test_match(self._gamma_cases())
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testMatchPoisson(self): def testMatchPoisson(self):
self._test_match(self._poisson_cases()) self._test_match(self._poisson_cases())
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testDeterminismFloat(self): def testDeterminismFloat(self):
self._test_determinism( self._test_determinism(
self._float_cases(shape_dtypes=(dtypes.int32, dtypes.int64))) self._float_cases(shape_dtypes=(dtypes.int32, dtypes.int64)))
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testDeterminismInt(self): def testDeterminismInt(self):
self._test_determinism( self._test_determinism(
self._int_cases(shape_dtypes=(dtypes.int32, dtypes.int64))) self._int_cases(shape_dtypes=(dtypes.int32, dtypes.int64)))
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testDeterminismMultinomial(self): def testDeterminismMultinomial(self):
self._test_determinism(self._multinomial_cases()) self._test_determinism(self._multinomial_cases())
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testDeterminismGamma(self): def testDeterminismGamma(self):
self._test_determinism(self._gamma_cases()) self._test_determinism(self._gamma_cases())
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testDeterminismPoisson(self): def testDeterminismPoisson(self):
self._test_determinism(self._poisson_cases()) self._test_determinism(self._poisson_cases())

View File

@ -57,6 +57,8 @@ from tensorflow.python.training import training_util
from tensorflow.python.util import compat from tensorflow.python.util import compat
@test_util.disable_tfrt(
"Trying to assign variable with wrong dtype. b/156200342")
@test_util.with_control_flow_v2 @test_util.with_control_flow_v2
class ResourceVariableOpsTest(test_util.TensorFlowTestCase, class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
parameterized.TestCase): parameterized.TestCase):
@ -332,6 +334,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
g = gradients_impl.gradients(c, [b], unconnected_gradients="zero")[0] g = gradients_impl.gradients(c, [b], unconnected_gradients="zero")[0]
self.assertAllEqual(g.shape.as_list(), [1, 2]) self.assertAllEqual(g.shape.as_list(), [1, 2])
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testGradientCondInWhileLoop(self): def testGradientCondInWhileLoop(self):
v = resource_variable_ops.ResourceVariable(initial_value=1.0) v = resource_variable_ops.ResourceVariable(initial_value=1.0)
@ -965,6 +968,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
assign = var.assign(np.zeros(shape=[2, 2])) assign = var.assign(np.zeros(shape=[2, 2]))
self.evaluate(assign) self.evaluate(assign)
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
@test_util.disable_xla("XLA doesn't allow changing shape at assignment, as " @test_util.disable_xla("XLA doesn't allow changing shape at assignment, as "
"dictated by tf2xla/xla_resource.cc:SetTypeAndShape") "dictated by tf2xla/xla_resource.cc:SetTypeAndShape")
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
@ -1327,6 +1331,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
# TODO(ebrevdo): Add run_in_graph_and_eager_modes once we can create # TODO(ebrevdo): Add run_in_graph_and_eager_modes once we can create
# EagerTensor constants with TensorProto inputs. # EagerTensor constants with TensorProto inputs.
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testVariantInitializer(self): def testVariantInitializer(self):
variant_shape_and_type_data = self.create_variant_shape_and_type_data() variant_shape_and_type_data = self.create_variant_shape_and_type_data()
@ -1520,6 +1525,7 @@ class PerReplicaResourceHandleTest(test_util.TensorFlowTestCase):
context.LogicalDeviceConfiguration(), context.LogicalDeviceConfiguration(),
]) ])
@test_util.disable_tfrt("Multiple device support. b/154956430")
def testAllowedDevices(self): def testAllowedDevices(self):
device0 = "/job:localhost/replica:0/task:0/device:CPU:0" device0 = "/job:localhost/replica:0/task:0/device:CPU:0"
device1 = "/job:localhost/replica:0/task:0/device:CPU:1" device1 = "/job:localhost/replica:0/task:0/device:CPU:1"