[XLA] Make XLA Scalar like

PiperOrigin-RevId: 266200816
This commit is contained in:
Blake Hechtman 2019-08-29 12:25:28 -07:00 committed by TensorFlower Gardener
parent f0ebc5b745
commit 79ba1e5243
4 changed files with 39 additions and 87 deletions

View File

@ -1781,6 +1781,9 @@ cc_library(
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core/lib/core:bits",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:types",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",

View File

@ -876,8 +876,7 @@ std::unique_ptr<HloInstruction> TryDivideToShift(
int64 b_value = c->literal().GetFirstElement<T>();
if (b_value > 0 && IsPowerOfTwo(static_cast<uint64>(b_value))) {
// Handle negative dividends by negating the result of the division.
HloInstruction* zero_like_a = BroadcastZeros(
computation, a->shape().element_type(), a->shape().dimensions());
HloInstruction* zero_like_a = MakeScalarLike(a, 0);
Shape changed_shape = ShapeUtil::ChangeElementType(a->shape(), PRED);
simplifier->UpdateLayout(&changed_shape);
@ -893,19 +892,9 @@ std::unique_ptr<HloInstruction> TryDivideToShift(
a->shape(), HloOpcode::kSelect, dividend_is_negative,
negated_dividend, a));
int log2_abs_b_value = tensorflow::Log2Floor64(b_value);
auto* shift_amount = computation->AddInstruction(
simplifier->CreateConstantWithLayoutUpdated(
LiteralUtil::CreateR0<T>(log2_abs_b_value)));
if (!ShapeUtil::IsScalar(b->shape())) {
shift_amount = computation->AddInstruction(
HloInstruction::CreateBroadcast(b->shape(), shift_amount, {}));
}
auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary(
divide->shape(), HloOpcode::kShiftRightLogical, abs_dividend,
shift_amount));
MakeScalarLike(abs_dividend, tensorflow::Log2Floor64(b_value))));
auto* neqated_quotient =
computation->AddInstruction(HloInstruction::CreateUnary(
@ -918,16 +907,9 @@ std::unique_ptr<HloInstruction> TryDivideToShift(
} else {
uint64 b_value = c->literal().GetFirstElement<T>();
if (IsPowerOfTwo(b_value)) {
int log2_abs_b_value = tensorflow::Log2Floor64(b_value);
HloInstruction* shift_amount = computation->AddInstruction(
simplifier->CreateConstantWithLayoutUpdated(
LiteralUtil::CreateR0<T>(log2_abs_b_value)));
if (!ShapeUtil::IsScalar(b->shape())) {
shift_amount = computation->AddInstruction(
HloInstruction::CreateBroadcast(b->shape(), shift_amount, {}));
}
return HloInstruction::CreateBinary(
divide->shape(), HloOpcode::kShiftRightLogical, a, shift_amount);
divide->shape(), HloOpcode::kShiftRightLogical, a,
MakeScalarLike(a, tensorflow::Log2Floor64(b_value)));
}
}
@ -1915,24 +1897,18 @@ Status AlgebraicSimplifierVisitor::HandleGather(HloInstruction* gather) {
HloInstruction::CreateBroadcast(gather->shape(), scalar, {}));
};
auto result = get_value(0);
auto one = computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::One(index_shape.element_type())));
auto index = one;
auto pred_shape = ShapeUtil::ChangeElementType(gather->shape(), PRED);
auto iter_shape = ShapeUtil::ChangeElementType(gather->shape(),
index_shape.element_type());
for (int64 i = 1; i < operand_elements; ++i) {
auto broadcasted_index = computation_->AddInstruction(
HloInstruction::CreateBroadcast(iter_shape, index, {}));
for (int64 i = 0; i < operand_elements; ++i) {
auto index_mask =
computation_->AddInstruction(HloInstruction::CreateCompare(
pred_shape, gather->mutable_operand(1), broadcasted_index,
pred_shape, gather->mutable_operand(1),
MakeScalarLike(gather->mutable_operand(1), i),
ComparisonDirection::kGe));
result = computation_->AddInstruction(
HloInstruction::CreateTernary(gather->shape(), HloOpcode::kSelect,
index_mask, get_value(i), result));
index = computation_->AddInstruction(HloInstruction::CreateBinary(
index->shape(), HloOpcode::kAdd, index, one));
}
return ReplaceInstruction(gather, result);
}
@ -2380,27 +2356,18 @@ Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) {
HloInstruction* rhs;
CHECK(Match(compare, m::Compare(m::Op(&lhs), m::Op(&rhs))));
auto replace_with_pred_broadcast = [&](bool value) {
return ReplaceWithNewInstruction(
compare,
HloInstruction::CreateBroadcast(
compare->shape(),
computation_->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0(value))),
{}));
};
if (compare->comparison_direction() == ComparisonDirection::kLt &&
lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
return replace_with_pred_broadcast(false);
return ReplaceInstruction(compare, MakeScalarLike(compare, false));
} else if (compare->comparison_direction() == ComparisonDirection::kGt &&
IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) {
return replace_with_pred_broadcast(false);
return ReplaceInstruction(compare, MakeScalarLike(compare, false));
} else if (compare->comparison_direction() == ComparisonDirection::kGe &&
lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
return replace_with_pred_broadcast(true);
return ReplaceInstruction(compare, MakeScalarLike(compare, true));
} else if (compare->comparison_direction() == ComparisonDirection::kLe &&
IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) {
return replace_with_pred_broadcast(true);
return ReplaceInstruction(compare, MakeScalarLike(compare, true));
}
if (lhs == rhs &&
primitive_util::IsIntegralType(lhs->shape().element_type())) {
@ -2408,11 +2375,11 @@ Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) {
case ComparisonDirection::kGt:
case ComparisonDirection::kLt:
case ComparisonDirection::kNe:
return replace_with_pred_broadcast(false);
return ReplaceInstruction(compare, MakeScalarLike(compare, false));
case ComparisonDirection::kEq:
case ComparisonDirection::kGe:
case ComparisonDirection::kLe:
return replace_with_pred_broadcast(true);
return ReplaceInstruction(compare, MakeScalarLike(compare, true));
}
}
return Status::OK();
@ -2590,16 +2557,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
HloInstruction *lhs, *rhs;
CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
if (IsAll(rhs, 0)) {
auto one = simplifier_->CreateConstantWithLayoutUpdated(
LiteralUtil::One(power->shape().element_type()).Clone());
std::unique_ptr<HloInstruction> ones;
if (ShapeUtil::IsScalar(power->shape())) {
ones = std::move(one);
} else {
ones = HloInstruction::CreateBroadcast(
power->shape(), computation_->AddInstruction(std::move(one)), {});
}
return ReplaceWithNewInstruction(power, std::move(ones));
return ReplaceInstruction(power, MakeScalarLike(power, 1));
}
VLOG(10) << "trying transform [pow(A, 1) => A]: " << power->ToString();
@ -2625,18 +2583,9 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
if (IsAll(rhs, -1)) {
auto* one = computation_->AddInstruction(
simplifier_->CreateConstantWithLayoutUpdated(
LiteralUtil::One(rhs->shape().element_type()).Clone()));
// Explicitly broadcast scalar 1 to the output shape, to avoid implicit
// broadcast in divide HLO as we are trying to eliminate implicit
// broadcasting at HLO level.
auto* broadcast_one = computation_->AddInstruction(
HloInstruction::CreateBroadcast(power->shape(), one, {}));
return ReplaceWithNewInstruction(
power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide,
broadcast_one, lhs));
MakeScalarLike(lhs, 1), lhs));
}
VLOG(10) << "trying transform [pow(pow(A, X), Y) => pow(A, X*Y)]: "
@ -2774,16 +2723,9 @@ std::unique_ptr<HloInstruction> TryRemainderToAnd(
a->shape(), HloOpcode::kSelect, dividend_is_negative,
negated_dividend, a));
auto* mask_amount = computation->AddInstruction(
simplifier->CreateConstantWithLayoutUpdated(
LiteralUtil::CreateR0<T>(b_value - 1)));
if (!ShapeUtil::IsScalar(b->shape())) {
mask_amount = computation->AddInstruction(
HloInstruction::CreateBroadcast(b->shape(), mask_amount, {}));
}
auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary(
remainder->shape(), HloOpcode::kAnd, abs_dividend, mask_amount));
remainder->shape(), HloOpcode::kAnd, abs_dividend,
MakeScalarLike(abs_dividend, b_value - 1)));
auto* neqated_quotient =
computation->AddInstruction(HloInstruction::CreateUnary(
@ -4052,14 +3994,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
// Zero-sized input or filter.
if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) ||
ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) {
return ReplaceWithNewInstruction(
convolution,
HloInstruction::CreateBroadcast(
convolution->shape(),
computation_->AddInstruction(
simplifier_->CreateConstantWithLayoutUpdated(
LiteralUtil::Zero(convolution->shape().element_type()))),
{}));
return ReplaceInstruction(convolution, MakeScalarLike(convolution, 0));
}
// Try to merge padding/dilation of the input with the convolution's window.

View File

@ -1455,10 +1455,8 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) {
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, GmockMatch(m::Divide(m::Broadcast(), m::Parameter(0))));
EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kBroadcast);
EXPECT_EQ(root->operand(0)->operand(0)->literal().GetFirstElement<float>(),
1);
EXPECT_THAT(root, GmockMatch(m::Divide(m::Constant(), m::Parameter(0))));
EXPECT_EQ(root->operand(0)->literal().GetFirstElement<float>(), 1);
}
TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) {

View File

@ -166,6 +166,22 @@ HloInstruction* MakeR0ConstantHlo(HloComputation* computation, NativeT value) {
HloInstruction::CreateConstant(LiteralUtil::CreateR0<NativeT>(value)));
}
// Makes a scalar that is elementwise compatible with the shape of the base
// instruction.
template <class NativeT>
HloInstruction* MakeScalarLike(HloInstruction* base, NativeT value) {
auto scalar = base->parent()->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<NativeT>(value)
.Convert(base->shape().element_type())
.ValueOrDie()));
if (base->shape().rank() == 0) {
*scalar->mutable_shape() = base->shape();
return scalar;
}
return base->parent()->AddInstruction(
HloInstruction::CreateBroadcast(base->shape(), scalar, {}));
}
// -----------------------------------------------------------------------------
// Some other miscellaneous helpers to generate common HLO patterns. All of
// these add all the instructions they generate into the computation containing