Add "AddV2" to tensorflow v1 flops calculation

This commit is contained in:
tigertang 2020-03-06 12:33:07 +08:00 committed by GitHub
parent 6272b68ae7
commit 4e517f884c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -28,9 +28,9 @@ IMPLEMENTED_OPS = set([
"Reciprocal", "Square", "Rsqrt", "Log", "Neg", "AssignSub", "AssignAdd",
"L2Loss", "Softmax",
# Binary ops
"Add", "Sub", "Mul", "RealDiv", "Maximum", "Minimum", "Pow", "RsqrtGrad",
"GreaterEqual", "Greater", "LessEqual", "Less", "Equal", "NotEqual",
"SquaredDifference",
"Add", "AddV2", "Sub", "Mul", "RealDiv", "Maximum", "Minimum", "Pow",
"RsqrtGrad", "GreaterEqual", "Greater", "LessEqual", "Less", "Equal",
"NotEqual", "SquaredDifference",
# Reduction ops
"Mean", "Sum", "ArgMax", "ArgMin", "BiasAddGrad",
# Convolution and pooling
@ -145,11 +145,11 @@ def _binary_per_element_op_flops(graph, node, ops_per_element=1):
@ops.RegisterStatistics("Add", "flops")
@ops.RegisterStatistics("AddV2", "flops")
def _add_flops(graph, node):
"""Compute flops for Add operation."""
return _binary_per_element_op_flops(graph, node)
@ops.RegisterStatistics("Sub", "flops")
def _sub_flops(graph, node):
"""Compute flops for Sub operation."""