STT-tensorflow/tensorflow/compiler/xla/service/hlo_replication_analysis.h
HyoukJoong Lee 3c28370a9c Combine cross-replica / cross-partition AllReduce after SPMD partition
PiperOrigin-RevId: 283610192
Change-Id: I801097d159c39d8137457c55906d455e0ee7733d
2019-12-03 13:32:47 -08:00

94 lines
4.1 KiB
C++

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REPLICATION_ANALYSIS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REPLICATION_ANALYSIS_H_
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
// An HLO pass that determines whether each instruction in the module outputs
// the same value across replicas or across partitions (depending on the value
// `cross_partition_spmd`). It propagates sources of replicated values to
// the rest of the module, where sources include cross-replica-sum, annotated
// entry parameters, and constants.
class HloReplicationAnalysis {
public:
// Runs the analysis on module and returns the result or an error.
static StatusOr<std::unique_ptr<HloReplicationAnalysis>> Run(
const HloModule* module, bool cross_partition_spmd);
// Same as above, but the caller can provide additional annotations: a set of
// while loops that are known to have the same iteration counts across
// replicas or partitions.
static StatusOr<std::unique_ptr<HloReplicationAnalysis>> Run(
const HloModule* module, bool cross_partition_spmd,
const absl::flat_hash_set<const HloInstruction*>*
loops_known_with_same_iterations);
// Returns if the HLO instruction outputs the same value (i.e., replicated) at
// the given index across all replicas or partitions.
bool HloInstructionIsReplicatedAt(const HloInstruction* inst,
const ShapeIndex& index) const;
private:
HloReplicationAnalysis(const HloModule* module, bool cross_partition_spmd,
const absl::flat_hash_set<const HloInstruction*>*
loops_known_with_same_iterations)
: module_(module),
cross_partition_spmd_(cross_partition_spmd),
loops_known_with_same_iterations_(*loops_known_with_same_iterations) {}
// Computes hlo_replication_.
void ComputeHloReplication();
// A helper function to recursively compute hlo_replication on a computation.
// Returns whether hlo_replication_ is changed.
bool ComputeHloReplicationOnComputation(const HloComputation* computation,
bool mark_everything_not_replicated);
const HloModule* module_;
// If true, run this replication analysis for replicated values across
// partitions (not across replicas) on an SPMD partitioned module. This means
// that HloInstructionIsReplicatedAt() returns true if the value is identical
// across partitions for each replica. The module-level parameter and root
// instructions may have HloSharding attributes that indicate whether values
// are identical across partitions.
//
// If false, HloReplicationAnalysis runs across replicas.
bool cross_partition_spmd_;
// A set of while loops that are known to have the same iteration counts
// across replicas or partitions. This is provided by the caller as additional
// annotations.
const absl::flat_hash_set<const HloInstruction*>&
loops_known_with_same_iterations_;
// A map from each analyzed HLO instruction to a shape tree that represents
// whether the instruction outputs the same value across replicas or
// partitions at each shape index.
absl::flat_hash_map<const HloInstruction*, ShapeTree<bool>> hlo_replication_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REPLICATION_ANALYSIS_H_