[XLA] Fix algebraic simplifier to handle instruction without layout when

transforming A / Const to A * (1 / Const).

If the instruction that contains the constant value doesn't have a layout,
algebraic simplifier crashes when creating the inverse of the constant value.
The change is to always use the shape of the literal to create the new value.

PiperOrigin-RevId: 232700905
This commit is contained in:
Bixia Zheng 2019-02-06 10:36:35 -08:00 committed by TensorFlower Gardener
parent 218df759b6
commit 3bfbcb2abd
2 changed files with 26 additions and 3 deletions

View File

@ -892,7 +892,6 @@ std::unique_ptr<HloInstruction> TryDivideToShift(HloInstruction* divide,
} // namespace
Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
Shape* shape;
HloInstruction *a, *b, *c, *d;
CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
// A/1 => A
@ -955,6 +954,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
break;
}
Shape* shape;
// exp(A)/exp(B) => exp(A-B)
if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b)))
.WithShape(m::Shape(&shape)))) {
@ -1005,8 +1005,9 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
// (Backends can do this transformation, but generally only if the constant is
// a scalar.)
if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) {
Literal new_literal(b->shape());
switch (b->shape().element_type()) {
Shape result_shape = b->literal().shape();
Literal new_literal(result_shape);
switch (result_shape.element_type()) {
case F16:
TF_RETURN_IF_ERROR(InvertConstant<half>(*b, &new_literal));
break;

View File

@ -4795,5 +4795,27 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReshapeWithoutLayout) {
EXPECT_THAT(root, op::Constant());
}
TEST_F(AlgebraicSimplifierTest, DividedByConstantInstructionWithoutLayout) {
Shape shape = ShapeUtil::MakeShape(F32, {});
shape.clear_layout();
auto builder = HloComputation::Builder(TestName());
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param"));
HloInstruction* const_value = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(20.0f)));
builder.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kDivide,
param, const_value));
std::unique_ptr<VerifiedHloModule> module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
AlgebraicSimplifierOptions options;
AlgebraicSimplifier simplifier(options);
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Multiply());
}
} // namespace
} // namespace xla