PR #37378: Add "AddV2" to tensorflow v1 flops calculation
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/37378 Copybara import of the project: --4e517f884c
by tigertang <tigertang.zju@gmail.com>: Add "AddV2" to tensorflow v1 flops calculation --a74820d5ae
by tigertang <tigertang.zju@gmail.com>: Add another space line PiperOrigin-RevId: 300333508 Change-Id: If104ba82b7831ab5dd4ca0ea574aec36b5394fde
This commit is contained in:
parent
2394cfb44e
commit
0dc61704dd
@ -25,53 +25,21 @@ from tensorflow.python.framework import ops
|
||||
# List of all ops which have implemented flops statistics.
|
||||
IMPLEMENTED_OPS = set([
|
||||
# Unary ops
|
||||
"Reciprocal",
|
||||
"Square",
|
||||
"Rsqrt",
|
||||
"Log",
|
||||
"Neg",
|
||||
"AssignSub",
|
||||
"AssignAdd",
|
||||
"L2Loss",
|
||||
"Softmax",
|
||||
"Reciprocal", "Square", "Rsqrt", "Log", "Neg", "AssignSub", "AssignAdd",
|
||||
"L2Loss", "Softmax",
|
||||
# Binary ops
|
||||
"Add",
|
||||
"AddV2",
|
||||
"Sub",
|
||||
"Mul",
|
||||
"RealDiv",
|
||||
"Maximum",
|
||||
"Minimum",
|
||||
"Pow",
|
||||
"RsqrtGrad",
|
||||
"GreaterEqual",
|
||||
"Greater",
|
||||
"LessEqual",
|
||||
"Less",
|
||||
"Equal",
|
||||
"NotEqual",
|
||||
"Add", "Sub", "Mul", "RealDiv", "Maximum", "Minimum", "Pow", "RsqrtGrad",
|
||||
"GreaterEqual", "Greater", "LessEqual", "Less", "Equal", "NotEqual",
|
||||
"SquaredDifference",
|
||||
# Reduction ops
|
||||
"Mean",
|
||||
"Sum",
|
||||
"ArgMax",
|
||||
"ArgMin",
|
||||
"BiasAddGrad",
|
||||
"Mean", "Sum", "ArgMax", "ArgMin", "BiasAddGrad",
|
||||
# Convolution and pooling
|
||||
"AvgPool",
|
||||
"MaxPool",
|
||||
"AvgPoolGrad",
|
||||
"MaxPoolGrad",
|
||||
"Conv2DBackpropInput",
|
||||
"AvgPool", "MaxPool", "AvgPoolGrad", "MaxPoolGrad", "Conv2DBackpropInput",
|
||||
"Conv2DBackpropFilter",
|
||||
# Other ops
|
||||
"AddN",
|
||||
# Ops implemented in core tensorflow:
|
||||
"MatMul",
|
||||
"Conv2D",
|
||||
"DepthwiseConv2dNative",
|
||||
"BiasAdd",
|
||||
"Dilation2D",
|
||||
"MatMul", "Conv2D", "DepthwiseConv2dNative", "BiasAdd", "Dilation2D",
|
||||
])
|
||||
|
||||
|
||||
@ -177,7 +145,6 @@ def _binary_per_element_op_flops(graph, node, ops_per_element=1):
|
||||
|
||||
|
||||
@ops.RegisterStatistics("Add", "flops")
|
||||
@ops.RegisterStatistics("AddV2", "flops")
|
||||
def _add_flops(graph, node):
|
||||
"""Compute flops for Add operation."""
|
||||
return _binary_per_element_op_flops(graph, node)
|
||||
|
Loading…
Reference in New Issue
Block a user