Added check that the main TensorFlow python library depends on and links XLA. Specifically, check that python/eager:def_function depends on compiler/jit:xla_kernel_creator.

For this to work, I had to fix `check_deps` so it uses `data` is used as well as `deps`.

PiperOrigin-RevId: 355072016
Change-Id: I6a32c0c8f5f75619100c5a1af09dba1081c75f2c
This commit is contained in:
Michael Delorimier 2021-02-01 18:34:56 -08:00 committed by TensorFlower Gardener
parent 5031630885
commit 541a5f754d
2 changed files with 23 additions and 9 deletions

View File

@ -1,3 +1,4 @@
load("//tensorflow:tensorflow.bzl", "check_deps")
load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test")
# buildifier: disable=same-origin-load # buildifier: disable=same-origin-load
@ -818,6 +819,15 @@ py_library(
], ],
) )
# Check that the main TensorFlow python library depends on and links XLA.
check_deps(
name = "def_function_check_deps",
required_deps = [
"//tensorflow/compiler/jit:xla_kernel_creator",
],
deps = [":def_function"],
)
py_library( py_library(
name = "def_function", name = "def_function",
srcs = ["def_function.py"], srcs = ["def_function.py"],

View File

@ -1766,15 +1766,19 @@ def tf_custom_op_library_additional_deps_impl():
# and the tf_collected_deps of the dependencies of this target. # and the tf_collected_deps of the dependencies of this target.
def _collect_deps_aspect_impl(target, ctx): def _collect_deps_aspect_impl(target, ctx):
direct, transitive = [], [] direct, transitive = [], []
all_deps = []
if hasattr(ctx.rule.attr, "deps"): if hasattr(ctx.rule.attr, "deps"):
for dep in ctx.rule.attr.deps: all_deps += ctx.rule.attr.deps
direct.append(dep.label) if hasattr(ctx.rule.attr, "data"):
if hasattr(dep, "tf_collected_deps"): all_deps += ctx.rule.attr.data
transitive.append(dep.tf_collected_deps) for dep in all_deps:
direct.append(dep.label)
if hasattr(dep, "tf_collected_deps"):
transitive.append(dep.tf_collected_deps)
return struct(tf_collected_deps = depset(direct = direct, transitive = transitive)) return struct(tf_collected_deps = depset(direct = direct, transitive = transitive))
collect_deps_aspect = aspect( collect_deps_aspect = aspect(
attr_aspects = ["deps"], attr_aspects = ["deps", "data"],
implementation = _collect_deps_aspect_impl, implementation = _collect_deps_aspect_impl,
) )
@ -1782,10 +1786,10 @@ def _dep_label(dep):
label = dep.label label = dep.label
return label.package + ":" + label.name return label.package + ":" + label.name
# This rule checks that the transitive dependencies of targets listed # This rule checks that transitive dependencies don't depend on the targets
# in the 'deps' attribute don't depend on the targets listed in # listed in the 'disallowed_deps' attribute, but do depend on the targets listed
# the 'disallowed_deps' attribute, but do depend on the targets listed in the # in the 'required_deps' attribute. Dependencies considered are targets in the
# 'required_deps' attribute. # 'deps' attribute or the 'data' attribute.
def _check_deps_impl(ctx): def _check_deps_impl(ctx):
required_deps = ctx.attr.required_deps required_deps = ctx.attr.required_deps
disallowed_deps = ctx.attr.disallowed_deps disallowed_deps = ctx.attr.disallowed_deps