Add an option to enable MLIR bridge for tf_xla_py_test rule

Using this option will run each test with and without MLIR bridge. It uses similar mechanism as xla_enable_strict_auto_jit option.

Using this new option for unary_mlir_ops_test for now and it will be combined with unary_ops_test in a separate change.

PiperOrigin-RevId: 306287817
Change-Id: Ie3b7351b3b711bec5028aa342f0dec50d94fef32
This commit is contained in:
Smit Hinsu 2020-04-13 12:38:17 -07:00 committed by TensorFlower Gardener
parent c0f84b60d8
commit 06c4bb3ef6
7 changed files with 86 additions and 20 deletions

View File

@ -1366,6 +1366,7 @@ tf_xla_py_test(
name = "unary_mlir_ops_test",
size = "medium",
srcs = ["unary_mlir_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip

View File

@ -26,6 +26,7 @@ def tf_xla_py_test(
enabled_backends = None,
disabled_backends = None,
use_xla_device = True,
enable_mlir_bridge = False,
**kwargs):
"""Generates py_test targets, one per XLA backend.
@ -55,6 +56,8 @@ def tf_xla_py_test(
use_xla_device: If true then the --test_device argument is set to XLA_CPU
and XLA_GPU for the CPU and GPU tests. Otherwise it is set to CPU and
GPU.
enable_mlir_bridge: If true, then runs the test with and without mlir
bridge enabled.
**kwargs: keyword arguments passed onto the generated py_test() rules.
"""
if enabled_backends == None:
@ -104,19 +107,33 @@ def tf_xla_py_test(
fail("Unknown backend {}".format(backend))
test_tags = tags + backend_tags
native.py_test(
name = test_name,
srcs = srcs,
srcs_version = "PY2AND3",
args = backend_args,
main = "{}.py".format(name) if main == None else main,
data = data + backend_data,
deps = deps + backend_deps,
tags = test_tags,
exec_properties = tf_exec_properties({"tags": test_tags}),
**kwargs
)
test_names.append(test_name)
enable_mlir_bridge_options = [False]
if enable_mlir_bridge:
enable_mlir_bridge_options.append(True)
for mlir_option in enable_mlir_bridge_options:
extra_dep = []
updated_name = test_name
if mlir_option:
extra_dep = ["//tensorflow/python:is_mlir_bridge_test_true"]
if updated_name.endswith("_test"):
updated_name = updated_name[:-5]
updated_name += "_mlir_bridge_test"
native.py_test(
name = updated_name,
srcs = srcs,
srcs_version = "PY2AND3",
args = backend_args,
main = "{}.py".format(name) if main == None else main,
data = data + backend_data,
deps = deps + backend_deps + extra_dep,
tags = test_tags,
exec_properties = tf_exec_properties({"tags": test_tags}),
**kwargs
)
test_names.append(updated_name)
native.test_suite(name = name, tests = test_names)
def generate_backend_suites(backends = []):

View File

@ -21,7 +21,6 @@ from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@ -31,10 +30,6 @@ from tensorflow.python.platform import googletest
class UnaryOpsTest(xla_test.XLATestCase):
"""Test cases for unary operators."""
def __init__(self, method_name='runTest'):
super(UnaryOpsTest, self).__init__(method_name)
context.context().enable_mlir_bridge = True
def _assertOpOutputMatchesExpected(self,
op,
inp,

View File

@ -34,6 +34,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import flags
@ -82,6 +83,8 @@ class XLATestCase(test.TestCase):
def __init__(self, method_name='runTest'):
super(XLATestCase, self).__init__(method_name)
context.context().enable_mlir_bridge = test_util.is_mlir_bridge_enabled()
self.device = FLAGS.test_device
self.has_custom_call = (self.device == 'XLA_CPU')
self._all_tf_types = set([

View File

@ -1965,6 +1965,14 @@ py_library(
srcs_version = "PY2AND3",
)
# Including this as a dependency will result in tests using
# :framework_test_lib to use MLIR.
py_library(
name = "is_mlir_bridge_test_true",
srcs = ["framework/is_mlir_bridge_test_true.py"],
srcs_version = "PY2AND3",
)
# Including this as a dependency will result in tests to use TFRT.
py_library(
name = "is_tfrt_test_true",

View File

@ -0,0 +1,30 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Including this as a dependency will result in tests using MLIR bridge.
This function is defined by default in test_util.py to False. The test_util then
attempts to import this module. If this file is made available through the BUILD
rule, then this function is overridden and will instead cause Tensorflow graphs
to be compiled with MLIR bridge.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def is_mlir_bridge_enabled():
"""Returns true to if MLIR bridge should be enabled for tests."""
return True

View File

@ -94,10 +94,22 @@ def is_xla_enabled():
try:
from tensorflow.python.framework.is_xla_test_true import is_xla_enabled # pylint: disable=g-import-not-at-top
except:
from tensorflow.python.framework.is_xla_test_true import is_xla_enabled # pylint: disable=g-import-not-at-top, unused-import
except Exception: # pylint: disable=broad-except
pass
# Uses the same mechanism as above to selectively enable MLIR compilation.
def is_mlir_bridge_enabled():
return False
try:
from tensorflow.python.framework.is_mlir_bridge_test_true import is_mlir_bridge_enabled # pylint: disable=g-import-not-at-top, unused-import
except Exception: # pylint: disable=broad-except
pass
def _get_object_count_by_type():
return collections.Counter([type(obj).__name__ for obj in gc.get_objects()])