diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 35ca93d9345..84d813fe771 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -165,6 +165,7 @@ tf_cc_test( cc_library( name = "cost_estimator", + srcs = ["cost_estimator.cc"], hdrs = ["cost_estimator.h"], visibility = ["//visibility:public"], deps = [ @@ -173,6 +174,16 @@ cc_library( ], ) +tf_cc_test( + name = "cost_estimator_test", + srcs = ["cost_estimator_test.cc"], + deps = [ + ":cost_estimator", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "virtual_placer", srcs = ["virtual_placer.cc"], diff --git a/tensorflow/core/grappler/costs/cost_estimator.cc b/tensorflow/core/grappler/costs/cost_estimator.cc new file mode 100644 index 00000000000..0fc4e99689b --- /dev/null +++ b/tensorflow/core/grappler/costs/cost_estimator.cc @@ -0,0 +1,81 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/costs/cost_estimator.h" + +namespace tensorflow { +namespace grappler { + +Costs CombineCosts(const Costs& left, const Costs& right) { + CHECK_NE(left.max_memory, kMemoryUnknown); + CHECK_NE(left.max_per_op_buffers, kMemoryUnknown); + CHECK_NE(left.max_per_op_streaming, kMemoryUnknown); + + Costs result = left; + result.execution_time += right.execution_time; + result.compute_time += right.compute_time; + result.memory_time += right.memory_time; + result.intermediate_memory_time += right.intermediate_memory_time; + result.intermediate_memory_read_time += right.intermediate_memory_read_time; + result.intermediate_memory_write_time += right.intermediate_memory_write_time; + + if (right.max_per_op_buffers != kMemoryUnknown) { + result.max_per_op_buffers = + std::max(left.max_per_op_buffers, right.max_per_op_buffers); + } + if (right.max_per_op_streaming != kMemoryUnknown) { + result.max_per_op_streaming = + std::max(left.max_per_op_streaming, right.max_per_op_streaming); + } + + result.num_ops_total += right.num_ops_total; + if (right.inaccurate) { + result.inaccurate = true; + } + result.num_ops_with_unknown_shapes += right.num_ops_with_unknown_shapes; + if (right.max_memory != kMemoryUnknown) { + result.max_memory += right.max_memory; + } + + return result; +} + +// Multiplies Costs by a scalar. +// Equivalent to applying CombineCosts "multiplier" times. +// Note the field regarding num_ops are not multiplied. +Costs MultiplyCosts(const Costs& costs, int multiplier) { + CHECK_GE(multiplier, 0); + if (multiplier == 0) { + return Costs::ZeroCosts(); + } + if (multiplier == 1) { + return costs; + } + + Costs result = costs; + result.execution_time *= multiplier; + result.compute_time *= multiplier; + result.memory_time *= multiplier; + result.intermediate_memory_time *= multiplier; + result.intermediate_memory_read_time *= multiplier; + result.intermediate_memory_write_time *= multiplier; + if (result.max_memory != kMemoryUnknown) { + result.max_memory *= multiplier; + } + return result; +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/cost_estimator.h b/tensorflow/core/grappler/costs/cost_estimator.h index 5876d6f2e77..9815d3d0c04 100644 --- a/tensorflow/core/grappler/costs/cost_estimator.h +++ b/tensorflow/core/grappler/costs/cost_estimator.h @@ -16,9 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_ #define TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_ -#include -#include -#include #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/protobuf/config.pb.h" @@ -204,6 +201,12 @@ Costs Costs::ZeroCosts() { return costs; } +Costs CombineCosts(const Costs& left, const Costs& right); + +// Multiplies Costs by a scalar. +// Equivalent to applying CombineCosts "multiplier" times. +Costs MultiplyCosts(const Costs& costs, int multiplier); + // Given a GrapperItem and an optimized implementation of the corresponding // TensorFlow graph, the CostEstimator attempts to predicts the actual cost of // running the graph. diff --git a/tensorflow/core/grappler/costs/cost_estimator_test.cc b/tensorflow/core/grappler/costs/cost_estimator_test.cc new file mode 100644 index 00000000000..62197a43dfb --- /dev/null +++ b/tensorflow/core/grappler/costs/cost_estimator_test.cc @@ -0,0 +1,88 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/costs/cost_estimator.h" + +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +TEST(CostEstimatorTest, CombineCosts) { + Costs c = Costs::ZeroCosts(); + c.execution_time = Costs::NanoSeconds(1); + c.compute_time = Costs::NanoSeconds(2); + c.memory_time = Costs::NanoSeconds(3); + c.intermediate_memory_time = Costs::NanoSeconds(4); + c.intermediate_memory_read_time = Costs::NanoSeconds(5); + c.intermediate_memory_write_time = Costs::NanoSeconds(6); + c.max_memory = 1; + c.max_per_op_buffers = 2; + c.max_per_op_streaming = 3; + c.num_ops_total = 1; + c.inaccurate = false; + c.num_ops_with_unknown_shapes = 0; + + Costs sum = CombineCosts(c, c); + + EXPECT_EQ(sum.execution_time, Costs::NanoSeconds(2)); + EXPECT_EQ(sum.compute_time, Costs::NanoSeconds(4)); + EXPECT_EQ(sum.memory_time, Costs::NanoSeconds(6)); + EXPECT_EQ(sum.intermediate_memory_time, Costs::NanoSeconds(8)); + EXPECT_EQ(sum.intermediate_memory_read_time, Costs::NanoSeconds(10)); + EXPECT_EQ(sum.intermediate_memory_write_time, Costs::NanoSeconds(12)); + EXPECT_EQ(sum.max_memory, 2); + EXPECT_EQ(sum.max_per_op_buffers, 2); + EXPECT_EQ(sum.max_per_op_streaming, 3); + EXPECT_EQ(sum.num_ops_total, 2); + EXPECT_FALSE(sum.inaccurate); + EXPECT_EQ(sum.num_ops_with_unknown_shapes, 0); +} + +TEST(CostEstimatorTest, MultiplyCosts) { + Costs c = Costs::ZeroCosts(); + c.execution_time = Costs::NanoSeconds(1); + c.compute_time = Costs::NanoSeconds(2); + c.memory_time = Costs::NanoSeconds(3); + c.intermediate_memory_time = Costs::NanoSeconds(4); + c.intermediate_memory_read_time = Costs::NanoSeconds(5); + c.intermediate_memory_write_time = Costs::NanoSeconds(6); + c.max_memory = 1; + c.max_per_op_buffers = 2; + c.max_per_op_streaming = 3; + c.num_ops_total = 1; + c.inaccurate = false; + c.num_ops_with_unknown_shapes = 0; + + Costs product = MultiplyCosts(c, 10); + + EXPECT_EQ(product.execution_time, Costs::NanoSeconds(10)); + EXPECT_EQ(product.compute_time, Costs::NanoSeconds(20)); + EXPECT_EQ(product.memory_time, Costs::NanoSeconds(30)); + EXPECT_EQ(product.intermediate_memory_time, Costs::NanoSeconds(40)); + EXPECT_EQ(product.intermediate_memory_read_time, Costs::NanoSeconds(50)); + EXPECT_EQ(product.intermediate_memory_write_time, Costs::NanoSeconds(60)); + EXPECT_EQ(product.max_memory, 10); + EXPECT_EQ(product.max_per_op_buffers, 2); + EXPECT_EQ(product.max_per_op_streaming, 3); + EXPECT_EQ(product.num_ops_total, 1); + EXPECT_FALSE(product.inaccurate); + EXPECT_EQ(product.num_ops_with_unknown_shapes, 0); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index d5492468f44..a3c023089ba 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -34,41 +34,9 @@ limitations under the License. namespace tensorflow { namespace grappler { + namespace { -Costs CombineCosts(const Costs& left, const Costs& right) { - CHECK_NE(left.max_memory, kMemoryUnknown); - CHECK_NE(left.max_per_op_buffers, kMemoryUnknown); - CHECK_NE(left.max_per_op_streaming, kMemoryUnknown); - - Costs result = left; - result.execution_time += right.execution_time; - result.compute_time += right.compute_time; - result.memory_time += right.memory_time; - result.intermediate_memory_time += right.intermediate_memory_time; - - result.num_ops_total += right.num_ops_total; - if (right.inaccurate) result.inaccurate = true; - result.num_ops_with_unknown_shapes += right.num_ops_with_unknown_shapes; - - if (right.max_memory != kMemoryUnknown) { - result.max_memory += right.max_memory; - } - if (right.max_per_op_buffers != kMemoryUnknown) { - result.max_per_op_buffers = - std::max(left.max_per_op_buffers, right.max_per_op_buffers); - } - if (right.max_per_op_streaming != kMemoryUnknown) { - result.max_per_op_streaming = - std::max(left.max_per_op_streaming, right.max_per_op_streaming); - } - VLOG(4) << "costs execution_time=" << result.execution_time.count() - << " max_memory=" << result.max_memory - << " max_per_op_buffers=" << result.max_per_op_buffers - << " max_per_op_streaming=" << result.max_per_op_streaming; - return result; -} - // Key to the cached _Recv ops map, and its hash and predicate structures. struct RecvNodeDescriptor { const NodeDef* node;