[XLA] Fix algebraic simplifier to handle instruction without layout when
transforming A / Const to A * (1 / Const). If the instruction that contains the constant value doesn't have a layout, algebraic simplifier crashes when creating the inverse of the constant value. The change is to always use the shape of the literal to create the new value. PiperOrigin-RevId: 232700905
This commit is contained in:
parent
218df759b6
commit
3bfbcb2abd
@ -892,7 +892,6 @@ std::unique_ptr<HloInstruction> TryDivideToShift(HloInstruction* divide,
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
|
Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
|
||||||
Shape* shape;
|
|
||||||
HloInstruction *a, *b, *c, *d;
|
HloInstruction *a, *b, *c, *d;
|
||||||
CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
|
CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
|
||||||
// A/1 => A
|
// A/1 => A
|
||||||
@ -955,6 +954,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Shape* shape;
|
||||||
// exp(A)/exp(B) => exp(A-B)
|
// exp(A)/exp(B) => exp(A-B)
|
||||||
if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b)))
|
if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b)))
|
||||||
.WithShape(m::Shape(&shape)))) {
|
.WithShape(m::Shape(&shape)))) {
|
||||||
@ -1005,8 +1005,9 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
|
|||||||
// (Backends can do this transformation, but generally only if the constant is
|
// (Backends can do this transformation, but generally only if the constant is
|
||||||
// a scalar.)
|
// a scalar.)
|
||||||
if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) {
|
if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) {
|
||||||
Literal new_literal(b->shape());
|
Shape result_shape = b->literal().shape();
|
||||||
switch (b->shape().element_type()) {
|
Literal new_literal(result_shape);
|
||||||
|
switch (result_shape.element_type()) {
|
||||||
case F16:
|
case F16:
|
||||||
TF_RETURN_IF_ERROR(InvertConstant<half>(*b, &new_literal));
|
TF_RETURN_IF_ERROR(InvertConstant<half>(*b, &new_literal));
|
||||||
break;
|
break;
|
||||||
|
@ -4795,5 +4795,27 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReshapeWithoutLayout) {
|
|||||||
EXPECT_THAT(root, op::Constant());
|
EXPECT_THAT(root, op::Constant());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(AlgebraicSimplifierTest, DividedByConstantInstructionWithoutLayout) {
|
||||||
|
Shape shape = ShapeUtil::MakeShape(F32, {});
|
||||||
|
shape.clear_layout();
|
||||||
|
auto builder = HloComputation::Builder(TestName());
|
||||||
|
HloInstruction* param = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, shape, "param"));
|
||||||
|
|
||||||
|
HloInstruction* const_value = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(20.0f)));
|
||||||
|
builder.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kDivide,
|
||||||
|
param, const_value));
|
||||||
|
|
||||||
|
std::unique_ptr<VerifiedHloModule> module = CreateNewVerifiedModule();
|
||||||
|
module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
|
AlgebraicSimplifierOptions options;
|
||||||
|
AlgebraicSimplifier simplifier(options);
|
||||||
|
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||||
|
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||||
|
EXPECT_THAT(root, op::Multiply());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
x
Reference in New Issue
Block a user