diff --git a/tensorflow/python/compiler/xla/xla_test.py b/tensorflow/python/compiler/xla/xla_test.py index af18abf727a..041a68fd2a6 100644 --- a/tensorflow/python/compiler/xla/xla_test.py +++ b/tensorflow/python/compiler/xla/xla_test.py @@ -217,6 +217,8 @@ class XLACompileContextTest(test.TestCase, parameterized.TestCase): class XlaCompileTest(test.TestCase): @test_util.run_v2_only + @test_util.disable_tfrt( + 'Legacy XLA test. It depends on EncapsulateXlaComputationsPass.') def test_xla_compile_eager(self): """Tests that xla.compile raises proper exception when used eagerly.""" @@ -225,6 +227,8 @@ class XlaCompileTest(test.TestCase): self.assertEqual(self.evaluate(xla.compile(computation, [1, 2])[0]), 3) + @test_util.disable_tfrt( + 'Legacy XLA test. It depends on EncapsulateXlaComputationsPass.') def test_xla_compile_in_function(self): """Tests that xla.compile works in tf.function.""" @@ -238,6 +242,8 @@ class XlaCompileTest(test.TestCase): self.assertEqual(self.evaluate(func_wrapper(1))[0], 2) + @test_util.disable_tfrt( + 'Legacy XLA test. It depends on EncapsulateXlaComputationsPass.') def test_xla_compile_write_variable_in_function(self): """Tests that xla.compile works with variable in tf.function.""" a = variable_scope.get_variable( diff --git a/tensorflow/python/data/experimental/kernel_tests/replicate_test.py b/tensorflow/python/data/experimental/kernel_tests/replicate_test.py index 4995b054011..a1e0210abb6 100644 --- a/tensorflow/python/data/experimental/kernel_tests/replicate_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/replicate_test.py @@ -248,16 +248,22 @@ class EagerClusterReplicateTest(test_base.DatasetTestBase, def __init__(self, methodName="runTest"): # pylint: disable=invalid-name super(EagerClusterReplicateTest, self).__init__(methodName) self._job_name = "remove_device" - self._cached_server1 = server_lib.Server.create_local_server() - self._cached_server2 = server_lib.Server.create_local_server() - self._cached_server1_target = self._cached_server1.target[len("grpc://"):] - self._cached_server2_target = self._cached_server2.target[len("grpc://"):] self._device0 = "/job:%s/replica:0/task:0/device:CPU:0" % self._job_name self._device1 = "/job:%s/replica:0/task:1/device:CPU:0" % self._job_name self._device2 = "/job:%s/replica:0/task:2/device:CPU:0" % self._job_name def setUp(self): super(EagerClusterReplicateTest, self).setUp() + + if context.context().use_tfrt: + self.skipTest("b/171412104: This test requires distributed support.") + + # TODO(b/171412104): Move create server to __init__ once tfrt support it. + self._cached_server1 = server_lib.Server.create_local_server() + self._cached_server2 = server_lib.Server.create_local_server() + self._cached_server1_target = self._cached_server1.target[len("grpc://"):] + self._cached_server2_target = self._cached_server2.target[len("grpc://"):] + # Start the local server. local_port = pywrap_tfe.TF_PickUnusedPortOrDie() context.set_server_def( diff --git a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py index 24307790c05..e0e96786b0c 100644 --- a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py @@ -30,6 +30,7 @@ from tensorflow.python.data.experimental.ops import snapshot from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.eager import context from tensorflow.python.framework import combinations from tensorflow.python.framework import errors from tensorflow.python.ops import gen_array_ops @@ -371,6 +372,8 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, @combinations.generate(test_base.default_test_combinations()) def testReadOptimizableUsingFlatMap(self): + if context.context().use_tfrt: + self.skipTest("b/177260096: Flaky test.") dataset = dataset_ops.Dataset.range(100) # Will be optimized into ShuffleAndRepeat. dataset = dataset.shuffle(10) diff --git a/tensorflow/python/debug/lib/check_numerics_callback_test.py b/tensorflow/python/debug/lib/check_numerics_callback_test.py index 5c0cc6394ac..b2d004d2f45 100644 --- a/tensorflow/python/debug/lib/check_numerics_callback_test.py +++ b/tensorflow/python/debug/lib/check_numerics_callback_test.py @@ -250,14 +250,12 @@ class CheckNumericsCallbackUnhealthyTest(test_util.TensorFlowTestCase): # Check that the correct line for op creation is printed. self.assertTrue(re.search(r"Stack trace of op's creation", message)) self.assertIn("return math_ops.log(-x)", message) - if context.executing_eagerly(): - # The code path for raising error is slightly different under graph mode. - self.assertTrue(message.endswith("\n")) @test_util.run_in_graph_and_eager_modes @test_util.disable_xla( "There is a small inconsistency in the step at which overflow happens: " "128 (without XLA) and 127 (with XLA).") + @test_util.disable_tfrt("b/177261532: TFRT cannot detect overflow yet.") def testOverflowInTfFunction(self): """Test catching Infinity caused by overflow in a tf.function with while.""" check_numerics_callback.enable_check_numerics() diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 5cfe805c796..33753aad257 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -254,6 +254,7 @@ py_test( tags = [ "no_oss", "no_pip", + "no_tfrt", # TFRT doesn't support custom device yet. ], deps = [ ":context", diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py index 328977181f8..4ee4ca525aa 100644 --- a/tensorflow/python/kernel_tests/cond_v2_test.py +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -1461,6 +1461,7 @@ class CondV2ContainerTest(test.TestCase): self.assertEqual(compat.as_bytes(""), container(q5.queue_ref)) +@test_util.disable_tfrt("b/171412104: This test requires distributed support.") class CondV2ColocationGroupAndDeviceTest(test.TestCase, parameterized.TestCase): def setUp(self): diff --git a/tensorflow/python/ops/collective_ops_test.py b/tensorflow/python/ops/collective_ops_test.py index d3988d0806d..5a49c833e92 100644 --- a/tensorflow/python/ops/collective_ops_test.py +++ b/tensorflow/python/ops/collective_ops_test.py @@ -444,6 +444,8 @@ class CollectiveOpTest(test.TestCase): self.assertAllClose(results_[1], expected_output_, rtol=1e-5, atol=1e-5) @test_util.run_v2_only + @test_util.disable_tfrt( + 'b/177270918: TFRT has dead lock when executing collective ops.') def testCollectiveGroupSizeMismatch(self): cpus = config.list_physical_devices('CPU') self.assertEqual(len(cpus), 1)