[XLA] Extract a common visitor for rewriting instructions.

PiperOrigin-RevId: 257293820
This commit is contained in:
George Karpenkov 2019-07-09 16:20:43 -07:00 committed by TensorFlower Gardener
parent 9fbc062b58
commit a7440c393a
13 changed files with 65 additions and 98 deletions

View File

@ -644,6 +644,7 @@ cc_library(
hdrs = ["call_inliner.h"],
deps = [
":call_graph",
":hlo",
":hlo_dce",
":hlo_pass",
"//tensorflow/compiler/xla:statusor",

View File

@ -168,13 +168,8 @@ bool IsUnstridedSlice(const HloInstruction* hlo) {
// algebraic expressions to simplified forms. Note: This only supports
// simplifications that simply look at the operands of an instruction. For the
// more general case a worklist based approach would be needed.
class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
public:
// Default visitor action is to do nothing and return OK.
Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
return Status::OK();
}
Status HandleAdd(HloInstruction* add) override;
Status HandleAnd(HloInstruction* logical_and) override;
@ -250,9 +245,6 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
Status HandleMap(HloInstruction* map) override;
// Returns whether algebraic simplification has occurred.
const bool changed() const { return changed_; }
// Runs the visitor on a computation.
static bool Run(HloComputation* computation,
const AlgebraicSimplifierOptions& options,
@ -350,35 +342,6 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
StatusOr<bool> TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
HloInstruction* broadcast);
// Replaces the existing HLO instruction old_instruction, with
// new_instruction, and marks the optimizer status as changed.
// Returns the Status representing the result of the replace operation.
Status ReplaceWithNewInstruction(
HloInstruction* old_instruction,
std::unique_ptr<HloInstruction> new_instruction) {
VLOG(3) << "Replacing instruction:";
VLOG(3) << " old: " << old_instruction->ToString();
VLOG(3) << " new: " << new_instruction->ToString();
TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
old_instruction, std::move(new_instruction)));
changed_ = true;
return Status::OK();
}
// Replaces the existing HLO instruction old_instruction, with
// new_instruction, and marks the optimizer status as changed.
// Returns the Status representing the result of the replace operation.
Status ReplaceInstruction(HloInstruction* old_instruction,
HloInstruction* new_instruction) {
VLOG(3) << "Replacing instruction:";
VLOG(3) << " old: " << old_instruction->ToString();
VLOG(3) << " new: " << new_instruction->ToString();
TF_RETURN_IF_ERROR(
computation_->ReplaceInstruction(old_instruction, new_instruction));
changed_ = true;
return Status::OK();
}
StatusOr<HloInstruction*> OptimizeDotOfConcat(HloInstruction* dot);
StatusOr<HloInstruction*> OptimizeDotOfConcatHelper(
const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim,
@ -445,7 +408,7 @@ bool AlgebraicSimplifierVisitor::Run(HloComputation* computation,
AlgebraicSimplifier* simplifier) {
AlgebraicSimplifierVisitor visitor(computation, options, simplifier);
TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_;
return visitor.changed_ || visitor.changed();
}
bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs,
@ -1723,6 +1686,7 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
}
Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
CHECK(computation_ == dot->parent());
HloInstruction *lhs, *rhs;
CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
if (options_.is_layout_sensitive()) {

View File

@ -46,13 +46,8 @@ using absl::optional;
// BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm
// operations into smaller operations.
class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
class BatchNormExpanderVisitor : public DfsHloRewriteVisitor {
public:
// Default visitor action is to do nothing and return OK.
Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
return Status::OK();
}
Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
Status HandleBatchNormInference(HloInstruction* batch_norm) override;
@ -63,9 +58,6 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
static bool Run(HloComputation* computation, bool rewrite_training_op,
bool rewrite_inference_op, bool rewrite_grad_op);
// Returns whether any batch norm ops were rewritten.
const bool changed() const { return changed_; }
~BatchNormExpanderVisitor() override = default;
private:
@ -133,28 +125,6 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
elements_per_feature_u32);
}
// Replaces the existing HLO instruction old_instruction, with
// new_instruction, and marks the optimizer status as changed.
// Returns the Status representing the result of the replace operation.
Status ReplaceWithNewInstruction(
HloInstruction* old_instruction,
std::unique_ptr<HloInstruction> new_instruction) {
TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
old_instruction, std::move(new_instruction)));
changed_ = true;
return Status::OK();
}
// Replaces the existing HLO instruction old_instruction, with
// new_instruction, and marks the optimizer status as changed.
// Returns the Status representing the result of the replace operation.
Status ReplaceInstruction(HloInstruction* old_instruction,
HloInstruction* new_instruction) {
TF_RETURN_IF_ERROR(
computation_->ReplaceInstruction(old_instruction, new_instruction));
changed_ = true;
return Status::OK();
}
// Current HloComputation instance the BatchNormExpander is
// traversing.
HloComputation* computation_;
@ -162,9 +132,6 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
bool rewrite_training_op_;
bool rewrite_inference_op_;
bool rewrite_grad_op_;
// Whether rewrite has occurred.
bool changed_ = false;
};
} // namespace
@ -179,7 +146,7 @@ bool BatchNormExpanderVisitor::Run(HloComputation* computation,
/*rewrite_inference_op=*/rewrite_inference_op,
/*rewrite_grad_op=*/rewrite_grad_op);
TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_;
return visitor.changed();
}
Status BatchNormExpanderVisitor::HandleBatchNormTraining(

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <deque>
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/core/lib/core/errors.h"

View File

@ -20,6 +20,8 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -29,9 +31,6 @@ limitations under the License.
namespace xla {
class HloComputation;
class HloInstruction;
// DfsHloVisitor with default action based on the HloInstruction being visited.
// Users should not use this class directly, but use the type aliases
// DfsHloVisitorWithDefault/ConstDfsHloVisitorWithDefault instead.
@ -246,6 +245,52 @@ using DfsHloVisitorWithDefault = DfsHloVisitorWithDefaultBase<HloInstruction*>;
using ConstDfsHloVisitorWithDefault =
DfsHloVisitorWithDefaultBase<const HloInstruction*>;
// A common base class for visitors performing rewriting operation.
//
// Subclasses call ReplaceWithNewInstruction and ReplaceInstruction while
// visiting.
class DfsHloRewriteVisitor : public DfsHloVisitorWithDefault {
public:
// Default visitor action is to do nothing and return OK.
Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
return Status::OK();
}
bool changed() const { return changed_; }
protected:
// Replaces the existing HLO instruction old_instruction, with
// new_instruction, and marks the optimizer status as changed.
// Returns the Status representing the result of the replace operation.
Status ReplaceWithNewInstruction(
HloInstruction* old_instruction,
std::unique_ptr<HloInstruction> new_instruction) {
VLOG(3) << "Replacing instruction:";
VLOG(3) << " old: " << old_instruction->ToString();
VLOG(3) << " new: " << new_instruction->ToString();
TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceWithNewInstruction(
old_instruction, std::move(new_instruction)));
changed_ = true;
return Status::OK();
}
// Replaces the existing HLO instruction old_instruction, with
// new_instruction, and marks the optimizer status as changed.
// Returns the Status representing the result of the replace operation.
Status ReplaceInstruction(HloInstruction* old_instruction,
HloInstruction* new_instruction) {
VLOG(3) << "Replacing instruction:";
VLOG(3) << " old: " << old_instruction->ToString();
VLOG(3) << " new: " << new_instruction->ToString();
TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceInstruction(
old_instruction, new_instruction));
changed_ = true;
return Status::OK();
}
bool changed_ = false;
};
// (Const)FunctionVisitor lets you transform an
// std::function<Status((const) HloInstruction*)> into a (Const)DfsHloVisitor.
//

View File

@ -14,8 +14,10 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
namespace xla {

View File

@ -61,7 +61,7 @@ static complex128 GetScalarConstantAsComplex(const Literal &literal) {
// and provided C has no other users).
// We then guide the buffer assignment to alias the buffer of the custom call
// and C.
class GemmRewriterVisitor : public DfsHloVisitorWithDefault {
class GemmRewriterVisitor : public DfsHloRewriteVisitor {
public:
Status HandleDot(HloInstruction *instr) override {
if (IsMatrixMultiplication(*instr)) {
@ -107,9 +107,7 @@ class GemmRewriterVisitor : public DfsHloVisitorWithDefault {
config.set_alpha_real(new_alpha.real());
config.set_alpha_imag(new_alpha.imag());
TF_RETURN_IF_ERROR(existing_gemm->set_backend_config(config));
TF_RETURN_IF_ERROR(
instr->parent()->ReplaceInstruction(instr, existing_gemm));
changed_ = true;
TF_RETURN_IF_ERROR(ReplaceInstruction(instr, existing_gemm));
}
}
return Status::OK();
@ -141,27 +139,12 @@ class GemmRewriterVisitor : public DfsHloVisitorWithDefault {
}
return Status::OK();
}
Status DefaultAction(HloInstruction *) override { return Status::OK(); }
bool IsChanged() { return changed_; }
private:
Status ReplaceWithNewInstruction(
HloInstruction *instr, std::unique_ptr<HloInstruction> replacement) {
TF_RETURN_IF_ERROR(instr->parent()->ReplaceWithNewInstruction(
instr, std::move(replacement)));
changed_ = true;
return Status::OK();
}
bool changed_ = false;
};
static StatusOr<bool> RunOnComputation(HloComputation *computation) {
GemmRewriterVisitor visitor;
TF_RETURN_IF_ERROR(computation->Accept(&visitor));
return visitor.IsChanged();
return visitor.changed();
}
StatusOr<bool> GemmRewriter::Run(HloModule *module) {

View File

@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/iterator_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_clone_context.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_clone_context.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"