Disable a few failing tests in tensorflow.v2_tfrt and fix tensorflow/python/ops/ragged:ragged_map_fn_op_test.
PiperOrigin-RevId: 351449157 Change-Id: Ie6486b9c6e13e2f10143c403350aaca84d1b3441
This commit is contained in:
parent
0ead9dec60
commit
b8cd771a05
@ -217,6 +217,8 @@ class XLACompileContextTest(test.TestCase, parameterized.TestCase):
|
||||
class XlaCompileTest(test.TestCase):
|
||||
|
||||
@test_util.run_v2_only
|
||||
@test_util.disable_tfrt(
|
||||
'Legacy XLA test. It depends on EncapsulateXlaComputationsPass.')
|
||||
def test_xla_compile_eager(self):
|
||||
"""Tests that xla.compile raises proper exception when used eagerly."""
|
||||
|
||||
@ -225,6 +227,8 @@ class XlaCompileTest(test.TestCase):
|
||||
|
||||
self.assertEqual(self.evaluate(xla.compile(computation, [1, 2])[0]), 3)
|
||||
|
||||
@test_util.disable_tfrt(
|
||||
'Legacy XLA test. It depends on EncapsulateXlaComputationsPass.')
|
||||
def test_xla_compile_in_function(self):
|
||||
"""Tests that xla.compile works in tf.function."""
|
||||
|
||||
@ -238,6 +242,8 @@ class XlaCompileTest(test.TestCase):
|
||||
|
||||
self.assertEqual(self.evaluate(func_wrapper(1))[0], 2)
|
||||
|
||||
@test_util.disable_tfrt(
|
||||
'Legacy XLA test. It depends on EncapsulateXlaComputationsPass.')
|
||||
def test_xla_compile_write_variable_in_function(self):
|
||||
"""Tests that xla.compile works with variable in tf.function."""
|
||||
a = variable_scope.get_variable(
|
||||
|
@ -248,16 +248,22 @@ class EagerClusterReplicateTest(test_base.DatasetTestBase,
|
||||
def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
|
||||
super(EagerClusterReplicateTest, self).__init__(methodName)
|
||||
self._job_name = "remove_device"
|
||||
self._cached_server1 = server_lib.Server.create_local_server()
|
||||
self._cached_server2 = server_lib.Server.create_local_server()
|
||||
self._cached_server1_target = self._cached_server1.target[len("grpc://"):]
|
||||
self._cached_server2_target = self._cached_server2.target[len("grpc://"):]
|
||||
self._device0 = "/job:%s/replica:0/task:0/device:CPU:0" % self._job_name
|
||||
self._device1 = "/job:%s/replica:0/task:1/device:CPU:0" % self._job_name
|
||||
self._device2 = "/job:%s/replica:0/task:2/device:CPU:0" % self._job_name
|
||||
|
||||
def setUp(self):
|
||||
super(EagerClusterReplicateTest, self).setUp()
|
||||
|
||||
if context.context().use_tfrt:
|
||||
self.skipTest("b/171412104: This test requires distributed support.")
|
||||
|
||||
# TODO(b/171412104): Move create server to __init__ once tfrt support it.
|
||||
self._cached_server1 = server_lib.Server.create_local_server()
|
||||
self._cached_server2 = server_lib.Server.create_local_server()
|
||||
self._cached_server1_target = self._cached_server1.target[len("grpc://"):]
|
||||
self._cached_server2_target = self._cached_server2.target[len("grpc://"):]
|
||||
|
||||
# Start the local server.
|
||||
local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
|
||||
context.set_server_def(
|
||||
|
@ -30,6 +30,7 @@ from tensorflow.python.data.experimental.ops import snapshot
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import readers as core_readers
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
@ -371,6 +372,8 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testReadOptimizableUsingFlatMap(self):
|
||||
if context.context().use_tfrt:
|
||||
self.skipTest("b/177260096: Flaky test.")
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
# Will be optimized into ShuffleAndRepeat.
|
||||
dataset = dataset.shuffle(10)
|
||||
|
@ -250,14 +250,12 @@ class CheckNumericsCallbackUnhealthyTest(test_util.TensorFlowTestCase):
|
||||
# Check that the correct line for op creation is printed.
|
||||
self.assertTrue(re.search(r"Stack trace of op's creation", message))
|
||||
self.assertIn("return math_ops.log(-x)", message)
|
||||
if context.executing_eagerly():
|
||||
# The code path for raising error is slightly different under graph mode.
|
||||
self.assertTrue(message.endswith("\n"))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@test_util.disable_xla(
|
||||
"There is a small inconsistency in the step at which overflow happens: "
|
||||
"128 (without XLA) and 127 (with XLA).")
|
||||
@test_util.disable_tfrt("b/177261532: TFRT cannot detect overflow yet.")
|
||||
def testOverflowInTfFunction(self):
|
||||
"""Test catching Infinity caused by overflow in a tf.function with while."""
|
||||
check_numerics_callback.enable_check_numerics()
|
||||
|
@ -254,6 +254,7 @@ py_test(
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_tfrt", # TFRT doesn't support custom device yet.
|
||||
],
|
||||
deps = [
|
||||
":context",
|
||||
|
@ -1461,6 +1461,7 @@ class CondV2ContainerTest(test.TestCase):
|
||||
self.assertEqual(compat.as_bytes(""), container(q5.queue_ref))
|
||||
|
||||
|
||||
@test_util.disable_tfrt("b/171412104: This test requires distributed support.")
|
||||
class CondV2ColocationGroupAndDeviceTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -444,6 +444,8 @@ class CollectiveOpTest(test.TestCase):
|
||||
self.assertAllClose(results_[1], expected_output_, rtol=1e-5, atol=1e-5)
|
||||
|
||||
@test_util.run_v2_only
|
||||
@test_util.disable_tfrt(
|
||||
'b/177270918: TFRT has dead lock when executing collective ops.')
|
||||
def testCollectiveGroupSizeMismatch(self):
|
||||
cpus = config.list_physical_devices('CPU')
|
||||
self.assertEqual(len(cpus), 1)
|
||||
|
Loading…
Reference in New Issue
Block a user