This avoids the awkward situation where an RNG which is clearly eligible for fusion becomes ineligible mid-fusion because it suddenly has an extra (dead) user. PiperOrigin-RevId: 173141716
2920 lines
108 KiB
C++
2920 lines
108 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
|
|
|
#include <algorithm>
|
|
#include <deque>
|
|
#include <ostream>
|
|
#include <set>
|
|
#include <unordered_set>
|
|
#include <utility>
|
|
|
|
#include "tensorflow/compiler/xla/layout_util.h"
|
|
#include "tensorflow/compiler/xla/literal_util.h"
|
|
#include "tensorflow/compiler/xla/protobuf_util.h"
|
|
#include "tensorflow/compiler/xla/ptr_util.h"
|
|
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
|
#include "tensorflow/compiler/xla/service/name_uniquer.h"
|
|
#include "tensorflow/compiler/xla/shape_util.h"
|
|
#include "tensorflow/compiler/xla/status_macros.h"
|
|
#include "tensorflow/compiler/xla/types.h"
|
|
#include "tensorflow/compiler/xla/util.h"
|
|
#include "tensorflow/compiler/xla/window_util.h"
|
|
#include "tensorflow/core/lib/core/errors.h"
|
|
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
#include "tensorflow/core/lib/strings/str_util.h"
|
|
#include "tensorflow/core/lib/strings/strcat.h"
|
|
#include "tensorflow/core/platform/logging.h"
|
|
|
|
namespace xla {
|
|
|
|
using ::tensorflow::str_util::Join;
|
|
using ::tensorflow::strings::StrAppend;
|
|
using ::tensorflow::strings::StrCat;
|
|
|
|
/* static */
|
|
StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
|
HloModule* module, const HloInstructionProto& proto,
|
|
const tensorflow::gtl::FlatMap<string, HloInstruction*>& instruction_map,
|
|
tensorflow::gtl::FlatMap<string, HloComputation*>* computation_map) {
|
|
TF_RET_CHECK(!proto.opcode().empty());
|
|
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode()));
|
|
TF_RET_CHECK(proto.has_shape());
|
|
|
|
auto instruction = WrapUnique(new HloInstruction(opcode, proto.shape()));
|
|
for (const string& operand_name : proto.operand_names()) {
|
|
TF_RET_CHECK(ContainsKey(instruction_map, operand_name))
|
|
<< "No instruction named " << operand_name;
|
|
instruction->AppendOperand(instruction_map.at(operand_name));
|
|
}
|
|
for (const string& predecessor_name : proto.control_predecessor_names()) {
|
|
TF_RET_CHECK(ContainsKey(instruction_map, predecessor_name))
|
|
<< "No instruction named " << predecessor_name;
|
|
TF_RETURN_IF_ERROR(instruction_map.at(predecessor_name)
|
|
->AddControlDependencyTo(instruction.get()));
|
|
}
|
|
|
|
// In the proto, fused computations are held exclusively within the
|
|
// HloInstructionProto and do not appear as an HloComputationProto within the
|
|
// HloModuleProto.
|
|
if (instruction->opcode() == HloOpcode::kFusion) {
|
|
TF_RET_CHECK(proto.has_fused_instructions_computation());
|
|
TF_RET_CHECK(!proto.fusion_kind().empty());
|
|
TF_ASSIGN_OR_RETURN(instruction->fusion_kind_,
|
|
StringToFusionKind(proto.fusion_kind()));
|
|
TF_ASSIGN_OR_RETURN(
|
|
std::unique_ptr<HloComputation> fused_computation,
|
|
HloComputation::CreateFromProto(
|
|
module, proto.fused_instructions_computation(), computation_map,
|
|
/*fusion_instruction=*/instruction.get()));
|
|
instruction->called_computations_.push_back(
|
|
module->AddEmbeddedComputation(std::move(fused_computation)));
|
|
} else {
|
|
for (const string& computation_name : proto.called_computation_names()) {
|
|
TF_RET_CHECK(ContainsKey(*computation_map, computation_name))
|
|
<< "No computation named " << computation_name;
|
|
instruction->called_computations_.push_back(
|
|
computation_map->at(computation_name));
|
|
}
|
|
}
|
|
|
|
TF_RET_CHECK(!proto.name().empty());
|
|
instruction->name_ = proto.name();
|
|
|
|
instruction->metadata_ = proto.metadata();
|
|
if (proto.has_literal()) {
|
|
instruction->literal_ = MakeUnique<Literal>(proto.literal());
|
|
}
|
|
instruction->parameter_number_ = proto.parameter_number();
|
|
instruction->parameter_name_ = proto.parameter_name();
|
|
|
|
instruction->tuple_index_ = proto.tuple_index();
|
|
for (int64 dimension : proto.dimensions()) {
|
|
instruction->dimensions_.push_back(dimension);
|
|
}
|
|
if (proto.has_window()) {
|
|
instruction->window_ = MakeUnique<Window>(proto.window());
|
|
}
|
|
if (proto.has_convolution_dimension_numbers()) {
|
|
instruction->convolution_dimension_numbers_ =
|
|
MakeUnique<ConvolutionDimensionNumbers>(
|
|
proto.convolution_dimension_numbers());
|
|
}
|
|
for (const HloInstructionProto::SliceDimensions& slice_dimensions :
|
|
proto.slice_dimensions()) {
|
|
instruction->slice_starts_.push_back(slice_dimensions.start());
|
|
instruction->slice_limits_.push_back(slice_dimensions.limit());
|
|
instruction->slice_strides_.push_back(slice_dimensions.stride());
|
|
}
|
|
instruction->exponent_bits_ = proto.exponent_bits();
|
|
instruction->mantissa_bits_ = proto.mantissa_bits();
|
|
for (int64 dynamic_slice_size : proto.dynamic_slice_sizes()) {
|
|
instruction->dynamic_slice_sizes_.push_back(dynamic_slice_size);
|
|
}
|
|
if (proto.has_padding_config()) {
|
|
instruction->padding_config_ =
|
|
MakeUnique<PaddingConfig>(proto.padding_config());
|
|
}
|
|
instruction->outfeed_config_ = proto.outfeed_config();
|
|
instruction->distribution_ = proto.distribution();
|
|
instruction->epsilon_ = proto.epsilon();
|
|
instruction->feature_index_ = proto.feature_index();
|
|
instruction->channel_id_ = proto.channel_id();
|
|
instruction->infeed_config_ = proto.infeed_config();
|
|
instruction->custom_call_target_ = proto.custom_call_target();
|
|
instruction->outfeed_shape_ = proto.outfeed_shape();
|
|
|
|
return std::move(instruction);
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter(
|
|
int64 parameter_number, const Shape& shape, const string& name) {
|
|
auto instruction =
|
|
WrapUnique(new HloInstruction(HloOpcode::kParameter, shape));
|
|
instruction->parameter_number_ = parameter_number;
|
|
instruction->parameter_name_ = name;
|
|
instruction->name_ = "%" + name;
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTrace(
|
|
const string& tag, HloInstruction* operand) {
|
|
auto instruction =
|
|
WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()));
|
|
instruction->operands_.push_back(operand);
|
|
instruction->literal_.reset(new Literal);
|
|
instruction->literal_->append_u8s(tag);
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
|
|
std::unique_ptr<Literal> literal) {
|
|
auto instruction =
|
|
WrapUnique(new HloInstruction(HloOpcode::kConstant, literal->shape()));
|
|
instruction->literal_ = std::move(literal);
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction>
|
|
HloInstruction::CreateGetTupleElement(const Shape& shape,
|
|
HloInstruction* operand, int64 index) {
|
|
auto instruction =
|
|
WrapUnique(new HloInstruction(HloOpcode::kGetTupleElement, shape));
|
|
instruction->tuple_index_ = index;
|
|
instruction->AppendOperand(operand);
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRng(
|
|
const Shape& shape, RandomDistribution distribution,
|
|
tensorflow::gtl::ArraySlice<HloInstruction*> parameters) {
|
|
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kRng, shape));
|
|
instruction->distribution_ = distribution;
|
|
instruction->shape_ = shape;
|
|
for (HloInstruction* param : parameters) {
|
|
instruction->AppendOperand(param);
|
|
}
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary(
|
|
const Shape& shape, HloOpcode opcode,
|
|
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
|
|
if (opcode == HloOpcode::kCopy) {
|
|
// It is impossible to copy an opaque shape, we don't know how big it is.
|
|
CHECK(!ShapeUtil::IsOpaque(shape));
|
|
}
|
|
auto instruction = WrapUnique(new HloInstruction(opcode, shape));
|
|
for (auto operand : operands) {
|
|
instruction->AppendOperand(operand);
|
|
}
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateUnary(
|
|
const Shape& shape, HloOpcode opcode, HloInstruction* operand) {
|
|
// Only certain opcodes are supported with CreateUnary: opcodes of unary
|
|
// instructions with no auxiliary fields.
|
|
switch (opcode) {
|
|
case HloOpcode::kAbs:
|
|
case HloOpcode::kRoundNearestAfz:
|
|
case HloOpcode::kBitcast:
|
|
case HloOpcode::kCeil:
|
|
case HloOpcode::kCopy:
|
|
case HloOpcode::kCos:
|
|
case HloOpcode::kExp:
|
|
case HloOpcode::kFloor:
|
|
case HloOpcode::kIsFinite:
|
|
case HloOpcode::kLog:
|
|
case HloOpcode::kNot:
|
|
case HloOpcode::kNegate:
|
|
case HloOpcode::kSign:
|
|
case HloOpcode::kSin:
|
|
case HloOpcode::kSort:
|
|
case HloOpcode::kTanh:
|
|
break;
|
|
default:
|
|
LOG(FATAL) << "Invalid unary instruction opcode "
|
|
<< HloOpcodeString(opcode);
|
|
}
|
|
return CreateNary(shape, opcode, {operand});
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBinary(
|
|
const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
|
|
HloInstruction* rhs) {
|
|
// Only certain opcodes are supported with CreateBinary: opcodes of binary
|
|
// instructions with no auxiliary fields.
|
|
switch (opcode) {
|
|
case (HloOpcode::kAdd):
|
|
case (HloOpcode::kDivide):
|
|
case (HloOpcode::kDot):
|
|
case (HloOpcode::kEq):
|
|
case (HloOpcode::kGe):
|
|
case (HloOpcode::kGt):
|
|
case (HloOpcode::kLe):
|
|
case (HloOpcode::kLt):
|
|
case (HloOpcode::kMaximum):
|
|
case (HloOpcode::kMinimum):
|
|
case (HloOpcode::kMultiply):
|
|
case (HloOpcode::kNe):
|
|
case (HloOpcode::kPower):
|
|
case (HloOpcode::kRemainder):
|
|
case (HloOpcode::kSubtract):
|
|
case (HloOpcode::kAnd):
|
|
case (HloOpcode::kOr):
|
|
case (HloOpcode::kShiftLeft):
|
|
case (HloOpcode::kShiftRightArithmetic):
|
|
case (HloOpcode::kShiftRightLogical):
|
|
break;
|
|
default:
|
|
LOG(FATAL) << "Invalid binary instruction opcode "
|
|
<< HloOpcodeString(opcode);
|
|
}
|
|
return CreateNary(shape, opcode, {lhs, rhs});
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTernary(
|
|
const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
|
|
HloInstruction* rhs, HloInstruction* ehs) {
|
|
// Only certain opcodes are supported with CreateTernary: opcodes of ternary
|
|
// instructions with no auxiliary fields.
|
|
switch (opcode) {
|
|
case (HloOpcode::kClamp):
|
|
case (HloOpcode::kSelect):
|
|
break;
|
|
default:
|
|
LOG(FATAL) << "Invalid ternary instruction opcode "
|
|
<< HloOpcodeString(opcode);
|
|
}
|
|
return CreateNary(shape, opcode, {lhs, rhs, ehs});
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateVariadic(
|
|
const Shape& shape, HloOpcode opcode,
|
|
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
|
|
CHECK_EQ(HloOpcode::kTuple, opcode);
|
|
return CreateNary(shape, opcode, operands);
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap(
|
|
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
|
|
HloComputation* map_computation,
|
|
tensorflow::gtl::ArraySlice<HloInstruction*> static_operands) {
|
|
CHECK(static_operands.empty()) << "static_operands not yet supported";
|
|
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kMap, shape));
|
|
for (auto operand : operands) {
|
|
instruction->AppendOperand(operand);
|
|
}
|
|
instruction->called_computations_.push_back(map_computation);
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
|
|
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
|
|
const Window& window,
|
|
const ConvolutionDimensionNumbers& dimension_numbers) {
|
|
auto instruction =
|
|
WrapUnique(new HloInstruction(HloOpcode::kConvolution, shape));
|
|
if (window_util::HasBaseDilation(window)) {
|
|
instruction->name_ = instruction->name() + "-base-dilated";
|
|
}
|
|
if (window_util::HasWindowDilation(window)) {
|
|
instruction->name_ = instruction->name() + "-window-dilated";
|
|
}
|
|
instruction->AppendOperand(lhs);
|
|
instruction->AppendOperand(rhs);
|
|
instruction->window_ = MakeUnique<Window>(window);
|
|
instruction->convolution_dimension_numbers_ =
|
|
MakeUnique<ConvolutionDimensionNumbers>(dimension_numbers);
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction>
|
|
HloInstruction::CreateReducePrecision(const Shape& shape,
|
|
HloInstruction* operand,
|
|
const int exponent_bits,
|
|
const int mantissa_bits) {
|
|
auto instruction =
|
|
WrapUnique(new HloInstruction(HloOpcode::kReducePrecision, shape));
|
|
instruction->AppendOperand(operand);
|
|
instruction->exponent_bits_ = exponent_bits;
|
|
instruction->mantissa_bits_ = mantissa_bits;
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction>
|
|
HloInstruction::CreateCrossReplicaSum(const Shape& shape,
|
|
HloInstruction* operand) {
|
|
auto instruction =
|
|
WrapUnique(new HloInstruction(HloOpcode::kCrossReplicaSum, shape));
|
|
instruction->AppendOperand(operand);
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
|
|
const Shape& shape, const string& config) {
|
|
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kInfeed, shape));
|
|
instruction->set_infeed_config(config);
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
|
|
const Shape& shape, HloInstruction* operand,
|
|
tensorflow::StringPiece outfeed_config) {
|
|
std::unique_ptr<HloInstruction> instruction =
|
|
WrapUnique(new HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeNil()));
|
|
instruction->AppendOperand(operand);
|
|
instruction->outfeed_config_ = outfeed_config.ToString();
|
|
instruction->outfeed_shape_ = shape;
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
|
|
HloInstruction* operand, int64 channel_id) {
|
|
auto instruction =
|
|
WrapUnique(new HloInstruction(HloOpcode::kSend, ShapeUtil::MakeNil()));
|
|
instruction->AppendOperand(operand);
|
|
instruction->channel_id_ = channel_id;
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
|
|
const Shape& shape, int64 channel_id) {
|
|
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kRecv, shape));
|
|
instruction->channel_id_ = channel_id;
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
|
|
const Shape& shape, HloInstruction* operand,
|
|
tensorflow::gtl::ArraySlice<int64> dimensions) {
|
|
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReverse, shape));
|
|
instruction->AppendOperand(operand);
|
|
instruction->dimensions_.assign(dimensions.begin(), dimensions.end());
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile(
|
|
const Shape& shape, HloComputation* condition, HloComputation* body,
|
|
HloInstruction* init) {
|
|
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kWhile, shape));
|
|
instruction->AppendOperand(init);
|
|
// Body comes before condition computation in the vector.
|
|
instruction->called_computations_.push_back(body);
|
|
instruction->called_computations_.push_back(condition);
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
|
|
const Shape& shape, HloInstruction* operand,
|
|
tensorflow::gtl::ArraySlice<int64> start_indices,
|
|
tensorflow::gtl::ArraySlice<int64> limit_indices,
|
|
tensorflow::gtl::ArraySlice<int64> strides) {
|
|
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kSlice, shape));
|
|
instruction->AppendOperand(operand);
|
|
instruction->slice_starts_.assign(start_indices.begin(), start_indices.end());
|
|
instruction->slice_limits_.assign(limit_indices.begin(), limit_indices.end());
|
|
instruction->slice_strides_.assign(strides.begin(), strides.end());
|
|
// For backward compatibility with old serialized computations: if there are
|
|
// no strides, assume all strides are 1.
|
|
// TODO(b/63317920): remove this code.
|
|
if (instruction->slice_strides_.empty()) {
|
|
instruction->slice_strides_ = std::vector<int64>(start_indices.size(), 1LL);
|
|
}
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDynamicSlice(
|
|
const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
|
|
tensorflow::gtl::ArraySlice<int64> slice_sizes) {
|
|
auto instruction =
|
|
WrapUnique(new HloInstruction(HloOpcode::kDynamicSlice, shape));
|
|
instruction->AppendOperand(operand);
|
|
instruction->AppendOperand(start_indices);
|
|
instruction->dynamic_slice_sizes_.assign(slice_sizes.begin(),
|
|
slice_sizes.end());
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction>
|
|
HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
|
|
HloInstruction* operand,
|
|
HloInstruction* update,
|
|
HloInstruction* start_indices) {
|
|
auto instruction =
|
|
WrapUnique(new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape));
|
|
instruction->AppendOperand(operand);
|
|
instruction->AppendOperand(update);
|
|
instruction->AppendOperand(start_indices);
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
|
|
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
|
|
int64 dimension) {
|
|
auto instruction =
|
|
WrapUnique(new HloInstruction(HloOpcode::kConcatenate, shape));
|
|
for (auto operand : operands) {
|
|
instruction->AppendOperand(operand);
|
|
}
|
|
instruction->dimensions_.push_back(dimension);
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvert(
|
|
const Shape& shape, HloInstruction* operand) {
|
|
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kConvert, shape));
|
|
instruction->AppendOperand(operand);
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
|
|
const Shape& shape, HloInstruction* arg, HloInstruction* init_value,
|
|
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
|
|
HloComputation* reduce_computation) {
|
|
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReduce, shape));
|
|
instruction->AppendOperand(arg);
|
|
instruction->AppendOperand(init_value);
|
|
instruction->dimensions_.assign(dimensions_to_reduce.begin(),
|
|
dimensions_to_reduce.end());
|
|
instruction->called_computations_.push_back(reduce_computation);
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
|
|
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
|
|
const Window& window, HloComputation* reduce_computation) {
|
|
auto instruction =
|
|
WrapUnique(new HloInstruction(HloOpcode::kReduceWindow, shape));
|
|
instruction->AppendOperand(operand);
|
|
instruction->AppendOperand(init_value);
|
|
instruction->called_computations_.push_back(reduce_computation);
|
|
instruction->window_ = MakeUnique<Window>(window);
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction>
|
|
HloInstruction::CreateBatchNormTraining(const Shape& shape,
|
|
HloInstruction* operand,
|
|
HloInstruction* scale,
|
|
HloInstruction* offset, float epsilon,
|
|
int64 feature_index) {
|
|
auto instruction =
|
|
WrapUnique(new HloInstruction(HloOpcode::kBatchNormTraining, shape));
|
|
instruction->AppendOperand(operand);
|
|
instruction->AppendOperand(scale);
|
|
instruction->AppendOperand(offset);
|
|
instruction->epsilon_ = epsilon;
|
|
instruction->feature_index_ = feature_index;
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction>
|
|
HloInstruction::CreateBatchNormInference(
|
|
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
|
|
HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
|
|
float epsilon, int64 feature_index) {
|
|
auto instruction =
|
|
WrapUnique(new HloInstruction(HloOpcode::kBatchNormInference, shape));
|
|
instruction->AppendOperand(operand);
|
|
instruction->AppendOperand(scale);
|
|
instruction->AppendOperand(offset);
|
|
instruction->AppendOperand(mean);
|
|
instruction->AppendOperand(variance);
|
|
instruction->epsilon_ = epsilon;
|
|
instruction->feature_index_ = feature_index;
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction>
|
|
HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
|
|
HloInstruction* scale, HloInstruction* mean,
|
|
HloInstruction* variance,
|
|
HloInstruction* grad_output, float epsilon,
|
|
int64 feature_index) {
|
|
auto instruction =
|
|
WrapUnique(new HloInstruction(HloOpcode::kBatchNormGrad, shape));
|
|
instruction->AppendOperand(operand);
|
|
instruction->AppendOperand(scale);
|
|
instruction->AppendOperand(mean);
|
|
instruction->AppendOperand(variance);
|
|
instruction->AppendOperand(grad_output);
|
|
instruction->epsilon_ = epsilon;
|
|
instruction->feature_index_ = feature_index;
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction>
|
|
HloInstruction::CreateSelectAndScatter(
|
|
const Shape& shape, HloInstruction* operand, HloComputation* select,
|
|
const Window& window, HloInstruction* source, HloInstruction* init_value,
|
|
HloComputation* scatter) {
|
|
auto instruction =
|
|
WrapUnique(new HloInstruction(HloOpcode::kSelectAndScatter, shape));
|
|
instruction->AppendOperand(operand);
|
|
instruction->AppendOperand(source);
|
|
instruction->AppendOperand(init_value);
|
|
// Select comes before scatter in the vector.
|
|
instruction->called_computations_.push_back(select);
|
|
instruction->called_computations_.push_back(scatter);
|
|
instruction->window_ = MakeUnique<Window>(window);
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBroadcast(
|
|
const Shape& shape, HloInstruction* operand,
|
|
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
|
auto instruction =
|
|
WrapUnique(new HloInstruction(HloOpcode::kBroadcast, shape));
|
|
instruction->AppendOperand(operand);
|
|
instruction->dimensions_.assign(broadcast_dimensions.begin(),
|
|
broadcast_dimensions.end());
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePad(
|
|
const Shape& shape, HloInstruction* operand, HloInstruction* padding_value,
|
|
const PaddingConfig& padding_config) {
|
|
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kPad, shape));
|
|
instruction->AppendOperand(operand);
|
|
instruction->AppendOperand(padding_value);
|
|
instruction->padding_config_ = MakeUnique<PaddingConfig>(padding_config);
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape(
|
|
const Shape& shape, HloInstruction* operand) {
|
|
CHECK_EQ(ShapeUtil::ElementsIn(shape),
|
|
ShapeUtil::ElementsIn(operand->shape()))
|
|
<< "shape: " << ShapeUtil::HumanString(shape)
|
|
<< " operand: " << ShapeUtil::HumanString(operand->shape());
|
|
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReshape, shape));
|
|
instruction->AppendOperand(operand);
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
|
|
const Shape& shape, HloInstruction* operand,
|
|
tensorflow::gtl::ArraySlice<int64> dimensions) {
|
|
CHECK_EQ(shape.dimensions().size(), dimensions.size());
|
|
CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size());
|
|
CHECK(std::equal(operand->shape().dimensions().begin(),
|
|
operand->shape().dimensions().end(),
|
|
Permute(dimensions, shape.dimensions()).begin()));
|
|
auto instruction =
|
|
WrapUnique(new HloInstruction(HloOpcode::kTranspose, shape));
|
|
instruction->AppendOperand(operand);
|
|
instruction->dimensions_.assign(dimensions.begin(), dimensions.end());
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
|
|
const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
|
|
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
|
|
instruction->fusion_kind_ = fusion_kind;
|
|
instruction->set_parent(fused_root->parent());
|
|
instruction->set_metadata(fused_root->metadata());
|
|
instruction->CloneAndFuseInternal(fused_root);
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction>
|
|
HloInstruction::CreateFusionForBackwardConvolution(
|
|
const Shape& shape, FusionKind fusion_kind, const Window& window,
|
|
const ConvolutionDimensionNumbers& conv_dnums, HloInstruction* fused_root) {
|
|
std::unique_ptr<HloInstruction> fusion =
|
|
CreateFusion(shape, fusion_kind, fused_root);
|
|
fusion->window_ = MakeUnique<Window>(window);
|
|
fusion->convolution_dimension_numbers_ =
|
|
MakeUnique<ConvolutionDimensionNumbers>(conv_dnums);
|
|
return fusion;
|
|
}
|
|
|
|
void HloInstruction::MergeFusionInstruction(
|
|
HloInstruction* instruction_to_merge) {
|
|
CHECK_EQ(opcode_, HloOpcode::kFusion);
|
|
CHECK_EQ(instruction_to_merge->opcode(), HloOpcode::kFusion);
|
|
CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) !=
|
|
operands().end());
|
|
// Clone the instruction from which to merge fused instructions.
|
|
std::unique_ptr<HloInstruction> clone = instruction_to_merge->Clone();
|
|
// Replace uses of fused parameters with the corresponding operand of the
|
|
// fusion. Add all non-parameter fused instructions to 'unfused_instructions'
|
|
// to be merged into 'this'. This is done in reverse post order.
|
|
std::vector<HloInstruction*> unfused_instructions;
|
|
auto fused_instructions =
|
|
clone->fused_instructions_computation()->MakeInstructionPostOrder();
|
|
for (auto fused_it = fused_instructions.rbegin();
|
|
fused_it != fused_instructions.rend(); ++fused_it) {
|
|
auto fused_instruction = *fused_it;
|
|
if (fused_instruction->opcode() == HloOpcode::kParameter) {
|
|
TF_CHECK_OK(fused_instruction->ReplaceAllUsesWith(
|
|
clone->mutable_operand(fused_instruction->parameter_number())));
|
|
} else {
|
|
unfused_instructions.push_back(fused_instruction);
|
|
}
|
|
}
|
|
CHECK(unfused_instructions.front() == clone->fused_expression_root());
|
|
// Replace instruction_to_merge use of 'this' with unfused_root.
|
|
TF_CHECK_OK(
|
|
instruction_to_merge->ReplaceUseWith(this, unfused_instructions.front()));
|
|
// Fuse 'unfused_instructions' into 'this'.
|
|
for (auto& instruction : unfused_instructions) {
|
|
FuseInstruction(instruction);
|
|
instruction->DetachFromOperands();
|
|
}
|
|
CHECK_EQ(0, clone->user_count());
|
|
clone->DetachFromOperands();
|
|
TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation(
|
|
clone->fused_instructions_computation()));
|
|
}
|
|
|
|
void HloInstruction::MergeFusionInstructionIntoMultiOutput(
|
|
HloInstruction* instruction_to_merge) {
|
|
CHECK_EQ(opcode_, HloOpcode::kFusion);
|
|
CHECK_EQ(instruction_to_merge->opcode(), HloOpcode::kFusion);
|
|
// Add all non-parameter fused instructions to 'unfused_instructions' to be
|
|
// merged into 'this'. `old_to_new' maps the instructions in the fused node
|
|
// to the disaseembled fusion instructions.
|
|
// Note that we add the unfused instructions to this->parent_ computation.
|
|
// This is necessary because the unique_id needs for an instruction and
|
|
// it's only added when inserting to the computation.
|
|
tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> old_to_new;
|
|
std::vector<HloInstruction*> unfused_instructions;
|
|
auto computation_to_merge =
|
|
instruction_to_merge->fused_instructions_computation();
|
|
auto post_order = computation_to_merge->MakeInstructionPostOrder();
|
|
for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) {
|
|
auto fused_instruction = *rit;
|
|
if (fused_instruction->opcode() == HloOpcode::kParameter) {
|
|
InsertOrDie(&old_to_new, fused_instruction,
|
|
instruction_to_merge->mutable_operand(
|
|
fused_instruction->parameter_number()));
|
|
continue;
|
|
}
|
|
|
|
// Here we clone the insertion and call FuseInstructionIntoMultiOutput()
|
|
// which clones again. This can be improved.
|
|
auto cloned_instruction =
|
|
parent_->AddInstruction(fused_instruction->Clone());
|
|
unfused_instructions.push_back(cloned_instruction);
|
|
InsertOrDie(&old_to_new, fused_instruction, cloned_instruction);
|
|
}
|
|
for (auto unfused_instruction : unfused_instructions) {
|
|
for (int64 index = 0; index < unfused_instruction->operand_count();
|
|
index++) {
|
|
auto new_operand =
|
|
FindOrDie(old_to_new, unfused_instruction->mutable_operand(index));
|
|
TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand));
|
|
}
|
|
}
|
|
|
|
HloInstruction* unfused_root = unfused_instructions.front();
|
|
TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root));
|
|
|
|
TF_CHECK_OK(
|
|
instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge));
|
|
if (GetModule()) {
|
|
TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge));
|
|
}
|
|
|
|
// Fuse the root instruction and generate multiple outputs.
|
|
FuseInstructionIntoMultiOutput(unfused_root);
|
|
TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root));
|
|
// The rest instructions are of normal fusing.
|
|
for (int64 i = 1; i < unfused_instructions.size(); i++) {
|
|
auto instruction = unfused_instructions[i];
|
|
FuseInstruction(instruction);
|
|
TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
|
|
}
|
|
}
|
|
|
|
HloInstruction* HloInstruction::FuseInstructionInternal(
|
|
HloInstruction* instruction_to_fuse, bool add_output) {
|
|
CHECK_EQ(opcode_, HloOpcode::kFusion);
|
|
|
|
// When add_output is false, this fusion instruction must be a user of
|
|
// instruction_to_fuse.
|
|
if (!add_output) {
|
|
CHECK(IsUserOf(instruction_to_fuse));
|
|
}
|
|
HloInstruction* fused_instruction =
|
|
CloneAndFuseInternal(instruction_to_fuse, add_output);
|
|
return fused_instruction;
|
|
}
|
|
|
|
HloInstruction* HloInstruction::CloneAndFuseInternal(
|
|
HloInstruction* instruction_to_fuse, bool add_output) {
|
|
CHECK_EQ(opcode_, HloOpcode::kFusion);
|
|
CHECK(instruction_to_fuse->IsFusable());
|
|
VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString();
|
|
HloInstruction* clone = nullptr;
|
|
if (called_computations_.empty()) {
|
|
// New fusion instruction. It should not be a multioutput instruction.
|
|
CHECK(!add_output);
|
|
auto builder = HloComputation::Builder("fused_computation", this);
|
|
builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/""));
|
|
called_computations_.push_back(
|
|
CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build()));
|
|
clone = fused_expression_root();
|
|
} else {
|
|
clone = fused_instructions_computation()->AddInstruction(
|
|
instruction_to_fuse->Clone(/*suffix=*/""));
|
|
// When add_output is false, instruction_to_fuse is necessarily an operand
|
|
// of the fusion instruction. After fusion this will no longer be the case.
|
|
// Remove the operand from the operand list and remove its corresponding
|
|
// fused parameter instruction. Renumber parameters as necessary to make
|
|
// parameter numbers consistent with their index in the
|
|
// fused_parameter_ vector.
|
|
bool in_operand_list = std::find(operands_.begin(), operands_.end(),
|
|
instruction_to_fuse) != operands_.end();
|
|
CHECK(add_output || in_operand_list);
|
|
const std::vector<HloInstruction*>& fused_parameters =
|
|
fused_instructions_computation()->parameter_instructions();
|
|
for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
|
|
if (instruction_to_fuse == operands_[operand_num]) {
|
|
// replace the fused parameter instruction's uses with the clone.
|
|
HloInstruction* fused_parameter = fused_parameters[operand_num];
|
|
TF_CHECK_OK(fused_parameter->ReplaceAllUsesWith(clone));
|
|
|
|
// Remove the corresponding fused parameter and operand from their
|
|
// respective vectors.
|
|
TF_CHECK_OK(
|
|
fused_instructions_computation()->RemoveParameter(operand_num));
|
|
operands_.erase(operands_.begin() + operand_num);
|
|
break;
|
|
}
|
|
}
|
|
// We've cloned instruction_to_fuse into this fusion instruction, so this
|
|
// fusion instruction is no longer a use of instruction_to_fuse.
|
|
if (in_operand_list) {
|
|
instruction_to_fuse->RemoveUser(this);
|
|
// When the instruction_to_fuse does not have other users, we don't need
|
|
// to generate a multioutput fusion instruction.
|
|
if (instruction_to_fuse->user_count() == 0) {
|
|
add_output = false;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Reread the parameters in the computation.
|
|
const std::vector<HloInstruction*>& fused_parameters =
|
|
fused_instructions_computation()->parameter_instructions();
|
|
|
|
// Add each operand of the clone as an operand of the fusion instruction. A
|
|
// complication is that some clone operands may already be operands of the
|
|
// fusion instruction.
|
|
for (int64 operand_num = 0; operand_num < clone->operand_count();
|
|
++operand_num) {
|
|
HloInstruction* operand = clone->mutable_operand(operand_num);
|
|
|
|
// See if this operand is already an operand of the fusion node.
|
|
CHECK_EQ(operands_.size(), fused_parameters.size());
|
|
HloInstruction* fused_param = nullptr;
|
|
for (int64 i = 0; i < operands_.size(); ++i) {
|
|
if (operands_[i] == operand) {
|
|
fused_param = fused_parameters[i];
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (fused_param == nullptr) {
|
|
// Clone's operand was not already an operand of the fusion
|
|
// instruction. Add it as an operand and add a corresponding fused
|
|
// parameter instruction.
|
|
int64 param_no = fused_parameters.size();
|
|
// Name the parameter after the instruction it represents in the outer
|
|
// (non-fusion) computation. Strip the leading "%" from the operand name
|
|
// to avoid a double %%.
|
|
string param_name =
|
|
StrCat(operand->name().substr(1), ".param_", param_no);
|
|
fused_param = fused_instructions_computation()->AddParameter(
|
|
CreateParameter(param_no, operand->shape(), param_name));
|
|
AppendOperand(operand);
|
|
}
|
|
TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param));
|
|
}
|
|
|
|
if (add_output) {
|
|
CHECK_GT(instruction_to_fuse->user_count(), 0);
|
|
// If this is already a multioutput fusion instruction, expand the root
|
|
// tuple by 1.
|
|
HloInstruction* fused_root = fused_expression_root();
|
|
HloInstruction::InstructionVector tuple_elements;
|
|
bool newly_created_tuple_instr = false;
|
|
if (fused_root->opcode() == HloOpcode::kTuple) {
|
|
tuple_elements = fused_root->operands();
|
|
} else {
|
|
tuple_elements.push_back(fused_root);
|
|
newly_created_tuple_instr = true;
|
|
}
|
|
if (clone->opcode() == HloOpcode::kTuple) {
|
|
for (auto inst : clone->operands()) {
|
|
tuple_elements.push_back(inst);
|
|
}
|
|
} else {
|
|
tuple_elements.push_back(clone);
|
|
}
|
|
HloInstruction* new_root = fused_instructions_computation()->AddInstruction(
|
|
HloInstruction::CreateTuple(tuple_elements));
|
|
fused_instructions_computation()->set_root_instruction(new_root);
|
|
shape_ = new_root->shape();
|
|
if (fused_root->opcode() == HloOpcode::kTuple) {
|
|
TF_CHECK_OK(
|
|
fused_instructions_computation()->RemoveInstruction(fused_root));
|
|
}
|
|
|
|
// If this is a newly created multioutput instruction, we need to update
|
|
// the use of the original fusion instruction.
|
|
if (newly_created_tuple_instr) {
|
|
HloInstruction* new_instr = parent_->AddInstruction(
|
|
HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0));
|
|
TF_CHECK_OK(ReplaceAllUsesWith(new_instr));
|
|
}
|
|
int64 index = tuple_elements.size();
|
|
if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
|
|
index -= instruction_to_fuse->operand_count();
|
|
std::vector<HloInstruction*> to_be_removed;
|
|
for (auto old_gte : instruction_to_fuse->users()) {
|
|
CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement);
|
|
int64 old_tuple_index = old_gte->tuple_index();
|
|
HloInstruction* new_gte =
|
|
parent_->AddInstruction(HloInstruction::CreateGetTupleElement(
|
|
old_gte->shape(), this, index + old_tuple_index));
|
|
TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte));
|
|
to_be_removed.push_back(old_gte);
|
|
}
|
|
for (auto old_gte : to_be_removed) {
|
|
TF_CHECK_OK(parent_->RemoveInstruction(old_gte));
|
|
}
|
|
TF_CHECK_OK(fused_instructions_computation()->RemoveInstruction(clone));
|
|
} else {
|
|
HloInstruction* new_gte =
|
|
parent_->AddInstruction(HloInstruction::CreateGetTupleElement(
|
|
clone->shape(), this, index - 1));
|
|
TF_CHECK_OK(instruction_to_fuse->ReplaceAllUsesWith(new_gte));
|
|
}
|
|
}
|
|
|
|
VLOG(2) << "New clone:\n" << clone->ToString();
|
|
return clone;
|
|
}
|
|
|
|
RandomDistribution HloInstruction::random_distribution() const {
|
|
CHECK_EQ(opcode_, HloOpcode::kRng);
|
|
return distribution_;
|
|
}
|
|
|
|
bool HloInstruction::HasSideEffect() const {
|
|
switch (opcode_) {
|
|
case HloOpcode::kSend:
|
|
case HloOpcode::kRecv:
|
|
case HloOpcode::kInfeed:
|
|
case HloOpcode::kOutfeed:
|
|
case HloOpcode::kTrace:
|
|
return true;
|
|
default: {
|
|
// Check if any of the called computations has a side effect.
|
|
for (const auto& computation : called_computations()) {
|
|
if (computation->HasSideEffect()) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCall(
|
|
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
|
|
HloComputation* computation) {
|
|
std::unique_ptr<HloInstruction> instruction =
|
|
WrapUnique(new HloInstruction(HloOpcode::kCall, shape));
|
|
for (auto operand : operands) {
|
|
instruction->AppendOperand(operand);
|
|
}
|
|
instruction->called_computations_.push_back(computation);
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
|
|
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
|
|
tensorflow::StringPiece custom_call_target) {
|
|
std::unique_ptr<HloInstruction> instruction =
|
|
WrapUnique(new HloInstruction(HloOpcode::kCustomCall, shape));
|
|
for (auto operand : operands) {
|
|
instruction->AppendOperand(operand);
|
|
}
|
|
instruction->custom_call_target_ = custom_call_target.ToString();
|
|
return instruction;
|
|
}
|
|
|
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
|
|
tensorflow::gtl::ArraySlice<HloInstruction*> elements) {
|
|
std::vector<Shape> element_shapes;
|
|
for (auto element : elements) {
|
|
element_shapes.push_back(element->shape());
|
|
}
|
|
Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
|
|
return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements);
|
|
}
|
|
|
|
std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
|
const Shape& shape,
|
|
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands) const {
|
|
VLOG(3) << "CloneWithNewOperands:\n " << ToString();
|
|
VLOG(3) << " new operands:";
|
|
for (const HloInstruction* new_operand : new_operands) {
|
|
VLOG(3) << " " << new_operand->name();
|
|
}
|
|
|
|
std::unique_ptr<HloInstruction> clone;
|
|
|
|
// Explicitly call the factory for the instruction type. This is more robust
|
|
// in the face of code changes than copying fields explicitly. This also
|
|
// properly sets the user fields of the operands.
|
|
switch (opcode_) {
|
|
// Unary ops.
|
|
case HloOpcode::kAbs:
|
|
case HloOpcode::kRoundNearestAfz:
|
|
case HloOpcode::kBitcast:
|
|
case HloOpcode::kCeil:
|
|
case HloOpcode::kCopy:
|
|
case HloOpcode::kCos:
|
|
case HloOpcode::kExp:
|
|
case HloOpcode::kIsFinite:
|
|
case HloOpcode::kFloor:
|
|
case HloOpcode::kLog:
|
|
case HloOpcode::kNot:
|
|
case HloOpcode::kNegate:
|
|
case HloOpcode::kSign:
|
|
case HloOpcode::kSin:
|
|
case HloOpcode::kSort:
|
|
case HloOpcode::kTanh:
|
|
CHECK_EQ(new_operands.size(), 1);
|
|
clone = CreateUnary(shape, opcode_, new_operands[0]);
|
|
break;
|
|
// Binary ops.
|
|
case HloOpcode::kAdd:
|
|
case HloOpcode::kDivide:
|
|
case HloOpcode::kMultiply:
|
|
case HloOpcode::kSubtract:
|
|
case HloOpcode::kEq:
|
|
case HloOpcode::kGe:
|
|
case HloOpcode::kGt:
|
|
case HloOpcode::kLe:
|
|
case HloOpcode::kLt:
|
|
case HloOpcode::kNe:
|
|
case HloOpcode::kDot:
|
|
case HloOpcode::kMaximum:
|
|
case HloOpcode::kMinimum:
|
|
case HloOpcode::kPower:
|
|
case HloOpcode::kRemainder:
|
|
case HloOpcode::kAnd:
|
|
case HloOpcode::kOr:
|
|
case HloOpcode::kShiftLeft:
|
|
case HloOpcode::kShiftRightArithmetic:
|
|
case HloOpcode::kShiftRightLogical:
|
|
CHECK_EQ(new_operands.size(), 2);
|
|
clone = CreateBinary(shape, opcode_, new_operands[0], new_operands[1]);
|
|
break;
|
|
// Ternary ops.
|
|
case HloOpcode::kClamp:
|
|
case HloOpcode::kSelect:
|
|
CHECK_EQ(new_operands.size(), 3);
|
|
clone = CreateTernary(shape, opcode_, new_operands[0], new_operands[1],
|
|
new_operands[2]);
|
|
break;
|
|
// Other supported ops.
|
|
case HloOpcode::kBroadcast:
|
|
CHECK_EQ(new_operands.size(), 1);
|
|
clone = CreateBroadcast(shape, new_operands[0], dimensions_);
|
|
break;
|
|
case HloOpcode::kCall:
|
|
clone = CreateCall(shape, new_operands, to_apply());
|
|
break;
|
|
case HloOpcode::kCustomCall:
|
|
clone = CreateCustomCall(shape, new_operands, custom_call_target_);
|
|
break;
|
|
case HloOpcode::kConcatenate:
|
|
clone = CreateConcatenate(shape, new_operands, dimensions(0));
|
|
break;
|
|
case HloOpcode::kConvert:
|
|
CHECK_EQ(new_operands.size(), 1);
|
|
clone = CreateConvert(shape, new_operands[0]);
|
|
break;
|
|
case HloOpcode::kReducePrecision:
|
|
CHECK_EQ(new_operands.size(), 1);
|
|
clone = CreateReducePrecision(shape, new_operands[0], exponent_bits_,
|
|
mantissa_bits_);
|
|
break;
|
|
case HloOpcode::kConvolution:
|
|
CHECK_EQ(new_operands.size(), 2);
|
|
clone = CreateConvolve(shape, new_operands[0], new_operands[1], *window_,
|
|
*convolution_dimension_numbers_);
|
|
break;
|
|
case HloOpcode::kCrossReplicaSum:
|
|
CHECK_EQ(new_operands.size(), 1);
|
|
clone = CreateCrossReplicaSum(shape, new_operands[0]);
|
|
break;
|
|
case HloOpcode::kGetTupleElement:
|
|
CHECK_EQ(new_operands.size(), 1);
|
|
clone = CreateGetTupleElement(shape, new_operands[0], tuple_index());
|
|
break;
|
|
case HloOpcode::kMap:
|
|
clone = CreateMap(shape, new_operands, to_apply());
|
|
break;
|
|
case HloOpcode::kPad:
|
|
CHECK_EQ(new_operands.size(), 2);
|
|
clone =
|
|
CreatePad(shape, new_operands[0], new_operands[1], *padding_config_);
|
|
break;
|
|
case HloOpcode::kReduce:
|
|
CHECK_EQ(new_operands.size(), 2);
|
|
clone = CreateReduce(shape, new_operands[0], new_operands[1], dimensions_,
|
|
to_apply());
|
|
break;
|
|
case HloOpcode::kReduceWindow:
|
|
CHECK_EQ(new_operands.size(), 2);
|
|
clone = CreateReduceWindow(shape, new_operands[0], new_operands[1],
|
|
*window_, to_apply());
|
|
break;
|
|
case HloOpcode::kSelectAndScatter:
|
|
CHECK_EQ(new_operands.size(), 3);
|
|
clone =
|
|
CreateSelectAndScatter(shape, new_operands[0], select(), *window_,
|
|
new_operands[1], new_operands[2], scatter());
|
|
break;
|
|
case HloOpcode::kReverse:
|
|
CHECK_EQ(new_operands.size(), 1);
|
|
clone = CreateReverse(shape, new_operands[0], dimensions_);
|
|
break;
|
|
case HloOpcode::kRng:
|
|
clone = CreateRng(shape, distribution_, new_operands);
|
|
break;
|
|
case HloOpcode::kReshape:
|
|
CHECK_EQ(new_operands.size(), 1);
|
|
clone = CreateReshape(shape, new_operands[0]);
|
|
break;
|
|
case HloOpcode::kSlice:
|
|
CHECK_EQ(new_operands.size(), 1);
|
|
clone = CreateSlice(shape, new_operands[0], slice_starts_, slice_limits_,
|
|
slice_strides_);
|
|
break;
|
|
case HloOpcode::kDynamicSlice:
|
|
clone = CreateDynamicSlice(shape, new_operands[0], new_operands[1],
|
|
dynamic_slice_sizes_);
|
|
break;
|
|
case HloOpcode::kDynamicUpdateSlice:
|
|
CHECK_EQ(new_operands.size(), 3);
|
|
clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1],
|
|
new_operands[2]);
|
|
break;
|
|
case HloOpcode::kTranspose:
|
|
CHECK_EQ(new_operands.size(), 1);
|
|
clone = CreateTranspose(shape, new_operands[0], dimensions_);
|
|
break;
|
|
case HloOpcode::kTuple:
|
|
clone = CreateTuple(new_operands);
|
|
*clone->mutable_shape() = shape;
|
|
break;
|
|
case HloOpcode::kWhile:
|
|
CHECK_EQ(new_operands.size(), 1);
|
|
clone =
|
|
CreateWhile(shape, while_condition(), while_body(), new_operands[0]);
|
|
break;
|
|
case HloOpcode::kConstant:
|
|
clone = CreateConstant(literal_->CloneToUnique());
|
|
break;
|
|
case HloOpcode::kFusion:
|
|
clone = CloneFusionWithNewOperands(shape, new_operands);
|
|
break;
|
|
case HloOpcode::kParameter:
|
|
clone = CreateParameter(parameter_number_, shape, parameter_name_);
|
|
break;
|
|
case HloOpcode::kBatchNormTraining:
|
|
CHECK_EQ(new_operands.size(), 3);
|
|
clone =
|
|
CreateBatchNormTraining(shape, new_operands[0], new_operands[1],
|
|
new_operands[2], epsilon(), feature_index());
|
|
break;
|
|
case HloOpcode::kBatchNormInference:
|
|
CHECK_EQ(new_operands.size(), 5);
|
|
clone = CreateBatchNormInference(
|
|
shape, new_operands[0], new_operands[1], new_operands[2],
|
|
new_operands[3], new_operands[4], epsilon(), feature_index());
|
|
break;
|
|
case HloOpcode::kInfeed:
|
|
CHECK_EQ(new_operands.size(), 0);
|
|
clone = CreateInfeed(shape, infeed_config());
|
|
break;
|
|
case HloOpcode::kOutfeed:
|
|
CHECK_EQ(new_operands.size(), 1);
|
|
clone = CreateOutfeed(outfeed_shape_, new_operands[0], outfeed_config());
|
|
break;
|
|
case HloOpcode::kBatchNormGrad:
|
|
CHECK_EQ(new_operands.size(), 5);
|
|
clone = CreateBatchNormGrad(shape, new_operands[0], new_operands[1],
|
|
new_operands[2], new_operands[3],
|
|
new_operands[4], epsilon(), feature_index());
|
|
break;
|
|
case HloOpcode::kRecv:
|
|
case HloOpcode::kSend:
|
|
case HloOpcode::kUpdate:
|
|
case HloOpcode::kIndex:
|
|
case HloOpcode::kTrace:
|
|
LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_);
|
|
}
|
|
clone->set_metadata(metadata_);
|
|
return clone;
|
|
}
|
|
|
|
HloInstruction::~HloInstruction() {}
|
|
|
|
std::unique_ptr<HloInstruction> HloInstruction::Clone(
|
|
const string& suffix) const {
|
|
std::unique_ptr<HloInstruction> clone =
|
|
CloneWithNewOperands(shape_, operands_);
|
|
if (suffix.empty()) {
|
|
clone->name_ = name();
|
|
} else {
|
|
// If an instruction is cloned multiple times avoid names like
|
|
// foo.suffix.suffix.suffix. Instead of repeating the suffix add a numeric
|
|
// suffix. Specifically, the clone of foo.suffix is named foo.suffix2, the
|
|
// clone of foo.suffix2 is named foo.suffix3 and so on.
|
|
const string dot_suffix = "." + suffix;
|
|
size_t index = name().rfind(dot_suffix);
|
|
if (index == string::npos) {
|
|
// Existing name does not include ".suffix".
|
|
clone->name_ = name() + dot_suffix;
|
|
} else {
|
|
// Existing name includes ".suffix". Determine if substring after
|
|
// ".suffix" is numeric and should be replaced with an incremented number.
|
|
string after_suffix = name().substr(index + dot_suffix.size());
|
|
if (after_suffix.empty()) {
|
|
// Existing name ends in ".suffix". New name should end in ".suffix2".
|
|
clone->name_ = name() + "2";
|
|
} else {
|
|
// If names ends with .suffix[0-9]+ then replace with a suffix with the
|
|
// numeric value incremented.
|
|
int64 numeric_suffix;
|
|
if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) {
|
|
clone->name_ =
|
|
StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1);
|
|
} else {
|
|
// Substring after ".suffix" is non-numeric.
|
|
clone->name_ = name() + dot_suffix;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
clone->set_parent(parent_);
|
|
return clone;
|
|
}
|
|
|
|
std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
|
|
const Shape& shape,
|
|
tensorflow::gtl::ArraySlice<HloInstruction*> operands) const {
|
|
CHECK_EQ(opcode_, HloOpcode::kFusion);
|
|
CHECK(parent() != nullptr);
|
|
|
|
auto new_instruction =
|
|
WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
|
|
// Add the operands to our new fusion instruction.
|
|
for (HloInstruction* new_operand : operands) {
|
|
new_instruction->AppendOperand(new_operand);
|
|
}
|
|
// Clone all the fused instructions for the new fusion instruction.
|
|
std::map<HloInstruction*, HloInstruction*> old_to_new;
|
|
std::list<std::unique_ptr<HloInstruction>> new_fused_instructions;
|
|
// Create the list of fused parameters by mapping through the cloned,
|
|
// fused instructions.
|
|
for (HloInstruction* old_fused_parameter :
|
|
fused_instructions_computation()->parameter_instructions()) {
|
|
new_fused_instructions.push_back(old_fused_parameter->Clone());
|
|
HloInstruction* new_fusion_parameter = new_fused_instructions.back().get();
|
|
InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter);
|
|
}
|
|
for (auto old_fused_instruction :
|
|
fused_instructions_computation()->MakeInstructionPostOrder()) {
|
|
if (old_fused_instruction->opcode() == HloOpcode::kParameter) {
|
|
FindOrDie(old_to_new, old_fused_instruction);
|
|
continue;
|
|
}
|
|
std::vector<HloInstruction*> new_operands;
|
|
for (int64 operand_idx = 0;
|
|
operand_idx < old_fused_instruction->operand_count(); ++operand_idx) {
|
|
HloInstruction* old_operand =
|
|
old_fused_instruction->mutable_operand(operand_idx);
|
|
new_operands.push_back(FindOrDie(old_to_new, old_operand));
|
|
}
|
|
new_fused_instructions.push_back(
|
|
old_fused_instruction->CloneWithNewOperands(
|
|
old_fused_instruction->shape(), new_operands));
|
|
HloInstruction* new_fused_instruction = new_fused_instructions.back().get();
|
|
new_fused_instruction->set_parent(parent_);
|
|
InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction);
|
|
}
|
|
new_instruction->fusion_kind_ = fusion_kind_;
|
|
auto computation_builder = HloComputation::Builder(
|
|
fused_instructions_computation()->name() + ".clone",
|
|
new_instruction.get());
|
|
// We iterated the fusion instructions in reverse post order which means
|
|
// that we must reverse our new list of fusion instructions.
|
|
for (auto new_fused_instruction_iter = new_fused_instructions.rbegin();
|
|
new_fused_instruction_iter != new_fused_instructions.rend();
|
|
++new_fused_instruction_iter) {
|
|
computation_builder.AddInstruction(std::move(*new_fused_instruction_iter));
|
|
}
|
|
auto fused_root_ = fused_expression_root();
|
|
new_instruction->called_computations_.push_back(
|
|
CHECK_NOTNULL(GetModule())
|
|
->AddEmbeddedComputation(
|
|
computation_builder.Build(FindOrDie(old_to_new, fused_root_))));
|
|
new_instruction->set_parent(parent_);
|
|
return new_instruction;
|
|
}
|
|
|
|
std::pair<const HloInstruction*, ShapeIndex>
|
|
HloInstruction::LatestNonGteAncestorAndIndex() const {
|
|
const HloInstruction* hlo = this;
|
|
ShapeIndex index;
|
|
while (hlo->opcode() == HloOpcode::kGetTupleElement) {
|
|
index.push_back(hlo->tuple_index());
|
|
hlo = hlo->operand(0);
|
|
}
|
|
|
|
// We built up index in the reverse order from what we want.
|
|
std::reverse(index.begin(), index.end());
|
|
|
|
return {hlo, index};
|
|
}
|
|
|
|
const HloInstruction* HloInstruction::LatestNonGteAncestor() const {
|
|
const HloInstruction* hlo = this;
|
|
while (hlo->opcode() == HloOpcode::kGetTupleElement) {
|
|
hlo = hlo->operand(0);
|
|
}
|
|
return hlo;
|
|
}
|
|
|
|
const Literal& HloInstruction::literal() const {
|
|
CHECK_EQ(HloOpcode::kConstant, opcode_);
|
|
return *literal_;
|
|
}
|
|
|
|
bool HloInstruction::CanHaveDimensionsField() const {
|
|
return (opcode() == HloOpcode::kReverse ||
|
|
opcode() == HloOpcode::kConcatenate ||
|
|
opcode() == HloOpcode::kReduce || opcode() == HloOpcode::kBroadcast ||
|
|
opcode() == HloOpcode::kTranspose);
|
|
}
|
|
|
|
const std::vector<int64>& HloInstruction::dimensions() const {
|
|
CHECK(CanHaveDimensionsField());
|
|
return dimensions_;
|
|
}
|
|
|
|
int64 HloInstruction::dimensions(int64 index) const {
|
|
return dimensions()[index];
|
|
}
|
|
|
|
int64 HloInstruction::concatenate_dimension() const {
|
|
CHECK(opcode() == HloOpcode::kConcatenate);
|
|
CHECK_EQ(1, dimensions_.size());
|
|
return dimensions(0);
|
|
}
|
|
|
|
int64 HloInstruction::tuple_index() const {
|
|
CHECK_EQ(HloOpcode::kGetTupleElement, opcode_);
|
|
return tuple_index_;
|
|
}
|
|
|
|
const HloInstruction* HloInstruction::operand(int64 i) const {
|
|
return operands_[i];
|
|
}
|
|
|
|
HloInstruction* HloInstruction::mutable_operand(int64 i) {
|
|
CHECK(operands_[i] != nullptr);
|
|
return operands_[i];
|
|
}
|
|
|
|
int64 HloInstruction::operand_index(const HloInstruction* target) const {
|
|
for (int64 i = 0; i < operand_count(); ++i) {
|
|
if (target == operand(i)) {
|
|
return i;
|
|
}
|
|
}
|
|
LOG(FATAL) << "target was not an operand";
|
|
}
|
|
|
|
Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) {
|
|
TF_RET_CHECK(instruction->parent() == parent());
|
|
if (std::find(control_successors_.begin(), control_successors_.end(),
|
|
instruction) == control_successors_.end()) {
|
|
control_successors_.push_back(instruction);
|
|
TF_RET_CHECK(std::find(instruction->control_predecessors_.begin(),
|
|
instruction->control_predecessors_.end(),
|
|
this) == instruction->control_predecessors_.end());
|
|
instruction->control_predecessors_.push_back(this);
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) {
|
|
auto succ_it = std::find(control_successors_.begin(),
|
|
control_successors_.end(), instruction);
|
|
TF_RET_CHECK(succ_it != control_successors_.end());
|
|
control_successors_.erase(succ_it);
|
|
auto pred_it = std::find(instruction->control_predecessors_.begin(),
|
|
instruction->control_predecessors_.end(), this);
|
|
TF_RET_CHECK(pred_it != instruction->control_predecessors_.end());
|
|
instruction->control_predecessors_.erase(pred_it);
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
void HloInstruction::AppendOperand(HloInstruction* operand) {
|
|
operands_.push_back(operand);
|
|
operand->AddUser(this);
|
|
}
|
|
|
|
void HloInstruction::AddUser(HloInstruction* user) {
|
|
if (!ContainsKey(user_set_, user)) {
|
|
user_set_.insert(user);
|
|
users_.push_back(user);
|
|
}
|
|
}
|
|
|
|
bool HloInstruction::IsConstant() const {
|
|
return opcode_ == HloOpcode::kConstant;
|
|
}
|
|
|
|
bool HloInstruction::HasConstantOperand() const {
|
|
for (const HloInstruction* operand : operands_) {
|
|
if (operand->IsConstant()) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool HloInstruction::IdenticalSlowPath(
|
|
const HloInstruction& other,
|
|
std::function<bool(const HloComputation*, const HloComputation*)>
|
|
eq_computations) const {
|
|
// Perform opcode specific checks.
|
|
switch (opcode()) {
|
|
// The result of these instructions only depend upon their opcode and
|
|
// operands.
|
|
case HloOpcode::kAbs:
|
|
case HloOpcode::kRoundNearestAfz:
|
|
case HloOpcode::kAdd:
|
|
case HloOpcode::kCeil:
|
|
case HloOpcode::kClamp:
|
|
case HloOpcode::kCopy:
|
|
case HloOpcode::kCos:
|
|
case HloOpcode::kCrossReplicaSum:
|
|
case HloOpcode::kDivide:
|
|
case HloOpcode::kDot:
|
|
case HloOpcode::kEq:
|
|
case HloOpcode::kExp:
|
|
case HloOpcode::kFloor:
|
|
case HloOpcode::kGe:
|
|
case HloOpcode::kGt:
|
|
case HloOpcode::kIsFinite:
|
|
case HloOpcode::kLe:
|
|
case HloOpcode::kLog:
|
|
case HloOpcode::kAnd:
|
|
case HloOpcode::kNot:
|
|
case HloOpcode::kOr:
|
|
case HloOpcode::kLt:
|
|
case HloOpcode::kMaximum:
|
|
case HloOpcode::kMinimum:
|
|
case HloOpcode::kMultiply:
|
|
case HloOpcode::kNe:
|
|
case HloOpcode::kNegate:
|
|
case HloOpcode::kPower:
|
|
case HloOpcode::kRemainder:
|
|
case HloOpcode::kSelect:
|
|
case HloOpcode::kShiftLeft:
|
|
case HloOpcode::kShiftRightArithmetic:
|
|
case HloOpcode::kShiftRightLogical:
|
|
case HloOpcode::kSign:
|
|
case HloOpcode::kSin:
|
|
case HloOpcode::kSubtract:
|
|
case HloOpcode::kTanh:
|
|
case HloOpcode::kTuple:
|
|
return true;
|
|
|
|
// These opcodes have complex or special behavior so just return false.
|
|
case HloOpcode::kFusion:
|
|
case HloOpcode::kRng:
|
|
case HloOpcode::kTrace:
|
|
case HloOpcode::kWhile:
|
|
return false;
|
|
|
|
case HloOpcode::kParameter:
|
|
return parameter_number() == other.parameter_number() &&
|
|
// Check the shape too because `this` and `other` may be in
|
|
// different HloComputations.
|
|
ShapeUtil::Compatible(shape(), other.shape());
|
|
|
|
case HloOpcode::kBatchNormTraining:
|
|
case HloOpcode::kBatchNormInference:
|
|
case HloOpcode::kBatchNormGrad:
|
|
return feature_index() == other.feature_index() &&
|
|
epsilon() == other.epsilon();
|
|
|
|
// A constant is defined by the value in the literal.
|
|
case HloOpcode::kConstant:
|
|
return literal() == other.literal();
|
|
|
|
// A convert result is determined by the primitive type that the operand is
|
|
// converted into.
|
|
case HloOpcode::kConvert:
|
|
return shape().element_type() == other.shape().element_type();
|
|
|
|
// A reduce-precision operation is determined by the bit sizes.
|
|
case HloOpcode::kReducePrecision:
|
|
return exponent_bits() == other.exponent_bits() &&
|
|
mantissa_bits() == other.mantissa_bits();
|
|
|
|
// Convolution has a window and dimensions.
|
|
case HloOpcode::kConvolution:
|
|
return protobuf_util::ProtobufEquals(window(), other.window()) &&
|
|
protobuf_util::ProtobufEquals(
|
|
convolution_dimension_numbers(),
|
|
other.convolution_dimension_numbers());
|
|
|
|
// Reduction results are determined by the reduction dimension and the
|
|
// reduction computation.
|
|
case HloOpcode::kReduce:
|
|
return dimensions() == other.dimensions() &&
|
|
eq_computations(to_apply(), other.to_apply());
|
|
case HloOpcode::kReduceWindow:
|
|
return eq_computations(to_apply(), other.to_apply()) &&
|
|
protobuf_util::ProtobufEquals(window(), other.window());
|
|
|
|
// SelectAndScatter is determined by both select and scatter
|
|
// computation as well as the window configuration.
|
|
case HloOpcode::kSelectAndScatter:
|
|
return eq_computations(select(), other.select()) &&
|
|
eq_computations(scatter(), other.scatter()) &&
|
|
protobuf_util::ProtobufEquals(window(), other.window());
|
|
|
|
case HloOpcode::kReshape:
|
|
return ShapeUtil::Compatible(shape(), other.shape());
|
|
|
|
// Transpose result is determined by the final shape and the permutation.
|
|
case HloOpcode::kTranspose:
|
|
return ShapeUtil::Compatible(shape(), other.shape()) &&
|
|
dimensions() == other.dimensions();
|
|
|
|
// Remaining instructions with special values.
|
|
case HloOpcode::kBitcast:
|
|
return ShapeUtil::Equal(shape(), other.shape());
|
|
case HloOpcode::kBroadcast:
|
|
return ShapeUtil::Compatible(shape(), other.shape()) &&
|
|
dimensions() == other.dimensions();
|
|
case HloOpcode::kConcatenate:
|
|
return dimensions() == other.dimensions();
|
|
case HloOpcode::kGetTupleElement:
|
|
return tuple_index() == other.tuple_index();
|
|
case HloOpcode::kPad:
|
|
return protobuf_util::ProtobufEquals(padding_config(),
|
|
other.padding_config());
|
|
case HloOpcode::kSlice:
|
|
return slice_starts_ == other.slice_starts_ &&
|
|
slice_limits_ == other.slice_limits_;
|
|
case HloOpcode::kDynamicSlice:
|
|
return ShapeUtil::Compatible(shape(), other.shape()) &&
|
|
dynamic_slice_sizes_ == other.dynamic_slice_sizes_;
|
|
case HloOpcode::kDynamicUpdateSlice:
|
|
return ShapeUtil::Compatible(shape(), other.shape());
|
|
case HloOpcode::kCall:
|
|
case HloOpcode::kMap:
|
|
return eq_computations(to_apply(), other.to_apply());
|
|
case HloOpcode::kCustomCall:
|
|
return custom_call_target_ == other.custom_call_target_;
|
|
case HloOpcode::kReverse:
|
|
return dimensions() == other.dimensions();
|
|
|
|
// These opcodes are not yet supported.
|
|
case HloOpcode::kIndex:
|
|
case HloOpcode::kInfeed:
|
|
case HloOpcode::kOutfeed:
|
|
case HloOpcode::kSort:
|
|
case HloOpcode::kUpdate:
|
|
case HloOpcode::kSend:
|
|
case HloOpcode::kRecv:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
bool HloInstruction::IsRank2Transpose() const {
|
|
return (opcode_ == HloOpcode::kTranspose) &&
|
|
dimensions_ == std::vector<int64>({1, 0}) &&
|
|
shape_.dimensions_size() == 2 &&
|
|
std::equal(shape_.dimensions().begin(), shape_.dimensions().end(),
|
|
operands_[0]->shape_.dimensions().rbegin());
|
|
}
|
|
|
|
void HloInstruction::RemoveUser(HloInstruction* user) {
|
|
auto set_it = user_set_.find(user);
|
|
CHECK(set_it != user_set_.end());
|
|
user_set_.erase(set_it);
|
|
// This is linear in the number of the users, but a vector provides a stable
|
|
// iteration order and much faster traversal.
|
|
auto vec_it = std::find(users_.begin(), users_.end(), user);
|
|
CHECK(vec_it != users_.end());
|
|
users_.erase(vec_it);
|
|
}
|
|
|
|
Status HloInstruction::ReplaceUseWith(HloInstruction* user,
|
|
HloInstruction* new_producer) {
|
|
TF_RET_CHECK(ShapeUtil::Compatible(shape(), new_producer->shape()))
|
|
<< "this shape: " << ShapeUtil::HumanString(shape())
|
|
<< ", replacement shape: "
|
|
<< ShapeUtil::HumanString(new_producer->shape());
|
|
|
|
VLOG(3) << "Replacing uses of " << name() << " in " << user->name()
|
|
<< " with " << new_producer->name();
|
|
|
|
RemoveUser(user);
|
|
|
|
TF_RET_CHECK(
|
|
std::count(user->operands_.begin(), user->operands_.end(), this) >= 0);
|
|
std::replace(user->operands_.begin(), user->operands_.end(), this,
|
|
new_producer);
|
|
new_producer->AddUser(user);
|
|
return Status::OK();
|
|
}
|
|
|
|
Status HloInstruction::ReplaceOperandWith(int64 operand_num,
|
|
HloInstruction* new_operand) {
|
|
TF_RET_CHECK(operand_num >= 0);
|
|
TF_RET_CHECK(operand_num < operand_count());
|
|
HloInstruction* old_operand = mutable_operand(operand_num);
|
|
TF_RET_CHECK(
|
|
ShapeUtil::Compatible(old_operand->shape(), new_operand->shape()))
|
|
<< old_operand->shape().ShortDebugString() << " is not compatible with "
|
|
<< new_operand->shape().ShortDebugString();
|
|
operands_[operand_num] = new_operand;
|
|
|
|
VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with "
|
|
<< new_operand->name() << ", was " << old_operand->name();
|
|
|
|
if (std::find(operands_.begin(), operands_.end(), old_operand) ==
|
|
operands_.end()) {
|
|
old_operand->RemoveUser(this);
|
|
}
|
|
new_operand->AddUser(this);
|
|
return Status::OK();
|
|
}
|
|
|
|
Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
|
|
bool new_producer_is_user = false;
|
|
for (HloInstruction* user : users()) {
|
|
if (user == new_producer) {
|
|
// It's possible that new_producer is a user of this instruction as might
|
|
// be the case when replacing an instruction with a kCopy of itself. In
|
|
// this case, don't do the replacement to avoid creating a cycle in the
|
|
// graph. new_producer remains the only user of this instruction.
|
|
new_producer_is_user = true;
|
|
} else {
|
|
std::replace(user->operands_.begin(), user->operands_.end(), this,
|
|
new_producer);
|
|
new_producer->AddUser(user);
|
|
}
|
|
}
|
|
users_.clear();
|
|
user_set_.clear();
|
|
if (new_producer_is_user) {
|
|
AddUser(new_producer);
|
|
}
|
|
if (parent_ && parent_->root_instruction() == this) {
|
|
parent_->set_root_instruction(new_producer);
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
void HloInstruction::DetachFromOperands() {
|
|
VLOG(3) << "DetachFromOperands:\n " << ToString();
|
|
CHECK_EQ(0, user_count());
|
|
// An instruction may be repeated as an operand. To avoid calling RemoveUser
|
|
// twice on the same operand, keep a set of already detached operands.
|
|
std::set<HloInstruction*> detached_operands;
|
|
for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
|
|
HloInstruction* operand = operands_[operand_num];
|
|
if (!ContainsKey(detached_operands, operand)) {
|
|
operand->RemoveUser(this);
|
|
detached_operands.insert(operand);
|
|
}
|
|
operands_[operand_num] = nullptr;
|
|
}
|
|
}
|
|
|
|
HloComputation* HloInstruction::to_apply() const {
|
|
switch (opcode_) {
|
|
case HloOpcode::kCall:
|
|
case HloOpcode::kMap:
|
|
case HloOpcode::kReduceWindow:
|
|
case HloOpcode::kReduce:
|
|
CHECK_EQ(called_computations_.size(), 1);
|
|
return called_computations_[0];
|
|
default:
|
|
LOG(FATAL) << "Invalid opcode for to_apply(): "
|
|
<< HloOpcodeString(opcode());
|
|
}
|
|
}
|
|
|
|
void HloInstruction::set_to_apply(HloComputation* computation) {
|
|
// Don't allow changing the computation for fused instructions so we don't
|
|
// have to recompute called_instructions for the entire fusion instruction.
|
|
CHECK(!IsFused());
|
|
switch (opcode_) {
|
|
case HloOpcode::kCall:
|
|
case HloOpcode::kMap:
|
|
case HloOpcode::kReduceWindow:
|
|
case HloOpcode::kReduce:
|
|
CHECK_EQ(called_computations_.size(), 1);
|
|
called_computations_[0] = computation;
|
|
break;
|
|
default:
|
|
LOG(FATAL) << "Invalid opcode for to_apply(): "
|
|
<< HloOpcodeString(opcode());
|
|
}
|
|
}
|
|
|
|
const string& HloInstruction::custom_call_target() const {
|
|
CHECK_EQ(opcode_, HloOpcode::kCustomCall);
|
|
return custom_call_target_;
|
|
}
|
|
|
|
const string& HloInstruction::outfeed_config() const {
|
|
CHECK_EQ(opcode_, HloOpcode::kOutfeed);
|
|
return outfeed_config_;
|
|
}
|
|
|
|
HloComputation* HloInstruction::while_condition() const {
|
|
CHECK_EQ(HloOpcode::kWhile, opcode_);
|
|
return called_computations_[kConditionComputationIndex];
|
|
}
|
|
|
|
HloComputation* HloInstruction::while_body() const {
|
|
CHECK_EQ(HloOpcode::kWhile, opcode_);
|
|
return called_computations_[kBodyComputationIndex];
|
|
}
|
|
|
|
void HloInstruction::set_while_condition(HloComputation* computation) {
|
|
// Don't allow changing the computation for fused instructions so we don't
|
|
// have to recompute called_instructions for the entire fusion instruction.
|
|
CHECK(!IsFused());
|
|
CHECK_EQ(HloOpcode::kWhile, opcode_);
|
|
called_computations_[kConditionComputationIndex] = computation;
|
|
}
|
|
|
|
void HloInstruction::set_while_body(HloComputation* computation) {
|
|
// Don't allow changing the computation for fused instructions so we don't
|
|
// have to recompute called_instructions for the entire fusion instruction.
|
|
CHECK(!IsFused());
|
|
CHECK_EQ(HloOpcode::kWhile, opcode_);
|
|
called_computations_[kBodyComputationIndex] = computation;
|
|
}
|
|
|
|
HloComputation* HloInstruction::select() const {
|
|
CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
|
|
return called_computations_[kSelectComputationIndex];
|
|
}
|
|
|
|
HloComputation* HloInstruction::scatter() const {
|
|
CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
|
|
return called_computations_[kScatterComputationIndex];
|
|
}
|
|
|
|
void HloInstruction::set_select(HloComputation* computation) {
|
|
// Don't allow changing the computation for fused instructions so we don't
|
|
// have to recompute called_instructions for the entire fusion instruction.
|
|
CHECK(!IsFused());
|
|
CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
|
|
called_computations_[kSelectComputationIndex] = computation;
|
|
}
|
|
|
|
void HloInstruction::set_scatter(HloComputation* computation) {
|
|
// Don't allow changing the computation for fused instructions so we don't
|
|
// have to recompute called_instructions for the entire fusion instruction.
|
|
CHECK(!IsFused());
|
|
CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
|
|
called_computations_[kScatterComputationIndex] = computation;
|
|
}
|
|
|
|
string HloInstruction::SignatureString() const {
|
|
string operands =
|
|
Join(operands_, ", ", [](string* out, HloInstruction* operand) {
|
|
StrAppend(out, ShapeUtil::HumanString(operand->shape()));
|
|
});
|
|
return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape()));
|
|
}
|
|
|
|
string HloInstruction::ExtendedOpcodeStr() const {
|
|
string opc_name = HloOpcodeString(opcode());
|
|
HloOpcode opc = opcode();
|
|
if (HloOpcode::kFusion == opc) {
|
|
opc_name += ":" + xla::ToString(fusion_kind());
|
|
}
|
|
return opc_name;
|
|
}
|
|
|
|
string HloInstruction::ToString(bool compact_operands,
|
|
bool include_metadata) const {
|
|
string result =
|
|
StrCat(name(), " = ", ShapeUtil::HumanStringWithLayout(shape()), " ",
|
|
ExtendedOpcodeStr(), "(", OperandsToString(compact_operands), ")");
|
|
for (const string& extra : ExtraAttributesToString()) {
|
|
StrAppend(&result, ", ", extra);
|
|
}
|
|
if (include_metadata &&
|
|
(!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
|
|
!metadata_.source_file().empty())) {
|
|
StrAppend(&result, " # metadata=", metadata_.ShortDebugString());
|
|
}
|
|
return result;
|
|
}
|
|
|
|
string HloInstruction::OperandsToString(bool compact) const {
|
|
string operands;
|
|
if (opcode() == HloOpcode::kConstant) {
|
|
// For constants, show the actual value in place of an empty operand list.
|
|
if (!ShapeUtil::IsTuple(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) {
|
|
// Literal::ToString emits multidimensional arrays over multiple
|
|
// lines. Compact this into one line by stripping out white space.
|
|
string tmp = literal().ToString();
|
|
std::replace(tmp.begin(), tmp.end(), '\n', ' ');
|
|
std::vector<string> v = tensorflow::str_util::Split(tmp, ' ');
|
|
bool first = true;
|
|
// Concatenate elements in "v" with spaces separating them, but ignoring
|
|
// empty entries.
|
|
for (const auto& s : v) {
|
|
if (s.empty()) {
|
|
continue;
|
|
}
|
|
StrAppend(&operands, (first ? "" : " "), s);
|
|
first = false;
|
|
}
|
|
} else {
|
|
// Do not show large constants or tuples.
|
|
operands = "{...}";
|
|
}
|
|
} else if (opcode() == HloOpcode::kParameter) {
|
|
StrAppend(&operands, parameter_number_);
|
|
} else {
|
|
tensorflow::gtl::ArraySlice<HloInstruction*> slice(operands_);
|
|
const int64 kMaxOperandsToShowIfCompact = 4;
|
|
if (compact && slice.size() > kMaxOperandsToShowIfCompact) {
|
|
slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
|
|
}
|
|
operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) {
|
|
*out += ShapeUtil::HumanStringWithLayout(operand->shape());
|
|
if (!compact) {
|
|
StrAppend(out, " ", operand->name());
|
|
}
|
|
});
|
|
const int64 remaining = operands_.size() - slice.size();
|
|
if (slice.size() != operands_.size()) {
|
|
StrAppend(&operands, ", ...(+", remaining, ")");
|
|
}
|
|
}
|
|
return operands;
|
|
}
|
|
|
|
std::vector<string> HloInstruction::ExtraAttributesToString() const {
|
|
std::vector<string> extra;
|
|
if (CanHaveDimensionsField()) {
|
|
extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}"));
|
|
}
|
|
if (window_ != nullptr) {
|
|
extra.push_back(window_util::ToString(*window_));
|
|
}
|
|
if (padding_config_ != nullptr) {
|
|
extra.push_back(StrCat("padding=", padding_config_->ShortDebugString()));
|
|
}
|
|
if (!slice_starts_.empty() && !slice_limits_.empty()) {
|
|
std::vector<string> bounds;
|
|
bounds.reserve(slice_starts_.size());
|
|
const bool omit_stride =
|
|
std::all_of(slice_strides_.begin(), slice_strides_.end(),
|
|
[](int64 stride) { return stride == 1; });
|
|
for (int i = 0; i < slice_starts_.size(); ++i) {
|
|
string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]);
|
|
bounds.push_back(StrCat("[", slice_starts_[i], ":", slice_limits_[i],
|
|
stride_str, "]"));
|
|
}
|
|
extra.push_back(StrCat("slice={", Join(bounds, ", "), "}"));
|
|
}
|
|
|
|
if (convolution_dimension_numbers_ != nullptr) {
|
|
extra.push_back(ConvolutionDimensionNumbersToString());
|
|
}
|
|
|
|
if (opcode() == HloOpcode::kWhile) {
|
|
extra.push_back(StrCat("condition=", while_condition()->name()));
|
|
extra.push_back(StrCat("body=", while_body()->name()));
|
|
} else if (opcode() == HloOpcode::kSelectAndScatter) {
|
|
extra.push_back(StrCat("select=", select()->name()));
|
|
extra.push_back(StrCat("scatter=", scatter()->name()));
|
|
} else if (!called_computations().empty()) {
|
|
extra.push_back(StrCat(
|
|
"calls=", Join(called_computations(), ", ",
|
|
[](string* out, const HloComputation* computation) {
|
|
StrAppend(out, computation->name());
|
|
})));
|
|
}
|
|
|
|
if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv) {
|
|
extra.push_back(StrCat("channel_id=", channel_id_));
|
|
}
|
|
|
|
if (opcode() == HloOpcode::kGetTupleElement) {
|
|
extra.push_back(StrCat("index=", tuple_index()));
|
|
}
|
|
if (!control_successors_.empty()) {
|
|
extra.push_back(StrCat(
|
|
"control-successors=",
|
|
Join(control_successors_, ", ", [](string* out, HloInstruction* succ) {
|
|
StrAppend(out, succ->name());
|
|
})));
|
|
}
|
|
return extra;
|
|
}
|
|
|
|
string HloInstruction::ToShortString() const {
|
|
return StrCat(name(), " = ", HloOpcodeString(opcode()), "(",
|
|
Join(operands_, ", ",
|
|
[](string* out, HloInstruction* operand) {
|
|
StrAppend(out, operand->name());
|
|
}),
|
|
")");
|
|
}
|
|
|
|
HloInstructionProto HloInstruction::ToProto() const {
|
|
HloInstructionProto proto;
|
|
proto.set_name(name_);
|
|
proto.set_opcode(HloOpcodeString(opcode_));
|
|
*proto.mutable_shape() = shape_;
|
|
for (const HloInstruction* operand : operands_) {
|
|
*proto.add_operand_names() = operand->name();
|
|
}
|
|
for (const HloInstruction* control : control_predecessors_) {
|
|
*proto.add_control_predecessor_names() = control->name();
|
|
}
|
|
|
|
*proto.mutable_metadata() = metadata_;
|
|
if (literal_ != nullptr) {
|
|
*proto.mutable_literal() = literal_->ToProto();
|
|
}
|
|
proto.set_parameter_number(parameter_number_);
|
|
proto.set_parameter_name(parameter_name_);
|
|
if (opcode() == HloOpcode::kFusion) {
|
|
proto.set_fusion_kind(xla::ToString(fusion_kind()));
|
|
*proto.mutable_fused_instructions_computation() =
|
|
fused_instructions_computation()->ToProto();
|
|
} else {
|
|
for (const HloComputation* computation : called_computations_) {
|
|
*proto.add_called_computation_names() = computation->name();
|
|
}
|
|
}
|
|
|
|
proto.set_tuple_index(tuple_index_);
|
|
for (int64 dimension : dimensions_) {
|
|
proto.add_dimensions(dimension);
|
|
}
|
|
if (window_ != nullptr) {
|
|
*proto.mutable_window() = *window_;
|
|
}
|
|
if (convolution_dimension_numbers_ != nullptr) {
|
|
*proto.mutable_convolution_dimension_numbers() =
|
|
*convolution_dimension_numbers_;
|
|
}
|
|
for (int i = 0; i < slice_starts_.size(); ++i) {
|
|
auto* slice_dimension = proto.add_slice_dimensions();
|
|
slice_dimension->set_start(slice_starts_[i]);
|
|
slice_dimension->set_limit(slice_limits_[i]);
|
|
slice_dimension->set_stride(slice_strides_[i]);
|
|
}
|
|
proto.set_exponent_bits(exponent_bits_);
|
|
proto.set_mantissa_bits(mantissa_bits_);
|
|
for (int64 slice_size : dynamic_slice_sizes_) {
|
|
proto.add_dynamic_slice_sizes(slice_size);
|
|
}
|
|
if (padding_config_ != nullptr) {
|
|
*proto.mutable_padding_config() = *padding_config_;
|
|
}
|
|
proto.set_outfeed_config(outfeed_config_);
|
|
if (opcode() == HloOpcode::kRng) {
|
|
proto.set_distribution(distribution_);
|
|
}
|
|
proto.set_epsilon(epsilon_);
|
|
proto.set_feature_index(feature_index_);
|
|
proto.set_channel_id(channel_id_);
|
|
proto.set_infeed_config(infeed_config_);
|
|
proto.set_custom_call_target(custom_call_target_);
|
|
*proto.mutable_outfeed_shape() = outfeed_shape_;
|
|
|
|
return proto;
|
|
}
|
|
|
|
string HloInstruction::ToCategory() const {
|
|
if (opcode() == HloOpcode::kTranspose || opcode() == HloOpcode::kCopy ||
|
|
opcode() == HloOpcode::kReshape) {
|
|
return "data formatting";
|
|
}
|
|
|
|
if (opcode() == HloOpcode::kConvolution) {
|
|
string category = "convolution";
|
|
if (window_util::HasBaseDilation(window())) {
|
|
category += " base-dilated";
|
|
}
|
|
if (window_util::HasWindowDilation(window())) {
|
|
category += " window-dilated";
|
|
}
|
|
return category;
|
|
}
|
|
|
|
if (opcode() == HloOpcode::kFusion) {
|
|
if (operands().size() == 2) {
|
|
bool saw_rank_1 = false;
|
|
bool saw_higher_rank = false;
|
|
for (const auto* operand : operands()) {
|
|
saw_rank_1 |= ShapeUtil::Rank(operand->shape()) == 1;
|
|
saw_higher_rank |= ShapeUtil::Rank(operand->shape()) > 1;
|
|
}
|
|
if (saw_rank_1 && saw_higher_rank) {
|
|
return "rank-1-broadcast binary fusion";
|
|
}
|
|
}
|
|
switch (fusion_kind()) {
|
|
case FusionKind::kLoop:
|
|
if (IsElementwise()) {
|
|
return "elementwise fusion";
|
|
} else {
|
|
return "non-elementwise fusion";
|
|
}
|
|
case FusionKind::kInput:
|
|
return "input fusion";
|
|
case FusionKind::kOutput:
|
|
return "output fusion";
|
|
case FusionKind::kTransposeDot:
|
|
return "dot fusion";
|
|
case FusionKind::kConvBackwardFilter:
|
|
case FusionKind::kConvBackwardInput:
|
|
return "convolution fusion";
|
|
case FusionKind::kCustom:
|
|
return "custom fusion";
|
|
}
|
|
}
|
|
|
|
if (IsElementwise() && opcode() != HloOpcode::kFusion) {
|
|
return "non-fusion elementwise";
|
|
}
|
|
|
|
return HloOpcodeString(opcode());
|
|
}
|
|
|
|
HloInstruction* HloInstruction::tracing() const { return trace_instruction_; }
|
|
|
|
void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
|
|
trace_instruction_ = trace_instruction;
|
|
}
|
|
|
|
string HloInstruction::TracingTag() const {
|
|
CHECK_EQ(HloOpcode::kTrace, opcode());
|
|
CHECK(literal_ != nullptr);
|
|
return literal_->u8s_string();
|
|
}
|
|
|
|
bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
|
|
|
|
bool HloInstruction::IsFusable() const {
|
|
// Instructions which are traced should not be fused.
|
|
if (tracing()) {
|
|
return false;
|
|
}
|
|
|
|
// Some kinds of instructions don't make sense to fuse.
|
|
switch (opcode_) {
|
|
case HloOpcode::kInfeed:
|
|
case HloOpcode::kOutfeed:
|
|
case HloOpcode::kParameter:
|
|
case HloOpcode::kTrace:
|
|
case HloOpcode::kSend:
|
|
case HloOpcode::kRecv:
|
|
return false;
|
|
// Only fuse Rng if it is used once, otherwise the random numbers generated
|
|
// will be different in each fusion. If it is the root (user count = 0)
|
|
// then it is the equivalent of having one user.
|
|
case HloOpcode::kRng:
|
|
return users_.size() <= 1;
|
|
default:
|
|
return true;
|
|
}
|
|
}
|
|
|
|
HloComputation* HloInstruction::fused_instructions_computation() const {
|
|
CHECK_EQ(opcode_, HloOpcode::kFusion);
|
|
CHECK(!called_computations_.empty());
|
|
auto* fused_instructions_computation = called_computations_.front();
|
|
CHECK(fused_instructions_computation->IsFusionComputation());
|
|
return fused_instructions_computation;
|
|
}
|
|
|
|
HloInstruction* HloInstruction::fused_expression_root() const {
|
|
CHECK_EQ(opcode_, HloOpcode::kFusion);
|
|
return fused_instructions_computation()->root_instruction();
|
|
}
|
|
|
|
HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const {
|
|
CHECK_EQ(opcode_, HloOpcode::kFusion);
|
|
return fused_instructions_computation()->parameter_instruction(
|
|
parameter_number);
|
|
}
|
|
|
|
const std::vector<HloInstruction*>& HloInstruction::fused_parameters() const {
|
|
CHECK_EQ(opcode_, HloOpcode::kFusion);
|
|
return fused_instructions_computation()->parameter_instructions();
|
|
}
|
|
|
|
const tensorflow::gtl::iterator_range<UnwrappingIterator<
|
|
std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
|
|
HloInstruction::fused_instructions() const {
|
|
CHECK_EQ(opcode_, HloOpcode::kFusion);
|
|
const HloComputation* subcomp = fused_instructions_computation();
|
|
return subcomp->instructions();
|
|
}
|
|
|
|
const tensorflow::gtl::iterator_range<
|
|
UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
|
|
HloInstruction::fused_instructions() {
|
|
CHECK_EQ(opcode_, HloOpcode::kFusion);
|
|
return fused_instructions_computation()->instructions();
|
|
}
|
|
|
|
int64 HloInstruction::fused_instruction_count() const {
|
|
return fused_instructions_computation()->instruction_count();
|
|
}
|
|
|
|
HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
|
|
: unique_id_(-1),
|
|
opcode_(opcode),
|
|
shape_(shape),
|
|
name_("%" + HloOpcodeString(opcode)) {
|
|
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
|
|
}
|
|
|
|
Status HloInstruction::Visit(DfsHloVisitor* visitor) {
|
|
switch (opcode_) {
|
|
case HloOpcode::kAbs:
|
|
return visitor->HandleAbs(this, operands_[0]);
|
|
case HloOpcode::kRoundNearestAfz:
|
|
return visitor->HandleRound(this);
|
|
case HloOpcode::kBatchNormTraining:
|
|
return visitor->HandleBatchNormTraining(this);
|
|
case HloOpcode::kBatchNormInference:
|
|
return visitor->HandleBatchNormInference(this);
|
|
case HloOpcode::kBatchNormGrad:
|
|
return visitor->HandleBatchNormGrad(this);
|
|
case HloOpcode::kSign:
|
|
return visitor->HandleSign(this, operands_[0]);
|
|
case HloOpcode::kConstant:
|
|
return visitor->HandleConstant(this, *literal_);
|
|
case HloOpcode::kGetTupleElement:
|
|
return visitor->HandleGetTupleElement(this, operands_[0]);
|
|
case HloOpcode::kParameter:
|
|
return visitor->HandleParameter(this);
|
|
case HloOpcode::kEq:
|
|
case HloOpcode::kGe:
|
|
case HloOpcode::kGt:
|
|
case HloOpcode::kLe:
|
|
case HloOpcode::kLt:
|
|
case HloOpcode::kNe:
|
|
return visitor->HandleCompare(this, opcode_, operands_[0], operands_[1]);
|
|
case HloOpcode::kAdd:
|
|
return visitor->HandleAdd(this, operands_[0], operands_[1]);
|
|
case HloOpcode::kDivide:
|
|
return visitor->HandleDivide(this, operands_[0], operands_[1]);
|
|
case HloOpcode::kSubtract:
|
|
return visitor->HandleSubtract(this, operands_[0], operands_[1]);
|
|
case HloOpcode::kMaximum:
|
|
return visitor->HandleMaximum(this);
|
|
case HloOpcode::kMinimum:
|
|
return visitor->HandleMinimum(this);
|
|
case HloOpcode::kAnd:
|
|
return visitor->HandleAnd(this, operands_[0], operands_[1]);
|
|
case HloOpcode::kOr:
|
|
return visitor->HandleOr(this, operands_[0], operands_[1]);
|
|
case HloOpcode::kShiftLeft:
|
|
return visitor->HandleShiftLeft(this, operands_[0], operands_[1]);
|
|
case HloOpcode::kShiftRightArithmetic:
|
|
return visitor->HandleShiftRightArithmetic(this, operands_[0],
|
|
operands_[1]);
|
|
case HloOpcode::kShiftRightLogical:
|
|
return visitor->HandleShiftRightLogical(this, operands_[0], operands_[1]);
|
|
case HloOpcode::kConcatenate:
|
|
return visitor->HandleConcatenate(this, operands_);
|
|
case HloOpcode::kConvert:
|
|
return visitor->HandleConvert(this);
|
|
case HloOpcode::kCopy:
|
|
return visitor->HandleCopy(this);
|
|
case HloOpcode::kMultiply:
|
|
return visitor->HandleMultiply(this, operands_[0], operands_[1]);
|
|
case HloOpcode::kDot:
|
|
return visitor->HandleDot(this, operands_[0], operands_[1]);
|
|
case HloOpcode::kPower:
|
|
return visitor->HandlePower(this, operands_[0], operands_[1]);
|
|
case HloOpcode::kRemainder:
|
|
return visitor->HandleRemainder(this, operands_[0], operands_[1]);
|
|
case HloOpcode::kSelect:
|
|
return visitor->HandleSelect(this, operands_[0], operands_[1],
|
|
operands_[2]);
|
|
case HloOpcode::kConvolution:
|
|
return visitor->HandleConvolution(this, operands_[0], operands_[1],
|
|
window());
|
|
case HloOpcode::kCrossReplicaSum:
|
|
return visitor->HandleCrossReplicaSum(this);
|
|
case HloOpcode::kTuple:
|
|
return visitor->HandleTuple(this, operands_);
|
|
case HloOpcode::kMap:
|
|
return visitor->HandleMap(this, operands_, to_apply(), {});
|
|
case HloOpcode::kClamp:
|
|
return visitor->HandleClamp(this, operands_[0], operands_[1],
|
|
operands_[2]);
|
|
case HloOpcode::kReduce:
|
|
return visitor->HandleReduce(this, operands_[0], operands_[1],
|
|
dimensions_, to_apply());
|
|
case HloOpcode::kReduceWindow:
|
|
return visitor->HandleReduceWindow(this, operands_[0], window(),
|
|
to_apply());
|
|
case HloOpcode::kSelectAndScatter:
|
|
return visitor->HandleSelectAndScatter(this);
|
|
case HloOpcode::kNegate:
|
|
return visitor->HandleNegate(this, operands_[0]);
|
|
case HloOpcode::kExp:
|
|
return visitor->HandleExp(this, operands_[0]);
|
|
case HloOpcode::kFloor:
|
|
return visitor->HandleFloor(this, operands_[0]);
|
|
case HloOpcode::kCeil:
|
|
return visitor->HandleCeil(this, operands_[0]);
|
|
case HloOpcode::kLog:
|
|
return visitor->HandleLog(this, operands_[0]);
|
|
case HloOpcode::kTanh:
|
|
return visitor->HandleTanh(this, operands_[0]);
|
|
case HloOpcode::kCos:
|
|
return visitor->HandleCos(this, operands_[0]);
|
|
case HloOpcode::kSin:
|
|
return visitor->HandleSin(this, operands_[0]);
|
|
case HloOpcode::kIsFinite:
|
|
return visitor->HandleIsFinite(this, operands_[0]);
|
|
case HloOpcode::kNot:
|
|
return visitor->HandleNot(this, operands_[0]);
|
|
case HloOpcode::kBitcast:
|
|
return visitor->HandleBitcast(this);
|
|
case HloOpcode::kBroadcast:
|
|
return visitor->HandleBroadcast(this);
|
|
case HloOpcode::kPad:
|
|
return visitor->HandlePad(this);
|
|
case HloOpcode::kReshape:
|
|
return visitor->HandleReshape(this);
|
|
case HloOpcode::kTranspose:
|
|
return visitor->HandleTranspose(this);
|
|
case HloOpcode::kReverse:
|
|
return visitor->HandleReverse(this, operands_[0]);
|
|
case HloOpcode::kReducePrecision:
|
|
return visitor->HandleReducePrecision(this);
|
|
case HloOpcode::kSlice:
|
|
return visitor->HandleSlice(this, operands_[0]);
|
|
case HloOpcode::kDynamicSlice:
|
|
return visitor->HandleDynamicSlice(this, operands_[0], operands_[1]);
|
|
case HloOpcode::kDynamicUpdateSlice:
|
|
return visitor->HandleDynamicUpdateSlice(this, operands_[0], operands_[1],
|
|
operands_[2]);
|
|
case HloOpcode::kSort:
|
|
return visitor->HandleSort(this, operands_[0]);
|
|
case HloOpcode::kInfeed:
|
|
return visitor->HandleInfeed(this);
|
|
case HloOpcode::kOutfeed:
|
|
return visitor->HandleOutfeed(this);
|
|
case HloOpcode::kRng:
|
|
return visitor->HandleRng(this, distribution_);
|
|
case HloOpcode::kWhile:
|
|
return visitor->HandleWhile(this);
|
|
case HloOpcode::kFusion:
|
|
return visitor->HandleFusion(this);
|
|
case HloOpcode::kCall:
|
|
return visitor->HandleCall(this);
|
|
case HloOpcode::kCustomCall:
|
|
return visitor->HandleCustomCall(this, operands_, custom_call_target_);
|
|
case HloOpcode::kSend:
|
|
return visitor->HandleSend(this);
|
|
case HloOpcode::kRecv:
|
|
return visitor->HandleRecv(this);
|
|
|
|
// These opcodes are not handled here.
|
|
case HloOpcode::kIndex:
|
|
case HloOpcode::kTrace:
|
|
case HloOpcode::kUpdate:
|
|
break;
|
|
}
|
|
return Unimplemented("unhandled HloOpcode for DfsHloVisitor: %s",
|
|
HloOpcodeString(opcode_).c_str());
|
|
}
|
|
|
|
using DFSStack =
|
|
tensorflow::gtl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
|
|
|
|
// Push "child" onto the dfs_stack if not already visited. Returns false if a
|
|
// cycle was detected, and true otherwise.
|
|
inline bool PushDFSChild(DfsHloVisitor* visitor, DFSStack* dfs_stack,
|
|
HloInstruction* child) {
|
|
CHECK(child != nullptr);
|
|
const int id = child->unique_id();
|
|
CHECK_GE(id, 0) << "instruction may not have a parent computation";
|
|
switch (visitor->GetVisitState(id)) {
|
|
case DfsHloVisitor::kVisiting:
|
|
return false;
|
|
|
|
case DfsHloVisitor::kVisited:
|
|
// Nothing to do
|
|
return true;
|
|
|
|
case DfsHloVisitor::kNotVisited:
|
|
dfs_stack->push_back(std::make_pair(id, child));
|
|
return true;
|
|
}
|
|
}
|
|
|
|
using InternalCompareFunction =
|
|
std::function<bool(std::pair<int, const HloInstruction*>,
|
|
std::pair<int, const HloInstruction*>)>;
|
|
static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor,
|
|
const InternalCompareFunction* operand_order,
|
|
bool ignore_control_predecessors) {
|
|
visitor->ReserveVisitStates(root->GetModule()->NumUniqueInstructionIds());
|
|
|
|
// dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>.
|
|
//
|
|
// We need to keep track of both the id and the instruction because
|
|
// instructions can get deleted while they are on the stack, so we
|
|
// can't always use the (potentiall dead) instruction object to grab
|
|
// its id.
|
|
DFSStack dfs_stack;
|
|
dfs_stack.emplace_back(root->unique_id(), root);
|
|
|
|
do {
|
|
DCHECK(!dfs_stack.empty());
|
|
|
|
int current_id = dfs_stack.back().first;
|
|
HloInstruction* current_node = dfs_stack.back().second;
|
|
CHECK_GE(current_id, 0) << current_id << ": " << current_node
|
|
<< ": instruction may not have parent computation";
|
|
DfsHloVisitor::VisitState visit_state = visitor->GetVisitState(current_id);
|
|
if (visit_state == DfsHloVisitor::kVisited) {
|
|
dfs_stack.pop_back();
|
|
VLOG(3) << "Not visiting HLO " << current_node->name()
|
|
<< " as it was already visited.";
|
|
continue;
|
|
}
|
|
|
|
if (visit_state == DfsHloVisitor::kVisiting) {
|
|
dfs_stack.pop_back();
|
|
|
|
TF_RETURN_IF_ERROR(visitor->Preprocess(current_node));
|
|
VLOG(2) << "Visiting HLO " << current_node->name();
|
|
TF_RETURN_IF_ERROR(current_node->Visit(visitor));
|
|
visitor->SetVisitState(current_id, DfsHloVisitor::kVisited);
|
|
TF_RETURN_IF_ERROR(visitor->Postprocess(current_node));
|
|
continue;
|
|
}
|
|
|
|
visitor->SetVisitState(current_id, DfsHloVisitor::kVisiting);
|
|
|
|
const size_t old_dfs_stack_size = dfs_stack.size();
|
|
for (HloInstruction* child : current_node->operands()) {
|
|
if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
|
|
return FailedPrecondition(
|
|
"A cycle is detected while visiting instruction %s",
|
|
current_node->ToString().c_str());
|
|
}
|
|
}
|
|
|
|
if (!ignore_control_predecessors) {
|
|
for (HloInstruction* child : current_node->control_predecessors()) {
|
|
if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
|
|
return FailedPrecondition(
|
|
"A cycle is detected while visiting instruction %s",
|
|
current_node->ToString().c_str());
|
|
}
|
|
}
|
|
}
|
|
|
|
if (operand_order != nullptr) {
|
|
std::sort(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end(),
|
|
*operand_order);
|
|
}
|
|
|
|
// This makes the traversal order the same as what you'd expect
|
|
// out of a recursive algorithm.
|
|
std::reverse(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end());
|
|
} while (!dfs_stack.empty());
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
Status HloInstruction::Accept(DfsHloVisitor* visitor, bool call_finish_visit,
|
|
bool ignore_control_predecessors) {
|
|
VLOG(3) << "HloInstruction::Accept(" << name() << ")";
|
|
TF_RETURN_IF_ERROR(
|
|
PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors));
|
|
if (call_finish_visit) {
|
|
TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status HloInstruction::AcceptWithOperandOrder(
|
|
DfsHloVisitor* visitor, const CompareFunction& operand_order,
|
|
bool call_finish_visit) {
|
|
VLOG(2) << "HloInstruction::AcceptWithOperandOrder(" << name() << ")";
|
|
InternalCompareFunction func = [&operand_order](
|
|
std::pair<int, const HloInstruction*> a,
|
|
std::pair<int, const HloInstruction*> b) {
|
|
// Call the client's comparison function on the actual HloInstruction*
|
|
// objects (ignoring the internal ids we also have in our stack entries)
|
|
return operand_order(a.second, b.second);
|
|
};
|
|
TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &func,
|
|
/*ignore_control_predecessors=*/false));
|
|
if (call_finish_visit) {
|
|
VLOG(3) << "HloInstruction::AcceptWithOperandOrder BEFORE FINISH VISIT";
|
|
TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
|
|
VLOG(3) << "HloInstruction::AcceptWithOperandOrder AFTER FINISH VISIT";
|
|
}
|
|
VLOG(2) << "HloInstruction::AcceptWithOperandOrder EXIT";
|
|
return Status::OK();
|
|
}
|
|
|
|
namespace {
|
|
|
|
// Returns true if the given order is a topological sort of the instructions
|
|
// it contains.
|
|
bool OrderIsTopologicalSort(const std::vector<const HloInstruction*>& order) {
|
|
// Create a map from instruction to its position in 'order'.
|
|
std::unordered_map<const HloInstruction*, int> order_position;
|
|
for (int i = 0; i < order.size(); i++) {
|
|
if (!order_position.insert({order[i], i}).second) {
|
|
// Instruction order[i] is duplicated in the order.
|
|
return false;
|
|
}
|
|
}
|
|
// Verify that the operand of each instruction in the order is also in the
|
|
// order *and* the operand's position is earlier (defs are before uses for
|
|
// all ops).
|
|
for (auto* instruction : order) {
|
|
for (auto* operand : instruction->operands()) {
|
|
if (!ContainsKey(order_position, operand) ||
|
|
order_position.at(operand) >= order_position.at(instruction)) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
Status HloInstruction::Accept(
|
|
const FunctionVisitor::VisitorFunction& visitor_func) {
|
|
FunctionVisitor visitor(visitor_func);
|
|
return this->Accept(&visitor);
|
|
}
|
|
|
|
Status HloInstruction::AcceptOrdered(
|
|
DfsHloVisitor* visitor, const std::vector<const HloInstruction*>& order) {
|
|
VLOG(2) << "HloInstruction::AcceptOrdered(" << name() << ")";
|
|
TF_RET_CHECK(OrderIsTopologicalSort(order));
|
|
|
|
// Compute the predecessors of this instruction.
|
|
std::unordered_set<const HloInstruction*> predecessors;
|
|
TF_RETURN_IF_ERROR(this->Accept([&predecessors](HloInstruction* instruction) {
|
|
predecessors.insert(instruction);
|
|
return Status::OK();
|
|
}));
|
|
|
|
for (auto* const_instruction : order) {
|
|
if (!ContainsKey(predecessors, const_instruction)) {
|
|
// Instruction is not a predecessors of 'this'.
|
|
continue;
|
|
}
|
|
|
|
// The visitor can mark instructions as visited to skip particular
|
|
// instructions.
|
|
if (visitor->DidVisit(*const_instruction)) {
|
|
VLOG(3) << "Not visiting HLO " << const_instruction->name()
|
|
<< " as it was already visited.";
|
|
continue;
|
|
}
|
|
|
|
HloInstruction* instruction =
|
|
const_cast<HloInstruction*>(const_instruction);
|
|
|
|
TF_RETURN_IF_ERROR(visitor->Preprocess(instruction));
|
|
VLOG(2) << "Visiting HLO " << instruction->name();
|
|
TF_RETURN_IF_ERROR(instruction->Visit(visitor));
|
|
visitor->SetVisited(*instruction);
|
|
TF_RETURN_IF_ERROR(visitor->Postprocess(instruction));
|
|
}
|
|
|
|
return visitor->FinishVisit(this);
|
|
}
|
|
|
|
const Shape& HloInstruction::outfeed_shape() const {
|
|
DCHECK_EQ(opcode_, HloOpcode::kOutfeed);
|
|
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
|
|
return outfeed_shape_;
|
|
}
|
|
|
|
const Shape& HloInstruction::shape() const {
|
|
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
|
|
return shape_;
|
|
}
|
|
|
|
std::vector<int64> HloInstruction::OperandIndices(
|
|
const HloInstruction* operand) const {
|
|
std::vector<int64> result;
|
|
for (int64 i = 0; i < operand_count(); ++i) {
|
|
if (this->operand(i) == operand) {
|
|
result.push_back(i);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
bool HloInstruction::IsElementwiseBinary() const {
|
|
switch (opcode_) {
|
|
// Binary elementwise operations. If you update this, please update
|
|
// IsElementwise() accordingly.
|
|
case HloOpcode::kAdd:
|
|
case HloOpcode::kDivide:
|
|
case HloOpcode::kEq:
|
|
case HloOpcode::kGe:
|
|
case HloOpcode::kGt:
|
|
case HloOpcode::kLe:
|
|
case HloOpcode::kLt:
|
|
case HloOpcode::kMaximum:
|
|
case HloOpcode::kMinimum:
|
|
case HloOpcode::kMultiply:
|
|
case HloOpcode::kNe:
|
|
case HloOpcode::kPower:
|
|
case HloOpcode::kRemainder:
|
|
case HloOpcode::kSubtract:
|
|
case HloOpcode::kAnd:
|
|
case HloOpcode::kOr:
|
|
case HloOpcode::kShiftLeft:
|
|
case HloOpcode::kShiftRightArithmetic:
|
|
case HloOpcode::kShiftRightLogical:
|
|
return true;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
bool HloInstruction::IsElementwise() const {
|
|
switch (opcode_) {
|
|
// Nullary elementwise operations.
|
|
case HloOpcode::kConstant:
|
|
return true;
|
|
|
|
// Unary elementwise operations.
|
|
case HloOpcode::kAbs:
|
|
case HloOpcode::kRoundNearestAfz:
|
|
case HloOpcode::kCeil:
|
|
case HloOpcode::kConvert:
|
|
case HloOpcode::kCopy:
|
|
case HloOpcode::kCos:
|
|
case HloOpcode::kExp:
|
|
case HloOpcode::kFloor:
|
|
case HloOpcode::kIsFinite:
|
|
case HloOpcode::kLog:
|
|
case HloOpcode::kNot:
|
|
case HloOpcode::kNegate:
|
|
case HloOpcode::kReducePrecision:
|
|
case HloOpcode::kSign:
|
|
case HloOpcode::kSin:
|
|
case HloOpcode::kTanh:
|
|
return true;
|
|
|
|
// Binary elementwise operations, the same as in IsElementwiseBinary().
|
|
// If you update this, please update IsElementwiseBinary() accordingly.
|
|
case HloOpcode::kAdd:
|
|
case HloOpcode::kDivide:
|
|
case HloOpcode::kEq:
|
|
case HloOpcode::kGe:
|
|
case HloOpcode::kGt:
|
|
case HloOpcode::kLe:
|
|
case HloOpcode::kLt:
|
|
case HloOpcode::kMaximum:
|
|
case HloOpcode::kMinimum:
|
|
case HloOpcode::kMultiply:
|
|
case HloOpcode::kNe:
|
|
case HloOpcode::kPower:
|
|
case HloOpcode::kRemainder:
|
|
case HloOpcode::kSubtract:
|
|
case HloOpcode::kAnd:
|
|
case HloOpcode::kOr:
|
|
case HloOpcode::kShiftLeft:
|
|
case HloOpcode::kShiftRightArithmetic:
|
|
case HloOpcode::kShiftRightLogical:
|
|
return true;
|
|
|
|
// Ternary elementwise operations.
|
|
case HloOpcode::kSelect:
|
|
return !ShapeUtil::IsTuple(shape_);
|
|
case HloOpcode::kClamp:
|
|
return true;
|
|
|
|
// Other operations.
|
|
case HloOpcode::kRng:
|
|
case HloOpcode::kMap:
|
|
return true;
|
|
case HloOpcode::kFusion:
|
|
if (fusion_kind() != FusionKind::kLoop) {
|
|
return false;
|
|
}
|
|
for (auto* fused : fused_instructions()) {
|
|
if (fused->opcode() != HloOpcode::kParameter &&
|
|
!fused->IsElementwise()) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const {
|
|
CHECK(IsElementwise());
|
|
return !ShapeUtil::Equal(shape(), operand(operand_idx)->shape());
|
|
}
|
|
|
|
namespace {
|
|
bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
|
|
const HloInstruction* operand) {
|
|
std::vector<int64> operand_indices = instruction->OperandIndices(operand);
|
|
return std::all_of(
|
|
operand_indices.begin(), operand_indices.end(),
|
|
[instruction](int64 operand_index) {
|
|
return instruction->IsElementwiseOnOperand(operand_index);
|
|
});
|
|
}
|
|
} // namespace
|
|
|
|
bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const {
|
|
// For all instructions other than kFusion, being elementwise on one of the
|
|
// operands is equivalent to being elementwise on all the operands.
|
|
if (opcode() != HloOpcode::kFusion) {
|
|
return IsElementwise();
|
|
}
|
|
|
|
CHECK_EQ(HloOpcode::kFusion, opcode());
|
|
if (fusion_kind() != FusionKind::kLoop) {
|
|
return false;
|
|
}
|
|
|
|
// A loop-fusion is elementwise on an operand if all operations (computed
|
|
// using BFS) between the operand and the fused root are elementwise.
|
|
std::deque<HloInstruction*> worklist;
|
|
std::unordered_set<const HloInstruction*> visited;
|
|
worklist.push_back(fused_parameter(operand_idx));
|
|
visited.insert(fused_parameter(operand_idx));
|
|
while (!worklist.empty()) {
|
|
HloInstruction* operand = worklist.front();
|
|
worklist.pop_front();
|
|
for (HloInstruction* user : operand->users()) {
|
|
CHECK_GE(user->unique_id(), 0);
|
|
if (ContainsKey(visited, user)) {
|
|
continue;
|
|
}
|
|
if (user->IsElementwise() ||
|
|
IsInstructionElementwiseOnOperand(user, operand)) {
|
|
worklist.push_back(user);
|
|
visited.insert(user);
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// A helper class for memoized, recursive computation of HloOpcode::kFusion
|
|
// in HloInstruction::OperandElementUse below.
|
|
class HloInstruction::FusionReusesParamElements {
|
|
public:
|
|
using UseKind = HloInstruction::UseKind;
|
|
|
|
// We could rather iterate backwards thru fused_instructions_ here, as it is
|
|
// in reverse postorder, and compute whether each fused instruction reuses
|
|
// the value of this parameter, which would save stack space but not allow
|
|
// us to finish early if we find a reuse.
|
|
static UseKind Compute(int64 i, const HloInstruction& hlo) {
|
|
tensorflow::gtl::FlatMap<const HloInstruction*, UseKind> memoization_cache;
|
|
return ComputeInternal(i, hlo, &memoization_cache);
|
|
}
|
|
|
|
private:
|
|
static UseKind ComputeInternal(
|
|
int64 i, const HloInstruction& hlo,
|
|
tensorflow::gtl::FlatMap<const HloInstruction*, UseKind>* cache) {
|
|
if (hlo.opcode_ == HloOpcode::kParameter && hlo.parameter_number_ == i) {
|
|
return UseKind::kUse;
|
|
}
|
|
|
|
auto p = cache->emplace(&hlo, UseKind{});
|
|
auto value_it = p.first;
|
|
const bool key_is_new = p.second;
|
|
|
|
if (key_is_new) {
|
|
for (int64 j = 0; j < hlo.operands_.size(); ++j) {
|
|
UseKind old_val = value_it->second;
|
|
|
|
// The next operation invalidates iterators.
|
|
UseKind new_val =
|
|
Plus(old_val, std::min(hlo.OperandElementUse(j),
|
|
ComputeInternal(i, *hlo.operand(j), cache)));
|
|
|
|
// Re-acquire the iterator. We could work harder to do this only if
|
|
// absolutely necessary, but this code is not hot enough to warrant
|
|
// that.
|
|
value_it = cache->find(&hlo);
|
|
value_it->second = new_val;
|
|
}
|
|
}
|
|
return value_it->second;
|
|
}
|
|
|
|
// Fold operation for UseKinds.
|
|
static UseKind Plus(UseKind a, UseKind b) {
|
|
if (a == UseKind::kNoUse) {
|
|
return b;
|
|
} else if (b == UseKind::kNoUse) {
|
|
return a;
|
|
} else if (a == UseKind::kReuse || b == UseKind::kReuse) {
|
|
return UseKind::kReuse;
|
|
} else if (a == UseKind::kUsePermutingElements ||
|
|
b == UseKind::kUsePermutingElements) {
|
|
return UseKind::kReuse;
|
|
} else {
|
|
CHECK(a == UseKind::kUse && b == UseKind::kUse);
|
|
return UseKind::kUse;
|
|
}
|
|
}
|
|
};
|
|
|
|
HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const {
|
|
switch (opcode_) {
|
|
case HloOpcode::kBitcast:
|
|
case HloOpcode::kConcatenate:
|
|
case HloOpcode::kReshape:
|
|
case HloOpcode::kReverse:
|
|
case HloOpcode::kSlice:
|
|
case HloOpcode::kTranspose:
|
|
return UseKind::kUsePermutingElements;
|
|
case HloOpcode::kPad:
|
|
case HloOpcode::kReduce:
|
|
// Pad reuses the padding value but not the padded array elements.
|
|
// Reduce reuses the init value but not the operand array elements.
|
|
return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements;
|
|
case HloOpcode::kFusion:
|
|
// Uses the memoizing, recursive computation defined above.
|
|
return FusionReusesParamElements::Compute(i, *fused_expression_root());
|
|
case HloOpcode::kDot:
|
|
// Dot operations with inputs [A,B] * [B,1] do not re-use
|
|
// elements on their left operand.
|
|
// Dot operations with inputs [1,A] * [A,B] do not re-use
|
|
// elements on their right operand.
|
|
if (shape().dimensions_size() == 2) {
|
|
if ((i == 0 && shape().dimensions(1) == 1) ||
|
|
(i == 1 && shape().dimensions(0) == 1)) {
|
|
return UseKind::kUse;
|
|
}
|
|
}
|
|
return UseKind::kReuse;
|
|
case HloOpcode::kDynamicUpdateSlice:
|
|
// Dynamic-update-slice reuses only operand 2 (start_indices).
|
|
if (i == 0 || i == 1) {
|
|
return UseKind::kUse;
|
|
}
|
|
return UseKind::kReuse;
|
|
default:
|
|
return IsElementwise() && !ImplicitlyBroadcastsOperand(i)
|
|
? UseKind::kUse
|
|
: UseKind::kReuse;
|
|
}
|
|
}
|
|
|
|
std::tuple<bool, std::vector<int64>, std::vector<int64>>
|
|
HloInstruction::ReshapeMerelyInsertsOrDeletes1SizedDimensions() const {
|
|
if (HloOpcode::kReshape != opcode_) {
|
|
return std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
|
|
}
|
|
return ShapeUtil::InsertedOrDeleted1SizedDimensions(operand(0)->shape_,
|
|
shape_);
|
|
}
|
|
|
|
string ToString(HloInstruction::FusionKind kind) {
|
|
switch (kind) {
|
|
case HloInstruction::FusionKind::kLoop:
|
|
return "kLoop";
|
|
case HloInstruction::FusionKind::kInput:
|
|
return "kInput";
|
|
case HloInstruction::FusionKind::kOutput:
|
|
return "kOutput";
|
|
case HloInstruction::FusionKind::kTransposeDot:
|
|
return "kTransposeDot";
|
|
case HloInstruction::FusionKind::kConvBackwardFilter:
|
|
return "kConvBackwardFilter";
|
|
case HloInstruction::FusionKind::kConvBackwardInput:
|
|
return "kConvBackwardInput";
|
|
case HloInstruction::FusionKind::kCustom:
|
|
return "kCustom";
|
|
}
|
|
}
|
|
|
|
StatusOr<HloInstruction::FusionKind> StringToFusionKind(
|
|
const string& kind_name) {
|
|
if (kind_name == "kLoop") {
|
|
return HloInstruction::FusionKind::kLoop;
|
|
}
|
|
if (kind_name == "kInput") {
|
|
return HloInstruction::FusionKind::kInput;
|
|
}
|
|
if (kind_name == "kOutput") {
|
|
return HloInstruction::FusionKind::kOutput;
|
|
}
|
|
if (kind_name == "kTransposeDot") {
|
|
return HloInstruction::FusionKind::kTransposeDot;
|
|
}
|
|
if (kind_name == "kConvBackwardFilter") {
|
|
return HloInstruction::FusionKind::kConvBackwardFilter;
|
|
}
|
|
if (kind_name == "kConvBackwardInput") {
|
|
return HloInstruction::FusionKind::kConvBackwardInput;
|
|
}
|
|
if (kind_name == "kCustom") {
|
|
return HloInstruction::FusionKind::kCustom;
|
|
}
|
|
return InvalidArgument("Unknown fusion kind: %s", kind_name.c_str());
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
|
|
return os << ToString(kind);
|
|
}
|
|
|
|
string HloInstruction::ConvolutionDimensionNumbersToString() const {
|
|
string result;
|
|
if (convolution_dimension_numbers_ == nullptr) {
|
|
return result;
|
|
}
|
|
const ConvolutionDimensionNumbers& dnums = *convolution_dimension_numbers_;
|
|
// Show the given dimension labels in order of major to minor based on the
|
|
// shape's layout.
|
|
const auto append_dims = [&](const std::vector<string>& dims,
|
|
const Shape& shape) {
|
|
CHECK_EQ(dims.size(), ShapeUtil::Rank(shape));
|
|
for (int64 logical = 0; logical < dims.size(); ++logical) {
|
|
int64 physical = logical;
|
|
if (!shape.layout().minor_to_major().empty()) {
|
|
physical = LayoutUtil::Major(shape.layout(), logical);
|
|
}
|
|
result += dims[physical];
|
|
}
|
|
};
|
|
|
|
// lhs_dims[i] is the symbol of the logical dimension i for the lhs
|
|
// operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b".
|
|
std::vector<string> lhs_dims(2 + dnums.spatial_dimensions().size());
|
|
lhs_dims[dnums.input_batch_dimension()] = 'b';
|
|
lhs_dims[dnums.input_feature_dimension()] = 'f';
|
|
for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) {
|
|
lhs_dims[dnums.spatial_dimensions(i)] = StrCat(i);
|
|
}
|
|
|
|
std::vector<string> rhs_dims(2 + dnums.kernel_spatial_dimensions().size());
|
|
rhs_dims[dnums.kernel_input_feature_dimension()] = "i";
|
|
rhs_dims[dnums.kernel_output_feature_dimension()] = "o";
|
|
for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) {
|
|
rhs_dims[dnums.kernel_spatial_dimensions(i)] = StrCat(i);
|
|
}
|
|
|
|
std::vector<string> output_dims(2 + dnums.spatial_dimensions().size());
|
|
output_dims[dnums.output_batch_dimension()] = 'b';
|
|
output_dims[dnums.output_feature_dimension()] = 'f';
|
|
for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) {
|
|
output_dims[dnums.spatial_dimensions(i)] = StrCat(i);
|
|
}
|
|
|
|
result += "dim_labels=";
|
|
append_dims(lhs_dims, operand(0)->shape());
|
|
result += "_";
|
|
append_dims(rhs_dims, operand(1)->shape());
|
|
result += "->";
|
|
append_dims(output_dims, shape());
|
|
return result;
|
|
}
|
|
|
|
bool HloInstruction::CouldBeBitcast() const {
|
|
switch (opcode_) {
|
|
case HloOpcode::kTranspose:
|
|
return true;
|
|
case HloOpcode::kReshape:
|
|
return std::get<0>(ReshapeMerelyInsertsOrDeletes1SizedDimensions());
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
HloModule* HloInstruction::GetModule() const {
|
|
if (parent_) {
|
|
return parent_->parent();
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
void HloInstruction::UniquifyName(NameUniquer* name_uniquer) {
|
|
string parent_str = parent() == nullptr ? "noparent" : parent()->name();
|
|
name_ = name_uniquer->GetUniqueName(name_);
|
|
}
|
|
|
|
void HloInstruction::set_outer_dimension_partitions(
|
|
const std::vector<int64>& outer_dimension_partitions) {
|
|
outer_dimension_partitions_ = outer_dimension_partitions;
|
|
}
|
|
|
|
void HloInstruction::RelayoutConstant(const Layout& new_layout,
|
|
const ShapeIndex& shape_index) {
|
|
CHECK_EQ(opcode(), HloOpcode::kConstant);
|
|
Shape* mutable_array_subshape =
|
|
ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index);
|
|
CHECK(ShapeUtil::IsArray(*mutable_array_subshape));
|
|
|
|
// Normally array_subshape will always have a layout, but this invariant is
|
|
// temporarily broken in LayoutAssignment::AssignLayouts.
|
|
|
|
if (!mutable_array_subshape->has_layout() ||
|
|
!LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
|
|
literal_ = literal_->Relayout(new_layout, shape_index);
|
|
*mutable_array_subshape->mutable_layout() = new_layout;
|
|
}
|
|
}
|
|
|
|
} // namespace xla
|