[XLA] Extract a common visitor for rewriting instructions.
PiperOrigin-RevId: 257293820
This commit is contained in:
parent
9fbc062b58
commit
a7440c393a
@ -644,6 +644,7 @@ cc_library(
|
||||
hdrs = ["call_inliner.h"],
|
||||
deps = [
|
||||
":call_graph",
|
||||
":hlo",
|
||||
":hlo_dce",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
|
@ -168,13 +168,8 @@ bool IsUnstridedSlice(const HloInstruction* hlo) {
|
||||
// algebraic expressions to simplified forms. Note: This only supports
|
||||
// simplifications that simply look at the operands of an instruction. For the
|
||||
// more general case a worklist based approach would be needed.
|
||||
class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
||||
class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
|
||||
public:
|
||||
// Default visitor action is to do nothing and return OK.
|
||||
Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HandleAdd(HloInstruction* add) override;
|
||||
|
||||
Status HandleAnd(HloInstruction* logical_and) override;
|
||||
@ -250,9 +245,6 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
Status HandleMap(HloInstruction* map) override;
|
||||
|
||||
// Returns whether algebraic simplification has occurred.
|
||||
const bool changed() const { return changed_; }
|
||||
|
||||
// Runs the visitor on a computation.
|
||||
static bool Run(HloComputation* computation,
|
||||
const AlgebraicSimplifierOptions& options,
|
||||
@ -350,35 +342,6 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
||||
StatusOr<bool> TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
|
||||
HloInstruction* broadcast);
|
||||
|
||||
// Replaces the existing HLO instruction old_instruction, with
|
||||
// new_instruction, and marks the optimizer status as changed.
|
||||
// Returns the Status representing the result of the replace operation.
|
||||
Status ReplaceWithNewInstruction(
|
||||
HloInstruction* old_instruction,
|
||||
std::unique_ptr<HloInstruction> new_instruction) {
|
||||
VLOG(3) << "Replacing instruction:";
|
||||
VLOG(3) << " old: " << old_instruction->ToString();
|
||||
VLOG(3) << " new: " << new_instruction->ToString();
|
||||
TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
|
||||
old_instruction, std::move(new_instruction)));
|
||||
changed_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Replaces the existing HLO instruction old_instruction, with
|
||||
// new_instruction, and marks the optimizer status as changed.
|
||||
// Returns the Status representing the result of the replace operation.
|
||||
Status ReplaceInstruction(HloInstruction* old_instruction,
|
||||
HloInstruction* new_instruction) {
|
||||
VLOG(3) << "Replacing instruction:";
|
||||
VLOG(3) << " old: " << old_instruction->ToString();
|
||||
VLOG(3) << " new: " << new_instruction->ToString();
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation_->ReplaceInstruction(old_instruction, new_instruction));
|
||||
changed_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> OptimizeDotOfConcat(HloInstruction* dot);
|
||||
StatusOr<HloInstruction*> OptimizeDotOfConcatHelper(
|
||||
const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim,
|
||||
@ -445,7 +408,7 @@ bool AlgebraicSimplifierVisitor::Run(HloComputation* computation,
|
||||
AlgebraicSimplifier* simplifier) {
|
||||
AlgebraicSimplifierVisitor visitor(computation, options, simplifier);
|
||||
TF_CHECK_OK(computation->Accept(&visitor));
|
||||
return visitor.changed_;
|
||||
return visitor.changed_ || visitor.changed();
|
||||
}
|
||||
|
||||
bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs,
|
||||
@ -1723,6 +1686,7 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
|
||||
}
|
||||
|
||||
Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
|
||||
CHECK(computation_ == dot->parent());
|
||||
HloInstruction *lhs, *rhs;
|
||||
CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
|
||||
if (options_.is_layout_sensitive()) {
|
||||
|
@ -46,13 +46,8 @@ using absl::optional;
|
||||
|
||||
// BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm
|
||||
// operations into smaller operations.
|
||||
class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
|
||||
class BatchNormExpanderVisitor : public DfsHloRewriteVisitor {
|
||||
public:
|
||||
// Default visitor action is to do nothing and return OK.
|
||||
Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
|
||||
|
||||
Status HandleBatchNormInference(HloInstruction* batch_norm) override;
|
||||
@ -63,9 +58,6 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
|
||||
static bool Run(HloComputation* computation, bool rewrite_training_op,
|
||||
bool rewrite_inference_op, bool rewrite_grad_op);
|
||||
|
||||
// Returns whether any batch norm ops were rewritten.
|
||||
const bool changed() const { return changed_; }
|
||||
|
||||
~BatchNormExpanderVisitor() override = default;
|
||||
|
||||
private:
|
||||
@ -133,28 +125,6 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
|
||||
elements_per_feature_u32);
|
||||
}
|
||||
|
||||
// Replaces the existing HLO instruction old_instruction, with
|
||||
// new_instruction, and marks the optimizer status as changed.
|
||||
// Returns the Status representing the result of the replace operation.
|
||||
Status ReplaceWithNewInstruction(
|
||||
HloInstruction* old_instruction,
|
||||
std::unique_ptr<HloInstruction> new_instruction) {
|
||||
TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
|
||||
old_instruction, std::move(new_instruction)));
|
||||
changed_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Replaces the existing HLO instruction old_instruction, with
|
||||
// new_instruction, and marks the optimizer status as changed.
|
||||
// Returns the Status representing the result of the replace operation.
|
||||
Status ReplaceInstruction(HloInstruction* old_instruction,
|
||||
HloInstruction* new_instruction) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation_->ReplaceInstruction(old_instruction, new_instruction));
|
||||
changed_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
// Current HloComputation instance the BatchNormExpander is
|
||||
// traversing.
|
||||
HloComputation* computation_;
|
||||
@ -162,9 +132,6 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
|
||||
bool rewrite_training_op_;
|
||||
bool rewrite_inference_op_;
|
||||
bool rewrite_grad_op_;
|
||||
|
||||
// Whether rewrite has occurred.
|
||||
bool changed_ = false;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@ -179,7 +146,7 @@ bool BatchNormExpanderVisitor::Run(HloComputation* computation,
|
||||
/*rewrite_inference_op=*/rewrite_inference_op,
|
||||
/*rewrite_grad_op=*/rewrite_grad_op);
|
||||
TF_CHECK_OK(computation->Accept(&visitor));
|
||||
return visitor.changed_;
|
||||
return visitor.changed();
|
||||
}
|
||||
|
||||
Status BatchNormExpanderVisitor::HandleBatchNormTraining(
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h"
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <deque>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
|
@ -20,6 +20,8 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
@ -29,9 +31,6 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
class HloComputation;
|
||||
class HloInstruction;
|
||||
|
||||
// DfsHloVisitor with default action based on the HloInstruction being visited.
|
||||
// Users should not use this class directly, but use the type aliases
|
||||
// DfsHloVisitorWithDefault/ConstDfsHloVisitorWithDefault instead.
|
||||
@ -246,6 +245,52 @@ using DfsHloVisitorWithDefault = DfsHloVisitorWithDefaultBase<HloInstruction*>;
|
||||
using ConstDfsHloVisitorWithDefault =
|
||||
DfsHloVisitorWithDefaultBase<const HloInstruction*>;
|
||||
|
||||
// A common base class for visitors performing rewriting operation.
|
||||
//
|
||||
// Subclasses call ReplaceWithNewInstruction and ReplaceInstruction while
|
||||
// visiting.
|
||||
class DfsHloRewriteVisitor : public DfsHloVisitorWithDefault {
|
||||
public:
|
||||
// Default visitor action is to do nothing and return OK.
|
||||
Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool changed() const { return changed_; }
|
||||
|
||||
protected:
|
||||
// Replaces the existing HLO instruction old_instruction, with
|
||||
// new_instruction, and marks the optimizer status as changed.
|
||||
// Returns the Status representing the result of the replace operation.
|
||||
Status ReplaceWithNewInstruction(
|
||||
HloInstruction* old_instruction,
|
||||
std::unique_ptr<HloInstruction> new_instruction) {
|
||||
VLOG(3) << "Replacing instruction:";
|
||||
VLOG(3) << " old: " << old_instruction->ToString();
|
||||
VLOG(3) << " new: " << new_instruction->ToString();
|
||||
TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceWithNewInstruction(
|
||||
old_instruction, std::move(new_instruction)));
|
||||
changed_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Replaces the existing HLO instruction old_instruction, with
|
||||
// new_instruction, and marks the optimizer status as changed.
|
||||
// Returns the Status representing the result of the replace operation.
|
||||
Status ReplaceInstruction(HloInstruction* old_instruction,
|
||||
HloInstruction* new_instruction) {
|
||||
VLOG(3) << "Replacing instruction:";
|
||||
VLOG(3) << " old: " << old_instruction->ToString();
|
||||
VLOG(3) << " new: " << new_instruction->ToString();
|
||||
TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceInstruction(
|
||||
old_instruction, new_instruction));
|
||||
changed_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool changed_ = false;
|
||||
};
|
||||
|
||||
// (Const)FunctionVisitor lets you transform an
|
||||
// std::function<Status((const) HloInstruction*)> into a (Const)DfsHloVisitor.
|
||||
//
|
||||
|
@ -14,8 +14,10 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
|
||||
namespace xla {
|
||||
|
@ -61,7 +61,7 @@ static complex128 GetScalarConstantAsComplex(const Literal &literal) {
|
||||
// and provided C has no other users).
|
||||
// We then guide the buffer assignment to alias the buffer of the custom call
|
||||
// and C.
|
||||
class GemmRewriterVisitor : public DfsHloVisitorWithDefault {
|
||||
class GemmRewriterVisitor : public DfsHloRewriteVisitor {
|
||||
public:
|
||||
Status HandleDot(HloInstruction *instr) override {
|
||||
if (IsMatrixMultiplication(*instr)) {
|
||||
@ -107,9 +107,7 @@ class GemmRewriterVisitor : public DfsHloVisitorWithDefault {
|
||||
config.set_alpha_real(new_alpha.real());
|
||||
config.set_alpha_imag(new_alpha.imag());
|
||||
TF_RETURN_IF_ERROR(existing_gemm->set_backend_config(config));
|
||||
TF_RETURN_IF_ERROR(
|
||||
instr->parent()->ReplaceInstruction(instr, existing_gemm));
|
||||
changed_ = true;
|
||||
TF_RETURN_IF_ERROR(ReplaceInstruction(instr, existing_gemm));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
@ -141,27 +139,12 @@ class GemmRewriterVisitor : public DfsHloVisitorWithDefault {
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DefaultAction(HloInstruction *) override { return Status::OK(); }
|
||||
|
||||
bool IsChanged() { return changed_; }
|
||||
|
||||
private:
|
||||
Status ReplaceWithNewInstruction(
|
||||
HloInstruction *instr, std::unique_ptr<HloInstruction> replacement) {
|
||||
TF_RETURN_IF_ERROR(instr->parent()->ReplaceWithNewInstruction(
|
||||
instr, std::move(replacement)));
|
||||
changed_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool changed_ = false;
|
||||
};
|
||||
|
||||
static StatusOr<bool> RunOnComputation(HloComputation *computation) {
|
||||
GemmRewriterVisitor visitor;
|
||||
TF_RETURN_IF_ERROR(computation->Accept(&visitor));
|
||||
return visitor.IsChanged();
|
||||
return visitor.changed();
|
||||
}
|
||||
|
||||
StatusOr<bool> GemmRewriter::Run(HloModule *module) {
|
||||
|
@ -30,7 +30,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/iterator_util.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_clone_context.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/logical_buffer.h"
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_clone_context.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
|
Loading…
x
Reference in New Issue
Block a user