diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 21b1dbc1676..5c1f1a61cc2 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -914,12 +914,8 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, HloInstruction* operand, HloInstruction* update, HloInstruction* start_indices) { - auto instruction = absl::WrapUnique( - new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(update); - instruction->AppendOperand(start_indices); - return instruction; + return absl::make_unique( + shape, operand, update, start_indices); } /* static */ std::unique_ptr HloInstruction::CreateConcatenate( diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 1ea02cf9c03..2fe6395efec 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1994,12 +1994,21 @@ std::unique_ptr HloPadInstruction::CloneWithNewOperandsImpl( HloDynamicSliceInstruction::HloDynamicSliceInstruction( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, absl::Span slice_sizes) - : HloInstruction(HloOpcode::kDynamicSlice, shape), + : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape), dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) { AppendOperand(operand); AppendOperand(start_indices); } +HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* update, + HloInstruction* start_indices) + : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) { + AppendOperand(operand); + AppendOperand(update); + AppendOperand(start_indices); +} + HloInstructionProto HloDynamicSliceInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); for (int64 slice_size : dynamic_slice_sizes_) { diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index b5c28137a14..5420d4ce11f 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1171,7 +1171,14 @@ class HloPadInstruction : public HloInstruction { PaddingConfig padding_config_; }; -class HloDynamicSliceInstruction : public HloInstruction { +class HloDynamicIndexInstruction : public HloInstruction { + public: + explicit HloDynamicIndexInstruction(HloOpcode opcode, const Shape& shape) + : HloInstruction(opcode, shape) {} + virtual int64 index_operand_number() const = 0; +}; + +class HloDynamicSliceInstruction : public HloDynamicIndexInstruction { public: explicit HloDynamicSliceInstruction(const Shape& shape, HloInstruction* operand, @@ -1189,6 +1196,8 @@ class HloDynamicSliceInstruction : public HloInstruction { // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + int64 index_operand_number() const override { return 1; } + private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -1206,6 +1215,16 @@ class HloDynamicSliceInstruction : public HloInstruction { std::vector dynamic_slice_sizes_; }; +class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction { + public: + explicit HloDynamicUpdateSliceInstruction(const Shape& shape, + HloInstruction* operand, + HloInstruction* update, + HloInstruction* start_indices); + + int64 index_operand_number() const override { return 2; } +}; + class HloGatherInstruction : public HloInstruction { public: explicit HloGatherInstruction(