diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 7c72cdfeb50..1c6fe0ab5db 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -310,7 +311,10 @@ StatusOr XlaBuilder::Build(int64 root_id) { module->add_computations()->Swap(&e.second); } module->add_computations()->Swap(&entry); - + if (!input_output_aliases_.empty()) { + TF_RETURN_IF_ERROR( + PopulateInputOutputAlias(module, program_shape, input_output_aliases_)); + } *(module->mutable_dynamic_parameter_binding()) = dynamic_parameter_binding_.ToProto(); @@ -323,6 +327,34 @@ StatusOr XlaBuilder::Build(int64 root_id) { return std::move(computation); } +/* static */ Status XlaBuilder::PopulateInputOutputAlias( + HloModuleProto* module, const ProgramShape& program_shape, + const std::vector& input_output_aliases) { + HloInputOutputAliasConfig config(program_shape.result()); + for (auto& alias : input_output_aliases) { + // The HloInputOutputAliasConfig does not do parameter validation as it only + // carries the result shape. Maybe it should be constructed with a + // ProgramShape to allow full validation. We will still get an error when + // trying to compile the HLO module, but would be better to have validation + // at this stage. + if (alias.param_number >= program_shape.parameters_size()) { + return InvalidArgument("Invalid parameter number %ld (total %ld)", + alias.param_number, + program_shape.parameters_size()); + } + const Shape& parameter_shape = program_shape.parameters(alias.param_number); + if (!ShapeUtil::IndexIsValid(parameter_shape, alias.param_index)) { + return InvalidArgument("Invalid parameter %ld index: %s", + alias.param_number, + alias.param_index.ToString().c_str()); + } + TF_RETURN_IF_ERROR(config.SetUpAlias(alias.output_index, alias.param_number, + alias.param_index)); + } + *module->mutable_input_output_alias() = config.ToProto(); + return Status::OK(); +} + StatusOr XlaBuilder::InDimBroadcast( const Shape& shape, const XlaOp& operand, absl::Span broadcast_dimensions) { diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 6e9b025e5d7..68ddf2cdd89 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -276,7 +276,22 @@ class XlaBuilder { int64 target_param_num, ShapeIndex target_param_index, int64 target_dim_num); + // Adds a new input/output alias. Since the input/ouput shape information are + // not available until the computation is built, and eventual error in the + // arguments of this API will be detected only at computation Build() time. + void SetUpAlias(const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index) { + input_output_aliases_.push_back({output_index, param_number, param_index}); + } + private: + // Describes an input/output alias as inserted by the SetUpAlias() API. + struct InputOutputAlias { + ShapeIndex output_index; + int64 param_number; + ShapeIndex param_index; + }; + // Build helper which takes the id of the root operation.. StatusOr Build(int64 root_id); @@ -730,6 +745,12 @@ class XlaBuilder { int64 GetNextId() { return ++next_id_; } + // Populates the module with the input/output alias information stored within + // the input_output_aliases vector. + static Status PopulateInputOutputAlias( + HloModuleProto* module, const ProgramShape& program_shape, + const std::vector& input_output_aliases); + string name_; // Name to use for the built computation. // The next sequential ID for every instruction/computation contained within @@ -749,6 +770,9 @@ class XlaBuilder { // Dynamic parameter configuration of this computation. DynamicParameterBinding dynamic_parameter_binding_; + // Holds the input/output alias information populated by the SetUpAlias() API. + std::vector input_output_aliases_; + // A map from XlaOp::Handle to the index in the instructions_ vector where the // instruction is held. absl::flat_hash_map handle_to_index_; diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index b3f5be300d3..ba929a12009 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -455,5 +455,31 @@ TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) { ::testing::HasSubstr("All operands to AfterAll must be tokens")); } +TEST_F(XlaBuilderTest, CheckInputOutputAlias) { + XlaBuilder b(TestName()); + auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {8, 4}), "p0"); + auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {8, 4}), "p1"); + auto add = Add(p0, p1); + auto sub = Sub(p0, p1); + auto root = Tuple(&b, {add, sub}); + + b.SetUpAlias({1}, 0, {}); + b.SetUpAlias({0}, 1, {}); + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, root)); + + const HloInputOutputAliasConfig& config = module->input_output_alias_config(); + EXPECT_TRUE(config.ParameterHasAlias(0, {})); + EXPECT_TRUE(config.ParameterHasAlias(1, {})); + + auto alias_p0 = config.GetAliasedOutput(0, {}); + ASSERT_TRUE(alias_p0.has_value()); + EXPECT_EQ(*alias_p0, ShapeIndex({1})); + + auto alias_p1 = config.GetAliasedOutput(1, {}); + ASSERT_TRUE(alias_p1.has_value()); + EXPECT_EQ(*alias_p1, ShapeIndex({0})); +} + } // namespace } // namespace xla