Add xla target for tensorflow tests that request them.
Change: 145856327
This commit is contained in:
parent
9d7899f992
commit
ce2a102760
4
configure
vendored
4
configure
vendored
@ -175,10 +175,10 @@ done
|
||||
|
||||
if [ "$TF_ENABLE_XLA" == "1" ]; then
|
||||
# Update Bazel build configuration.
|
||||
perl -pi -e "s,WITH_XLA_SUPPORT = (False|True),WITH_XLA_SUPPORT = True,s" tensorflow/core/platform/default/build_config.bzl
|
||||
sed -i -e "s/^WITH_XLA_SUPPORT = [FT].*/WITH_XLA_SUPPORT = True/" tensorflow/core/platform/default/build_config_root.bzl
|
||||
else
|
||||
# Update Bazel build configuration.
|
||||
perl -pi -e "s,WITH_XLA_SUPPORT = (False|True),WITH_XLA_SUPPORT = False,s" tensorflow/core/platform/default/build_config.bzl
|
||||
sed -i -e "s/^WITH_XLA_SUPPORT = [FT].*/WITH_XLA_SUPPORT = False/" tensorflow/core/platform/default/build_config_root.bzl
|
||||
fi
|
||||
|
||||
|
||||
|
@ -7,7 +7,6 @@ load("//tensorflow:tensorflow.bzl", "if_not_mobile")
|
||||
# configure may change the following lines
|
||||
WITH_GCP_SUPPORT = False
|
||||
WITH_HDFS_SUPPORT = False
|
||||
WITH_XLA_SUPPORT = False
|
||||
WITH_JEMALLOC = True
|
||||
|
||||
# Appends a suffix to a list of deps.
|
||||
@ -242,15 +241,3 @@ def tf_additional_cloud_kernel_deps():
|
||||
#if WITH_GCP_SUPPORT:
|
||||
# deps = if_not_mobile(["//tensorflow/core:cloud_ops_op_lib"])
|
||||
return deps
|
||||
|
||||
def tf_additional_plugin_deps():
|
||||
deps = []
|
||||
if WITH_XLA_SUPPORT:
|
||||
deps.append("//tensorflow/compiler/jit")
|
||||
return deps
|
||||
|
||||
def tf_additional_license_deps():
|
||||
licenses = []
|
||||
if WITH_XLA_SUPPORT:
|
||||
licenses.append("@llvm//:LICENSE.TXT")
|
||||
return licenses
|
||||
|
@ -2,8 +2,25 @@
|
||||
# The functions in this file might be referred by tensorflow.bzl. They have to
|
||||
# be separate to avoid cyclic references.
|
||||
|
||||
WITH_XLA_SUPPORT = False
|
||||
|
||||
def tf_cuda_tests_tags():
|
||||
return ["local"]
|
||||
|
||||
def tf_sycl_tests_tags():
|
||||
return ["local"]
|
||||
|
||||
def tf_additional_plugin_deps():
|
||||
deps = []
|
||||
if WITH_XLA_SUPPORT:
|
||||
deps.append("//tensorflow/compiler/jit")
|
||||
return deps
|
||||
|
||||
def tf_additional_xla_deps_py():
|
||||
return []
|
||||
|
||||
def tf_additional_license_deps():
|
||||
licenses = []
|
||||
if WITH_XLA_SUPPORT:
|
||||
licenses.append("@llvm//:LICENSE.TXT")
|
||||
return licenses
|
||||
|
@ -23,7 +23,7 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library_py")
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_lib_deps")
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_plugin_deps")
|
||||
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_plugin_deps")
|
||||
load("//tensorflow/python:build_defs.bzl", "tf_gen_op_wrapper_private_py")
|
||||
|
||||
py_library(
|
||||
|
@ -12,6 +12,7 @@ load(
|
||||
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||
"tf_cuda_tests_tags",
|
||||
"tf_sycl_tests_tags",
|
||||
"tf_additional_xla_deps_py",
|
||||
)
|
||||
load(
|
||||
"@local_config_cuda//cuda:build_defs.bzl",
|
||||
@ -789,7 +790,10 @@ def py_test(deps=[], **kwargs):
|
||||
**kwargs)
|
||||
|
||||
def tf_py_test(name, srcs, size="medium", data=[], main=None, args=[],
|
||||
tags=[], shard_count=1, additional_deps=[], flaky=0):
|
||||
tags=[], shard_count=1, additional_deps=[], flaky=0,
|
||||
xla_enabled=False):
|
||||
if xla_enabled:
|
||||
additional_deps += tf_additional_xla_deps_py()
|
||||
native.py_test(
|
||||
name=name,
|
||||
size=size,
|
||||
@ -811,7 +815,8 @@ def tf_py_test(name, srcs, size="medium", data=[], main=None, args=[],
|
||||
srcs_version="PY2AND3")
|
||||
|
||||
def cuda_py_test(name, srcs, size="medium", data=[], main=None, args=[],
|
||||
shard_count=1, additional_deps=[], tags=[], flaky=0):
|
||||
shard_count=1, additional_deps=[], tags=[], flaky=0,
|
||||
xla_enabled=False):
|
||||
test_tags = tags + tf_cuda_tests_tags()
|
||||
tf_py_test(name=name,
|
||||
size=size,
|
||||
@ -822,10 +827,12 @@ def cuda_py_test(name, srcs, size="medium", data=[], main=None, args=[],
|
||||
tags=test_tags,
|
||||
shard_count=shard_count,
|
||||
additional_deps=additional_deps,
|
||||
flaky=flaky)
|
||||
flaky=flaky,
|
||||
xla_enabled=xla_enabled)
|
||||
|
||||
def sycl_py_test(name, srcs, size="medium", data=[], main=None, args=[],
|
||||
shard_count=1, additional_deps=[], tags=[], flaky=0):
|
||||
shard_count=1, additional_deps=[], tags=[], flaky=0,
|
||||
xla_enabled=False):
|
||||
test_tags = tags + tf_sycl_tests_tags()
|
||||
tf_py_test(name=name,
|
||||
size=size,
|
||||
@ -836,7 +843,8 @@ def sycl_py_test(name, srcs, size="medium", data=[], main=None, args=[],
|
||||
tags=test_tags,
|
||||
shard_count=shard_count,
|
||||
additional_deps=additional_deps,
|
||||
flaky=flaky)
|
||||
flaky=flaky,
|
||||
xla_enabled=xla_enabled)
|
||||
|
||||
def py_tests(name,
|
||||
srcs,
|
||||
@ -845,7 +853,8 @@ def py_tests(name,
|
||||
data=[],
|
||||
tags=[],
|
||||
shard_count=1,
|
||||
prefix=""):
|
||||
prefix="",
|
||||
xla_enabled=False):
|
||||
for src in srcs:
|
||||
test_name = src.split("/")[-1].split(".")[0]
|
||||
if prefix:
|
||||
@ -857,13 +866,15 @@ def py_tests(name,
|
||||
tags=tags,
|
||||
shard_count=shard_count,
|
||||
data=data,
|
||||
additional_deps=additional_deps)
|
||||
additional_deps=additional_deps,
|
||||
xla_enabled=xla_enabled)
|
||||
|
||||
def cuda_py_tests(name, srcs, size="medium", additional_deps=[], data=[],
|
||||
shard_count=1, tags=[], prefix=""):
|
||||
shard_count=1, tags=[], prefix="", xla_enabled=False):
|
||||
test_tags = tags + tf_cuda_tests_tags()
|
||||
py_tests(name=name, size=size, srcs=srcs, additional_deps=additional_deps,
|
||||
data=data, tags=test_tags, shard_count=shard_count,prefix=prefix)
|
||||
data=data, tags=test_tags, shard_count=shard_count,prefix=prefix,
|
||||
xla_enabled=xla_enabled)
|
||||
|
||||
# Creates a genrule named <name> for running tools/proto_text's generator to
|
||||
# make the proto_text functions, for the protos passed in <srcs>.
|
||||
|
@ -4,7 +4,7 @@
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "transitive_hdrs")
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_license_deps")
|
||||
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps")
|
||||
|
||||
# This returns a list of headers of all public header libraries (e.g.,
|
||||
# framework, lib), and all of the transitive dependencies of those
|
||||
|
Loading…
Reference in New Issue
Block a user