From 541a5f754d150de3e3d4f0854ddaff557332eaae Mon Sep 17 00:00:00 2001 From: Michael Delorimier Date: Mon, 1 Feb 2021 18:34:56 -0800 Subject: [PATCH] Added check that the main TensorFlow python library depends on and links XLA. Specifically, check that `python/eager:def_function` depends on `compiler/jit:xla_kernel_creator`. For this to work, I had to fix `check_deps` so it uses `data` is used as well as `deps`. PiperOrigin-RevId: 355072016 Change-Id: I6a32c0c8f5f75619100c5a1af09dba1081c75f2c --- tensorflow/python/eager/BUILD | 10 ++++++++++ tensorflow/tensorflow.bzl | 22 +++++++++++++--------- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index e42bc930539..c0a6525337f 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -1,3 +1,4 @@ +load("//tensorflow:tensorflow.bzl", "check_deps") load("//tensorflow:tensorflow.bzl", "cuda_py_test") # buildifier: disable=same-origin-load @@ -818,6 +819,15 @@ py_library( ], ) +# Check that the main TensorFlow python library depends on and links XLA. +check_deps( + name = "def_function_check_deps", + required_deps = [ + "//tensorflow/compiler/jit:xla_kernel_creator", + ], + deps = [":def_function"], +) + py_library( name = "def_function", srcs = ["def_function.py"], diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index d09a1d9fade..e4d9031ef95 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -1766,15 +1766,19 @@ def tf_custom_op_library_additional_deps_impl(): # and the tf_collected_deps of the dependencies of this target. def _collect_deps_aspect_impl(target, ctx): direct, transitive = [], [] + all_deps = [] if hasattr(ctx.rule.attr, "deps"): - for dep in ctx.rule.attr.deps: - direct.append(dep.label) - if hasattr(dep, "tf_collected_deps"): - transitive.append(dep.tf_collected_deps) + all_deps += ctx.rule.attr.deps + if hasattr(ctx.rule.attr, "data"): + all_deps += ctx.rule.attr.data + for dep in all_deps: + direct.append(dep.label) + if hasattr(dep, "tf_collected_deps"): + transitive.append(dep.tf_collected_deps) return struct(tf_collected_deps = depset(direct = direct, transitive = transitive)) collect_deps_aspect = aspect( - attr_aspects = ["deps"], + attr_aspects = ["deps", "data"], implementation = _collect_deps_aspect_impl, ) @@ -1782,10 +1786,10 @@ def _dep_label(dep): label = dep.label return label.package + ":" + label.name -# This rule checks that the transitive dependencies of targets listed -# in the 'deps' attribute don't depend on the targets listed in -# the 'disallowed_deps' attribute, but do depend on the targets listed in the -# 'required_deps' attribute. +# This rule checks that transitive dependencies don't depend on the targets +# listed in the 'disallowed_deps' attribute, but do depend on the targets listed +# in the 'required_deps' attribute. Dependencies considered are targets in the +# 'deps' attribute or the 'data' attribute. def _check_deps_impl(ctx): required_deps = ctx.attr.required_deps disallowed_deps = ctx.attr.disallowed_deps