Disable a few failing tests in tensorflow.v2_tfrt and fix tensorflow/python/ops/ragged:ragged_map_fn_op_test.

PiperOrigin-RevId: 351449157
Change-Id: Ie6486b9c6e13e2f10143c403350aaca84d1b3441
This commit is contained in:
Xiao Yu 2021-01-12 14:23:34 -08:00 committed by TensorFlower Gardener
parent 0ead9dec60
commit b8cd771a05
7 changed files with 24 additions and 7 deletions

View File

@ -217,6 +217,8 @@ class XLACompileContextTest(test.TestCase, parameterized.TestCase):
class XlaCompileTest(test.TestCase):
@test_util.run_v2_only
@test_util.disable_tfrt(
'Legacy XLA test. It depends on EncapsulateXlaComputationsPass.')
def test_xla_compile_eager(self):
"""Tests that xla.compile raises proper exception when used eagerly."""
@ -225,6 +227,8 @@ class XlaCompileTest(test.TestCase):
self.assertEqual(self.evaluate(xla.compile(computation, [1, 2])[0]), 3)
@test_util.disable_tfrt(
'Legacy XLA test. It depends on EncapsulateXlaComputationsPass.')
def test_xla_compile_in_function(self):
"""Tests that xla.compile works in tf.function."""
@ -238,6 +242,8 @@ class XlaCompileTest(test.TestCase):
self.assertEqual(self.evaluate(func_wrapper(1))[0], 2)
@test_util.disable_tfrt(
'Legacy XLA test. It depends on EncapsulateXlaComputationsPass.')
def test_xla_compile_write_variable_in_function(self):
"""Tests that xla.compile works with variable in tf.function."""
a = variable_scope.get_variable(

View File

@ -248,16 +248,22 @@ class EagerClusterReplicateTest(test_base.DatasetTestBase,
def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
super(EagerClusterReplicateTest, self).__init__(methodName)
self._job_name = "remove_device"
self._cached_server1 = server_lib.Server.create_local_server()
self._cached_server2 = server_lib.Server.create_local_server()
self._cached_server1_target = self._cached_server1.target[len("grpc://"):]
self._cached_server2_target = self._cached_server2.target[len("grpc://"):]
self._device0 = "/job:%s/replica:0/task:0/device:CPU:0" % self._job_name
self._device1 = "/job:%s/replica:0/task:1/device:CPU:0" % self._job_name
self._device2 = "/job:%s/replica:0/task:2/device:CPU:0" % self._job_name
def setUp(self):
super(EagerClusterReplicateTest, self).setUp()
if context.context().use_tfrt:
self.skipTest("b/171412104: This test requires distributed support.")
# TODO(b/171412104): Move create server to __init__ once tfrt support it.
self._cached_server1 = server_lib.Server.create_local_server()
self._cached_server2 = server_lib.Server.create_local_server()
self._cached_server1_target = self._cached_server1.target[len("grpc://"):]
self._cached_server2_target = self._cached_server2.target[len("grpc://"):]
# Start the local server.
local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
context.set_server_def(

View File

@ -30,6 +30,7 @@ from tensorflow.python.data.experimental.ops import snapshot
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.eager import context
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.ops import gen_array_ops
@ -371,6 +372,8 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
@combinations.generate(test_base.default_test_combinations())
def testReadOptimizableUsingFlatMap(self):
if context.context().use_tfrt:
self.skipTest("b/177260096: Flaky test.")
dataset = dataset_ops.Dataset.range(100)
# Will be optimized into ShuffleAndRepeat.
dataset = dataset.shuffle(10)

View File

@ -250,14 +250,12 @@ class CheckNumericsCallbackUnhealthyTest(test_util.TensorFlowTestCase):
# Check that the correct line for op creation is printed.
self.assertTrue(re.search(r"Stack trace of op's creation", message))
self.assertIn("return math_ops.log(-x)", message)
if context.executing_eagerly():
# The code path for raising error is slightly different under graph mode.
self.assertTrue(message.endswith("\n"))
@test_util.run_in_graph_and_eager_modes
@test_util.disable_xla(
"There is a small inconsistency in the step at which overflow happens: "
"128 (without XLA) and 127 (with XLA).")
@test_util.disable_tfrt("b/177261532: TFRT cannot detect overflow yet.")
def testOverflowInTfFunction(self):
"""Test catching Infinity caused by overflow in a tf.function with while."""
check_numerics_callback.enable_check_numerics()

View File

@ -254,6 +254,7 @@ py_test(
tags = [
"no_oss",
"no_pip",
"no_tfrt", # TFRT doesn't support custom device yet.
],
deps = [
":context",

View File

@ -1461,6 +1461,7 @@ class CondV2ContainerTest(test.TestCase):
self.assertEqual(compat.as_bytes(""), container(q5.queue_ref))
@test_util.disable_tfrt("b/171412104: This test requires distributed support.")
class CondV2ColocationGroupAndDeviceTest(test.TestCase, parameterized.TestCase):
def setUp(self):

View File

@ -444,6 +444,8 @@ class CollectiveOpTest(test.TestCase):
self.assertAllClose(results_[1], expected_output_, rtol=1e-5, atol=1e-5)
@test_util.run_v2_only
@test_util.disable_tfrt(
'b/177270918: TFRT has dead lock when executing collective ops.')
def testCollectiveGroupSizeMismatch(self):
cpus = config.list_physical_devices('CPU')
self.assertEqual(len(cpus), 1)