449 lines
16 KiB
Python
449 lines
16 KiB
Python
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Register flops statistics for various TensorFlow operations.
|
|
"""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from tensorflow.python.framework import graph_util
|
|
from tensorflow.python.framework import ops
|
|
|
|
|
|
# List of all ops which have implemented flops statistics.
|
|
IMPLEMENTED_OPS = set([
|
|
# Unary ops
|
|
"Reciprocal", "Square", "Rsqrt", "Log", "Neg", "AssignSub", "AssignAdd",
|
|
"L2Loss", "Softmax",
|
|
# Binary ops
|
|
"Add", "AddV2", "Sub", "Mul", "RealDiv", "Maximum", "Minimum", "Pow",
|
|
"RsqrtGrad", "GreaterEqual", "Greater", "LessEqual", "Less", "Equal",
|
|
"NotEqual", "SquaredDifference",
|
|
# Reduction ops
|
|
"Mean", "Sum", "ArgMax", "ArgMin", "BiasAddGrad",
|
|
# Convolution and pooling
|
|
"AvgPool", "MaxPool", "AvgPoolGrad", "MaxPoolGrad", "Conv2DBackpropInput",
|
|
"Conv2DBackpropFilter",
|
|
# Other ops
|
|
"AddN",
|
|
# Ops implemented in core tensorflow:
|
|
"MatMul", "Conv2D", "DepthwiseConv2dNative", "BiasAdd", "Dilation2D",
|
|
])
|
|
|
|
|
|
def _zero_flops(graph, node):
|
|
"""Returns zero flops."""
|
|
del graph, node # graph and node are unused
|
|
return ops.OpStats("flops", 0)
|
|
|
|
|
|
def _list_product(lst):
|
|
"""Computes product of element of the list."""
|
|
result = 1
|
|
for item in lst:
|
|
result *= item
|
|
return result
|
|
|
|
################################################################################
|
|
# Unary operations
|
|
################################################################################
|
|
|
|
|
|
def _unary_op_flops(graph, node, ops_per_element=1):
|
|
"""Common code which compute flops for unary operations."""
|
|
in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
|
|
in_shape.assert_is_fully_defined()
|
|
return ops.OpStats("flops", in_shape.num_elements() * ops_per_element)
|
|
|
|
|
|
@ops.RegisterStatistics("Reciprocal", "flops")
|
|
def _reciprocal_flops(graph, node):
|
|
"""Compute flops for Reciprocal operation."""
|
|
return _unary_op_flops(graph, node)
|
|
|
|
|
|
@ops.RegisterStatistics("Square", "flops")
|
|
def _square_flops(graph, node):
|
|
"""Compute flops for Square operation."""
|
|
return _unary_op_flops(graph, node)
|
|
|
|
|
|
@ops.RegisterStatistics("Rsqrt", "flops")
|
|
def _rsqrt_flops(graph, node):
|
|
"""Compute flops for Rsqrt operation."""
|
|
# Rsqrt(x) = 1 / sqrt(x)
|
|
return _unary_op_flops(graph, node, ops_per_element=2)
|
|
|
|
|
|
@ops.RegisterStatistics("Log", "flops")
|
|
def _log_flops(graph, node):
|
|
"""Compute flops for Log operation."""
|
|
return _unary_op_flops(graph, node)
|
|
|
|
|
|
@ops.RegisterStatistics("Neg", "flops")
|
|
def _neg_flops(graph, node):
|
|
"""Compute flops for Neg operation."""
|
|
return _unary_op_flops(graph, node)
|
|
|
|
|
|
@ops.RegisterStatistics("AssignSub", "flops")
|
|
def _assign_sub_flops(graph, node):
|
|
"""Compute flops for AssignSub operation."""
|
|
return _unary_op_flops(graph, node)
|
|
|
|
|
|
@ops.RegisterStatistics("AssignAdd", "flops")
|
|
def _assign_add_flops(graph, node):
|
|
"""Compute flops for AssignAdd operation."""
|
|
return _unary_op_flops(graph, node)
|
|
|
|
|
|
@ops.RegisterStatistics("L2Loss", "flops")
|
|
def _l2_loss_flops(graph, node):
|
|
"""Compute flops for L2Loss operation."""
|
|
in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
|
|
in_shape.assert_is_fully_defined()
|
|
# Tensorflow uses inefficient implementation, with (3*N-1) flops:
|
|
# Optimal implementation is 2*N flops
|
|
return ops.OpStats("flops", in_shape.num_elements() * 3 - 1)
|
|
|
|
|
|
@ops.RegisterStatistics("Softmax", "flops")
|
|
def _softmax_flops(graph, node):
|
|
"""Compute flops for Softmax operation."""
|
|
# Softmax implemetation:
|
|
#
|
|
# Approximate flops breakdown:
|
|
# 2*n -- compute shifted logits
|
|
# n -- exp of shifted logits
|
|
# 2*n -- compute softmax from exp of shifted logits
|
|
return _unary_op_flops(graph, node, ops_per_element=5)
|
|
|
|
################################################################################
|
|
# Binary operations
|
|
################################################################################
|
|
|
|
|
|
def _binary_per_element_op_flops(graph, node, ops_per_element=1):
|
|
"""Common code which compute flops for binary operations."""
|
|
out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
|
|
out_shape.assert_is_fully_defined()
|
|
return ops.OpStats("flops", out_shape.num_elements() * ops_per_element)
|
|
|
|
|
|
@ops.RegisterStatistics("Add", "flops")
|
|
@ops.RegisterStatistics("AddV2", "flops")
|
|
def _add_flops(graph, node):
|
|
"""Compute flops for Add operation."""
|
|
return _binary_per_element_op_flops(graph, node)
|
|
|
|
|
|
@ops.RegisterStatistics("Sub", "flops")
|
|
def _sub_flops(graph, node):
|
|
"""Compute flops for Sub operation."""
|
|
return _binary_per_element_op_flops(graph, node)
|
|
|
|
|
|
@ops.RegisterStatistics("Mul", "flops")
|
|
def _mul_flops(graph, node):
|
|
"""Compute flops for Mul operation."""
|
|
return _binary_per_element_op_flops(graph, node)
|
|
|
|
|
|
@ops.RegisterStatistics("RealDiv", "flops")
|
|
def _real_div_flops(graph, node):
|
|
"""Compute flops for RealDiv operation."""
|
|
return _binary_per_element_op_flops(graph, node)
|
|
|
|
|
|
@ops.RegisterStatistics("Maximum", "flops")
|
|
def _maximum_flops(graph, node):
|
|
"""Compute flops for Maximum operation."""
|
|
return _binary_per_element_op_flops(graph, node)
|
|
|
|
|
|
@ops.RegisterStatistics("Minimum", "flops")
|
|
def _minimum_flops(graph, node):
|
|
"""Compute flops for Minimum operation."""
|
|
return _binary_per_element_op_flops(graph, node)
|
|
|
|
|
|
@ops.RegisterStatistics("Pow", "flops")
|
|
def _pow_flops(graph, node):
|
|
"""Compute flops for Pow operation."""
|
|
return _binary_per_element_op_flops(graph, node)
|
|
|
|
|
|
@ops.RegisterStatistics("RsqrtGrad", "flops")
|
|
def _rsqrt_grad_flops(graph, node):
|
|
"""Compute flops for RsqrtGrad operation."""
|
|
return _binary_per_element_op_flops(graph, node, ops_per_element=4)
|
|
|
|
|
|
@ops.RegisterStatistics("GreaterEqual", "flops")
|
|
def _greater_equal_flops(graph, node):
|
|
"""Compute flops for GreaterEqual operation."""
|
|
return _binary_per_element_op_flops(graph, node)
|
|
|
|
|
|
@ops.RegisterStatistics("Greater", "flops")
|
|
def _greater_flops(graph, node):
|
|
"""Compute flops for Greater operation."""
|
|
return _binary_per_element_op_flops(graph, node)
|
|
|
|
|
|
@ops.RegisterStatistics("LessEqual", "flops")
|
|
def _less_equal_flops(graph, node):
|
|
"""Compute flops for LessEqual operation."""
|
|
return _binary_per_element_op_flops(graph, node)
|
|
|
|
|
|
@ops.RegisterStatistics("Less", "flops")
|
|
def _less_flops(graph, node):
|
|
"""Compute flops for Less operation."""
|
|
return _binary_per_element_op_flops(graph, node)
|
|
|
|
|
|
@ops.RegisterStatistics("Equal", "flops")
|
|
def _equal_flops(graph, node):
|
|
"""Compute flops for Equal operation."""
|
|
return _binary_per_element_op_flops(graph, node)
|
|
|
|
|
|
@ops.RegisterStatistics("NotEqual", "flops")
|
|
def _not_equal_flops(graph, node):
|
|
"""Compute flops for NotEqual operation."""
|
|
return _binary_per_element_op_flops(graph, node)
|
|
|
|
|
|
@ops.RegisterStatistics("SquaredDifference", "flops")
|
|
def _squared_difference_flops(graph, node):
|
|
"""Compute flops for SquaredDifference operation."""
|
|
return _binary_per_element_op_flops(graph, node, ops_per_element=2)
|
|
|
|
################################################################################
|
|
# Reduction ops
|
|
################################################################################
|
|
|
|
|
|
def _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0):
|
|
"""Common code which compute flops for reduction operations."""
|
|
in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
|
|
in_shape.assert_is_fully_defined()
|
|
out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
|
|
out_shape.assert_is_fully_defined()
|
|
num_flops = (in_shape.num_elements() * reduce_flops
|
|
+ out_shape.num_elements() * (finalize_flops - reduce_flops))
|
|
return ops.OpStats("flops", num_flops)
|
|
|
|
|
|
@ops.RegisterStatistics("Mean", "flops")
|
|
def _mean_flops(graph, node):
|
|
"""Compute flops for Mean operation."""
|
|
# reduction - sum, finalization - divide
|
|
return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=1)
|
|
|
|
|
|
@ops.RegisterStatistics("Sum", "flops")
|
|
def _sum_flops(graph, node):
|
|
"""Compute flops for Sum operation."""
|
|
# reduction - sum, no finalization
|
|
return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
|
|
|
|
|
|
@ops.RegisterStatistics("ArgMax", "flops")
|
|
def _arg_max_flops(graph, node):
|
|
"""Compute flops for ArgMax operation."""
|
|
# reduction - comparison, no finalization
|
|
return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
|
|
|
|
|
|
@ops.RegisterStatistics("ArgMin", "flops")
|
|
def _arg_min_flops(graph, node):
|
|
"""Compute flops for ArgMin operation."""
|
|
# reduction - comparison, no finalization
|
|
return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
|
|
|
|
|
|
@ops.RegisterStatistics("BiasAddGrad", "flops")
|
|
def _bias_add_grad_flops(graph, node):
|
|
"""Compute flops for BiasAddGrad operation."""
|
|
# Implementation of BiasAddGrad, essentially it's a reduce sum and reshaping:
|
|
# So computing flops same way as for "Sum"
|
|
return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
|
|
|
|
################################################################################
|
|
# Convolution and pooling
|
|
# Note: all flops statistics are implemented only for NHWC data format
|
|
################################################################################
|
|
|
|
|
|
def _verify_conv_data_format(node):
|
|
"""Verifies data format for pooling and convolutional operations."""
|
|
# TODO(xpan): P1: Support NCHW
|
|
if node.attr["data_format"].s != b"NHWC":
|
|
raise ValueError("Only NHWC format is supported in flops computations")
|
|
|
|
|
|
def _pool_flops(graph, node):
|
|
"""Common code which compute flops for pooling operations."""
|
|
# compute flops for average and max pooling
|
|
_verify_conv_data_format(node)
|
|
#
|
|
# Pooling declaration:
|
|
# Inputs:
|
|
# - value
|
|
# Outputs:
|
|
# - output
|
|
# Attributes:
|
|
# - ksize
|
|
# - strides
|
|
# - padding
|
|
# - data_format
|
|
#
|
|
# Pooling implemetation:
|
|
out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
|
|
out_shape.assert_is_fully_defined()
|
|
kernel_shape = list(node.attr["ksize"].list.i)
|
|
kernel_area = _list_product(kernel_shape)
|
|
return ops.OpStats("flops", kernel_area * out_shape.num_elements())
|
|
|
|
|
|
@ops.RegisterStatistics("AvgPool", "flops")
|
|
def _avg_pool_flops(graph, node):
|
|
"""Compute flops for AvgPool operation."""
|
|
return _pool_flops(graph, node)
|
|
|
|
|
|
@ops.RegisterStatistics("MaxPool", "flops")
|
|
def _max_pool_flops(graph, node):
|
|
"""Compute flops for MaxPool operation."""
|
|
return _pool_flops(graph, node)
|
|
|
|
|
|
@ops.RegisterStatistics("AvgPoolGrad", "flops")
|
|
def _avg_pool_grad_flops(graph, node):
|
|
"""Compute flops for AvgPoolGrad operation."""
|
|
_verify_conv_data_format(node)
|
|
# Pooling gradient implementation:
|
|
out_backprop_shape = graph_util.tensor_shape_from_node_def_name(graph,
|
|
node.input[1])
|
|
out_backprop_shape.assert_is_fully_defined()
|
|
kernel_shape = list(node.attr["ksize"].list.i)
|
|
kernel_area = _list_product(kernel_shape)
|
|
# TensorFlow multiply each element of pooling window by coefficient,
|
|
# then sum up all of them, thus we have 2 flops per element:
|
|
# More optimal implementation - if division is done after.
|
|
return ops.OpStats("flops",
|
|
kernel_area * out_backprop_shape.num_elements() * 2)
|
|
|
|
|
|
@ops.RegisterStatistics("MaxPoolGrad", "flops")
|
|
def _max_pool_grad_flops(graph, node):
|
|
"""Compute flops for MaxPoolGrad operation."""
|
|
_verify_conv_data_format(node)
|
|
#
|
|
# MaxPoolGrad declaration:
|
|
# Inputs:
|
|
# - orig_input -- original input tensor (of max_pool)
|
|
# - orig_output -- original output tensor (of max_pool)
|
|
# - grad -- gradient with respect to output of max_pool
|
|
# Outputs:
|
|
# - output -- gradient with respect to input of max_pool
|
|
# Attributes:
|
|
# - ksize
|
|
# - strides
|
|
# - padding
|
|
# - data_format
|
|
# It computes MaxPool first, then one flop per each element of original output
|
|
#
|
|
kernel_shape = list(node.attr["ksize"].list.i)
|
|
kernel_area = _list_product(kernel_shape)
|
|
orig_out_shape = graph_util.tensor_shape_from_node_def_name(graph,
|
|
node.input[1])
|
|
orig_out_shape.assert_is_fully_defined()
|
|
max_pool_ops = kernel_area * orig_out_shape.num_elements()
|
|
return ops.OpStats("flops", max_pool_ops + orig_out_shape.num_elements())
|
|
|
|
|
|
@ops.RegisterStatistics("Conv2DBackpropInput", "flops")
|
|
def _conv_2d_backprop_input_flops(graph, node):
|
|
"""Compute flops for Conv2DBackpropInput operation."""
|
|
# Formula:
|
|
# batch_size * image_x_dim * image_y_dim * kernel_x_dim * kernel_y_dim
|
|
# * input_depth * output_depth * 2 / (image_x_stride * image_x_stride)
|
|
#
|
|
# Where:
|
|
# image_x_dim, image_y_dim and input_depth --- size of input to source (no
|
|
# backprop) convolution, in other words they are sizes of backprop output.
|
|
# output_depth --- number of filters in the original convolution, thus
|
|
# depth of backprop input.
|
|
# kernel_x_dim and kernel_y_dim --- sizes of filter in spatial dimension
|
|
# image_x_stride and image_x_stride --- strides of the convolution
|
|
#
|
|
_verify_conv_data_format(node)
|
|
# out_shape = [batch_size, image_y_dim, image_x_dim, input_depth]
|
|
out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
|
|
out_shape.assert_is_fully_defined()
|
|
# kernel_shape = [kernel_y_dim, kernel_x_dim, input_depth, output_depth]
|
|
kernel_shape = graph_util.tensor_shape_from_node_def_name(graph,
|
|
node.input[1])
|
|
kernel_shape.assert_is_fully_defined()
|
|
# strides
|
|
strides_shape = list(node.attr["strides"].list.i)
|
|
strides_product = strides_shape[1] * strides_shape[2]
|
|
return ops.OpStats("flops",
|
|
(2 * out_shape.num_elements()
|
|
* kernel_shape.num_elements()
|
|
/ (out_shape.dims[-1].value * strides_product)))
|
|
|
|
|
|
@ops.RegisterStatistics("Conv2DBackpropFilter", "flops")
|
|
def _conv_2d_backprop_filter_flops(graph, node):
|
|
"""Compute flops for Conv2DBackpropFilter operation."""
|
|
# Formula same as for Conv2DBackpropInput:
|
|
# batch_size * image_x_dim * image_y_dim * kernel_x_dim * kernel_y_dim
|
|
# * input_depth * output_depth * 2 / (image_x_stride * image_x_stride)
|
|
#
|
|
_verify_conv_data_format(node)
|
|
# image_shape = [batch_size, image_y_dim, image_x_dim, input_depth]
|
|
image_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
|
|
image_shape.assert_is_fully_defined()
|
|
# kernel_shape = [kernel_y_dim, kernel_x_dim, input_depth, output_depth]
|
|
kernel_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
|
|
kernel_shape.assert_is_fully_defined()
|
|
# strides
|
|
strides_shape = list(node.attr["strides"].list.i)
|
|
strides_product = strides_shape[1] * strides_shape[2]
|
|
return ops.OpStats("flops",
|
|
(2 * image_shape.num_elements()
|
|
* kernel_shape.num_elements()
|
|
/ (image_shape.dims[-1].value * strides_product)))
|
|
|
|
################################################################################
|
|
# Other ops
|
|
################################################################################
|
|
|
|
|
|
@ops.RegisterStatistics("AddN", "flops")
|
|
def _add_n_flops(graph, node):
|
|
"""Compute flops for AddN operation."""
|
|
if not node.input:
|
|
return _zero_flops(graph, node)
|
|
in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
|
|
in_shape.assert_is_fully_defined()
|
|
return ops.OpStats("flops", in_shape.num_elements() * (len(node.input) - 1))
|