ragged tensor support for some of the keras metrics
PiperOrigin-RevId: 250652803
This commit is contained in:
parent
61357a807c
commit
babc6fd531
@ -202,6 +202,8 @@ py_library(
|
||||
"//tensorflow/python/keras/mixed_precision/experimental:loss_scale_optimizer",
|
||||
"//tensorflow/python/keras/mixed_precision/experimental:policy",
|
||||
"//tensorflow/python/module",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
"//tensorflow/python/ops/ragged:ragged_util",
|
||||
"//tensorflow/python/training/tracking:data_structures",
|
||||
"//tensorflow/tools/docs:doc_controls",
|
||||
"@six_archive//:six",
|
||||
@ -1509,3 +1511,19 @@ tf_py_test(
|
||||
],
|
||||
tags = ["notsan"],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "metrics_utils_test",
|
||||
size = "small",
|
||||
srcs = ["utils/metrics_utils_test.py"],
|
||||
additional_deps = [
|
||||
":keras",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/ops/ragged:ragged_factory_ops",
|
||||
"//tensorflow/python/ops/ragged:ragged_test_util",
|
||||
],
|
||||
)
|
||||
|
@ -294,6 +294,9 @@ class Reduce(Metric):
|
||||
Returns:
|
||||
Update op.
|
||||
"""
|
||||
[values], sample_weight = \
|
||||
metrics_utils.ragged_assert_compatible_and_get_flat_values(
|
||||
[values], sample_weight)
|
||||
values = math_ops.cast(values, self._dtype)
|
||||
if sample_weight is not None:
|
||||
sample_weight = math_ops.cast(sample_weight, self._dtype)
|
||||
@ -497,6 +500,9 @@ class MeanRelativeError(Mean):
|
||||
"""
|
||||
y_true = math_ops.cast(y_true, self._dtype)
|
||||
y_pred = math_ops.cast(y_pred, self._dtype)
|
||||
[y_pred, y_true], sample_weight = \
|
||||
metrics_utils.ragged_assert_compatible_and_get_flat_values(
|
||||
[y_pred, y_true], sample_weight)
|
||||
y_pred, y_true = squeeze_or_expand_dimensions(y_pred, y_true)
|
||||
|
||||
y_pred, self.normalizer = confusion_matrix.remove_squeezable_dimensions(
|
||||
@ -549,6 +555,9 @@ class MeanMetricWrapper(Mean):
|
||||
"""
|
||||
y_true = math_ops.cast(y_true, self._dtype)
|
||||
y_pred = math_ops.cast(y_pred, self._dtype)
|
||||
[y_true, y_pred], sample_weight = \
|
||||
metrics_utils.ragged_assert_compatible_and_get_flat_values(
|
||||
[y_true, y_pred], sample_weight)
|
||||
y_pred, y_true = squeeze_or_expand_dimensions(y_pred, y_true)
|
||||
|
||||
matches = self._fn(y_true, y_pred, **self._fn_kwargs)
|
||||
@ -2715,6 +2724,9 @@ class SumOverBatchSizeMetricWrapper(SumOverBatchSize):
|
||||
|
||||
|
||||
def accuracy(y_true, y_pred):
|
||||
[y_pred, y_true], _ = \
|
||||
metrics_utils.ragged_assert_compatible_and_get_flat_values(
|
||||
[y_pred, y_true])
|
||||
y_pred.shape.assert_is_compatible_with(y_true.shape)
|
||||
if y_true.dtype != y_pred.dtype:
|
||||
y_pred = math_ops.cast(y_pred, y_true.dtype)
|
||||
|
@ -27,6 +27,7 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function as eager_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
@ -38,6 +39,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops import weights_broadcast_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
|
||||
@ -363,6 +365,26 @@ class KerasAccuracyTest(test.TestCase):
|
||||
result = self.evaluate(result_t)
|
||||
self.assertAlmostEqual(result, 0.96, 2) # 4.5/4.7
|
||||
|
||||
def test_accuracy_ragged(self):
|
||||
acc_obj = metrics.Accuracy(name='my_acc')
|
||||
self.evaluate(variables.variables_initializer(acc_obj.variables))
|
||||
|
||||
# verify that correct value is returned
|
||||
rt1 = ragged_factory_ops.constant([[1], [2], [3], [4]])
|
||||
rt2 = ragged_factory_ops.constant([[1], [2], [3], [4]])
|
||||
update_op = acc_obj.update_state(rt1, rt2)
|
||||
self.evaluate(update_op)
|
||||
result = self.evaluate(acc_obj.result())
|
||||
self.assertEqual(result, 1) # 2/2
|
||||
|
||||
# check with sample_weight
|
||||
rt1 = ragged_factory_ops.constant([[2], [1]])
|
||||
rt2 = ragged_factory_ops.constant([[2], [0]])
|
||||
sw_ragged = ragged_factory_ops.constant([[0.5], [0.2]])
|
||||
result_t = acc_obj(rt1, rt2, sample_weight=sw_ragged)
|
||||
result = self.evaluate(result_t)
|
||||
self.assertAlmostEqual(result, 0.96, 2) # 4.5/4.7
|
||||
|
||||
def test_binary_accuracy(self):
|
||||
acc_obj = metrics.BinaryAccuracy(name='my_acc')
|
||||
|
||||
@ -395,6 +417,26 @@ class KerasAccuracyTest(test.TestCase):
|
||||
result = self.evaluate(result_t)
|
||||
self.assertAlmostEqual(result, 0.67, 2) # 4.5/6.7
|
||||
|
||||
def test_binary_accuracy_ragged(self):
|
||||
acc_obj = metrics.BinaryAccuracy(name='my_acc')
|
||||
self.evaluate(variables.variables_initializer(acc_obj.variables))
|
||||
|
||||
# verify that correct value is returned
|
||||
rt1 = ragged_factory_ops.constant([[1], [0]])
|
||||
rt2 = ragged_factory_ops.constant([[1], [0]])
|
||||
update_op = acc_obj.update_state(rt1, rt2)
|
||||
self.evaluate(update_op)
|
||||
result = self.evaluate(acc_obj.result())
|
||||
self.assertEqual(result, 1) # 2/2
|
||||
|
||||
# check y_true squeeze only supported for dense tensors and is
|
||||
# not supported by ragged tensor (different ranks). --> error
|
||||
rt1 = ragged_factory_ops.constant([[[1], [1]]])
|
||||
rt2 = ragged_factory_ops.constant([[1], [0]])
|
||||
with self.assertRaises(ValueError):
|
||||
result_t = acc_obj(rt1, rt2)
|
||||
result = self.evaluate(result_t)
|
||||
|
||||
def test_binary_accuracy_threshold(self):
|
||||
acc_obj = metrics.BinaryAccuracy(threshold=0.7)
|
||||
self.evaluate(variables.variables_initializer(acc_obj.variables))
|
||||
@ -402,6 +444,15 @@ class KerasAccuracyTest(test.TestCase):
|
||||
result = self.evaluate(result_t)
|
||||
self.assertAlmostEqual(result, 0.5, 2)
|
||||
|
||||
def test_binary_accuracy_threshold_ragged(self):
|
||||
acc_obj = metrics.BinaryAccuracy(threshold=0.7)
|
||||
self.evaluate(variables.variables_initializer(acc_obj.variables))
|
||||
rt1 = ragged_factory_ops.constant([[1], [1], [0], [0]])
|
||||
rt2 = ragged_factory_ops.constant([[0.9], [0.6], [0.4], [0.8]])
|
||||
result_t = acc_obj(rt1, rt2)
|
||||
result = self.evaluate(result_t)
|
||||
self.assertAlmostEqual(result, 0.5, 2)
|
||||
|
||||
def test_categorical_accuracy(self):
|
||||
acc_obj = metrics.CategoricalAccuracy(name='my_acc')
|
||||
|
||||
@ -425,6 +476,26 @@ class KerasAccuracyTest(test.TestCase):
|
||||
result = self.evaluate(result_t)
|
||||
self.assertAlmostEqual(result, 0.93, 2) # 2.5/2.7
|
||||
|
||||
def test_categorical_accuracy_ragged(self):
|
||||
acc_obj = metrics.CategoricalAccuracy(name='my_acc')
|
||||
self.evaluate(variables.variables_initializer(acc_obj.variables))
|
||||
|
||||
# verify that correct value is returned
|
||||
rt1 = ragged_factory_ops.constant([[0, 0, 1], [0, 1, 0]])
|
||||
rt2 = ragged_factory_ops.constant([[0.1, 0.1, 0.8], [0.05, 0.95, 0]])
|
||||
update_op = acc_obj.update_state(rt1, rt2)
|
||||
self.evaluate(update_op)
|
||||
result = self.evaluate(acc_obj.result())
|
||||
self.assertEqual(result, 1) # 2/2
|
||||
|
||||
# check with sample_weight
|
||||
rt1 = ragged_factory_ops.constant([[0, 0, 1], [0, 1, 0]])
|
||||
rt2 = ragged_factory_ops.constant([[0.1, 0.1, 0.8], [0.05, 0, 0.95]])
|
||||
sample_weight = ragged_factory_ops.constant([[0.5], [0.2]])
|
||||
with self.assertRaises(errors_impl.InvalidArgumentError):
|
||||
result_t = acc_obj(rt1, rt2, sample_weight)
|
||||
result = self.evaluate(result_t)
|
||||
|
||||
def test_sparse_categorical_accuracy(self):
|
||||
acc_obj = metrics.SparseCategoricalAccuracy(name='my_acc')
|
||||
|
||||
@ -448,6 +519,19 @@ class KerasAccuracyTest(test.TestCase):
|
||||
result = self.evaluate(result_t)
|
||||
self.assertAlmostEqual(result, 0.93, 2) # 2.5/2.7
|
||||
|
||||
def test_sparse_categorical_accuracy_ragged(self):
|
||||
acc_obj = metrics.SparseCategoricalAccuracy(name='my_acc')
|
||||
|
||||
# verify that correct value is returned
|
||||
rt1 = ragged_factory_ops.constant([[2], [1]])
|
||||
rt2 = ragged_factory_ops.constant([[0.1, 0.1, 0.8], [0.05, 0.95, 0]])
|
||||
|
||||
with self.assertRaises(errors_impl.InvalidArgumentError):
|
||||
# sparse_categorical_accuracy is not supported for composite/ragged
|
||||
# tensors.
|
||||
update_op = acc_obj.update_state(rt1, rt2)
|
||||
self.evaluate(update_op)
|
||||
|
||||
def test_sparse_categorical_accuracy_mismatched_dims(self):
|
||||
acc_obj = metrics.SparseCategoricalAccuracy(name='my_acc')
|
||||
|
||||
|
@ -36,6 +36,8 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import weights_broadcast_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_util
|
||||
from tensorflow.python.util import tf_decorator
|
||||
|
||||
NEG_INF = -1e10
|
||||
@ -265,6 +267,9 @@ def update_confusion_matrix_variables(variables_to_update,
|
||||
return
|
||||
y_true = math_ops.cast(y_true, dtype=dtypes.float32)
|
||||
y_pred = math_ops.cast(y_pred, dtype=dtypes.float32)
|
||||
[y_pred,
|
||||
y_true], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true],
|
||||
sample_weight)
|
||||
y_pred.shape.assert_is_compatible_with(y_true.shape)
|
||||
|
||||
if not any(
|
||||
@ -387,3 +392,67 @@ def _filter_top_k(x, k):
|
||||
top_k_mask = math_ops.reduce_sum(
|
||||
array_ops.one_hot(top_k_idx, x.shape[-1], axis=-1), axis=-2)
|
||||
return x * top_k_mask + NEG_INF * (1 - top_k_mask)
|
||||
|
||||
|
||||
def ragged_assert_compatible_and_get_flat_values(values, mask=None):
|
||||
"""If ragged, it checks the compatibility and then returns the flat_values.
|
||||
|
||||
Note: If two tensors are dense, it does not check their compatibility.
|
||||
Note: Although two ragged tensors with different ragged ranks could have
|
||||
identical overall rank and dimension sizes and hence be compatible,
|
||||
we do not support those cases.
|
||||
Args:
|
||||
values: A list of potentially ragged tensor of the same ragged_rank.
|
||||
mask: A potentially ragged tensor of the same ragged_rank as elements in
|
||||
Values.
|
||||
|
||||
Returns:
|
||||
A tuple in which the first element is the list of tensors and the second
|
||||
is the mask tensor. ([Values], mask). Mask and the element in Values
|
||||
are equal to the flat_values of the input arguments (if they were ragged).
|
||||
"""
|
||||
if isinstance(values, list):
|
||||
is_all_ragged = \
|
||||
all(isinstance(rt, ragged_tensor.RaggedTensor) for rt in values)
|
||||
is_any_ragged = \
|
||||
any(isinstance(rt, ragged_tensor.RaggedTensor) for rt in values)
|
||||
else:
|
||||
is_all_ragged = isinstance(values, ragged_tensor.RaggedTensor)
|
||||
is_any_ragged = is_all_ragged
|
||||
if (is_all_ragged and
|
||||
((mask is None) or isinstance(mask, ragged_tensor.RaggedTensor))):
|
||||
to_be_stripped = False
|
||||
if not isinstance(values, list):
|
||||
values = [values]
|
||||
to_be_stripped = True
|
||||
|
||||
# NOTE: we leave the flat_values compatiblity to
|
||||
# tf.TensorShape `assert_is_compatible_with`
|
||||
# check if both dynamic dimensions are equal and then use the flat_values.
|
||||
nested_row_split_list = [rt.nested_row_splits for rt in values]
|
||||
assertion_list = ragged_util.assert_splits_match(nested_row_split_list)
|
||||
|
||||
# if both are ragged sample_weights also should be ragged with same dims.
|
||||
if isinstance(mask, ragged_tensor.RaggedTensor):
|
||||
assertion_list_for_mask = ragged_util.assert_splits_match(
|
||||
[nested_row_split_list[0], mask.nested_row_splits])
|
||||
tmp = control_flow_ops.with_dependencies(assertion_list_for_mask,
|
||||
mask.flat_values)
|
||||
mask = array_ops.expand_dims(tmp, -1)
|
||||
|
||||
# values has at least 1 element.
|
||||
flat_values = []
|
||||
for value in values:
|
||||
tmp = control_flow_ops.with_dependencies(assertion_list,
|
||||
value.flat_values)
|
||||
flat_values.append(array_ops.expand_dims(tmp, -1))
|
||||
|
||||
values = flat_values[0] if to_be_stripped else flat_values
|
||||
|
||||
elif is_any_ragged:
|
||||
raise TypeError('One of the inputs does not have acceptable types.')
|
||||
# values are empty or value are not ragged and mask is ragged.
|
||||
elif isinstance(mask, ragged_tensor.RaggedTensor):
|
||||
raise TypeError('Ragged mask is not allowed with non-ragged inputs.')
|
||||
|
||||
return values, mask
|
||||
|
253
tensorflow/python/keras/utils/metrics_utils_test.py
Normal file
253
tensorflow/python/keras/utils/metrics_utils_test.py
Normal file
@ -0,0 +1,253 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for metrics_utils."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.utils import metrics_utils
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedSizeOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([
|
||||
{
|
||||
'x_list': [1],
|
||||
'y_list': [2]
|
||||
},
|
||||
{
|
||||
'x_list': [1, 2],
|
||||
'y_list': [2, 3]
|
||||
},
|
||||
{
|
||||
'x_list': [1, 2, 4],
|
||||
'y_list': [2, 3, 5]
|
||||
},
|
||||
{
|
||||
'x_list': [[1, 2], [3, 4]],
|
||||
'y_list': [[2, 3], [5, 6]]
|
||||
},
|
||||
])
|
||||
def test_passing_dense_tensors(self, x_list, y_list):
|
||||
x = constant_op.constant(x_list)
|
||||
y = constant_op.constant(y_list)
|
||||
[x,
|
||||
y], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y])
|
||||
x.shape.assert_is_compatible_with(y.shape)
|
||||
|
||||
@parameterized.parameters([
|
||||
{
|
||||
'x_list': [1],
|
||||
},
|
||||
{
|
||||
'x_list': [1, 2],
|
||||
},
|
||||
{
|
||||
'x_list': [1, 2, 4],
|
||||
},
|
||||
{
|
||||
'x_list': [[1, 2], [3, 4]],
|
||||
},
|
||||
])
|
||||
def test_passing_one_dense_tensor(self, x_list):
|
||||
x = constant_op.constant(x_list)
|
||||
[x], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x])
|
||||
|
||||
@parameterized.parameters([
|
||||
{
|
||||
'x_list': [1],
|
||||
'y_list': [2]
|
||||
},
|
||||
{
|
||||
'x_list': [1, 2],
|
||||
'y_list': [2, 3]
|
||||
},
|
||||
{
|
||||
'x_list': [1, 2, 4],
|
||||
'y_list': [2, 3, 5]
|
||||
},
|
||||
{
|
||||
'x_list': [[1, 2], [3, 4]],
|
||||
'y_list': [[2, 3], [5, 6]]
|
||||
},
|
||||
{
|
||||
'x_list': [[1, 2], [3, 4], [1]],
|
||||
'y_list': [[2, 3], [5, 6], [3]]
|
||||
},
|
||||
{
|
||||
'x_list': [[1, 2], [], [1]],
|
||||
'y_list': [[2, 3], [], [3]]
|
||||
},
|
||||
])
|
||||
def test_passing_both_ragged(self, x_list, y_list):
|
||||
x = ragged_factory_ops.constant(x_list)
|
||||
y = ragged_factory_ops.constant(y_list)
|
||||
[x,
|
||||
y], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y])
|
||||
x.shape.assert_is_compatible_with(y.shape)
|
||||
|
||||
@parameterized.parameters([
|
||||
{
|
||||
'x_list': [1],
|
||||
},
|
||||
{
|
||||
'x_list': [1, 2],
|
||||
},
|
||||
{
|
||||
'x_list': [1, 2, 4],
|
||||
},
|
||||
{
|
||||
'x_list': [[1, 2], [3, 4]],
|
||||
},
|
||||
{
|
||||
'x_list': [[1, 2], [3, 4], [1]],
|
||||
},
|
||||
{
|
||||
'x_list': [[1, 2], [], [1]],
|
||||
},
|
||||
])
|
||||
def test_passing_one_ragged(self, x_list):
|
||||
x = ragged_factory_ops.constant(x_list)
|
||||
[x], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x])
|
||||
|
||||
@parameterized.parameters([
|
||||
{
|
||||
'x_list': [1],
|
||||
'y_list': [2],
|
||||
'mask_list': [0]
|
||||
},
|
||||
{
|
||||
'x_list': [1, 2],
|
||||
'y_list': [2, 3],
|
||||
'mask_list': [0, 1]
|
||||
},
|
||||
{
|
||||
'x_list': [1, 2, 4],
|
||||
'y_list': [2, 3, 5],
|
||||
'mask_list': [1, 1, 1]
|
||||
},
|
||||
{
|
||||
'x_list': [[1, 2], [3, 4]],
|
||||
'y_list': [[2, 3], [5, 6]],
|
||||
'mask_list': [[1, 1], [0, 1]]
|
||||
},
|
||||
{
|
||||
'x_list': [[1, 2], [3, 4], [1]],
|
||||
'y_list': [[2, 3], [5, 6], [3]],
|
||||
'mask_list': [[1, 1], [0, 0], [1]]
|
||||
},
|
||||
{
|
||||
'x_list': [[1, 2], [], [1]],
|
||||
'y_list': [[2, 3], [], [3]],
|
||||
'mask_list': [[1, 1], [], [0]]
|
||||
},
|
||||
])
|
||||
def test_passing_both_ragged_with_mask(self, x_list, y_list, mask_list):
|
||||
x = ragged_factory_ops.constant(x_list)
|
||||
y = ragged_factory_ops.constant(y_list)
|
||||
mask = ragged_factory_ops.constant(mask_list)
|
||||
[x, y], mask = \
|
||||
metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y], mask)
|
||||
x.shape.assert_is_compatible_with(y.shape)
|
||||
y.shape.assert_is_compatible_with(mask.shape)
|
||||
|
||||
@parameterized.parameters([
|
||||
{
|
||||
'x_list': [1],
|
||||
'mask_list': [0]
|
||||
},
|
||||
{
|
||||
'x_list': [1, 2],
|
||||
'mask_list': [0, 1]
|
||||
},
|
||||
{
|
||||
'x_list': [1, 2, 4],
|
||||
'mask_list': [1, 1, 1]
|
||||
},
|
||||
{
|
||||
'x_list': [[1, 2], [3, 4]],
|
||||
'mask_list': [[1, 1], [0, 1]]
|
||||
},
|
||||
{
|
||||
'x_list': [[1, 2], [3, 4], [1]],
|
||||
'mask_list': [[1, 1], [0, 0], [1]]
|
||||
},
|
||||
{
|
||||
'x_list': [[1, 2], [], [1]],
|
||||
'mask_list': [[1, 1], [], [0]]
|
||||
},
|
||||
])
|
||||
def test_passing_one_ragged_with_mask(self, x_list, mask_list):
|
||||
x = ragged_factory_ops.constant(x_list)
|
||||
mask = ragged_factory_ops.constant(mask_list)
|
||||
[x], mask = \
|
||||
metrics_utils.ragged_assert_compatible_and_get_flat_values([x], mask)
|
||||
x.shape.assert_is_compatible_with(mask.shape)
|
||||
|
||||
@parameterized.parameters([
|
||||
{
|
||||
'x_list': [[[1, 3]]],
|
||||
'y_list': [[2, 3]]
|
||||
},
|
||||
])
|
||||
def test_failing_different_ragged_and_dense_ranks(self, x_list, y_list):
|
||||
x = ragged_factory_ops.constant(x_list)
|
||||
y = ragged_factory_ops.constant(y_list)
|
||||
with self.assertRaises(ValueError): # pylint: disable=g-error-prone-assert-raises
|
||||
[x, y
|
||||
], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y])
|
||||
|
||||
@parameterized.parameters([
|
||||
{
|
||||
'x_list': [[[1, 3]]],
|
||||
'y_list': [[[2, 3]]],
|
||||
'mask_list': [[0, 1]]
|
||||
},
|
||||
])
|
||||
def test_failing_different_mask_ranks(self, x_list, y_list, mask_list):
|
||||
x = ragged_factory_ops.constant(x_list)
|
||||
y = ragged_factory_ops.constant(y_list)
|
||||
mask = ragged_factory_ops.constant(mask_list)
|
||||
with self.assertRaises(ValueError): # pylint: disable=g-error-prone-assert-raises
|
||||
[x, y
|
||||
], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y],
|
||||
mask)
|
||||
|
||||
# we do not support such cases that ragged_ranks are different but overall
|
||||
# dimension shapes and sizes are identical due to adding too much performance
|
||||
# overheads to the overall use cases.
|
||||
def test_failing_different_ragged_ranks(self):
|
||||
dt = constant_op.constant([[[1, 2]]])
|
||||
# adding a ragged dimension
|
||||
x = ragged_tensor.RaggedTensor.from_row_splits(dt, row_splits=[0, 1])
|
||||
y = ragged_factory_ops.constant([[[[1, 2]]]])
|
||||
with self.assertRaises(ValueError): # pylint: disable=g-error-prone-assert-raises
|
||||
[x, y], _ = \
|
||||
metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
Loading…
x
Reference in New Issue
Block a user