Enable framework_ops_test test for TFRT.

Also class support for test_util.disable_tfrt().

PiperOrigin-RevId: 308743069
Change-Id: I813ab74fac4c94c2075fc31a677f4aadc7e138f5
This commit is contained in:
Kibeom Kim 2020-04-27 18:56:35 -07:00 committed by TensorFlower Gardener
parent f4a02e4f54
commit 7f2dc64ee0
4 changed files with 32 additions and 13 deletions

View File

@ -1969,6 +1969,7 @@ py_library(
) )
# Including this as a dependency will result in tests to use TFRT. # Including this as a dependency will result in tests to use TFRT.
# TODO(b/153582383): Move tf_ops_alwayslink dependency to c_api_tfrt instead.
py_library( py_library(
name = "is_tfrt_test_true", name = "is_tfrt_test_true",
srcs = ["framework/is_tfrt_test_true.py"], srcs = ["framework/is_tfrt_test_true.py"],
@ -2378,6 +2379,7 @@ tf_py_test(
main = "framework/ops_test.py", main = "framework/ops_test.py",
python_version = "PY3", python_version = "PY3",
tags = ["no_pip"], # test_ops_2 is not available in pip. tags = ["no_pip"], # test_ops_2 is not available in pip.
tfrt_enabled = True,
deps = [ deps = [
":cond_v2", ":cond_v2",
":control_flow_ops", ":control_flow_ops",

View File

@ -621,7 +621,6 @@ py_library(
deps = [":test"], deps = [":test"],
) )
# TODO(b/153582383): Move tf_ops_alwayslink dependency to c_api_tfrt instead.
cuda_py_test( cuda_py_test(
name = "benchmarks_test", name = "benchmarks_test",
srcs = ["benchmarks_test.py"], srcs = ["benchmarks_test.py"],

View File

@ -90,6 +90,7 @@ class ResourceTest(test_util.TensorFlowTestCase):
resources.shared_resources()).eval()), 0) resources.shared_resources()).eval()), 0)
@test_util.disable_tfrt("Graph is not supported yet.")
class TensorAndShapeTest(test_util.TensorFlowTestCase): class TensorAndShapeTest(test_util.TensorFlowTestCase):
def testShape(self): def testShape(self):
@ -309,6 +310,7 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
del x del x
self.assertIsNotNone(x_ref.deref()) self.assertIsNotNone(x_ref.deref())
@test_util.disable_tfrt("Graph mode is not supported yet.")
@test_util.run_all_in_graph_and_eager_modes @test_util.run_all_in_graph_and_eager_modes
class IndexedSlicesTest(test_util.TensorFlowTestCase): class IndexedSlicesTest(test_util.TensorFlowTestCase):
@ -353,6 +355,7 @@ class IndexedSlicesTest(test_util.TensorFlowTestCase):
self.assertAllEqual(x.indices, [0, 2]) self.assertAllEqual(x.indices, [0, 2])
@test_util.disable_tfrt("Graph mode is not supported yet.")
@test_util.run_all_in_graph_and_eager_modes @test_util.run_all_in_graph_and_eager_modes
class IndexedSlicesSpecTest(test_util.TensorFlowTestCase, class IndexedSlicesSpecTest(test_util.TensorFlowTestCase,
parameterized.TestCase): parameterized.TestCase):
@ -498,6 +501,7 @@ def _apply_op(g, *args, **kwargs):
return op.outputs return op.outputs
@test_util.disable_tfrt("Graph is not supported yet.")
class OperationTest(test_util.TensorFlowTestCase): class OperationTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
@ -1428,6 +1432,7 @@ class NameTest(test_util.TensorFlowTestCase):
g.create_op("FloatOutput", [], [dtypes.float32]).name) g.create_op("FloatOutput", [], [dtypes.float32]).name)
@test_util.disable_tfrt("Device API are not supported yet.")
class DeviceTest(test_util.TensorFlowTestCase): class DeviceTest(test_util.TensorFlowTestCase):
def testNoDevice(self): def testNoDevice(self):
@ -2008,6 +2013,7 @@ class CollectionTest(test_util.TensorFlowTestCase):
# Collections are ordered. # Collections are ordered.
self.assertEqual([90, 100], ops.get_collection("key")) self.assertEqual([90, 100], ops.get_collection("key"))
@test_util.disable_tfrt("Functions are not supported yet.")
def test_defun(self): def test_defun(self):
with context.eager_mode(): with context.eager_mode():
@ -2114,6 +2120,7 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase):
# e should be dominated by c. # e should be dominated by c.
self.assertEqual(e.op.control_inputs, []) self.assertEqual(e.op.control_inputs, [])
@test_util.disable_tfrt("Graph is not supported yet.")
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testEager(self): def testEager(self):
def future(): def future():
@ -2434,6 +2441,7 @@ class OpScopeTest(test_util.TensorFlowTestCase):
self._testGraphElements([a, variable, b]) self._testGraphElements([a, variable, b])
@test_util.disable_tfrt("Graphs are not supported yet.")
class InitScopeTest(test_util.TensorFlowTestCase): class InitScopeTest(test_util.TensorFlowTestCase):
def testClearsControlDependencies(self): def testClearsControlDependencies(self):
@ -2736,6 +2744,7 @@ class InitScopeTest(test_util.TensorFlowTestCase):
self.assertFalse(self.evaluate(f())) self.assertFalse(self.evaluate(f()))
@test_util.disable_tfrt("Graphs are not supported yet.")
class GraphTest(test_util.TensorFlowTestCase): class GraphTest(test_util.TensorFlowTestCase):
def setUp(self): def setUp(self):
@ -3213,6 +3222,7 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
b = variables.Variable([3.0], name="b") b = variables.Variable([3.0], name="b")
self.assertEqual([b"loc:@a"], b.op.colocation_groups()) self.assertEqual([b"loc:@a"], b.op.colocation_groups())
@test_util.disable_tfrt("Functions are not supported yet.")
def testColocateWithVariableInFunction(self): def testColocateWithVariableInFunction(self):
v = variables.Variable(1.) v = variables.Variable(1.)
@ -3248,6 +3258,7 @@ class DeprecatedTest(test_util.TensorFlowTestCase):
class DenseTensorLikeTypeTest(test_util.TensorFlowTestCase): class DenseTensorLikeTypeTest(test_util.TensorFlowTestCase):
@test_util.disable_tfrt("Graph is not supported yet.")
def testSuccess(self): def testSuccess(self):
op = ops.Operation( op = ops.Operation(
ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
@ -3421,6 +3432,7 @@ ops.register_tensor_conversion_function(
class CustomConvertToCompositeTensorTest(test_util.TensorFlowTestCase): class CustomConvertToCompositeTensorTest(test_util.TensorFlowTestCase):
@test_util.disable_tfrt("b/154858769")
def testCompositeTensorConversion(self): def testCompositeTensorConversion(self):
"""Tests that a user can register a CompositeTensor converter.""" """Tests that a user can register a CompositeTensor converter."""
x = _MyTuple((1, [2., 3.], [[4, 5], [6, 7]])) x = _MyTuple((1, [2., 3.], [[4, 5], [6, 7]]))

View File

@ -1788,23 +1788,29 @@ def disable_mlir_bridge(description): # pylint: disable=unused-argument
# The description is just for documentation purposes. # The description is just for documentation purposes.
def disable_tfrt(unused_description): def disable_tfrt(unused_description):
def disable_tfrt_impl(func): def disable_tfrt_impl(cls_or_func):
"""Execute the test method only if tfrt is not enabled.""" """Execute the test only if tfrt is not enabled."""
def decorator(func): if tf_inspect.isclass(cls_or_func):
if is_tfrt_enabled():
return None
else:
return cls_or_func
else:
def decorator(func):
def decorated(self, *args, **kwargs): def decorated(self, *args, **kwargs):
if is_tfrt_enabled(): if is_tfrt_enabled():
return return
else: else:
return func(self, *args, **kwargs) return func(self, *args, **kwargs)
return decorated return decorated
if func is not None: if cls_or_func is not None:
return decorator(func) return decorator(cls_or_func)
return decorator return decorator
return disable_tfrt_impl return disable_tfrt_impl