[XLA] Convert Abs(a)*Abs(a) to a*a and add an option to allow for numerically unsafe algebraic simplifications
PiperOrigin-RevId: 325084126 Change-Id: Id8bf89ba6601d7bb1efc2b167e6e9accf5913114
This commit is contained in:
parent
8846105326
commit
b2f5d100d1
@ -665,7 +665,7 @@ Status AlgebraicSimplifierVisitor::ScalarMultiplyReduction(
|
|||||||
HloInstruction* inst;
|
HloInstruction* inst;
|
||||||
HloInstruction* user;
|
HloInstruction* user;
|
||||||
int64 index;
|
int64 index;
|
||||||
std::tie(inst, user, index) = operands.back();
|
std::tie (inst, user, index) = operands.back();
|
||||||
operands.pop_back();
|
operands.pop_back();
|
||||||
|
|
||||||
// Skip the op types that are not commutative with multiply.
|
// Skip the op types that are not commutative with multiply.
|
||||||
@ -913,7 +913,7 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
|
|||||||
(Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) &&
|
(Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) &&
|
||||||
Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) &&
|
Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) &&
|
||||||
(ShapeUtil::ElementIsIntegral(add->shape()) ||
|
(ShapeUtil::ElementIsIntegral(add->shape()) ||
|
||||||
options_.enable_floats_are_real() || IsAllFpConstantPowerOf2(c))) {
|
IsAllFpConstantPowerOf2(c))) {
|
||||||
return ReplaceWithNewInstruction(
|
return ReplaceWithNewInstruction(
|
||||||
add, HloInstruction::CreateBinary(
|
add, HloInstruction::CreateBinary(
|
||||||
add->shape(), HloOpcode::kMultiply,
|
add->shape(), HloOpcode::kMultiply,
|
||||||
@ -2667,17 +2667,6 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
|
||||||
HloInstruction* abs_operand;
|
|
||||||
if (lhs == rhs && Match(lhs, m::Abs(m::Op(&abs_operand))) &&
|
|
||||||
!ShapeUtil::ElementIsComplex(abs_operand->shape())) {
|
|
||||||
TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(0, abs_operand));
|
|
||||||
TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(1, abs_operand));
|
|
||||||
changed_ = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
{
|
||||||
HloInstruction *convert_operand, *operand;
|
HloInstruction *convert_operand, *operand;
|
||||||
// Mul(Convert(Pred), operand) => select(pred, operand, 0)
|
// Mul(Convert(Pred), operand) => select(pred, operand, 0)
|
||||||
@ -3048,8 +3037,8 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
|
|||||||
HloInstruction* new_broadcast = computation_->AddInstruction(
|
HloInstruction* new_broadcast = computation_->AddInstruction(
|
||||||
HloInstruction::CreateBroadcast(user->shape(), operand, {}));
|
HloInstruction::CreateBroadcast(user->shape(), operand, {}));
|
||||||
// Use HloInstruction::ReplaceAllUsesWith instead of
|
// Use HloInstruction::ReplaceAllUsesWith instead of
|
||||||
// HloComputation::ReplaceWithNewInstruction because we are replacing
|
// HloComputation::ReplaceWithNewInstruction because we are replacing an
|
||||||
// an instruction other than the visited instruction.
|
// instruction other than the visited instruction.
|
||||||
changed_ = true;
|
changed_ = true;
|
||||||
return user->ReplaceAllUsesWith(new_broadcast);
|
return user->ReplaceAllUsesWith(new_broadcast);
|
||||||
}
|
}
|
||||||
@ -3166,11 +3155,9 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) {
|
|||||||
|
|
||||||
// Eliminate a convert pair if it is a no-op. The following are a few
|
// Eliminate a convert pair if it is a no-op. The following are a few
|
||||||
// example cases that are being handled:
|
// example cases that are being handled:
|
||||||
// 1. convert(convert(A, $TYPE1), $TYPE2) is simplified to A if A is of
|
// 1. convert(convert(A, $TYPE1), $TYPE2) is simplified to A if A is of $TYPE2
|
||||||
// $TYPE2
|
|
||||||
// and convert(A, $TYPE1) is an upcast
|
// and convert(A, $TYPE1) is an upcast
|
||||||
// 2. convert(convert(A, $TYPE1),$TYPE2) is simplified to A if A is of
|
// 2. convert(convert(A, $TYPE1),$TYPE2) is simplified to A if A is of $TYPE2
|
||||||
// $TYPE2
|
|
||||||
// and convert(A, $TYPE1) is an upcast and is an integral conversion from
|
// and convert(A, $TYPE1) is an upcast and is an integral conversion from
|
||||||
// unsigned to signed (only signed to unsigned conversion is NOT allowed)
|
// unsigned to signed (only signed to unsigned conversion is NOT allowed)
|
||||||
// 3. Tuple(convert(A, $TYPE1) , floor(convert(convert(A, $TYPE1), $TYPE2)),
|
// 3. Tuple(convert(A, $TYPE1) , floor(convert(convert(A, $TYPE1), $TYPE2)),
|
||||||
@ -3306,8 +3293,7 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
|
|||||||
pad->shape(), nonzero_pad->mutable_shape()));
|
pad->shape(), nonzero_pad->mutable_shape()));
|
||||||
simplifier_->UpdateLayout(nonzero_pad->mutable_shape());
|
simplifier_->UpdateLayout(nonzero_pad->mutable_shape());
|
||||||
|
|
||||||
// Second, construct the slice instruction to perform the negative
|
// Second, construct the slice instruction to perform the negative padding.
|
||||||
// padding.
|
|
||||||
std::vector<int64> start_indices;
|
std::vector<int64> start_indices;
|
||||||
std::vector<int64> end_indices;
|
std::vector<int64> end_indices;
|
||||||
std::vector<int64> strides;
|
std::vector<int64> strides;
|
||||||
@ -3460,8 +3446,8 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
|
|||||||
|
|
||||||
Shape changed_shape;
|
Shape changed_shape;
|
||||||
for (HloInstruction* user_operand : user->operands()) {
|
for (HloInstruction* user_operand : user->operands()) {
|
||||||
// If this is a broadcast operand that is not our original broadcast
|
// If this is a broadcast operand that is not our original broadcast input
|
||||||
// input to this function then we might need to change the input.
|
// to this function then we might need to change the input.
|
||||||
if (is_compatible_broadcast(user_operand)) {
|
if (is_compatible_broadcast(user_operand)) {
|
||||||
// If this is a broadcast from a scalar value rewrite a broadcast from
|
// If this is a broadcast from a scalar value rewrite a broadcast from
|
||||||
// the scalar to the new shape enforced from the other broadcast
|
// the scalar to the new shape enforced from the other broadcast
|
||||||
@ -3632,16 +3618,16 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
|
|||||||
// If M < N, then {0, ..., M} % N ==> {0, ..., M}.
|
// If M < N, then {0, ..., M} % N ==> {0, ..., M}.
|
||||||
//
|
//
|
||||||
// Currently this only covers the case when N is a broadcasted constant
|
// Currently this only covers the case when N is a broadcasted constant
|
||||||
// scalar. We could also cover the case when N is a non-broadcasted
|
// scalar. We could also cover the case when N is a non-broadcasted constant
|
||||||
// constant with the same value repeated.
|
// with the same value repeated.
|
||||||
HloInstruction* iota;
|
HloInstruction* iota;
|
||||||
HloInstruction* divisor;
|
HloInstruction* divisor;
|
||||||
if (Match(remainder,
|
if (Match(remainder,
|
||||||
m::Remainder(m::Iota(&iota),
|
m::Remainder(m::Iota(&iota),
|
||||||
m::Broadcast(m::ConstantEffectiveScalar(&divisor))))) {
|
m::Broadcast(m::ConstantEffectiveScalar(&divisor))))) {
|
||||||
// The iota counts {0, ..., iota_upper_bound - 1}. (Actually this is
|
// The iota counts {0, ..., iota_upper_bound - 1}. (Actually this is
|
||||||
// conservative; the iota may overflow and count up to a smaller value
|
// conservative; the iota may overflow and count up to a smaller value than
|
||||||
// than this. But that's OK for our purposes here.)
|
// this. But that's OK for our purposes here.)
|
||||||
int64 iota_upper_bound = iota->shape().dimensions(
|
int64 iota_upper_bound = iota->shape().dimensions(
|
||||||
Cast<HloIotaInstruction>(iota)->iota_dimension());
|
Cast<HloIotaInstruction>(iota)->iota_dimension());
|
||||||
absl::optional<int64> divisor_val = divisor->literal().GetIntegralAsS64(
|
absl::optional<int64> divisor_val = divisor->literal().GetIntegralAsS64(
|
||||||
@ -3654,8 +3640,8 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
|
|||||||
// (X + N) % N = X % N, so long as X + N does not overflow.
|
// (X + N) % N = X % N, so long as X + N does not overflow.
|
||||||
//
|
//
|
||||||
// We don't have range tracking in XLA that would let us know whether X + N
|
// We don't have range tracking in XLA that would let us know whether X + N
|
||||||
// overflows, so for now we only do this simplification when X is an iota.
|
// overflows, so for now we only do this simplification when X is an iota. We
|
||||||
// We could add other operations where it's easy to see a range, such as
|
// could add other operations where it's easy to see a range, such as
|
||||||
// remainder, convert, etc., though at some point we'd probably want a
|
// remainder, convert, etc., though at some point we'd probably want a
|
||||||
// range-tracking analysis.
|
// range-tracking analysis.
|
||||||
HloInstruction* bcast;
|
HloInstruction* bcast;
|
||||||
@ -3667,9 +3653,9 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
|
|||||||
m::Broadcast(m::ConstantEffectiveScalar(&addend))),
|
m::Broadcast(m::ConstantEffectiveScalar(&addend))),
|
||||||
m::Broadcast(&bcast, m::ConstantEffectiveScalar(&divisor)))) &&
|
m::Broadcast(&bcast, m::ConstantEffectiveScalar(&divisor)))) &&
|
||||||
addend == divisor) {
|
addend == divisor) {
|
||||||
// The iota counts {0, ...iota_upper_bound - 1}, with the same caveat
|
// The iota counts {0, ...iota_upper_bound - 1}, with the same caveat above
|
||||||
// above that iota_upper_bound is conservative, and the true upper bound
|
// that iota_upper_bound is conservative, and the true upper bound may be
|
||||||
// may be smaller.
|
// smaller.
|
||||||
int64 iota_upper_bound = iota->shape().dimensions(
|
int64 iota_upper_bound = iota->shape().dimensions(
|
||||||
Cast<HloIotaInstruction>(iota)->iota_dimension());
|
Cast<HloIotaInstruction>(iota)->iota_dimension());
|
||||||
absl::optional<int64> divisor_val = divisor->literal().GetIntegralAsS64(
|
absl::optional<int64> divisor_val = divisor->literal().GetIntegralAsS64(
|
||||||
@ -3774,9 +3760,9 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) {
|
|||||||
|
|
||||||
StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyScalarSlice(
|
StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyScalarSlice(
|
||||||
HloInstruction* slice) {
|
HloInstruction* slice) {
|
||||||
// Only try to do this for effective scalars. We could do the same for
|
// Only try to do this for effective scalars. We could do the same for slicing
|
||||||
// slicing out larger pieces of padding (replacing with a broadcast of the
|
// out larger pieces of padding (replacing with a broadcast of the padding
|
||||||
// padding value), but this is probably not worth it.
|
// value), but this is probably not worth it.
|
||||||
if (!ShapeUtil::IsEffectiveScalar(slice->shape())) {
|
if (!ShapeUtil::IsEffectiveScalar(slice->shape())) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -3877,8 +3863,8 @@ StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allowing a slice to move through a reverse with any necessary updates to
|
// Allowing a slice to move through a reverse with any necessary updates to the
|
||||||
// the slice config.
|
// slice config.
|
||||||
StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse(
|
StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse(
|
||||||
HloInstruction* slice) {
|
HloInstruction* slice) {
|
||||||
VLOG(2) << "Entered TryToReorderSliceAndReverse for slice:"
|
VLOG(2) << "Entered TryToReorderSliceAndReverse for slice:"
|
||||||
@ -3906,8 +3892,8 @@ StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse(
|
|||||||
<< new_limits[rdim];
|
<< new_limits[rdim];
|
||||||
}
|
}
|
||||||
// New slice formed from the reverse_operand, but strides and shape of the
|
// New slice formed from the reverse_operand, but strides and shape of the
|
||||||
// slice output remains the same. New slice's starts and limits are
|
// slice output remains the same. New slice's starts and limits are updated
|
||||||
// updated for ONLY the reversed dimensions as indicated above.
|
// for ONLY the reversed dimensions as indicated above.
|
||||||
HloInstruction* new_slice = computation_->AddInstruction(
|
HloInstruction* new_slice = computation_->AddInstruction(
|
||||||
HloInstruction::CreateSlice(slice->shape(), reverse_operand, new_starts,
|
HloInstruction::CreateSlice(slice->shape(), reverse_operand, new_starts,
|
||||||
new_limits, new_strides));
|
new_limits, new_strides));
|
||||||
@ -3934,8 +3920,7 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
|
|||||||
if (Match(slice, m::Slice(m::Pad(&pad, m::Op(&pad_operand), m::Op())))) {
|
if (Match(slice, m::Slice(m::Pad(&pad, m::Op(&pad_operand), m::Op())))) {
|
||||||
// Is the result of the slice the pad operand.
|
// Is the result of the slice the pad operand.
|
||||||
bool slice_undoes_pad = true;
|
bool slice_undoes_pad = true;
|
||||||
// Can the slice be moved to the pad_operand without any padding being
|
// Can the slice be moved to the pad_operand without any padding being read.
|
||||||
// read.
|
|
||||||
bool slice_inside_pad = true;
|
bool slice_inside_pad = true;
|
||||||
// Does this slice slice out pading only.
|
// Does this slice slice out pading only.
|
||||||
bool slice_in_padding = false;
|
bool slice_in_padding = false;
|
||||||
@ -4070,8 +4055,8 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do not try to reorder slices and reshapes after layout assignment as it
|
// Do not try to reorder slices and reshapes after layout assignment as it may
|
||||||
// may be invalid.
|
// be invalid.
|
||||||
if (!options_.is_layout_sensitive()) {
|
if (!options_.is_layout_sensitive()) {
|
||||||
TF_ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice));
|
TF_ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice));
|
||||||
}
|
}
|
||||||
@ -4121,8 +4106,8 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
|
|||||||
if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
|
if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
|
||||||
return ReplaceInstruction(dynamic_slice, operand);
|
return ReplaceInstruction(dynamic_slice, operand);
|
||||||
}
|
}
|
||||||
// DynamicSlice where operand has the same size as the output is simply
|
// DynamicSlice where operand has the same size as the output is simply equal
|
||||||
// equal to operand.
|
// to operand.
|
||||||
if (SameShape(operand, dynamic_slice)) {
|
if (SameShape(operand, dynamic_slice)) {
|
||||||
return ReplaceInstruction(dynamic_slice, operand);
|
return ReplaceInstruction(dynamic_slice, operand);
|
||||||
}
|
}
|
||||||
@ -4453,8 +4438,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
|
|||||||
// Convert Reduce(concat({a,b,...})) to
|
// Convert Reduce(concat({a,b,...})) to
|
||||||
// map(reduce(a),map(reduce(b),...,))
|
// map(reduce(a),map(reduce(b),...,))
|
||||||
//
|
//
|
||||||
// This should make fusion easier or use less memory bandwidth in the
|
// This should make fusion easier or use less memory bandwidth in the unfused
|
||||||
// unfused case.
|
// case.
|
||||||
if (arg->opcode() == HloOpcode::kConcatenate &&
|
if (arg->opcode() == HloOpcode::kConcatenate &&
|
||||||
absl::c_linear_search(reduce->dimensions(),
|
absl::c_linear_search(reduce->dimensions(),
|
||||||
arg->concatenate_dimension())) {
|
arg->concatenate_dimension())) {
|
||||||
@ -4473,9 +4458,9 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
HloInstruction *dot, *lhs, *rhs;
|
HloInstruction *dot, *lhs, *rhs;
|
||||||
// Convert Reduce(Dot(X,Y)) to Dot(X,Y) if any of the dimensions reduced
|
// Convert Reduce(Dot(X,Y)) to Dot(X,Y) if any of the dimensions reduced were
|
||||||
// were batch dimensions of the dot. The transformation supports reducing
|
// batch dimensions of the dot. The transformation supports reducing other
|
||||||
// other dimensions as well.
|
// dimensions as well.
|
||||||
if (options_.enable_dot_strength_reduction() &&
|
if (options_.enable_dot_strength_reduction() &&
|
||||||
Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) &&
|
Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) &&
|
||||||
Match(reduce->to_apply()->root_instruction(),
|
Match(reduce->to_apply()->root_instruction(),
|
||||||
@ -4547,13 +4532,13 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
|
|||||||
if (options_.enable_window_reduce_to_reduce_replacement()) {
|
if (options_.enable_window_reduce_to_reduce_replacement()) {
|
||||||
// A reduce window can be expressed as a reduce and a reshape if all
|
// A reduce window can be expressed as a reduce and a reshape if all
|
||||||
// dimensions either have a window size of one or the entire dimension. If
|
// dimensions either have a window size of one or the entire dimension. If
|
||||||
// there is no stride, dilation, or padding, this is as easy as checking
|
// there is no stride, dilation, or padding, this is as easy as checking the
|
||||||
// the size of the output shape and window dimension.
|
// size of the output shape and window dimension.
|
||||||
//
|
//
|
||||||
// The reshape is a bitcast since it adds one-sized dimensions. Often
|
// The reshape is a bitcast since it adds one-sized dimensions. Often these
|
||||||
// these ones are immediately removed as well with another reshape. The
|
// ones are immediately removed as well with another reshape. The
|
||||||
// implementation of reduce tends to be slightly more efficient at
|
// implementation of reduce tends to be slightly more efficient at reducing
|
||||||
// reducing entire dimensions compared to reduce window.
|
// entire dimensions compared to reduce window.
|
||||||
auto effective_reduce_dims = [&] {
|
auto effective_reduce_dims = [&] {
|
||||||
if (window_util::HasStride(window) || window_util::HasDilation(window) ||
|
if (window_util::HasStride(window) || window_util::HasDilation(window) ||
|
||||||
window_util::HasPadding(window)) {
|
window_util::HasPadding(window)) {
|
||||||
@ -5068,8 +5053,7 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
|
|||||||
|
|
||||||
auto new_dim = swapped_window.add_dimensions();
|
auto new_dim = swapped_window.add_dimensions();
|
||||||
new_dim->set_size(input_size);
|
new_dim->set_size(input_size);
|
||||||
// If the kernel is not reversed, the activations must be manually
|
// If the kernel is not reversed, the activations must be manually reversed.
|
||||||
// reversed.
|
|
||||||
if (!window_dims[spatial_dim].window_reversal()) {
|
if (!window_dims[spatial_dim].window_reversal()) {
|
||||||
reverse_dimensions.push_back(
|
reverse_dimensions.push_back(
|
||||||
dnums.kernel_spatial_dimensions(spatial_dim));
|
dnums.kernel_spatial_dimensions(spatial_dim));
|
||||||
@ -5089,8 +5073,8 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
|
|||||||
dilated_kernel_size);
|
dilated_kernel_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Don't transform if a naive convolution implementation would not have
|
// Don't transform if a naive convolution implementation would not have fewer
|
||||||
// fewer flops.
|
// flops.
|
||||||
if (kernel_product <= swapped_kernel_product) {
|
if (kernel_product <= swapped_kernel_product) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -5168,11 +5152,11 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stride ignores part of the output, which matrix multiplication does not
|
// Stride ignores part of the output, which matrix multiplication does not do,
|
||||||
// do, so require no stride. Padding and base (lhs) dilation both implicitly
|
// so require no stride. Padding and base (lhs) dilation both implicitly
|
||||||
// extend the data, which matrix multiplication also does not do, so require
|
// extend the data, which matrix multiplication also does not do, so require
|
||||||
// no padding and no base (lhs) dilation. Window (rhs) dilation has no
|
// no padding and no base (lhs) dilation. Window (rhs) dilation has no effect
|
||||||
// effect for a 1x1 window, so window dilation is no problem.
|
// for a 1x1 window, so window dilation is no problem.
|
||||||
if (window_util::HasStride(window) || window_util::HasPadding(window) ||
|
if (window_util::HasStride(window) || window_util::HasPadding(window) ||
|
||||||
window_util::HasBaseDilation(window)) {
|
window_util::HasBaseDilation(window)) {
|
||||||
return false;
|
return false;
|
||||||
@ -5225,9 +5209,8 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// We already checked feature_dimension is most minor, so data in
|
// We already checked feature_dimension is most minor, so data in input_shape
|
||||||
// input_shape and row-major {conv_width,input_channels} are bitwise
|
// and row-major {conv_width,input_channels} are bitwise identical.
|
||||||
// identical.
|
|
||||||
Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout(
|
Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout(
|
||||||
input_shape.element_type(), {conv_width, input_channels});
|
input_shape.element_type(), {conv_width, input_channels});
|
||||||
simplifier_->UpdateLayout(&new_input_shape);
|
simplifier_->UpdateLayout(&new_input_shape);
|
||||||
|
@ -97,14 +97,6 @@ class AlgebraicSimplifierOptions {
|
|||||||
return enable_scalar_multiply_reduction_;
|
return enable_scalar_multiply_reduction_;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Also the algebraic simplifer to treat floating point values like real
|
|
||||||
// numbers.
|
|
||||||
void set_enable_floats_are_real(bool enable_floats_are_real) {
|
|
||||||
enable_floats_are_real_ = enable_floats_are_real;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool enable_floats_are_real() const { return enable_floats_are_real_; }
|
|
||||||
|
|
||||||
// If enable_window_reduce_replacement is true, the kReduceWindow instruction
|
// If enable_window_reduce_replacement is true, the kReduceWindow instruction
|
||||||
// can be optimized by replacement with simpler operations.
|
// can be optimized by replacement with simpler operations.
|
||||||
void set_enable_window_reduce_to_reduce_replacement(
|
void set_enable_window_reduce_to_reduce_replacement(
|
||||||
@ -166,7 +158,6 @@ class AlgebraicSimplifierOptions {
|
|||||||
bool enable_conv_simplification_{true};
|
bool enable_conv_simplification_{true};
|
||||||
bool enable_conv_operand_swap_{true};
|
bool enable_conv_operand_swap_{true};
|
||||||
bool enable_scalar_multiply_reduction_{false};
|
bool enable_scalar_multiply_reduction_{false};
|
||||||
bool enable_floats_are_real_{false};
|
|
||||||
bool enable_window_reduce_to_reduce_replacement_{true};
|
bool enable_window_reduce_to_reduce_replacement_{true};
|
||||||
bool enable_reduce_of_reshape_{true};
|
bool enable_reduce_of_reshape_{true};
|
||||||
bool replace_transpose_with_bitcast_{true};
|
bool replace_transpose_with_bitcast_{true};
|
||||||
|
@ -117,22 +117,6 @@ TEST_F(AlgebraicSimplifierTest, FactorFpAddition) {
|
|||||||
m::ConstantScalar(0.125))));
|
m::ConstantScalar(0.125))));
|
||||||
}
|
}
|
||||||
|
|
||||||
// (Abs(A)) * (Abs(A)) => (A*A)
|
|
||||||
TEST_F(AlgebraicSimplifierTest, SquareOfAbs) {
|
|
||||||
const char* kModuleStr = R"(
|
|
||||||
HloModule m
|
|
||||||
test {
|
|
||||||
p = f32[] parameter(0)
|
|
||||||
a = f32[] abs(p)
|
|
||||||
ROOT z = f32[] multiply(a, a)
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
|
||||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
|
||||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
|
||||||
GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
|
|
||||||
}
|
|
||||||
|
|
||||||
// (A*C1) * (B*C2) => (A*B)*(C1*C2)
|
// (A*C1) * (B*C2) => (A*B)*(C1*C2)
|
||||||
TEST_F(AlgebraicSimplifierTest, MultiplyChain) {
|
TEST_F(AlgebraicSimplifierTest, MultiplyChain) {
|
||||||
const char* kModuleStr = R"(
|
const char* kModuleStr = R"(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user