[XLA] Add the ability to disable the Reduce(Reshape(X)) to Reduce(X) rewrite and have TransposeFolding default to folding all operands of Convolution and Dot
PiperOrigin-RevId: 299518278 Change-Id: I295cc8ec91d5059b6f760bd4edd945fc10996820
This commit is contained in:
parent
edd36f52a3
commit
c05dd7b074
@ -3647,7 +3647,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
|
||||
// A reshape that collapses multiple dimensions into a dimension being
|
||||
// reduced can just reduce all of those dimensions instead of doing a
|
||||
// collapsing reshape before a reduction.
|
||||
if (arg->opcode() == HloOpcode::kReshape) {
|
||||
if (options_.enable_reduce_of_reshape() &&
|
||||
arg->opcode() == HloOpcode::kReshape) {
|
||||
std::vector<std::pair<int64, int64>> unmodified_dims =
|
||||
ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(),
|
||||
arg->shape());
|
||||
|
@ -107,6 +107,12 @@ class AlgebraicSimplifierOptions {
|
||||
return metadata_.cudnn_batchnorm_forward_training_metadata;
|
||||
}
|
||||
|
||||
void set_enable_reduce_of_reshape(bool enable_reduce_of_reshape) {
|
||||
enable_reduce_of_reshape_ = enable_reduce_of_reshape;
|
||||
}
|
||||
|
||||
bool enable_reduce_of_reshape() const { return enable_reduce_of_reshape_; }
|
||||
|
||||
private:
|
||||
// Metadata struct can be used to store any metadata information encapsulated
|
||||
// with the AlgebraicSimplierOptions that can be later used in an
|
||||
@ -126,6 +132,7 @@ class AlgebraicSimplifierOptions {
|
||||
bool enable_dot_to_multiply_rewrite_{true};
|
||||
bool enable_conv_simplification_{true};
|
||||
bool enable_window_reduce_to_reduce_replacement_{true};
|
||||
bool enable_reduce_of_reshape_{true};
|
||||
int64 very_small_gather_size_{4};
|
||||
Metadata metadata_;
|
||||
};
|
||||
|
@ -39,6 +39,13 @@ class TransposeFolding : public HloModulePass {
|
||||
const OperandIndices&) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Helper function to always fold transposes.
|
||||
static OperandIndices AlwaysFoldTranspose(const HloInstruction&,
|
||||
const OperandIndices& ids) {
|
||||
return ids;
|
||||
}
|
||||
|
||||
// transposable_gemm_operands returns the set of operands it wants to fold if
|
||||
// the instruction argument is implemented as a GEMM kernel that supports
|
||||
// transposing its arguments.
|
||||
@ -47,8 +54,10 @@ class TransposeFolding : public HloModulePass {
|
||||
// the instruction argument is implemented as a convolution that supports
|
||||
// transposing its arguments.
|
||||
explicit TransposeFolding(
|
||||
TransposableGemmOperandsFn transposable_gemm_operands,
|
||||
TransposableConvOperandsFn transposable_conv_operands);
|
||||
TransposableGemmOperandsFn transposable_gemm_operands =
|
||||
AlwaysFoldTranspose,
|
||||
TransposableConvOperandsFn transposable_conv_operands =
|
||||
AlwaysFoldTranspose);
|
||||
absl::string_view name() const override { return "transpose-folding"; }
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
Loading…
x
Reference in New Issue
Block a user