Simplify and fix some bugs in constant folding of neutral/absorbing element optimizations.
PiperOrigin-RevId: 178293088
This commit is contained in:
parent
0509f07cc2
commit
2ea11416c9
@ -680,6 +680,7 @@ Status CreateConstantTensorAttrValue(DataType type, double value,
|
|||||||
const TensorShapeProto& shape,
|
const TensorShapeProto& shape,
|
||||||
AttrValue* attr_tensor) {
|
AttrValue* attr_tensor) {
|
||||||
TensorProto* t = attr_tensor->mutable_tensor();
|
TensorProto* t = attr_tensor->mutable_tensor();
|
||||||
|
t->set_dtype(type);
|
||||||
*t->mutable_tensor_shape() = shape;
|
*t->mutable_tensor_shape() = shape;
|
||||||
switch (type) {
|
switch (type) {
|
||||||
SET_TENSOR_VAL_CASE(DT_FLOAT, float, float);
|
SET_TENSOR_VAL_CASE(DT_FLOAT, float, float);
|
||||||
@ -1332,45 +1333,47 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
|
|||||||
}
|
}
|
||||||
const TensorShapeProto& output_shape =
|
const TensorShapeProto& output_shape =
|
||||||
properties.GetOutputProperties(node.name())[0].shape();
|
properties.GetOutputProperties(node.name())[0].shape();
|
||||||
const TensorShapeProto& x_shape =
|
|
||||||
properties.GetInputProperties(node.name())[0].shape();
|
// Simplify element-wise multiplication by ones or addition of zeros.
|
||||||
const TensorShapeProto& y_shape =
|
const TensorShapeProto& y_shape =
|
||||||
properties.GetInputProperties(node.name())[1].shape();
|
properties.GetInputProperties(node.name())[1].shape();
|
||||||
const bool x_is_zero = IsZeros(*x);
|
const bool x_is_zero = IsZeros(*x);
|
||||||
const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
|
const bool x_is_one = IsOnes(*x);
|
||||||
const bool y_is_zero = IsZeros(*y);
|
|
||||||
const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
|
const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
|
||||||
|
if (y_matches_output_shape &&
|
||||||
// Simplify addition of zeros.
|
((is_mul && x_is_one) || (is_add && x_is_zero))) {
|
||||||
if (is_add) {
|
// 1 * y = y or 0 + y = y.
|
||||||
if (x_is_zero && y_matches_output_shape) {
|
ReplaceAddOrMulWithIdentity(1, &node);
|
||||||
// 0 + y = y.
|
continue;
|
||||||
ReplaceAddOrMulWithIdentity(1, &node);
|
|
||||||
continue;
|
|
||||||
} else if (y_is_zero && x_matches_output_shape) {
|
|
||||||
// x + 0 = y.
|
|
||||||
ReplaceAddOrMulWithIdentity(0, &node);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
const TensorShapeProto& x_shape =
|
||||||
// Simplify element-wise multiplication by ones.
|
properties.GetInputProperties(node.name())[0].shape();
|
||||||
if (is_mul) {
|
const bool y_is_zero = IsZeros(*y);
|
||||||
if (IsOnes(*x) && y_matches_output_shape) {
|
const bool y_is_one = IsOnes(*y);
|
||||||
// 1 * y = y.
|
const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
|
||||||
ReplaceAddOrMulWithIdentity(1, &node);
|
if (x_matches_output_shape &&
|
||||||
continue;
|
((is_mul && y_is_one) || (is_add && y_is_zero))) {
|
||||||
}
|
// x * 1 = x or x + 0 = x
|
||||||
if (IsOnes(*y) && x_matches_output_shape) {
|
ReplaceAddOrMulWithIdentity(0, &node);
|
||||||
// x * 1 = x.
|
continue;
|
||||||
ReplaceAddOrMulWithIdentity(0, &node);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Simplify multiplication and matmul by zeros.
|
// Simplify multiplication and matmul by zeros.
|
||||||
if (x_is_zero || y_is_zero) {
|
if (!is_add && (x_is_zero || y_is_zero)) {
|
||||||
TF_RETURN_IF_ERROR(ReplaceAddOrMulWithConstant(0, output_shape, &node));
|
const PartialTensorShape shp(output_shape);
|
||||||
|
if (shp.IsFullyDefined()) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
ReplaceAddOrMulWithConstant(0, output_shape, &node));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Even if an input shape is only partially known, we may known that it
|
||||||
|
// matches the output shape and thus forward the corresponding zero
|
||||||
|
// input.
|
||||||
|
if (is_mul && x_is_zero && x_matches_output_shape) {
|
||||||
|
ReplaceAddOrMulWithIdentity(0, &node);
|
||||||
|
} else if (is_mul && y_is_zero && y_matches_output_shape) {
|
||||||
|
ReplaceAddOrMulWithIdentity(1, &node);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -198,6 +198,120 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_UnknownOutputShape) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
Output x_known =
|
||||||
|
ops::Placeholder(s.WithOpName("x_known"), DT_FLOAT,
|
||||||
|
ops::Placeholder::Shape(TensorShape({2, 2})));
|
||||||
|
Output x_partially_known =
|
||||||
|
ops::Placeholder(s.WithOpName("x_partially_unknown"), DT_FLOAT,
|
||||||
|
ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
|
||||||
|
Output x_unknown = ops::Placeholder(s.WithOpName("x_unknown"), DT_FLOAT);
|
||||||
|
Output zeros_known = ops::ZerosLike(s.WithOpName("zeros_known"), x_known);
|
||||||
|
Output zeros_partially_known =
|
||||||
|
ops::ZerosLike(s.WithOpName("zeros_partially_known"), x_partially_known);
|
||||||
|
Output zeros_unknown =
|
||||||
|
ops::ZerosLike(s.WithOpName("zeros_unknown"), x_unknown);
|
||||||
|
|
||||||
|
// Multiplies without any additional ops to supply the output shape.
|
||||||
|
int count = 0;
|
||||||
|
std::vector<Output> muls;
|
||||||
|
std::unordered_set<string> not_converted;
|
||||||
|
std::unordered_set<string> to_const;
|
||||||
|
std::unordered_set<string> to_identity;
|
||||||
|
for (const auto* x : {&x_known, &x_partially_known, &x_unknown}) {
|
||||||
|
for (const auto* zeros :
|
||||||
|
{&zeros_known, &zeros_partially_known, &zeros_unknown}) {
|
||||||
|
const string name = strings::StrCat("mul_", count++);
|
||||||
|
muls.push_back(ops::Mul(s.WithOpName(name), *x, *zeros));
|
||||||
|
if (x == &x_partially_known && zeros == &zeros_partially_known) {
|
||||||
|
to_identity.insert(name);
|
||||||
|
} else if (x == &x_unknown || zeros == &zeros_unknown) {
|
||||||
|
not_converted.insert(name);
|
||||||
|
} else {
|
||||||
|
to_const.insert(name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
|
||||||
|
ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
|
||||||
|
nullptr /* cpu_device */);
|
||||||
|
GraphDef output;
|
||||||
|
Status status = optimizer.Optimize(nullptr, item, &output);
|
||||||
|
TF_EXPECT_OK(status);
|
||||||
|
LOG(INFO) << output.DebugString();
|
||||||
|
|
||||||
|
EXPECT_EQ(15, output.node_size());
|
||||||
|
for (int i = 0; i < output.node_size(); ++i) {
|
||||||
|
const NodeDef& node = output.node(i);
|
||||||
|
const string& name = node.name();
|
||||||
|
if (to_const.count(name) > 0) {
|
||||||
|
EXPECT_EQ("Const", node.op()) << node.name();
|
||||||
|
} else if (to_identity.count(name) > 0) {
|
||||||
|
EXPECT_EQ("Identity", node.op()) << node.name();
|
||||||
|
} else if (not_converted.count(name) > 0) {
|
||||||
|
EXPECT_EQ("Mul", node.op()) << node.name();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_KnownOutputShape) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
Output known_shape = ops::Const(s.WithOpName("known_shape"), 0.0f, {2, 2});
|
||||||
|
Output x_partially_known =
|
||||||
|
ops::Placeholder(s.WithOpName("x_partially_unknown"), DT_FLOAT,
|
||||||
|
ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
|
||||||
|
Output x_unknown = ops::Placeholder(s.WithOpName("x_unknown"), DT_FLOAT);
|
||||||
|
Output zeros_partially_known =
|
||||||
|
ops::ZerosLike(s.WithOpName("zeros_partially_known"), x_partially_known);
|
||||||
|
Output zeros_unknown =
|
||||||
|
ops::ZerosLike(s.WithOpName("zeros_unknown"), x_unknown);
|
||||||
|
|
||||||
|
// If at least one of the inputs to AddN has a known shape, shape inference
|
||||||
|
// will propagate the shape back to the inputs of AddN, making the
|
||||||
|
// output shapes of all its inputs known
|
||||||
|
std::vector<Output> muls_deduced_output_shape;
|
||||||
|
std::unordered_set<string> to_const;
|
||||||
|
int count = 0;
|
||||||
|
for (const auto& x : {x_partially_known, x_unknown}) {
|
||||||
|
for (const auto& zeros : {zeros_partially_known, zeros_unknown}) {
|
||||||
|
const string name = strings::StrCat("mul_", count++);
|
||||||
|
muls_deduced_output_shape.push_back(
|
||||||
|
ops::Mul(s.WithOpName(name), x, zeros));
|
||||||
|
to_const.insert(name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// We add a known shape as input to AddN to propagate it back to the
|
||||||
|
// multiplies above, which means they can all be turned into Const nodes.
|
||||||
|
muls_deduced_output_shape.push_back(known_shape);
|
||||||
|
Output addn1 = ops::AddN(s.WithOpName("addn1"), muls_deduced_output_shape);
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
|
||||||
|
ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
|
||||||
|
nullptr /* cpu_device */);
|
||||||
|
GraphDef output;
|
||||||
|
Status status = optimizer.Optimize(nullptr, item, &output);
|
||||||
|
TF_EXPECT_OK(status);
|
||||||
|
LOG(INFO) << output.DebugString();
|
||||||
|
|
||||||
|
EXPECT_EQ(10, output.node_size());
|
||||||
|
for (int i = 0; i < output.node_size(); ++i) {
|
||||||
|
const NodeDef& node = output.node(i);
|
||||||
|
const string& name = node.name();
|
||||||
|
if (to_const.count(name) > 0) {
|
||||||
|
EXPECT_EQ("Const", node.op()) << node.name();
|
||||||
|
EXPECT_EQ(2, node.input_size());
|
||||||
|
EXPECT_TRUE(IsControlInput(node.input(0)));
|
||||||
|
EXPECT_TRUE(IsControlInput(node.input(1)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ConstantFoldingTest, CreateConstNodes) {
|
TEST_F(ConstantFoldingTest, CreateConstNodes) {
|
||||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user