PR #37378: Add "AddV2" to tensorflow v1 flops calculation

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/37378

Copybara import of the project:

--
4e517f884c by tigertang <tigertang.zju@gmail.com>:

Add "AddV2" to tensorflow v1 flops calculation
--
a74820d5ae by tigertang <tigertang.zju@gmail.com>:

Add another space line

PiperOrigin-RevId: 300333508
Change-Id: If104ba82b7831ab5dd4ca0ea574aec36b5394fde
This commit is contained in:
A. Unique TensorFlower 2020-03-11 08:48:30 -07:00 committed by TensorFlower Gardener
parent 2394cfb44e
commit 0dc61704dd

View File

@ -25,53 +25,21 @@ from tensorflow.python.framework import ops
# List of all ops which have implemented flops statistics.
IMPLEMENTED_OPS = set([
# Unary ops
"Reciprocal",
"Square",
"Rsqrt",
"Log",
"Neg",
"AssignSub",
"AssignAdd",
"L2Loss",
"Softmax",
"Reciprocal", "Square", "Rsqrt", "Log", "Neg", "AssignSub", "AssignAdd",
"L2Loss", "Softmax",
# Binary ops
"Add",
"AddV2",
"Sub",
"Mul",
"RealDiv",
"Maximum",
"Minimum",
"Pow",
"RsqrtGrad",
"GreaterEqual",
"Greater",
"LessEqual",
"Less",
"Equal",
"NotEqual",
"Add", "Sub", "Mul", "RealDiv", "Maximum", "Minimum", "Pow", "RsqrtGrad",
"GreaterEqual", "Greater", "LessEqual", "Less", "Equal", "NotEqual",
"SquaredDifference",
# Reduction ops
"Mean",
"Sum",
"ArgMax",
"ArgMin",
"BiasAddGrad",
"Mean", "Sum", "ArgMax", "ArgMin", "BiasAddGrad",
# Convolution and pooling
"AvgPool",
"MaxPool",
"AvgPoolGrad",
"MaxPoolGrad",
"Conv2DBackpropInput",
"AvgPool", "MaxPool", "AvgPoolGrad", "MaxPoolGrad", "Conv2DBackpropInput",
"Conv2DBackpropFilter",
# Other ops
"AddN",
# Ops implemented in core tensorflow:
"MatMul",
"Conv2D",
"DepthwiseConv2dNative",
"BiasAdd",
"Dilation2D",
"MatMul", "Conv2D", "DepthwiseConv2dNative", "BiasAdd", "Dilation2D",
])
@ -177,7 +145,6 @@ def _binary_per_element_op_flops(graph, node, ops_per_element=1):
@ops.RegisterStatistics("Add", "flops")
@ops.RegisterStatistics("AddV2", "flops")
def _add_flops(graph, node):
"""Compute flops for Add operation."""
return _binary_per_element_op_flops(graph, node)