Add XLA token input/output to XlaIf and XlaWhile when necessary.

PiperOrigin-RevId: 212070721
This commit is contained in:
Tong Shen 2018-09-07 18:41:50 -07:00 committed by TensorFlower Gardener
parent 3ea43a044e
commit 3e1b06ee93
13 changed files with 403 additions and 8 deletions

View File

@ -191,6 +191,7 @@ cc_library(
":functionalize_control_flow",
":host_compute_metadata_proto",
":sharding_util",
":side_effect_util",
":tf2xla_util",
"//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/xla:literal",
@ -360,6 +361,7 @@ tf_cc_test(
name = "xla_compiler_test",
srcs = ["xla_compiler_test.cc"],
deps = [
":side_effect_util",
":xla_compiler",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:function_ops",
@ -371,6 +373,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:core_cpu_internal",
@ -632,3 +635,12 @@ tf_cc_test(
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "side_effect_util",
srcs = ["side_effect_util.cc"],
hdrs = ["side_effect_util.h"],
deps = [
"//tensorflow/core:core_cpu",
],
)

View File

@ -178,6 +178,7 @@ tf_kernel_library(
hdrs = ["while_op.h"],
deps = [
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:side_effect_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal",
@ -195,6 +196,7 @@ tf_kernel_library(
hdrs = ["if_op.h"],
deps = [
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:side_effect_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal",

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/if_op.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@ -33,6 +34,11 @@ XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tcond", &cond_type_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_));
if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) {
has_token_input_output_ = false;
} else {
has_token_input_output_ = !token_input_nodes_.empty();
}
}
// TODO(b/35949885): There is duplication here with the handling of the
@ -90,6 +96,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
options.resolve_compile_time_constants = false;
options.return_updated_values_for_all_resources = true;
options.is_entry_computation = false;
options.add_token_input_output = has_token_input_output_;
XlaCompiler* compiler = ctx->compiler();
XlaCompiler::CompilationResult then_result;
@ -191,7 +198,16 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
std::vector<xla::XlaOp> inputs(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
int input_num = then_result.input_mapping[i] + 1;
if (ctx->input_type(input_num) == DT_RESOURCE) {
if (has_token_input_output_ && i == num_inputs - 1) {
// Set token input for this "if" op.
std::vector<xla::XlaOp> token_inputs;
for (const string& node_name : token_input_nodes_) {
auto token_or = compiler->GetNodeToken(node_name);
OP_REQUIRES_OK(ctx, token_or.status());
token_inputs.push_back(token_or.ValueOrDie());
}
inputs[i] = xla::AfterAll(b, token_inputs);
} else if (ctx->input_type(input_num) == DT_RESOURCE) {
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b));
@ -219,6 +235,18 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
}
ctx->SetOutput(i, output_handle);
}
if (has_token_input_output_) {
// Set token output for this "if" op.
xla::XlaOp token_output =
xla::GetTupleElement(outputs, output_types_.size());
auto shape_or = b->GetShape(token_output);
OP_REQUIRES_OK(ctx, shape_or.status());
OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()),
errors::FailedPrecondition(
"Token output is not token type: ",
xla::ShapeUtil::HumanString(shape_or.ValueOrDie())));
OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output));
}
// Updates the values of any resource variables modified by the conditional
// bodies.

View File

@ -52,6 +52,8 @@ class XlaIfOp : public XlaOpKernel {
DataType cond_type_;
DataTypeVector input_types_;
DataTypeVector output_types_;
bool has_token_input_output_;
std::vector<string> token_input_nodes_;
};
} // namespace tensorflow

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/while_op.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
@ -90,6 +91,11 @@ XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
cond_name_attr_ = *name_attr;
OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &name_attr));
body_name_attr_ = *name_attr;
if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) {
has_token_input_output_ = false;
} else {
has_token_input_output_ = !token_input_nodes_.empty();
}
}
void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
@ -120,6 +126,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
body_options.return_updated_values_for_all_resources = true;
body_options.resolve_compile_time_constants = false;
body_options.is_entry_computation = false;
body_options.add_token_input_output = has_token_input_output_;
XlaCompiler::CompilationResult body;
OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_,
arguments, &body));
@ -192,6 +199,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
cond_options.use_tuple_arg = true;
cond_options.resolve_compile_time_constants = false;
cond_options.is_entry_computation = false;
cond_options.add_token_input_output = has_token_input_output_;
XlaCompiler::CompilationResult cond;
OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_,
arguments, &cond));
@ -238,7 +246,16 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
std::vector<xla::XlaOp> inputs(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
int input_num = body.input_mapping[i];
if (ctx->input_type(input_num) == DT_RESOURCE) {
if (has_token_input_output_ && i == num_inputs - 1) {
// Set token input for this "while" op.
std::vector<xla::XlaOp> token_inputs;
for (const string& node_name : token_input_nodes_) {
auto token_or = compiler->GetNodeToken(node_name);
OP_REQUIRES_OK(ctx, token_or.status());
token_inputs.push_back(token_or.ValueOrDie());
}
inputs[i] = xla::AfterAll(builder, token_inputs);
} else if (ctx->input_type(input_num) == DT_RESOURCE) {
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], builder));
@ -273,6 +290,18 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
xla::GetTupleElement(while_result, i));
}
}
if (has_token_input_output_) {
// Set token output for this "while" op.
xla::XlaOp token_output =
xla::GetTupleElement(while_result, ctx->num_outputs());
auto shape_or = builder->GetShape(token_output);
OP_REQUIRES_OK(ctx, shape_or.status());
OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()),
errors::FailedPrecondition(
"Token output is not token type: ",
xla::ShapeUtil::HumanString(shape_or.ValueOrDie())));
OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output));
}
// Updates the values of any resource variables modified by the loop.
for (int i = 0; i < body.resource_updates.size(); ++i) {

View File

@ -56,6 +56,8 @@ class XlaWhileOp : public XlaOpKernel {
private:
NameAttrList cond_name_attr_;
NameAttrList body_name_attr_;
bool has_token_input_output_;
std::vector<string> token_input_nodes_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaWhileOp);
};

View File

@ -0,0 +1,67 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/core/graph/algorithm.h"
namespace tensorflow {
const char kXlaTokenInputNodesAttrName[] = "_xla_token_input_nodes";
const char kXlaTokenArgNodeName[] = "_xla_token_arg_node";
std::set<std::string> CalculateTokenInputsForOutputToken(const Graph& g) {
std::set<std::string> results;
Node* first_side_effecting_node_on_path = nullptr;
ReverseDFS(g,
[&](Node* n) {
std::vector<string> token_input_nodes;
if (!GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName,
&token_input_nodes)
.ok() ||
token_input_nodes.empty()) {
return;
}
if (first_side_effecting_node_on_path != nullptr) {
return;
}
first_side_effecting_node_on_path = n;
results.insert(n->name());
},
[&](Node* n) {
if (first_side_effecting_node_on_path == n) {
first_side_effecting_node_on_path = nullptr;
}
},
NodeComparatorName());
return results;
}
bool HasSideEffectingNodes(const Graph& g) {
for (Node* n : g.nodes()) {
std::vector<string> token_input_nodes;
if (GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, &token_input_nodes)
.ok() &&
!token_input_nodes.empty()) {
return true;
}
}
return false;
}
} // namespace tensorflow

View File

@ -0,0 +1,47 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_
#include <vector>
#include "tensorflow/core/graph/graph.h"
namespace tensorflow {
// Side-effecting nodes will have this attribute set. Its value is the list of
// node names which this node has side-effect dependencies on.
//
// Nodes like HostCompute, SendToHost, RecvFromHost always have this attribute,
// because they always have side-effect.
// If and While nodes may or may not have this attribute, depending on whether
// their bodies have side-effecting nodes.
extern const char kXlaTokenInputNodesAttrName[];
// This node name is used in kXlaTokenInputNodesAttrName attr to signal that a
// node has side-effect dependency on current graph's token input.
extern const char kXlaTokenArgNodeName[];
// Calculates side-effect dependencies for the graph's token output.
// Returns a set of node names representing these dependencies.
std::set<std::string> CalculateTokenInputsForOutputToken(const Graph& g);
// Returns whether a graph contains side-effecting nodes.
bool HasSideEffectingNodes(const Graph& g);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
@ -291,6 +292,10 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
"Invalid resource type in XLAShapeForArgument()");
}
}
case XlaCompiler::Argument::kToken: {
*xla_shape = xla::ShapeUtil::MakeTokenShape();
return Status::OK();
}
case XlaCompiler::Argument::kInvalid:
return errors::Internal("Invalid argument type in XLAShapeForArgument()");
}
@ -489,7 +494,8 @@ Status XlaCompiler::BuildArguments(
}
break;
case XlaCompiler::Argument::kParameter: {
case XlaCompiler::Argument::kParameter:
case XlaCompiler::Argument::kToken: {
input_mapping->push_back(i);
break;
}
@ -616,6 +622,10 @@ Status XlaCompiler::BuildArguments(
arg_expression.set_handle(arg_handles[i]);
}
break;
case XlaCompiler::Argument::kToken: {
arg_expression.set_handle(arg_handles[i]);
break;
}
case XlaCompiler::Argument::kConstant:
case XlaCompiler::Argument::kInvalid:
return errors::Internal(
@ -757,23 +767,71 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
&options_.shape_representation_fn);
core::ScopedUnref context_unref(context);
std::vector<XlaCompiler::Argument> real_args(args);
int token_input_index = -1;
if (options.add_token_input_output) {
// Add extra token input.
token_input_index = real_args.size();
XlaCompiler::Argument token_arg;
token_arg.kind = XlaCompiler::Argument::kToken;
real_args.push_back(token_arg);
}
std::vector<XlaExpression> arg_expressions;
std::vector<int> arg_cores;
TF_RETURN_IF_ERROR(
BuildArguments(*graph, args, options.use_tuple_arg, &builder, context,
&arg_cores, &arg_expressions, &result->input_mapping,
&result->xla_input_shapes, options.is_entry_computation));
TF_RETURN_IF_ERROR(BuildArguments(
*graph, real_args, options.use_tuple_arg, &builder, context, &arg_cores,
&arg_expressions, &result->input_mapping, &result->xla_input_shapes,
options.is_entry_computation));
context->set_args(std::move(arg_expressions));
PushNodeTokenMapping();
// Use std::set instead of std::unordered_set to ensure determinism.
std::set<std::string> output_node_token_inputs;
if (token_input_index != -1) {
// Original token comes from input.
auto arg_expression = context->args()[token_input_index];
TF_RETURN_IF_ERROR(
SetNodeToken(kXlaTokenArgNodeName, arg_expression.handle()));
// Calculate token inputs for output token.
output_node_token_inputs = CalculateTokenInputsForOutputToken(*graph);
// If there's no side-effecting op in the graph, use token input as token
// output.
if (output_node_token_inputs.empty()) {
output_node_token_inputs.insert(kXlaTokenArgNodeName);
}
} else if (options.is_entry_computation) {
// Original token is manually created.
if (HasSideEffectingNodes(*graph)) {
TF_RETURN_IF_ERROR(
SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(&builder)));
}
}
TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_,
flib_runtime_, NextStepId()));
if (token_input_index != -1) {
// Add extra token output.
std::vector<xla::XlaOp> token_inputs;
for (const auto& node_name : output_node_token_inputs) {
auto token_or = GetNodeToken(node_name);
TF_RETURN_IF_ERROR(token_or.status());
token_inputs.push_back(token_or.ValueOrDie());
}
TF_RETURN_IF_ERROR(
context->AppendTokenRetval(xla::AfterAll(&builder, token_inputs)));
}
TF_RETURN_IF_ERROR(PopNodeTokenMapping());
int num_nonconst_outputs;
int num_computation_outputs;
result->computation = std::make_shared<xla::XlaComputation>();
result->outputs.resize(context->retvals().size());
TF_RETURN_IF_ERROR(BuildComputation(
args, arg_cores, context->retvals(), context->resources(),
real_args, arg_cores, context->retvals(), context->resources(),
options.return_updated_values_for_all_resources,
options.always_return_tuple, &builder, result->computation.get(),
&num_computation_outputs, &num_nonconst_outputs, &result->outputs,
@ -912,4 +970,47 @@ Status XlaCompiler::SetHostComputeControlDependency(
return Status::OK();
}
void XlaCompiler::PushNodeTokenMapping() {
node_token_mapping_stack_.emplace(std::map<string, xla::XlaOp>{});
}
Status XlaCompiler::PopNodeTokenMapping() {
if (node_token_mapping_stack_.empty()) {
return errors::FailedPrecondition(
"Calling PopNodeTokenMapping() when node_token_mapping_stack_ is "
"empty.");
}
node_token_mapping_stack_.pop();
return Status::OK();
}
Status XlaCompiler::SetNodeToken(const string& node_name,
const xla::XlaOp& op) {
if (node_token_mapping_stack_.empty()) {
return errors::FailedPrecondition(
"Calling SetNodeToken() when node_token_mapping_stack_ is "
"empty.");
}
auto insert_result = node_token_mapping_stack_.top().insert({node_name, op});
if (!insert_result.second) {
return errors::FailedPrecondition("Token mapping already exists for node ",
node_name);
}
return Status::OK();
}
xla::StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) {
if (node_token_mapping_stack_.empty()) {
return errors::FailedPrecondition(
"Calling GetNodeToken() when node_token_mapping_stack_ is "
"empty.");
}
auto iter = node_token_mapping_stack_.top().find(node_name);
if (iter == node_token_mapping_stack_.top().end()) {
return errors::FailedPrecondition("Cannot find token mapping for node ",
node_name);
}
return iter->second;
}
} // namespace tensorflow

View File

@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
#include <stack>
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@ -26,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/notification.h"
@ -106,6 +109,9 @@ class XlaCompiler {
// Argument is a run-time parameter.
kParameter,
// Argument is an XLA token.
kToken,
};
Kind kind = kInvalid;
@ -179,6 +185,9 @@ class XlaCompiler {
// True when compiling the entry computation, false for subcomputations
// (while, call, etc.)
bool is_entry_computation = true;
// True when we should add XLA input & output to the graph/function.
bool add_token_input_output = false;
};
struct OutputDescription {
@ -384,6 +393,11 @@ class XlaCompiler {
xla::Client* client() const { return options_.client; }
FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
void PushNodeTokenMapping();
Status PopNodeTokenMapping();
Status SetNodeToken(const string& node_name, const xla::XlaOp& op);
xla::StatusOr<xla::XlaOp> GetNodeToken(const string& node_name);
private:
// Sets the function body `fbody` to the one registered as `function`.
Status FindFunctionBody(const NameAttrList& function,
@ -448,6 +462,15 @@ class XlaCompiler {
std::unordered_map<string, xla::XlaOp> host_compute_control_output_;
// This is used to store <node name, token output> mapping. Side-effecting
// ops call SetNodeToken() to record its token output, so later side-effecting
// ops can use GetNodeToken() to get it and use it as token input.
//
// It's a stack because we need a mapping like this for each level of nested
// CompileGraph() call. In CompileGraph(), we will push a new mapping to the
// stack, and pop the mapping before returning.
std::stack<std::map<string, xla::XlaOp>> node_token_mapping_stack_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
};

View File

@ -20,10 +20,12 @@ limitations under the License.
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -32,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
@ -1274,5 +1277,70 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
}
}
class DummySideEffectingOp : public XlaOpKernel {
public:
explicit DummySideEffectingOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
OP_REQUIRES_OK(ctx, ctx->compiler()->SetNodeToken(
name(), xla::CreateToken(ctx->builder())));
}
};
REGISTER_OP("DummySideEffectingOp");
REGISTER_XLA_OP(Name("DummySideEffectingOp"), DummySideEffectingOp);
TEST_F(XlaCompilerTest, TokenInputAndOutput) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
NodeDef side_effecting_op;
side_effecting_op.set_name("DummySideEffectingOp");
side_effecting_op.set_op("DummySideEffectingOp");
AddNodeAttr(kXlaTokenInputNodesAttrName,
std::vector<string>{kXlaTokenArgNodeName}, &side_effecting_op);
Status status;
graph->AddNode(side_effecting_op, &status);
TF_ASSERT_OK(status);
EXPECT_TRUE(FixupSourceAndSinkEdges(graph.get()));
const std::vector<XlaCompiler::Argument> empty_args;
{
// The case for entry computation: we don't add token input/output. Instead,
// we use CreateToken HLO to create the entry token.
XlaCompiler::CompileOptions options;
options.is_entry_computation = true;
options.add_token_input_output = false;
XlaCompiler compiler(DefaultOptions());
std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
CopyGraph(*graph, graph_copy.get());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
empty_args, &result));
EXPECT_EQ(result.xla_input_shapes.size(), 0);
EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape));
EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 0);
}
{
// The case for non-entry computation (e.g. while loop body). We add token
// input/output.
XlaCompiler::CompileOptions options;
options.is_entry_computation = false;
options.add_token_input_output = true;
XlaCompiler compiler(DefaultOptions());
std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
CopyGraph(*graph, graph_copy.get());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
empty_args, &result));
EXPECT_EQ(result.xla_input_shapes.size(), 1);
EXPECT_TRUE(xla::ShapeUtil::IsToken(result.xla_input_shapes[0]));
EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape));
EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1);
EXPECT_TRUE(xla::ShapeUtil::IsToken(
xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 0)));
}
}
} // namespace
} // namespace tensorflow

View File

@ -119,6 +119,17 @@ Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) {
return Status::OK();
}
Status XlaContext::AppendTokenRetval(const xla::XlaOp& token) {
VLOG(1) << "Adding retval index " << retvals_.size()
<< " with token to XLA computation";
XlaExpression e;
e.set_handle(token);
// We use DT_INVALID because there is no TF DataType which corresponds to XLA
// token. XlaCompiler handles this case separately, so putting it here is OK.
retvals_.push_back(Retval{DT_INVALID, TensorShape(), e});
return Status::OK();
}
xla::XlaBuilder* XlaContext::builder() { return builder_; }
Status XlaContext::CreateResource(

View File

@ -89,6 +89,9 @@ class XlaContext : public ResourceBase {
// As for Retval, but for return values that are resource handles.
Status AddResourceRetval(int retval_index, XlaResource* resource);
// As for Retval, but for return values that are XLA tokens.
Status AppendTokenRetval(const xla::XlaOp& token);
// Creates a resource with resource `kind` and initial value `handle`. `name`
// is a descriptive name for use in error messages. See the `XlaResource`
// constructor for a description of the remaining arguments.