Allow input/output alias information to be populated via the XLA builder.
PiperOrigin-RevId: 226935597
This commit is contained in:
parent
b58527bd2d
commit
c220247fe3
@ -29,6 +29,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/client/sharding_builder.h"
|
#include "tensorflow/compiler/xla/client/sharding_builder.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||||
#include "tensorflow/compiler/xla/execution_options_util.h"
|
#include "tensorflow/compiler/xla/execution_options_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||||
@ -310,7 +311,10 @@ StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id) {
|
|||||||
module->add_computations()->Swap(&e.second);
|
module->add_computations()->Swap(&e.second);
|
||||||
}
|
}
|
||||||
module->add_computations()->Swap(&entry);
|
module->add_computations()->Swap(&entry);
|
||||||
|
if (!input_output_aliases_.empty()) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
PopulateInputOutputAlias(module, program_shape, input_output_aliases_));
|
||||||
|
}
|
||||||
*(module->mutable_dynamic_parameter_binding()) =
|
*(module->mutable_dynamic_parameter_binding()) =
|
||||||
dynamic_parameter_binding_.ToProto();
|
dynamic_parameter_binding_.ToProto();
|
||||||
|
|
||||||
@ -323,6 +327,34 @@ StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id) {
|
|||||||
return std::move(computation);
|
return std::move(computation);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */ Status XlaBuilder::PopulateInputOutputAlias(
|
||||||
|
HloModuleProto* module, const ProgramShape& program_shape,
|
||||||
|
const std::vector<InputOutputAlias>& input_output_aliases) {
|
||||||
|
HloInputOutputAliasConfig config(program_shape.result());
|
||||||
|
for (auto& alias : input_output_aliases) {
|
||||||
|
// The HloInputOutputAliasConfig does not do parameter validation as it only
|
||||||
|
// carries the result shape. Maybe it should be constructed with a
|
||||||
|
// ProgramShape to allow full validation. We will still get an error when
|
||||||
|
// trying to compile the HLO module, but would be better to have validation
|
||||||
|
// at this stage.
|
||||||
|
if (alias.param_number >= program_shape.parameters_size()) {
|
||||||
|
return InvalidArgument("Invalid parameter number %ld (total %ld)",
|
||||||
|
alias.param_number,
|
||||||
|
program_shape.parameters_size());
|
||||||
|
}
|
||||||
|
const Shape& parameter_shape = program_shape.parameters(alias.param_number);
|
||||||
|
if (!ShapeUtil::IndexIsValid(parameter_shape, alias.param_index)) {
|
||||||
|
return InvalidArgument("Invalid parameter %ld index: %s",
|
||||||
|
alias.param_number,
|
||||||
|
alias.param_index.ToString().c_str());
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(config.SetUpAlias(alias.output_index, alias.param_number,
|
||||||
|
alias.param_index));
|
||||||
|
}
|
||||||
|
*module->mutable_input_output_alias() = config.ToProto();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
|
StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
|
||||||
const Shape& shape, const XlaOp& operand,
|
const Shape& shape, const XlaOp& operand,
|
||||||
absl::Span<const int64> broadcast_dimensions) {
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
|
@ -276,7 +276,22 @@ class XlaBuilder {
|
|||||||
int64 target_param_num,
|
int64 target_param_num,
|
||||||
ShapeIndex target_param_index, int64 target_dim_num);
|
ShapeIndex target_param_index, int64 target_dim_num);
|
||||||
|
|
||||||
|
// Adds a new input/output alias. Since the input/ouput shape information are
|
||||||
|
// not available until the computation is built, and eventual error in the
|
||||||
|
// arguments of this API will be detected only at computation Build() time.
|
||||||
|
void SetUpAlias(const ShapeIndex& output_index, int64 param_number,
|
||||||
|
const ShapeIndex& param_index) {
|
||||||
|
input_output_aliases_.push_back({output_index, param_number, param_index});
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
// Describes an input/output alias as inserted by the SetUpAlias() API.
|
||||||
|
struct InputOutputAlias {
|
||||||
|
ShapeIndex output_index;
|
||||||
|
int64 param_number;
|
||||||
|
ShapeIndex param_index;
|
||||||
|
};
|
||||||
|
|
||||||
// Build helper which takes the id of the root operation..
|
// Build helper which takes the id of the root operation..
|
||||||
StatusOr<XlaComputation> Build(int64 root_id);
|
StatusOr<XlaComputation> Build(int64 root_id);
|
||||||
|
|
||||||
@ -730,6 +745,12 @@ class XlaBuilder {
|
|||||||
|
|
||||||
int64 GetNextId() { return ++next_id_; }
|
int64 GetNextId() { return ++next_id_; }
|
||||||
|
|
||||||
|
// Populates the module with the input/output alias information stored within
|
||||||
|
// the input_output_aliases vector.
|
||||||
|
static Status PopulateInputOutputAlias(
|
||||||
|
HloModuleProto* module, const ProgramShape& program_shape,
|
||||||
|
const std::vector<InputOutputAlias>& input_output_aliases);
|
||||||
|
|
||||||
string name_; // Name to use for the built computation.
|
string name_; // Name to use for the built computation.
|
||||||
|
|
||||||
// The next sequential ID for every instruction/computation contained within
|
// The next sequential ID for every instruction/computation contained within
|
||||||
@ -749,6 +770,9 @@ class XlaBuilder {
|
|||||||
// Dynamic parameter configuration of this computation.
|
// Dynamic parameter configuration of this computation.
|
||||||
DynamicParameterBinding dynamic_parameter_binding_;
|
DynamicParameterBinding dynamic_parameter_binding_;
|
||||||
|
|
||||||
|
// Holds the input/output alias information populated by the SetUpAlias() API.
|
||||||
|
std::vector<InputOutputAlias> input_output_aliases_;
|
||||||
|
|
||||||
// A map from XlaOp::Handle to the index in the instructions_ vector where the
|
// A map from XlaOp::Handle to the index in the instructions_ vector where the
|
||||||
// instruction is held.
|
// instruction is held.
|
||||||
absl::flat_hash_map<int64, int64> handle_to_index_;
|
absl::flat_hash_map<int64, int64> handle_to_index_;
|
||||||
|
@ -455,5 +455,31 @@ TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) {
|
|||||||
::testing::HasSubstr("All operands to AfterAll must be tokens"));
|
::testing::HasSubstr("All operands to AfterAll must be tokens"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(XlaBuilderTest, CheckInputOutputAlias) {
|
||||||
|
XlaBuilder b(TestName());
|
||||||
|
auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {8, 4}), "p0");
|
||||||
|
auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {8, 4}), "p1");
|
||||||
|
auto add = Add(p0, p1);
|
||||||
|
auto sub = Sub(p0, p1);
|
||||||
|
auto root = Tuple(&b, {add, sub});
|
||||||
|
|
||||||
|
b.SetUpAlias({1}, 0, {});
|
||||||
|
b.SetUpAlias({0}, 1, {});
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, root));
|
||||||
|
|
||||||
|
const HloInputOutputAliasConfig& config = module->input_output_alias_config();
|
||||||
|
EXPECT_TRUE(config.ParameterHasAlias(0, {}));
|
||||||
|
EXPECT_TRUE(config.ParameterHasAlias(1, {}));
|
||||||
|
|
||||||
|
auto alias_p0 = config.GetAliasedOutput(0, {});
|
||||||
|
ASSERT_TRUE(alias_p0.has_value());
|
||||||
|
EXPECT_EQ(*alias_p0, ShapeIndex({1}));
|
||||||
|
|
||||||
|
auto alias_p1 = config.GetAliasedOutput(1, {});
|
||||||
|
ASSERT_TRUE(alias_p1.has_value());
|
||||||
|
EXPECT_EQ(*alias_p1, ShapeIndex({0}));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
Reference in New Issue
Block a user