diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 5c95c29c32d..898958faac2 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -2871,6 +2871,7 @@ def _calc_mat_mul_flops(graph, node): @ops.RegisterStatistics("BatchMatMul", "flops") +@ops.RegisterStatistics("BatchMatMulV2", "flops") def _calc_batch_mat_mul_flops(graph, node): """Calculates the compute resources needed for BatchMatMul.""" transpose_a = node.attr["transpose_a"].b