Merge pull request #30575 from csarron:patch-1
PiperOrigin-RevId: 257869411
This commit is contained in:
commit
73de47a88f
@ -2795,6 +2795,22 @@ def _calc_mat_mul_flops(graph, node):
|
||||
return ops.OpStats("flops", (k * output_count * 2))
|
||||
|
||||
|
||||
@ops.RegisterStatistics("BatchMatMul", "flops")
|
||||
def _calc_batch_mat_mul_flops(graph, node):
|
||||
"""Calculates the compute resources needed for BatchMatMul."""
|
||||
transpose_a = node.attr["transpose_a"].b
|
||||
a_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
|
||||
a_shape.assert_is_fully_defined()
|
||||
if transpose_a:
|
||||
k = int(a_shape[-2])
|
||||
else:
|
||||
k = int(a_shape[-1])
|
||||
output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
|
||||
output_shape.assert_is_fully_defined()
|
||||
output_count = np.prod(output_shape.as_list())
|
||||
return ops.OpStats("flops", (k * output_count * 2))
|
||||
|
||||
|
||||
def _as_indexed_slices(x, optimize=True):
|
||||
"""Convert 'x' to IndexedSlices.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user