Refactoring for internal use.
PiperOrigin-RevId: 238065622
This commit is contained in:
parent
ab9b930f75
commit
92875ced0f
@ -165,6 +165,7 @@ tf_cc_test(
|
||||
|
||||
cc_library(
|
||||
name = "cost_estimator",
|
||||
srcs = ["cost_estimator.cc"],
|
||||
hdrs = ["cost_estimator.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
@ -173,6 +174,16 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "cost_estimator_test",
|
||||
srcs = ["cost_estimator_test.cc"],
|
||||
deps = [
|
||||
":cost_estimator",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "virtual_placer",
|
||||
srcs = ["virtual_placer.cc"],
|
||||
|
||||
81
tensorflow/core/grappler/costs/cost_estimator.cc
Normal file
81
tensorflow/core/grappler/costs/cost_estimator.cc
Normal file
@ -0,0 +1,81 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/costs/cost_estimator.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
Costs CombineCosts(const Costs& left, const Costs& right) {
|
||||
CHECK_NE(left.max_memory, kMemoryUnknown);
|
||||
CHECK_NE(left.max_per_op_buffers, kMemoryUnknown);
|
||||
CHECK_NE(left.max_per_op_streaming, kMemoryUnknown);
|
||||
|
||||
Costs result = left;
|
||||
result.execution_time += right.execution_time;
|
||||
result.compute_time += right.compute_time;
|
||||
result.memory_time += right.memory_time;
|
||||
result.intermediate_memory_time += right.intermediate_memory_time;
|
||||
result.intermediate_memory_read_time += right.intermediate_memory_read_time;
|
||||
result.intermediate_memory_write_time += right.intermediate_memory_write_time;
|
||||
|
||||
if (right.max_per_op_buffers != kMemoryUnknown) {
|
||||
result.max_per_op_buffers =
|
||||
std::max(left.max_per_op_buffers, right.max_per_op_buffers);
|
||||
}
|
||||
if (right.max_per_op_streaming != kMemoryUnknown) {
|
||||
result.max_per_op_streaming =
|
||||
std::max(left.max_per_op_streaming, right.max_per_op_streaming);
|
||||
}
|
||||
|
||||
result.num_ops_total += right.num_ops_total;
|
||||
if (right.inaccurate) {
|
||||
result.inaccurate = true;
|
||||
}
|
||||
result.num_ops_with_unknown_shapes += right.num_ops_with_unknown_shapes;
|
||||
if (right.max_memory != kMemoryUnknown) {
|
||||
result.max_memory += right.max_memory;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Multiplies Costs by a scalar.
|
||||
// Equivalent to applying CombineCosts "multiplier" times.
|
||||
// Note the field regarding num_ops are not multiplied.
|
||||
Costs MultiplyCosts(const Costs& costs, int multiplier) {
|
||||
CHECK_GE(multiplier, 0);
|
||||
if (multiplier == 0) {
|
||||
return Costs::ZeroCosts();
|
||||
}
|
||||
if (multiplier == 1) {
|
||||
return costs;
|
||||
}
|
||||
|
||||
Costs result = costs;
|
||||
result.execution_time *= multiplier;
|
||||
result.compute_time *= multiplier;
|
||||
result.memory_time *= multiplier;
|
||||
result.intermediate_memory_time *= multiplier;
|
||||
result.intermediate_memory_read_time *= multiplier;
|
||||
result.intermediate_memory_write_time *= multiplier;
|
||||
if (result.max_memory != kMemoryUnknown) {
|
||||
result.max_memory *= multiplier;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
@ -16,9 +16,6 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_
|
||||
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <unordered_map>
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
|
||||
@ -204,6 +201,12 @@ Costs Costs::ZeroCosts() {
|
||||
return costs;
|
||||
}
|
||||
|
||||
Costs CombineCosts(const Costs& left, const Costs& right);
|
||||
|
||||
// Multiplies Costs by a scalar.
|
||||
// Equivalent to applying CombineCosts "multiplier" times.
|
||||
Costs MultiplyCosts(const Costs& costs, int multiplier);
|
||||
|
||||
// Given a GrapperItem and an optimized implementation of the corresponding
|
||||
// TensorFlow graph, the CostEstimator attempts to predicts the actual cost of
|
||||
// running the graph.
|
||||
|
||||
88
tensorflow/core/grappler/costs/cost_estimator_test.cc
Normal file
88
tensorflow/core/grappler/costs/cost_estimator_test.cc
Normal file
@ -0,0 +1,88 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/costs/cost_estimator.h"
|
||||
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
TEST(CostEstimatorTest, CombineCosts) {
|
||||
Costs c = Costs::ZeroCosts();
|
||||
c.execution_time = Costs::NanoSeconds(1);
|
||||
c.compute_time = Costs::NanoSeconds(2);
|
||||
c.memory_time = Costs::NanoSeconds(3);
|
||||
c.intermediate_memory_time = Costs::NanoSeconds(4);
|
||||
c.intermediate_memory_read_time = Costs::NanoSeconds(5);
|
||||
c.intermediate_memory_write_time = Costs::NanoSeconds(6);
|
||||
c.max_memory = 1;
|
||||
c.max_per_op_buffers = 2;
|
||||
c.max_per_op_streaming = 3;
|
||||
c.num_ops_total = 1;
|
||||
c.inaccurate = false;
|
||||
c.num_ops_with_unknown_shapes = 0;
|
||||
|
||||
Costs sum = CombineCosts(c, c);
|
||||
|
||||
EXPECT_EQ(sum.execution_time, Costs::NanoSeconds(2));
|
||||
EXPECT_EQ(sum.compute_time, Costs::NanoSeconds(4));
|
||||
EXPECT_EQ(sum.memory_time, Costs::NanoSeconds(6));
|
||||
EXPECT_EQ(sum.intermediate_memory_time, Costs::NanoSeconds(8));
|
||||
EXPECT_EQ(sum.intermediate_memory_read_time, Costs::NanoSeconds(10));
|
||||
EXPECT_EQ(sum.intermediate_memory_write_time, Costs::NanoSeconds(12));
|
||||
EXPECT_EQ(sum.max_memory, 2);
|
||||
EXPECT_EQ(sum.max_per_op_buffers, 2);
|
||||
EXPECT_EQ(sum.max_per_op_streaming, 3);
|
||||
EXPECT_EQ(sum.num_ops_total, 2);
|
||||
EXPECT_FALSE(sum.inaccurate);
|
||||
EXPECT_EQ(sum.num_ops_with_unknown_shapes, 0);
|
||||
}
|
||||
|
||||
TEST(CostEstimatorTest, MultiplyCosts) {
|
||||
Costs c = Costs::ZeroCosts();
|
||||
c.execution_time = Costs::NanoSeconds(1);
|
||||
c.compute_time = Costs::NanoSeconds(2);
|
||||
c.memory_time = Costs::NanoSeconds(3);
|
||||
c.intermediate_memory_time = Costs::NanoSeconds(4);
|
||||
c.intermediate_memory_read_time = Costs::NanoSeconds(5);
|
||||
c.intermediate_memory_write_time = Costs::NanoSeconds(6);
|
||||
c.max_memory = 1;
|
||||
c.max_per_op_buffers = 2;
|
||||
c.max_per_op_streaming = 3;
|
||||
c.num_ops_total = 1;
|
||||
c.inaccurate = false;
|
||||
c.num_ops_with_unknown_shapes = 0;
|
||||
|
||||
Costs product = MultiplyCosts(c, 10);
|
||||
|
||||
EXPECT_EQ(product.execution_time, Costs::NanoSeconds(10));
|
||||
EXPECT_EQ(product.compute_time, Costs::NanoSeconds(20));
|
||||
EXPECT_EQ(product.memory_time, Costs::NanoSeconds(30));
|
||||
EXPECT_EQ(product.intermediate_memory_time, Costs::NanoSeconds(40));
|
||||
EXPECT_EQ(product.intermediate_memory_read_time, Costs::NanoSeconds(50));
|
||||
EXPECT_EQ(product.intermediate_memory_write_time, Costs::NanoSeconds(60));
|
||||
EXPECT_EQ(product.max_memory, 10);
|
||||
EXPECT_EQ(product.max_per_op_buffers, 2);
|
||||
EXPECT_EQ(product.max_per_op_streaming, 3);
|
||||
EXPECT_EQ(product.num_ops_total, 1);
|
||||
EXPECT_FALSE(product.inaccurate);
|
||||
EXPECT_EQ(product.num_ops_with_unknown_shapes, 0);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
@ -34,41 +34,9 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
namespace {
|
||||
|
||||
Costs CombineCosts(const Costs& left, const Costs& right) {
|
||||
CHECK_NE(left.max_memory, kMemoryUnknown);
|
||||
CHECK_NE(left.max_per_op_buffers, kMemoryUnknown);
|
||||
CHECK_NE(left.max_per_op_streaming, kMemoryUnknown);
|
||||
|
||||
Costs result = left;
|
||||
result.execution_time += right.execution_time;
|
||||
result.compute_time += right.compute_time;
|
||||
result.memory_time += right.memory_time;
|
||||
result.intermediate_memory_time += right.intermediate_memory_time;
|
||||
|
||||
result.num_ops_total += right.num_ops_total;
|
||||
if (right.inaccurate) result.inaccurate = true;
|
||||
result.num_ops_with_unknown_shapes += right.num_ops_with_unknown_shapes;
|
||||
|
||||
if (right.max_memory != kMemoryUnknown) {
|
||||
result.max_memory += right.max_memory;
|
||||
}
|
||||
if (right.max_per_op_buffers != kMemoryUnknown) {
|
||||
result.max_per_op_buffers =
|
||||
std::max(left.max_per_op_buffers, right.max_per_op_buffers);
|
||||
}
|
||||
if (right.max_per_op_streaming != kMemoryUnknown) {
|
||||
result.max_per_op_streaming =
|
||||
std::max(left.max_per_op_streaming, right.max_per_op_streaming);
|
||||
}
|
||||
VLOG(4) << "costs execution_time=" << result.execution_time.count()
|
||||
<< " max_memory=" << result.max_memory
|
||||
<< " max_per_op_buffers=" << result.max_per_op_buffers
|
||||
<< " max_per_op_streaming=" << result.max_per_op_streaming;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Key to the cached _Recv ops map, and its hash and predicate structures.
|
||||
struct RecvNodeDescriptor {
|
||||
const NodeDef* node;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user