Merge pull request #30575 from csarron:patch-1

PiperOrigin-RevId: 257869411
This commit is contained in:
TensorFlower Gardener 2019-07-12 14:36:25 -07:00
commit 73de47a88f

View File

@ -2795,6 +2795,22 @@ def _calc_mat_mul_flops(graph, node):
return ops.OpStats("flops", (k * output_count * 2))
@ops.RegisterStatistics("BatchMatMul", "flops")
def _calc_batch_mat_mul_flops(graph, node):
"""Calculates the compute resources needed for BatchMatMul."""
transpose_a = node.attr["transpose_a"].b
a_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
a_shape.assert_is_fully_defined()
if transpose_a:
k = int(a_shape[-2])
else:
k = int(a_shape[-1])
output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
output_shape.assert_is_fully_defined()
output_count = np.prod(output_shape.as_list())
return ops.OpStats("flops", (k * output_count * 2))
def _as_indexed_slices(x, optimize=True):
"""Convert 'x' to IndexedSlices.