Fix bug in Pow optimizer rule when broadcasting is involved.
Minor cleanup by moving the helper function ShapesEqual to GraphProperties and adding unit tests for it. PiperOrigin-RevId: 213876779
This commit is contained in:
parent
d388770922
commit
17dbe77f5a
@ -7,10 +7,6 @@ load("//tensorflow:tensorflow.bzl", "tf_cuda_only_cc_test")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
|
||||
# Platform specific build config
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config.bzl",
|
||||
"tf_protos_grappler",
|
||||
)
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||
"if_static",
|
||||
@ -97,7 +93,6 @@ cc_library(
|
||||
deps = [
|
||||
":evaluation_utils",
|
||||
":graph_optimizer",
|
||||
":symbolic_shapes",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
@ -107,6 +102,7 @@ cc_library(
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/clusters:cluster",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
"//tensorflow/core/grappler/utils:symbolic_shapes",
|
||||
],
|
||||
)
|
||||
|
||||
@ -261,7 +257,6 @@ cc_library(
|
||||
":constant_folding",
|
||||
":graph_optimizer",
|
||||
":graph_optimizer_stage",
|
||||
":symbolic_shapes",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
@ -270,6 +265,7 @@ cc_library(
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
"//tensorflow/core/grappler/utils:symbolic_shapes",
|
||||
"//tensorflow/core/grappler/utils:topological_sort",
|
||||
],
|
||||
)
|
||||
@ -648,7 +644,6 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":graph_optimizer",
|
||||
":symbolic_shapes",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -658,6 +653,7 @@ cc_library(
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
"//tensorflow/core/grappler/utils:frame",
|
||||
"//tensorflow/core/grappler/utils:symbolic_shapes",
|
||||
],
|
||||
)
|
||||
|
||||
@ -714,31 +710,6 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "symbolic_shapes",
|
||||
srcs = ["symbolic_shapes.cc"],
|
||||
hdrs = ["symbolic_shapes.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
] + tf_protos_grappler(),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "symbolic_shapes_test",
|
||||
srcs = ["symbolic_shapes_test.cc"],
|
||||
deps = [
|
||||
":symbolic_shapes",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "debug_stripper",
|
||||
srcs = ["debug_stripper.cc"],
|
||||
|
@ -35,8 +35,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
|
||||
#include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h"
|
||||
#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
|
||||
#include "tensorflow/core/grappler/utils/topological_sort.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
@ -2367,26 +2367,24 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
|
||||
}
|
||||
|
||||
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
|
||||
const auto& p = ctx().graph_properties->GetInputProperties(node->name())[1];
|
||||
for (int i = 0; i < p.shape().dim_size(); ++i) {
|
||||
if (p.shape().dim(i).size() < 0) {
|
||||
const auto& pow_props =
|
||||
ctx().graph_properties->GetInputProperties(node->name())[1];
|
||||
for (int i = 0; i < pow_props.shape().dim_size(); ++i) {
|
||||
if (pow_props.shape().dim(i).size() < 0) {
|
||||
// skip if p is is not fully defined.
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
if (TensorShape::IsValid(p.shape()) && p.has_value()) {
|
||||
Tensor pow(p.dtype(), p.shape());
|
||||
if (!pow.FromProto(p.value())) {
|
||||
if (TensorShape::IsValid(pow_props.shape()) && pow_props.has_value()) {
|
||||
Tensor pow(pow_props.dtype(), pow_props.shape());
|
||||
if (!pow.FromProto(pow_props.value())) {
|
||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
||||
p.value().DebugString());
|
||||
pow_props.value().DebugString());
|
||||
}
|
||||
|
||||
complex128 prev, curr;
|
||||
for (int i = 0; i < pow.NumElements(); ++i) {
|
||||
if (!GetElementUnexhaustive(pow, i,
|
||||
{DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
|
||||
DT_COMPLEX64, DT_COMPLEX128},
|
||||
&curr)) {
|
||||
if (!GetElementUnexhaustive(pow, i, {pow_props.dtype()}, &curr)) {
|
||||
// input data type is not supported by Pow. Skip.
|
||||
return Status::OK();
|
||||
}
|
||||
@ -2399,12 +2397,19 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
|
||||
NodeDef *x, *y;
|
||||
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
|
||||
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
|
||||
const auto& value_props =
|
||||
ctx().graph_properties->GetInputProperties(node->name())[0];
|
||||
const TensorShapeProto& output_shape =
|
||||
ctx().graph_properties->GetOutputProperties(node->name())[0].shape();
|
||||
if (curr == complex128(2, 0)) {
|
||||
node->set_op("Square");
|
||||
node->set_input(1, AsControlDependency(y->name()));
|
||||
AddToOptimizationQueue(node);
|
||||
AddToOptimizationQueue(y);
|
||||
} else if (curr == complex128(1, 0)) {
|
||||
} else if (curr == complex128(1, 0) &&
|
||||
ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
|
||||
// Pow could be used to broadcast, so make sure the shapes of the two
|
||||
// arguments are identical before replacing Pow with Identity.
|
||||
node->set_op("Identity");
|
||||
node->set_input(1, AsControlDependency(y->name()));
|
||||
AddToOptimizationQueue(node);
|
||||
@ -2414,20 +2419,20 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
|
||||
node->set_input(1, AsControlDependency(y->name()));
|
||||
AddToOptimizationQueue(node);
|
||||
AddToOptimizationQueue(y);
|
||||
} else if (curr == complex128(0, 0)) {
|
||||
const auto& b =
|
||||
ctx().graph_properties->GetInputProperties(node->name())[0];
|
||||
for (int i = 0; i < b.shape().dim_size(); ++i) {
|
||||
if (b.shape().dim(i).size() < 0) {
|
||||
} else if (curr == complex128(0, 0) &&
|
||||
ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
|
||||
for (int i = 0; i < value_props.shape().dim_size(); ++i) {
|
||||
if (value_props.shape().dim(i).size() < 0) {
|
||||
// skip if b is is not fully defined.
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
if (TensorShape::IsValid(b.shape()) && b.has_value()) {
|
||||
Tensor base(b.dtype(), b.shape());
|
||||
if (!base.FromProto(b.value())) {
|
||||
if (TensorShape::IsValid(value_props.shape()) &&
|
||||
value_props.has_value()) {
|
||||
Tensor base(value_props.dtype(), value_props.shape());
|
||||
if (!base.FromProto(value_props.value())) {
|
||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
||||
b.value().DebugString());
|
||||
value_props.value().DebugString());
|
||||
}
|
||||
node->set_op("Const");
|
||||
Tensor c(base.dtype(), base.shape());
|
||||
@ -2585,12 +2590,10 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage {
|
||||
~ConvertExpm1Stage() override = default;
|
||||
|
||||
bool IsSupported(const NodeDef* node) const override {
|
||||
if (!IsSub(*node))
|
||||
return false;
|
||||
if (!IsSub(*node)) return false;
|
||||
|
||||
NodeDef* input;
|
||||
if (!GetInputNode(node->input(0), &input).ok())
|
||||
return false;
|
||||
if (!GetInputNode(node->input(0), &input).ok()) return false;
|
||||
|
||||
return IsExp(*input);
|
||||
}
|
||||
@ -2610,10 +2613,8 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const auto& t =
|
||||
ctx().graph_properties->GetInputProperties(exp->name())[0];
|
||||
const auto& c =
|
||||
ctx().graph_properties->GetInputProperties(node->name())[1];
|
||||
const auto& t = ctx().graph_properties->GetInputProperties(exp->name())[0];
|
||||
const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1];
|
||||
for (int k = 0; k < c.shape().dim_size(); ++k) {
|
||||
// Skip if c shape is not fully determined.
|
||||
if (c.shape().dim(k).size() < 0) {
|
||||
|
@ -2474,6 +2474,9 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
|
||||
auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2});
|
||||
auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2});
|
||||
auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
|
||||
auto z = ops::Const(s.WithOpName("z"), {42.0f}, {});
|
||||
auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3});
|
||||
auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3});
|
||||
Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
|
||||
Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
|
||||
Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
|
||||
@ -2481,21 +2484,24 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
|
||||
Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5);
|
||||
Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1);
|
||||
Output out = ops::Pow(s.WithOpName("out"), x, y);
|
||||
Output out_bcast1 = ops::Pow(s.WithOpName("out_bcast1"), z, ones);
|
||||
Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"out2", "out1", "out.5", "out0", "out_.5", "out_1", "out"};
|
||||
item.fetch = {"out2", "out1", "out.5", "out0", "out_.5",
|
||||
"out_1", "out", "out_bcast1", "out_bcast2"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||
EXPECT_EQ(7, tensors_expected.size());
|
||||
EXPECT_EQ(9, tensors_expected.size());
|
||||
|
||||
GraphDef got;
|
||||
ArithmeticOptimizer optimizer;
|
||||
EnableOnlyConvertPow(&optimizer);
|
||||
OptimizeAndPrune(&optimizer, &item, &got);
|
||||
auto tensors = EvaluateNodes(got, item.fetch);
|
||||
EXPECT_EQ(7, tensors.size());
|
||||
EXPECT_EQ(9, tensors.size());
|
||||
|
||||
for (int i = 0; i < 7; ++i) {
|
||||
for (int i = 0; i < tensors.size(); ++i) {
|
||||
EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
|
||||
test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
|
||||
}
|
||||
@ -2509,6 +2515,9 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
|
||||
AddNode("y_.5", "Const", {}, {}, &want);
|
||||
AddNode("y_1", "Const", {}, {}, &want);
|
||||
AddNode("y", "Const", {}, {}, &want);
|
||||
AddNode("z", "Const", {}, {}, &want);
|
||||
AddNode("ones", "Const", {}, {}, &want);
|
||||
AddNode("zeros", "Const", {}, {}, &want);
|
||||
AddNode("out2", "Square", {"x", AsControlDependency("y2")}, {}, &want);
|
||||
AddNode("out1", "Identity", {"x", AsControlDependency("y1")}, {}, &want);
|
||||
AddNode("out.5", "Sqrt", {"x", AsControlDependency("y.5")}, {}, &want);
|
||||
@ -2517,6 +2526,8 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
|
||||
AddNode("out_.5", "Rsqrt", {"x", AsControlDependency("y_.5")}, {}, &want);
|
||||
AddNode("out_1", "Reciprocal", {"x", AsControlDependency("y_1")}, {}, &want);
|
||||
AddNode("out", "Pow", {"x", "y"}, {}, &want);
|
||||
AddNode("out_bcast1", "Pow", {"z", "ones"}, {}, &want);
|
||||
AddNode("out_bcast2", "Pow", {"z", "zeros"}, {}, &want);
|
||||
|
||||
CompareGraphs(want, got);
|
||||
}
|
||||
|
@ -32,8 +32,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
|
||||
#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
@ -437,25 +437,6 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool ShapesEqual(const TensorShapeProto& shape1,
|
||||
const TensorShapeProto& shape2) {
|
||||
if (shape1.unknown_rank() || shape2.unknown_rank()) {
|
||||
return false;
|
||||
}
|
||||
if (shape1.dim_size() != shape2.dim_size()) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < shape1.dim_size(); ++i) {
|
||||
if (shape1.dim(i).size() != shape2.dim(i).size()) {
|
||||
return false;
|
||||
}
|
||||
if (shape1.dim(i).size() == -1 || shape2.dim(i).size() == -1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties,
|
||||
BCast::Vec* shape, int64* min_id) {
|
||||
if (shape_node.op() == "Shape") {
|
||||
@ -2348,7 +2329,8 @@ Status ConstantFolding::SimplifyArithmeticOperations(
|
||||
properties.GetInputProperties(node->name())[1].shape();
|
||||
const bool x_is_zero = IsZeros(*x);
|
||||
const bool x_is_one = x_is_zero ? false : IsOnes(*x);
|
||||
const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
|
||||
const bool y_matches_output_shape =
|
||||
ShapesSymbolicallyEqual(output_shape, y_shape);
|
||||
if (y_matches_output_shape &&
|
||||
((is_mul && x_is_one) || (is_add && x_is_zero))) {
|
||||
// 1 * y = y or 0 + y = y.
|
||||
@ -2378,7 +2360,8 @@ Status ConstantFolding::SimplifyArithmeticOperations(
|
||||
properties.GetInputProperties(node->name())[0].shape();
|
||||
const bool y_is_zero = IsZeros(*y);
|
||||
const bool y_is_one = y_is_zero ? false : IsOnes(*y);
|
||||
const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
|
||||
const bool x_matches_output_shape =
|
||||
ShapesSymbolicallyEqual(output_shape, x_shape);
|
||||
if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
|
||||
((is_add || is_sub) && y_is_zero))) {
|
||||
// x * 1 = x or x / 1 = x or x +/- 0 = x
|
||||
|
@ -20,10 +20,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/grappler/graph_view.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
|
||||
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -1,6 +1,10 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config.bzl",
|
||||
"tf_protos_grappler",
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "scc",
|
||||
@ -210,3 +214,28 @@ tf_cc_test(
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "symbolic_shapes",
|
||||
srcs = ["symbolic_shapes.cc"],
|
||||
hdrs = ["symbolic_shapes.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
] + tf_protos_grappler(),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "symbolic_shapes_test",
|
||||
srcs = ["symbolic_shapes_test.cc"],
|
||||
deps = [
|
||||
":symbolic_shapes",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
|
||||
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
|
||||
#include "tensorflow/core/util/bcast.h"
|
||||
|
||||
namespace tensorflow {
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_
|
||||
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
|
||||
@ -74,4 +74,4 @@ int64 ComputeSizeRatio(const TensorShapeProto& numerator,
|
||||
} // namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
|
||||
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
Loading…
x
Reference in New Issue
Block a user