Add XLA token input/output to XlaIf and XlaWhile when necessary.
PiperOrigin-RevId: 212070721
This commit is contained in:
parent
3ea43a044e
commit
3e1b06ee93
@ -191,6 +191,7 @@ cc_library(
|
||||
":functionalize_control_flow",
|
||||
":host_compute_metadata_proto",
|
||||
":sharding_util",
|
||||
":side_effect_util",
|
||||
":tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla/lib:util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
@ -360,6 +361,7 @@ tf_cc_test(
|
||||
name = "xla_compiler_test",
|
||||
srcs = ["xla_compiler_test.cc"],
|
||||
deps = [
|
||||
":side_effect_util",
|
||||
":xla_compiler",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:function_ops",
|
||||
@ -371,6 +373,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
@ -632,3 +635,12 @@ tf_cc_test(
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "side_effect_util",
|
||||
srcs = ["side_effect_util.cc"],
|
||||
hdrs = ["side_effect_util.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu",
|
||||
],
|
||||
)
|
||||
|
@ -178,6 +178,7 @@ tf_kernel_library(
|
||||
hdrs = ["while_op.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:side_effect_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
@ -195,6 +196,7 @@ tf_kernel_library(
|
||||
hdrs = ["if_op.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:side_effect_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/kernels/if_op.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
@ -33,6 +34,11 @@ XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tcond", &cond_type_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_));
|
||||
if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) {
|
||||
has_token_input_output_ = false;
|
||||
} else {
|
||||
has_token_input_output_ = !token_input_nodes_.empty();
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(b/35949885): There is duplication here with the handling of the
|
||||
@ -90,6 +96,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
|
||||
options.resolve_compile_time_constants = false;
|
||||
options.return_updated_values_for_all_resources = true;
|
||||
options.is_entry_computation = false;
|
||||
options.add_token_input_output = has_token_input_output_;
|
||||
XlaCompiler* compiler = ctx->compiler();
|
||||
|
||||
XlaCompiler::CompilationResult then_result;
|
||||
@ -191,7 +198,16 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
|
||||
std::vector<xla::XlaOp> inputs(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
int input_num = then_result.input_mapping[i] + 1;
|
||||
if (ctx->input_type(input_num) == DT_RESOURCE) {
|
||||
if (has_token_input_output_ && i == num_inputs - 1) {
|
||||
// Set token input for this "if" op.
|
||||
std::vector<xla::XlaOp> token_inputs;
|
||||
for (const string& node_name : token_input_nodes_) {
|
||||
auto token_or = compiler->GetNodeToken(node_name);
|
||||
OP_REQUIRES_OK(ctx, token_or.status());
|
||||
token_inputs.push_back(token_or.ValueOrDie());
|
||||
}
|
||||
inputs[i] = xla::AfterAll(b, token_inputs);
|
||||
} else if (ctx->input_type(input_num) == DT_RESOURCE) {
|
||||
XlaResource* resource;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
|
||||
OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b));
|
||||
@ -219,6 +235,18 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
|
||||
}
|
||||
ctx->SetOutput(i, output_handle);
|
||||
}
|
||||
if (has_token_input_output_) {
|
||||
// Set token output for this "if" op.
|
||||
xla::XlaOp token_output =
|
||||
xla::GetTupleElement(outputs, output_types_.size());
|
||||
auto shape_or = b->GetShape(token_output);
|
||||
OP_REQUIRES_OK(ctx, shape_or.status());
|
||||
OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()),
|
||||
errors::FailedPrecondition(
|
||||
"Token output is not token type: ",
|
||||
xla::ShapeUtil::HumanString(shape_or.ValueOrDie())));
|
||||
OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output));
|
||||
}
|
||||
|
||||
// Updates the values of any resource variables modified by the conditional
|
||||
// bodies.
|
||||
|
@ -52,6 +52,8 @@ class XlaIfOp : public XlaOpKernel {
|
||||
DataType cond_type_;
|
||||
DataTypeVector input_types_;
|
||||
DataTypeVector output_types_;
|
||||
bool has_token_input_output_;
|
||||
std::vector<string> token_input_nodes_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/kernels/while_op.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
@ -90,6 +91,11 @@ XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
cond_name_attr_ = *name_attr;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &name_attr));
|
||||
body_name_attr_ = *name_attr;
|
||||
if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) {
|
||||
has_token_input_output_ = false;
|
||||
} else {
|
||||
has_token_input_output_ = !token_input_nodes_.empty();
|
||||
}
|
||||
}
|
||||
|
||||
void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
|
||||
@ -120,6 +126,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
|
||||
body_options.return_updated_values_for_all_resources = true;
|
||||
body_options.resolve_compile_time_constants = false;
|
||||
body_options.is_entry_computation = false;
|
||||
body_options.add_token_input_output = has_token_input_output_;
|
||||
XlaCompiler::CompilationResult body;
|
||||
OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_,
|
||||
arguments, &body));
|
||||
@ -192,6 +199,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
|
||||
cond_options.use_tuple_arg = true;
|
||||
cond_options.resolve_compile_time_constants = false;
|
||||
cond_options.is_entry_computation = false;
|
||||
cond_options.add_token_input_output = has_token_input_output_;
|
||||
XlaCompiler::CompilationResult cond;
|
||||
OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_,
|
||||
arguments, &cond));
|
||||
@ -238,7 +246,16 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
|
||||
std::vector<xla::XlaOp> inputs(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
int input_num = body.input_mapping[i];
|
||||
if (ctx->input_type(input_num) == DT_RESOURCE) {
|
||||
if (has_token_input_output_ && i == num_inputs - 1) {
|
||||
// Set token input for this "while" op.
|
||||
std::vector<xla::XlaOp> token_inputs;
|
||||
for (const string& node_name : token_input_nodes_) {
|
||||
auto token_or = compiler->GetNodeToken(node_name);
|
||||
OP_REQUIRES_OK(ctx, token_or.status());
|
||||
token_inputs.push_back(token_or.ValueOrDie());
|
||||
}
|
||||
inputs[i] = xla::AfterAll(builder, token_inputs);
|
||||
} else if (ctx->input_type(input_num) == DT_RESOURCE) {
|
||||
XlaResource* resource;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
|
||||
OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], builder));
|
||||
@ -273,6 +290,18 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
|
||||
xla::GetTupleElement(while_result, i));
|
||||
}
|
||||
}
|
||||
if (has_token_input_output_) {
|
||||
// Set token output for this "while" op.
|
||||
xla::XlaOp token_output =
|
||||
xla::GetTupleElement(while_result, ctx->num_outputs());
|
||||
auto shape_or = builder->GetShape(token_output);
|
||||
OP_REQUIRES_OK(ctx, shape_or.status());
|
||||
OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()),
|
||||
errors::FailedPrecondition(
|
||||
"Token output is not token type: ",
|
||||
xla::ShapeUtil::HumanString(shape_or.ValueOrDie())));
|
||||
OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output));
|
||||
}
|
||||
|
||||
// Updates the values of any resource variables modified by the loop.
|
||||
for (int i = 0; i < body.resource_updates.size(); ++i) {
|
||||
|
@ -56,6 +56,8 @@ class XlaWhileOp : public XlaOpKernel {
|
||||
private:
|
||||
NameAttrList cond_name_attr_;
|
||||
NameAttrList body_name_attr_;
|
||||
bool has_token_input_output_;
|
||||
std::vector<string> token_input_nodes_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaWhileOp);
|
||||
};
|
||||
|
67
tensorflow/compiler/tf2xla/side_effect_util.cc
Normal file
67
tensorflow/compiler/tf2xla/side_effect_util.cc
Normal file
@ -0,0 +1,67 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
|
||||
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const char kXlaTokenInputNodesAttrName[] = "_xla_token_input_nodes";
|
||||
|
||||
const char kXlaTokenArgNodeName[] = "_xla_token_arg_node";
|
||||
|
||||
std::set<std::string> CalculateTokenInputsForOutputToken(const Graph& g) {
|
||||
std::set<std::string> results;
|
||||
Node* first_side_effecting_node_on_path = nullptr;
|
||||
ReverseDFS(g,
|
||||
[&](Node* n) {
|
||||
std::vector<string> token_input_nodes;
|
||||
if (!GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName,
|
||||
&token_input_nodes)
|
||||
.ok() ||
|
||||
token_input_nodes.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (first_side_effecting_node_on_path != nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
first_side_effecting_node_on_path = n;
|
||||
results.insert(n->name());
|
||||
},
|
||||
[&](Node* n) {
|
||||
if (first_side_effecting_node_on_path == n) {
|
||||
first_side_effecting_node_on_path = nullptr;
|
||||
}
|
||||
},
|
||||
NodeComparatorName());
|
||||
return results;
|
||||
}
|
||||
|
||||
bool HasSideEffectingNodes(const Graph& g) {
|
||||
for (Node* n : g.nodes()) {
|
||||
std::vector<string> token_input_nodes;
|
||||
if (GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, &token_input_nodes)
|
||||
.ok() &&
|
||||
!token_input_nodes.empty()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
47
tensorflow/compiler/tf2xla/side_effect_util.h
Normal file
47
tensorflow/compiler/tf2xla/side_effect_util.h
Normal file
@ -0,0 +1,47 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Side-effecting nodes will have this attribute set. Its value is the list of
|
||||
// node names which this node has side-effect dependencies on.
|
||||
//
|
||||
// Nodes like HostCompute, SendToHost, RecvFromHost always have this attribute,
|
||||
// because they always have side-effect.
|
||||
// If and While nodes may or may not have this attribute, depending on whether
|
||||
// their bodies have side-effecting nodes.
|
||||
extern const char kXlaTokenInputNodesAttrName[];
|
||||
|
||||
// This node name is used in kXlaTokenInputNodesAttrName attr to signal that a
|
||||
// node has side-effect dependency on current graph's token input.
|
||||
extern const char kXlaTokenArgNodeName[];
|
||||
|
||||
// Calculates side-effect dependencies for the graph's token output.
|
||||
// Returns a set of node names representing these dependencies.
|
||||
std::set<std::string> CalculateTokenInputsForOutputToken(const Graph& g);
|
||||
|
||||
// Returns whether a graph contains side-effecting nodes.
|
||||
bool HasSideEffectingNodes(const Graph& g);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
@ -291,6 +292,10 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
|
||||
"Invalid resource type in XLAShapeForArgument()");
|
||||
}
|
||||
}
|
||||
case XlaCompiler::Argument::kToken: {
|
||||
*xla_shape = xla::ShapeUtil::MakeTokenShape();
|
||||
return Status::OK();
|
||||
}
|
||||
case XlaCompiler::Argument::kInvalid:
|
||||
return errors::Internal("Invalid argument type in XLAShapeForArgument()");
|
||||
}
|
||||
@ -489,7 +494,8 @@ Status XlaCompiler::BuildArguments(
|
||||
}
|
||||
|
||||
break;
|
||||
case XlaCompiler::Argument::kParameter: {
|
||||
case XlaCompiler::Argument::kParameter:
|
||||
case XlaCompiler::Argument::kToken: {
|
||||
input_mapping->push_back(i);
|
||||
break;
|
||||
}
|
||||
@ -616,6 +622,10 @@ Status XlaCompiler::BuildArguments(
|
||||
arg_expression.set_handle(arg_handles[i]);
|
||||
}
|
||||
break;
|
||||
case XlaCompiler::Argument::kToken: {
|
||||
arg_expression.set_handle(arg_handles[i]);
|
||||
break;
|
||||
}
|
||||
case XlaCompiler::Argument::kConstant:
|
||||
case XlaCompiler::Argument::kInvalid:
|
||||
return errors::Internal(
|
||||
@ -757,23 +767,71 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
&options_.shape_representation_fn);
|
||||
core::ScopedUnref context_unref(context);
|
||||
|
||||
std::vector<XlaCompiler::Argument> real_args(args);
|
||||
int token_input_index = -1;
|
||||
if (options.add_token_input_output) {
|
||||
// Add extra token input.
|
||||
token_input_index = real_args.size();
|
||||
|
||||
XlaCompiler::Argument token_arg;
|
||||
token_arg.kind = XlaCompiler::Argument::kToken;
|
||||
real_args.push_back(token_arg);
|
||||
}
|
||||
|
||||
std::vector<XlaExpression> arg_expressions;
|
||||
std::vector<int> arg_cores;
|
||||
TF_RETURN_IF_ERROR(
|
||||
BuildArguments(*graph, args, options.use_tuple_arg, &builder, context,
|
||||
&arg_cores, &arg_expressions, &result->input_mapping,
|
||||
&result->xla_input_shapes, options.is_entry_computation));
|
||||
TF_RETURN_IF_ERROR(BuildArguments(
|
||||
*graph, real_args, options.use_tuple_arg, &builder, context, &arg_cores,
|
||||
&arg_expressions, &result->input_mapping, &result->xla_input_shapes,
|
||||
options.is_entry_computation));
|
||||
context->set_args(std::move(arg_expressions));
|
||||
|
||||
PushNodeTokenMapping();
|
||||
// Use std::set instead of std::unordered_set to ensure determinism.
|
||||
std::set<std::string> output_node_token_inputs;
|
||||
if (token_input_index != -1) {
|
||||
// Original token comes from input.
|
||||
auto arg_expression = context->args()[token_input_index];
|
||||
TF_RETURN_IF_ERROR(
|
||||
SetNodeToken(kXlaTokenArgNodeName, arg_expression.handle()));
|
||||
|
||||
// Calculate token inputs for output token.
|
||||
output_node_token_inputs = CalculateTokenInputsForOutputToken(*graph);
|
||||
|
||||
// If there's no side-effecting op in the graph, use token input as token
|
||||
// output.
|
||||
if (output_node_token_inputs.empty()) {
|
||||
output_node_token_inputs.insert(kXlaTokenArgNodeName);
|
||||
}
|
||||
} else if (options.is_entry_computation) {
|
||||
// Original token is manually created.
|
||||
if (HasSideEffectingNodes(*graph)) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(&builder)));
|
||||
}
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_,
|
||||
flib_runtime_, NextStepId()));
|
||||
if (token_input_index != -1) {
|
||||
// Add extra token output.
|
||||
std::vector<xla::XlaOp> token_inputs;
|
||||
for (const auto& node_name : output_node_token_inputs) {
|
||||
auto token_or = GetNodeToken(node_name);
|
||||
TF_RETURN_IF_ERROR(token_or.status());
|
||||
token_inputs.push_back(token_or.ValueOrDie());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->AppendTokenRetval(xla::AfterAll(&builder, token_inputs)));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(PopNodeTokenMapping());
|
||||
|
||||
int num_nonconst_outputs;
|
||||
int num_computation_outputs;
|
||||
result->computation = std::make_shared<xla::XlaComputation>();
|
||||
result->outputs.resize(context->retvals().size());
|
||||
TF_RETURN_IF_ERROR(BuildComputation(
|
||||
args, arg_cores, context->retvals(), context->resources(),
|
||||
real_args, arg_cores, context->retvals(), context->resources(),
|
||||
options.return_updated_values_for_all_resources,
|
||||
options.always_return_tuple, &builder, result->computation.get(),
|
||||
&num_computation_outputs, &num_nonconst_outputs, &result->outputs,
|
||||
@ -912,4 +970,47 @@ Status XlaCompiler::SetHostComputeControlDependency(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void XlaCompiler::PushNodeTokenMapping() {
|
||||
node_token_mapping_stack_.emplace(std::map<string, xla::XlaOp>{});
|
||||
}
|
||||
|
||||
Status XlaCompiler::PopNodeTokenMapping() {
|
||||
if (node_token_mapping_stack_.empty()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Calling PopNodeTokenMapping() when node_token_mapping_stack_ is "
|
||||
"empty.");
|
||||
}
|
||||
node_token_mapping_stack_.pop();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaCompiler::SetNodeToken(const string& node_name,
|
||||
const xla::XlaOp& op) {
|
||||
if (node_token_mapping_stack_.empty()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Calling SetNodeToken() when node_token_mapping_stack_ is "
|
||||
"empty.");
|
||||
}
|
||||
auto insert_result = node_token_mapping_stack_.top().insert({node_name, op});
|
||||
if (!insert_result.second) {
|
||||
return errors::FailedPrecondition("Token mapping already exists for node ",
|
||||
node_name);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
xla::StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) {
|
||||
if (node_token_mapping_stack_.empty()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Calling GetNodeToken() when node_token_mapping_stack_ is "
|
||||
"empty.");
|
||||
}
|
||||
auto iter = node_token_mapping_stack_.top().find(node_name);
|
||||
if (iter == node_token_mapping_stack_.top().end()) {
|
||||
return errors::FailedPrecondition("Cannot find token mapping for node ",
|
||||
node_name);
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
|
||||
|
||||
#include <stack>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
@ -26,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/notification.h"
|
||||
@ -106,6 +109,9 @@ class XlaCompiler {
|
||||
|
||||
// Argument is a run-time parameter.
|
||||
kParameter,
|
||||
|
||||
// Argument is an XLA token.
|
||||
kToken,
|
||||
};
|
||||
|
||||
Kind kind = kInvalid;
|
||||
@ -179,6 +185,9 @@ class XlaCompiler {
|
||||
// True when compiling the entry computation, false for subcomputations
|
||||
// (while, call, etc.)
|
||||
bool is_entry_computation = true;
|
||||
|
||||
// True when we should add XLA input & output to the graph/function.
|
||||
bool add_token_input_output = false;
|
||||
};
|
||||
|
||||
struct OutputDescription {
|
||||
@ -384,6 +393,11 @@ class XlaCompiler {
|
||||
xla::Client* client() const { return options_.client; }
|
||||
FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
|
||||
|
||||
void PushNodeTokenMapping();
|
||||
Status PopNodeTokenMapping();
|
||||
Status SetNodeToken(const string& node_name, const xla::XlaOp& op);
|
||||
xla::StatusOr<xla::XlaOp> GetNodeToken(const string& node_name);
|
||||
|
||||
private:
|
||||
// Sets the function body `fbody` to the one registered as `function`.
|
||||
Status FindFunctionBody(const NameAttrList& function,
|
||||
@ -448,6 +462,15 @@ class XlaCompiler {
|
||||
|
||||
std::unordered_map<string, xla::XlaOp> host_compute_control_output_;
|
||||
|
||||
// This is used to store <node name, token output> mapping. Side-effecting
|
||||
// ops call SetNodeToken() to record its token output, so later side-effecting
|
||||
// ops can use GetNodeToken() to get it and use it as token input.
|
||||
//
|
||||
// It's a stack because we need a mapping like this for each level of nested
|
||||
// CompileGraph() call. In CompileGraph(), we will push a new mapping to the
|
||||
// stack, and pop the mapping before returning.
|
||||
std::stack<std::map<string, xla::XlaOp>> node_token_mapping_stack_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
|
||||
};
|
||||
|
||||
|
@ -20,10 +20,12 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/resource_variable_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -32,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
@ -1274,5 +1277,70 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
|
||||
}
|
||||
}
|
||||
|
||||
class DummySideEffectingOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit DummySideEffectingOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
OP_REQUIRES_OK(ctx, ctx->compiler()->SetNodeToken(
|
||||
name(), xla::CreateToken(ctx->builder())));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_OP("DummySideEffectingOp");
|
||||
|
||||
REGISTER_XLA_OP(Name("DummySideEffectingOp"), DummySideEffectingOp);
|
||||
|
||||
TEST_F(XlaCompilerTest, TokenInputAndOutput) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
NodeDef side_effecting_op;
|
||||
side_effecting_op.set_name("DummySideEffectingOp");
|
||||
side_effecting_op.set_op("DummySideEffectingOp");
|
||||
AddNodeAttr(kXlaTokenInputNodesAttrName,
|
||||
std::vector<string>{kXlaTokenArgNodeName}, &side_effecting_op);
|
||||
Status status;
|
||||
graph->AddNode(side_effecting_op, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
EXPECT_TRUE(FixupSourceAndSinkEdges(graph.get()));
|
||||
|
||||
const std::vector<XlaCompiler::Argument> empty_args;
|
||||
{
|
||||
// The case for entry computation: we don't add token input/output. Instead,
|
||||
// we use CreateToken HLO to create the entry token.
|
||||
XlaCompiler::CompileOptions options;
|
||||
options.is_entry_computation = true;
|
||||
options.add_token_input_output = false;
|
||||
XlaCompiler compiler(DefaultOptions());
|
||||
|
||||
std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
|
||||
CopyGraph(*graph, graph_copy.get());
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
|
||||
empty_args, &result));
|
||||
EXPECT_EQ(result.xla_input_shapes.size(), 0);
|
||||
EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape));
|
||||
EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 0);
|
||||
}
|
||||
{
|
||||
// The case for non-entry computation (e.g. while loop body). We add token
|
||||
// input/output.
|
||||
XlaCompiler::CompileOptions options;
|
||||
options.is_entry_computation = false;
|
||||
options.add_token_input_output = true;
|
||||
XlaCompiler compiler(DefaultOptions());
|
||||
|
||||
std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
|
||||
CopyGraph(*graph, graph_copy.get());
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
|
||||
empty_args, &result));
|
||||
EXPECT_EQ(result.xla_input_shapes.size(), 1);
|
||||
EXPECT_TRUE(xla::ShapeUtil::IsToken(result.xla_input_shapes[0]));
|
||||
EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape));
|
||||
EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1);
|
||||
EXPECT_TRUE(xla::ShapeUtil::IsToken(
|
||||
xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 0)));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -119,6 +119,17 @@ Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaContext::AppendTokenRetval(const xla::XlaOp& token) {
|
||||
VLOG(1) << "Adding retval index " << retvals_.size()
|
||||
<< " with token to XLA computation";
|
||||
XlaExpression e;
|
||||
e.set_handle(token);
|
||||
// We use DT_INVALID because there is no TF DataType which corresponds to XLA
|
||||
// token. XlaCompiler handles this case separately, so putting it here is OK.
|
||||
retvals_.push_back(Retval{DT_INVALID, TensorShape(), e});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
xla::XlaBuilder* XlaContext::builder() { return builder_; }
|
||||
|
||||
Status XlaContext::CreateResource(
|
||||
|
@ -89,6 +89,9 @@ class XlaContext : public ResourceBase {
|
||||
// As for Retval, but for return values that are resource handles.
|
||||
Status AddResourceRetval(int retval_index, XlaResource* resource);
|
||||
|
||||
// As for Retval, but for return values that are XLA tokens.
|
||||
Status AppendTokenRetval(const xla::XlaOp& token);
|
||||
|
||||
// Creates a resource with resource `kind` and initial value `handle`. `name`
|
||||
// is a descriptive name for use in error messages. See the `XlaResource`
|
||||
// constructor for a description of the remaining arguments.
|
||||
|
Loading…
x
Reference in New Issue
Block a user