STT-tensorflow/tensorflow/python/grappler/cost_analyzer_test.py
A. Unique TensorFlower 18722841f5 Fix incorrect "is" comparisons to instead use "==".
PiperOrigin-RevId: 237140147
2019-03-06 16:23:38 -08:00

165 lines
6.3 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the cost analyzer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.grappler import cost_analyzer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import adam
class CostAnalysisTest(test.TestCase):
@test_util.run_deprecated_v1
def testBasicCost(self):
"""Make sure arguments can be passed correctly."""
a = constant_op.constant(10, name="a")
b = constant_op.constant(20, name="b")
c = math_ops.add_n([a, b], name="c")
d = math_ops.add_n([b, c], name="d")
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(d)
mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
report = cost_analyzer.GenerateCostReport(mg, per_node_report=True)
# Check the report headers
self.assertTrue(b"Total time measured in ns (serialized):" in report)
self.assertTrue(b"Total time measured in ns (actual):" in report)
self.assertTrue(b"Total time analytical in ns (upper bound):" in report)
self.assertTrue(b"Total time analytical in ns (lower bound):" in report)
self.assertTrue(b"Overall efficiency (analytical upper/actual):" in report)
self.assertTrue(b"Overall efficiency (analytical lower/actual):" in report)
self.assertTrue(b"Below is the per-node report summary:" in report)
# Also print the report to make it easier to debug
print("{}".format(report))
@test_util.run_deprecated_v1
def testVerbose(self):
"""Make sure the full report is generated with verbose=True."""
a = constant_op.constant(10, name="a")
b = constant_op.constant(20, name="b")
c = math_ops.add_n([a, b], name="c")
d = math_ops.add_n([b, c], name="d")
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(d)
mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
report = cost_analyzer.GenerateCostReport(
mg, per_node_report=True, verbose=True)
# Check the report headers
self.assertTrue(b"Below is the full per-node report:" in report)
# Also print the report to make it easier to debug
print("{}".format(report))
@test_util.run_deprecated_v1
def testSmallNetworkCost(self):
image = array_ops.placeholder(dtypes.float32, shape=[1, 28, 28, 1])
label = array_ops.placeholder(dtypes.float32, shape=[1, 10])
w = variables.Variable(
random_ops.truncated_normal([5, 5, 1, 32], stddev=0.1))
b = variables.Variable(random_ops.truncated_normal([32], stddev=0.1))
conv = nn_ops.conv2d(image, w, strides=[1, 1, 1, 1], padding="SAME")
h_conv = nn_ops.relu(conv + b)
h_conv_flat = array_ops.reshape(h_conv, [1, -1])
w_fc = variables.Variable(
random_ops.truncated_normal([25088, 10], stddev=0.1))
b_fc = variables.Variable(random_ops.truncated_normal([10], stddev=0.1))
y_conv = nn_ops.softmax(math_ops.matmul(h_conv_flat, w_fc) + b_fc)
cross_entropy = math_ops.reduce_mean(
-math_ops.reduce_sum(label * math_ops.log(y_conv), axis=[1]))
_ = adam.AdamOptimizer(1e-4).minimize(cross_entropy)
mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
report = cost_analyzer.GenerateCostReport(mg)
# Print the report to make it easier to debug
print("{}".format(report))
self.assertTrue(b"MatMul" in report)
self.assertTrue(b"ApplyAdam" in report)
self.assertTrue(b"Conv2D" in report)
self.assertTrue(b"Conv2DBackpropFilter" in report)
self.assertTrue(b"Softmax" in report)
for op_type in [b"MatMul", b"Conv2D", b"Conv2DBackpropFilter"]:
matcher = re.compile(
br"\s+" + op_type + br",\s*(\d+),\s*(\d+),\s*([\d\.eE+-]+)%,\s*" +
br"([\d\.eE+-]+)%,\s*(-?\d+),\s*(\d+),", re.MULTILINE)
m = matcher.search(report)
op_count = int(m.group(1))
# upper = int(m.group(5))
lower = int(m.group(6))
if op_type == b"MatMul":
self.assertEqual(3, op_count)
else:
self.assertEqual(1, op_count)
self.assertTrue(0 <= lower)
# self.assertTrue(0 < upper)
# self.assertTrue(lower <= upper)
@test_util.run_deprecated_v1
def testBasicMemory(self):
"""Make sure arguments can be passed correctly."""
with test_util.device(use_gpu=False):
a = constant_op.constant(10, name="a")
b = constant_op.constant(20, name="b")
c = math_ops.add_n([a, b], name="c")
d = math_ops.add_n([b, c], name="d")
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(d)
mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
report = cost_analyzer.GenerateMemoryReport(mg)
# Print the report to make it easier to debug
print("{}".format(report))
# Check the report
self.assertTrue(
"Peak usage for device /job:localhost/replica:0/task:0/device:CPU:0: "
"16 bytes"
in report)
self.assertTrue(" a:0 uses 4 bytes" in report)
self.assertTrue(" b:0 uses 4 bytes" in report)
self.assertTrue(" c:0 uses 4 bytes" in report)
self.assertTrue(" d:0 uses 4 bytes" in report)
if __name__ == "__main__":
test.main()