Add "AddV2" to tensorflow v1 flops calculation
This commit is contained in:
parent
6272b68ae7
commit
4e517f884c
@ -28,9 +28,9 @@ IMPLEMENTED_OPS = set([
|
||||
"Reciprocal", "Square", "Rsqrt", "Log", "Neg", "AssignSub", "AssignAdd",
|
||||
"L2Loss", "Softmax",
|
||||
# Binary ops
|
||||
"Add", "Sub", "Mul", "RealDiv", "Maximum", "Minimum", "Pow", "RsqrtGrad",
|
||||
"GreaterEqual", "Greater", "LessEqual", "Less", "Equal", "NotEqual",
|
||||
"SquaredDifference",
|
||||
"Add", "AddV2", "Sub", "Mul", "RealDiv", "Maximum", "Minimum", "Pow",
|
||||
"RsqrtGrad", "GreaterEqual", "Greater", "LessEqual", "Less", "Equal",
|
||||
"NotEqual", "SquaredDifference",
|
||||
# Reduction ops
|
||||
"Mean", "Sum", "ArgMax", "ArgMin", "BiasAddGrad",
|
||||
# Convolution and pooling
|
||||
@ -145,11 +145,11 @@ def _binary_per_element_op_flops(graph, node, ops_per_element=1):
|
||||
|
||||
|
||||
@ops.RegisterStatistics("Add", "flops")
|
||||
@ops.RegisterStatistics("AddV2", "flops")
|
||||
def _add_flops(graph, node):
|
||||
"""Compute flops for Add operation."""
|
||||
return _binary_per_element_op_flops(graph, node)
|
||||
|
||||
|
||||
@ops.RegisterStatistics("Sub", "flops")
|
||||
def _sub_flops(graph, node):
|
||||
"""Compute flops for Sub operation."""
|
||||
|
Loading…
Reference in New Issue
Block a user