Enable framework_ops_test test for TFRT.

Also class support for test_util.disable_tfrt().

PiperOrigin-RevId: 308743069
Change-Id: I813ab74fac4c94c2075fc31a677f4aadc7e138f5
This commit is contained in:
Kibeom Kim 2020-04-27 18:56:35 -07:00 committed by TensorFlower Gardener
parent f4a02e4f54
commit 7f2dc64ee0
4 changed files with 32 additions and 13 deletions
tensorflow/python

View File

@ -1969,6 +1969,7 @@ py_library(
)
# Including this as a dependency will result in tests to use TFRT.
# TODO(b/153582383): Move tf_ops_alwayslink dependency to c_api_tfrt instead.
py_library(
name = "is_tfrt_test_true",
srcs = ["framework/is_tfrt_test_true.py"],
@ -2378,6 +2379,7 @@ tf_py_test(
main = "framework/ops_test.py",
python_version = "PY3",
tags = ["no_pip"], # test_ops_2 is not available in pip.
tfrt_enabled = True,
deps = [
":cond_v2",
":control_flow_ops",

View File

@ -621,7 +621,6 @@ py_library(
deps = [":test"],
)
# TODO(b/153582383): Move tf_ops_alwayslink dependency to c_api_tfrt instead.
cuda_py_test(
name = "benchmarks_test",
srcs = ["benchmarks_test.py"],

View File

@ -90,6 +90,7 @@ class ResourceTest(test_util.TensorFlowTestCase):
resources.shared_resources()).eval()), 0)
@test_util.disable_tfrt("Graph is not supported yet.")
class TensorAndShapeTest(test_util.TensorFlowTestCase):
def testShape(self):
@ -309,6 +310,7 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
del x
self.assertIsNotNone(x_ref.deref())
@test_util.disable_tfrt("Graph mode is not supported yet.")
@test_util.run_all_in_graph_and_eager_modes
class IndexedSlicesTest(test_util.TensorFlowTestCase):
@ -353,6 +355,7 @@ class IndexedSlicesTest(test_util.TensorFlowTestCase):
self.assertAllEqual(x.indices, [0, 2])
@test_util.disable_tfrt("Graph mode is not supported yet.")
@test_util.run_all_in_graph_and_eager_modes
class IndexedSlicesSpecTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
@ -498,6 +501,7 @@ def _apply_op(g, *args, **kwargs):
return op.outputs
@test_util.disable_tfrt("Graph is not supported yet.")
class OperationTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1
@ -1428,6 +1432,7 @@ class NameTest(test_util.TensorFlowTestCase):
g.create_op("FloatOutput", [], [dtypes.float32]).name)
@test_util.disable_tfrt("Device API are not supported yet.")
class DeviceTest(test_util.TensorFlowTestCase):
def testNoDevice(self):
@ -2008,6 +2013,7 @@ class CollectionTest(test_util.TensorFlowTestCase):
# Collections are ordered.
self.assertEqual([90, 100], ops.get_collection("key"))
@test_util.disable_tfrt("Functions are not supported yet.")
def test_defun(self):
with context.eager_mode():
@ -2114,6 +2120,7 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase):
# e should be dominated by c.
self.assertEqual(e.op.control_inputs, [])
@test_util.disable_tfrt("Graph is not supported yet.")
@test_util.run_in_graph_and_eager_modes
def testEager(self):
def future():
@ -2434,6 +2441,7 @@ class OpScopeTest(test_util.TensorFlowTestCase):
self._testGraphElements([a, variable, b])
@test_util.disable_tfrt("Graphs are not supported yet.")
class InitScopeTest(test_util.TensorFlowTestCase):
def testClearsControlDependencies(self):
@ -2736,6 +2744,7 @@ class InitScopeTest(test_util.TensorFlowTestCase):
self.assertFalse(self.evaluate(f()))
@test_util.disable_tfrt("Graphs are not supported yet.")
class GraphTest(test_util.TensorFlowTestCase):
def setUp(self):
@ -3213,6 +3222,7 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
b = variables.Variable([3.0], name="b")
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
@test_util.disable_tfrt("Functions are not supported yet.")
def testColocateWithVariableInFunction(self):
v = variables.Variable(1.)
@ -3248,6 +3258,7 @@ class DeprecatedTest(test_util.TensorFlowTestCase):
class DenseTensorLikeTypeTest(test_util.TensorFlowTestCase):
@test_util.disable_tfrt("Graph is not supported yet.")
def testSuccess(self):
op = ops.Operation(
ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
@ -3421,6 +3432,7 @@ ops.register_tensor_conversion_function(
class CustomConvertToCompositeTensorTest(test_util.TensorFlowTestCase):
@test_util.disable_tfrt("b/154858769")
def testCompositeTensorConversion(self):
"""Tests that a user can register a CompositeTensor converter."""
x = _MyTuple((1, [2., 3.], [[4, 5], [6, 7]]))

View File

@ -1788,23 +1788,29 @@ def disable_mlir_bridge(description): # pylint: disable=unused-argument
# The description is just for documentation purposes.
def disable_tfrt(unused_description):
def disable_tfrt_impl(func):
"""Execute the test method only if tfrt is not enabled."""
def disable_tfrt_impl(cls_or_func):
"""Execute the test only if tfrt is not enabled."""
def decorator(func):
if tf_inspect.isclass(cls_or_func):
if is_tfrt_enabled():
return None
else:
return cls_or_func
else:
def decorator(func):
def decorated(self, *args, **kwargs):
if is_tfrt_enabled():
return
else:
return func(self, *args, **kwargs)
def decorated(self, *args, **kwargs):
if is_tfrt_enabled():
return
else:
return func(self, *args, **kwargs)
return decorated
return decorated
if func is not None:
return decorator(func)
if cls_or_func is not None:
return decorator(cls_or_func)
return decorator
return decorator
return disable_tfrt_impl