From 66d54d7de287da2fc9d4066002e8bf2f1a012a86 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 17 Aug 2020 16:24:48 -0700 Subject: [PATCH] Use scatter_nd to compute confusion_matrix and enable TPU compatibility. PiperOrigin-RevId: 327122675 Change-Id: I6d6a6d093ffe45e1658c43c3120684db4eafebb5 --- tensorflow/python/ops/confusion_matrix.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/ops/confusion_matrix.py b/tensorflow/python/ops/confusion_matrix.py index 39177defe57..38d3461bc0b 100644 --- a/tensorflow/python/ops/confusion_matrix.py +++ b/tensorflow/python/ops/confusion_matrix.py @@ -20,12 +20,10 @@ from __future__ import print_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import sparse_ops from tensorflow.python.util import deprecation from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @@ -194,13 +192,10 @@ def confusion_matrix(labels, indices = array_ops.stack([labels, predictions], axis=1) values = (array_ops.ones_like(predictions, dtype) if weights is None else weights) - cm_sparse = sparse_tensor.SparseTensor( + return array_ops.scatter_nd( indices=indices, - values=values, - dense_shape=math_ops.cast(shape, dtypes.int64)) - zero_matrix = array_ops.zeros(math_ops.cast(shape, dtypes.int32), dtype) - - return sparse_ops.sparse_add(zero_matrix, cm_sparse) + updates=values, + shape=math_ops.cast(shape, dtypes.int64)) @tf_export(v1=['math.confusion_matrix', 'confusion_matrix'])