[XLA] Make XLA Scalar like
PiperOrigin-RevId: 266200816
This commit is contained in:
parent
f0ebc5b745
commit
79ba1e5243
@ -1781,6 +1781,9 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/lib/core:bits",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
|
@ -876,8 +876,7 @@ std::unique_ptr<HloInstruction> TryDivideToShift(
|
||||
int64 b_value = c->literal().GetFirstElement<T>();
|
||||
if (b_value > 0 && IsPowerOfTwo(static_cast<uint64>(b_value))) {
|
||||
// Handle negative dividends by negating the result of the division.
|
||||
HloInstruction* zero_like_a = BroadcastZeros(
|
||||
computation, a->shape().element_type(), a->shape().dimensions());
|
||||
HloInstruction* zero_like_a = MakeScalarLike(a, 0);
|
||||
|
||||
Shape changed_shape = ShapeUtil::ChangeElementType(a->shape(), PRED);
|
||||
simplifier->UpdateLayout(&changed_shape);
|
||||
@ -893,19 +892,9 @@ std::unique_ptr<HloInstruction> TryDivideToShift(
|
||||
a->shape(), HloOpcode::kSelect, dividend_is_negative,
|
||||
negated_dividend, a));
|
||||
|
||||
int log2_abs_b_value = tensorflow::Log2Floor64(b_value);
|
||||
|
||||
auto* shift_amount = computation->AddInstruction(
|
||||
simplifier->CreateConstantWithLayoutUpdated(
|
||||
LiteralUtil::CreateR0<T>(log2_abs_b_value)));
|
||||
if (!ShapeUtil::IsScalar(b->shape())) {
|
||||
shift_amount = computation->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(b->shape(), shift_amount, {}));
|
||||
}
|
||||
|
||||
auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary(
|
||||
divide->shape(), HloOpcode::kShiftRightLogical, abs_dividend,
|
||||
shift_amount));
|
||||
MakeScalarLike(abs_dividend, tensorflow::Log2Floor64(b_value))));
|
||||
|
||||
auto* neqated_quotient =
|
||||
computation->AddInstruction(HloInstruction::CreateUnary(
|
||||
@ -918,16 +907,9 @@ std::unique_ptr<HloInstruction> TryDivideToShift(
|
||||
} else {
|
||||
uint64 b_value = c->literal().GetFirstElement<T>();
|
||||
if (IsPowerOfTwo(b_value)) {
|
||||
int log2_abs_b_value = tensorflow::Log2Floor64(b_value);
|
||||
HloInstruction* shift_amount = computation->AddInstruction(
|
||||
simplifier->CreateConstantWithLayoutUpdated(
|
||||
LiteralUtil::CreateR0<T>(log2_abs_b_value)));
|
||||
if (!ShapeUtil::IsScalar(b->shape())) {
|
||||
shift_amount = computation->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(b->shape(), shift_amount, {}));
|
||||
}
|
||||
return HloInstruction::CreateBinary(
|
||||
divide->shape(), HloOpcode::kShiftRightLogical, a, shift_amount);
|
||||
divide->shape(), HloOpcode::kShiftRightLogical, a,
|
||||
MakeScalarLike(a, tensorflow::Log2Floor64(b_value)));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1915,24 +1897,18 @@ Status AlgebraicSimplifierVisitor::HandleGather(HloInstruction* gather) {
|
||||
HloInstruction::CreateBroadcast(gather->shape(), scalar, {}));
|
||||
};
|
||||
auto result = get_value(0);
|
||||
auto one = computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::One(index_shape.element_type())));
|
||||
auto index = one;
|
||||
auto pred_shape = ShapeUtil::ChangeElementType(gather->shape(), PRED);
|
||||
auto iter_shape = ShapeUtil::ChangeElementType(gather->shape(),
|
||||
index_shape.element_type());
|
||||
for (int64 i = 1; i < operand_elements; ++i) {
|
||||
auto broadcasted_index = computation_->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(iter_shape, index, {}));
|
||||
for (int64 i = 0; i < operand_elements; ++i) {
|
||||
auto index_mask =
|
||||
computation_->AddInstruction(HloInstruction::CreateCompare(
|
||||
pred_shape, gather->mutable_operand(1), broadcasted_index,
|
||||
pred_shape, gather->mutable_operand(1),
|
||||
MakeScalarLike(gather->mutable_operand(1), i),
|
||||
ComparisonDirection::kGe));
|
||||
result = computation_->AddInstruction(
|
||||
HloInstruction::CreateTernary(gather->shape(), HloOpcode::kSelect,
|
||||
index_mask, get_value(i), result));
|
||||
index = computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
index->shape(), HloOpcode::kAdd, index, one));
|
||||
}
|
||||
return ReplaceInstruction(gather, result);
|
||||
}
|
||||
@ -2380,27 +2356,18 @@ Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) {
|
||||
HloInstruction* rhs;
|
||||
CHECK(Match(compare, m::Compare(m::Op(&lhs), m::Op(&rhs))));
|
||||
|
||||
auto replace_with_pred_broadcast = [&](bool value) {
|
||||
return ReplaceWithNewInstruction(
|
||||
compare,
|
||||
HloInstruction::CreateBroadcast(
|
||||
compare->shape(),
|
||||
computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0(value))),
|
||||
{}));
|
||||
};
|
||||
if (compare->comparison_direction() == ComparisonDirection::kLt &&
|
||||
lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
|
||||
return replace_with_pred_broadcast(false);
|
||||
return ReplaceInstruction(compare, MakeScalarLike(compare, false));
|
||||
} else if (compare->comparison_direction() == ComparisonDirection::kGt &&
|
||||
IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) {
|
||||
return replace_with_pred_broadcast(false);
|
||||
return ReplaceInstruction(compare, MakeScalarLike(compare, false));
|
||||
} else if (compare->comparison_direction() == ComparisonDirection::kGe &&
|
||||
lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
|
||||
return replace_with_pred_broadcast(true);
|
||||
return ReplaceInstruction(compare, MakeScalarLike(compare, true));
|
||||
} else if (compare->comparison_direction() == ComparisonDirection::kLe &&
|
||||
IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) {
|
||||
return replace_with_pred_broadcast(true);
|
||||
return ReplaceInstruction(compare, MakeScalarLike(compare, true));
|
||||
}
|
||||
if (lhs == rhs &&
|
||||
primitive_util::IsIntegralType(lhs->shape().element_type())) {
|
||||
@ -2408,11 +2375,11 @@ Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) {
|
||||
case ComparisonDirection::kGt:
|
||||
case ComparisonDirection::kLt:
|
||||
case ComparisonDirection::kNe:
|
||||
return replace_with_pred_broadcast(false);
|
||||
return ReplaceInstruction(compare, MakeScalarLike(compare, false));
|
||||
case ComparisonDirection::kEq:
|
||||
case ComparisonDirection::kGe:
|
||||
case ComparisonDirection::kLe:
|
||||
return replace_with_pred_broadcast(true);
|
||||
return ReplaceInstruction(compare, MakeScalarLike(compare, true));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
@ -2590,16 +2557,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
|
||||
HloInstruction *lhs, *rhs;
|
||||
CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
|
||||
if (IsAll(rhs, 0)) {
|
||||
auto one = simplifier_->CreateConstantWithLayoutUpdated(
|
||||
LiteralUtil::One(power->shape().element_type()).Clone());
|
||||
std::unique_ptr<HloInstruction> ones;
|
||||
if (ShapeUtil::IsScalar(power->shape())) {
|
||||
ones = std::move(one);
|
||||
} else {
|
||||
ones = HloInstruction::CreateBroadcast(
|
||||
power->shape(), computation_->AddInstruction(std::move(one)), {});
|
||||
}
|
||||
return ReplaceWithNewInstruction(power, std::move(ones));
|
||||
return ReplaceInstruction(power, MakeScalarLike(power, 1));
|
||||
}
|
||||
|
||||
VLOG(10) << "trying transform [pow(A, 1) => A]: " << power->ToString();
|
||||
@ -2625,18 +2583,9 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
|
||||
|
||||
VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
|
||||
if (IsAll(rhs, -1)) {
|
||||
auto* one = computation_->AddInstruction(
|
||||
simplifier_->CreateConstantWithLayoutUpdated(
|
||||
LiteralUtil::One(rhs->shape().element_type()).Clone()));
|
||||
|
||||
// Explicitly broadcast scalar 1 to the output shape, to avoid implicit
|
||||
// broadcast in divide HLO as we are trying to eliminate implicit
|
||||
// broadcasting at HLO level.
|
||||
auto* broadcast_one = computation_->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(power->shape(), one, {}));
|
||||
return ReplaceWithNewInstruction(
|
||||
power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide,
|
||||
broadcast_one, lhs));
|
||||
MakeScalarLike(lhs, 1), lhs));
|
||||
}
|
||||
|
||||
VLOG(10) << "trying transform [pow(pow(A, X), Y) => pow(A, X*Y)]: "
|
||||
@ -2774,16 +2723,9 @@ std::unique_ptr<HloInstruction> TryRemainderToAnd(
|
||||
a->shape(), HloOpcode::kSelect, dividend_is_negative,
|
||||
negated_dividend, a));
|
||||
|
||||
auto* mask_amount = computation->AddInstruction(
|
||||
simplifier->CreateConstantWithLayoutUpdated(
|
||||
LiteralUtil::CreateR0<T>(b_value - 1)));
|
||||
if (!ShapeUtil::IsScalar(b->shape())) {
|
||||
mask_amount = computation->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(b->shape(), mask_amount, {}));
|
||||
}
|
||||
|
||||
auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary(
|
||||
remainder->shape(), HloOpcode::kAnd, abs_dividend, mask_amount));
|
||||
remainder->shape(), HloOpcode::kAnd, abs_dividend,
|
||||
MakeScalarLike(abs_dividend, b_value - 1)));
|
||||
|
||||
auto* neqated_quotient =
|
||||
computation->AddInstruction(HloInstruction::CreateUnary(
|
||||
@ -4052,14 +3994,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
|
||||
// Zero-sized input or filter.
|
||||
if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) ||
|
||||
ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) {
|
||||
return ReplaceWithNewInstruction(
|
||||
convolution,
|
||||
HloInstruction::CreateBroadcast(
|
||||
convolution->shape(),
|
||||
computation_->AddInstruction(
|
||||
simplifier_->CreateConstantWithLayoutUpdated(
|
||||
LiteralUtil::Zero(convolution->shape().element_type()))),
|
||||
{}));
|
||||
return ReplaceInstruction(convolution, MakeScalarLike(convolution, 0));
|
||||
}
|
||||
|
||||
// Try to merge padding/dilation of the input with the convolution's window.
|
||||
|
@ -1455,10 +1455,8 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) {
|
||||
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
|
||||
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_THAT(root, GmockMatch(m::Divide(m::Broadcast(), m::Parameter(0))));
|
||||
EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kBroadcast);
|
||||
EXPECT_EQ(root->operand(0)->operand(0)->literal().GetFirstElement<float>(),
|
||||
1);
|
||||
EXPECT_THAT(root, GmockMatch(m::Divide(m::Constant(), m::Parameter(0))));
|
||||
EXPECT_EQ(root->operand(0)->literal().GetFirstElement<float>(), 1);
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) {
|
||||
|
@ -166,6 +166,22 @@ HloInstruction* MakeR0ConstantHlo(HloComputation* computation, NativeT value) {
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<NativeT>(value)));
|
||||
}
|
||||
|
||||
// Makes a scalar that is elementwise compatible with the shape of the base
|
||||
// instruction.
|
||||
template <class NativeT>
|
||||
HloInstruction* MakeScalarLike(HloInstruction* base, NativeT value) {
|
||||
auto scalar = base->parent()->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<NativeT>(value)
|
||||
.Convert(base->shape().element_type())
|
||||
.ValueOrDie()));
|
||||
if (base->shape().rank() == 0) {
|
||||
*scalar->mutable_shape() = base->shape();
|
||||
return scalar;
|
||||
}
|
||||
return base->parent()->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(base->shape(), scalar, {}));
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Some other miscellaneous helpers to generate common HLO patterns. All of
|
||||
// these add all the instructions they generate into the computation containing
|
||||
|
Loading…
x
Reference in New Issue
Block a user