[Grappler] Replace deprecated node name parsing code in Grappler with calls to ParseTensorName, which is also a lot faster.
Run on XXXX (72 X 2993 MHz CPUs); 2019-09-18T11:56:10.508535133-07:00 CPU: Intel Skylake Xeon with HyperThreading (36 cores) dL1:32KB dL2:1024KB dL3:24MB Benchmark Base (ns) New (ns) Improvement ------------------------------------------------------------------------------- BM_NodePositionIfSameNode_Match_7 14.6 7.87 +46.1% BM_NodePositionIfSameNode_Match_0 7.33 7.33 +0.0% BM_NodePositionIfSameNode_Match_Ctrl 8.77 7.06 +19.5% BM_NodePositionIfSameNode_NoMatch_0 2.17 1.90 +12.4% BM_NodePositionIfSameNode_NoMatch_end 7.07 5.43 +23.2% BM_NodeNameAsStringPiece/1 3.79 4.17 -10.0% BM_NodeNameAsStringPiece/8 6.24 4.17 +33.2% BM_NodeNameAsStringPiece/64 28.3 4.15 +85.3% BM_NodeNameAsStringPiece/512 221 4.14 +98.1% BM_NodeNameAsStringPiece/1k 432 4.14 +99.0% BM_ParseNodeNameAsStringPiece_foo 4.81 4.56 +5.2% BM_ParseNodeNameAsStringPiece_foo_bar_baz 7.45 4.34 +41.7% BM_ParseNodeNameAsStringPiece_foo_bar_baz_ctrl 7.04 5.80 +17.6% BM_ParseNodeNameAsStringPiece_foo123 13.9 6.80 +51.1% BM_ParseNodeNameAsStringPiece_foo_bar_baz_123 15.9 6.52 +59.0% BM_ParseNodeNameAsStringPiece_foo_bar_baz_123_ctrl 7.89 7.08 +10.3% Skip graph copy in ShapeOptimizer in more cases. PiperOrigin-RevId: 269893789
This commit is contained in:
parent
b45e94d159
commit
959487fd0e
@ -42,12 +42,16 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
bool has_size = false;
|
||||
bool has_shape = false;
|
||||
bool has_prod = false;
|
||||
auto is_int = [](const NodeDef& node) -> bool {
|
||||
return node.attr().at("T").type() == DT_INT32 ||
|
||||
node.attr().at("T").type() == DT_INT64;
|
||||
};
|
||||
for (const NodeDef& node : item.graph.node()) {
|
||||
if (IsShape(node)) {
|
||||
has_shape = true;
|
||||
} else if (IsProd(node)) {
|
||||
} else if (IsProd(node) && is_int(node)) {
|
||||
has_prod = true;
|
||||
} else if (IsDiv(node)) {
|
||||
} else if (IsDiv(node) && is_int(node)) {
|
||||
has_div = true;
|
||||
} else if (IsSize(node)) {
|
||||
has_size = true;
|
||||
|
@ -22,6 +22,8 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
@ -120,54 +122,45 @@ bool IsSameInput(const string& name1, const string& name2);
|
||||
|
||||
// Returns the trailing position number (or zero if no number is present) if
|
||||
// NodeName(input_name) is equal to node_name. Returns -1 for control inputs.
|
||||
// Returns -2 if NodeName(input_name) is not equal to node_name.
|
||||
// Note: This function is used very heavily, and this hand-optimized
|
||||
// version is 3-4x faster than the version using Scanner, which it replaced.
|
||||
// This is worth the reduction in readability.
|
||||
inline int NodePositionIfSameNode(const string& input_name,
|
||||
const string& node_name) {
|
||||
if (input_name.empty()) return -2;
|
||||
const bool is_ctrl = input_name[0] == '^';
|
||||
auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin();
|
||||
auto node_it = node_name.begin();
|
||||
if (node_name.empty() ||
|
||||
std::distance(input_it, input_name.end()) < node_name.size()) {
|
||||
// Returns -2 if input_name is empty or NodeName(input_name) is not equal to
|
||||
// node_name.
|
||||
inline int NodePositionIfSameNode(absl::string_view input_name,
|
||||
absl::string_view node_name) {
|
||||
bool is_control = absl::StartsWith(input_name, "^");
|
||||
if (is_control) input_name.remove_prefix(1);
|
||||
if (input_name.empty() || node_name.empty() ||
|
||||
input_name.size() < node_name.size()) {
|
||||
return -2;
|
||||
}
|
||||
while (node_it != node_name.end()) {
|
||||
if (*input_it++ != *node_it++) {
|
||||
return -2;
|
||||
}
|
||||
TensorId id = ParseTensorName(input_name);
|
||||
if (id.first != node_name) return -2;
|
||||
if (is_control) return -1;
|
||||
return id.second;
|
||||
}
|
||||
|
||||
// Returns the node name and position in a single call.
|
||||
inline StringPiece ParseNodeNameAsStringPiece(absl::string_view name,
|
||||
int* position) {
|
||||
const bool is_control = absl::StartsWith(name, "^");
|
||||
TensorId id = ParseTensorName(name);
|
||||
if (position) {
|
||||
*position = is_control ? -1 : id.second;
|
||||
}
|
||||
if (input_it == input_name.end()) {
|
||||
return is_ctrl ? -1 : 0;
|
||||
} else if (*input_it++ == ':') {
|
||||
StringPiece remaining(&(*input_it),
|
||||
std::distance(input_it, input_name.end()));
|
||||
int position;
|
||||
if (!strings::safe_strto32(remaining, &position)) {
|
||||
return -2;
|
||||
}
|
||||
return is_ctrl ? -1 : position;
|
||||
} else {
|
||||
return -2;
|
||||
if (is_control && id.second >= 0) {
|
||||
id.first.remove_prefix(1);
|
||||
}
|
||||
return id.first;
|
||||
}
|
||||
|
||||
// Returns the node name and position in a single call.
|
||||
inline string ParseNodeName(const string& name, int* position) {
|
||||
return string(ParseNodeNameAsStringPiece(name, position));
|
||||
}
|
||||
|
||||
// Return the node name corresponding to 'name' if name is valid, or the empty
|
||||
// string otherwise.
|
||||
inline StringPiece NodeNameAsStringPiece(const string& name) {
|
||||
static const string empty;
|
||||
if (name.empty()) return StringPiece(empty);
|
||||
const auto begin_it = name[0] == '^' ? name.begin() + 1 : name.begin();
|
||||
auto end_it = begin_it;
|
||||
while (end_it != name.end() && *end_it != ':') {
|
||||
++end_it;
|
||||
}
|
||||
if (end_it != name.end() && *end_it != ':') {
|
||||
return StringPiece(empty);
|
||||
}
|
||||
return StringPiece(&(*begin_it), std::distance(begin_it, end_it));
|
||||
return ParseNodeNameAsStringPiece(name, nullptr);
|
||||
}
|
||||
|
||||
// Return the node name corresponding to 'name' if name is valid, or the empty
|
||||
@ -176,43 +169,6 @@ inline string NodeName(const string& name) {
|
||||
return string(NodeNameAsStringPiece(name));
|
||||
}
|
||||
|
||||
// Returns the node name and position in a single call.
|
||||
// DEPRECATED(ezhulenev): Use TensorId and ParseTensorName.
|
||||
inline StringPiece ParseNodeNameAsStringPiece(const string& name,
|
||||
int* position) {
|
||||
static const string empty;
|
||||
if (name.empty()) {
|
||||
*position = 0;
|
||||
return StringPiece(empty);
|
||||
}
|
||||
const bool is_ctrl = name[0] == '^';
|
||||
const auto begin_it = is_ctrl ? name.begin() + 1 : name.begin();
|
||||
*position = is_ctrl ? -1 : 0;
|
||||
auto end_it = begin_it;
|
||||
while (end_it != name.end() && *end_it != ':') {
|
||||
++end_it;
|
||||
}
|
||||
const StringPiece node_name(&(*begin_it), std::distance(begin_it, end_it));
|
||||
if (end_it != name.end()) {
|
||||
if (*end_it != ':') {
|
||||
return StringPiece(empty);
|
||||
} else if (!is_ctrl) {
|
||||
++end_it;
|
||||
StringPiece remaining(&(*end_it), std::distance(end_it, name.end()));
|
||||
if (!strings::safe_strto32(remaining, position)) {
|
||||
return StringPiece(empty);
|
||||
}
|
||||
}
|
||||
}
|
||||
return node_name;
|
||||
}
|
||||
|
||||
// Returns the node name and position in a single call.
|
||||
// DEPRECATED(ezhulenev): Use SafeTensorId and ParseTensorName.
|
||||
inline string ParseNodeName(const string& name, int* position) {
|
||||
return string(ParseNodeNameAsStringPiece(name, position));
|
||||
}
|
||||
|
||||
inline int NodePosition(const string& name) {
|
||||
int position;
|
||||
ParseNodeNameAsStringPiece(name, &position);
|
||||
|
@ -454,6 +454,16 @@ BM_NodePositionIfSameNode("^foo/bar/baz", "foo/bar/baz", Match_Ctrl);
|
||||
BM_NodePositionIfSameNode("blah", "foo/bar/baz", NoMatch_0);
|
||||
BM_NodePositionIfSameNode("foo/bar/baz/gnu", "foo/bar/baz", NoMatch_end);
|
||||
|
||||
static void BM_NodeNameAsStringPiece(int iters, int size) {
|
||||
string input(size + 3, 'x');
|
||||
input[size] = ':';
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
StringPiece node_name = NodeNameAsStringPiece(input);
|
||||
CHECK_GT(node_name.size(), 0);
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_NodeNameAsStringPiece)->Range(1, 1024);
|
||||
|
||||
#define BM_ParseNodeNameAsStringPiece(I, NAME) \
|
||||
static void BM_ParseNodeNameAsStringPiece_##NAME(int iters) { \
|
||||
string input = I; \
|
||||
|
Loading…
x
Reference in New Issue
Block a user