[XLA] Split out HloDynamicUpdateSliceInstruction
This doesn't have any benefit in terms of sizeof(HloInstruction), but it's awkward to have a sublcass for DS and not DUS. Also adds an intermediate class in the hierarchy that avoids having to hard-code the index operand's number. PiperOrigin-RevId: 225033893
This commit is contained in:
parent
90a840fbcb
commit
c99ecfa992
@ -914,12 +914,8 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
|
||||
HloInstruction* operand,
|
||||
HloInstruction* update,
|
||||
HloInstruction* start_indices) {
|
||||
auto instruction = absl::WrapUnique(
|
||||
new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape));
|
||||
instruction->AppendOperand(operand);
|
||||
instruction->AppendOperand(update);
|
||||
instruction->AppendOperand(start_indices);
|
||||
return instruction;
|
||||
return absl::make_unique<HloDynamicUpdateSliceInstruction>(
|
||||
shape, operand, update, start_indices);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
|
||||
|
@ -1994,12 +1994,21 @@ std::unique_ptr<HloInstruction> HloPadInstruction::CloneWithNewOperandsImpl(
|
||||
HloDynamicSliceInstruction::HloDynamicSliceInstruction(
|
||||
const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
|
||||
absl::Span<const int64> slice_sizes)
|
||||
: HloInstruction(HloOpcode::kDynamicSlice, shape),
|
||||
: HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape),
|
||||
dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
|
||||
AppendOperand(operand);
|
||||
AppendOperand(start_indices);
|
||||
}
|
||||
|
||||
HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction(
|
||||
const Shape& shape, HloInstruction* operand, HloInstruction* update,
|
||||
HloInstruction* start_indices)
|
||||
: HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) {
|
||||
AppendOperand(operand);
|
||||
AppendOperand(update);
|
||||
AppendOperand(start_indices);
|
||||
}
|
||||
|
||||
HloInstructionProto HloDynamicSliceInstruction::ToProto() const {
|
||||
HloInstructionProto proto = HloInstruction::ToProto();
|
||||
for (int64 slice_size : dynamic_slice_sizes_) {
|
||||
|
@ -1171,7 +1171,14 @@ class HloPadInstruction : public HloInstruction {
|
||||
PaddingConfig padding_config_;
|
||||
};
|
||||
|
||||
class HloDynamicSliceInstruction : public HloInstruction {
|
||||
class HloDynamicIndexInstruction : public HloInstruction {
|
||||
public:
|
||||
explicit HloDynamicIndexInstruction(HloOpcode opcode, const Shape& shape)
|
||||
: HloInstruction(opcode, shape) {}
|
||||
virtual int64 index_operand_number() const = 0;
|
||||
};
|
||||
|
||||
class HloDynamicSliceInstruction : public HloDynamicIndexInstruction {
|
||||
public:
|
||||
explicit HloDynamicSliceInstruction(const Shape& shape,
|
||||
HloInstruction* operand,
|
||||
@ -1189,6 +1196,8 @@ class HloDynamicSliceInstruction : public HloInstruction {
|
||||
// Returns a serialized representation of this instruction.
|
||||
HloInstructionProto ToProto() const override;
|
||||
|
||||
int64 index_operand_number() const override { return 1; }
|
||||
|
||||
private:
|
||||
std::vector<string> ExtraAttributesToStringImpl(
|
||||
const HloPrintOptions& options) const override;
|
||||
@ -1206,6 +1215,16 @@ class HloDynamicSliceInstruction : public HloInstruction {
|
||||
std::vector<int64> dynamic_slice_sizes_;
|
||||
};
|
||||
|
||||
class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction {
|
||||
public:
|
||||
explicit HloDynamicUpdateSliceInstruction(const Shape& shape,
|
||||
HloInstruction* operand,
|
||||
HloInstruction* update,
|
||||
HloInstruction* start_indices);
|
||||
|
||||
int64 index_operand_number() const override { return 2; }
|
||||
};
|
||||
|
||||
class HloGatherInstruction : public HloInstruction {
|
||||
public:
|
||||
explicit HloGatherInstruction(
|
||||
|
Loading…
Reference in New Issue
Block a user