Enable framework_ops_test test for TFRT.
Also class support for test_util.disable_tfrt(). PiperOrigin-RevId: 308743069 Change-Id: I813ab74fac4c94c2075fc31a677f4aadc7e138f5
This commit is contained in:
parent
f4a02e4f54
commit
7f2dc64ee0
tensorflow/python
@ -1969,6 +1969,7 @@ py_library(
|
||||
)
|
||||
|
||||
# Including this as a dependency will result in tests to use TFRT.
|
||||
# TODO(b/153582383): Move tf_ops_alwayslink dependency to c_api_tfrt instead.
|
||||
py_library(
|
||||
name = "is_tfrt_test_true",
|
||||
srcs = ["framework/is_tfrt_test_true.py"],
|
||||
@ -2378,6 +2379,7 @@ tf_py_test(
|
||||
main = "framework/ops_test.py",
|
||||
python_version = "PY3",
|
||||
tags = ["no_pip"], # test_ops_2 is not available in pip.
|
||||
tfrt_enabled = True,
|
||||
deps = [
|
||||
":cond_v2",
|
||||
":control_flow_ops",
|
||||
|
@ -621,7 +621,6 @@ py_library(
|
||||
deps = [":test"],
|
||||
)
|
||||
|
||||
# TODO(b/153582383): Move tf_ops_alwayslink dependency to c_api_tfrt instead.
|
||||
cuda_py_test(
|
||||
name = "benchmarks_test",
|
||||
srcs = ["benchmarks_test.py"],
|
||||
|
@ -90,6 +90,7 @@ class ResourceTest(test_util.TensorFlowTestCase):
|
||||
resources.shared_resources()).eval()), 0)
|
||||
|
||||
|
||||
@test_util.disable_tfrt("Graph is not supported yet.")
|
||||
class TensorAndShapeTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testShape(self):
|
||||
@ -309,6 +310,7 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
|
||||
del x
|
||||
self.assertIsNotNone(x_ref.deref())
|
||||
|
||||
@test_util.disable_tfrt("Graph mode is not supported yet.")
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class IndexedSlicesTest(test_util.TensorFlowTestCase):
|
||||
|
||||
@ -353,6 +355,7 @@ class IndexedSlicesTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllEqual(x.indices, [0, 2])
|
||||
|
||||
|
||||
@test_util.disable_tfrt("Graph mode is not supported yet.")
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class IndexedSlicesSpecTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
@ -498,6 +501,7 @@ def _apply_op(g, *args, **kwargs):
|
||||
return op.outputs
|
||||
|
||||
|
||||
@test_util.disable_tfrt("Graph is not supported yet.")
|
||||
class OperationTest(test_util.TensorFlowTestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@ -1428,6 +1432,7 @@ class NameTest(test_util.TensorFlowTestCase):
|
||||
g.create_op("FloatOutput", [], [dtypes.float32]).name)
|
||||
|
||||
|
||||
@test_util.disable_tfrt("Device API are not supported yet.")
|
||||
class DeviceTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testNoDevice(self):
|
||||
@ -2008,6 +2013,7 @@ class CollectionTest(test_util.TensorFlowTestCase):
|
||||
# Collections are ordered.
|
||||
self.assertEqual([90, 100], ops.get_collection("key"))
|
||||
|
||||
@test_util.disable_tfrt("Functions are not supported yet.")
|
||||
def test_defun(self):
|
||||
with context.eager_mode():
|
||||
|
||||
@ -2114,6 +2120,7 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase):
|
||||
# e should be dominated by c.
|
||||
self.assertEqual(e.op.control_inputs, [])
|
||||
|
||||
@test_util.disable_tfrt("Graph is not supported yet.")
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testEager(self):
|
||||
def future():
|
||||
@ -2434,6 +2441,7 @@ class OpScopeTest(test_util.TensorFlowTestCase):
|
||||
self._testGraphElements([a, variable, b])
|
||||
|
||||
|
||||
@test_util.disable_tfrt("Graphs are not supported yet.")
|
||||
class InitScopeTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testClearsControlDependencies(self):
|
||||
@ -2736,6 +2744,7 @@ class InitScopeTest(test_util.TensorFlowTestCase):
|
||||
self.assertFalse(self.evaluate(f()))
|
||||
|
||||
|
||||
@test_util.disable_tfrt("Graphs are not supported yet.")
|
||||
class GraphTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -3213,6 +3222,7 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
|
||||
b = variables.Variable([3.0], name="b")
|
||||
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
|
||||
|
||||
@test_util.disable_tfrt("Functions are not supported yet.")
|
||||
def testColocateWithVariableInFunction(self):
|
||||
v = variables.Variable(1.)
|
||||
|
||||
@ -3248,6 +3258,7 @@ class DeprecatedTest(test_util.TensorFlowTestCase):
|
||||
|
||||
class DenseTensorLikeTypeTest(test_util.TensorFlowTestCase):
|
||||
|
||||
@test_util.disable_tfrt("Graph is not supported yet.")
|
||||
def testSuccess(self):
|
||||
op = ops.Operation(
|
||||
ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
|
||||
@ -3421,6 +3432,7 @@ ops.register_tensor_conversion_function(
|
||||
|
||||
class CustomConvertToCompositeTensorTest(test_util.TensorFlowTestCase):
|
||||
|
||||
@test_util.disable_tfrt("b/154858769")
|
||||
def testCompositeTensorConversion(self):
|
||||
"""Tests that a user can register a CompositeTensor converter."""
|
||||
x = _MyTuple((1, [2., 3.], [[4, 5], [6, 7]]))
|
||||
|
@ -1788,23 +1788,29 @@ def disable_mlir_bridge(description): # pylint: disable=unused-argument
|
||||
# The description is just for documentation purposes.
|
||||
def disable_tfrt(unused_description):
|
||||
|
||||
def disable_tfrt_impl(func):
|
||||
"""Execute the test method only if tfrt is not enabled."""
|
||||
def disable_tfrt_impl(cls_or_func):
|
||||
"""Execute the test only if tfrt is not enabled."""
|
||||
|
||||
def decorator(func):
|
||||
if tf_inspect.isclass(cls_or_func):
|
||||
if is_tfrt_enabled():
|
||||
return None
|
||||
else:
|
||||
return cls_or_func
|
||||
else:
|
||||
def decorator(func):
|
||||
|
||||
def decorated(self, *args, **kwargs):
|
||||
if is_tfrt_enabled():
|
||||
return
|
||||
else:
|
||||
return func(self, *args, **kwargs)
|
||||
def decorated(self, *args, **kwargs):
|
||||
if is_tfrt_enabled():
|
||||
return
|
||||
else:
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
return decorated
|
||||
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
if cls_or_func is not None:
|
||||
return decorator(cls_or_func)
|
||||
|
||||
return decorator
|
||||
return decorator
|
||||
|
||||
return disable_tfrt_impl
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user