diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index b6f5e11b856..1ee25813320 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1366,6 +1366,7 @@ tf_xla_py_test( name = "unary_mlir_ops_test", size = "medium", srcs = ["unary_mlir_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index 6a3f97d6d08..19a1d62cddd 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -26,6 +26,7 @@ def tf_xla_py_test( enabled_backends = None, disabled_backends = None, use_xla_device = True, + enable_mlir_bridge = False, **kwargs): """Generates py_test targets, one per XLA backend. @@ -55,6 +56,8 @@ def tf_xla_py_test( use_xla_device: If true then the --test_device argument is set to XLA_CPU and XLA_GPU for the CPU and GPU tests. Otherwise it is set to CPU and GPU. + enable_mlir_bridge: If true, then runs the test with and without mlir + bridge enabled. **kwargs: keyword arguments passed onto the generated py_test() rules. """ if enabled_backends == None: @@ -104,19 +107,33 @@ def tf_xla_py_test( fail("Unknown backend {}".format(backend)) test_tags = tags + backend_tags - native.py_test( - name = test_name, - srcs = srcs, - srcs_version = "PY2AND3", - args = backend_args, - main = "{}.py".format(name) if main == None else main, - data = data + backend_data, - deps = deps + backend_deps, - tags = test_tags, - exec_properties = tf_exec_properties({"tags": test_tags}), - **kwargs - ) - test_names.append(test_name) + + enable_mlir_bridge_options = [False] + if enable_mlir_bridge: + enable_mlir_bridge_options.append(True) + + for mlir_option in enable_mlir_bridge_options: + extra_dep = [] + updated_name = test_name + if mlir_option: + extra_dep = ["//tensorflow/python:is_mlir_bridge_test_true"] + if updated_name.endswith("_test"): + updated_name = updated_name[:-5] + updated_name += "_mlir_bridge_test" + + native.py_test( + name = updated_name, + srcs = srcs, + srcs_version = "PY2AND3", + args = backend_args, + main = "{}.py".format(name) if main == None else main, + data = data + backend_data, + deps = deps + backend_deps + extra_dep, + tags = test_tags, + exec_properties = tf_exec_properties({"tags": test_tags}), + **kwargs + ) + test_names.append(updated_name) native.test_suite(name = name, tests = test_names) def generate_backend_suites(backends = []): diff --git a/tensorflow/compiler/tests/unary_mlir_ops_test.py b/tensorflow/compiler/tests/unary_mlir_ops_test.py index 92be8e04b71..4238877c761 100644 --- a/tensorflow/compiler/tests/unary_mlir_ops_test.py +++ b/tensorflow/compiler/tests/unary_mlir_ops_test.py @@ -21,7 +21,6 @@ from __future__ import print_function import numpy as np from tensorflow.compiler.tests import xla_test -from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -31,10 +30,6 @@ from tensorflow.python.platform import googletest class UnaryOpsTest(xla_test.XLATestCase): """Test cases for unary operators.""" - def __init__(self, method_name='runTest'): - super(UnaryOpsTest, self).__init__(method_name) - context.context().enable_mlir_bridge = True - def _assertOpOutputMatchesExpected(self, op, inp, diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index d6e02ecc827..f5f63cb60aa 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -34,6 +34,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables from tensorflow.python.platform import flags @@ -82,6 +83,8 @@ class XLATestCase(test.TestCase): def __init__(self, method_name='runTest'): super(XLATestCase, self).__init__(method_name) + context.context().enable_mlir_bridge = test_util.is_mlir_bridge_enabled() + self.device = FLAGS.test_device self.has_custom_call = (self.device == 'XLA_CPU') self._all_tf_types = set([ diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 4fc970d98fb..207fde4c76c 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1965,6 +1965,14 @@ py_library( srcs_version = "PY2AND3", ) +# Including this as a dependency will result in tests using +# :framework_test_lib to use MLIR. +py_library( + name = "is_mlir_bridge_test_true", + srcs = ["framework/is_mlir_bridge_test_true.py"], + srcs_version = "PY2AND3", +) + # Including this as a dependency will result in tests to use TFRT. py_library( name = "is_tfrt_test_true", diff --git a/tensorflow/python/framework/is_mlir_bridge_test_true.py b/tensorflow/python/framework/is_mlir_bridge_test_true.py new file mode 100644 index 00000000000..9ef94cd8222 --- /dev/null +++ b/tensorflow/python/framework/is_mlir_bridge_test_true.py @@ -0,0 +1,30 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Including this as a dependency will result in tests using MLIR bridge. + +This function is defined by default in test_util.py to False. The test_util then +attempts to import this module. If this file is made available through the BUILD +rule, then this function is overridden and will instead cause Tensorflow graphs +to be compiled with MLIR bridge. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +def is_mlir_bridge_enabled(): + """Returns true to if MLIR bridge should be enabled for tests.""" + return True diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 65b59c72e8a..9b423bf10c5 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -94,10 +94,22 @@ def is_xla_enabled(): try: - from tensorflow.python.framework.is_xla_test_true import is_xla_enabled # pylint: disable=g-import-not-at-top -except: + from tensorflow.python.framework.is_xla_test_true import is_xla_enabled # pylint: disable=g-import-not-at-top, unused-import +except Exception: # pylint: disable=broad-except pass + +# Uses the same mechanism as above to selectively enable MLIR compilation. +def is_mlir_bridge_enabled(): + return False + + +try: + from tensorflow.python.framework.is_mlir_bridge_test_true import is_mlir_bridge_enabled # pylint: disable=g-import-not-at-top, unused-import +except Exception: # pylint: disable=broad-except + pass + + def _get_object_count_by_type(): return collections.Counter([type(obj).__name__ for obj in gc.get_objects()])