[XLA] Fix algebraic simplifier to handle instruction without layout when
transforming A / Const to A * (1 / Const). If the instruction that contains the constant value doesn't have a layout, algebraic simplifier crashes when creating the inverse of the constant value. The change is to always use the shape of the literal to create the new value. PiperOrigin-RevId: 232700905
This commit is contained in:
parent
218df759b6
commit
3bfbcb2abd
@ -892,7 +892,6 @@ std::unique_ptr<HloInstruction> TryDivideToShift(HloInstruction* divide,
|
||||
} // namespace
|
||||
|
||||
Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
|
||||
Shape* shape;
|
||||
HloInstruction *a, *b, *c, *d;
|
||||
CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
|
||||
// A/1 => A
|
||||
@ -955,6 +954,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
|
||||
break;
|
||||
}
|
||||
|
||||
Shape* shape;
|
||||
// exp(A)/exp(B) => exp(A-B)
|
||||
if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b)))
|
||||
.WithShape(m::Shape(&shape)))) {
|
||||
@ -1005,8 +1005,9 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
|
||||
// (Backends can do this transformation, but generally only if the constant is
|
||||
// a scalar.)
|
||||
if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) {
|
||||
Literal new_literal(b->shape());
|
||||
switch (b->shape().element_type()) {
|
||||
Shape result_shape = b->literal().shape();
|
||||
Literal new_literal(result_shape);
|
||||
switch (result_shape.element_type()) {
|
||||
case F16:
|
||||
TF_RETURN_IF_ERROR(InvertConstant<half>(*b, &new_literal));
|
||||
break;
|
||||
|
@ -4795,5 +4795,27 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReshapeWithoutLayout) {
|
||||
EXPECT_THAT(root, op::Constant());
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, DividedByConstantInstructionWithoutLayout) {
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {});
|
||||
shape.clear_layout();
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
HloInstruction* param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, shape, "param"));
|
||||
|
||||
HloInstruction* const_value = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(20.0f)));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kDivide,
|
||||
param, const_value));
|
||||
|
||||
std::unique_ptr<VerifiedHloModule> module = CreateNewVerifiedModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
AlgebraicSimplifierOptions options;
|
||||
AlgebraicSimplifier simplifier(options);
|
||||
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Multiply());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user