Enable framework_ops_test test for TFRT.
Also class support for test_util.disable_tfrt(). PiperOrigin-RevId: 308743069 Change-Id: I813ab74fac4c94c2075fc31a677f4aadc7e138f5
This commit is contained in:
parent
f4a02e4f54
commit
7f2dc64ee0
@ -1969,6 +1969,7 @@ py_library(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Including this as a dependency will result in tests to use TFRT.
|
# Including this as a dependency will result in tests to use TFRT.
|
||||||
|
# TODO(b/153582383): Move tf_ops_alwayslink dependency to c_api_tfrt instead.
|
||||||
py_library(
|
py_library(
|
||||||
name = "is_tfrt_test_true",
|
name = "is_tfrt_test_true",
|
||||||
srcs = ["framework/is_tfrt_test_true.py"],
|
srcs = ["framework/is_tfrt_test_true.py"],
|
||||||
@ -2378,6 +2379,7 @@ tf_py_test(
|
|||||||
main = "framework/ops_test.py",
|
main = "framework/ops_test.py",
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = ["no_pip"], # test_ops_2 is not available in pip.
|
tags = ["no_pip"], # test_ops_2 is not available in pip.
|
||||||
|
tfrt_enabled = True,
|
||||||
deps = [
|
deps = [
|
||||||
":cond_v2",
|
":cond_v2",
|
||||||
":control_flow_ops",
|
":control_flow_ops",
|
||||||
|
@ -621,7 +621,6 @@ py_library(
|
|||||||
deps = [":test"],
|
deps = [":test"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(b/153582383): Move tf_ops_alwayslink dependency to c_api_tfrt instead.
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "benchmarks_test",
|
name = "benchmarks_test",
|
||||||
srcs = ["benchmarks_test.py"],
|
srcs = ["benchmarks_test.py"],
|
||||||
|
@ -90,6 +90,7 @@ class ResourceTest(test_util.TensorFlowTestCase):
|
|||||||
resources.shared_resources()).eval()), 0)
|
resources.shared_resources()).eval()), 0)
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.disable_tfrt("Graph is not supported yet.")
|
||||||
class TensorAndShapeTest(test_util.TensorFlowTestCase):
|
class TensorAndShapeTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def testShape(self):
|
def testShape(self):
|
||||||
@ -309,6 +310,7 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
|
|||||||
del x
|
del x
|
||||||
self.assertIsNotNone(x_ref.deref())
|
self.assertIsNotNone(x_ref.deref())
|
||||||
|
|
||||||
|
@test_util.disable_tfrt("Graph mode is not supported yet.")
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class IndexedSlicesTest(test_util.TensorFlowTestCase):
|
class IndexedSlicesTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
@ -353,6 +355,7 @@ class IndexedSlicesTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllEqual(x.indices, [0, 2])
|
self.assertAllEqual(x.indices, [0, 2])
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.disable_tfrt("Graph mode is not supported yet.")
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class IndexedSlicesSpecTest(test_util.TensorFlowTestCase,
|
class IndexedSlicesSpecTest(test_util.TensorFlowTestCase,
|
||||||
parameterized.TestCase):
|
parameterized.TestCase):
|
||||||
@ -498,6 +501,7 @@ def _apply_op(g, *args, **kwargs):
|
|||||||
return op.outputs
|
return op.outputs
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.disable_tfrt("Graph is not supported yet.")
|
||||||
class OperationTest(test_util.TensorFlowTestCase):
|
class OperationTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
@ -1428,6 +1432,7 @@ class NameTest(test_util.TensorFlowTestCase):
|
|||||||
g.create_op("FloatOutput", [], [dtypes.float32]).name)
|
g.create_op("FloatOutput", [], [dtypes.float32]).name)
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.disable_tfrt("Device API are not supported yet.")
|
||||||
class DeviceTest(test_util.TensorFlowTestCase):
|
class DeviceTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def testNoDevice(self):
|
def testNoDevice(self):
|
||||||
@ -2008,6 +2013,7 @@ class CollectionTest(test_util.TensorFlowTestCase):
|
|||||||
# Collections are ordered.
|
# Collections are ordered.
|
||||||
self.assertEqual([90, 100], ops.get_collection("key"))
|
self.assertEqual([90, 100], ops.get_collection("key"))
|
||||||
|
|
||||||
|
@test_util.disable_tfrt("Functions are not supported yet.")
|
||||||
def test_defun(self):
|
def test_defun(self):
|
||||||
with context.eager_mode():
|
with context.eager_mode():
|
||||||
|
|
||||||
@ -2114,6 +2120,7 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase):
|
|||||||
# e should be dominated by c.
|
# e should be dominated by c.
|
||||||
self.assertEqual(e.op.control_inputs, [])
|
self.assertEqual(e.op.control_inputs, [])
|
||||||
|
|
||||||
|
@test_util.disable_tfrt("Graph is not supported yet.")
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testEager(self):
|
def testEager(self):
|
||||||
def future():
|
def future():
|
||||||
@ -2434,6 +2441,7 @@ class OpScopeTest(test_util.TensorFlowTestCase):
|
|||||||
self._testGraphElements([a, variable, b])
|
self._testGraphElements([a, variable, b])
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.disable_tfrt("Graphs are not supported yet.")
|
||||||
class InitScopeTest(test_util.TensorFlowTestCase):
|
class InitScopeTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def testClearsControlDependencies(self):
|
def testClearsControlDependencies(self):
|
||||||
@ -2736,6 +2744,7 @@ class InitScopeTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertFalse(self.evaluate(f()))
|
self.assertFalse(self.evaluate(f()))
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.disable_tfrt("Graphs are not supported yet.")
|
||||||
class GraphTest(test_util.TensorFlowTestCase):
|
class GraphTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -3213,6 +3222,7 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
|
|||||||
b = variables.Variable([3.0], name="b")
|
b = variables.Variable([3.0], name="b")
|
||||||
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
|
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
|
||||||
|
|
||||||
|
@test_util.disable_tfrt("Functions are not supported yet.")
|
||||||
def testColocateWithVariableInFunction(self):
|
def testColocateWithVariableInFunction(self):
|
||||||
v = variables.Variable(1.)
|
v = variables.Variable(1.)
|
||||||
|
|
||||||
@ -3248,6 +3258,7 @@ class DeprecatedTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
class DenseTensorLikeTypeTest(test_util.TensorFlowTestCase):
|
class DenseTensorLikeTypeTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@test_util.disable_tfrt("Graph is not supported yet.")
|
||||||
def testSuccess(self):
|
def testSuccess(self):
|
||||||
op = ops.Operation(
|
op = ops.Operation(
|
||||||
ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
|
ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
|
||||||
@ -3421,6 +3432,7 @@ ops.register_tensor_conversion_function(
|
|||||||
|
|
||||||
class CustomConvertToCompositeTensorTest(test_util.TensorFlowTestCase):
|
class CustomConvertToCompositeTensorTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@test_util.disable_tfrt("b/154858769")
|
||||||
def testCompositeTensorConversion(self):
|
def testCompositeTensorConversion(self):
|
||||||
"""Tests that a user can register a CompositeTensor converter."""
|
"""Tests that a user can register a CompositeTensor converter."""
|
||||||
x = _MyTuple((1, [2., 3.], [[4, 5], [6, 7]]))
|
x = _MyTuple((1, [2., 3.], [[4, 5], [6, 7]]))
|
||||||
|
@ -1788,23 +1788,29 @@ def disable_mlir_bridge(description): # pylint: disable=unused-argument
|
|||||||
# The description is just for documentation purposes.
|
# The description is just for documentation purposes.
|
||||||
def disable_tfrt(unused_description):
|
def disable_tfrt(unused_description):
|
||||||
|
|
||||||
def disable_tfrt_impl(func):
|
def disable_tfrt_impl(cls_or_func):
|
||||||
"""Execute the test method only if tfrt is not enabled."""
|
"""Execute the test only if tfrt is not enabled."""
|
||||||
|
|
||||||
def decorator(func):
|
if tf_inspect.isclass(cls_or_func):
|
||||||
|
if is_tfrt_enabled():
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return cls_or_func
|
||||||
|
else:
|
||||||
|
def decorator(func):
|
||||||
|
|
||||||
def decorated(self, *args, **kwargs):
|
def decorated(self, *args, **kwargs):
|
||||||
if is_tfrt_enabled():
|
if is_tfrt_enabled():
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
return func(self, *args, **kwargs)
|
return func(self, *args, **kwargs)
|
||||||
|
|
||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
if func is not None:
|
if cls_or_func is not None:
|
||||||
return decorator(func)
|
return decorator(cls_or_func)
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
return disable_tfrt_impl
|
return disable_tfrt_impl
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user