[XLA] Split out HloDynamicUpdateSliceInstruction
This doesn't have any benefit in terms of sizeof(HloInstruction), but it's awkward to have a sublcass for DS and not DUS. Also adds an intermediate class in the hierarchy that avoids having to hard-code the index operand's number. PiperOrigin-RevId: 225033893
This commit is contained in:
parent
90a840fbcb
commit
c99ecfa992
@ -914,12 +914,8 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
|
|||||||
HloInstruction* operand,
|
HloInstruction* operand,
|
||||||
HloInstruction* update,
|
HloInstruction* update,
|
||||||
HloInstruction* start_indices) {
|
HloInstruction* start_indices) {
|
||||||
auto instruction = absl::WrapUnique(
|
return absl::make_unique<HloDynamicUpdateSliceInstruction>(
|
||||||
new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape));
|
shape, operand, update, start_indices);
|
||||||
instruction->AppendOperand(operand);
|
|
||||||
instruction->AppendOperand(update);
|
|
||||||
instruction->AppendOperand(start_indices);
|
|
||||||
return instruction;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
|
||||||
|
@ -1994,12 +1994,21 @@ std::unique_ptr<HloInstruction> HloPadInstruction::CloneWithNewOperandsImpl(
|
|||||||
HloDynamicSliceInstruction::HloDynamicSliceInstruction(
|
HloDynamicSliceInstruction::HloDynamicSliceInstruction(
|
||||||
const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
|
const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
|
||||||
absl::Span<const int64> slice_sizes)
|
absl::Span<const int64> slice_sizes)
|
||||||
: HloInstruction(HloOpcode::kDynamicSlice, shape),
|
: HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape),
|
||||||
dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
|
dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
|
||||||
AppendOperand(operand);
|
AppendOperand(operand);
|
||||||
AppendOperand(start_indices);
|
AppendOperand(start_indices);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction(
|
||||||
|
const Shape& shape, HloInstruction* operand, HloInstruction* update,
|
||||||
|
HloInstruction* start_indices)
|
||||||
|
: HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) {
|
||||||
|
AppendOperand(operand);
|
||||||
|
AppendOperand(update);
|
||||||
|
AppendOperand(start_indices);
|
||||||
|
}
|
||||||
|
|
||||||
HloInstructionProto HloDynamicSliceInstruction::ToProto() const {
|
HloInstructionProto HloDynamicSliceInstruction::ToProto() const {
|
||||||
HloInstructionProto proto = HloInstruction::ToProto();
|
HloInstructionProto proto = HloInstruction::ToProto();
|
||||||
for (int64 slice_size : dynamic_slice_sizes_) {
|
for (int64 slice_size : dynamic_slice_sizes_) {
|
||||||
|
@ -1171,7 +1171,14 @@ class HloPadInstruction : public HloInstruction {
|
|||||||
PaddingConfig padding_config_;
|
PaddingConfig padding_config_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class HloDynamicSliceInstruction : public HloInstruction {
|
class HloDynamicIndexInstruction : public HloInstruction {
|
||||||
|
public:
|
||||||
|
explicit HloDynamicIndexInstruction(HloOpcode opcode, const Shape& shape)
|
||||||
|
: HloInstruction(opcode, shape) {}
|
||||||
|
virtual int64 index_operand_number() const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class HloDynamicSliceInstruction : public HloDynamicIndexInstruction {
|
||||||
public:
|
public:
|
||||||
explicit HloDynamicSliceInstruction(const Shape& shape,
|
explicit HloDynamicSliceInstruction(const Shape& shape,
|
||||||
HloInstruction* operand,
|
HloInstruction* operand,
|
||||||
@ -1189,6 +1196,8 @@ class HloDynamicSliceInstruction : public HloInstruction {
|
|||||||
// Returns a serialized representation of this instruction.
|
// Returns a serialized representation of this instruction.
|
||||||
HloInstructionProto ToProto() const override;
|
HloInstructionProto ToProto() const override;
|
||||||
|
|
||||||
|
int64 index_operand_number() const override { return 1; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<string> ExtraAttributesToStringImpl(
|
std::vector<string> ExtraAttributesToStringImpl(
|
||||||
const HloPrintOptions& options) const override;
|
const HloPrintOptions& options) const override;
|
||||||
@ -1206,6 +1215,16 @@ class HloDynamicSliceInstruction : public HloInstruction {
|
|||||||
std::vector<int64> dynamic_slice_sizes_;
|
std::vector<int64> dynamic_slice_sizes_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction {
|
||||||
|
public:
|
||||||
|
explicit HloDynamicUpdateSliceInstruction(const Shape& shape,
|
||||||
|
HloInstruction* operand,
|
||||||
|
HloInstruction* update,
|
||||||
|
HloInstruction* start_indices);
|
||||||
|
|
||||||
|
int64 index_operand_number() const override { return 2; }
|
||||||
|
};
|
||||||
|
|
||||||
class HloGatherInstruction : public HloInstruction {
|
class HloGatherInstruction : public HloInstruction {
|
||||||
public:
|
public:
|
||||||
explicit HloGatherInstruction(
|
explicit HloGatherInstruction(
|
||||||
|
Loading…
Reference in New Issue
Block a user