STT-tensorflow/tensorflow/compiler/xla/service/hlo_replication_analysis.cc
Yunxing Dai 3a63cf6b99 Support x64<->x32 instructions in hlo replication analysis.
PiperOrigin-RevId: 334882856
Change-Id: Ic2699f2fa4513606300106e965187b9320045d2c
2020-10-01 13:25:16 -07:00

395 lines
17 KiB
C++

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_replication_analysis.h"
#include <memory>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
namespace {
// Determines whether an HLO instruction is replicated at index based on current
// knowledge in hlo_replication.
bool DetermineHloInstructionIsReplicated(
const HloInstruction* hlo, const ShapeIndex& index,
bool cross_partition_spmd,
const absl::flat_hash_map<const HloInstruction*, ShapeTree<bool>>&
hlo_replication) {
// Returns true if all operands are known to be replicated.
const auto all_operands_replicated =
[&hlo_replication](const HloInstruction* inst) {
for (auto operand : inst->operands()) {
auto operand_it = hlo_replication.find(operand);
if (operand_it == hlo_replication.end() ||
!operand_it->second.element({})) {
return false;
}
}
return true;
};
if (hlo->opcode() == HloOpcode::kAllReduce ||
hlo->opcode() == HloOpcode::kAllGather) {
// All-reduce/all-gather returns same values across partitions/replicas as
// long as its operands are replicated.
if (all_operands_replicated(hlo)) {
return true;
}
if (!hlo->channel_id().has_value()) {
// This is cross-replica-only.
if (cross_partition_spmd) {
return false;
}
// Only all-reduce/all-gather across all cores are replicated, which means
// there is only one subgroup.
return hlo->replica_groups().empty() || hlo->replica_groups().size() == 1;
} else {
bool global_id;
if (hlo->opcode() == HloOpcode::kAllReduce) {
global_id = Cast<HloAllReduceInstruction>(hlo)->use_global_device_ids();
} else {
global_id = Cast<HloAllGatherInstruction>(hlo)->use_global_device_ids();
}
if (global_id) {
bool replicated_across_partitions = true;
bool replicated_across_replicas = true;
const int64 num_partitions =
hlo->GetModule()->config().num_partitions();
for (const auto& group : hlo->replica_groups()) {
absl::flat_hash_set<int64> visited_partitions;
absl::flat_hash_set<int64> visited_replicas;
for (int64 id : group.replica_ids()) {
int64 rid = id / num_partitions;
int64 pid = id % num_partitions;
visited_partitions.insert(pid);
visited_replicas.insert(rid);
}
replicated_across_partitions &=
visited_partitions.size() == num_partitions;
replicated_across_replicas &=
visited_replicas.size() ==
hlo->GetModule()->config().replica_count();
}
return cross_partition_spmd ? replicated_across_partitions
: replicated_across_replicas;
}
return cross_partition_spmd ? true
: hlo->replica_groups().empty() ||
hlo->replica_groups().size() == 1;
}
}
if (hlo->HasSideEffectNoRecurse()) {
return false;
}
if (hlo->opcode() == HloOpcode::kReplicaId) {
// ReplicaId returns the same value for all partitions in each replica.
return cross_partition_spmd;
}
if (hlo->opcode() == HloOpcode::kPartitionId) {
// PartitionId returns the same value for all replicas in each partition.
return !cross_partition_spmd;
}
auto it = hlo_replication.find(hlo);
if (hlo->opcode() == HloOpcode::kParameter) {
// Parameters should have been processed.
return it != hlo_replication.end() && it->second.element(index);
}
if (it != hlo_replication.end() && !it->second.element(index)) {
// The HLO is already marked as non-replicated.
return false;
}
if (hlo->opcode() == HloOpcode::kConstant) {
return true;
}
if (hlo->opcode() == HloOpcode::kCustomCall &&
(hlo->custom_call_target() == "X64SplitLow" ||
hlo->custom_call_target() == "X64SplitHigh" ||
hlo->custom_call_target() == "X64Combine")) {
return all_operands_replicated(hlo);
}
if (hlo->IsElementwise() || //
hlo->opcode() == HloOpcode::kConcatenate || //
hlo->opcode() == HloOpcode::kConvolution || //
hlo->opcode() == HloOpcode::kDot || //
hlo->opcode() == HloOpcode::kReduce || //
hlo->opcode() == HloOpcode::kBroadcast || //
hlo->opcode() == HloOpcode::kTranspose || //
hlo->opcode() == HloOpcode::kReshape || //
hlo->opcode() == HloOpcode::kBitcast || //
hlo->opcode() == HloOpcode::kReverse || //
hlo->opcode() == HloOpcode::kGather || //
hlo->opcode() == HloOpcode::kScatter || //
hlo->opcode() == HloOpcode::kIota || //
hlo->opcode() == HloOpcode::kPad || //
hlo->opcode() == HloOpcode::kSlice || //
hlo->opcode() == HloOpcode::kDynamicSlice || //
hlo->opcode() == HloOpcode::kDynamicUpdateSlice || //
hlo->opcode() == HloOpcode::kReduceWindow || //
hlo->opcode() == HloOpcode::kCopy) {
return all_operands_replicated(hlo);
}
return false;
}
} // namespace
bool HloReplicationAnalysis::ComputeHloReplicationOnComputation(
const HloComputation* computation, bool mark_everything_not_replicated) {
bool changed = false;
for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
// Assigns the shape tree to dest if dest doesn't have one yet, or combines
// it with the existing one by and'ing them. Returns if anything is updated.
auto assign_or_combine_shapetree = [&](ShapeTree<bool>&& to_combine,
const HloInstruction* dest) {
auto it = hlo_replication_.find(dest);
if (it == hlo_replication_.end()) {
hlo_replication_[dest] = std::move(to_combine);
return true;
}
bool updated = false;
it->second.ForEachMutableElement(
[&](const ShapeIndex& index, bool* element) {
if (*element && !to_combine.element(index)) {
*element = false;
updated = true;
}
});
return updated;
};
// Assigns or combines source's shape tree to dest. Returns if anything is
// updated.
auto propagate_shapetree = [&](const HloInstruction* source,
const HloInstruction* dest) {
auto source_it = hlo_replication_.find(source);
if (source_it == hlo_replication_.end()) {
return false;
}
return assign_or_combine_shapetree(ShapeTree<bool>(source_it->second),
dest);
};
// For the opcodes below that we do special handling, we don't need to
// explicitly check mark_everything_not_replicated because if it is set, the
// operands should already be marked as not replicated.
if (inst->opcode() == HloOpcode::kWhile) {
// Since while body's input and output alias each other, we need to run it
// multiple times until a fixed point is reached.
while (true) {
// First, propagate the input's and body root's shape trees to the
// parameters of the body and condition.
bool updated = propagate_shapetree(
inst->operand(0),
inst->while_condition()->parameter_instruction(0));
updated |= propagate_shapetree(
inst->while_body()->root_instruction(),
inst->while_condition()->parameter_instruction(0));
updated |= propagate_shapetree(
inst->operand(0), inst->while_body()->parameter_instruction(0));
updated |=
propagate_shapetree(inst->while_body()->root_instruction(),
inst->while_body()->parameter_instruction(0));
// Compute the condition.
updated |= ComputeHloReplicationOnComputation(
inst->while_condition(), mark_everything_not_replicated);
// Compute the body. If the condition is not replicated, the while body
// should be different across replicas.
if (!ContainsKey(loops_known_with_same_iterations_, inst) &&
!hlo_replication_[inst->while_condition()->root_instruction()]
.element({})) {
updated |= ComputeHloReplicationOnComputation(
inst->while_body(), /*mark_everything_not_replicated=*/true);
} else {
updated |= ComputeHloReplicationOnComputation(
inst->while_body(), mark_everything_not_replicated);
}
if (!updated) {
break;
}
changed = true;
}
// Propagate the input's and body root's shape trees to the while HLO.
changed |= propagate_shapetree(inst->operand(0), inst);
changed |=
propagate_shapetree(inst->while_body()->root_instruction(), inst);
} else if (inst->opcode() == HloOpcode::kCall ||
inst->opcode() == HloOpcode::kFusion) {
auto called = inst->called_computations().front();
for (int64 i = 0; i < inst->operand_count(); ++i) {
changed |= propagate_shapetree(inst->operand(i),
called->parameter_instruction(i));
}
changed |= ComputeHloReplicationOnComputation(
called, mark_everything_not_replicated);
changed |= propagate_shapetree(called->root_instruction(), inst);
} else if (inst->opcode() == HloOpcode::kConditional) {
// Propagate inputs' shape trees to the called computations' parameters.
for (int64 i = 0; i < inst->called_computations().size(); ++i) {
changed |= propagate_shapetree(
inst->operand(i + 1),
inst->called_computations()[i]->parameter_instruction(0));
}
// If the condition is not replicated, the conditional result should be
// different across replicas.
if (!hlo_replication_[inst->operand(0)].element({})) {
for (auto called : inst->called_computations()) {
changed |= ComputeHloReplicationOnComputation(
called,
/*mark_everything_not_replicated=*/true);
}
changed |= assign_or_combine_shapetree(
ShapeTree<bool>(inst->shape(), false), inst);
} else {
for (auto called : inst->called_computations()) {
changed |= ComputeHloReplicationOnComputation(
called, mark_everything_not_replicated);
changed |= propagate_shapetree(called->root_instruction(), inst);
}
}
} else if (inst->opcode() == HloOpcode::kTupleSelect) {
if (!hlo_replication_[inst->operand(0)].element({})) {
// The predicate is not replicated, so the result is different across
// replicas.
changed |= assign_or_combine_shapetree(
ShapeTree<bool>(inst->shape(), false), inst);
} else {
changed |= propagate_shapetree(inst->operand(1), inst);
changed |= propagate_shapetree(inst->operand(2), inst);
}
} else if (inst->opcode() == HloOpcode::kTuple) {
ShapeTree<bool> shape_tree(inst->shape(), true);
for (int64 i = 0; i < inst->operand_count(); ++i) {
shape_tree.CopySubtreeFrom(hlo_replication_[inst->operand(i)], {}, {i});
}
changed |= assign_or_combine_shapetree(std::move(shape_tree), inst);
} else if (inst->opcode() == HloOpcode::kGetTupleElement) {
ShapeTree<bool> shape_tree(inst->shape(), true);
shape_tree.CopySubtreeFrom(hlo_replication_[inst->operand(0)],
{inst->tuple_index()}, {});
changed |= assign_or_combine_shapetree(std::move(shape_tree), inst);
} else if (inst->opcode() == HloOpcode::kInfeed && cross_partition_spmd_) {
ShapeTree<bool> shape_tree(inst->shape(), false);
if (inst->has_sharding()) {
auto sharding = inst->sharding().GetAsShapeTree(inst->shape());
shape_tree.ForEachMutableElement(
[&sharding](const ShapeIndex& index, bool* data) {
*data = sharding.element(index).IsReplicated();
});
}
changed |= assign_or_combine_shapetree(std::move(shape_tree), inst);
} else {
if (mark_everything_not_replicated) {
changed |= assign_or_combine_shapetree(
ShapeTree<bool>(inst->shape(), false), inst);
} else {
ShapeTree<bool> shape_tree(inst->shape(), true);
ShapeUtil::ForEachSubshape(
inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
*shape_tree.mutable_element(index) =
DetermineHloInstructionIsReplicated(
inst, index, cross_partition_spmd_, hlo_replication_);
return Status::OK();
});
changed |= assign_or_combine_shapetree(std::move(shape_tree), inst);
}
}
}
return changed;
}
void HloReplicationAnalysis::ComputeHloReplication() {
// Add entry parameters to the above sets according to user annotation.
// Replicated modules read from `parameter_replicated_at_leaf_buffers` whereas
// SPMD partitioned modules read from HloSharding attributes.
auto entry = module_->entry_computation();
for (int i = 0; i < entry->num_parameters(); ++i) {
auto param = entry->parameter_instruction(i);
ShapeTree<bool> shape_tree(param->shape(), false);
if (cross_partition_spmd_ && param->has_sharding()) {
auto sharding_tree =
param->sharding().AsShapeTree(param->shape()).ValueOrDie();
ShapeUtil::ForEachSubshape(
param->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
if (!ShapeUtil::IsLeafIndex(param->shape(), index)) {
return Status::OK();
}
*shape_tree.mutable_element(index) =
sharding_tree.element(index).IsReplicated();
return Status::OK();
});
} else if (!cross_partition_spmd_) {
const auto& replication = param->parameter_replicated_at_leaf_buffers();
int leaf_index = 0;
ShapeUtil::ForEachSubshape(
param->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
if (!ShapeUtil::IsLeafIndex(param->shape(), index)) {
return Status::OK();
}
if (replication && replication->at(leaf_index)) {
*shape_tree.mutable_element(index) = true;
}
++leaf_index;
return Status::OK();
});
}
hlo_replication_[param] = std::move(shape_tree);
}
ComputeHloReplicationOnComputation(entry,
/*mark_everything_not_replicated=*/false);
}
bool HloReplicationAnalysis::HloInstructionIsReplicatedAt(
const HloInstruction* inst, const ShapeIndex& index) const {
auto it = hlo_replication_.find(inst);
if (it == hlo_replication_.end()) {
return false;
}
return it->second.element(index);
}
/* static */ StatusOr<std::unique_ptr<HloReplicationAnalysis>>
HloReplicationAnalysis::Run(const HloModule* module,
bool cross_partition_spmd) {
const absl::flat_hash_set<const HloInstruction*> empty;
return Run(module, cross_partition_spmd, &empty);
}
/* static */ StatusOr<std::unique_ptr<HloReplicationAnalysis>>
HloReplicationAnalysis::Run(const HloModule* module, bool cross_partition_spmd,
const absl::flat_hash_set<const HloInstruction*>*
loops_known_with_same_iterations) {
auto analysis = absl::WrapUnique(new HloReplicationAnalysis(
module, cross_partition_spmd, loops_known_with_same_iterations));
analysis->ComputeHloReplication();
return analysis;
}
} // namespace xla