diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc index 3a1cfb64b99..4bcb4dfc791 100644 --- a/tensorflow/core/grappler/optimizers/shape_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc @@ -42,12 +42,16 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, bool has_size = false; bool has_shape = false; bool has_prod = false; + auto is_int = [](const NodeDef& node) -> bool { + return node.attr().at("T").type() == DT_INT32 || + node.attr().at("T").type() == DT_INT64; + }; for (const NodeDef& node : item.graph.node()) { if (IsShape(node)) { has_shape = true; - } else if (IsProd(node)) { + } else if (IsProd(node) && is_int(node)) { has_prod = true; - } else if (IsDiv(node)) { + } else if (IsDiv(node) && is_int(node)) { has_div = true; } else if (IsSize(node)) { has_size = true; diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index 71dbefeb5d4..7a1b65e1729 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -22,6 +22,8 @@ limitations under the License. #include #include #include + +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -120,54 +122,45 @@ bool IsSameInput(const string& name1, const string& name2); // Returns the trailing position number (or zero if no number is present) if // NodeName(input_name) is equal to node_name. Returns -1 for control inputs. -// Returns -2 if NodeName(input_name) is not equal to node_name. -// Note: This function is used very heavily, and this hand-optimized -// version is 3-4x faster than the version using Scanner, which it replaced. -// This is worth the reduction in readability. -inline int NodePositionIfSameNode(const string& input_name, - const string& node_name) { - if (input_name.empty()) return -2; - const bool is_ctrl = input_name[0] == '^'; - auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin(); - auto node_it = node_name.begin(); - if (node_name.empty() || - std::distance(input_it, input_name.end()) < node_name.size()) { +// Returns -2 if input_name is empty or NodeName(input_name) is not equal to +// node_name. +inline int NodePositionIfSameNode(absl::string_view input_name, + absl::string_view node_name) { + bool is_control = absl::StartsWith(input_name, "^"); + if (is_control) input_name.remove_prefix(1); + if (input_name.empty() || node_name.empty() || + input_name.size() < node_name.size()) { return -2; } - while (node_it != node_name.end()) { - if (*input_it++ != *node_it++) { - return -2; - } + TensorId id = ParseTensorName(input_name); + if (id.first != node_name) return -2; + if (is_control) return -1; + return id.second; +} + +// Returns the node name and position in a single call. +inline StringPiece ParseNodeNameAsStringPiece(absl::string_view name, + int* position) { + const bool is_control = absl::StartsWith(name, "^"); + TensorId id = ParseTensorName(name); + if (position) { + *position = is_control ? -1 : id.second; } - if (input_it == input_name.end()) { - return is_ctrl ? -1 : 0; - } else if (*input_it++ == ':') { - StringPiece remaining(&(*input_it), - std::distance(input_it, input_name.end())); - int position; - if (!strings::safe_strto32(remaining, &position)) { - return -2; - } - return is_ctrl ? -1 : position; - } else { - return -2; + if (is_control && id.second >= 0) { + id.first.remove_prefix(1); } + return id.first; +} + +// Returns the node name and position in a single call. +inline string ParseNodeName(const string& name, int* position) { + return string(ParseNodeNameAsStringPiece(name, position)); } // Return the node name corresponding to 'name' if name is valid, or the empty // string otherwise. inline StringPiece NodeNameAsStringPiece(const string& name) { - static const string empty; - if (name.empty()) return StringPiece(empty); - const auto begin_it = name[0] == '^' ? name.begin() + 1 : name.begin(); - auto end_it = begin_it; - while (end_it != name.end() && *end_it != ':') { - ++end_it; - } - if (end_it != name.end() && *end_it != ':') { - return StringPiece(empty); - } - return StringPiece(&(*begin_it), std::distance(begin_it, end_it)); + return ParseNodeNameAsStringPiece(name, nullptr); } // Return the node name corresponding to 'name' if name is valid, or the empty @@ -176,43 +169,6 @@ inline string NodeName(const string& name) { return string(NodeNameAsStringPiece(name)); } -// Returns the node name and position in a single call. -// DEPRECATED(ezhulenev): Use TensorId and ParseTensorName. -inline StringPiece ParseNodeNameAsStringPiece(const string& name, - int* position) { - static const string empty; - if (name.empty()) { - *position = 0; - return StringPiece(empty); - } - const bool is_ctrl = name[0] == '^'; - const auto begin_it = is_ctrl ? name.begin() + 1 : name.begin(); - *position = is_ctrl ? -1 : 0; - auto end_it = begin_it; - while (end_it != name.end() && *end_it != ':') { - ++end_it; - } - const StringPiece node_name(&(*begin_it), std::distance(begin_it, end_it)); - if (end_it != name.end()) { - if (*end_it != ':') { - return StringPiece(empty); - } else if (!is_ctrl) { - ++end_it; - StringPiece remaining(&(*end_it), std::distance(end_it, name.end())); - if (!strings::safe_strto32(remaining, position)) { - return StringPiece(empty); - } - } - } - return node_name; -} - -// Returns the node name and position in a single call. -// DEPRECATED(ezhulenev): Use SafeTensorId and ParseTensorName. -inline string ParseNodeName(const string& name, int* position) { - return string(ParseNodeNameAsStringPiece(name, position)); -} - inline int NodePosition(const string& name) { int position; ParseNodeNameAsStringPiece(name, &position); diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc index a65704147ea..777630ff98c 100644 --- a/tensorflow/core/grappler/utils_test.cc +++ b/tensorflow/core/grappler/utils_test.cc @@ -454,6 +454,16 @@ BM_NodePositionIfSameNode("^foo/bar/baz", "foo/bar/baz", Match_Ctrl); BM_NodePositionIfSameNode("blah", "foo/bar/baz", NoMatch_0); BM_NodePositionIfSameNode("foo/bar/baz/gnu", "foo/bar/baz", NoMatch_end); +static void BM_NodeNameAsStringPiece(int iters, int size) { + string input(size + 3, 'x'); + input[size] = ':'; + for (int i = 0; i < iters; ++i) { + StringPiece node_name = NodeNameAsStringPiece(input); + CHECK_GT(node_name.size(), 0); + } +} +BENCHMARK(BM_NodeNameAsStringPiece)->Range(1, 1024); + #define BM_ParseNodeNameAsStringPiece(I, NAME) \ static void BM_ParseNodeNameAsStringPiece_##NAME(int iters) { \ string input = I; \