Use scatter_nd to compute confusion_matrix and enable TPU compatibility.

PiperOrigin-RevId: 327122675
Change-Id: I6d6a6d093ffe45e1658c43c3120684db4eafebb5
This commit is contained in:
A. Unique TensorFlower 2020-08-17 16:24:48 -07:00 committed by TensorFlower Gardener
parent fb78e1b6c1
commit 66d54d7de2

View File

@ -20,12 +20,10 @@ from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@ -194,13 +192,10 @@ def confusion_matrix(labels,
indices = array_ops.stack([labels, predictions], axis=1)
values = (array_ops.ones_like(predictions, dtype)
if weights is None else weights)
cm_sparse = sparse_tensor.SparseTensor(
return array_ops.scatter_nd(
indices=indices,
values=values,
dense_shape=math_ops.cast(shape, dtypes.int64))
zero_matrix = array_ops.zeros(math_ops.cast(shape, dtypes.int32), dtype)
return sparse_ops.sparse_add(zero_matrix, cm_sparse)
updates=values,
shape=math_ops.cast(shape, dtypes.int64))
@tf_export(v1=['math.confusion_matrix', 'confusion_matrix'])