[XLA] Split out HloDynamicUpdateSliceInstruction

This doesn't have any benefit in terms of sizeof(HloInstruction), but it's awkward to have a sublcass for DS and not DUS. Also adds an intermediate class in the hierarchy that avoids having to hard-code the index operand's number.

PiperOrigin-RevId: 225033893
This commit is contained in:
Michael Kuperstein 2018-12-11 11:01:27 -08:00 committed by TensorFlower Gardener
parent 90a840fbcb
commit c99ecfa992
3 changed files with 32 additions and 8 deletions

View File

@ -914,12 +914,8 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
HloInstruction* operand,
HloInstruction* update,
HloInstruction* start_indices) {
auto instruction = absl::WrapUnique(
new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape));
instruction->AppendOperand(operand);
instruction->AppendOperand(update);
instruction->AppendOperand(start_indices);
return instruction;
return absl::make_unique<HloDynamicUpdateSliceInstruction>(
shape, operand, update, start_indices);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(

View File

@ -1994,12 +1994,21 @@ std::unique_ptr<HloInstruction> HloPadInstruction::CloneWithNewOperandsImpl(
HloDynamicSliceInstruction::HloDynamicSliceInstruction(
const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
absl::Span<const int64> slice_sizes)
: HloInstruction(HloOpcode::kDynamicSlice, shape),
: HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape),
dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
AppendOperand(operand);
AppendOperand(start_indices);
}
HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction(
const Shape& shape, HloInstruction* operand, HloInstruction* update,
HloInstruction* start_indices)
: HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) {
AppendOperand(operand);
AppendOperand(update);
AppendOperand(start_indices);
}
HloInstructionProto HloDynamicSliceInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
for (int64 slice_size : dynamic_slice_sizes_) {

View File

@ -1171,7 +1171,14 @@ class HloPadInstruction : public HloInstruction {
PaddingConfig padding_config_;
};
class HloDynamicSliceInstruction : public HloInstruction {
class HloDynamicIndexInstruction : public HloInstruction {
public:
explicit HloDynamicIndexInstruction(HloOpcode opcode, const Shape& shape)
: HloInstruction(opcode, shape) {}
virtual int64 index_operand_number() const = 0;
};
class HloDynamicSliceInstruction : public HloDynamicIndexInstruction {
public:
explicit HloDynamicSliceInstruction(const Shape& shape,
HloInstruction* operand,
@ -1189,6 +1196,8 @@ class HloDynamicSliceInstruction : public HloInstruction {
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
int64 index_operand_number() const override { return 1; }
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
@ -1206,6 +1215,16 @@ class HloDynamicSliceInstruction : public HloInstruction {
std::vector<int64> dynamic_slice_sizes_;
};
class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction {
public:
explicit HloDynamicUpdateSliceInstruction(const Shape& shape,
HloInstruction* operand,
HloInstruction* update,
HloInstruction* start_indices);
int64 index_operand_number() const override { return 2; }
};
class HloGatherInstruction : public HloInstruction {
public:
explicit HloGatherInstruction(