From 6a20128c5b9819cc09b3dc948e240d37e06aba4b Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Wed, 25 Sep 2019 17:20:07 -0700 Subject: [PATCH] [TF/XLA Bridge] Alias pass-through parameters when the input graph contains no TF reference variables Previously, expensive copies were required for pass-through parameters. Removing those copies is not safe in the presence of TF reference variables in the graph, so we only remove them for cases when the graph does not contain TF reference variables. PiperOrigin-RevId: 271241769 --- tensorflow/compiler/jit/build_xla_ops_pass.cc | 5 ++ .../compiler/jit/build_xla_ops_pass_test.cc | 8 ++ .../jit/encapsulate_subgraphs_pass.cc | 10 +++ .../compiler/jit/encapsulate_subgraphs_pass.h | 3 + .../jit/encapsulate_subgraphs_pass_test.cc | 74 +++++++++++++++++++ tensorflow/compiler/jit/kernels/BUILD | 2 + tensorflow/compiler/jit/kernels/xla_ops.cc | 26 +++++-- tensorflow/compiler/jit/kernels/xla_ops.h | 3 + 8 files changed, 124 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index 4b313cfb2bb..32f2d1db813 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -472,6 +472,11 @@ Status ReplaceNodeWithXlaCompileAndXlaRun( /*resources=*/cluster_info.resource_inputs, /*must_compile=*/requires_compilation, cluster_info.function); + + bool has_ref_attr; + TF_RETURN_IF_ERROR( + GetNodeAttr(n->attrs(), kXlaHasReferenceVarsAttr, &has_ref_attr)); + xla_compile.operation.node()->AddAttr(kXlaHasReferenceVarsAttr, has_ref_attr); TF_RETURN_IF_ERROR( CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node())); diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc index 8b235e349e1..f434feb18a4 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc @@ -149,8 +149,10 @@ TEST_F(BuildXlaOpsTest, ControlDepsPreserved) { TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib)); Node* call; TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call)); + call->AddAttr(kXlaHasReferenceVarsAttr, false); call->set_requested_device(kXlaDeviceName); Node* write_op = MakeWrite(root, "write"); + write_op->AddAttr(kXlaHasReferenceVarsAttr, false); root.graph()->AddControlEdge(call, write_op); std::unique_ptr graph; @@ -191,8 +193,10 @@ TEST_F(BuildXlaOpsTest, OnNonXlaDevice) { Node* call; TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call)); TF_ASSERT_OK(root.DoShapeInference(call)); + call->AddAttr(kXlaHasReferenceVarsAttr, false); Node* write_op = MakeWrite(root, Output(call), "write_result"); + write_op->AddAttr(kXlaHasReferenceVarsAttr, false); auto xla_compile = NodeWith(Op("_XlaCompile"), Attr("must_compile", false)); auto predicated_compilation_key = @@ -226,8 +230,10 @@ TEST_F(BuildXlaOpsTest, OnXlaDevice) { TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call)); call->set_requested_device(kXlaDeviceName); TF_ASSERT_OK(root.DoShapeInference(call)); + call->AddAttr(kXlaHasReferenceVarsAttr, false); Node* write_op = MakeWrite(root, Output(call), "write_result"); + write_op->AddAttr(kXlaHasReferenceVarsAttr, false); std::unique_ptr graph; TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph)); @@ -250,6 +256,7 @@ TEST_F(BuildXlaOpsTest, NoExtraMergeForEdgeToSink) { TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib)); Node* call; TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call)); + call->AddAttr(kXlaHasReferenceVarsAttr, false); std::unique_ptr graph; TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph)); @@ -278,6 +285,7 @@ TEST_F(BuildXlaOpsTest, NoDeviceToHostCopiesForClustersWithInt32Inputs) { TF_ASSERT_OK( MakeXlaCompiledKernel(root.graph(), "cluster_int32", "C", &call)); call->set_requested_device(kXlaDeviceName); + call->AddAttr(kXlaHasReferenceVarsAttr, false); auto var = ops::VarHandleOp(root.WithOpName("var"), DT_INT32, TensorShape({})); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 71f423af4ec..114800d87f3 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/device_factory.h" @@ -61,6 +62,7 @@ const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs"; const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs"; const char* const kXlaHostTransferSequencerAttr = "_xla_host_transfer_sequencer"; +const char* const kXlaHasReferenceVarsAttr = "_XlaHasReferenceVars"; void SortControlInputs(GraphDef* gdef) { int64 num_nodes = gdef->node_size(); @@ -1311,6 +1313,14 @@ Status EncapsulateSubgraphsPass::Run( } *options.graph = std::move(graph_out); + TF_ASSIGN_OR_RETURN(absl::flat_hash_set ref_related_nodes, + GetNodesRelatedToRefVariables(**options.graph, flr)); + for (Node* node : (*options.graph)->nodes()) { + bool has_ref_vars = ref_related_nodes.contains(node); + node->AddAttr(kXlaHasReferenceVarsAttr, has_ref_vars); + VLOG(3) << "Has ref vars = " << has_ref_vars + << ", node: " << node->def().SerializeAsString(); + } return Status::OK(); } diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 62b752cf40f..50e4149bc08 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -91,6 +91,9 @@ extern const char* const kXlaNumConstantArgsAttr; // Name of the attribute containing the number of resource variable arguments. extern const char* const kXlaNumResourceArgsAttr; +// Name of the attribute defining whether the cluster has reference variables. +extern const char* const kXlaHasReferenceVarsAttr; + // Sorts each node's control inputs by their names. This guarantees that for two // structually equivalent GraphDefs, we get the same traversal ordering on // node's control input fields. diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 9573c4d3b93..d3d6cd96f97 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -2581,5 +2581,79 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); } +void CreateSubgraphTouchingRefVar(const Scope& s) { + Output variable = + ops::Variable(s.WithOpName("variable"), PartialTensorShape{}, DT_FLOAT); + Output read = ops::Identity(s.WithOpName("read_ref_var"), variable); + Output neg = ops::Negate(s.WithOpName("negate_ref"), read); + Output add = ops::Add(s.WithOpName("add_ref"), neg, neg); + + Output constant = + ops::Const(s.WithOpName("constant_ref"), Input::Initializer(0.0)); + s.graph()->AddControlEdge(constant.node(), variable.node()); +} + +TEST(EncapsulateSubgraphsTest, RefVariablesMarked) { + Scope root = Scope::NewRootScope().ExitOnError(); + CreateSubgraphTouchingRefVar(root); + + auto graph = absl::make_unique(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + SessionOptions session_options; + session_options.env = Env::Default(); + GraphOptimizationPassOptions options; + options.session_options = &session_options; + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + options.flib_def = &library; + options.graph = &graph; + + EncapsulateSubgraphsPass pass; + TF_ASSERT_OK(pass.Run(options)); + + for (const Node* node : graph->nodes()) { + bool has_ref_var; + TF_ASSERT_OK( + GetNodeAttr(node->attrs(), kXlaHasReferenceVarsAttr, &has_ref_var)); + EXPECT_TRUE(node->IsSink() || node->IsSource() || has_ref_var) + << "All nodes apart from source and sink can access reference variable"; + } +} + +void CreateSubgraphNotTouchingRefVar(const Scope& s) { + Output constant = + ops::Const(s.WithOpName("constant_normal"), Input::Initializer(0.0)); + Output neg = ops::Negate(s.WithOpName("negate_normal"), constant); + Output add = ops::Add(s.WithOpName("add_normal"), neg, neg); +} + +TEST(EncapsulateSubgraphsTest, NoRefVarsNoAttr) { + Scope root = Scope::NewRootScope().ExitOnError(); + CreateSubgraphNotTouchingRefVar(root); + + auto graph = absl::make_unique(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + // TODO(cheshire): reduce boilerplate for creating + // GraphOptimizationPassOptions here and elsewhere, probably using a macro. + SessionOptions session_options; + session_options.env = Env::Default(); + GraphOptimizationPassOptions options; + options.session_options = &session_options; + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + options.flib_def = &library; + options.graph = &graph; + + EncapsulateSubgraphsPass pass; + TF_ASSERT_OK(pass.Run(options)); + + for (const Node* node : graph->nodes()) { + bool has_ref_var; + TF_ASSERT_OK( + GetNodeAttr(node->attrs(), kXlaHasReferenceVarsAttr, &has_ref_var)); + EXPECT_FALSE(has_ref_var) << "The graph does not have reference variables"; + } +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index e09dfd2b49c..86b3a54c627 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -9,11 +9,13 @@ XLA_OPS_DEPS = [ "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "//tensorflow/compiler/jit:common", + "//tensorflow/compiler/jit:compilation_passes", "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:xla_activity_listener", "//tensorflow/compiler/jit:xla_activity_proto_cc", "//tensorflow/compiler/jit:xla_compilation_cache", "//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration", + "//tensorflow/compiler/jit:xla_cluster_util", "//tensorflow/compiler/jit:xla_launch_util", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:tf2xla_util", diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index d382306f6da..9bf5ee88954 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -18,8 +18,10 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -268,7 +270,7 @@ static Status BuildCompilationCache(OpKernelContext* ctx, } static Status CompileToLocalExecutable( - OpKernelContext* ctx, const NameAttrList& function, + OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars, const XlaPlatformInfo& platform_info, absl::Span resources, absl::Span constants, bool lazy, xla::LocalClient** client, std::map* variables, @@ -313,8 +315,9 @@ static Status CompileToLocalExecutable( options.shape_representation_fn = platform_info.xla_device_metadata()->shape_representation_fn(); } - // TODO(b/138728225): Set options.alias_passthrough_params for clusters - // without ref variables. + // If reference variables are not present in the graph, we can safely alias + // passthrough parameters without performing a copy. + options.alias_passthrough_params = !has_ref_vars; std::map constant_args; for (int i : constants) { @@ -351,8 +354,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { { Status s = CompileToLocalExecutable( - ctx, function_, platform_info_, resources_, constants_, /*lazy=*/false, - &client, &variables, &kernel, &executable); + ctx, function_, /*has_ref_vars=*/true, platform_info_, resources_, + constants_, /*lazy=*/false, &client, &variables, &kernel, &executable); if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU || platform_info_.device_type().type_string() == DEVICE_GPU)) { // Suggest auto jit if the failure was with GPU or CPU. @@ -451,6 +454,14 @@ bool MustCompileAttr(OpKernelConstruction* ctx) { ctx->GetAttr("must_compile", &must_compile)); return must_compile; } + +bool HasRefVars(OpKernelConstruction* ctx) { + bool has_ref_vars; + OP_REQUIRES_OK_RETURN(ctx, false, + ctx->GetAttr(kXlaHasReferenceVarsAttr, &has_ref_vars)); + return has_ref_vars; +} + } // namespace XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) @@ -467,7 +478,8 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx) resources_(ResourcesVector(ctx)), function_(FunctionAttr(ctx)), platform_info_(PlatformInfoFromContext(ctx)), - must_compile_(MustCompileAttr(ctx)) {} + must_compile_(MustCompileAttr(ctx)), + has_ref_vars_(HasRefVars(ctx)) {} void XlaCompileOp::Compute(OpKernelContext* ctx) { VLOG(3) << "XlaCompileOp " << def().name() @@ -488,7 +500,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { executable = nullptr; } else { Status status = CompileToLocalExecutable( - ctx, function_, platform_info_, resources_, constants_, + ctx, function_, has_ref_vars_, platform_info_, resources_, constants_, /*lazy=*/!must_compile_, &client, &variables, &kernel, &executable); if (must_compile_ || status.code() != error::UNIMPLEMENTED) { OP_REQUIRES_OK(ctx, status); diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h index bc6829a6c77..3848ac72aac 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.h +++ b/tensorflow/compiler/jit/kernels/xla_ops.h @@ -153,6 +153,9 @@ class XlaCompileOp : public OpKernel { const bool must_compile_; + // Whether the graph has TF reference variables. + const bool has_ref_vars_; + // cannot_compile_cluster_ is set to true if XLA returns an Unimplemented // error when compiling the cluster this _XlaCompile is supposed to compile. // If `cannot_compile_cluster_` is true then we avoid compiling this cluster