Save unnecessary lookup in NodeMap constructor.
Run on XXXX (72 X 2991 MHz CPUs); 2019-09-16T15:05:55.08629839-07:00 CPU: Intel Skylake Xeon with HyperThreading (36 cores) dL1:32KB dL2:1024KB dL3:24MB Benchmark Base (ns) New (ns) Improvement ------------------------------------------------------------------ BM_NodeMapConstruct/8 1504 1324 +12.0% BM_NodeMapConstruct/64 20784 17478 +15.9% BM_NodeMapConstruct/512 213301 178139 +16.5% BM_NodeMapConstruct/4k 2200710 1871703 +15.0% BM_NodeMapConstruct/32k 25063213 22232692 +11.3% BM_NodeMapConstruct/256k 370569289 326468834 +11.9% BM_NodeMapConstruct/1M 1832415652 1596115918 +12.9% PiperOrigin-RevId: 269444504
This commit is contained in:
parent
4e2bbf82aa
commit
cd48ba6db5
@ -46,6 +46,11 @@ tf_cc_test(
|
||||
"//tensorflow/core:tensor_testutil",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/lib/bfloat16",
|
||||
"//tensorflow/core/lib/core:status",
|
||||
"//tensorflow/core/platform:notification",
|
||||
"//tensorflow/core/platform:types",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -80,10 +80,12 @@ NodeMap::NodeMap(GraphDef* graph) {
|
||||
auto rslt = nodes_.emplace(node_name, node);
|
||||
// Check that the graph doesn't contain multiple nodes with the same name.
|
||||
if (!rslt.second) {
|
||||
// The first node found with a given name becomes the canonical.
|
||||
LOG(WARNING) << "Duplicated node in the graph: " << node_name;
|
||||
}
|
||||
NodeDef* canonical = rslt.second ? node : rslt.first->second;
|
||||
for (const auto& input : node->input()) {
|
||||
outputs_[NodeName(input)].insert(nodes_[node_name]);
|
||||
outputs_[NodeName(input)].insert(canonical);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
|
||||
#include <unistd.h>
|
||||
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
|
||||
@ -23,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/benchmark_testlib.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -599,6 +601,17 @@ TEST(SetTensorValueTest, Quantized) {
|
||||
/*error_msg=*/"");
|
||||
}
|
||||
|
||||
static void BM_NodeMapConstruct(int iters, int size) {
|
||||
testing::StopTiming();
|
||||
GraphDef graph = test::CreateRandomGraph(size);
|
||||
testing::StartTiming();
|
||||
for (int i = 0; i < iters; i++) {
|
||||
NodeMap node_map(&graph);
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
BENCHMARK(BM_NodeMapConstruct)->Range(1, 1 << 20);
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user