Simplify and fix some bugs in constant folding of neutral/absorbing element optimizations.

PiperOrigin-RevId: 178293088
This commit is contained in:
A. Unique TensorFlower 2017-12-07 14:12:44 -08:00 committed by TensorFlower Gardener
parent 0509f07cc2
commit 2ea11416c9
2 changed files with 148 additions and 31 deletions

View File

@ -680,6 +680,7 @@ Status CreateConstantTensorAttrValue(DataType type, double value,
const TensorShapeProto& shape, const TensorShapeProto& shape,
AttrValue* attr_tensor) { AttrValue* attr_tensor) {
TensorProto* t = attr_tensor->mutable_tensor(); TensorProto* t = attr_tensor->mutable_tensor();
t->set_dtype(type);
*t->mutable_tensor_shape() = shape; *t->mutable_tensor_shape() = shape;
switch (type) { switch (type) {
SET_TENSOR_VAL_CASE(DT_FLOAT, float, float); SET_TENSOR_VAL_CASE(DT_FLOAT, float, float);
@ -1332,45 +1333,47 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
} }
const TensorShapeProto& output_shape = const TensorShapeProto& output_shape =
properties.GetOutputProperties(node.name())[0].shape(); properties.GetOutputProperties(node.name())[0].shape();
const TensorShapeProto& x_shape =
properties.GetInputProperties(node.name())[0].shape(); // Simplify element-wise multiplication by ones or addition of zeros.
const TensorShapeProto& y_shape = const TensorShapeProto& y_shape =
properties.GetInputProperties(node.name())[1].shape(); properties.GetInputProperties(node.name())[1].shape();
const bool x_is_zero = IsZeros(*x); const bool x_is_zero = IsZeros(*x);
const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape); const bool x_is_one = IsOnes(*x);
const bool y_is_zero = IsZeros(*y);
const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape); const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
if (y_matches_output_shape &&
// Simplify addition of zeros. ((is_mul && x_is_one) || (is_add && x_is_zero))) {
if (is_add) { // 1 * y = y or 0 + y = y.
if (x_is_zero && y_matches_output_shape) {
// 0 + y = y.
ReplaceAddOrMulWithIdentity(1, &node);
continue;
} else if (y_is_zero && x_matches_output_shape) {
// x + 0 = y.
ReplaceAddOrMulWithIdentity(0, &node);
continue;
}
}
// Simplify element-wise multiplication by ones.
if (is_mul) {
if (IsOnes(*x) && y_matches_output_shape) {
// 1 * y = y.
ReplaceAddOrMulWithIdentity(1, &node); ReplaceAddOrMulWithIdentity(1, &node);
continue; continue;
} }
if (IsOnes(*y) && x_matches_output_shape) { const TensorShapeProto& x_shape =
// x * 1 = x. properties.GetInputProperties(node.name())[0].shape();
const bool y_is_zero = IsZeros(*y);
const bool y_is_one = IsOnes(*y);
const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
if (x_matches_output_shape &&
((is_mul && y_is_one) || (is_add && y_is_zero))) {
// x * 1 = x or x + 0 = x
ReplaceAddOrMulWithIdentity(0, &node); ReplaceAddOrMulWithIdentity(0, &node);
continue; continue;
} }
}
// Simplify multiplication and matmul by zeros. // Simplify multiplication and matmul by zeros.
if (x_is_zero || y_is_zero) { if (!is_add && (x_is_zero || y_is_zero)) {
TF_RETURN_IF_ERROR(ReplaceAddOrMulWithConstant(0, output_shape, &node)); const PartialTensorShape shp(output_shape);
if (shp.IsFullyDefined()) {
TF_RETURN_IF_ERROR(
ReplaceAddOrMulWithConstant(0, output_shape, &node));
continue;
}
// Even if an input shape is only partially known, we may known that it
// matches the output shape and thus forward the corresponding zero
// input.
if (is_mul && x_is_zero && x_matches_output_shape) {
ReplaceAddOrMulWithIdentity(0, &node);
} else if (is_mul && y_is_zero && y_matches_output_shape) {
ReplaceAddOrMulWithIdentity(1, &node);
}
} }
} }
} }

View File

@ -198,6 +198,120 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
} }
} }
TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_UnknownOutputShape) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output x_known =
ops::Placeholder(s.WithOpName("x_known"), DT_FLOAT,
ops::Placeholder::Shape(TensorShape({2, 2})));
Output x_partially_known =
ops::Placeholder(s.WithOpName("x_partially_unknown"), DT_FLOAT,
ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
Output x_unknown = ops::Placeholder(s.WithOpName("x_unknown"), DT_FLOAT);
Output zeros_known = ops::ZerosLike(s.WithOpName("zeros_known"), x_known);
Output zeros_partially_known =
ops::ZerosLike(s.WithOpName("zeros_partially_known"), x_partially_known);
Output zeros_unknown =
ops::ZerosLike(s.WithOpName("zeros_unknown"), x_unknown);
// Multiplies without any additional ops to supply the output shape.
int count = 0;
std::vector<Output> muls;
std::unordered_set<string> not_converted;
std::unordered_set<string> to_const;
std::unordered_set<string> to_identity;
for (const auto* x : {&x_known, &x_partially_known, &x_unknown}) {
for (const auto* zeros :
{&zeros_known, &zeros_partially_known, &zeros_unknown}) {
const string name = strings::StrCat("mul_", count++);
muls.push_back(ops::Mul(s.WithOpName(name), *x, *zeros));
if (x == &x_partially_known && zeros == &zeros_partially_known) {
to_identity.insert(name);
} else if (x == &x_unknown || zeros == &zeros_unknown) {
not_converted.insert(name);
} else {
to_const.insert(name);
}
}
}
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
nullptr /* cpu_device */);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
LOG(INFO) << output.DebugString();
EXPECT_EQ(15, output.node_size());
for (int i = 0; i < output.node_size(); ++i) {
const NodeDef& node = output.node(i);
const string& name = node.name();
if (to_const.count(name) > 0) {
EXPECT_EQ("Const", node.op()) << node.name();
} else if (to_identity.count(name) > 0) {
EXPECT_EQ("Identity", node.op()) << node.name();
} else if (not_converted.count(name) > 0) {
EXPECT_EQ("Mul", node.op()) << node.name();
}
}
}
TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_KnownOutputShape) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output known_shape = ops::Const(s.WithOpName("known_shape"), 0.0f, {2, 2});
Output x_partially_known =
ops::Placeholder(s.WithOpName("x_partially_unknown"), DT_FLOAT,
ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
Output x_unknown = ops::Placeholder(s.WithOpName("x_unknown"), DT_FLOAT);
Output zeros_partially_known =
ops::ZerosLike(s.WithOpName("zeros_partially_known"), x_partially_known);
Output zeros_unknown =
ops::ZerosLike(s.WithOpName("zeros_unknown"), x_unknown);
// If at least one of the inputs to AddN has a known shape, shape inference
// will propagate the shape back to the inputs of AddN, making the
// output shapes of all its inputs known
std::vector<Output> muls_deduced_output_shape;
std::unordered_set<string> to_const;
int count = 0;
for (const auto& x : {x_partially_known, x_unknown}) {
for (const auto& zeros : {zeros_partially_known, zeros_unknown}) {
const string name = strings::StrCat("mul_", count++);
muls_deduced_output_shape.push_back(
ops::Mul(s.WithOpName(name), x, zeros));
to_const.insert(name);
}
}
// We add a known shape as input to AddN to propagate it back to the
// multiplies above, which means they can all be turned into Const nodes.
muls_deduced_output_shape.push_back(known_shape);
Output addn1 = ops::AddN(s.WithOpName("addn1"), muls_deduced_output_shape);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
nullptr /* cpu_device */);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
LOG(INFO) << output.DebugString();
EXPECT_EQ(10, output.node_size());
for (int i = 0; i < output.node_size(); ++i) {
const NodeDef& node = output.node(i);
const string& name = node.name();
if (to_const.count(name) > 0) {
EXPECT_EQ("Const", node.op()) << node.name();
EXPECT_EQ(2, node.input_size());
EXPECT_TRUE(IsControlInput(node.input(0)));
EXPECT_TRUE(IsControlInput(node.input(1)));
}
}
}
TEST_F(ConstantFoldingTest, CreateConstNodes) { TEST_F(ConstantFoldingTest, CreateConstNodes) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Scope s = tensorflow::Scope::NewRootScope();