Remove //third_party/tensorflow/c/c_api_no_xla's dependency on XLA. This was accomplished by removing the dependency from kernel_and_device to xla_kernel_creator. Now //third_party/tensorflow/core/common_runtime/eager:core explicitly depends on xla_kernel_creator since users of this :core expect it to include XLA.

Because changes like this could accidentally cause TensorFlow to lose its link to XLA, I've added a check that `//third_party/tensorflow/python/eager:def_function` depends on XLA.

PiperOrigin-RevId: 354426190
Change-Id: Ie6e687fa49e3fdb62146d79f38b72e7862054083
This commit is contained in:
Michael Delorimier 2021-01-28 16:54:57 -08:00 committed by TensorFlower Gardener
parent bfc7ac4832
commit 1aa07125d5
5 changed files with 43 additions and 15 deletions

View File

@ -730,6 +730,7 @@ bzl_library(
"//tensorflow/core/platform/default:cuda_build_defs_bzl",
"//third_party/mkl:build_defs_bzl",
"//third_party/mkl_dnn:build_defs_bzl",
"@bazel_skylib//lib:new_sets",
"@bazel_skylib//rules:common_settings",
"@local_config_cuda//cuda:build_defs_bzl",
"@local_config_rocm//rocm:build_defs_bzl",

View File

@ -4,6 +4,7 @@
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load(
"//tensorflow:tensorflow.bzl",
"check_deps",
"tf_cc_test",
"tf_copts",
"tf_cuda_library",
@ -174,6 +175,13 @@ tf_cuda_library(
}),
)
# Check that c_api_no_xla does not depend on xla.
check_deps(
name = "c_api_no_xla_check_deps",
disallowed_deps = ["//tensorflow/compiler/jit:xla_kernel_creator"],
deps = [":c_api_no_xla"],
)
tf_cuda_library(
name = "c_api_no_xla",
srcs = [

View File

@ -265,7 +265,7 @@ cc_library(
deps = [
":tfr_decompose_ctx",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime/eager:core",
"//tensorflow/core/common_runtime/eager:core_no_xla",
"//tensorflow/core/common_runtime/eager:eager_op_rewrite_registry",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",

View File

@ -23,6 +23,16 @@ package(
# TODO(b/152902651): Remove this file once all circular dependencies are resolved.
tf_cuda_library(
name = "core",
visibility = ["//tensorflow:internal"],
deps = [
":core_no_xla",
"//tensorflow/compiler/jit:xla_kernel_creator",
],
alwayslink = 1,
)
tf_cuda_library(
name = "core_no_xla",
srcs = [
"core.cc",
],
@ -422,10 +432,7 @@ tf_cuda_library(
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//tensorflow:windows": KERNEL_AND_DEVICE_DEPS,
"//conditions:default": KERNEL_AND_DEVICE_DEPS + [
"//tensorflow/compiler/jit:xla_kernel_creator",
],
"//conditions:default": KERNEL_AND_DEVICE_DEPS,
}),
)

View File

@ -45,6 +45,7 @@ load(
"if_mkl_open_source_only",
"if_mkldnn_threadpool",
)
load("@bazel_skylib//lib:new_sets.bzl", "sets")
load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo")
# version for the shared libraries, can
@ -1783,20 +1784,27 @@ def _dep_label(dep):
# This rule checks that the transitive dependencies of targets listed
# in the 'deps' attribute don't depend on the targets listed in
# the 'disallowed_deps' attribute.
# the 'disallowed_deps' attribute, but do depend on the targets listed in the
# 'required_deps' attribute.
def _check_deps_impl(ctx):
required_deps = ctx.attr.required_deps
disallowed_deps = ctx.attr.disallowed_deps
for input_dep in ctx.attr.deps:
if not hasattr(input_dep, "tf_collected_deps"):
continue
for dep in input_dep.tf_collected_deps.to_list():
for disallowed_dep in disallowed_deps:
if dep == disallowed_dep.label:
fail(
_dep_label(input_dep) + " cannot depend on " + _dep_label(
disallowed_dep,
),
)
collected_deps = sets.make(input_dep.tf_collected_deps.to_list())
for disallowed_dep in disallowed_deps:
if sets.contains(collected_deps, disallowed_dep.label):
fail(
_dep_label(input_dep) + " cannot depend on " +
_dep_label(disallowed_dep),
)
for required_dep in required_deps:
if not sets.contains(collected_deps, required_dep.label):
fail(
_dep_label(input_dep) + " must depend on " +
_dep_label(required_dep),
)
return struct()
check_deps = rule(
@ -1808,7 +1816,11 @@ check_deps = rule(
allow_files = True,
),
"disallowed_deps": attr.label_list(
mandatory = True,
default = [],
allow_files = True,
),
"required_deps": attr.label_list(
default = [],
allow_files = True,
),
},