Internal change
PiperOrigin-RevId: 241258289
This commit is contained in:
parent
5e8df789cc
commit
77b06b0577
@ -1076,33 +1076,38 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
|
||||
//
|
||||
// (Backends can do this transformation, but generally only if the constant is
|
||||
// a scalar.)
|
||||
if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) {
|
||||
Shape result_shape = b->literal().shape();
|
||||
if (Match(divide, m::Divide(m::NonConstant(&a), m::Op(&b))) &&
|
||||
(Match(b, m::Constant(&c)) || Match(b, m::Broadcast(m::Constant(&c))))) {
|
||||
Shape result_shape = c->literal().shape();
|
||||
Literal new_literal(result_shape);
|
||||
switch (result_shape.element_type()) {
|
||||
case F16:
|
||||
TF_RETURN_IF_ERROR(InvertConstant<half>(*b, &new_literal));
|
||||
TF_RETURN_IF_ERROR(InvertConstant<half>(*c, &new_literal));
|
||||
break;
|
||||
case F32:
|
||||
TF_RETURN_IF_ERROR(InvertConstant<float>(*b, &new_literal));
|
||||
TF_RETURN_IF_ERROR(InvertConstant<float>(*c, &new_literal));
|
||||
break;
|
||||
case BF16:
|
||||
TF_RETURN_IF_ERROR(InvertConstant<bfloat16>(*b, &new_literal));
|
||||
TF_RETURN_IF_ERROR(InvertConstant<bfloat16>(*c, &new_literal));
|
||||
break;
|
||||
case F64:
|
||||
TF_RETURN_IF_ERROR(InvertConstant<double>(*b, &new_literal));
|
||||
TF_RETURN_IF_ERROR(InvertConstant<double>(*c, &new_literal));
|
||||
break;
|
||||
case C64:
|
||||
TF_RETURN_IF_ERROR(InvertConstant<complex64>(*b, &new_literal));
|
||||
TF_RETURN_IF_ERROR(InvertConstant<complex64>(*c, &new_literal));
|
||||
break;
|
||||
case C128:
|
||||
TF_RETURN_IF_ERROR(InvertConstant<complex128>(*b, &new_literal));
|
||||
TF_RETURN_IF_ERROR(InvertConstant<complex128>(*c, &new_literal));
|
||||
break;
|
||||
default:
|
||||
return Status::OK();
|
||||
}
|
||||
auto inverse = computation_->AddInstruction(
|
||||
simplifier_->CreateConstantWithLayoutUpdated((new_literal.Clone())));
|
||||
simplifier_->CreateConstantWithLayoutUpdated(new_literal.Clone()));
|
||||
if (b != c) {
|
||||
inverse = computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
b->shape(), inverse, b->dimensions()));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto new_divide,
|
||||
MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
|
||||
return ReplaceInstruction(divide, new_divide);
|
||||
|
@ -853,6 +853,26 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) {
|
||||
GmockMatch(m::Multiply(m::Parameter(0), m::Constant())));
|
||||
}
|
||||
|
||||
// A / Broadcast(Const) => A * Broadcast(InvertedConst)
|
||||
TEST_F(AlgebraicSimplifierTest, DivideByBroadcastedConstant) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
p = f32[4] parameter(0)
|
||||
c = f32[] constant(256.0)
|
||||
b = f32[4] broadcast(c), dimensions={}
|
||||
ROOT d = f32[4] divide(p, b)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Multiply(
|
||||
m::Parameter(0),
|
||||
m::Broadcast(m::Op().IsConstantScalar(1.0f / 256.0f)))));
|
||||
}
|
||||
|
||||
// pow(pow(A, X), Y) => pow(A, X*Y)
|
||||
TEST_F(AlgebraicSimplifierTest, PowerOfPower) {
|
||||
auto m = CreateNewVerifiedModule();
|
||||
|
Loading…
Reference in New Issue
Block a user