[TF/XLA Bridge] Alias pass-through parameters when the input graph contains no TF reference variables
Previously, expensive copies were required for pass-through parameters. Removing those copies is not safe in the presence of TF reference variables in the graph, so we only remove them for cases when the graph does not contain TF reference variables. PiperOrigin-RevId: 271241769
This commit is contained in:
parent
b7ce325a0d
commit
6a20128c5b
@ -472,6 +472,11 @@ Status ReplaceNodeWithXlaCompileAndXlaRun(
|
||||
/*resources=*/cluster_info.resource_inputs,
|
||||
/*must_compile=*/requires_compilation,
|
||||
cluster_info.function);
|
||||
|
||||
bool has_ref_attr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(n->attrs(), kXlaHasReferenceVarsAttr, &has_ref_attr));
|
||||
xla_compile.operation.node()->AddAttr(kXlaHasReferenceVarsAttr, has_ref_attr);
|
||||
TF_RETURN_IF_ERROR(
|
||||
CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node()));
|
||||
|
||||
|
@ -149,8 +149,10 @@ TEST_F(BuildXlaOpsTest, ControlDepsPreserved) {
|
||||
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib));
|
||||
Node* call;
|
||||
TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call));
|
||||
call->AddAttr(kXlaHasReferenceVarsAttr, false);
|
||||
call->set_requested_device(kXlaDeviceName);
|
||||
Node* write_op = MakeWrite(root, "write");
|
||||
write_op->AddAttr(kXlaHasReferenceVarsAttr, false);
|
||||
root.graph()->AddControlEdge(call, write_op);
|
||||
|
||||
std::unique_ptr<Graph> graph;
|
||||
@ -191,8 +193,10 @@ TEST_F(BuildXlaOpsTest, OnNonXlaDevice) {
|
||||
Node* call;
|
||||
TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call));
|
||||
TF_ASSERT_OK(root.DoShapeInference(call));
|
||||
call->AddAttr(kXlaHasReferenceVarsAttr, false);
|
||||
|
||||
Node* write_op = MakeWrite(root, Output(call), "write_result");
|
||||
write_op->AddAttr(kXlaHasReferenceVarsAttr, false);
|
||||
|
||||
auto xla_compile = NodeWith(Op("_XlaCompile"), Attr("must_compile", false));
|
||||
auto predicated_compilation_key =
|
||||
@ -226,8 +230,10 @@ TEST_F(BuildXlaOpsTest, OnXlaDevice) {
|
||||
TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call));
|
||||
call->set_requested_device(kXlaDeviceName);
|
||||
TF_ASSERT_OK(root.DoShapeInference(call));
|
||||
call->AddAttr(kXlaHasReferenceVarsAttr, false);
|
||||
|
||||
Node* write_op = MakeWrite(root, Output(call), "write_result");
|
||||
write_op->AddAttr(kXlaHasReferenceVarsAttr, false);
|
||||
|
||||
std::unique_ptr<Graph> graph;
|
||||
TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph));
|
||||
@ -250,6 +256,7 @@ TEST_F(BuildXlaOpsTest, NoExtraMergeForEdgeToSink) {
|
||||
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib));
|
||||
Node* call;
|
||||
TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call));
|
||||
call->AddAttr(kXlaHasReferenceVarsAttr, false);
|
||||
|
||||
std::unique_ptr<Graph> graph;
|
||||
TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph));
|
||||
@ -278,6 +285,7 @@ TEST_F(BuildXlaOpsTest, NoDeviceToHostCopiesForClustersWithInt32Inputs) {
|
||||
TF_ASSERT_OK(
|
||||
MakeXlaCompiledKernel(root.graph(), "cluster_int32", "C", &call));
|
||||
call->set_requested_device(kXlaDeviceName);
|
||||
call->AddAttr(kXlaHasReferenceVarsAttr, false);
|
||||
|
||||
auto var =
|
||||
ops::VarHandleOp(root.WithOpName("var"), DT_INT32, TensorShape({}));
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
@ -61,6 +62,7 @@ const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs";
|
||||
const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs";
|
||||
const char* const kXlaHostTransferSequencerAttr =
|
||||
"_xla_host_transfer_sequencer";
|
||||
const char* const kXlaHasReferenceVarsAttr = "_XlaHasReferenceVars";
|
||||
|
||||
void SortControlInputs(GraphDef* gdef) {
|
||||
int64 num_nodes = gdef->node_size();
|
||||
@ -1311,6 +1313,14 @@ Status EncapsulateSubgraphsPass::Run(
|
||||
}
|
||||
|
||||
*options.graph = std::move(graph_out);
|
||||
TF_ASSIGN_OR_RETURN(absl::flat_hash_set<Node*> ref_related_nodes,
|
||||
GetNodesRelatedToRefVariables(**options.graph, flr));
|
||||
for (Node* node : (*options.graph)->nodes()) {
|
||||
bool has_ref_vars = ref_related_nodes.contains(node);
|
||||
node->AddAttr(kXlaHasReferenceVarsAttr, has_ref_vars);
|
||||
VLOG(3) << "Has ref vars = " << has_ref_vars
|
||||
<< ", node: " << node->def().SerializeAsString();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -91,6 +91,9 @@ extern const char* const kXlaNumConstantArgsAttr;
|
||||
// Name of the attribute containing the number of resource variable arguments.
|
||||
extern const char* const kXlaNumResourceArgsAttr;
|
||||
|
||||
// Name of the attribute defining whether the cluster has reference variables.
|
||||
extern const char* const kXlaHasReferenceVarsAttr;
|
||||
|
||||
// Sorts each node's control inputs by their names. This guarantees that for two
|
||||
// structually equivalent GraphDefs, we get the same traversal ordering on
|
||||
// node's control input fields.
|
||||
|
@ -2581,5 +2581,79 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
|
||||
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
|
||||
}
|
||||
|
||||
void CreateSubgraphTouchingRefVar(const Scope& s) {
|
||||
Output variable =
|
||||
ops::Variable(s.WithOpName("variable"), PartialTensorShape{}, DT_FLOAT);
|
||||
Output read = ops::Identity(s.WithOpName("read_ref_var"), variable);
|
||||
Output neg = ops::Negate(s.WithOpName("negate_ref"), read);
|
||||
Output add = ops::Add(s.WithOpName("add_ref"), neg, neg);
|
||||
|
||||
Output constant =
|
||||
ops::Const(s.WithOpName("constant_ref"), Input::Initializer(0.0));
|
||||
s.graph()->AddControlEdge(constant.node(), variable.node());
|
||||
}
|
||||
|
||||
TEST(EncapsulateSubgraphsTest, RefVariablesMarked) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
CreateSubgraphTouchingRefVar(root);
|
||||
|
||||
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
SessionOptions session_options;
|
||||
session_options.env = Env::Default();
|
||||
GraphOptimizationPassOptions options;
|
||||
options.session_options = &session_options;
|
||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
||||
options.flib_def = &library;
|
||||
options.graph = &graph;
|
||||
|
||||
EncapsulateSubgraphsPass pass;
|
||||
TF_ASSERT_OK(pass.Run(options));
|
||||
|
||||
for (const Node* node : graph->nodes()) {
|
||||
bool has_ref_var;
|
||||
TF_ASSERT_OK(
|
||||
GetNodeAttr(node->attrs(), kXlaHasReferenceVarsAttr, &has_ref_var));
|
||||
EXPECT_TRUE(node->IsSink() || node->IsSource() || has_ref_var)
|
||||
<< "All nodes apart from source and sink can access reference variable";
|
||||
}
|
||||
}
|
||||
|
||||
void CreateSubgraphNotTouchingRefVar(const Scope& s) {
|
||||
Output constant =
|
||||
ops::Const(s.WithOpName("constant_normal"), Input::Initializer(0.0));
|
||||
Output neg = ops::Negate(s.WithOpName("negate_normal"), constant);
|
||||
Output add = ops::Add(s.WithOpName("add_normal"), neg, neg);
|
||||
}
|
||||
|
||||
TEST(EncapsulateSubgraphsTest, NoRefVarsNoAttr) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
CreateSubgraphNotTouchingRefVar(root);
|
||||
|
||||
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
// TODO(cheshire): reduce boilerplate for creating
|
||||
// GraphOptimizationPassOptions here and elsewhere, probably using a macro.
|
||||
SessionOptions session_options;
|
||||
session_options.env = Env::Default();
|
||||
GraphOptimizationPassOptions options;
|
||||
options.session_options = &session_options;
|
||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
||||
options.flib_def = &library;
|
||||
options.graph = &graph;
|
||||
|
||||
EncapsulateSubgraphsPass pass;
|
||||
TF_ASSERT_OK(pass.Run(options));
|
||||
|
||||
for (const Node* node : graph->nodes()) {
|
||||
bool has_ref_var;
|
||||
TF_ASSERT_OK(
|
||||
GetNodeAttr(node->attrs(), kXlaHasReferenceVarsAttr, &has_ref_var));
|
||||
EXPECT_FALSE(has_ref_var) << "The graph does not have reference variables";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -9,11 +9,13 @@ XLA_OPS_DEPS = [
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"//tensorflow/compiler/jit:common",
|
||||
"//tensorflow/compiler/jit:compilation_passes",
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/compiler/jit:xla_activity_listener",
|
||||
"//tensorflow/compiler/jit:xla_activity_proto_cc",
|
||||
"//tensorflow/compiler/jit:xla_compilation_cache",
|
||||
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
|
||||
"//tensorflow/compiler/jit:xla_cluster_util",
|
||||
"//tensorflow/compiler/jit:xla_launch_util",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
|
@ -18,8 +18,10 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
@ -268,7 +270,7 @@ static Status BuildCompilationCache(OpKernelContext* ctx,
|
||||
}
|
||||
|
||||
static Status CompileToLocalExecutable(
|
||||
OpKernelContext* ctx, const NameAttrList& function,
|
||||
OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
|
||||
const XlaPlatformInfo& platform_info, absl::Span<const int> resources,
|
||||
absl::Span<const int> constants, bool lazy, xla::LocalClient** client,
|
||||
std::map<int, OptionalTensor>* variables,
|
||||
@ -313,8 +315,9 @@ static Status CompileToLocalExecutable(
|
||||
options.shape_representation_fn =
|
||||
platform_info.xla_device_metadata()->shape_representation_fn();
|
||||
}
|
||||
// TODO(b/138728225): Set options.alias_passthrough_params for clusters
|
||||
// without ref variables.
|
||||
// If reference variables are not present in the graph, we can safely alias
|
||||
// passthrough parameters without performing a copy.
|
||||
options.alias_passthrough_params = !has_ref_vars;
|
||||
|
||||
std::map<int, Tensor> constant_args;
|
||||
for (int i : constants) {
|
||||
@ -351,8 +354,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
|
||||
{
|
||||
Status s = CompileToLocalExecutable(
|
||||
ctx, function_, platform_info_, resources_, constants_, /*lazy=*/false,
|
||||
&client, &variables, &kernel, &executable);
|
||||
ctx, function_, /*has_ref_vars=*/true, platform_info_, resources_,
|
||||
constants_, /*lazy=*/false, &client, &variables, &kernel, &executable);
|
||||
if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU ||
|
||||
platform_info_.device_type().type_string() == DEVICE_GPU)) {
|
||||
// Suggest auto jit if the failure was with GPU or CPU.
|
||||
@ -451,6 +454,14 @@ bool MustCompileAttr(OpKernelConstruction* ctx) {
|
||||
ctx->GetAttr("must_compile", &must_compile));
|
||||
return must_compile;
|
||||
}
|
||||
|
||||
bool HasRefVars(OpKernelConstruction* ctx) {
|
||||
bool has_ref_vars;
|
||||
OP_REQUIRES_OK_RETURN(ctx, false,
|
||||
ctx->GetAttr(kXlaHasReferenceVarsAttr, &has_ref_vars));
|
||||
return has_ref_vars;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
|
||||
@ -467,7 +478,8 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
|
||||
resources_(ResourcesVector(ctx)),
|
||||
function_(FunctionAttr(ctx)),
|
||||
platform_info_(PlatformInfoFromContext(ctx)),
|
||||
must_compile_(MustCompileAttr(ctx)) {}
|
||||
must_compile_(MustCompileAttr(ctx)),
|
||||
has_ref_vars_(HasRefVars(ctx)) {}
|
||||
|
||||
void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(3) << "XlaCompileOp " << def().name()
|
||||
@ -488,7 +500,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
||||
executable = nullptr;
|
||||
} else {
|
||||
Status status = CompileToLocalExecutable(
|
||||
ctx, function_, platform_info_, resources_, constants_,
|
||||
ctx, function_, has_ref_vars_, platform_info_, resources_, constants_,
|
||||
/*lazy=*/!must_compile_, &client, &variables, &kernel, &executable);
|
||||
if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
|
||||
OP_REQUIRES_OK(ctx, status);
|
||||
|
@ -153,6 +153,9 @@ class XlaCompileOp : public OpKernel {
|
||||
|
||||
const bool must_compile_;
|
||||
|
||||
// Whether the graph has TF reference variables.
|
||||
const bool has_ref_vars_;
|
||||
|
||||
// cannot_compile_cluster_ is set to true if XLA returns an Unimplemented
|
||||
// error when compiling the cluster this _XlaCompile is supposed to compile.
|
||||
// If `cannot_compile_cluster_` is true then we avoid compiling this cluster
|
||||
|
Loading…
x
Reference in New Issue
Block a user