Internal change

PiperOrigin-RevId: 241258289
This commit is contained in:
David Majnemer 2019-03-31 21:44:54 -07:00 committed by TensorFlower Gardener
parent 5e8df789cc
commit 77b06b0577
2 changed files with 34 additions and 9 deletions

View File

@ -1076,33 +1076,38 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
// //
// (Backends can do this transformation, but generally only if the constant is // (Backends can do this transformation, but generally only if the constant is
// a scalar.) // a scalar.)
if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) { if (Match(divide, m::Divide(m::NonConstant(&a), m::Op(&b))) &&
Shape result_shape = b->literal().shape(); (Match(b, m::Constant(&c)) || Match(b, m::Broadcast(m::Constant(&c))))) {
Shape result_shape = c->literal().shape();
Literal new_literal(result_shape); Literal new_literal(result_shape);
switch (result_shape.element_type()) { switch (result_shape.element_type()) {
case F16: case F16:
TF_RETURN_IF_ERROR(InvertConstant<half>(*b, &new_literal)); TF_RETURN_IF_ERROR(InvertConstant<half>(*c, &new_literal));
break; break;
case F32: case F32:
TF_RETURN_IF_ERROR(InvertConstant<float>(*b, &new_literal)); TF_RETURN_IF_ERROR(InvertConstant<float>(*c, &new_literal));
break; break;
case BF16: case BF16:
TF_RETURN_IF_ERROR(InvertConstant<bfloat16>(*b, &new_literal)); TF_RETURN_IF_ERROR(InvertConstant<bfloat16>(*c, &new_literal));
break; break;
case F64: case F64:
TF_RETURN_IF_ERROR(InvertConstant<double>(*b, &new_literal)); TF_RETURN_IF_ERROR(InvertConstant<double>(*c, &new_literal));
break; break;
case C64: case C64:
TF_RETURN_IF_ERROR(InvertConstant<complex64>(*b, &new_literal)); TF_RETURN_IF_ERROR(InvertConstant<complex64>(*c, &new_literal));
break; break;
case C128: case C128:
TF_RETURN_IF_ERROR(InvertConstant<complex128>(*b, &new_literal)); TF_RETURN_IF_ERROR(InvertConstant<complex128>(*c, &new_literal));
break; break;
default: default:
return Status::OK(); return Status::OK();
} }
auto inverse = computation_->AddInstruction( auto inverse = computation_->AddInstruction(
simplifier_->CreateConstantWithLayoutUpdated((new_literal.Clone()))); simplifier_->CreateConstantWithLayoutUpdated(new_literal.Clone()));
if (b != c) {
inverse = computation_->AddInstruction(HloInstruction::CreateBroadcast(
b->shape(), inverse, b->dimensions()));
}
TF_ASSIGN_OR_RETURN(auto new_divide, TF_ASSIGN_OR_RETURN(auto new_divide,
MakeBinaryHlo(HloOpcode::kMultiply, a, inverse)); MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
return ReplaceInstruction(divide, new_divide); return ReplaceInstruction(divide, new_divide);

View File

@ -853,6 +853,26 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) {
GmockMatch(m::Multiply(m::Parameter(0), m::Constant()))); GmockMatch(m::Multiply(m::Parameter(0), m::Constant())));
} }
// A / Broadcast(Const) => A * Broadcast(InvertedConst)
TEST_F(AlgebraicSimplifierTest, DivideByBroadcastedConstant) {
const char* kModuleStr = R"(
HloModule m
test {
p = f32[4] parameter(0)
c = f32[] constant(256.0)
b = f32[4] broadcast(c), dimensions={}
ROOT d = f32[4] divide(p, b)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Multiply(
m::Parameter(0),
m::Broadcast(m::Op().IsConstantScalar(1.0f / 256.0f)))));
}
// pow(pow(A, X), Y) => pow(A, X*Y) // pow(pow(A, X), Y) => pow(A, X*Y)
TEST_F(AlgebraicSimplifierTest, PowerOfPower) { TEST_F(AlgebraicSimplifierTest, PowerOfPower) {
auto m = CreateNewVerifiedModule(); auto m = CreateNewVerifiedModule();