Internal change
PiperOrigin-RevId: 241258289
This commit is contained in:
parent
5e8df789cc
commit
77b06b0577
@ -1076,33 +1076,38 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
|
|||||||
//
|
//
|
||||||
// (Backends can do this transformation, but generally only if the constant is
|
// (Backends can do this transformation, but generally only if the constant is
|
||||||
// a scalar.)
|
// a scalar.)
|
||||||
if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) {
|
if (Match(divide, m::Divide(m::NonConstant(&a), m::Op(&b))) &&
|
||||||
Shape result_shape = b->literal().shape();
|
(Match(b, m::Constant(&c)) || Match(b, m::Broadcast(m::Constant(&c))))) {
|
||||||
|
Shape result_shape = c->literal().shape();
|
||||||
Literal new_literal(result_shape);
|
Literal new_literal(result_shape);
|
||||||
switch (result_shape.element_type()) {
|
switch (result_shape.element_type()) {
|
||||||
case F16:
|
case F16:
|
||||||
TF_RETURN_IF_ERROR(InvertConstant<half>(*b, &new_literal));
|
TF_RETURN_IF_ERROR(InvertConstant<half>(*c, &new_literal));
|
||||||
break;
|
break;
|
||||||
case F32:
|
case F32:
|
||||||
TF_RETURN_IF_ERROR(InvertConstant<float>(*b, &new_literal));
|
TF_RETURN_IF_ERROR(InvertConstant<float>(*c, &new_literal));
|
||||||
break;
|
break;
|
||||||
case BF16:
|
case BF16:
|
||||||
TF_RETURN_IF_ERROR(InvertConstant<bfloat16>(*b, &new_literal));
|
TF_RETURN_IF_ERROR(InvertConstant<bfloat16>(*c, &new_literal));
|
||||||
break;
|
break;
|
||||||
case F64:
|
case F64:
|
||||||
TF_RETURN_IF_ERROR(InvertConstant<double>(*b, &new_literal));
|
TF_RETURN_IF_ERROR(InvertConstant<double>(*c, &new_literal));
|
||||||
break;
|
break;
|
||||||
case C64:
|
case C64:
|
||||||
TF_RETURN_IF_ERROR(InvertConstant<complex64>(*b, &new_literal));
|
TF_RETURN_IF_ERROR(InvertConstant<complex64>(*c, &new_literal));
|
||||||
break;
|
break;
|
||||||
case C128:
|
case C128:
|
||||||
TF_RETURN_IF_ERROR(InvertConstant<complex128>(*b, &new_literal));
|
TF_RETURN_IF_ERROR(InvertConstant<complex128>(*c, &new_literal));
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
auto inverse = computation_->AddInstruction(
|
auto inverse = computation_->AddInstruction(
|
||||||
simplifier_->CreateConstantWithLayoutUpdated((new_literal.Clone())));
|
simplifier_->CreateConstantWithLayoutUpdated(new_literal.Clone()));
|
||||||
|
if (b != c) {
|
||||||
|
inverse = computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
||||||
|
b->shape(), inverse, b->dimensions()));
|
||||||
|
}
|
||||||
TF_ASSIGN_OR_RETURN(auto new_divide,
|
TF_ASSIGN_OR_RETURN(auto new_divide,
|
||||||
MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
|
MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
|
||||||
return ReplaceInstruction(divide, new_divide);
|
return ReplaceInstruction(divide, new_divide);
|
||||||
|
@ -853,6 +853,26 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) {
|
|||||||
GmockMatch(m::Multiply(m::Parameter(0), m::Constant())));
|
GmockMatch(m::Multiply(m::Parameter(0), m::Constant())));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A / Broadcast(Const) => A * Broadcast(InvertedConst)
|
||||||
|
TEST_F(AlgebraicSimplifierTest, DivideByBroadcastedConstant) {
|
||||||
|
const char* kModuleStr = R"(
|
||||||
|
HloModule m
|
||||||
|
test {
|
||||||
|
p = f32[4] parameter(0)
|
||||||
|
c = f32[] constant(256.0)
|
||||||
|
b = f32[4] broadcast(c), dimensions={}
|
||||||
|
ROOT d = f32[4] divide(p, b)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||||
|
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||||
|
|
||||||
|
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||||
|
GmockMatch(m::Multiply(
|
||||||
|
m::Parameter(0),
|
||||||
|
m::Broadcast(m::Op().IsConstantScalar(1.0f / 256.0f)))));
|
||||||
|
}
|
||||||
|
|
||||||
// pow(pow(A, X), Y) => pow(A, X*Y)
|
// pow(pow(A, X), Y) => pow(A, X*Y)
|
||||||
TEST_F(AlgebraicSimplifierTest, PowerOfPower) {
|
TEST_F(AlgebraicSimplifierTest, PowerOfPower) {
|
||||||
auto m = CreateNewVerifiedModule();
|
auto m = CreateNewVerifiedModule();
|
||||||
|
Loading…
Reference in New Issue
Block a user