Don't use the _output_shape attribute in the op_level_cost_estimator since

there is no guaranty that it will be present or accurate.

PiperOrigin-RevId: 157898989
This commit is contained in:
Benoit Steiner 2017-06-02 18:34:22 -07:00 committed by TensorFlower Gardener
parent 6f4204c3d3
commit bb7a8d8e72
10 changed files with 216 additions and 58 deletions

View File

@ -267,3 +267,18 @@ cc_library(
"//tensorflow/core/grappler:grappler_item",
],
)
cc_test(
name = "analytical_cost_estimator_test",
srcs = ["analytical_cost_estimator_test.cc"],
deps = [
":analytical_cost_estimator",
":virtual_scheduler",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/grappler/clusters:virtual_cluster",
],
)

View File

@ -97,7 +97,7 @@ Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph,
node_costs.compute_time.asMicroSeconds().count());
cost_node->set_memory_time(
node_costs.memory_time.asMicroSeconds().count());
for (const auto& output : node_info.outputs) {
for (const auto& output : node_info.op_info.outputs()) {
auto output_info = cost_node->add_output_info();
output_info->set_dtype(output.dtype());
auto shape = output_info->mutable_shape();

View File

@ -0,0 +1,110 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/costs/analytical_cost_estimator.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace grappler {
class AnalyticalCostEstimatorTest : public ::testing::Test {
protected:
void SetUp() override {
// Initializes cluster_ and placer_.
std::unordered_map<string, DeviceProperties> devices;
DeviceProperties cpu_device;
cpu_device.set_type("CPU");
cpu_device.set_num_cores(4);
cpu_device.set_frequency(2600);
cpu_device.set_bandwidth(24 * 1024 * 1024);
devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device;
DeviceProperties gpu_device;
gpu_device.set_type("GPU");
gpu_device.set_num_cores(12);
gpu_device.set_frequency(1100);
gpu_device.set_bandwidth(180 * 1024 * 1024);
(*gpu_device.mutable_environment())["architecture"] = "6";
devices["/job:localhost/replica:0/task:0/gpu:0"] = gpu_device;
cluster_.reset(new VirtualCluster(devices));
}
GrapplerItem CreateMiniGraph() {
const int batch = 1;
const int width = 28;
const int height = 28;
const int num_channels = 1;
const int num_labels = 10;
const int kernel_size = 3;
const int conv_filters = 32;
Scope s = Scope::NewRootScope();
auto images = ops::RandomUniform(
s.WithOpName("image"), {batch, width, height, num_channels}, DT_FLOAT);
auto labels = ops::RandomUniform(s.WithOpName("label"), {batch, num_labels},
DT_FLOAT);
auto w = ops::Variable(
s.WithOpName("W"),
{kernel_size, kernel_size, num_channels, conv_filters}, DT_FLOAT);
auto b = ops::Variable(s.WithOpName("B"), {conv_filters}, DT_FLOAT);
auto conv =
ops::Conv2D(s.WithOpName("conv"), images, w, {1, 1, 1, 1}, "SAME");
auto bias = ops::Add(s.WithOpName("bias"), conv, b);
auto relu = ops::Relu(s.WithOpName("relu"), bias);
auto flat_shape = ops::Const(s.WithOpName("flat_shape"),
{batch, width * height * conv_filters});
auto flat = ops::Reshape(s.WithOpName("flat"), relu, flat_shape);
auto w2 =
ops::Variable(s.WithOpName("W2"),
{width * height * conv_filters, num_labels}, DT_FLOAT);
auto b2 = ops::Variable(s.WithOpName("B2"), {num_labels}, DT_FLOAT);
auto matmul = ops::MatMul(s.WithOpName("matmul"), flat, w2);
auto logits = ops::Add(s.WithOpName("logits"), matmul, b2);
auto softmax = ops::Softmax(s.WithOpName("softmax"), logits);
auto lsm = ops::Log(s.WithOpName("lsm"), softmax);
GrapplerItem item;
item.fetch.push_back("lsm");
TF_CHECK_OK(s.ToGraphDef(&item.graph));
return item;
}
std::unique_ptr<VirtualCluster> cluster_;
};
TEST_F(AnalyticalCostEstimatorTest, SimpleTest) {
GrapplerItem item = CreateMiniGraph();
AnalyticalCostEstimator estimator(cluster_.get(), true);
TF_ASSERT_OK(estimator.Initialize(item));
CostGraphDef cost_graph;
Costs summary;
TF_ASSERT_OK(estimator.PredictCosts(item.graph, &cost_graph, &summary));
EXPECT_EQ(Costs::NanoSeconds(9108), summary.execution_time);
EXPECT_FALSE(summary.inaccurate);
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -76,32 +76,21 @@ std::pair<double, double> OpLevelCostEstimator::GetDeviceInfo(
const DeviceProperties& device) const {
double gflops = -1;
double bandwidth = -1;
if (device.bandwidth() > 0) {
bandwidth = device.bandwidth() / 1e6;
}
if (device.type() == "CPU") {
DeviceProperties local_cpu;
if (device.num_cores() <= 0 || device.frequency() <= 0) {
local_cpu = GetLocalCPUInfo();
} else {
local_cpu = device;
}
// Check if vector instructions are available, and refine performance
// prediction based on this.
// Frequencies are stored in MHz in the DeviceProperties.
gflops = local_cpu.num_cores() * local_cpu.frequency() * 1e-3;
gflops = device.num_cores() * device.frequency() * 1e-3;
if (bandwidth < 0) {
if (local_cpu.bandwidth() > 0) {
bandwidth = local_cpu.bandwidth() / 1e6;
if (device.bandwidth() > 0) {
bandwidth = device.bandwidth() / 1e6;
} else {
bandwidth = 32;
}
}
} else if (device.type() == "GPU") {
const DeviceProperties local_gpu = GetLocalGPUInfo(0);
const string architecture = local_gpu.environment().at("architecture");
const string architecture = device.environment().at("architecture");
int cores_per_multiprocessor;
if (architecture < "3") {
// Fermi
@ -110,17 +99,18 @@ std::pair<double, double> OpLevelCostEstimator::GetDeviceInfo(
// Kepler
cores_per_multiprocessor = 192;
} else if (architecture < "6") {
// Maxwell
// Maxwell
cores_per_multiprocessor = 128;
} else {
// Pascal.
// Pascal
cores_per_multiprocessor = 64;
}
gflops = local_gpu.num_cores() * local_gpu.frequency() * 1e-3 *
gflops = device.num_cores() * device.frequency() * 1e-3 *
cores_per_multiprocessor * kOpsPerMac;
if (bandwidth < 0) {
CHECK(local_gpu.bandwidth() > 0);
bandwidth = local_gpu.bandwidth() / 1e6;
if (device.bandwidth() > 0) {
bandwidth = device.bandwidth() / 1e6;
} else {
bandwidth = 100;
}
}
@ -507,14 +497,13 @@ int64 OpLevelCostEstimator::CountConv2DBackPropInputOperations(
return ops;
}
if (op_features.attr().find("_output_shapes") == op_features.attr().end()) {
if (op_features.outputs_size() != 1) {
// Need _output_shapes for input shape.
LOG(ERROR) << "No output shape in Conv2DBackPropInput op feaure.";
LOG(ERROR) << "No output shape in Conv2DBackPropInput op.";
return ops;
}
const auto& input_shape =
op_features.attr().at("_output_shapes").list().shape(0);
const auto& input_shape = op_features.outputs(0).shape();
ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
input_shape, op_features.inputs(1).shape(), op_features,
found_unknown_shapes);
@ -542,14 +531,13 @@ int64 OpLevelCostEstimator::CountConv2DBackPropFilterOperations(
return ops;
}
if (op_features.attr().find("_output_shapes") == op_features.attr().end()) {
// Need _output_shapes for filter shape.
LOG(ERROR) << "No output shape in Conv2DBackPropFilter op feaure.";
if (op_features.outputs_size() != 1) {
// Need _output_shapes for input shape.
LOG(ERROR) << "No output shape in Conv2DBackPropFilter op.";
return ops;
}
const auto& filter_shape =
op_features.attr().at("_output_shapes").list().shape(0);
const auto& filter_shape = op_features.outputs(0).shape();
ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
op_features.inputs(0).shape(), filter_shape, op_features,
found_unknown_shapes);
@ -598,28 +586,19 @@ int64 OpLevelCostEstimator::CalculateOutputSize(
const OpInfo& op_features, bool* found_unknown_shapes) const {
int64 total_output_size = 0;
// use float as default for calculations
DataType dt = DT_FLOAT;
for (const auto& item : op_features.attr()) {
VLOG(1) << "Key:" << item.first
<< " Value:" << SummarizeAttrValue(item.second);
if (item.first == "_output_shapes") {
for (const auto& original_output_shape : item.second.list().shape()) {
int64 output_size = 1;
int num_dims = std::max(1, original_output_shape.dim_size());
auto output_shape = MaybeGetMinimumShape(
original_output_shape, num_dims, found_unknown_shapes);
for (const auto& dim : output_shape.dim()) {
output_size *= dim.size();
}
output_size *= DataTypeSize(dt);
total_output_size += output_size;
VLOG(1) << "Output Size: " << output_size
<< " Total Output Size:" << total_output_size;
}
}
if (item.first == "T") {
dt = item.second.type();
for (const auto& output : op_features.outputs()) {
DataType dt = output.dtype();
const auto& original_output_shape = output.shape();
int64 output_size = DataTypeSize(dt);
int num_dims = std::max(1, original_output_shape.dim_size());
auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims,
found_unknown_shapes);
for (const auto& dim : output_shape.dim()) {
output_size *= dim.size();
}
total_output_size += output_size;
VLOG(1) << "Output Size: " << output_size
<< " Total Output Size:" << total_output_size;
}
return total_output_size;
}

View File

@ -33,7 +33,7 @@ message OpInfo {
// Custom parameters impacting the behavior of the op.
map<string, AttrValue> attr = 2;
// Input types, shapes and values if known.
// Input data types, shapes and values if known.
message TensorProperties {
DataType dtype = 1;
TensorShapeProto shape = 2;
@ -41,6 +41,9 @@ message OpInfo {
};
repeated TensorProperties inputs = 3;
// Optional description of the op outputs
repeated TensorProperties outputs = 5;
// Device on which the operation is run.
DeviceProperties device = 4;
}

View File

@ -316,13 +316,17 @@ NodeInfo VirtualScheduler::GetCurrNodeInfo() const {
NodeInfo node_info;
node_info.name = node->name();
node_info.device_name = graph_properties_.GetDeviceName(node->name());
node_info.outputs = graph_properties_.GetOutputProperties(node->name());
std::vector<OpInfo::TensorProperties> outputs =
graph_properties_.GetOutputProperties(node->name());
auto& op_info = node_info.op_info;
op_info.set_op(node->op());
*op_info.mutable_attr() = node->attr();
for (auto& input : inputs) {
op_info.add_inputs()->Swap(&input);
}
for (auto& output : outputs) {
op_info.add_outputs()->Swap(&output);
}
op_info.mutable_device()->Swap(&device);
// add some more to the node_info.
return node_info;

View File

@ -95,7 +95,6 @@ struct NodeInfo {
OpInfo op_info;
string name;
string device_name;
std::vector<OpInfo::TensorProperties> outputs;
};
// The virtual scheduler emulates execution of nodes in a graph, considering

View File

@ -126,9 +126,9 @@ TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) {
EXPECT_EQ(ops_executed.count("c2"), 0);
// Check input / output properties.
EXPECT_EQ(1, ops_executed["x"].outputs.size());
EXPECT_EQ(1, ops_executed["y"].outputs.size());
EXPECT_EQ(1, ops_executed["f"].outputs.size());
EXPECT_EQ(1, ops_executed["x"].op_info.outputs_size());
EXPECT_EQ(1, ops_executed["y"].op_info.outputs_size());
EXPECT_EQ(1, ops_executed["f"].op_info.outputs_size());
EXPECT_EQ(2, ops_executed["c0"].op_info.inputs_size());
EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size());
}

View File

@ -3797,10 +3797,17 @@ py_test(
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
":array_ops",
":client_testlib",
":cost_analyzer",
":framework_for_generated_wrappers",
":math_ops",
":nn",
":nn_grad",
":random_ops",
":state_ops",
":training",
":variables",
"//tensorflow/core:protos_all_py",
"//third_party/py/numpy",
],

View File

@ -19,11 +19,18 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.grappler import cost_analyzer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import adam
class PyWrapOptimizeGraphTest(test.TestCase):
@ -51,6 +58,40 @@ class PyWrapOptimizeGraphTest(test.TestCase):
# Also print the report to make it easier to debug
print("{}".format(report))
def testSmallNetwork(self):
image = array_ops.placeholder(dtypes.float32, shape=[1, 28, 28, 1])
label = array_ops.placeholder(dtypes.float32, shape=[1, 10])
w = variables.Variable(
random_ops.truncated_normal([5, 5, 1, 32], stddev=0.1))
b = variables.Variable(random_ops.truncated_normal([32], stddev=0.1))
conv = nn_ops.conv2d(image, w, strides=[1, 1, 1, 1], padding="SAME")
h_conv = nn_ops.relu(conv + b)
h_conv_flat = array_ops.reshape(h_conv, [1, -1])
w_fc = variables.Variable(
random_ops.truncated_normal([25088, 10], stddev=0.1))
b_fc = variables.Variable(random_ops.truncated_normal([10], stddev=0.1))
y_conv = nn_ops.softmax(math_ops.matmul(h_conv_flat, w_fc) + b_fc)
cross_entropy = math_ops.reduce_mean(-math_ops.reduce_sum(
label * math_ops.log(y_conv), reduction_indices=[1]))
_ = adam.AdamOptimizer(1e-4).minimize(cross_entropy)
mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
report = cost_analyzer.GenerateCostReport(mg)
self.assertTrue(b"MatMul" in report)
self.assertTrue(b"ApplyAdam" in report)
self.assertTrue(b"Conv2D" in report)
self.assertTrue(b"Conv2DBackpropInput" in report)
self.assertTrue(b"Conv2DBackpropFilter" in report)
self.assertTrue(b"Softmax" in report)
# Also print the report to make it easier to debug
print("{}".format(report))
# print("{}".format(mg.graph_def))
if __name__ == "__main__":
test.main()