1946 lines
83 KiB
C++
1946 lines
83 KiB
C++
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
// All HloInstruction subclasses are put in this file.
|
|
|
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
|
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
|
|
|
|
#include "absl/memory/memory.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
|
#include "tensorflow/compiler/xla/shape.h"
|
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
|
|
|
namespace xla {
|
|
|
|
class HloBatchNormInstruction : public HloInstruction {
|
|
public:
|
|
// Returns feature_index field associated with the instruction. The index
|
|
// represents the index of the feature dimension.
|
|
int64 feature_index() const { return feature_index_; }
|
|
|
|
// Returns a epsilon value associated with the instruction. The is a small
|
|
// number added to the variance to avoid divide-by-zero error.
|
|
float epsilon() const { return epsilon_; }
|
|
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
protected:
|
|
explicit HloBatchNormInstruction(HloOpcode opcode, const Shape& shape,
|
|
HloInstruction* operand,
|
|
HloInstruction* scale, float epsilon,
|
|
int64 feature_index);
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// A small float number added to the variance to avoid divide-by-zero error.
|
|
float epsilon_ = 0.0f;
|
|
|
|
// An integer value representing the index of the feature dimension.
|
|
int64 feature_index_ = -1;
|
|
};
|
|
|
|
class HloBatchNormTrainingInstruction : public HloBatchNormInstruction {
|
|
public:
|
|
explicit HloBatchNormTrainingInstruction(const Shape& shape,
|
|
HloInstruction* operand,
|
|
HloInstruction* scale,
|
|
HloInstruction* offset,
|
|
float epsilon, int64 feature_index);
|
|
|
|
private:
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
};
|
|
|
|
class HloBatchNormInferenceInstruction : public HloBatchNormInstruction {
|
|
public:
|
|
explicit HloBatchNormInferenceInstruction(
|
|
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
|
|
HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
|
|
float epsilon, int64 feature_index);
|
|
|
|
private:
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
};
|
|
|
|
class HloBatchNormGradInstruction : public HloBatchNormInstruction {
|
|
public:
|
|
explicit HloBatchNormGradInstruction(
|
|
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
|
|
HloInstruction* mean, HloInstruction* variance,
|
|
HloInstruction* grad_output, float epsilon, int64 feature_index);
|
|
|
|
private:
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
};
|
|
|
|
class HloFftInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloFftInstruction(const Shape& shape, HloInstruction* operand,
|
|
FftType fft_type,
|
|
absl::Span<const int64> fft_length);
|
|
FftType fft_type() const { return fft_type_; }
|
|
|
|
const std::vector<int64>& fft_length() const { return fft_length_; }
|
|
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
// Describes FFT type for an FFT instruction.
|
|
FftType fft_type_ = FftType::FFT;
|
|
|
|
// Indicates the FFT length for an FFT instruction.
|
|
std::vector<int64> fft_length_;
|
|
};
|
|
|
|
class HloCopyStartInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloCopyStartInstruction(const Shape& shape, HloInstruction* operand,
|
|
bool is_cross_program_prefetch);
|
|
|
|
bool is_cross_program_prefetch() const { return is_cross_program_prefetch_; }
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
bool is_cross_program_prefetch_;
|
|
};
|
|
|
|
class HloCompareInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs,
|
|
HloInstruction* rhs,
|
|
ComparisonDirection direction,
|
|
absl::optional<Comparison::Type> type);
|
|
ComparisonDirection direction() const { return compare_.GetDirection(); }
|
|
Comparison::Type type() const { return compare_.GetType(); }
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
Comparison compare_;
|
|
};
|
|
|
|
class HloTriangularSolveInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloTriangularSolveInstruction(const Shape& shape, HloInstruction* a,
|
|
HloInstruction* b,
|
|
const TriangularSolveOptions& options);
|
|
const TriangularSolveOptions& triangular_solve_options() const {
|
|
return triangular_solve_options_;
|
|
}
|
|
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
TriangularSolveOptions triangular_solve_options_;
|
|
};
|
|
|
|
class HloCholeskyInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloCholeskyInstruction(const Shape& shape, HloInstruction* a,
|
|
const CholeskyOptions& options);
|
|
const CholeskyOptions& cholesky_options() const { return cholesky_options_; }
|
|
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
CholeskyOptions cholesky_options_;
|
|
};
|
|
|
|
// Class that represents instructions that synchronize and transfer data between
|
|
// partitioned devices. Send/Recv and collective instructions (AllReduce,
|
|
// AllToAll, CollectivePermute) belong to this instruction type. A group of
|
|
// instructions (of the same opcode) with the same channel_id communicate during
|
|
// execution.
|
|
class HloChannelInstruction : public HloInstruction {
|
|
public:
|
|
// Returns the channel id associated with the instruction. The id is
|
|
// shared between each Send/Recv pair or a group of collective instructions
|
|
// and is globally unique to identify each channel.
|
|
absl::optional<int64> channel_id() const { return channel_id_; }
|
|
void set_channel_id(const absl::optional<int64>& channel_id);
|
|
|
|
// Whether this instruction is identical to `other` except for the values of
|
|
// channel IDs, as long as both have channel IDs or neither has a channel ID.
|
|
virtual bool IdenticalSlowPathIgnoringChannelIdValues(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const {
|
|
return channel_id_.has_value() == other.channel_id().has_value();
|
|
}
|
|
|
|
protected:
|
|
explicit HloChannelInstruction(HloOpcode opcode, const Shape& shape,
|
|
const absl::optional<int64>& channel_id);
|
|
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
|
|
// Do not override IdenticalSlowPath(). Override
|
|
// IdenticalSlowPathIgnoringChannelIdValues() instead.
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const final;
|
|
|
|
absl::optional<int64> channel_id_;
|
|
};
|
|
|
|
class HloSendRecvInstruction : public HloChannelInstruction {
|
|
public:
|
|
// Returns whether this send/recv instruction sends data to/from the host.
|
|
bool is_host_transfer() const { return is_host_transfer_; }
|
|
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
protected:
|
|
explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape,
|
|
int64 channel_id, bool is_host_transfer);
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPathIgnoringChannelIdValues(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Whether this send/recv instruction sends data to/from the host.
|
|
bool is_host_transfer_;
|
|
};
|
|
|
|
class HloSendInstruction : public HloSendRecvInstruction {
|
|
public:
|
|
explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token,
|
|
int64 channel_id, bool is_host_transfer);
|
|
|
|
private:
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
};
|
|
|
|
class HloSendDoneInstruction : public HloSendRecvInstruction {
|
|
public:
|
|
explicit HloSendDoneInstruction(HloSendInstruction* operand,
|
|
bool is_host_transfer);
|
|
|
|
private:
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
};
|
|
|
|
class HloRecvInstruction : public HloSendRecvInstruction {
|
|
public:
|
|
explicit HloRecvInstruction(const Shape& shape, HloInstruction* token,
|
|
int64 channel_id, bool is_host_transfer);
|
|
|
|
private:
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
};
|
|
|
|
class HloRecvDoneInstruction : public HloSendRecvInstruction {
|
|
public:
|
|
explicit HloRecvDoneInstruction(HloRecvInstruction* operand,
|
|
bool is_host_transfer);
|
|
|
|
private:
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
};
|
|
|
|
class HloCollectiveInstruction : public HloChannelInstruction {
|
|
public:
|
|
const std::vector<ReplicaGroup>& replica_groups() const {
|
|
return replica_groups_;
|
|
}
|
|
|
|
// Returns true if the layout of the AllReduce is enforced by XLA client (as
|
|
// the layout set in the shape). The only reason for the client to set the
|
|
// layout is to separately compile computations that communicate with
|
|
// AllReduce. Since this field is only set `true` by the client, the compiler
|
|
// only needs to propagate existing values (e.g., Clone, X64Rewriter) or set
|
|
// `false` for all other cases.
|
|
//
|
|
// When this is `true`, there may be communication endpoints outside the
|
|
// current compilation unit, so the compiler considers this AllReduce as
|
|
// side-effecting to disable compiler transformations. The compiler is free to
|
|
// transform unconstrained AllReduces differently across compilation units.
|
|
// It is an error for an HloModule to have a mix of constrained and
|
|
// unconstrained AllReduce instructions (checked by HloVerifier).
|
|
bool constrain_layout() const { return constrain_layout_; }
|
|
|
|
protected:
|
|
explicit HloCollectiveInstruction(
|
|
HloOpcode opcode, const Shape& shape,
|
|
absl::Span<HloInstruction* const> operands,
|
|
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
|
|
const absl::optional<int64>& channel_id);
|
|
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPathIgnoringChannelIdValues(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
|
|
std::vector<ReplicaGroup> replica_groups_;
|
|
bool constrain_layout_;
|
|
};
|
|
|
|
class HloAllGatherInstruction : public HloCollectiveInstruction {
|
|
public:
|
|
explicit HloAllGatherInstruction(
|
|
const Shape& shape, HloInstruction* operand, int64 all_gather_dimension,
|
|
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
|
|
const absl::optional<int64>& channel_id, bool use_global_device_ids);
|
|
// Same as HloAllReduceInstruction::use_global_device_ids.
|
|
bool use_global_device_ids() const { return use_global_device_ids_; }
|
|
|
|
// The dimension on which data from different participants are concatenated.
|
|
int64 all_gather_dimension() const { return all_gather_dimension_; }
|
|
|
|
protected:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
bool IdenticalSlowPathIgnoringChannelIdValues(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
int64 all_gather_dimension_;
|
|
bool use_global_device_ids_;
|
|
};
|
|
|
|
class HloAllReduceInstruction : public HloCollectiveInstruction {
|
|
public:
|
|
explicit HloAllReduceInstruction(
|
|
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
|
HloComputation* reduce_computation,
|
|
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
|
|
const absl::optional<int64>& channel_id, bool use_global_device_ids);
|
|
|
|
// Returns true if the AllReduce does no communication, so it's equivalent
|
|
// to a mem copy.
|
|
bool IsNoop() const;
|
|
|
|
// Returns true if the ids in the ReplicaGroup config represent a global id of
|
|
// (replica_id * partition_count + partition_id) instead of a replica id.
|
|
// This enables more flexible grouping of devices if this all-reduce is both
|
|
// cross-partition and cross-replica.
|
|
//
|
|
// For example with 2 replicas and 4 partitions,
|
|
// replica_groups={{0,1,4,5},{2,3,6,7}}, use_global_device_ids=true means that
|
|
// group[0] = (0,0), (0,1), (1,0), (1,1)
|
|
// group[1] = (0,2), (0,3), (1,2), (1,3)
|
|
// where each pair is (replica_id, partition_id).
|
|
bool use_global_device_ids() const { return use_global_device_ids_; }
|
|
|
|
protected:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
bool IdenticalSlowPathIgnoringChannelIdValues(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
bool use_global_device_ids_;
|
|
};
|
|
|
|
class HloAllToAllInstruction : public HloCollectiveInstruction {
|
|
public:
|
|
explicit HloAllToAllInstruction(
|
|
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
|
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
|
|
const absl::optional<int64>& channel_id,
|
|
const absl::optional<int64>& split_dimension);
|
|
|
|
// AllToAll can optionally take a split dimension, which means that this
|
|
// AllToAll takes a single (flattened) array operand and produces an array
|
|
// output (instead of taking a list of operands and producing a tuple).
|
|
//
|
|
// split_dimension specifies which dimension in the operand is split across
|
|
// devices in each replica_group, and also means the concatenated dimension
|
|
// on the output (i.e., input and the output shapes are the same).
|
|
absl::optional<int64> split_dimension() const { return split_dimension_; }
|
|
void set_split_dimension(int64 dim) { split_dimension_ = dim; }
|
|
|
|
protected:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
bool IdenticalSlowPathIgnoringChannelIdValues(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
absl::optional<int64> split_dimension_;
|
|
};
|
|
|
|
class HloCollectivePermuteInstruction : public HloChannelInstruction {
|
|
public:
|
|
explicit HloCollectivePermuteInstruction(
|
|
HloOpcode opcode, const Shape& shape, HloInstruction* operand,
|
|
const std::vector<std::pair<int64, int64>>& source_target_pairs,
|
|
const absl::optional<int64>& channel_id);
|
|
|
|
const std::vector<std::pair<int64, int64>>& source_target_pairs() const {
|
|
return source_target_pairs_;
|
|
}
|
|
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPathIgnoringChannelIdValues(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
const std::vector<std::pair<int64, int64>> source_target_pairs_;
|
|
};
|
|
|
|
class HloReverseInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand,
|
|
absl::Span<const int64> dimensions);
|
|
// Returns the dimension sizes or numbers associated with this instruction.
|
|
const std::vector<int64>& dimensions() const override { return dimensions_; }
|
|
int64 dimensions(int64 index) const override { return dimensions()[index]; }
|
|
std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
std::vector<int64> dimensions_;
|
|
};
|
|
|
|
class HloConcatenateInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloConcatenateInstruction(const Shape& shape,
|
|
absl::Span<HloInstruction* const> operands,
|
|
int64 dimension);
|
|
// Returns the dimension sizes or numbers associated with this instruction.
|
|
const std::vector<int64>& dimensions() const override { return dimensions_; }
|
|
int64 dimensions(int64 index) const override { return dimensions()[index]; }
|
|
std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
|
|
// Accessor for the dimension in which a concatenate HLO should occur.
|
|
int64 concatenate_dimension() const { return dimensions(0); }
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
std::vector<int64> dimensions_;
|
|
};
|
|
|
|
class HloReduceInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloReduceInstruction(const Shape& shape,
|
|
absl::Span<HloInstruction* const> args,
|
|
absl::Span<const int64> dimensions_to_reduce,
|
|
HloComputation* reduce_computation);
|
|
// Returns the dimension sizes or numbers associated with this instruction.
|
|
const std::vector<int64>& dimensions() const override { return dimensions_; }
|
|
int64 dimensions(int64 index) const override { return dimensions()[index]; }
|
|
std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
// Returns the number of input arrays (and, consequentially, the number of
|
|
// init values) this reduce has.
|
|
int64 input_count() const { return operand_count() / 2; }
|
|
|
|
// Returns the input tensors to be reduced.
|
|
absl::Span<HloInstruction* const> inputs() const {
|
|
return absl::MakeSpan(operands()).subspan(0, input_count());
|
|
}
|
|
|
|
// Returns the init values of the reduction.
|
|
absl::Span<HloInstruction* const> init_values() const {
|
|
return absl::MakeSpan(operands()).subspan(input_count(), operand_count());
|
|
}
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
std::vector<int64> dimensions_;
|
|
};
|
|
|
|
class HloSortInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloSortInstruction(const Shape& shape, int64 dimension,
|
|
absl::Span<HloInstruction* const> operands,
|
|
HloComputation* compare, bool is_stable);
|
|
// Returns the dimension sizes or numbers associated with this instruction.
|
|
const std::vector<int64>& dimensions() const override { return dimensions_; }
|
|
int64 dimensions(int64 index) const override { return dimensions()[index]; }
|
|
std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
|
|
// Returns the sort dimension for this instruction
|
|
int64 sort_dimension() const { return dimensions(0); }
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
// Returns the key operand to this instruction.
|
|
const HloInstruction* keys() const { return operand(0); }
|
|
HloInstruction* mutable_keys() { return mutable_operand(0); }
|
|
// Returns the number of value operands.
|
|
int64 values_count() const { return operand_count() - 1; }
|
|
bool is_stable() const { return is_stable_; }
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
std::vector<int64> dimensions_;
|
|
bool is_stable_;
|
|
};
|
|
|
|
class HloTransposeInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloTransposeInstruction(const Shape& shape, HloInstruction* operand,
|
|
absl::Span<const int64> dimensions);
|
|
// Returns the dimension sizes or numbers associated with this instruction.
|
|
const std::vector<int64>& dimensions() const override { return dimensions_; }
|
|
int64 dimensions(int64 index) const override { return dimensions()[index]; }
|
|
std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
|
|
// Returns whether this instruction does a rank-2 transposition.
|
|
bool IsRank2Transpose() const;
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
std::vector<int64> dimensions_;
|
|
};
|
|
|
|
class HloBroadcastInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloBroadcastInstruction(const Shape& shape, HloInstruction* operand,
|
|
absl::Span<const int64> broadcast_dimension);
|
|
// Returns the dimension sizes or numbers associated with this instruction.
|
|
const std::vector<int64>& dimensions() const override { return dimensions_; }
|
|
int64 dimensions(int64 index) const override { return dimensions()[index]; }
|
|
std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
std::vector<int64> dimensions_;
|
|
};
|
|
|
|
class HloDynamicReshapeInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloDynamicReshapeInstruction(
|
|
const Shape& shape, HloInstruction* data_operand,
|
|
absl::Span<HloInstruction* const> dim_sizes);
|
|
|
|
// Returns the input dim sizes dimensions, which is operands[1:]
|
|
absl::Span<HloInstruction* const> dim_sizes() const {
|
|
return absl::MakeSpan(operands()).subspan(1, operand_count());
|
|
}
|
|
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
// Returns the input dim size dimension, which is operands[1+i]
|
|
HloInstruction* dim_sizes(int64 i) const { return operands()[i + 1]; }
|
|
};
|
|
|
|
class HloReshapeInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloReshapeInstruction(const Shape& shape, HloInstruction* operand,
|
|
int64 inferred_dimension);
|
|
int64 inferred_dimension() const { return inferred_dimension_; }
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
int64 inferred_dimension_;
|
|
};
|
|
|
|
class HloMapInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloMapInstruction(const Shape& shape,
|
|
absl::Span<HloInstruction* const> operands,
|
|
HloComputation* map_computation);
|
|
// Returns the dimension sizes or numbers associated with this instruction.
|
|
const std::vector<int64>& dimensions() const override { return dimensions_; }
|
|
int64 dimensions(int64 index) const override { return dimensions()[index]; }
|
|
std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
bool IsElementwiseImpl(
|
|
const absl::optional<int64>& operand_idx) const override;
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
std::vector<int64> dimensions_;
|
|
};
|
|
|
|
class HloSliceInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand,
|
|
absl::Span<const int64> start_indices,
|
|
absl::Span<const int64> limit_indices,
|
|
absl::Span<const int64> strides);
|
|
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
// Returns the start index in the given dimension for a slice node.
|
|
int64 slice_starts(int64 dimension) const { return slice_starts_[dimension]; }
|
|
const std::vector<int64>& slice_starts() const { return slice_starts_; }
|
|
std::vector<int64>* mutable_slice_starts() { return &slice_starts_; }
|
|
|
|
// Returns the (exclusive) limit index in the given dimension for a slice
|
|
// node.
|
|
int64 slice_limits(int64 dimension) const { return slice_limits_[dimension]; }
|
|
const std::vector<int64>& slice_limits() const { return slice_limits_; }
|
|
std::vector<int64>* mutable_slice_limits() { return &slice_limits_; }
|
|
|
|
// Returns the stride in the given dimension for a slice node.
|
|
int64 slice_strides(int64 dimension) const {
|
|
return slice_strides_[dimension];
|
|
}
|
|
const std::vector<int64>& slice_strides() const { return slice_strides_; }
|
|
std::vector<int64>* mutable_slice_strides() { return &slice_strides_; }
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
// Describes the [begin, end) index range for a slice.
|
|
std::vector<int64> slice_starts_;
|
|
std::vector<int64> slice_limits_;
|
|
std::vector<int64> slice_strides_;
|
|
};
|
|
|
|
class HloConstantInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloConstantInstruction(Literal literal);
|
|
explicit HloConstantInstruction(Literal literal, const Shape& shape);
|
|
// Used when the literal is too large and dropped.
|
|
explicit HloConstantInstruction(const Shape& shape);
|
|
// Returns the literal associated with this instruction.
|
|
const Literal& literal() const { return *literal_; }
|
|
// Returns the (mutable) literal associated with this instruction.
|
|
Literal* mutable_literal() { return &literal_.value(); }
|
|
// Returns whether there is literal associated with this instruction.
|
|
bool HasLiteral() const { return literal_.has_value(); }
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
// Change the layout for an Constant Hlo instruction to match new_layout. For
|
|
// tuple shaped constants shape_index is the path to the internal array
|
|
// subshape whose layout needs to be changed.
|
|
void RelayoutConstant(const Layout& new_layout,
|
|
const ShapeIndex& shape_index = {});
|
|
|
|
private:
|
|
bool IsElementwiseImpl(
|
|
const absl::optional<int64>& operand_idx) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
string OperandsToStringWithCanonicalNameMap(
|
|
const HloPrintOptions& options,
|
|
CanonicalNameMap* canonical_name_map) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
absl::optional<Literal> literal_;
|
|
};
|
|
|
|
class HloTraceInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloTraceInstruction(const string& tag, HloInstruction* operand);
|
|
// Returns a tag to be used in tracing.
|
|
string TracingTag() const { return literal_.GetR1U8AsString(); }
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
Literal literal_;
|
|
};
|
|
|
|
class HloFusionInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
|
|
HloInstruction* fused_root);
|
|
|
|
explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
|
|
absl::Span<HloInstruction* const> operands,
|
|
HloComputation* fusion_computation);
|
|
|
|
string ToCategory() const override;
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
// Adds a new operand the fusion instruction.
|
|
HloInstruction* AddFusionOperand(HloInstruction* new_operand);
|
|
|
|
// Merges the fused instructions from 'instruction_to_merge' into the
|
|
// fused instruction set of 'this', updating operands as necessary.
|
|
//
|
|
// Precondition: 'instruction_to_merge' must be an operand of 'this'.
|
|
void MergeFusionInstruction(HloFusionInstruction* instruction_to_merge);
|
|
|
|
// Merges the fused instructions from instruction_to_merge into the fused
|
|
// instruction set of 'this' and generates multioutput fusion instructions.
|
|
// All the users of instruction_to_merge will be redirected to 'this'
|
|
// instruction. instruction_to_merge will be removed from its parent
|
|
// computation.
|
|
void MergeFusionInstructionIntoMultiOutput(
|
|
HloFusionInstruction* instruction_to_merge);
|
|
|
|
// Fuses the given instruction in this fusion instruction. instruction_to_fuse
|
|
// is cloned and the clone is placed in the fusion
|
|
// instruction. instruction_to_fuse is unchanged. Instruction is cloned rather
|
|
// than moved to cleanly handle the case where the instruction has a use
|
|
// outside the fusion instruction. Moving such an instruction into a fusion
|
|
// instruction would violate the single-result invariant of HLO instructions
|
|
// and significantly complicate code generation.
|
|
HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) {
|
|
return FuseInstructionInternal(instruction_to_fuse);
|
|
}
|
|
|
|
// Fuses the given instruction in this fusion instruction and generates a
|
|
// multioutput fusion instruction. A clone of the instruction_to_fuse will
|
|
// be part of the output of fusion instructions. The users of
|
|
// instruction_to_fuse will be redirected to this fusion instructions.
|
|
// instruction_to_fuse is unchanged otherwise.
|
|
HloInstruction* FuseInstructionIntoMultiOutput(
|
|
HloInstruction* instruction_to_fuse) {
|
|
return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true);
|
|
}
|
|
|
|
// Returns the computation for this fused instruction.
|
|
HloComputation* fused_instructions_computation() const;
|
|
|
|
// Returns the root instruction of the fused expression contained within this
|
|
// fusion instruction.
|
|
HloInstruction* fused_expression_root() const;
|
|
|
|
// Returns the list of fused instructions inside this fusion instruction. The
|
|
// returned type is a range of HloInstruction*s.
|
|
const tensorflow::gtl::iterator_range<UnwrappingIterator<
|
|
std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
|
|
fused_instructions() const;
|
|
|
|
const tensorflow::gtl::iterator_range<
|
|
UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
|
|
fused_instructions();
|
|
|
|
// Gets the number of instructions inside this fusion instruction.
|
|
int64 fused_instruction_count() const;
|
|
|
|
// Returns the fused parameter instruction in this fusion instruction
|
|
// corresponding to the given parameter number.
|
|
HloInstruction* fused_parameter(int64 parameter_number) const;
|
|
|
|
// Returns the vector of fused parameters inside this fusion instruction.
|
|
const std::vector<HloInstruction*>& fused_parameters() const;
|
|
|
|
// Returns true if this instruction is a fusion instruction that generates
|
|
// multiple outputs.
|
|
const bool IsMultiOutputFusion() const {
|
|
return fused_expression_root()->opcode() == HloOpcode::kTuple;
|
|
}
|
|
|
|
FusionKind fusion_kind() const { return fusion_kind_; }
|
|
|
|
void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; }
|
|
|
|
// If multiple operands are the same instruction, keeps only one of them.
|
|
Status DeduplicateFusionOperands();
|
|
|
|
private:
|
|
// Fuses the given instruction into this fusion instruction.
|
|
// instruction_to_fuse is cloned and the clone is placed in the fusion
|
|
// instruction. The users of instruction_to_fuse will be redirected to this
|
|
// fusion instruction. instruction_to_fuse is unchanged otherwise. When
|
|
// add_output is true, a clone of the instruction_to_fuse will be added as
|
|
// additional output resulting in a multi-output fusion.
|
|
HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse,
|
|
bool add_output = false);
|
|
// Clones the given instruction_to_fuse and insert the clone into this fusion
|
|
// instruction. If add_output is true, a clone of instruction_to_fuse will
|
|
// be in the output of the this fusion instruction (part of the tuple of the
|
|
// fusion root).
|
|
HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse,
|
|
bool add_output = false);
|
|
|
|
bool IsElementwiseImpl(
|
|
const absl::optional<int64>& operand_idx) const override;
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
uint64 InnerHash() const override;
|
|
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
// The type of the fusion. Used by kFusion only.
|
|
FusionKind fusion_kind_;
|
|
};
|
|
|
|
class HloRngInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloRngInstruction(const Shape& shape,
|
|
RandomDistribution distribution,
|
|
absl::Span<HloInstruction* const> parameters);
|
|
// Returns the random distribution for this rng node.
|
|
RandomDistribution random_distribution() const { return distribution_; }
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
bool IsElementwiseImpl(
|
|
const absl::optional<int64>& operand_idx) const override;
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
// The distribution requested for random number generation.
|
|
RandomDistribution distribution_;
|
|
};
|
|
|
|
class HloParameterInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloParameterInstruction(int64 parameter_number, const Shape& shape,
|
|
const string& name);
|
|
int64 parameter_number() const { return parameter_number_; }
|
|
|
|
// Sets and gets the whether all replicas will receive the same parameter data
|
|
// for each leaf buffer in data parallelism.
|
|
void set_parameter_replicated_at_leaf_buffers(
|
|
absl::Span<const bool> parameter_replicated_at_leaf_buffers) {
|
|
CHECK_EQ(ShapeUtil::GetLeafCount(shape()),
|
|
parameter_replicated_at_leaf_buffers.size());
|
|
parameter_replicated_at_leaf_buffers_.emplace(
|
|
parameter_replicated_at_leaf_buffers.begin(),
|
|
parameter_replicated_at_leaf_buffers.end());
|
|
}
|
|
void set_parameter_replicated_at_leaf_buffers(
|
|
const std::vector<bool>& parameter_replicated_at_leaf_buffers) {
|
|
CHECK_EQ(ShapeUtil::GetLeafCount(shape()),
|
|
parameter_replicated_at_leaf_buffers.size());
|
|
parameter_replicated_at_leaf_buffers_ =
|
|
parameter_replicated_at_leaf_buffers;
|
|
}
|
|
const absl::optional<std::vector<bool>>&
|
|
parameter_replicated_at_leaf_buffers() const {
|
|
return parameter_replicated_at_leaf_buffers_;
|
|
}
|
|
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
string OperandsToStringWithCanonicalNameMap(
|
|
const HloPrintOptions& options,
|
|
CanonicalNameMap* canonical_name_map) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
int64 parameter_number_ = 0;
|
|
|
|
// Specifies whether each buffer has the same parameter value on all replicas
|
|
// in data parallelism.
|
|
absl::optional<std::vector<bool>> parameter_replicated_at_leaf_buffers_;
|
|
};
|
|
|
|
class HloGetTupleElementInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloGetTupleElementInstruction(const Shape& shape,
|
|
HloInstruction* operand, int64 index);
|
|
// Returns the tuple index associated with this instruction.
|
|
int64 tuple_index() const { return tuple_index_; }
|
|
// Sets the tuple index associated with this instruction.
|
|
void set_tuple_index(int64 new_tuple_index) {
|
|
tuple_index_ = new_tuple_index;
|
|
}
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
int64 tuple_index_ = -1;
|
|
};
|
|
|
|
class HloReducePrecisionInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloReducePrecisionInstruction(const Shape& shape,
|
|
HloInstruction* operand,
|
|
const int exponent_bits,
|
|
const int mantissa_bits);
|
|
// Returns the number of exponent bits for a reduce-precision node.
|
|
int32 exponent_bits() const { return exponent_bits_; }
|
|
// Returns the number of mantissa bits for a reduce-precision node.
|
|
int32 mantissa_bits() const { return mantissa_bits_; }
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
// The bit sizes for a reduce-precision operation.
|
|
int32 exponent_bits_ = 0;
|
|
int32 mantissa_bits_ = 0;
|
|
};
|
|
|
|
class HloInfeedInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloInfeedInstruction(const Shape& infeed_shape,
|
|
HloInstruction* token_operand,
|
|
const string& config);
|
|
// Returns the infeed configuration string. The infeed configuration includes
|
|
// any metadata needed for the backend compiler (e.g., infeed buffer address)
|
|
// and is target-dependent.
|
|
string infeed_config() const { return infeed_config_; }
|
|
void set_infeed_config(const string& config) { infeed_config_ = config; }
|
|
// Returns the shape of the data received by the infeed. This is not the same
|
|
// as the shape of the infeed instruction which produces a tuple containing
|
|
// the infeed data shape and a TOKEN.
|
|
const Shape& infeed_shape() const {
|
|
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape()));
|
|
return ShapeUtil::GetSubshape(shape(), {0});
|
|
}
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
// The string representation of the infeed configuration.
|
|
string infeed_config_;
|
|
};
|
|
|
|
class HloOutfeedInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloOutfeedInstruction(const Shape& outfeed_shape,
|
|
HloInstruction* operand,
|
|
HloInstruction* token_operand,
|
|
absl::string_view outfeed_config);
|
|
// Returns the shape for the Outfeed instruction.
|
|
const Shape& outfeed_shape() const { return outfeed_shape_; }
|
|
// Returns the mutable shape for the Outfeed instruction.
|
|
Shape* mutable_outfeed_shape() { return &outfeed_shape_; }
|
|
// Returns the config for the Outfeed instruction.
|
|
const string& outfeed_config() const { return outfeed_config_; }
|
|
void set_outfeed_config(const string& config) { outfeed_config_ = config; }
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
// Shape of outfeed request.
|
|
Shape outfeed_shape_;
|
|
// Outfeed configuration information, only present for kOutfeed.
|
|
string outfeed_config_;
|
|
};
|
|
|
|
class HloConvolutionInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloConvolutionInstruction(
|
|
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
|
|
int64 feature_group_count, int64 batch_group_count, const Window& window,
|
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
|
const PrecisionConfig& precision_config);
|
|
const Window& window() const override { return window_; }
|
|
void set_window(const Window& window) override { window_ = window; }
|
|
const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
|
|
return convolution_dimension_numbers_;
|
|
}
|
|
void set_convolution_dimension_numbers(
|
|
const ConvolutionDimensionNumbers& dnums) {
|
|
convolution_dimension_numbers_ = dnums;
|
|
}
|
|
// The number of feature groups. Must be a divisor of the input feature
|
|
// dimension and output feature dimension.
|
|
int64 feature_group_count() const { return feature_group_count_; }
|
|
void set_feature_group_count(int64 num_feature_groups) {
|
|
feature_group_count_ = num_feature_groups;
|
|
}
|
|
// The number of batch groups. Must be a divisor of the input batch dimension.
|
|
int64 batch_group_count() const { return batch_group_count_; }
|
|
void set_batch_group_count(int64 num_batch_groups) {
|
|
batch_group_count_ = num_batch_groups;
|
|
}
|
|
|
|
// Returns the information used to tell the implementation information about
|
|
// what sort of precision is requested. The meaning of the field is backend
|
|
// specific. At the moment, it is only supported for kConvolution and kDot.
|
|
// Transformations on one kDot or kConvolution to another will preserve this
|
|
// information. Transformations to other HLOs will not preserve this
|
|
// information but it is presumed that the alternate lowering is strictly
|
|
// superior.
|
|
const PrecisionConfig& precision_config() const { return precision_config_; }
|
|
PrecisionConfig* mutable_precision_config() { return &precision_config_; }
|
|
|
|
string ToCategory() const override;
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
// The number of feature groups. Must be a divisor of the input feature
|
|
// dimension and output feature dimension.
|
|
int64 feature_group_count_;
|
|
// The number of batch groups. Must be a divisor of the input batch dimension.
|
|
int64 batch_group_count_;
|
|
// Describes the window used for a convolution.
|
|
Window window_;
|
|
// Describes the dimension numbers used for a convolution.
|
|
ConvolutionDimensionNumbers convolution_dimension_numbers_;
|
|
// Information used to communicate to the implementation about the algorithm
|
|
// used to produce results. See the documentation on precision_config().
|
|
PrecisionConfig precision_config_;
|
|
};
|
|
|
|
class HloReduceWindowInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloReduceWindowInstruction(const Shape& shape,
|
|
HloInstruction* operand,
|
|
HloInstruction* init_value,
|
|
const Window& window,
|
|
HloComputation* reduce_computation);
|
|
explicit HloReduceWindowInstruction(
|
|
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
|
absl::Span<HloInstruction* const> init_values, const Window& window,
|
|
HloComputation* reduce_computation);
|
|
const Window& window() const override { return window_; }
|
|
void set_window(const Window& window) override { window_ = window; }
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
// Returns the number of input arrays (and, consequentially, the number of
|
|
// init values) this reduce has.
|
|
int64 input_count() const { return operand_count() / 2; }
|
|
// Returns the input tensors to be reduced.
|
|
absl::Span<HloInstruction* const> input_arrays() const {
|
|
return absl::MakeSpan(operands()).subspan(0, input_count());
|
|
}
|
|
// Returns the init values of the reduction.
|
|
absl::Span<HloInstruction* const> init_values() const {
|
|
return absl::MakeSpan(operands()).subspan(input_count(), operand_count());
|
|
}
|
|
// Returns the shapes of input tensors to be reduced.
|
|
absl::InlinedVector<const Shape*, 2> input_array_shapes() const {
|
|
absl::InlinedVector<const Shape*, 2> shapes;
|
|
for (const auto* op : input_arrays()) {
|
|
VLOG(2) << "Pushing input array shape for: " << op->ToString() << "\n";
|
|
shapes.push_back(&op->shape());
|
|
VLOG(2) << "Pushed shape: " << shapes.back()->ToString() << "\n";
|
|
}
|
|
return shapes;
|
|
}
|
|
// Returns the init values of the reduction.
|
|
absl::InlinedVector<const Shape*, 2> init_value_shapes() const {
|
|
absl::InlinedVector<const Shape*, 2> shapes;
|
|
for (const auto* op : init_values()) {
|
|
shapes.push_back(&op->shape());
|
|
}
|
|
return shapes;
|
|
}
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
Window window_;
|
|
};
|
|
|
|
class HloSelectAndScatterInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloSelectAndScatterInstruction(
|
|
const Shape& shape, HloInstruction* operand, HloComputation* select,
|
|
const Window& window, HloInstruction* source, HloInstruction* init_value,
|
|
HloComputation* scatter);
|
|
const Window& window() const override { return window_; }
|
|
void set_window(const Window& window) override { window_ = window; }
|
|
// Gets/sets the select or scatter HloComputation for SelectAndScatter. The
|
|
// setters should only be called by HloModule or HloComputation methods.
|
|
HloComputation* select() const {
|
|
return called_computations()[kSelectComputationIndex];
|
|
}
|
|
|
|
HloComputation* scatter() const {
|
|
return called_computations()[kScatterComputationIndex];
|
|
}
|
|
|
|
void set_select(HloComputation* computation) {
|
|
// Don't allow changing the computation for fused instructions so we don't
|
|
// have to recompute called_instructions for the entire fusion instruction.
|
|
CHECK(!IsFused());
|
|
set_called_computation(kSelectComputationIndex, computation);
|
|
}
|
|
|
|
void set_scatter(HloComputation* computation) {
|
|
// Don't allow changing the computation for fused instructions so we don't
|
|
// have to recompute called_instructions for the entire fusion instruction.
|
|
CHECK(!IsFused());
|
|
set_called_computation(kScatterComputationIndex, computation);
|
|
}
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
Window window_;
|
|
};
|
|
|
|
class HloCustomCallInstruction : public HloInstruction {
|
|
public:
|
|
HloCustomCallInstruction(const Shape& shape,
|
|
absl::Span<HloInstruction* const> operands,
|
|
absl::string_view custom_call_target, string opaque);
|
|
|
|
// Constructor for a custom call with constrained layout. 'shape' and
|
|
// 'operands_with_layout' must all have layouts.
|
|
HloCustomCallInstruction(const Shape& shape,
|
|
absl::Span<HloInstruction* const> operands,
|
|
absl::string_view custom_call_target, string opaque,
|
|
absl::Span<const Shape> operand_shapes_with_layout);
|
|
|
|
// Constructor for a custom call with a to_apply computation.
|
|
HloCustomCallInstruction(const Shape& shape,
|
|
absl::Span<HloInstruction* const> operands,
|
|
HloComputation* to_apply,
|
|
absl::string_view custom_call_target, string opaque);
|
|
|
|
// Constructor for a custom call with multiple computations.
|
|
HloCustomCallInstruction(
|
|
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
|
absl::Span<HloComputation* const> called_computations,
|
|
absl::string_view custom_call_target, string opaque);
|
|
|
|
const Window& window() const override {
|
|
CHECK(window_ != nullptr);
|
|
return *window_;
|
|
}
|
|
|
|
void set_window(const Window& window) override {
|
|
window_ = absl::make_unique<Window>(window);
|
|
}
|
|
|
|
const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
|
|
CHECK(convolution_dimension_numbers_ != nullptr);
|
|
return *convolution_dimension_numbers_;
|
|
}
|
|
|
|
void set_convolution_dimension_numbers(
|
|
const ConvolutionDimensionNumbers& dnums) {
|
|
convolution_dimension_numbers_ =
|
|
absl::make_unique<ConvolutionDimensionNumbers>(dnums);
|
|
}
|
|
// TODO(jpienaar): Remove this accessor in the follow up.
|
|
const string& opaque() const { return raw_backend_config_string(); }
|
|
const string& custom_call_target() const { return custom_call_target_; }
|
|
void set_feature_group_count(int64 feature_group_count) {
|
|
feature_group_count_ = feature_group_count;
|
|
}
|
|
void set_batch_group_count(int64 batch_group_count) {
|
|
batch_group_count_ = batch_group_count;
|
|
}
|
|
// Sets whether this custom call has a side-effect - by default a custom call
|
|
// has no side-effects.
|
|
void set_custom_call_has_side_effect(bool custom_call_has_side_effect) {
|
|
custom_call_has_side_effect_ = custom_call_has_side_effect;
|
|
}
|
|
int64 feature_group_count() const { return feature_group_count_; }
|
|
int64 batch_group_count() const { return batch_group_count_; }
|
|
bool custom_call_has_side_effect() const {
|
|
return custom_call_has_side_effect_;
|
|
}
|
|
// Returns padding type used for ops like convolution.
|
|
PaddingType padding_type() const { return padding_type_; }
|
|
|
|
void set_padding_type(PaddingType padding_type) {
|
|
padding_type_ = padding_type;
|
|
}
|
|
|
|
const PrecisionConfig& precision_config() const { return precision_config_; }
|
|
PrecisionConfig* mutable_precision_config() { return &precision_config_; }
|
|
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
// Returns whether the result and operand layouts are constrained.
|
|
bool layout_constrained() const { return layout_constrained_; }
|
|
|
|
// Returns the shapes (with layout) of the operands. CHECKs if this custom
|
|
// call does not have constrained layouts.
|
|
const std::vector<Shape>& operand_shapes_with_layout() const {
|
|
CHECK(layout_constrained());
|
|
return operand_shapes_with_layout_;
|
|
}
|
|
// Gets a list of output/operand buffer pairs that alias each other, where the
|
|
// output buffer is represented as a ShapeIndex, and the operand buffer is
|
|
// represented as the operand index and the ShapeIndex. By default this list
|
|
// is empty.
|
|
const std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>&
|
|
output_to_operand_aliasing() const {
|
|
return output_to_operand_aliasing_;
|
|
}
|
|
// Sets the list of output/operand buffer pairs that alias each other.
|
|
void set_output_to_operand_aliasing(
|
|
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
|
aliasing) {
|
|
output_to_operand_aliasing_ = std::move(aliasing);
|
|
}
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
// Name of a global symbol to call.
|
|
string custom_call_target_;
|
|
// Describes the window in a windowed operation such as convolution.
|
|
std::unique_ptr<Window> window_;
|
|
// Describes the dimension numbers used for a convolution.
|
|
std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
|
|
// The number of feature groups. This is used for grouped convolutions.
|
|
int64 feature_group_count_;
|
|
int64 batch_group_count_;
|
|
// Whether the result and operand layouts are constrained.
|
|
bool layout_constrained_;
|
|
// Information used to communicate to the implementation about the algorithm
|
|
// used to produce results for convolution instructions.
|
|
PrecisionConfig precision_config_;
|
|
// Describes the padding type for convolution instructions.
|
|
PaddingType padding_type_;
|
|
// For layout-constrained custom calls, this vector holds the shape with
|
|
// layout for each operand.
|
|
std::vector<Shape> operand_shapes_with_layout_;
|
|
// Whether this custom call has a side-effect.
|
|
bool custom_call_has_side_effect_;
|
|
// A list of output/operand buffer pairs that alias each other. See comment of
|
|
// output_to_operand_aliasing().
|
|
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
|
output_to_operand_aliasing_;
|
|
};
|
|
|
|
class HloPadInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloPadInstruction(const Shape& shape, HloInstruction* operand,
|
|
HloInstruction* padding_value,
|
|
const PaddingConfig& padding_config);
|
|
// Returns the padding configuration for a pad node.
|
|
const PaddingConfig& padding_config() const { return padding_config_; }
|
|
PaddingConfig* mutable_padding_config() { return &padding_config_; }
|
|
// Returns the padding value.
|
|
const HloInstruction* padding_value() const { return operand(1); }
|
|
HloInstruction* mutable_padding_value() { return mutable_operand(1); }
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
// The padding configuration that describes the edge padding and interior
|
|
// padding of this pad instruction.
|
|
PaddingConfig padding_config_;
|
|
};
|
|
|
|
class HloDynamicIndexInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloDynamicIndexInstruction(HloOpcode opcode, const Shape& shape)
|
|
: HloInstruction(opcode, shape) {}
|
|
virtual int64 first_index_operand_number() const = 0;
|
|
|
|
// Returns a subspan of operands which represent the start indices.
|
|
absl::Span<HloInstruction* const> index_operands() const {
|
|
return absl::MakeSpan(operands()).subspan(first_index_operand_number());
|
|
}
|
|
|
|
// Returns the shapes of the index operands.
|
|
std::vector<Shape> index_shapes() const {
|
|
std::vector<Shape> shapes;
|
|
auto indices = index_operands();
|
|
for (const HloInstruction* index : indices) {
|
|
shapes.push_back(index->shape());
|
|
}
|
|
return shapes;
|
|
}
|
|
};
|
|
|
|
class HloDynamicSliceInstruction : public HloDynamicIndexInstruction {
|
|
public:
|
|
explicit HloDynamicSliceInstruction(const Shape& shape,
|
|
HloInstruction* operand,
|
|
HloInstruction* start_indices,
|
|
absl::Span<const int64> slice_sizes);
|
|
explicit HloDynamicSliceInstruction(
|
|
const Shape& shape, HloInstruction* operand,
|
|
absl::Span<HloInstruction* const> start_indices,
|
|
absl::Span<const int64> slice_sizes);
|
|
// Old methods kept for smooth subclassing transition END.
|
|
// Returns the size of the slice in the given dimension for a dynamic
|
|
// slice node.
|
|
int64 slice_sizes(int64 dimension) const {
|
|
return dynamic_slice_sizes_[dimension];
|
|
}
|
|
const std::vector<int64>& dynamic_slice_sizes() const {
|
|
return dynamic_slice_sizes_;
|
|
}
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
int64 first_index_operand_number() const override { return 1; }
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
// Describes the [start, start + size) range size for a dynamic slice
|
|
// ('start' is specified dynamically in the second operand of the operation).
|
|
std::vector<int64> dynamic_slice_sizes_;
|
|
};
|
|
|
|
class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction {
|
|
public:
|
|
explicit HloDynamicUpdateSliceInstruction(const Shape& shape,
|
|
HloInstruction* operand,
|
|
HloInstruction* update,
|
|
HloInstruction* start_indices);
|
|
explicit HloDynamicUpdateSliceInstruction(
|
|
const Shape& shape, HloInstruction* operand, HloInstruction* update,
|
|
absl::Span<HloInstruction* const> start_indices);
|
|
|
|
int64 first_index_operand_number() const override { return 2; }
|
|
};
|
|
|
|
class HloGatherInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloGatherInstruction(
|
|
const Shape& shape, HloInstruction* operand,
|
|
HloInstruction* start_indices,
|
|
const GatherDimensionNumbers& gather_dim_numbers,
|
|
absl::Span<const int64> slice_sizes, bool indices_are_sorted);
|
|
const GatherDimensionNumbers& gather_dimension_numbers() const {
|
|
CHECK(gather_dimension_numbers_ != nullptr);
|
|
return *gather_dimension_numbers_;
|
|
}
|
|
absl::Span<const int64> gather_slice_sizes() const {
|
|
return gather_slice_sizes_;
|
|
}
|
|
bool indices_are_sorted() const { return indices_are_sorted_; }
|
|
void set_indices_are_sorted(bool indices_are_sorted) {
|
|
indices_are_sorted_ = indices_are_sorted;
|
|
}
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
// Creates an instance of GatherDimensionNumbers.
|
|
static GatherDimensionNumbers MakeGatherDimNumbers(
|
|
absl::Span<const int64> offset_dims,
|
|
absl::Span<const int64> collapsed_slice_dims,
|
|
absl::Span<const int64> start_index_map, int64 index_vector_dim);
|
|
// Returns the dump string of the given gather dimension numbers.
|
|
static string GatherDimensionNumbersToString(
|
|
const GatherDimensionNumbers& gather_dimension_numbers);
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
|
|
std::vector<int64> gather_slice_sizes_;
|
|
bool indices_are_sorted_;
|
|
};
|
|
|
|
class HloScatterInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloScatterInstruction(
|
|
const Shape& shape, HloInstruction* operand,
|
|
HloInstruction* scatter_indices, HloInstruction* updates,
|
|
HloComputation* update_computation,
|
|
const ScatterDimensionNumbers& scatter_dim_numbers,
|
|
bool indices_are_sorted, bool unique_indices);
|
|
const ScatterDimensionNumbers& scatter_dimension_numbers() const {
|
|
CHECK(scatter_dimension_numbers_ != nullptr);
|
|
return *scatter_dimension_numbers_;
|
|
}
|
|
bool indices_are_sorted() const { return indices_are_sorted_; }
|
|
void set_indices_are_sorted(bool indices_are_sorted) {
|
|
indices_are_sorted_ = indices_are_sorted;
|
|
}
|
|
bool unique_indices() const override { return unique_indices_; }
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
// Creates an instance of ScatterDimensionNumbers.
|
|
static ScatterDimensionNumbers MakeScatterDimNumbers(
|
|
absl::Span<const int64> update_window_dims,
|
|
absl::Span<const int64> inserted_window_dims,
|
|
absl::Span<const int64> scatter_dims_to_operand_dims,
|
|
int64 index_vector_dim);
|
|
// Returns the dump string of the given scatter dimension numbers.
|
|
static string ScatterDimensionNumbersToString(
|
|
const ScatterDimensionNumbers& scatter_dimension_numbers);
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_;
|
|
bool indices_are_sorted_;
|
|
bool unique_indices_;
|
|
};
|
|
|
|
class HloIotaInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloIotaInstruction(const Shape& shape, int64 iota_dimension);
|
|
// Returns the dimension sizes or numbers associated with this instruction.
|
|
int64 iota_dimension() const { return iota_dimension_; }
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
const int64 iota_dimension_;
|
|
};
|
|
|
|
class HloDotInstruction : public HloInstruction {
|
|
public:
|
|
// Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
|
|
// dimensions specified in 'dimension_numbers'.
|
|
explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs,
|
|
HloInstruction* rhs,
|
|
const DotDimensionNumbers& dimension_numbers,
|
|
const PrecisionConfig& precision_config);
|
|
|
|
// Returns data on the dimension numbers used for a dot operation.
|
|
const DotDimensionNumbers& dot_dimension_numbers() const {
|
|
return dot_dimension_numbers_;
|
|
}
|
|
|
|
// Returns the information used to tell the implementation information about
|
|
// what sort of precision is requested. The meaning of the field is backend
|
|
// specific. At the moment, it is only supported for kConvolution and kDot.
|
|
// Transformations on one kDot or kConvolution to another will preserve this
|
|
// information. Transformations to other HLOs will not preserve this
|
|
// information but it is presumed that the alternate lowering is strictly
|
|
// superior.
|
|
const PrecisionConfig& precision_config() const { return precision_config_; }
|
|
PrecisionConfig* mutable_precision_config() { return &precision_config_; }
|
|
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
// Returns the dump string of the dot dimension numbers.
|
|
string DotDimensionNumbersToString() const;
|
|
|
|
// Describes the dimension numbers used for a dot.
|
|
DotDimensionNumbers dot_dimension_numbers_;
|
|
|
|
// Information used to communicate to the implementation about the algorithm
|
|
// used to produce results. See the documentation on precision_config().
|
|
PrecisionConfig precision_config_;
|
|
};
|
|
|
|
class HloDomainInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloDomainInstruction(
|
|
const Shape& shape, HloInstruction* operand,
|
|
std::unique_ptr<DomainMetadata> operand_side_metadata,
|
|
std::unique_ptr<DomainMetadata> user_side_metadata);
|
|
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
// Retrieves the operand side metadata of a kDomain instruction.
|
|
const DomainMetadata& operand_side_metadata() const {
|
|
return *operand_side_metadata_;
|
|
}
|
|
// Retrieves the user side metadata of a kDomain instruction.
|
|
const DomainMetadata& user_side_metadata() const {
|
|
return *user_side_metadata_;
|
|
}
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
std::unique_ptr<DomainMetadata> operand_side_metadata_;
|
|
std::unique_ptr<DomainMetadata> user_side_metadata_;
|
|
};
|
|
|
|
class HloGetDimensionSizeInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloGetDimensionSizeInstruction(const Shape& shape,
|
|
HloInstruction* operand,
|
|
int64 dimension);
|
|
|
|
// Returns the dimension sizes or numbers associated with this instruction.
|
|
int64 dimension() const { return dimension_; }
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
int64 dimension_;
|
|
};
|
|
|
|
class HloSetDimensionSizeInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloSetDimensionSizeInstruction(const Shape& shape,
|
|
HloInstruction* operand,
|
|
HloInstruction* val, int64 dimension);
|
|
|
|
// Returns the dimension sizes or numbers associated with this instruction.
|
|
int64 dimension() const { return dimension_; }
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
int64 dimension_;
|
|
};
|
|
|
|
class HloRngGetAndUpdateStateInstruction : public HloInstruction {
|
|
public:
|
|
explicit HloRngGetAndUpdateStateInstruction(const Shape& shape, int64 delta);
|
|
|
|
// Returns the delta value.
|
|
int64 delta() const { return delta_; }
|
|
void set_delta(int64 delta) { delta_ = delta; }
|
|
// Returns a serialized representation of this instruction.
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
// Implementation for non-common logic of CloneWithNewOperands.
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
int64 delta_;
|
|
};
|
|
|
|
class HloRngBitGeneratorInstruction : public HloInstruction {
|
|
public:
|
|
HloRngBitGeneratorInstruction(const Shape& shape, HloInstruction* state,
|
|
RandomAlgorithm algorithm);
|
|
|
|
RandomAlgorithm algorithm() const { return algorithm_; }
|
|
HloInstructionProto ToProto() const override;
|
|
|
|
private:
|
|
std::vector<string> ExtraAttributesToStringImpl(
|
|
const HloPrintOptions& options) const override;
|
|
bool IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
|
eq_computations) const override;
|
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
|
HloCloneContext* context) const override;
|
|
|
|
RandomAlgorithm algorithm_;
|
|
};
|
|
|
|
} // namespace xla
|
|
|
|
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
|