Refactoring for internal use.
PiperOrigin-RevId: 238065622
This commit is contained in:
parent
ab9b930f75
commit
92875ced0f
@ -165,6 +165,7 @@ tf_cc_test(
|
|||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "cost_estimator",
|
name = "cost_estimator",
|
||||||
|
srcs = ["cost_estimator.cc"],
|
||||||
hdrs = ["cost_estimator.h"],
|
hdrs = ["cost_estimator.h"],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
@ -173,6 +174,16 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "cost_estimator_test",
|
||||||
|
srcs = ["cost_estimator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":cost_estimator",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "virtual_placer",
|
name = "virtual_placer",
|
||||||
srcs = ["virtual_placer.cc"],
|
srcs = ["virtual_placer.cc"],
|
||||||
|
|||||||
81
tensorflow/core/grappler/costs/cost_estimator.cc
Normal file
81
tensorflow/core/grappler/costs/cost_estimator.cc
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/grappler/costs/cost_estimator.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace grappler {
|
||||||
|
|
||||||
|
Costs CombineCosts(const Costs& left, const Costs& right) {
|
||||||
|
CHECK_NE(left.max_memory, kMemoryUnknown);
|
||||||
|
CHECK_NE(left.max_per_op_buffers, kMemoryUnknown);
|
||||||
|
CHECK_NE(left.max_per_op_streaming, kMemoryUnknown);
|
||||||
|
|
||||||
|
Costs result = left;
|
||||||
|
result.execution_time += right.execution_time;
|
||||||
|
result.compute_time += right.compute_time;
|
||||||
|
result.memory_time += right.memory_time;
|
||||||
|
result.intermediate_memory_time += right.intermediate_memory_time;
|
||||||
|
result.intermediate_memory_read_time += right.intermediate_memory_read_time;
|
||||||
|
result.intermediate_memory_write_time += right.intermediate_memory_write_time;
|
||||||
|
|
||||||
|
if (right.max_per_op_buffers != kMemoryUnknown) {
|
||||||
|
result.max_per_op_buffers =
|
||||||
|
std::max(left.max_per_op_buffers, right.max_per_op_buffers);
|
||||||
|
}
|
||||||
|
if (right.max_per_op_streaming != kMemoryUnknown) {
|
||||||
|
result.max_per_op_streaming =
|
||||||
|
std::max(left.max_per_op_streaming, right.max_per_op_streaming);
|
||||||
|
}
|
||||||
|
|
||||||
|
result.num_ops_total += right.num_ops_total;
|
||||||
|
if (right.inaccurate) {
|
||||||
|
result.inaccurate = true;
|
||||||
|
}
|
||||||
|
result.num_ops_with_unknown_shapes += right.num_ops_with_unknown_shapes;
|
||||||
|
if (right.max_memory != kMemoryUnknown) {
|
||||||
|
result.max_memory += right.max_memory;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiplies Costs by a scalar.
|
||||||
|
// Equivalent to applying CombineCosts "multiplier" times.
|
||||||
|
// Note the field regarding num_ops are not multiplied.
|
||||||
|
Costs MultiplyCosts(const Costs& costs, int multiplier) {
|
||||||
|
CHECK_GE(multiplier, 0);
|
||||||
|
if (multiplier == 0) {
|
||||||
|
return Costs::ZeroCosts();
|
||||||
|
}
|
||||||
|
if (multiplier == 1) {
|
||||||
|
return costs;
|
||||||
|
}
|
||||||
|
|
||||||
|
Costs result = costs;
|
||||||
|
result.execution_time *= multiplier;
|
||||||
|
result.compute_time *= multiplier;
|
||||||
|
result.memory_time *= multiplier;
|
||||||
|
result.intermediate_memory_time *= multiplier;
|
||||||
|
result.intermediate_memory_read_time *= multiplier;
|
||||||
|
result.intermediate_memory_write_time *= multiplier;
|
||||||
|
if (result.max_memory != kMemoryUnknown) {
|
||||||
|
result.max_memory *= multiplier;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace grappler
|
||||||
|
} // end namespace tensorflow
|
||||||
@ -16,9 +16,6 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_
|
#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_
|
||||||
#define TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_
|
#define TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_
|
||||||
|
|
||||||
#include <chrono>
|
|
||||||
#include <cmath>
|
|
||||||
#include <unordered_map>
|
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/protobuf/config.pb.h"
|
#include "tensorflow/core/protobuf/config.pb.h"
|
||||||
|
|
||||||
@ -204,6 +201,12 @@ Costs Costs::ZeroCosts() {
|
|||||||
return costs;
|
return costs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Costs CombineCosts(const Costs& left, const Costs& right);
|
||||||
|
|
||||||
|
// Multiplies Costs by a scalar.
|
||||||
|
// Equivalent to applying CombineCosts "multiplier" times.
|
||||||
|
Costs MultiplyCosts(const Costs& costs, int multiplier);
|
||||||
|
|
||||||
// Given a GrapperItem and an optimized implementation of the corresponding
|
// Given a GrapperItem and an optimized implementation of the corresponding
|
||||||
// TensorFlow graph, the CostEstimator attempts to predicts the actual cost of
|
// TensorFlow graph, the CostEstimator attempts to predicts the actual cost of
|
||||||
// running the graph.
|
// running the graph.
|
||||||
|
|||||||
88
tensorflow/core/grappler/costs/cost_estimator_test.cc
Normal file
88
tensorflow/core/grappler/costs/cost_estimator_test.cc
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/grappler/costs/cost_estimator.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace grappler {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(CostEstimatorTest, CombineCosts) {
|
||||||
|
Costs c = Costs::ZeroCosts();
|
||||||
|
c.execution_time = Costs::NanoSeconds(1);
|
||||||
|
c.compute_time = Costs::NanoSeconds(2);
|
||||||
|
c.memory_time = Costs::NanoSeconds(3);
|
||||||
|
c.intermediate_memory_time = Costs::NanoSeconds(4);
|
||||||
|
c.intermediate_memory_read_time = Costs::NanoSeconds(5);
|
||||||
|
c.intermediate_memory_write_time = Costs::NanoSeconds(6);
|
||||||
|
c.max_memory = 1;
|
||||||
|
c.max_per_op_buffers = 2;
|
||||||
|
c.max_per_op_streaming = 3;
|
||||||
|
c.num_ops_total = 1;
|
||||||
|
c.inaccurate = false;
|
||||||
|
c.num_ops_with_unknown_shapes = 0;
|
||||||
|
|
||||||
|
Costs sum = CombineCosts(c, c);
|
||||||
|
|
||||||
|
EXPECT_EQ(sum.execution_time, Costs::NanoSeconds(2));
|
||||||
|
EXPECT_EQ(sum.compute_time, Costs::NanoSeconds(4));
|
||||||
|
EXPECT_EQ(sum.memory_time, Costs::NanoSeconds(6));
|
||||||
|
EXPECT_EQ(sum.intermediate_memory_time, Costs::NanoSeconds(8));
|
||||||
|
EXPECT_EQ(sum.intermediate_memory_read_time, Costs::NanoSeconds(10));
|
||||||
|
EXPECT_EQ(sum.intermediate_memory_write_time, Costs::NanoSeconds(12));
|
||||||
|
EXPECT_EQ(sum.max_memory, 2);
|
||||||
|
EXPECT_EQ(sum.max_per_op_buffers, 2);
|
||||||
|
EXPECT_EQ(sum.max_per_op_streaming, 3);
|
||||||
|
EXPECT_EQ(sum.num_ops_total, 2);
|
||||||
|
EXPECT_FALSE(sum.inaccurate);
|
||||||
|
EXPECT_EQ(sum.num_ops_with_unknown_shapes, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CostEstimatorTest, MultiplyCosts) {
|
||||||
|
Costs c = Costs::ZeroCosts();
|
||||||
|
c.execution_time = Costs::NanoSeconds(1);
|
||||||
|
c.compute_time = Costs::NanoSeconds(2);
|
||||||
|
c.memory_time = Costs::NanoSeconds(3);
|
||||||
|
c.intermediate_memory_time = Costs::NanoSeconds(4);
|
||||||
|
c.intermediate_memory_read_time = Costs::NanoSeconds(5);
|
||||||
|
c.intermediate_memory_write_time = Costs::NanoSeconds(6);
|
||||||
|
c.max_memory = 1;
|
||||||
|
c.max_per_op_buffers = 2;
|
||||||
|
c.max_per_op_streaming = 3;
|
||||||
|
c.num_ops_total = 1;
|
||||||
|
c.inaccurate = false;
|
||||||
|
c.num_ops_with_unknown_shapes = 0;
|
||||||
|
|
||||||
|
Costs product = MultiplyCosts(c, 10);
|
||||||
|
|
||||||
|
EXPECT_EQ(product.execution_time, Costs::NanoSeconds(10));
|
||||||
|
EXPECT_EQ(product.compute_time, Costs::NanoSeconds(20));
|
||||||
|
EXPECT_EQ(product.memory_time, Costs::NanoSeconds(30));
|
||||||
|
EXPECT_EQ(product.intermediate_memory_time, Costs::NanoSeconds(40));
|
||||||
|
EXPECT_EQ(product.intermediate_memory_read_time, Costs::NanoSeconds(50));
|
||||||
|
EXPECT_EQ(product.intermediate_memory_write_time, Costs::NanoSeconds(60));
|
||||||
|
EXPECT_EQ(product.max_memory, 10);
|
||||||
|
EXPECT_EQ(product.max_per_op_buffers, 2);
|
||||||
|
EXPECT_EQ(product.max_per_op_streaming, 3);
|
||||||
|
EXPECT_EQ(product.num_ops_total, 1);
|
||||||
|
EXPECT_FALSE(product.inaccurate);
|
||||||
|
EXPECT_EQ(product.num_ops_with_unknown_shapes, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace grappler
|
||||||
|
} // namespace tensorflow
|
||||||
@ -34,41 +34,9 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
Costs CombineCosts(const Costs& left, const Costs& right) {
|
|
||||||
CHECK_NE(left.max_memory, kMemoryUnknown);
|
|
||||||
CHECK_NE(left.max_per_op_buffers, kMemoryUnknown);
|
|
||||||
CHECK_NE(left.max_per_op_streaming, kMemoryUnknown);
|
|
||||||
|
|
||||||
Costs result = left;
|
|
||||||
result.execution_time += right.execution_time;
|
|
||||||
result.compute_time += right.compute_time;
|
|
||||||
result.memory_time += right.memory_time;
|
|
||||||
result.intermediate_memory_time += right.intermediate_memory_time;
|
|
||||||
|
|
||||||
result.num_ops_total += right.num_ops_total;
|
|
||||||
if (right.inaccurate) result.inaccurate = true;
|
|
||||||
result.num_ops_with_unknown_shapes += right.num_ops_with_unknown_shapes;
|
|
||||||
|
|
||||||
if (right.max_memory != kMemoryUnknown) {
|
|
||||||
result.max_memory += right.max_memory;
|
|
||||||
}
|
|
||||||
if (right.max_per_op_buffers != kMemoryUnknown) {
|
|
||||||
result.max_per_op_buffers =
|
|
||||||
std::max(left.max_per_op_buffers, right.max_per_op_buffers);
|
|
||||||
}
|
|
||||||
if (right.max_per_op_streaming != kMemoryUnknown) {
|
|
||||||
result.max_per_op_streaming =
|
|
||||||
std::max(left.max_per_op_streaming, right.max_per_op_streaming);
|
|
||||||
}
|
|
||||||
VLOG(4) << "costs execution_time=" << result.execution_time.count()
|
|
||||||
<< " max_memory=" << result.max_memory
|
|
||||||
<< " max_per_op_buffers=" << result.max_per_op_buffers
|
|
||||||
<< " max_per_op_streaming=" << result.max_per_op_streaming;
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Key to the cached _Recv ops map, and its hash and predicate structures.
|
// Key to the cached _Recv ops map, and its hash and predicate structures.
|
||||||
struct RecvNodeDescriptor {
|
struct RecvNodeDescriptor {
|
||||||
const NodeDef* node;
|
const NodeDef* node;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user