[XLA] Add the ability to disable the Reduce(Reshape(X)) to Reduce(X) rewrite and have TransposeFolding default to folding all operands of Convolution and Dot

PiperOrigin-RevId: 299518278
Change-Id: I295cc8ec91d5059b6f760bd4edd945fc10996820
This commit is contained in:
Blake Hechtman 2020-03-07 00:33:38 -08:00 committed by TensorFlower Gardener
parent edd36f52a3
commit c05dd7b074
3 changed files with 20 additions and 3 deletions

View File

@ -3647,7 +3647,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
// A reshape that collapses multiple dimensions into a dimension being
// reduced can just reduce all of those dimensions instead of doing a
// collapsing reshape before a reduction.
if (arg->opcode() == HloOpcode::kReshape) {
if (options_.enable_reduce_of_reshape() &&
arg->opcode() == HloOpcode::kReshape) {
std::vector<std::pair<int64, int64>> unmodified_dims =
ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(),
arg->shape());

View File

@ -107,6 +107,12 @@ class AlgebraicSimplifierOptions {
return metadata_.cudnn_batchnorm_forward_training_metadata;
}
void set_enable_reduce_of_reshape(bool enable_reduce_of_reshape) {
enable_reduce_of_reshape_ = enable_reduce_of_reshape;
}
bool enable_reduce_of_reshape() const { return enable_reduce_of_reshape_; }
private:
// Metadata struct can be used to store any metadata information encapsulated
// with the AlgebraicSimplierOptions that can be later used in an
@ -126,6 +132,7 @@ class AlgebraicSimplifierOptions {
bool enable_dot_to_multiply_rewrite_{true};
bool enable_conv_simplification_{true};
bool enable_window_reduce_to_reduce_replacement_{true};
bool enable_reduce_of_reshape_{true};
int64 very_small_gather_size_{4};
Metadata metadata_;
};

View File

@ -39,6 +39,13 @@ class TransposeFolding : public HloModulePass {
const OperandIndices&) {
return {};
}
// Helper function to always fold transposes.
static OperandIndices AlwaysFoldTranspose(const HloInstruction&,
const OperandIndices& ids) {
return ids;
}
// transposable_gemm_operands returns the set of operands it wants to fold if
// the instruction argument is implemented as a GEMM kernel that supports
// transposing its arguments.
@ -47,8 +54,10 @@ class TransposeFolding : public HloModulePass {
// the instruction argument is implemented as a convolution that supports
// transposing its arguments.
explicit TransposeFolding(
TransposableGemmOperandsFn transposable_gemm_operands,
TransposableConvOperandsFn transposable_conv_operands);
TransposableGemmOperandsFn transposable_gemm_operands =
AlwaysFoldTranspose,
TransposableConvOperandsFn transposable_conv_operands =
AlwaysFoldTranspose);
absl::string_view name() const override { return "transpose-folding"; }
StatusOr<bool> Run(HloModule* module) override;