Use scatter_nd to compute confusion_matrix and enable TPU compatibility.
PiperOrigin-RevId: 327122675 Change-Id: I6d6a6d093ffe45e1658c43c3120684db4eafebb5
This commit is contained in:
parent
fb78e1b6c1
commit
66d54d7de2
@ -20,12 +20,10 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import dispatch
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -194,13 +192,10 @@ def confusion_matrix(labels,
|
||||
indices = array_ops.stack([labels, predictions], axis=1)
|
||||
values = (array_ops.ones_like(predictions, dtype)
|
||||
if weights is None else weights)
|
||||
cm_sparse = sparse_tensor.SparseTensor(
|
||||
return array_ops.scatter_nd(
|
||||
indices=indices,
|
||||
values=values,
|
||||
dense_shape=math_ops.cast(shape, dtypes.int64))
|
||||
zero_matrix = array_ops.zeros(math_ops.cast(shape, dtypes.int32), dtype)
|
||||
|
||||
return sparse_ops.sparse_add(zero_matrix, cm_sparse)
|
||||
updates=values,
|
||||
shape=math_ops.cast(shape, dtypes.int64))
|
||||
|
||||
|
||||
@tf_export(v1=['math.confusion_matrix', 'confusion_matrix'])
|
||||
|
Loading…
Reference in New Issue
Block a user