Add a lot of operations' flops calculations
PiperOrigin-RevId: 168256746
This commit is contained in:
parent
80ed8afc02
commit
509372c2ee
@ -82,6 +82,7 @@ py_library(
|
|||||||
"//tensorflow/core/profiler:protos_all_py",
|
"//tensorflow/core/profiler:protos_all_py",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
|
"//tensorflow/python/profiler/internal:flops_registry",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -7,6 +7,16 @@ load("//tensorflow:tensorflow.bzl", "py_test")
|
|||||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "flops_registry",
|
||||||
|
srcs = ["flops_registry.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:graph_util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "model_analyzer_testlib",
|
name = "model_analyzer_testlib",
|
||||||
srcs = ["model_analyzer_testlib.py"],
|
srcs = ["model_analyzer_testlib.py"],
|
||||||
|
446
tensorflow/python/profiler/internal/flops_registry.py
Normal file
446
tensorflow/python/profiler/internal/flops_registry.py
Normal file
@ -0,0 +1,446 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Register flops statistics for various TensorFlow operations.
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.framework import graph_util
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
|
||||||
|
|
||||||
|
# List of all ops which have implemented flops statistics.
|
||||||
|
IMPLEMENTED_OPS = set([
|
||||||
|
# Unary ops
|
||||||
|
"Reciprocal", "Square", "Rsqrt", "Log", "Neg", "AssignSub", "AssignAdd",
|
||||||
|
"L2Loss", "Softmax",
|
||||||
|
# Binary ops
|
||||||
|
"Add", "Sub", "Mul", "RealDiv", "Maximum", "Minimum", "Pow", "RsqrtGrad",
|
||||||
|
"GreaterEqual", "Greater", "LessEqual", "Less", "Equal", "NotEqual",
|
||||||
|
"SquaredDifference",
|
||||||
|
# Reduction ops
|
||||||
|
"Mean", "Sum", "ArgMax", "ArgMin", "BiasAddGrad",
|
||||||
|
# Convolution and pooling
|
||||||
|
"AvgPool", "MaxPool", "AvgPoolGrad", "MaxPoolGrad", "Conv2DBackpropInput",
|
||||||
|
"Conv2DBackpropFilter",
|
||||||
|
# Other ops
|
||||||
|
"AddN",
|
||||||
|
# Ops implemented in core tensorflow:
|
||||||
|
"MatMul", "Conv2D", "DepthwiseConv2dNative", "BiasAdd", "Dilation2D",
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def _zero_flops(graph, node):
|
||||||
|
"""Returns zero flops."""
|
||||||
|
del graph, node # graph and node are unused
|
||||||
|
return ops.OpStats("flops", 0)
|
||||||
|
|
||||||
|
|
||||||
|
def _list_product(lst):
|
||||||
|
"""Computes product of element of the list."""
|
||||||
|
result = 1
|
||||||
|
for item in lst:
|
||||||
|
result *= item
|
||||||
|
return result
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Unary operations
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def _unary_op_flops(graph, node, ops_per_element=1):
|
||||||
|
"""Common code which compute flops for unary operations."""
|
||||||
|
in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
|
||||||
|
in_shape.assert_is_fully_defined()
|
||||||
|
return ops.OpStats("flops", in_shape.num_elements() * ops_per_element)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("Reciprocal", "flops")
|
||||||
|
def _reciprocal_flops(graph, node):
|
||||||
|
"""Compute flops for Reciprocal operation."""
|
||||||
|
return _unary_op_flops(graph, node)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("Square", "flops")
|
||||||
|
def _square_flops(graph, node):
|
||||||
|
"""Compute flops for Square operation."""
|
||||||
|
return _unary_op_flops(graph, node)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("Rsqrt", "flops")
|
||||||
|
def _rsqrt_flops(graph, node):
|
||||||
|
"""Compute flops for Rsqrt operation."""
|
||||||
|
# Rsqrt(x) = 1 / sqrt(x)
|
||||||
|
return _unary_op_flops(graph, node, ops_per_element=2)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("Log", "flops")
|
||||||
|
def _log_flops(graph, node):
|
||||||
|
"""Compute flops for Log operation."""
|
||||||
|
return _unary_op_flops(graph, node)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("Neg", "flops")
|
||||||
|
def _neg_flops(graph, node):
|
||||||
|
"""Compute flops for Neg operation."""
|
||||||
|
return _unary_op_flops(graph, node)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("AssignSub", "flops")
|
||||||
|
def _assign_sub_flops(graph, node):
|
||||||
|
"""Compute flops for AssignSub operation."""
|
||||||
|
return _unary_op_flops(graph, node)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("AssignAdd", "flops")
|
||||||
|
def _assign_add_flops(graph, node):
|
||||||
|
"""Compute flops for AssignAdd operation."""
|
||||||
|
return _unary_op_flops(graph, node)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("L2Loss", "flops")
|
||||||
|
def _l2_loss_flops(graph, node):
|
||||||
|
"""Compute flops for L2Loss operation."""
|
||||||
|
in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
|
||||||
|
in_shape.assert_is_fully_defined()
|
||||||
|
# Tensorflow uses inefficient implementation, with (3*N-1) flops:
|
||||||
|
# Optimal implementation is 2*N flops
|
||||||
|
return ops.OpStats("flops", in_shape.num_elements() * 3 - 1)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("Softmax", "flops")
|
||||||
|
def _softmax_flops(graph, node):
|
||||||
|
"""Compute flops for Softmax operation."""
|
||||||
|
# Softmax implenetation:
|
||||||
|
#
|
||||||
|
# Approximate flops breakdown:
|
||||||
|
# 2*n -- compute shifted logits
|
||||||
|
# n -- exp of shifted logits
|
||||||
|
# 2*n -- compute softmax from exp of shifted logits
|
||||||
|
return _unary_op_flops(graph, node, ops_per_element=5)
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Binary operations
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def _binary_per_element_op_flops(graph, node, ops_per_element=1):
|
||||||
|
"""Common code which compute flops for binary operations."""
|
||||||
|
out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
|
||||||
|
out_shape.assert_is_fully_defined()
|
||||||
|
return ops.OpStats("flops", out_shape.num_elements() * ops_per_element)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("Add", "flops")
|
||||||
|
def _add_flops(graph, node):
|
||||||
|
"""Compute flops for Add operation."""
|
||||||
|
return _binary_per_element_op_flops(graph, node)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("Sub", "flops")
|
||||||
|
def _sub_flops(graph, node):
|
||||||
|
"""Compute flops for Sub operation."""
|
||||||
|
return _binary_per_element_op_flops(graph, node)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("Mul", "flops")
|
||||||
|
def _mul_flops(graph, node):
|
||||||
|
"""Compute flops for Mul operation."""
|
||||||
|
return _binary_per_element_op_flops(graph, node)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("RealDiv", "flops")
|
||||||
|
def _real_div_flops(graph, node):
|
||||||
|
"""Compute flops for RealDiv operation."""
|
||||||
|
return _binary_per_element_op_flops(graph, node)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("Maximum", "flops")
|
||||||
|
def _maximum_flops(graph, node):
|
||||||
|
"""Compute flops for Maximum operation."""
|
||||||
|
return _binary_per_element_op_flops(graph, node)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("Minimum", "flops")
|
||||||
|
def _minimum_flops(graph, node):
|
||||||
|
"""Compute flops for Minimum operation."""
|
||||||
|
return _binary_per_element_op_flops(graph, node)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("Pow", "flops")
|
||||||
|
def _pow_flops(graph, node):
|
||||||
|
"""Compute flops for Pow operation."""
|
||||||
|
return _binary_per_element_op_flops(graph, node)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("RsqrtGrad", "flops")
|
||||||
|
def _rsqrt_grad_flops(graph, node):
|
||||||
|
"""Compute flops for RsqrtGrad operation."""
|
||||||
|
return _binary_per_element_op_flops(graph, node, ops_per_element=4)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("GreaterEqual", "flops")
|
||||||
|
def _greater_equal_flops(graph, node):
|
||||||
|
"""Compute flops for GreaterEqual operation."""
|
||||||
|
return _binary_per_element_op_flops(graph, node)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("Greater", "flops")
|
||||||
|
def _greater_flops(graph, node):
|
||||||
|
"""Compute flops for Greater operation."""
|
||||||
|
return _binary_per_element_op_flops(graph, node)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("LessEqual", "flops")
|
||||||
|
def _less_equal_flops(graph, node):
|
||||||
|
"""Compute flops for LessEqual operation."""
|
||||||
|
return _binary_per_element_op_flops(graph, node)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("Less", "flops")
|
||||||
|
def _less_flops(graph, node):
|
||||||
|
"""Compute flops for Less operation."""
|
||||||
|
return _binary_per_element_op_flops(graph, node)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("Equal", "flops")
|
||||||
|
def _equal_flops(graph, node):
|
||||||
|
"""Compute flops for Equal operation."""
|
||||||
|
return _binary_per_element_op_flops(graph, node)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("NotEqual", "flops")
|
||||||
|
def _not_equal_flops(graph, node):
|
||||||
|
"""Compute flops for NotEqual operation."""
|
||||||
|
return _binary_per_element_op_flops(graph, node)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("SquaredDifference", "flops")
|
||||||
|
def _squared_difference_flops(graph, node):
|
||||||
|
"""Compute flops for SquaredDifference operation."""
|
||||||
|
return _binary_per_element_op_flops(graph, node, ops_per_element=2)
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Reduction ops
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0):
|
||||||
|
"""Common code which compute flops for reduction operations."""
|
||||||
|
in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
|
||||||
|
in_shape.assert_is_fully_defined()
|
||||||
|
out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
|
||||||
|
out_shape.assert_is_fully_defined()
|
||||||
|
num_flops = (in_shape.num_elements() * reduce_flops
|
||||||
|
+ out_shape.num_elements() * (finalize_flops - reduce_flops))
|
||||||
|
return ops.OpStats("flops", num_flops)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("Mean", "flops")
|
||||||
|
def _mean_flops(graph, node):
|
||||||
|
"""Compute flops for Mean operation."""
|
||||||
|
# reduction - sum, finalization - divide
|
||||||
|
return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=1)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("Sum", "flops")
|
||||||
|
def _sum_flops(graph, node):
|
||||||
|
"""Compute flops for Sum operation."""
|
||||||
|
# reduction - sum, no finalization
|
||||||
|
return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("ArgMax", "flops")
|
||||||
|
def _arg_max_flops(graph, node):
|
||||||
|
"""Compute flops for ArgMax operation."""
|
||||||
|
# reduction - comparison, no finalization
|
||||||
|
return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("ArgMin", "flops")
|
||||||
|
def _arg_min_flops(graph, node):
|
||||||
|
"""Compute flops for ArgMin operation."""
|
||||||
|
# reduction - comparison, no finalization
|
||||||
|
return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("BiasAddGrad", "flops")
|
||||||
|
def _bias_add_grad_flops(graph, node):
|
||||||
|
"""Compute flops for BiasAddGrad operation."""
|
||||||
|
# Implementation of BiasAddGrad, essentially it's a reduce sum and reshaping:
|
||||||
|
# So computing flops same way as for "Sum"
|
||||||
|
return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Convolution and pooling
|
||||||
|
# Note: all flops statistics are implemented only for NHWC data format
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_conv_data_format(node):
|
||||||
|
"""Verifies data format for pooling and convolutional operations."""
|
||||||
|
# TODO(xpan): P1: Support NCHW
|
||||||
|
if node.attr["data_format"].s != b"NHWC":
|
||||||
|
raise ValueError("Only NHWC format is supported in flops computations")
|
||||||
|
|
||||||
|
|
||||||
|
def _pool_flops(graph, node):
|
||||||
|
"""Common code which compute flops for pooling operations."""
|
||||||
|
# compute flops for average and max pooling
|
||||||
|
_verify_conv_data_format(node)
|
||||||
|
#
|
||||||
|
# Pooling declaration:
|
||||||
|
# Inputs:
|
||||||
|
# - value
|
||||||
|
# Outputs:
|
||||||
|
# - output
|
||||||
|
# Attributes:
|
||||||
|
# - ksize
|
||||||
|
# - strides
|
||||||
|
# - padding
|
||||||
|
# - data_format
|
||||||
|
#
|
||||||
|
# Pooling implenetation:
|
||||||
|
out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
|
||||||
|
out_shape.assert_is_fully_defined()
|
||||||
|
kernel_shape = list(node.attr["ksize"].list.i)
|
||||||
|
kernel_area = _list_product(kernel_shape)
|
||||||
|
return ops.OpStats("flops", kernel_area * out_shape.num_elements())
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("AvgPool", "flops")
|
||||||
|
def _avg_pool_flops(graph, node):
|
||||||
|
"""Compute flops for AvgPool operation."""
|
||||||
|
return _pool_flops(graph, node)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("MaxPool", "flops")
|
||||||
|
def _max_pool_flops(graph, node):
|
||||||
|
"""Compute flops for MaxPool operation."""
|
||||||
|
return _pool_flops(graph, node)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("AvgPoolGrad", "flops")
|
||||||
|
def _avg_pool_grad_flops(graph, node):
|
||||||
|
"""Compute flops for AvgPoolGrad operation."""
|
||||||
|
_verify_conv_data_format(node)
|
||||||
|
# Pooling gradient implementation:
|
||||||
|
out_backprop_shape = graph_util.tensor_shape_from_node_def_name(graph,
|
||||||
|
node.input[1])
|
||||||
|
out_backprop_shape.assert_is_fully_defined()
|
||||||
|
kernel_shape = list(node.attr["ksize"].list.i)
|
||||||
|
kernel_area = _list_product(kernel_shape)
|
||||||
|
# TensorFlow multiply each element of pooling window by coefficient,
|
||||||
|
# then sum up all of them, thus we have 2 flops per element:
|
||||||
|
# More optimal implementation - if division is done after.
|
||||||
|
return ops.OpStats("flops",
|
||||||
|
kernel_area * out_backprop_shape.num_elements() * 2)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("MaxPoolGrad", "flops")
|
||||||
|
def _max_pool_grad_flops(graph, node):
|
||||||
|
"""Compute flops for MaxPoolGrad operation."""
|
||||||
|
_verify_conv_data_format(node)
|
||||||
|
#
|
||||||
|
# MaxPoolGrad declaration:
|
||||||
|
# Inputs:
|
||||||
|
# - orig_input -- original input tensor (of max_pool)
|
||||||
|
# - orig_output -- original output tensor (of max_pool)
|
||||||
|
# - grad -- gradient with respect to output of max_pool
|
||||||
|
# Outputs:
|
||||||
|
# - output -- gradient with respect to input of max_pool
|
||||||
|
# Attributes:
|
||||||
|
# - ksize
|
||||||
|
# - strides
|
||||||
|
# - padding
|
||||||
|
# - data_format
|
||||||
|
# It computes MaxPool first, then one flop per each element of original output
|
||||||
|
#
|
||||||
|
kernel_shape = list(node.attr["ksize"].list.i)
|
||||||
|
kernel_area = _list_product(kernel_shape)
|
||||||
|
orig_out_shape = graph_util.tensor_shape_from_node_def_name(graph,
|
||||||
|
node.input[1])
|
||||||
|
max_pool_ops = kernel_area * orig_out_shape.num_elements()
|
||||||
|
return ops.OpStats("flops", max_pool_ops + orig_out_shape.num_elements())
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("Conv2DBackpropInput", "flops")
|
||||||
|
def _conv_2d_backprop_input_flops(graph, node):
|
||||||
|
"""Compute flops for Conv2DBackpropInput operation."""
|
||||||
|
# Formula:
|
||||||
|
# batch_size * image_x_dim * image_y_dim * kernel_x_dim * kernel_y_dim
|
||||||
|
# * input_depth * output_depth * 2 / (image_x_stride * image_x_stride)
|
||||||
|
#
|
||||||
|
# Where:
|
||||||
|
# image_x_dim, image_y_dim and input_depth --- size of input to source (no
|
||||||
|
# backprop) convolution, in other words they are sizes of backprop output.
|
||||||
|
# output_depth --- number of filters in the original convolution, thus
|
||||||
|
# depth of backprop input.
|
||||||
|
# kernel_x_dim and kernel_y_dim --- sizes of filter in spatial dimension
|
||||||
|
# image_x_stride and image_x_stride --- strides of the convolution
|
||||||
|
#
|
||||||
|
_verify_conv_data_format(node)
|
||||||
|
# out_shape = [batch_size, image_y_dim, image_x_dim, input_depth]
|
||||||
|
out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
|
||||||
|
out_shape.assert_is_fully_defined()
|
||||||
|
# kernel_shape = [kernel_y_dim, kernel_x_dim, input_depth, output_depth]
|
||||||
|
kernel_shape = graph_util.tensor_shape_from_node_def_name(graph,
|
||||||
|
node.input[1])
|
||||||
|
kernel_shape.assert_is_fully_defined()
|
||||||
|
# strides
|
||||||
|
strides_shape = list(node.attr["strides"].list.i)
|
||||||
|
strides_product = strides_shape[1] * strides_shape[2]
|
||||||
|
return ops.OpStats("flops",
|
||||||
|
(2 * out_shape.num_elements()
|
||||||
|
* kernel_shape.num_elements()
|
||||||
|
/ (out_shape[-1].value * strides_product)))
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("Conv2DBackpropFilter", "flops")
|
||||||
|
def _conv_2d_backprop_filter_flops(graph, node):
|
||||||
|
"""Compute flops for Conv2DBackpropFilter operation."""
|
||||||
|
# Formula same as for Conv2DBackpropInput:
|
||||||
|
# batch_size * image_x_dim * image_y_dim * kernel_x_dim * kernel_y_dim
|
||||||
|
# * input_depth * output_depth * 2 / (image_x_stride * image_x_stride)
|
||||||
|
#
|
||||||
|
_verify_conv_data_format(node)
|
||||||
|
# image_shape = [batch_size, image_y_dim, image_x_dim, input_depth]
|
||||||
|
image_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
|
||||||
|
image_shape.assert_is_fully_defined()
|
||||||
|
# kernel_shape = [kernel_y_dim, kernel_x_dim, input_depth, output_depth]
|
||||||
|
kernel_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
|
||||||
|
kernel_shape.assert_is_fully_defined()
|
||||||
|
# strides
|
||||||
|
strides_shape = list(node.attr["strides"].list.i)
|
||||||
|
strides_product = strides_shape[1] * strides_shape[2]
|
||||||
|
return ops.OpStats("flops",
|
||||||
|
(2 * image_shape.num_elements()
|
||||||
|
* kernel_shape.num_elements()
|
||||||
|
/ (image_shape[-1].value * strides_product)))
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Other ops
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterStatistics("AddN", "flops")
|
||||||
|
def _add_n_flops(graph, node):
|
||||||
|
"""Compute flops for AddN operation."""
|
||||||
|
if not node.input:
|
||||||
|
return _zero_flops(graph, node)
|
||||||
|
in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
|
||||||
|
in_shape.assert_is_fully_defined()
|
||||||
|
return ops.OpStats("flops", in_shape.num_elements() * (len(node.input) - 1))
|
@ -35,6 +35,7 @@ from tensorflow.python.profiler import model_analyzer
|
|||||||
from tensorflow.python.profiler import option_builder
|
from tensorflow.python.profiler import option_builder
|
||||||
from tensorflow.python.profiler import profile_context
|
from tensorflow.python.profiler import profile_context
|
||||||
from tensorflow.python.profiler.internal import model_analyzer_testlib as lib
|
from tensorflow.python.profiler.internal import model_analyzer_testlib as lib
|
||||||
|
from tensorflow.python.util import compat
|
||||||
|
|
||||||
builder = option_builder.ProfileOptionBuilder
|
builder = option_builder.ProfileOptionBuilder
|
||||||
|
|
||||||
@ -158,7 +159,7 @@ class PrintModelAnalysisTest(test.TestCase):
|
|||||||
with gfile.Open(outfile, 'r') as f:
|
with gfile.Open(outfile, 'r') as f:
|
||||||
# pylint: disable=line-too-long
|
# pylint: disable=line-too-long
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
'node name | # parameters | # float_ops | assigned devices | op types | op count (run|defined) | input shapes\n_TFProfRoot (--/451 params, --/10.44k flops, _kTFScopeParent, --/8|--/36, )\n Conv2D (0/0 params, 5.83k/5.83k flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D, 1/1|1/1, 0:2x6x6x3|1:3x3x3x6)\n Conv2D_1 (0/0 params, 4.61k/4.61k flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D, 1/1|1/1, 0:2x3x3x6|1:2x2x6x12)\n DW (3x3x3x6, 162/162 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables, 1/2|1/10, )\n DW/Assign (0/0 params, 0/0 flops, Assign, 0/0|1/1, 0:3x3x3x6|1:3x3x3x6)\n DW/Initializer (0/0 params, 0/0 flops, _kTFScopeParent, 0/0|1/7, )\n DW/Initializer/random_normal (0/0 params, 0/0 flops, Add, 0/0|1/6, 0:3x3x3x6|1:1)\n DW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, RandomStandardNormal, 0/0|1/1, 0:4)\n DW/Initializer/random_normal/mean (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW/Initializer/random_normal/mul (0/0 params, 0/0 flops, Mul, 0/0|1/1, 0:3x3x3x6|1:1)\n DW/Initializer/random_normal/shape (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW/read (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity, 1/1|1/1, 0:3x3x3x6)\n DW2 (2x2x6x12, 288/288 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables, 1/2|1/10, )\n DW2/Assign (0/0 params, 0/0 flops, Assign, 0/0|1/1, 0:2x2x6x12|1:2x2x6x12)\n DW2/Initializer (0/0 params, 0/0 flops, _kTFScopeParent, 0/0|1/7, )\n DW2/Initializer/random_normal (0/0 params, 0/0 flops, Add, 0/0|1/6, 0:2x2x6x12|1:1)\n DW2/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, RandomStandardNormal, 0/0|1/1, 0:4)\n DW2/Initializer/random_normal/mean (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW2/Initializer/random_normal/mul (0/0 params, 0/0 flops, Mul, 0/0|1/1, 0:2x2x6x12|1:1)\n DW2/Initializer/random_normal/shape (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW2/Initializer/random_normal/stddev (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW2/read (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity, 1/1|1/1, 0:2x2x6x12)\n ScalarW (1, 1/1 params, 0/0 flops, VariableV2|_trainable_variables, 0/0|1/10, )\n ScalarW/Assign (0/0 params, 0/0 flops, Assign, 0/0|1/1, 0:1|1:1)\n ScalarW/Initializer (0/0 params, 0/0 flops, _kTFScopeParent, 0/0|1/7, )\n ScalarW/Initializer/random_normal (0/0 params, 0/0 flops, Add, 0/0|1/6, 0:1|1:1)\n ScalarW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, RandomStandardNormal, 0/0|1/1, 0:0)\n ScalarW/Initializer/random_normal/mean (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n ScalarW/Initializer/random_normal/mul (0/0 params, 0/0 flops, Mul, 0/0|1/1, 0:1|1:1)\n ScalarW/Initializer/random_normal/shape (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n ScalarW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n ScalarW/read (0/0 params, 0/0 flops, Identity, 0/0|1/1, 0:1)\n _retval_Conv2D_1_0_0 (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|RunTimeOp, 1/1|1/1, )\n init (0/0 params, 0/0 flops, NoOp, 0/0|1/1, 0:1|1:3x3x3x6|2:2x2x6x12)\n zeros (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Const, 1/1|1/1, )\n',
|
'node name | # parameters | # float_ops | assigned devices | op types | op count (run|defined) | input shapes\n_TFProfRoot (--/451 params, --/11.34k flops, _kTFScopeParent, --/8|--/36, )\n Conv2D (0/0 params, 5.83k/5.83k flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D, 1/1|1/1, 0:2x6x6x3|1:3x3x3x6)\n Conv2D_1 (0/0 params, 4.61k/4.61k flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D, 1/1|1/1, 0:2x3x3x6|1:2x2x6x12)\n DW (3x3x3x6, 162/162 params, 0/324 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables, 1/2|1/10, )\n DW/Assign (0/0 params, 0/0 flops, Assign, 0/0|1/1, 0:3x3x3x6|1:3x3x3x6)\n DW/Initializer (0/0 params, 0/324 flops, _kTFScopeParent, 0/0|1/7, )\n DW/Initializer/random_normal (0/0 params, 162/324 flops, Add, 0/0|1/6, 0:3x3x3x6|1:1)\n DW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, RandomStandardNormal, 0/0|1/1, 0:4)\n DW/Initializer/random_normal/mean (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW/Initializer/random_normal/mul (0/0 params, 162/162 flops, Mul, 0/0|1/1, 0:3x3x3x6|1:1)\n DW/Initializer/random_normal/shape (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW/read (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity, 1/1|1/1, 0:3x3x3x6)\n DW2 (2x2x6x12, 288/288 params, 0/576 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables, 1/2|1/10, )\n DW2/Assign (0/0 params, 0/0 flops, Assign, 0/0|1/1, 0:2x2x6x12|1:2x2x6x12)\n DW2/Initializer (0/0 params, 0/576 flops, _kTFScopeParent, 0/0|1/7, )\n DW2/Initializer/random_normal (0/0 params, 288/576 flops, Add, 0/0|1/6, 0:2x2x6x12|1:1)\n DW2/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, RandomStandardNormal, 0/0|1/1, 0:4)\n DW2/Initializer/random_normal/mean (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW2/Initializer/random_normal/mul (0/0 params, 288/288 flops, Mul, 0/0|1/1, 0:2x2x6x12|1:1)\n DW2/Initializer/random_normal/shape (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW2/Initializer/random_normal/stddev (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW2/read (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity, 1/1|1/1, 0:2x2x6x12)\n ScalarW (1, 1/1 params, 0/2 flops, VariableV2|_trainable_variables, 0/0|1/10, )\n ScalarW/Assign (0/0 params, 0/0 flops, Assign, 0/0|1/1, 0:1|1:1)\n ScalarW/Initializer (0/0 params, 0/2 flops, _kTFScopeParent, 0/0|1/7, )\n ScalarW/Initializer/random_normal (0/0 params, 1/2 flops, Add, 0/0|1/6, 0:1|1:1)\n ScalarW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, RandomStandardNormal, 0/0|1/1, 0:0)\n ScalarW/Initializer/random_normal/mean (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n ScalarW/Initializer/random_normal/mul (0/0 params, 1/1 flops, Mul, 0/0|1/1, 0:1|1:1)\n ScalarW/Initializer/random_normal/shape (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n ScalarW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n ScalarW/read (0/0 params, 0/0 flops, Identity, 0/0|1/1, 0:1)\n _retval_Conv2D_1_0_0 (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|RunTimeOp, 1/1|1/1, )\n init (0/0 params, 0/0 flops, NoOp, 0/0|1/1, 0:1|1:3x3x3x6|2:2x2x6x12)\n zeros (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Const, 1/1|1/1, )\n',
|
||||||
f.read())
|
f.read())
|
||||||
# pylint: enable=line-too-long
|
# pylint: enable=line-too-long
|
||||||
|
|
||||||
@ -221,12 +222,12 @@ class PrintModelAnalysisTest(test.TestCase):
|
|||||||
with gfile.Open(outfile, 'r') as f:
|
with gfile.Open(outfile, 'r') as f:
|
||||||
lines = f.read().split('\n')
|
lines = f.read().split('\n')
|
||||||
result = '\n'.join([l[:min(len(l), 80)] for l in lines])
|
result = '\n'.join([l[:min(len(l), 80)] for l in lines])
|
||||||
self.assertEqual('node name | # parameters | # float_ops\n_TFProfRoot (--/2.84k params, --/91.04k flops)\n model_analyzer_testlib.py:63:BuildFullModel (0/1.80k params, 0/41.76k flops)\n model_analyzer_testlib.py:40:BuildSmallModel (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:44:BuildSmallModel (0/4 params, 0/0 flops)\n model_analyzer_testlib.py:48:BuildSmallModel (0/648 params, 0/0 flops)\n model_analyzer_testlib.py:49:BuildSmallModel (0/0 params, 0/23.33k flops)\n model_analyzer_testlib.py:53:BuildSmallModel (0/1.15k params, 0/0 flops)\n model_analyzer_testlib.py:54:BuildSmallModel (0/0 params, 0/18.43k flops)\n model_analyzer_testlib.py:63:BuildFullModel (gradient) (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:49:BuildSmallModel (gradient) (0/0 params, 0/0 flo\n model_analyzer_testlib.py:54:BuildSmallModel (gradient) (0/0 params, 0/0 flo\n model_analyzer_testlib.py:67:BuildFullModel (0/1.04k params, 0/16.51k flops)\n model_analyzer_testlib.py:67:BuildFullModel (gradient) (0/0 params, 0/32.77k f\n model_analyzer_testlib.py:69:BuildFullModel (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:70:BuildFullModel (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:70:BuildFullModel (gradient) (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:72:BuildFullModel (0/0 params, 0/0 flops)\n',
|
self.assertEqual(compat.as_bytes('node name | # parameters | # float_ops\n_TFProfRoot (--/2.84k params, --/168.85k flops)\n model_analyzer_testlib.py:63:BuildFullModel (0/1.80k params, 0/45.37k flops)\n model_analyzer_testlib.py:40:BuildSmallModel (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:44:BuildSmallModel (0/4 params, 0/8 flops)\n model_analyzer_testlib.py:48:BuildSmallModel (0/648 params, 0/1.30k flops)\n model_analyzer_testlib.py:49:BuildSmallModel (0/0 params, 0/23.33k flops)\n model_analyzer_testlib.py:53:BuildSmallModel (0/1.15k params, 0/2.30k flops)\n model_analyzer_testlib.py:54:BuildSmallModel (0/0 params, 0/18.43k flops)\n model_analyzer_testlib.py:63:BuildFullModel (gradient) (0/0 params, 0/67.39k f\n model_analyzer_testlib.py:49:BuildSmallModel (gradient) (0/0 params, 0/46.66\n model_analyzer_testlib.py:54:BuildSmallModel (gradient) (0/0 params, 0/20.74\n model_analyzer_testlib.py:67:BuildFullModel (0/1.04k params, 0/18.57k flops)\n model_analyzer_testlib.py:67:BuildFullModel (gradient) (0/0 params, 0/37.00k f\n model_analyzer_testlib.py:69:BuildFullModel (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:70:BuildFullModel (0/0 params, 0/258 flops)\n model_analyzer_testlib.py:70:BuildFullModel (gradient) (0/0 params, 0/130 flop\n model_analyzer_testlib.py:72:BuildFullModel (0/0 params, 0/141 flops)\n'),
|
||||||
result)
|
compat.as_bytes(result))
|
||||||
|
|
||||||
self.assertLess(0, tfprof_node.total_exec_micros)
|
self.assertLess(0, tfprof_node.total_exec_micros)
|
||||||
self.assertEqual(2844, tfprof_node.total_parameters)
|
self.assertEqual(2844, tfprof_node.total_parameters)
|
||||||
self.assertEqual(91040, tfprof_node.total_float_ops)
|
self.assertEqual(168855, tfprof_node.total_float_ops)
|
||||||
self.assertEqual(8, len(tfprof_node.children))
|
self.assertEqual(8, len(tfprof_node.children))
|
||||||
self.assertEqual('_TFProfRoot', tfprof_node.name)
|
self.assertEqual('_TFProfRoot', tfprof_node.name)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
@ -28,6 +28,7 @@ from tensorflow.core.profiler import tfprof_log_pb2
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
|
from tensorflow.python.profiler.internal import flops_registry # pylint: disable=unused-import
|
||||||
|
|
||||||
TRAINABLE_VARIABLES = '_trainable_variables'
|
TRAINABLE_VARIABLES = '_trainable_variables'
|
||||||
REGISTERED_FLOP_STATS = 'flops'
|
REGISTERED_FLOP_STATS = 'flops'
|
||||||
|
Loading…
Reference in New Issue
Block a user