Allow input/output alias information to be populated via the XLA builder.

PiperOrigin-RevId: 226935597
This commit is contained in:
Davide Libenzi 2018-12-26 11:12:32 -08:00 committed by TensorFlower Gardener
parent b58527bd2d
commit c220247fe3
3 changed files with 83 additions and 1 deletions

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/sharding_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
@ -310,7 +311,10 @@ StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id) {
module->add_computations()->Swap(&e.second);
}
module->add_computations()->Swap(&entry);
if (!input_output_aliases_.empty()) {
TF_RETURN_IF_ERROR(
PopulateInputOutputAlias(module, program_shape, input_output_aliases_));
}
*(module->mutable_dynamic_parameter_binding()) =
dynamic_parameter_binding_.ToProto();
@ -323,6 +327,34 @@ StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id) {
return std::move(computation);
}
/* static */ Status XlaBuilder::PopulateInputOutputAlias(
HloModuleProto* module, const ProgramShape& program_shape,
const std::vector<InputOutputAlias>& input_output_aliases) {
HloInputOutputAliasConfig config(program_shape.result());
for (auto& alias : input_output_aliases) {
// The HloInputOutputAliasConfig does not do parameter validation as it only
// carries the result shape. Maybe it should be constructed with a
// ProgramShape to allow full validation. We will still get an error when
// trying to compile the HLO module, but would be better to have validation
// at this stage.
if (alias.param_number >= program_shape.parameters_size()) {
return InvalidArgument("Invalid parameter number %ld (total %ld)",
alias.param_number,
program_shape.parameters_size());
}
const Shape& parameter_shape = program_shape.parameters(alias.param_number);
if (!ShapeUtil::IndexIsValid(parameter_shape, alias.param_index)) {
return InvalidArgument("Invalid parameter %ld index: %s",
alias.param_number,
alias.param_index.ToString().c_str());
}
TF_RETURN_IF_ERROR(config.SetUpAlias(alias.output_index, alias.param_number,
alias.param_index));
}
*module->mutable_input_output_alias() = config.ToProto();
return Status::OK();
}
StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
const Shape& shape, const XlaOp& operand,
absl::Span<const int64> broadcast_dimensions) {

View File

@ -276,7 +276,22 @@ class XlaBuilder {
int64 target_param_num,
ShapeIndex target_param_index, int64 target_dim_num);
// Adds a new input/output alias. Since the input/ouput shape information are
// not available until the computation is built, and eventual error in the
// arguments of this API will be detected only at computation Build() time.
void SetUpAlias(const ShapeIndex& output_index, int64 param_number,
const ShapeIndex& param_index) {
input_output_aliases_.push_back({output_index, param_number, param_index});
}
private:
// Describes an input/output alias as inserted by the SetUpAlias() API.
struct InputOutputAlias {
ShapeIndex output_index;
int64 param_number;
ShapeIndex param_index;
};
// Build helper which takes the id of the root operation..
StatusOr<XlaComputation> Build(int64 root_id);
@ -730,6 +745,12 @@ class XlaBuilder {
int64 GetNextId() { return ++next_id_; }
// Populates the module with the input/output alias information stored within
// the input_output_aliases vector.
static Status PopulateInputOutputAlias(
HloModuleProto* module, const ProgramShape& program_shape,
const std::vector<InputOutputAlias>& input_output_aliases);
string name_; // Name to use for the built computation.
// The next sequential ID for every instruction/computation contained within
@ -749,6 +770,9 @@ class XlaBuilder {
// Dynamic parameter configuration of this computation.
DynamicParameterBinding dynamic_parameter_binding_;
// Holds the input/output alias information populated by the SetUpAlias() API.
std::vector<InputOutputAlias> input_output_aliases_;
// A map from XlaOp::Handle to the index in the instructions_ vector where the
// instruction is held.
absl::flat_hash_map<int64, int64> handle_to_index_;

View File

@ -455,5 +455,31 @@ TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) {
::testing::HasSubstr("All operands to AfterAll must be tokens"));
}
TEST_F(XlaBuilderTest, CheckInputOutputAlias) {
XlaBuilder b(TestName());
auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {8, 4}), "p0");
auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {8, 4}), "p1");
auto add = Add(p0, p1);
auto sub = Sub(p0, p1);
auto root = Tuple(&b, {add, sub});
b.SetUpAlias({1}, 0, {});
b.SetUpAlias({0}, 1, {});
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, root));
const HloInputOutputAliasConfig& config = module->input_output_alias_config();
EXPECT_TRUE(config.ParameterHasAlias(0, {}));
EXPECT_TRUE(config.ParameterHasAlias(1, {}));
auto alias_p0 = config.GetAliasedOutput(0, {});
ASSERT_TRUE(alias_p0.has_value());
EXPECT_EQ(*alias_p0, ShapeIndex({1}));
auto alias_p1 = config.GetAliasedOutput(1, {});
ASSERT_TRUE(alias_p1.has_value());
EXPECT_EQ(*alias_p1, ShapeIndex({0}));
}
} // namespace
} // namespace xla