ragged tensor support for some of the keras metrics

PiperOrigin-RevId: 250652803
This commit is contained in:
A. Unique TensorFlower 2019-05-30 02:21:14 -07:00 committed by TensorFlower Gardener
parent 61357a807c
commit babc6fd531
5 changed files with 436 additions and 0 deletions

View File

@ -202,6 +202,8 @@ py_library(
"//tensorflow/python/keras/mixed_precision/experimental:loss_scale_optimizer",
"//tensorflow/python/keras/mixed_precision/experimental:policy",
"//tensorflow/python/module",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//tensorflow/python/ops/ragged:ragged_util",
"//tensorflow/python/training/tracking:data_structures",
"//tensorflow/tools/docs:doc_controls",
"@six_archive//:six",
@ -1509,3 +1511,19 @@ tf_py_test(
],
tags = ["notsan"],
)
tf_py_test(
name = "metrics_utils_test",
size = "small",
srcs = ["utils/metrics_utils_test.py"],
additional_deps = [
":keras",
"@absl_py//absl/testing:parameterized",
"//tensorflow/python:framework_ops",
"//tensorflow/python:ops",
"//tensorflow/python:platform_test",
"//tensorflow/python/eager:context",
"//tensorflow/python/ops/ragged:ragged_factory_ops",
"//tensorflow/python/ops/ragged:ragged_test_util",
],
)

View File

@ -294,6 +294,9 @@ class Reduce(Metric):
Returns:
Update op.
"""
[values], sample_weight = \
metrics_utils.ragged_assert_compatible_and_get_flat_values(
[values], sample_weight)
values = math_ops.cast(values, self._dtype)
if sample_weight is not None:
sample_weight = math_ops.cast(sample_weight, self._dtype)
@ -497,6 +500,9 @@ class MeanRelativeError(Mean):
"""
y_true = math_ops.cast(y_true, self._dtype)
y_pred = math_ops.cast(y_pred, self._dtype)
[y_pred, y_true], sample_weight = \
metrics_utils.ragged_assert_compatible_and_get_flat_values(
[y_pred, y_true], sample_weight)
y_pred, y_true = squeeze_or_expand_dimensions(y_pred, y_true)
y_pred, self.normalizer = confusion_matrix.remove_squeezable_dimensions(
@ -549,6 +555,9 @@ class MeanMetricWrapper(Mean):
"""
y_true = math_ops.cast(y_true, self._dtype)
y_pred = math_ops.cast(y_pred, self._dtype)
[y_true, y_pred], sample_weight = \
metrics_utils.ragged_assert_compatible_and_get_flat_values(
[y_true, y_pred], sample_weight)
y_pred, y_true = squeeze_or_expand_dimensions(y_pred, y_true)
matches = self._fn(y_true, y_pred, **self._fn_kwargs)
@ -2715,6 +2724,9 @@ class SumOverBatchSizeMetricWrapper(SumOverBatchSize):
def accuracy(y_true, y_pred):
[y_pred, y_true], _ = \
metrics_utils.ragged_assert_compatible_and_get_flat_values(
[y_pred, y_true])
y_pred.shape.assert_is_compatible_with(y_true.shape)
if y_true.dtype != y_pred.dtype:
y_pred = math_ops.cast(y_pred, y_true.dtype)

View File

@ -27,6 +27,7 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import function as eager_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import keras_parameterized
@ -38,6 +39,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test
from tensorflow.python.training.tracking import util as trackable_utils
@ -363,6 +365,26 @@ class KerasAccuracyTest(test.TestCase):
result = self.evaluate(result_t)
self.assertAlmostEqual(result, 0.96, 2) # 4.5/4.7
def test_accuracy_ragged(self):
acc_obj = metrics.Accuracy(name='my_acc')
self.evaluate(variables.variables_initializer(acc_obj.variables))
# verify that correct value is returned
rt1 = ragged_factory_ops.constant([[1], [2], [3], [4]])
rt2 = ragged_factory_ops.constant([[1], [2], [3], [4]])
update_op = acc_obj.update_state(rt1, rt2)
self.evaluate(update_op)
result = self.evaluate(acc_obj.result())
self.assertEqual(result, 1) # 2/2
# check with sample_weight
rt1 = ragged_factory_ops.constant([[2], [1]])
rt2 = ragged_factory_ops.constant([[2], [0]])
sw_ragged = ragged_factory_ops.constant([[0.5], [0.2]])
result_t = acc_obj(rt1, rt2, sample_weight=sw_ragged)
result = self.evaluate(result_t)
self.assertAlmostEqual(result, 0.96, 2) # 4.5/4.7
def test_binary_accuracy(self):
acc_obj = metrics.BinaryAccuracy(name='my_acc')
@ -395,6 +417,26 @@ class KerasAccuracyTest(test.TestCase):
result = self.evaluate(result_t)
self.assertAlmostEqual(result, 0.67, 2) # 4.5/6.7
def test_binary_accuracy_ragged(self):
acc_obj = metrics.BinaryAccuracy(name='my_acc')
self.evaluate(variables.variables_initializer(acc_obj.variables))
# verify that correct value is returned
rt1 = ragged_factory_ops.constant([[1], [0]])
rt2 = ragged_factory_ops.constant([[1], [0]])
update_op = acc_obj.update_state(rt1, rt2)
self.evaluate(update_op)
result = self.evaluate(acc_obj.result())
self.assertEqual(result, 1) # 2/2
# check y_true squeeze only supported for dense tensors and is
# not supported by ragged tensor (different ranks). --> error
rt1 = ragged_factory_ops.constant([[[1], [1]]])
rt2 = ragged_factory_ops.constant([[1], [0]])
with self.assertRaises(ValueError):
result_t = acc_obj(rt1, rt2)
result = self.evaluate(result_t)
def test_binary_accuracy_threshold(self):
acc_obj = metrics.BinaryAccuracy(threshold=0.7)
self.evaluate(variables.variables_initializer(acc_obj.variables))
@ -402,6 +444,15 @@ class KerasAccuracyTest(test.TestCase):
result = self.evaluate(result_t)
self.assertAlmostEqual(result, 0.5, 2)
def test_binary_accuracy_threshold_ragged(self):
acc_obj = metrics.BinaryAccuracy(threshold=0.7)
self.evaluate(variables.variables_initializer(acc_obj.variables))
rt1 = ragged_factory_ops.constant([[1], [1], [0], [0]])
rt2 = ragged_factory_ops.constant([[0.9], [0.6], [0.4], [0.8]])
result_t = acc_obj(rt1, rt2)
result = self.evaluate(result_t)
self.assertAlmostEqual(result, 0.5, 2)
def test_categorical_accuracy(self):
acc_obj = metrics.CategoricalAccuracy(name='my_acc')
@ -425,6 +476,26 @@ class KerasAccuracyTest(test.TestCase):
result = self.evaluate(result_t)
self.assertAlmostEqual(result, 0.93, 2) # 2.5/2.7
def test_categorical_accuracy_ragged(self):
acc_obj = metrics.CategoricalAccuracy(name='my_acc')
self.evaluate(variables.variables_initializer(acc_obj.variables))
# verify that correct value is returned
rt1 = ragged_factory_ops.constant([[0, 0, 1], [0, 1, 0]])
rt2 = ragged_factory_ops.constant([[0.1, 0.1, 0.8], [0.05, 0.95, 0]])
update_op = acc_obj.update_state(rt1, rt2)
self.evaluate(update_op)
result = self.evaluate(acc_obj.result())
self.assertEqual(result, 1) # 2/2
# check with sample_weight
rt1 = ragged_factory_ops.constant([[0, 0, 1], [0, 1, 0]])
rt2 = ragged_factory_ops.constant([[0.1, 0.1, 0.8], [0.05, 0, 0.95]])
sample_weight = ragged_factory_ops.constant([[0.5], [0.2]])
with self.assertRaises(errors_impl.InvalidArgumentError):
result_t = acc_obj(rt1, rt2, sample_weight)
result = self.evaluate(result_t)
def test_sparse_categorical_accuracy(self):
acc_obj = metrics.SparseCategoricalAccuracy(name='my_acc')
@ -448,6 +519,19 @@ class KerasAccuracyTest(test.TestCase):
result = self.evaluate(result_t)
self.assertAlmostEqual(result, 0.93, 2) # 2.5/2.7
def test_sparse_categorical_accuracy_ragged(self):
acc_obj = metrics.SparseCategoricalAccuracy(name='my_acc')
# verify that correct value is returned
rt1 = ragged_factory_ops.constant([[2], [1]])
rt2 = ragged_factory_ops.constant([[0.1, 0.1, 0.8], [0.05, 0.95, 0]])
with self.assertRaises(errors_impl.InvalidArgumentError):
# sparse_categorical_accuracy is not supported for composite/ragged
# tensors.
update_op = acc_obj.update_state(rt1, rt2)
self.evaluate(update_op)
def test_sparse_categorical_accuracy_mismatched_dims(self):
acc_obj = metrics.SparseCategoricalAccuracy(name='my_acc')

View File

@ -36,6 +36,8 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_util
from tensorflow.python.util import tf_decorator
NEG_INF = -1e10
@ -265,6 +267,9 @@ def update_confusion_matrix_variables(variables_to_update,
return
y_true = math_ops.cast(y_true, dtype=dtypes.float32)
y_pred = math_ops.cast(y_pred, dtype=dtypes.float32)
[y_pred,
y_true], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true],
sample_weight)
y_pred.shape.assert_is_compatible_with(y_true.shape)
if not any(
@ -387,3 +392,67 @@ def _filter_top_k(x, k):
top_k_mask = math_ops.reduce_sum(
array_ops.one_hot(top_k_idx, x.shape[-1], axis=-1), axis=-2)
return x * top_k_mask + NEG_INF * (1 - top_k_mask)
def ragged_assert_compatible_and_get_flat_values(values, mask=None):
"""If ragged, it checks the compatibility and then returns the flat_values.
Note: If two tensors are dense, it does not check their compatibility.
Note: Although two ragged tensors with different ragged ranks could have
identical overall rank and dimension sizes and hence be compatible,
we do not support those cases.
Args:
values: A list of potentially ragged tensor of the same ragged_rank.
mask: A potentially ragged tensor of the same ragged_rank as elements in
Values.
Returns:
A tuple in which the first element is the list of tensors and the second
is the mask tensor. ([Values], mask). Mask and the element in Values
are equal to the flat_values of the input arguments (if they were ragged).
"""
if isinstance(values, list):
is_all_ragged = \
all(isinstance(rt, ragged_tensor.RaggedTensor) for rt in values)
is_any_ragged = \
any(isinstance(rt, ragged_tensor.RaggedTensor) for rt in values)
else:
is_all_ragged = isinstance(values, ragged_tensor.RaggedTensor)
is_any_ragged = is_all_ragged
if (is_all_ragged and
((mask is None) or isinstance(mask, ragged_tensor.RaggedTensor))):
to_be_stripped = False
if not isinstance(values, list):
values = [values]
to_be_stripped = True
# NOTE: we leave the flat_values compatiblity to
# tf.TensorShape `assert_is_compatible_with`
# check if both dynamic dimensions are equal and then use the flat_values.
nested_row_split_list = [rt.nested_row_splits for rt in values]
assertion_list = ragged_util.assert_splits_match(nested_row_split_list)
# if both are ragged sample_weights also should be ragged with same dims.
if isinstance(mask, ragged_tensor.RaggedTensor):
assertion_list_for_mask = ragged_util.assert_splits_match(
[nested_row_split_list[0], mask.nested_row_splits])
tmp = control_flow_ops.with_dependencies(assertion_list_for_mask,
mask.flat_values)
mask = array_ops.expand_dims(tmp, -1)
# values has at least 1 element.
flat_values = []
for value in values:
tmp = control_flow_ops.with_dependencies(assertion_list,
value.flat_values)
flat_values.append(array_ops.expand_dims(tmp, -1))
values = flat_values[0] if to_be_stripped else flat_values
elif is_any_ragged:
raise TypeError('One of the inputs does not have acceptable types.')
# values are empty or value are not ragged and mask is ragged.
elif isinstance(mask, ragged_tensor.RaggedTensor):
raise TypeError('Ragged mask is not allowed with non-ragged inputs.')
return values, mask

View File

@ -0,0 +1,253 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for metrics_utils."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.keras.utils import metrics_utils
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedSizeOpTest(ragged_test_util.RaggedTensorTestCase,
parameterized.TestCase):
@parameterized.parameters([
{
'x_list': [1],
'y_list': [2]
},
{
'x_list': [1, 2],
'y_list': [2, 3]
},
{
'x_list': [1, 2, 4],
'y_list': [2, 3, 5]
},
{
'x_list': [[1, 2], [3, 4]],
'y_list': [[2, 3], [5, 6]]
},
])
def test_passing_dense_tensors(self, x_list, y_list):
x = constant_op.constant(x_list)
y = constant_op.constant(y_list)
[x,
y], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y])
x.shape.assert_is_compatible_with(y.shape)
@parameterized.parameters([
{
'x_list': [1],
},
{
'x_list': [1, 2],
},
{
'x_list': [1, 2, 4],
},
{
'x_list': [[1, 2], [3, 4]],
},
])
def test_passing_one_dense_tensor(self, x_list):
x = constant_op.constant(x_list)
[x], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x])
@parameterized.parameters([
{
'x_list': [1],
'y_list': [2]
},
{
'x_list': [1, 2],
'y_list': [2, 3]
},
{
'x_list': [1, 2, 4],
'y_list': [2, 3, 5]
},
{
'x_list': [[1, 2], [3, 4]],
'y_list': [[2, 3], [5, 6]]
},
{
'x_list': [[1, 2], [3, 4], [1]],
'y_list': [[2, 3], [5, 6], [3]]
},
{
'x_list': [[1, 2], [], [1]],
'y_list': [[2, 3], [], [3]]
},
])
def test_passing_both_ragged(self, x_list, y_list):
x = ragged_factory_ops.constant(x_list)
y = ragged_factory_ops.constant(y_list)
[x,
y], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y])
x.shape.assert_is_compatible_with(y.shape)
@parameterized.parameters([
{
'x_list': [1],
},
{
'x_list': [1, 2],
},
{
'x_list': [1, 2, 4],
},
{
'x_list': [[1, 2], [3, 4]],
},
{
'x_list': [[1, 2], [3, 4], [1]],
},
{
'x_list': [[1, 2], [], [1]],
},
])
def test_passing_one_ragged(self, x_list):
x = ragged_factory_ops.constant(x_list)
[x], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x])
@parameterized.parameters([
{
'x_list': [1],
'y_list': [2],
'mask_list': [0]
},
{
'x_list': [1, 2],
'y_list': [2, 3],
'mask_list': [0, 1]
},
{
'x_list': [1, 2, 4],
'y_list': [2, 3, 5],
'mask_list': [1, 1, 1]
},
{
'x_list': [[1, 2], [3, 4]],
'y_list': [[2, 3], [5, 6]],
'mask_list': [[1, 1], [0, 1]]
},
{
'x_list': [[1, 2], [3, 4], [1]],
'y_list': [[2, 3], [5, 6], [3]],
'mask_list': [[1, 1], [0, 0], [1]]
},
{
'x_list': [[1, 2], [], [1]],
'y_list': [[2, 3], [], [3]],
'mask_list': [[1, 1], [], [0]]
},
])
def test_passing_both_ragged_with_mask(self, x_list, y_list, mask_list):
x = ragged_factory_ops.constant(x_list)
y = ragged_factory_ops.constant(y_list)
mask = ragged_factory_ops.constant(mask_list)
[x, y], mask = \
metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y], mask)
x.shape.assert_is_compatible_with(y.shape)
y.shape.assert_is_compatible_with(mask.shape)
@parameterized.parameters([
{
'x_list': [1],
'mask_list': [0]
},
{
'x_list': [1, 2],
'mask_list': [0, 1]
},
{
'x_list': [1, 2, 4],
'mask_list': [1, 1, 1]
},
{
'x_list': [[1, 2], [3, 4]],
'mask_list': [[1, 1], [0, 1]]
},
{
'x_list': [[1, 2], [3, 4], [1]],
'mask_list': [[1, 1], [0, 0], [1]]
},
{
'x_list': [[1, 2], [], [1]],
'mask_list': [[1, 1], [], [0]]
},
])
def test_passing_one_ragged_with_mask(self, x_list, mask_list):
x = ragged_factory_ops.constant(x_list)
mask = ragged_factory_ops.constant(mask_list)
[x], mask = \
metrics_utils.ragged_assert_compatible_and_get_flat_values([x], mask)
x.shape.assert_is_compatible_with(mask.shape)
@parameterized.parameters([
{
'x_list': [[[1, 3]]],
'y_list': [[2, 3]]
},
])
def test_failing_different_ragged_and_dense_ranks(self, x_list, y_list):
x = ragged_factory_ops.constant(x_list)
y = ragged_factory_ops.constant(y_list)
with self.assertRaises(ValueError): # pylint: disable=g-error-prone-assert-raises
[x, y
], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y])
@parameterized.parameters([
{
'x_list': [[[1, 3]]],
'y_list': [[[2, 3]]],
'mask_list': [[0, 1]]
},
])
def test_failing_different_mask_ranks(self, x_list, y_list, mask_list):
x = ragged_factory_ops.constant(x_list)
y = ragged_factory_ops.constant(y_list)
mask = ragged_factory_ops.constant(mask_list)
with self.assertRaises(ValueError): # pylint: disable=g-error-prone-assert-raises
[x, y
], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y],
mask)
# we do not support such cases that ragged_ranks are different but overall
# dimension shapes and sizes are identical due to adding too much performance
# overheads to the overall use cases.
def test_failing_different_ragged_ranks(self):
dt = constant_op.constant([[[1, 2]]])
# adding a ragged dimension
x = ragged_tensor.RaggedTensor.from_row_splits(dt, row_splits=[0, 1])
y = ragged_factory_ops.constant([[[[1, 2]]]])
with self.assertRaises(ValueError): # pylint: disable=g-error-prone-assert-raises
[x, y], _ = \
metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y])
if __name__ == '__main__':
googletest.main()