Automatically set up user aliasing in tf2xla when a resource update is presented.

- When a resource update is presented, automatically alias the input and output.
- Also fix an issue where the input/output proto config is overwritten.

PiperOrigin-RevId: 294984983
Change-Id: I45e96513dfeaa91f523db63837355b698bd2fb85
This commit is contained in:
Yunxing Dai 2020-02-13 13:18:31 -08:00 committed by TensorFlower Gardener
parent 469e56eeff
commit 70d8aa322c
14 changed files with 206 additions and 106 deletions

View File

@ -242,13 +242,10 @@ Status CreateXlaArgs(const Graph& graph,
return Status::OK();
}
void PopulateXlaArgsAndXlaAlias(
const tf2xla::Config& config, std::vector<XlaCompiler::Argument>* xla_args,
std::vector<xla::XlaBuilder::InputOutputAlias>* xla_aliases) {
void PopulateXlaArgs(const tf2xla::Config& config,
std::vector<XlaCompiler::Argument>* xla_args) {
// Populate arguments with resource variables from the config. The variables
// get turned into inputs and outputs.
int64 input_num = xla_args->size();
int64 output_num = config.fetch_size();
for (const tf2xla::Variable& variable : config.variable()) {
XlaCompiler::Argument arg;
arg.type = variable.type();
@ -258,17 +255,9 @@ void PopulateXlaArgsAndXlaAlias(
arg.resource_kind = XlaResource::kVariable;
arg.initialized = true;
xla_args->push_back(std::move(arg));
}
}
if (!variable.readonly()) {
// We want to alias the input and output of the variable, so the updates
// are carried out in-place.
xla_aliases->push_back({/*output_index=*/{output_num},
/*param_number=*/input_num, /*param_index=*/{}});
++output_num;
}
++input_num;
}
}
Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config,
std::unique_ptr<Graph>* graph) {
TF_RETURN_IF_ERROR(ValidateConfig(config));

View File

@ -29,10 +29,9 @@ namespace tensorflow {
Status CreateXlaArgs(const Graph& graph,
std::vector<XlaCompiler::Argument>* xla_args);
// Populate xla_args and xla_aliases for the given XLA config.
void PopulateXlaArgsAndXlaAlias(
const tf2xla::Config& config, std::vector<XlaCompiler::Argument>* xla_args,
std::vector<xla::XlaBuilder::InputOutputAlias>* xla_aliases);
// Populate xla_args for the given XLA config.
void PopulateXlaArgs(const tf2xla::Config& config,
std::vector<XlaCompiler::Argument>* xla_args);
// InitGraph creates a graph based on the graph_def, that may then be converted
// to an xla::XlaComputation via ConvertGraphToXla.

View File

@ -63,9 +63,7 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph,
std::vector<XlaCompiler::Argument> xla_args;
TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args));
std::vector<xla::XlaBuilder::InputOutputAlias> xla_aliases;
PopulateXlaArgsAndXlaAlias(config, &xla_args, &xla_aliases);
PopulateXlaArgs(config, &xla_args);
// Compile the graph into an XLA computation.
XlaCompiler::Options compiler_options;
compiler_options.client = client;
@ -75,12 +73,15 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph,
compiler_options.allow_cpu_custom_calls = true;
compiler_options.custom_fake_quant_op_calls =
config.conversion_options().custom_fake_quant_op_calls();
XlaCompiler compiler(compiler_options);
XlaCompiler::CompilationResult result;
TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(),
"tfcompile", std::move(graph),
xla_args, xla_aliases, &result));
XlaCompiler::CompileOptions options;
options.alias_resource_update = true;
TF_RETURN_IF_ERROR(compiler.CompileGraph(
options, "tfcompile", std::move(graph), xla_args, &result));
*computation = std::move(*result.computation);
int num_const_results = 0;

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/executor.h"
@ -162,9 +163,9 @@ Status BuildComputation(
std::unique_ptr<xla::XlaOp> token_output,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
bool is_entry_computation, bool return_updated_values_for_all_resources,
bool always_return_tuple, xla::XlaBuilder* builder,
xla::XlaComputation* computation, int* num_computation_outputs,
int* num_nonconst_outputs,
bool always_return_tuple, bool use_tuple_arg, bool alias_resource_update,
xla::XlaBuilder* builder, xla::XlaComputation* computation,
int* num_computation_outputs, int* num_nonconst_outputs,
std::vector<XlaCompiler::OutputDescription>* outputs,
std::vector<XlaCompiler::ResourceUpdate>* resource_updates,
xla::Shape* output_shape) {
@ -284,6 +285,7 @@ Status BuildComputation(
!grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
arg.tensor_array_gradients.count(grad.first) == 0;
}
if (return_updated_values_for_all_resources || modified) {
resource_updates->emplace_back();
XlaCompiler::ResourceUpdate& update = resource_updates->back();
@ -291,6 +293,21 @@ Status BuildComputation(
update.type = resource->type();
update.shape = resource->shape();
update.modified = modified;
if (is_entry_computation && always_return_tuple &&
arg.resource_kind != XlaResource::kTensorArray &&
alias_resource_update) {
// Assuming tuple arg and results are used.
int64 output_index = elems.size();
if (use_tuple_arg) {
builder->SetUpAlias(/*output_index=*/{output_index},
/*param_number=*/0,
/*param_index=*/{update.input_index});
} else {
builder->SetUpAlias(/*output_index=*/{output_index},
/*param_number=*/update.input_index,
/*param_index=*/{});
}
}
for (const auto& grad : resource->tensor_array_gradients()) {
update.tensor_array_gradients_accessed.insert(grad.first);
}
@ -750,7 +767,7 @@ Status XlaCompiler::CompileFunction(
VLOG(1) << "====================================================";
TF_RETURN_IF_ERROR(
CompileGraph(options, function_id, std::move(graph), args, {}, result));
CompileGraph(options, function_id, std::move(graph), args, result));
VLOG(1) << "====================================================";
cache_[{function_id, arg_vector}] = *result;
@ -1192,8 +1209,7 @@ Status XlaCompiler::CompileSingleOp(
}
FixupSourceAndSinkEdges(graph.get());
return CompileGraph(options, node_def.name(), std::move(graph), args, {},
result);
return CompileGraph(options, node_def.name(), std::move(graph), args, result);
}
namespace {
@ -1291,7 +1307,6 @@ void ConvertConstantsToExpressions(xla::XlaBuilder* builder,
Status XlaCompiler::CompileGraph(
const XlaCompiler::CompileOptions& options, string const& name,
std::unique_ptr<Graph> graph, absl::Span<const XlaCompiler::Argument> args,
absl::Span<const xla::XlaBuilder::InputOutputAlias> user_aliases,
CompilationResult* result) {
VLOG(1) << "Executing graph symbolically to populate XlaBuilder.: " << name;
@ -1344,12 +1359,6 @@ Status XlaCompiler::CompileGraph(
&result->xla_input_shapes, options.is_entry_computation));
context->set_args(std::move(arg_expressions));
// Propagate any aliases given to us by the user.
for (const xla::XlaBuilder::InputOutputAlias& alias : user_aliases) {
builder.SetUpAlias(alias.output_index, alias.param_number,
alias.param_index);
}
PushNodeTokenMapping();
// Use std::set instead of std::unordered_set to ensure determinism.
std::set<std::string> output_node_token_inputs;
@ -1402,7 +1411,8 @@ Status XlaCompiler::CompileGraph(
: ShapeRepresentationFn{},
options.is_entry_computation,
options.return_updated_values_for_all_resources,
options.always_return_tuple, &builder, result->computation.get(),
options.always_return_tuple, options.use_tuple_arg,
options.alias_resource_update, &builder, result->computation.get(),
&num_computation_outputs, &num_nonconst_outputs, &result->outputs,
&result->resource_updates, &result->xla_output_shape));

View File

@ -213,6 +213,12 @@ class XlaCompiler {
// True when we should add XLA input & output to the graph/function.
bool add_token_input_output = false;
// Resource updates are converted into input / output of xla. The two
// buffers are aliased with other if this option is true.
//
// Currently only supports TPU.
bool alias_resource_update = false;
};
struct OutputDescription {
@ -367,7 +373,6 @@ class XlaCompiler {
Status CompileGraph(
const CompileOptions& options, string const& name,
std::unique_ptr<Graph> graph, absl::Span<const Argument> args,
absl::Span<const xla::XlaBuilder::InputOutputAlias> user_aliases,
CompilationResult* result);
// Compiles a single Op, given by `node_def`, into an

View File

@ -183,9 +183,9 @@ TEST_F(XlaCompilerTest, EmptyReturnValues) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(
XlaCompiler::CompileOptions(), "add", std::move(graph),
/*args=*/{}, /*user_aliases=*/{}, &result));
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph),
/*args=*/{}, &result));
TF_ASSERT_OK(client_->Execute(*result.computation, {}).status());
}
@ -215,8 +215,7 @@ TEST_F(XlaCompilerTest, Simple) {
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args,
/*user_aliases=*/{}, &result));
std::move(graph), args, &result));
// Tests that the generated computation works.
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
@ -267,7 +266,7 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) {
compile_options.always_return_tuple = false;
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
args, /*user_aliases=*/{}, &result));
args, &result));
// Tests that the generated computation works.
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
@ -319,8 +318,7 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForUnwrittenResource) {
XlaCompiler::CompileOptions compile_options;
compile_options.return_updated_values_for_all_resources = true;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
args,
/*user_aliases=*/{}, &result));
args, &result));
xla::Shape transposed =
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
// Check that the return shapes are correctly tranposed.
@ -366,8 +364,7 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForFastMemVar) {
XlaCompiler::CompileOptions compile_options;
compile_options.return_updated_values_for_all_resources = true;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
args,
/*user_aliases=*/{}, &result));
args, &result));
EXPECT_EQ(fast_mem_arg_count, 1);
}
@ -414,8 +411,7 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) {
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args,
/*user_aliases=*/{}, &result));
std::move(graph), args, &result));
xla::Shape transposed =
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
// Check that the return shapes are correctly tranposed.
@ -456,8 +452,7 @@ TEST_F(XlaCompilerTest, TransposeVariables) {
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "transpose",
std::move(graph), args,
/*user_aliases=*/{}, &result));
std::move(graph), args, &result));
xla::Shape transposed =
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {1, 0});
// Check that the return shapes are correctly tranposed.
@ -507,7 +502,7 @@ TEST_F(XlaCompilerTest, MixedOrderArguments) {
compile_options.always_return_tuple = false;
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
args, /*user_aliases=*/{}, &result));
args, &result));
EXPECT_THAT(result.input_mapping, ::testing::ElementsAre(0, 1));
}
@ -537,9 +532,9 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
Status status = compiler.CompileGraph(XlaCompiler::CompileOptions(),
"reshape", std::move(graph), args,
/*user_aliases=*/{}, &result);
Status status =
compiler.CompileGraph(XlaCompiler::CompileOptions(), "reshape",
std::move(graph), args, &result);
EXPECT_FALSE(status.ok());
EXPECT_TRUE(
absl::StrContains(status.error_message(), "depends on a parameter"))
@ -581,8 +576,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
XlaCompiler::CompileOptions compile_options;
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
std::move(graph_copy), args,
/*user_aliases=*/{}, &result));
std::move(graph_copy), args, &result));
ASSERT_EQ(2, result.outputs.size());
EXPECT_FALSE(result.outputs[0].is_constant);
@ -667,8 +661,7 @@ TEST_F(XlaCompilerTest, ConstantOutputsOfFunctionalNode) {
XlaCompiler::CompileOptions compile_options;
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
std::move(graph), args,
/*user_aliases=*/{}, &result));
std::move(graph), args, &result));
ASSERT_EQ(2, result.outputs.size());
EXPECT_FALSE(result.outputs[1].is_constant);
@ -707,8 +700,7 @@ TEST_F(XlaCompilerTest, ResourceManager) {
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy",
std::move(graph), args,
/*user_aliases=*/{}, &result));
std::move(graph), args, &result));
EXPECT_EQ(1, resource->Get());
@ -744,8 +736,7 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) {
XlaCompiler compiler(options);
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy",
std::move(graph), args,
/*user_aliases=*/{}, &results[i]));
std::move(graph), args, &results[i]));
}
for (int64 i = 1; i < test_count; ++i) {
@ -811,8 +802,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args,
/*user_aliases=*/{}, &result));
std::move(graph), args, &result));
ASSERT_EQ(1, result.resource_updates.size());
const XlaCompiler::ResourceUpdate& update = result.resource_updates[0];
@ -871,8 +861,7 @@ TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) {
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args,
/*user_aliases=*/{}, &result));
std::move(graph), args, &result));
EXPECT_EQ(0, result.resource_updates.size());
}
@ -904,8 +893,7 @@ TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) {
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args,
/*user_aliases=*/{}, &result));
std::move(graph), args, &result));
EXPECT_EQ(1, result.resource_updates.size());
}
@ -980,8 +968,7 @@ TEST_F(XlaCompilerTest, FunctionCallWithConstants) {
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
std::move(graph), args,
/*user_aliases=*/{}, &result));
std::move(graph), args, &result));
}
// Tests CompileFunction with a local function lookup failing, fails with
@ -1064,8 +1051,7 @@ TEST_F(XlaCompilerTest, Variables) {
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args,
/*user_aliases=*/{}, &result));
std::move(graph), args, &result));
RunAndCheckVariablesComputation(client_, result);
}
@ -1101,7 +1087,7 @@ TEST_F(XlaCompilerTest, ResultLayoutSingle) {
auto compile_options = XlaCompiler::CompileOptions();
compile_options.always_return_tuple = false;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "id", std::move(graph),
args, /*user_aliases=*/{}, &result));
args, &result));
EXPECT_TRUE(xla::ShapeUtil::Equal(
result.xla_output_shape,
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1})));
@ -1138,8 +1124,7 @@ TEST_F(XlaCompilerTest, ResultLayoutMultiple) {
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "id",
std::move(graph), args,
/*user_aliases=*/{}, &result));
std::move(graph), args, &result));
xla::Shape result_shape =
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
@ -1169,8 +1154,7 @@ TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) {
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args,
/*user_aliases=*/{}, &result));
std::move(graph), args, &result));
// Tests that the generated computation works.
xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
@ -1220,8 +1204,7 @@ TEST_F(XlaCompilerTest, ReturnResourceHandle) {
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args,
/*user_aliases=*/{}, &result));
std::move(graph), args, &result));
RunAndCheckVariablesComputation(client_, result);
}
@ -1273,7 +1256,7 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
args, /*user_aliases=*/{}, &result));
args, &result));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
client_->GetComputationShape(*result.computation));
@ -1344,7 +1327,7 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
args, /*user_aliases=*/{}, &result));
args, &result));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
client_->GetComputationShape(*result.computation));
@ -1425,8 +1408,7 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) {
std::vector<XlaCompiler::Argument> args;
XlaCompiler::CompilationResult result;
status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
std::move(graph), args, /*user_aliases=*/{},
&result);
std::move(graph), args, &result);
ASSERT_FALSE(status.ok());
EXPECT_TRUE(absl::StrContains(status.error_message(), "InvalidOp"))
<< status.error_message();
@ -1451,8 +1433,7 @@ TEST_F(XlaCompilerTest, NodeWithInvalidDataType) {
XlaCompiler::CompilationResult result;
XlaCompiler compiler(DefaultOptions());
status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type",
std::move(graph), args, /*user_aliases=*/{},
&result);
std::move(graph), args, &result);
ASSERT_FALSE(status.ok());
EXPECT_TRUE(absl::StrContains(status.error_message(),
"is not in the list of allowed values"))
@ -1478,8 +1459,7 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
CopyGraph(*graph, graph_copy.get());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
std::move(graph_copy), args,
/*user_aliases=*/{}, &result));
std::move(graph_copy), args, &result));
}
}
@ -1530,7 +1510,7 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) {
CopyGraph(*graph, graph_copy.get());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
args, /*user_aliases=*/{}, &result));
args, &result));
EXPECT_EQ(result.xla_input_shapes.size(), 1);
EXPECT_TRUE(result.xla_output_shape.IsTuple());
EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1);
@ -1548,7 +1528,7 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) {
CopyGraph(*graph, graph_copy.get());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
args, /*user_aliases=*/{}, &result));
args, &result));
EXPECT_EQ(result.xla_input_shapes.size(), 2);
EXPECT_TRUE(result.xla_input_shapes[1].IsToken());
EXPECT_TRUE(result.xla_output_shape.IsTuple());
@ -1620,8 +1600,7 @@ TEST_F(XlaCompilerTest, OpsWithTensorListInput) {
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args,
/*user_aliases=*/{}, &result));
std::move(graph), args, &result));
ASSERT_EQ(result.outputs.size(), 2);
const XlaCompiler::OutputDescription& output0 = result.outputs[0];
ASSERT_TRUE(output0.is_tensor_list);
@ -1710,8 +1689,7 @@ TEST_F(XlaCompilerTest, WhileWithResources) {
compile_options.return_updated_values_for_all_resources = true;
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "tested_while_with_vars",
std::move(graph), args,
/*user_aliases=*/{}, &result));
std::move(graph), args, &result));
ASSERT_EQ(result.outputs.size(), 3);
const XlaCompiler::OutputDescription& output1 = result.outputs[1];
ASSERT_EQ(output1.input_index, 1);
@ -1772,8 +1750,7 @@ TEST_F(XlaCompilerTest, SetShardingForReturnedTuple) {
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "test",
std::move(graph), args,
/*user_aliases=*/{}, &result));
std::move(graph), args, &result));
// Tests that we set sharding on the root TUPLE instruction.
const auto& hlo_module_proto = result.computation->proto();
@ -1829,8 +1806,8 @@ TEST_F(XlaCompilerTest, DoNotConstantFoldShapeOp) {
XlaCompiler::CompilationResult result;
auto options = XlaCompiler::CompileOptions();
TF_ASSERT_OK(compiler.CompileGraph(options, "test", std::move(graph), args,
/*user_aliases=*/{}, &result));
TF_ASSERT_OK(
compiler.CompileGraph(options, "test", std::move(graph), args, &result));
xla::Literal literal0 =
xla::LiteralUtil::CreateR2<int32>({{0, 1, 2}, {3, 4, 5}});

View File

@ -332,6 +332,12 @@ class XlaBuilder {
// Adds a new input/output alias. Since the input/output shape information are
// not available until the computation is built, and eventual error in the
// arguments of this API will be detected only at computation Build() time.
//
// Note: Aliasing API is 'may-alias' and only donated buffer at runtime will
// be aliased with output. If a buffer is not donated at runtime, a copy will
// be inserted by XLA to prevent buffer clobbering.
//
// Only works on TPU backend.
void SetUpAlias(const ShapeIndex& output_index, int64 param_number,
const ShapeIndex& param_index) {
input_output_aliases_.push_back({output_index, param_number, param_index});

View File

@ -335,7 +335,7 @@ void HloComputation::set_root_instruction(HloInstruction* new_root_instruction,
if (parent() && parent()->has_entry_computation() &&
parent()->entry_computation() == this) {
if (!Shape::Equal()(new_root_instruction->shape(),
if (!Shape::Equal().IgnoreLayout()(new_root_instruction->shape(),
root_instruction_->shape())) {
// Rebuild input output alias config now that we have a new output shape.
parent()->input_output_alias_config() =

View File

@ -64,7 +64,8 @@ class HloInputOutputAliasConfig {
// Sets up alias config from `output_index` to `param_index` at
// `param_number`.
Status SetUpAlias(const ShapeIndex& output_index, int64 param_number,
const ShapeIndex& param_index, AliasKind kind);
const ShapeIndex& param_index,
AliasKind kind = AliasKind::kUserAlias);
// Returns the kind of alias for the given parameter number and parameter
// index. If no alias exists, AliasKind::kNoAlias is returned.

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/stacktrace.h"
#include "tensorflow/core/platform/types.h"
namespace xla {

View File

@ -1780,8 +1780,12 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
}
TF_RETURN_IF_ERROR(module->input_output_alias_config().Verify(
*module, [this](const Shape& shape) {
*module, [this](const Shape& shape) -> int64 {
if (target_metadata_->IsLayoutSensitive()) {
return target_metadata_->ShapeSize(shape);
} else {
return 0;
}
}));
TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().Verify(*module));

View File

@ -214,6 +214,8 @@ class TargetVerifierMetadata {
virtual std::unique_ptr<ShapeVerifier> GetVerifier() const = 0;
virtual bool IsLayoutSensitive() const = 0;
TargetVerifierMetadata() {}
virtual ~TargetVerifierMetadata() {}
@ -245,6 +247,8 @@ class DefaultVerifierMetadata : public TargetVerifierMetadata {
layout_sensitive_, allow_mixed_precision_, shape_size_function_);
}
bool IsLayoutSensitive() const override { return layout_sensitive_; }
private:
bool layout_sensitive_;
bool allow_mixed_precision_;

View File

@ -215,8 +215,12 @@ TEST_F(BufferDonationTest, SimpleWhileTupleTest) {
auto gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(f32v1_, while0, 1));
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
module->AddEntryComputation(builder.Build());
// Input output aliasing is only supported on TPU.
#if defined(XLA_TEST_BACKEND_TPU)
TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias({0}, 0, {0}));
TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias({1}, 0, {1}));
#endif
auto arg = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR0<int>(0), LiteralUtil::CreateR1<float>({1.1f})});

View File

@ -1477,6 +1477,105 @@ TEST(RawApiTest, CompileAndExecuteWithReusedBuffers) {
EXPECT_TRUE(CompareLiterals(return_literal, expected_literal));
}
TEST(RawApiTest, CompileAndExecuteWithReusedBuffersS64) {
xla::Shape element_shape = xla::ShapeUtil::MakeShape(xla::S64, {2});
xla::Shape shape =
xla::ShapeUtil::MakeTupleShape({element_shape, element_shape});
xla::Shape return_shape = xla::ShapeUtil::MakeTupleShape(
{element_shape, element_shape, element_shape, element_shape});
xla::XlaBuilder builder("ReuseBuffer");
auto param = xla::Parameter(&builder, 0, shape, "param");
auto p0 = xla::GetTupleElement(param, 0);
auto p1 = xla::GetTupleElement(param, 1);
auto add = xla::Add(p0, p1);
auto sub = xla::Sub(p0, p1);
xla::Tuple(&builder, {add, sub, p0, p1});
// Flip the tuple literals in the input handle.
builder.SetUpAlias({1}, 0, {0});
builder.SetUpAlias({0}, 0, {1});
auto computation = builder.Build().ValueOrDie();
auto literal0 = xla::LiteralUtil::CreateR1<int64>({1, 2});
auto literal1 = xla::LiteralUtil::CreateR1<int64>({5, 9});
auto literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1});
xrt::XLAAllocation param_alloc;
*param_alloc.mutable_value() = literal.ToProto();
xrt::XLAComputation c;
auto config = c.mutable_config();
auto shapes = config->mutable_program_shape();
*shapes->add_parameters() = shape.ToProto();
*shapes->mutable_result() = return_shape.ToProto();
StoreComputationSnapshot(computation, c.mutable_hlo_snapshot());
xrt::XRTExecutionConfig e;
e.set_release_input_handles(false);
e.set_release_compilation_handle(true);
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
XrtClientSession session(root);
auto e_config =
ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
auto c_data =
ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
auto c_handle = ops::XRTCompile(root, c_data);
auto param_value = ops::Const(root.WithDevice("/device:CPU:0"),
param_alloc.SerializeAsString());
auto param_handle = ops::XRTAllocate(root, param_value);
TF_ASSERT_OK(root.status());
std::vector<Tensor> outputs;
TF_EXPECT_OK(session.Run({param_handle}, &outputs));
int64 alloc_handle = outputs[0].scalar<int64>()();
// Note that we release the result handle immediately, but since we aliased
// the output buffers onto the input allocation ones (held in alloc_handle),
// we can fetch the result from there.
auto result =
ops::XRTExecute(root, c_handle.handle, e_config, {Input(alloc_handle)});
auto read_back = ops::XRTReadLiteral(root, result);
auto release = ops::XRTReleaseAllocationHandle(
root.WithControlDependencies(read_back), result);
TF_ASSERT_OK(root.status());
TF_EXPECT_OK(
session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs));
xla::Literal exec_literal = ReadOutputLiteral(outputs, 0);
auto exec_literal_parts = exec_literal.DecomposeTuple();
ASSERT_EQ(exec_literal_parts.size(), 4);
EXPECT_TRUE(CompareLiterals(exec_literal_parts[2], literal0));
EXPECT_TRUE(CompareLiterals(exec_literal_parts[3], literal1));
// Now we read back the original input handle values, which at this point
// should contain the result of the XLA computation.
auto read_handle = ops::XRTReadLiteral(root, Input(alloc_handle));
TF_ASSERT_OK(root.status());
auto release_handle = ops::XRTReleaseAllocationHandle(
root.WithControlDependencies(read_handle), Input(alloc_handle));
TF_ASSERT_OK(root.status());
TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {read_handle},
{release_handle}, &outputs));
xla::Literal return_literal = ReadOutputLiteral(outputs, 0);
auto expected_literal0 = xla::LiteralUtil::CreateR1<int64>({6, 11});
auto expected_literal1 = xla::LiteralUtil::CreateR1<int64>({-4, -7});
// The first element of the computation returned tuple would be the add
// (expected_literal0), but since we flipped the buffers, the sub
// (expected_literal1) should come first.
auto expected_literal =
xla::LiteralUtil::MakeTuple({&expected_literal1, &expected_literal0});
EXPECT_TRUE(CompareLiterals(return_literal, expected_literal));
}
TEST(RawApiTest, CompileAndExecuteWithS64Argument) {
xrt::XLAAllocation p0;
*p0.mutable_value() = xla::LiteralUtil::CreateR0<int64>(11031965).ToProto();