diff --git a/tensorflow/compiler/tf2xla/graph_compiler_util.cc b/tensorflow/compiler/tf2xla/graph_compiler_util.cc index 6d20faddf83..814ebe39e6d 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler_util.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler_util.cc @@ -242,13 +242,10 @@ Status CreateXlaArgs(const Graph& graph, return Status::OK(); } -void PopulateXlaArgsAndXlaAlias( - const tf2xla::Config& config, std::vector* xla_args, - std::vector* xla_aliases) { +void PopulateXlaArgs(const tf2xla::Config& config, + std::vector* xla_args) { // Populate arguments with resource variables from the config. The variables // get turned into inputs and outputs. - int64 input_num = xla_args->size(); - int64 output_num = config.fetch_size(); for (const tf2xla::Variable& variable : config.variable()) { XlaCompiler::Argument arg; arg.type = variable.type(); @@ -258,17 +255,9 @@ void PopulateXlaArgsAndXlaAlias( arg.resource_kind = XlaResource::kVariable; arg.initialized = true; xla_args->push_back(std::move(arg)); - - if (!variable.readonly()) { - // We want to alias the input and output of the variable, so the updates - // are carried out in-place. - xla_aliases->push_back({/*output_index=*/{output_num}, - /*param_number=*/input_num, /*param_index=*/{}}); - ++output_num; - } - ++input_num; } } + Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, std::unique_ptr* graph) { TF_RETURN_IF_ERROR(ValidateConfig(config)); diff --git a/tensorflow/compiler/tf2xla/graph_compiler_util.h b/tensorflow/compiler/tf2xla/graph_compiler_util.h index ac6f72c9cbb..61fd1565295 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler_util.h +++ b/tensorflow/compiler/tf2xla/graph_compiler_util.h @@ -29,10 +29,9 @@ namespace tensorflow { Status CreateXlaArgs(const Graph& graph, std::vector* xla_args); -// Populate xla_args and xla_aliases for the given XLA config. -void PopulateXlaArgsAndXlaAlias( - const tf2xla::Config& config, std::vector* xla_args, - std::vector* xla_aliases); +// Populate xla_args for the given XLA config. +void PopulateXlaArgs(const tf2xla::Config& config, + std::vector* xla_args); // InitGraph creates a graph based on the graph_def, that may then be converted // to an xla::XlaComputation via ConvertGraphToXla. diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 78343e66724..9ced6e682fc 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -63,9 +63,7 @@ Status ConvertGraphToXla(std::unique_ptr graph, std::vector xla_args; TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args)); - std::vector xla_aliases; - PopulateXlaArgsAndXlaAlias(config, &xla_args, &xla_aliases); - + PopulateXlaArgs(config, &xla_args); // Compile the graph into an XLA computation. XlaCompiler::Options compiler_options; compiler_options.client = client; @@ -75,12 +73,15 @@ Status ConvertGraphToXla(std::unique_ptr graph, compiler_options.allow_cpu_custom_calls = true; compiler_options.custom_fake_quant_op_calls = config.conversion_options().custom_fake_quant_op_calls(); + XlaCompiler compiler(compiler_options); XlaCompiler::CompilationResult result; - TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(), - "tfcompile", std::move(graph), - xla_args, xla_aliases, &result)); + + XlaCompiler::CompileOptions options; + options.alias_resource_update = true; + TF_RETURN_IF_ERROR(compiler.CompileGraph( + options, "tfcompile", std::move(graph), xla_args, &result)); *computation = std::move(*result.computation); int num_const_results = 0; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index d69e40e93e2..8e44d3d4255 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" @@ -162,9 +163,9 @@ Status BuildComputation( std::unique_ptr token_output, const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, bool is_entry_computation, bool return_updated_values_for_all_resources, - bool always_return_tuple, xla::XlaBuilder* builder, - xla::XlaComputation* computation, int* num_computation_outputs, - int* num_nonconst_outputs, + bool always_return_tuple, bool use_tuple_arg, bool alias_resource_update, + xla::XlaBuilder* builder, xla::XlaComputation* computation, + int* num_computation_outputs, int* num_nonconst_outputs, std::vector* outputs, std::vector* resource_updates, xla::Shape* output_shape) { @@ -284,6 +285,7 @@ Status BuildComputation( !grad.second->value().IsIdenticalTo(grad.second->initial_value()) || arg.tensor_array_gradients.count(grad.first) == 0; } + if (return_updated_values_for_all_resources || modified) { resource_updates->emplace_back(); XlaCompiler::ResourceUpdate& update = resource_updates->back(); @@ -291,6 +293,21 @@ Status BuildComputation( update.type = resource->type(); update.shape = resource->shape(); update.modified = modified; + if (is_entry_computation && always_return_tuple && + arg.resource_kind != XlaResource::kTensorArray && + alias_resource_update) { + // Assuming tuple arg and results are used. + int64 output_index = elems.size(); + if (use_tuple_arg) { + builder->SetUpAlias(/*output_index=*/{output_index}, + /*param_number=*/0, + /*param_index=*/{update.input_index}); + } else { + builder->SetUpAlias(/*output_index=*/{output_index}, + /*param_number=*/update.input_index, + /*param_index=*/{}); + } + } for (const auto& grad : resource->tensor_array_gradients()) { update.tensor_array_gradients_accessed.insert(grad.first); } @@ -750,7 +767,7 @@ Status XlaCompiler::CompileFunction( VLOG(1) << "===================================================="; TF_RETURN_IF_ERROR( - CompileGraph(options, function_id, std::move(graph), args, {}, result)); + CompileGraph(options, function_id, std::move(graph), args, result)); VLOG(1) << "===================================================="; cache_[{function_id, arg_vector}] = *result; @@ -1192,8 +1209,7 @@ Status XlaCompiler::CompileSingleOp( } FixupSourceAndSinkEdges(graph.get()); - return CompileGraph(options, node_def.name(), std::move(graph), args, {}, - result); + return CompileGraph(options, node_def.name(), std::move(graph), args, result); } namespace { @@ -1291,7 +1307,6 @@ void ConvertConstantsToExpressions(xla::XlaBuilder* builder, Status XlaCompiler::CompileGraph( const XlaCompiler::CompileOptions& options, string const& name, std::unique_ptr graph, absl::Span args, - absl::Span user_aliases, CompilationResult* result) { VLOG(1) << "Executing graph symbolically to populate XlaBuilder.: " << name; @@ -1344,12 +1359,6 @@ Status XlaCompiler::CompileGraph( &result->xla_input_shapes, options.is_entry_computation)); context->set_args(std::move(arg_expressions)); - // Propagate any aliases given to us by the user. - for (const xla::XlaBuilder::InputOutputAlias& alias : user_aliases) { - builder.SetUpAlias(alias.output_index, alias.param_number, - alias.param_index); - } - PushNodeTokenMapping(); // Use std::set instead of std::unordered_set to ensure determinism. std::set output_node_token_inputs; @@ -1402,7 +1411,8 @@ Status XlaCompiler::CompileGraph( : ShapeRepresentationFn{}, options.is_entry_computation, options.return_updated_values_for_all_resources, - options.always_return_tuple, &builder, result->computation.get(), + options.always_return_tuple, options.use_tuple_arg, + options.alias_resource_update, &builder, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->outputs, &result->resource_updates, &result->xla_output_shape)); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 1ae82c3ea54..5ec5866632b 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -213,6 +213,12 @@ class XlaCompiler { // True when we should add XLA input & output to the graph/function. bool add_token_input_output = false; + + // Resource updates are converted into input / output of xla. The two + // buffers are aliased with other if this option is true. + // + // Currently only supports TPU. + bool alias_resource_update = false; }; struct OutputDescription { @@ -367,7 +373,6 @@ class XlaCompiler { Status CompileGraph( const CompileOptions& options, string const& name, std::unique_ptr graph, absl::Span args, - absl::Span user_aliases, CompilationResult* result); // Compiles a single Op, given by `node_def`, into an diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 4a239c39030..cf8bd6b6ce4 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -183,9 +183,9 @@ TEST_F(XlaCompilerTest, EmptyReturnValues) { std::unique_ptr graph(new Graph(OpRegistry::Global())); XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph( - XlaCompiler::CompileOptions(), "add", std::move(graph), - /*args=*/{}, /*user_aliases=*/{}, &result)); + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), + /*args=*/{}, &result)); TF_ASSERT_OK(client_->Execute(*result.computation, {}).status()); } @@ -215,8 +215,7 @@ TEST_F(XlaCompilerTest, Simple) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", - std::move(graph), args, - /*user_aliases=*/{}, &result)); + std::move(graph), args, &result)); // Tests that the generated computation works. xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); @@ -267,7 +266,7 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) { compile_options.always_return_tuple = false; XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), - args, /*user_aliases=*/{}, &result)); + args, &result)); // Tests that the generated computation works. xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); @@ -319,8 +318,7 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForUnwrittenResource) { XlaCompiler::CompileOptions compile_options; compile_options.return_updated_values_for_all_resources = true; TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), - args, - /*user_aliases=*/{}, &result)); + args, &result)); xla::Shape transposed = xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1}); // Check that the return shapes are correctly tranposed. @@ -366,8 +364,7 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForFastMemVar) { XlaCompiler::CompileOptions compile_options; compile_options.return_updated_values_for_all_resources = true; TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), - args, - /*user_aliases=*/{}, &result)); + args, &result)); EXPECT_EQ(fast_mem_arg_count, 1); } @@ -414,8 +411,7 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", - std::move(graph), args, - /*user_aliases=*/{}, &result)); + std::move(graph), args, &result)); xla::Shape transposed = xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1}); // Check that the return shapes are correctly tranposed. @@ -456,8 +452,7 @@ TEST_F(XlaCompilerTest, TransposeVariables) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "transpose", - std::move(graph), args, - /*user_aliases=*/{}, &result)); + std::move(graph), args, &result)); xla::Shape transposed = xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {1, 0}); // Check that the return shapes are correctly tranposed. @@ -507,7 +502,7 @@ TEST_F(XlaCompilerTest, MixedOrderArguments) { compile_options.always_return_tuple = false; XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), - args, /*user_aliases=*/{}, &result)); + args, &result)); EXPECT_THAT(result.input_mapping, ::testing::ElementsAre(0, 1)); } @@ -537,9 +532,9 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { XlaCompiler compiler(DefaultOptions()); XlaCompiler::CompilationResult result; - Status status = compiler.CompileGraph(XlaCompiler::CompileOptions(), - "reshape", std::move(graph), args, - /*user_aliases=*/{}, &result); + Status status = + compiler.CompileGraph(XlaCompiler::CompileOptions(), "reshape", + std::move(graph), args, &result); EXPECT_FALSE(status.ok()); EXPECT_TRUE( absl::StrContains(status.error_message(), "depends on a parameter")) @@ -581,8 +576,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { XlaCompiler::CompileOptions compile_options; XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants", - std::move(graph_copy), args, - /*user_aliases=*/{}, &result)); + std::move(graph_copy), args, &result)); ASSERT_EQ(2, result.outputs.size()); EXPECT_FALSE(result.outputs[0].is_constant); @@ -667,8 +661,7 @@ TEST_F(XlaCompilerTest, ConstantOutputsOfFunctionalNode) { XlaCompiler::CompileOptions compile_options; XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants", - std::move(graph), args, - /*user_aliases=*/{}, &result)); + std::move(graph), args, &result)); ASSERT_EQ(2, result.outputs.size()); EXPECT_FALSE(result.outputs[1].is_constant); @@ -707,8 +700,7 @@ TEST_F(XlaCompilerTest, ResourceManager) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy", - std::move(graph), args, - /*user_aliases=*/{}, &result)); + std::move(graph), args, &result)); EXPECT_EQ(1, resource->Get()); @@ -744,8 +736,7 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) { XlaCompiler compiler(options); TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy", - std::move(graph), args, - /*user_aliases=*/{}, &results[i])); + std::move(graph), args, &results[i])); } for (int64 i = 1; i < test_count; ++i) { @@ -811,8 +802,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", - std::move(graph), args, - /*user_aliases=*/{}, &result)); + std::move(graph), args, &result)); ASSERT_EQ(1, result.resource_updates.size()); const XlaCompiler::ResourceUpdate& update = result.resource_updates[0]; @@ -871,8 +861,7 @@ TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", - std::move(graph), args, - /*user_aliases=*/{}, &result)); + std::move(graph), args, &result)); EXPECT_EQ(0, result.resource_updates.size()); } @@ -904,8 +893,7 @@ TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", - std::move(graph), args, - /*user_aliases=*/{}, &result)); + std::move(graph), args, &result)); EXPECT_EQ(1, result.resource_updates.size()); } @@ -980,8 +968,7 @@ TEST_F(XlaCompilerTest, FunctionCallWithConstants) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", - std::move(graph), args, - /*user_aliases=*/{}, &result)); + std::move(graph), args, &result)); } // Tests CompileFunction with a local function lookup failing, fails with @@ -1064,8 +1051,7 @@ TEST_F(XlaCompilerTest, Variables) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", - std::move(graph), args, - /*user_aliases=*/{}, &result)); + std::move(graph), args, &result)); RunAndCheckVariablesComputation(client_, result); } @@ -1101,7 +1087,7 @@ TEST_F(XlaCompilerTest, ResultLayoutSingle) { auto compile_options = XlaCompiler::CompileOptions(); compile_options.always_return_tuple = false; TF_ASSERT_OK(compiler.CompileGraph(compile_options, "id", std::move(graph), - args, /*user_aliases=*/{}, &result)); + args, &result)); EXPECT_TRUE(xla::ShapeUtil::Equal( result.xla_output_shape, xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1}))); @@ -1138,8 +1124,7 @@ TEST_F(XlaCompilerTest, ResultLayoutMultiple) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "id", - std::move(graph), args, - /*user_aliases=*/{}, &result)); + std::move(graph), args, &result)); xla::Shape result_shape = xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1}); @@ -1169,8 +1154,7 @@ TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", - std::move(graph), args, - /*user_aliases=*/{}, &result)); + std::move(graph), args, &result)); // Tests that the generated computation works. xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); @@ -1220,8 +1204,7 @@ TEST_F(XlaCompilerTest, ReturnResourceHandle) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", - std::move(graph), args, - /*user_aliases=*/{}, &result)); + std::move(graph), args, &result)); RunAndCheckVariablesComputation(client_, result); } @@ -1273,7 +1256,7 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), - args, /*user_aliases=*/{}, &result)); + args, &result)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr program_shape, client_->GetComputationShape(*result.computation)); @@ -1344,7 +1327,7 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), - args, /*user_aliases=*/{}, &result)); + args, &result)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr program_shape, client_->GetComputationShape(*result.computation)); @@ -1425,8 +1408,7 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { std::vector args; XlaCompiler::CompilationResult result; status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", - std::move(graph), args, /*user_aliases=*/{}, - &result); + std::move(graph), args, &result); ASSERT_FALSE(status.ok()); EXPECT_TRUE(absl::StrContains(status.error_message(), "InvalidOp")) << status.error_message(); @@ -1451,8 +1433,7 @@ TEST_F(XlaCompilerTest, NodeWithInvalidDataType) { XlaCompiler::CompilationResult result; XlaCompiler compiler(DefaultOptions()); status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type", - std::move(graph), args, /*user_aliases=*/{}, - &result); + std::move(graph), args, &result); ASSERT_FALSE(status.ok()); EXPECT_TRUE(absl::StrContains(status.error_message(), "is not in the list of allowed values")) @@ -1478,8 +1459,7 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { CopyGraph(*graph, graph_copy.get()); XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", - std::move(graph_copy), args, - /*user_aliases=*/{}, &result)); + std::move(graph_copy), args, &result)); } } @@ -1530,7 +1510,7 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) { CopyGraph(*graph, graph_copy.get()); XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy), - args, /*user_aliases=*/{}, &result)); + args, &result)); EXPECT_EQ(result.xla_input_shapes.size(), 1); EXPECT_TRUE(result.xla_output_shape.IsTuple()); EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1); @@ -1548,7 +1528,7 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) { CopyGraph(*graph, graph_copy.get()); XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy), - args, /*user_aliases=*/{}, &result)); + args, &result)); EXPECT_EQ(result.xla_input_shapes.size(), 2); EXPECT_TRUE(result.xla_input_shapes[1].IsToken()); EXPECT_TRUE(result.xla_output_shape.IsTuple()); @@ -1620,8 +1600,7 @@ TEST_F(XlaCompilerTest, OpsWithTensorListInput) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", - std::move(graph), args, - /*user_aliases=*/{}, &result)); + std::move(graph), args, &result)); ASSERT_EQ(result.outputs.size(), 2); const XlaCompiler::OutputDescription& output0 = result.outputs[0]; ASSERT_TRUE(output0.is_tensor_list); @@ -1710,8 +1689,7 @@ TEST_F(XlaCompilerTest, WhileWithResources) { compile_options.return_updated_values_for_all_resources = true; XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(compile_options, "tested_while_with_vars", - std::move(graph), args, - /*user_aliases=*/{}, &result)); + std::move(graph), args, &result)); ASSERT_EQ(result.outputs.size(), 3); const XlaCompiler::OutputDescription& output1 = result.outputs[1]; ASSERT_EQ(output1.input_index, 1); @@ -1772,8 +1750,7 @@ TEST_F(XlaCompilerTest, SetShardingForReturnedTuple) { XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "test", - std::move(graph), args, - /*user_aliases=*/{}, &result)); + std::move(graph), args, &result)); // Tests that we set sharding on the root TUPLE instruction. const auto& hlo_module_proto = result.computation->proto(); @@ -1829,8 +1806,8 @@ TEST_F(XlaCompilerTest, DoNotConstantFoldShapeOp) { XlaCompiler::CompilationResult result; auto options = XlaCompiler::CompileOptions(); - TF_ASSERT_OK(compiler.CompileGraph(options, "test", std::move(graph), args, - /*user_aliases=*/{}, &result)); + TF_ASSERT_OK( + compiler.CompileGraph(options, "test", std::move(graph), args, &result)); xla::Literal literal0 = xla::LiteralUtil::CreateR2({{0, 1, 2}, {3, 4, 5}}); diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index fe56ede4692..993394ea275 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -332,6 +332,12 @@ class XlaBuilder { // Adds a new input/output alias. Since the input/output shape information are // not available until the computation is built, and eventual error in the // arguments of this API will be detected only at computation Build() time. + // + // Note: Aliasing API is 'may-alias' and only donated buffer at runtime will + // be aliased with output. If a buffer is not donated at runtime, a copy will + // be inserted by XLA to prevent buffer clobbering. + // + // Only works on TPU backend. void SetUpAlias(const ShapeIndex& output_index, int64 param_number, const ShapeIndex& param_index) { input_output_aliases_.push_back({output_index, param_number, param_index}); diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 1ca13cd9c9f..122122aae55 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -335,8 +335,8 @@ void HloComputation::set_root_instruction(HloInstruction* new_root_instruction, if (parent() && parent()->has_entry_computation() && parent()->entry_computation() == this) { - if (!Shape::Equal()(new_root_instruction->shape(), - root_instruction_->shape())) { + if (!Shape::Equal().IgnoreLayout()(new_root_instruction->shape(), + root_instruction_->shape())) { // Rebuild input output alias config now that we have a new output shape. parent()->input_output_alias_config() = HloInputOutputAliasConfig(new_root_instruction->shape()); diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h index 689007ff9ab..65ea02b6db0 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h @@ -64,7 +64,8 @@ class HloInputOutputAliasConfig { // Sets up alias config from `output_index` to `param_index` at // `param_number`. Status SetUpAlias(const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index, AliasKind kind); + const ShapeIndex& param_index, + AliasKind kind = AliasKind::kUserAlias); // Returns the kind of alias for the given parameter number and parameter // index. If no alias exists, AliasKind::kNoAlias is returned. diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index c38c30704bd..8d5af356275 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/stacktrace.h" #include "tensorflow/core/platform/types.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 6e87a95a14e..a8f9f612b0f 100755 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -1780,8 +1780,12 @@ StatusOr HloVerifier::Run(HloModule* module) { } TF_RETURN_IF_ERROR(module->input_output_alias_config().Verify( - *module, [this](const Shape& shape) { - return target_metadata_->ShapeSize(shape); + *module, [this](const Shape& shape) -> int64 { + if (target_metadata_->IsLayoutSensitive()) { + return target_metadata_->ShapeSize(shape); + } else { + return 0; + } })); TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().Verify(*module)); diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 124adceda86..2e83361a591 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -214,6 +214,8 @@ class TargetVerifierMetadata { virtual std::unique_ptr GetVerifier() const = 0; + virtual bool IsLayoutSensitive() const = 0; + TargetVerifierMetadata() {} virtual ~TargetVerifierMetadata() {} @@ -245,6 +247,8 @@ class DefaultVerifierMetadata : public TargetVerifierMetadata { layout_sensitive_, allow_mixed_precision_, shape_size_function_); } + bool IsLayoutSensitive() const override { return layout_sensitive_; } + private: bool layout_sensitive_; bool allow_mixed_precision_; diff --git a/tensorflow/compiler/xla/tests/buffer_donation_test.cc b/tensorflow/compiler/xla/tests/buffer_donation_test.cc index b4a75e29cb2..44e958215a6 100644 --- a/tensorflow/compiler/xla/tests/buffer_donation_test.cc +++ b/tensorflow/compiler/xla/tests/buffer_donation_test.cc @@ -215,8 +215,12 @@ TEST_F(BufferDonationTest, SimpleWhileTupleTest) { auto gte1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(f32v1_, while0, 1)); builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); - module->AddEntryComputation(builder.Build()); + // Input output aliasing is only supported on TPU. +#if defined(XLA_TEST_BACKEND_TPU) + TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias({0}, 0, {0})); + TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias({1}, 0, {1})); +#endif auto arg = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR0(0), LiteralUtil::CreateR1({1.1f})}); diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index 0cc2c981bad..243289c8821 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -1477,6 +1477,105 @@ TEST(RawApiTest, CompileAndExecuteWithReusedBuffers) { EXPECT_TRUE(CompareLiterals(return_literal, expected_literal)); } +TEST(RawApiTest, CompileAndExecuteWithReusedBuffersS64) { + xla::Shape element_shape = xla::ShapeUtil::MakeShape(xla::S64, {2}); + xla::Shape shape = + xla::ShapeUtil::MakeTupleShape({element_shape, element_shape}); + xla::Shape return_shape = xla::ShapeUtil::MakeTupleShape( + {element_shape, element_shape, element_shape, element_shape}); + xla::XlaBuilder builder("ReuseBuffer"); + auto param = xla::Parameter(&builder, 0, shape, "param"); + auto p0 = xla::GetTupleElement(param, 0); + auto p1 = xla::GetTupleElement(param, 1); + auto add = xla::Add(p0, p1); + auto sub = xla::Sub(p0, p1); + xla::Tuple(&builder, {add, sub, p0, p1}); + + // Flip the tuple literals in the input handle. + builder.SetUpAlias({1}, 0, {0}); + builder.SetUpAlias({0}, 0, {1}); + + auto computation = builder.Build().ValueOrDie(); + + auto literal0 = xla::LiteralUtil::CreateR1({1, 2}); + auto literal1 = xla::LiteralUtil::CreateR1({5, 9}); + auto literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1}); + + xrt::XLAAllocation param_alloc; + *param_alloc.mutable_value() = literal.ToProto(); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = shape.ToProto(); + *shapes->mutable_result() = return_shape.ToProto(); + StoreComputationSnapshot(computation, c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(false); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + XrtClientSession session(root); + auto e_config = + ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); + auto c_data = + ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, c_data); + auto param_value = ops::Const(root.WithDevice("/device:CPU:0"), + param_alloc.SerializeAsString()); + auto param_handle = ops::XRTAllocate(root, param_value); + TF_ASSERT_OK(root.status()); + + std::vector outputs; + TF_EXPECT_OK(session.Run({param_handle}, &outputs)); + + int64 alloc_handle = outputs[0].scalar()(); + + // Note that we release the result handle immediately, but since we aliased + // the output buffers onto the input allocation ones (held in alloc_handle), + // we can fetch the result from there. + auto result = + ops::XRTExecute(root, c_handle.handle, e_config, {Input(alloc_handle)}); + auto read_back = ops::XRTReadLiteral(root, result); + auto release = ops::XRTReleaseAllocationHandle( + root.WithControlDependencies(read_back), result); + TF_ASSERT_OK(root.status()); + + TF_EXPECT_OK( + session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs)); + + xla::Literal exec_literal = ReadOutputLiteral(outputs, 0); + auto exec_literal_parts = exec_literal.DecomposeTuple(); + ASSERT_EQ(exec_literal_parts.size(), 4); + + EXPECT_TRUE(CompareLiterals(exec_literal_parts[2], literal0)); + EXPECT_TRUE(CompareLiterals(exec_literal_parts[3], literal1)); + + // Now we read back the original input handle values, which at this point + // should contain the result of the XLA computation. + auto read_handle = ops::XRTReadLiteral(root, Input(alloc_handle)); + TF_ASSERT_OK(root.status()); + auto release_handle = ops::XRTReleaseAllocationHandle( + root.WithControlDependencies(read_handle), Input(alloc_handle)); + TF_ASSERT_OK(root.status()); + + TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {read_handle}, + {release_handle}, &outputs)); + + xla::Literal return_literal = ReadOutputLiteral(outputs, 0); + + auto expected_literal0 = xla::LiteralUtil::CreateR1({6, 11}); + auto expected_literal1 = xla::LiteralUtil::CreateR1({-4, -7}); + // The first element of the computation returned tuple would be the add + // (expected_literal0), but since we flipped the buffers, the sub + // (expected_literal1) should come first. + auto expected_literal = + xla::LiteralUtil::MakeTuple({&expected_literal1, &expected_literal0}); + + EXPECT_TRUE(CompareLiterals(return_literal, expected_literal)); +} + TEST(RawApiTest, CompileAndExecuteWithS64Argument) { xrt::XLAAllocation p0; *p0.mutable_value() = xla::LiteralUtil::CreateR0(11031965).ToProto();