Change tf_xla_py_test to use tf's implementation of py_test.

Otherwise, any pip tests we run becomes a normal tests, and our pip tests do not verify any behaviour for xla_py_tests

PiperOrigin-RevId: 322652875
Change-Id: Icb56f419ab632870f7b0bcee23b57ee66eff0971
This commit is contained in:
Gunhan Gulsoy 2020-07-22 14:16:26 -07:00 committed by TensorFlower Gardener
parent a0b45cb107
commit a225c77008
2 changed files with 8 additions and 3 deletions

View File

@ -8,6 +8,7 @@ load(
"tf_cuda_tests_tags",
"tf_exec_properties",
)
load("//tensorflow:tensorflow.bzl", "py_test")
def all_backends():
b = ["cpu"] + plugins.keys()
@ -121,7 +122,7 @@ def tf_xla_py_test(
updated_name = updated_name[:-5]
updated_name += "_mlir_bridge_test"
native.py_test(
py_test(
name = updated_name,
srcs = srcs,
srcs_version = "PY2AND3",

View File

@ -2198,10 +2198,14 @@ def pywrap_tensorflow_macro(
# Note that this only works on Windows. See the definition of
# //third_party/tensorflow/tools/pip_package:win_pip_package_marker for specific reasons.
# 2. When --define=no_tensorflow_py_deps=false (by default), it's a normal py_test.
def py_test(deps = [], data = [], kernels = [], **kwargs):
def py_test(deps = [], data = [], kernels = [], exec_properties = None, **kwargs):
# Python version placeholder
if kwargs.get("python_version", None) == "PY3":
kwargs["tags"] = kwargs.get("tags", []) + ["no_oss_py2"]
if not exec_properties:
exec_properties = tf_exec_properties(kwargs)
native.py_test(
# TODO(jlebar): Ideally we'd use tcmalloc here.,
deps = select({
@ -2212,7 +2216,7 @@ def py_test(deps = [], data = [], kernels = [], **kwargs):
"//conditions:default": kernels,
clean_dep("//tensorflow:no_tensorflow_py_deps"): ["//tensorflow/tools/pip_package:win_pip_package_marker"],
}),
exec_properties = tf_exec_properties(kwargs),
exec_properties = exec_properties,
**kwargs
)