Fix bug in Pow optimizer rule when broadcasting is involved.

Minor cleanup by moving the helper function ShapesEqual to GraphProperties and adding unit tests for it.

PiperOrigin-RevId: 213876779
This commit is contained in:
A. Unique TensorFlower 2018-09-20 13:56:49 -07:00 committed by TensorFlower Gardener
parent d388770922
commit 17dbe77f5a
9 changed files with 89 additions and 95 deletions

View File

@ -7,10 +7,6 @@ load("//tensorflow:tensorflow.bzl", "tf_cuda_only_cc_test")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
# Platform specific build config
load(
"//tensorflow/core:platform/default/build_config.bzl",
"tf_protos_grappler",
)
load(
"//tensorflow/core:platform/default/build_config_root.bzl",
"if_static",
@ -97,7 +93,6 @@ cc_library(
deps = [
":evaluation_utils",
":graph_optimizer",
":symbolic_shapes",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@ -107,6 +102,7 @@ cc_library(
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:symbolic_shapes",
],
)
@ -261,7 +257,6 @@ cc_library(
":constant_folding",
":graph_optimizer",
":graph_optimizer_stage",
":symbolic_shapes",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@ -270,6 +265,7 @@ cc_library(
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:symbolic_shapes",
"//tensorflow/core/grappler/utils:topological_sort",
],
)
@ -648,7 +644,6 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":graph_optimizer",
":symbolic_shapes",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -658,6 +653,7 @@ cc_library(
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:frame",
"//tensorflow/core/grappler/utils:symbolic_shapes",
],
)
@ -714,31 +710,6 @@ tf_cuda_cc_test(
],
)
cc_library(
name = "symbolic_shapes",
srcs = ["symbolic_shapes.cc"],
hdrs = ["symbolic_shapes.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
] + tf_protos_grappler(),
)
tf_cc_test(
name = "symbolic_shapes_test",
srcs = ["symbolic_shapes_test.cc"],
deps = [
":symbolic_shapes",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "debug_stripper",
srcs = ["debug_stripper.cc"],

View File

@ -35,8 +35,8 @@ limitations under the License.
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h"
#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@ -2367,26 +2367,24 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
const auto& p = ctx().graph_properties->GetInputProperties(node->name())[1];
for (int i = 0; i < p.shape().dim_size(); ++i) {
if (p.shape().dim(i).size() < 0) {
const auto& pow_props =
ctx().graph_properties->GetInputProperties(node->name())[1];
for (int i = 0; i < pow_props.shape().dim_size(); ++i) {
if (pow_props.shape().dim(i).size() < 0) {
// skip if p is is not fully defined.
return Status::OK();
}
}
if (TensorShape::IsValid(p.shape()) && p.has_value()) {
Tensor pow(p.dtype(), p.shape());
if (!pow.FromProto(p.value())) {
if (TensorShape::IsValid(pow_props.shape()) && pow_props.has_value()) {
Tensor pow(pow_props.dtype(), pow_props.shape());
if (!pow.FromProto(pow_props.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
p.value().DebugString());
pow_props.value().DebugString());
}
complex128 prev, curr;
for (int i = 0; i < pow.NumElements(); ++i) {
if (!GetElementUnexhaustive(pow, i,
{DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
DT_COMPLEX64, DT_COMPLEX128},
&curr)) {
if (!GetElementUnexhaustive(pow, i, {pow_props.dtype()}, &curr)) {
// input data type is not supported by Pow. Skip.
return Status::OK();
}
@ -2399,12 +2397,19 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
NodeDef *x, *y;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
const auto& value_props =
ctx().graph_properties->GetInputProperties(node->name())[0];
const TensorShapeProto& output_shape =
ctx().graph_properties->GetOutputProperties(node->name())[0].shape();
if (curr == complex128(2, 0)) {
node->set_op("Square");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
} else if (curr == complex128(1, 0)) {
} else if (curr == complex128(1, 0) &&
ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
// Pow could be used to broadcast, so make sure the shapes of the two
// arguments are identical before replacing Pow with Identity.
node->set_op("Identity");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
@ -2414,20 +2419,20 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
} else if (curr == complex128(0, 0)) {
const auto& b =
ctx().graph_properties->GetInputProperties(node->name())[0];
for (int i = 0; i < b.shape().dim_size(); ++i) {
if (b.shape().dim(i).size() < 0) {
} else if (curr == complex128(0, 0) &&
ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
for (int i = 0; i < value_props.shape().dim_size(); ++i) {
if (value_props.shape().dim(i).size() < 0) {
// skip if b is is not fully defined.
return Status::OK();
}
}
if (TensorShape::IsValid(b.shape()) && b.has_value()) {
Tensor base(b.dtype(), b.shape());
if (!base.FromProto(b.value())) {
if (TensorShape::IsValid(value_props.shape()) &&
value_props.has_value()) {
Tensor base(value_props.dtype(), value_props.shape());
if (!base.FromProto(value_props.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
b.value().DebugString());
value_props.value().DebugString());
}
node->set_op("Const");
Tensor c(base.dtype(), base.shape());
@ -2585,12 +2590,10 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage {
~ConvertExpm1Stage() override = default;
bool IsSupported(const NodeDef* node) const override {
if (!IsSub(*node))
return false;
if (!IsSub(*node)) return false;
NodeDef* input;
if (!GetInputNode(node->input(0), &input).ok())
return false;
if (!GetInputNode(node->input(0), &input).ok()) return false;
return IsExp(*input);
}
@ -2610,10 +2613,8 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage {
return Status::OK();
}
const auto& t =
ctx().graph_properties->GetInputProperties(exp->name())[0];
const auto& c =
ctx().graph_properties->GetInputProperties(node->name())[1];
const auto& t = ctx().graph_properties->GetInputProperties(exp->name())[0];
const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1];
for (int k = 0; k < c.shape().dim_size(); ++k) {
// Skip if c shape is not fully determined.
if (c.shape().dim(k).size() < 0) {

View File

@ -2474,6 +2474,9 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2});
auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2});
auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
auto z = ops::Const(s.WithOpName("z"), {42.0f}, {});
auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3});
auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3});
Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
@ -2481,21 +2484,24 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5);
Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1);
Output out = ops::Pow(s.WithOpName("out"), x, y);
Output out_bcast1 = ops::Pow(s.WithOpName("out_bcast1"), z, ones);
Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros);
GrapplerItem item;
item.fetch = {"out2", "out1", "out.5", "out0", "out_.5", "out_1", "out"};
item.fetch = {"out2", "out1", "out.5", "out0", "out_.5",
"out_1", "out", "out_bcast1", "out_bcast2"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
EXPECT_EQ(7, tensors_expected.size());
EXPECT_EQ(9, tensors_expected.size());
GraphDef got;
ArithmeticOptimizer optimizer;
EnableOnlyConvertPow(&optimizer);
OptimizeAndPrune(&optimizer, &item, &got);
auto tensors = EvaluateNodes(got, item.fetch);
EXPECT_EQ(7, tensors.size());
EXPECT_EQ(9, tensors.size());
for (int i = 0; i < 7; ++i) {
for (int i = 0; i < tensors.size(); ++i) {
EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
}
@ -2509,6 +2515,9 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
AddNode("y_.5", "Const", {}, {}, &want);
AddNode("y_1", "Const", {}, {}, &want);
AddNode("y", "Const", {}, {}, &want);
AddNode("z", "Const", {}, {}, &want);
AddNode("ones", "Const", {}, {}, &want);
AddNode("zeros", "Const", {}, {}, &want);
AddNode("out2", "Square", {"x", AsControlDependency("y2")}, {}, &want);
AddNode("out1", "Identity", {"x", AsControlDependency("y1")}, {}, &want);
AddNode("out.5", "Sqrt", {"x", AsControlDependency("y.5")}, {}, &want);
@ -2517,6 +2526,8 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
AddNode("out_.5", "Rsqrt", {"x", AsControlDependency("y_.5")}, {}, &want);
AddNode("out_1", "Reciprocal", {"x", AsControlDependency("y_1")}, {}, &want);
AddNode("out", "Pow", {"x", "y"}, {}, &want);
AddNode("out_bcast1", "Pow", {"z", "ones"}, {}, &want);
AddNode("out_bcast2", "Pow", {"z", "zeros"}, {}, &want);
CompareGraphs(want, got);
}

View File

@ -32,8 +32,8 @@ limitations under the License.
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@ -437,25 +437,6 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
}
namespace {
bool ShapesEqual(const TensorShapeProto& shape1,
const TensorShapeProto& shape2) {
if (shape1.unknown_rank() || shape2.unknown_rank()) {
return false;
}
if (shape1.dim_size() != shape2.dim_size()) {
return false;
}
for (int i = 0; i < shape1.dim_size(); ++i) {
if (shape1.dim(i).size() != shape2.dim(i).size()) {
return false;
}
if (shape1.dim(i).size() == -1 || shape2.dim(i).size() == -1) {
return false;
}
}
return true;
}
bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties,
BCast::Vec* shape, int64* min_id) {
if (shape_node.op() == "Shape") {
@ -2348,7 +2329,8 @@ Status ConstantFolding::SimplifyArithmeticOperations(
properties.GetInputProperties(node->name())[1].shape();
const bool x_is_zero = IsZeros(*x);
const bool x_is_one = x_is_zero ? false : IsOnes(*x);
const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
const bool y_matches_output_shape =
ShapesSymbolicallyEqual(output_shape, y_shape);
if (y_matches_output_shape &&
((is_mul && x_is_one) || (is_add && x_is_zero))) {
// 1 * y = y or 0 + y = y.
@ -2378,7 +2360,8 @@ Status ConstantFolding::SimplifyArithmeticOperations(
properties.GetInputProperties(node->name())[0].shape();
const bool y_is_zero = IsZeros(*y);
const bool y_is_one = y_is_zero ? false : IsOnes(*y);
const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
const bool x_matches_output_shape =
ShapesSymbolicallyEqual(output_shape, x_shape);
if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
((is_add || is_sub) && y_is_zero))) {
// x * 1 = x or x / 1 = x or x +/- 0 = x

View File

@ -20,10 +20,9 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {

View File

@ -1,6 +1,10 @@
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load(
"//tensorflow/core:platform/default/build_config.bzl",
"tf_protos_grappler",
)
cc_library(
name = "scc",
@ -210,3 +214,28 @@ tf_cc_test(
"//tensorflow/core:testlib",
],
)
cc_library(
name = "symbolic_shapes",
srcs = ["symbolic_shapes.cc"],
hdrs = ["symbolic_shapes.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
] + tf_protos_grappler(),
)
tf_cc_test(
name = "symbolic_shapes_test",
srcs = ["symbolic_shapes_test.cc"],
deps = [
":symbolic_shapes",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/util/bcast.h"
namespace tensorflow {

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_
#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_
#define TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
@ -74,4 +74,4 @@ int64 ComputeSizeRatio(const TensorShapeProto& numerator,
} // namespace grappler
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_
#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/platform/test.h"