Internal change

PiperOrigin-RevId: 241258289
This commit is contained in:
David Majnemer 2019-03-31 21:44:54 -07:00 committed by TensorFlower Gardener
parent 5e8df789cc
commit 77b06b0577
2 changed files with 34 additions and 9 deletions

View File

@ -1076,33 +1076,38 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
//
// (Backends can do this transformation, but generally only if the constant is
// a scalar.)
if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) {
Shape result_shape = b->literal().shape();
if (Match(divide, m::Divide(m::NonConstant(&a), m::Op(&b))) &&
(Match(b, m::Constant(&c)) || Match(b, m::Broadcast(m::Constant(&c))))) {
Shape result_shape = c->literal().shape();
Literal new_literal(result_shape);
switch (result_shape.element_type()) {
case F16:
TF_RETURN_IF_ERROR(InvertConstant<half>(*b, &new_literal));
TF_RETURN_IF_ERROR(InvertConstant<half>(*c, &new_literal));
break;
case F32:
TF_RETURN_IF_ERROR(InvertConstant<float>(*b, &new_literal));
TF_RETURN_IF_ERROR(InvertConstant<float>(*c, &new_literal));
break;
case BF16:
TF_RETURN_IF_ERROR(InvertConstant<bfloat16>(*b, &new_literal));
TF_RETURN_IF_ERROR(InvertConstant<bfloat16>(*c, &new_literal));
break;
case F64:
TF_RETURN_IF_ERROR(InvertConstant<double>(*b, &new_literal));
TF_RETURN_IF_ERROR(InvertConstant<double>(*c, &new_literal));
break;
case C64:
TF_RETURN_IF_ERROR(InvertConstant<complex64>(*b, &new_literal));
TF_RETURN_IF_ERROR(InvertConstant<complex64>(*c, &new_literal));
break;
case C128:
TF_RETURN_IF_ERROR(InvertConstant<complex128>(*b, &new_literal));
TF_RETURN_IF_ERROR(InvertConstant<complex128>(*c, &new_literal));
break;
default:
return Status::OK();
}
auto inverse = computation_->AddInstruction(
simplifier_->CreateConstantWithLayoutUpdated((new_literal.Clone())));
simplifier_->CreateConstantWithLayoutUpdated(new_literal.Clone()));
if (b != c) {
inverse = computation_->AddInstruction(HloInstruction::CreateBroadcast(
b->shape(), inverse, b->dimensions()));
}
TF_ASSIGN_OR_RETURN(auto new_divide,
MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
return ReplaceInstruction(divide, new_divide);

View File

@ -853,6 +853,26 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) {
GmockMatch(m::Multiply(m::Parameter(0), m::Constant())));
}
// A / Broadcast(Const) => A * Broadcast(InvertedConst)
TEST_F(AlgebraicSimplifierTest, DivideByBroadcastedConstant) {
const char* kModuleStr = R"(
HloModule m
test {
p = f32[4] parameter(0)
c = f32[] constant(256.0)
b = f32[4] broadcast(c), dimensions={}
ROOT d = f32[4] divide(p, b)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Multiply(
m::Parameter(0),
m::Broadcast(m::Op().IsConstantScalar(1.0f / 256.0f)))));
}
// pow(pow(A, X), Y) => pow(A, X*Y)
TEST_F(AlgebraicSimplifierTest, PowerOfPower) {
auto m = CreateNewVerifiedModule();