Add ragged tensor support to tf.squeeze
PiperOrigin-RevId: 244385979
This commit is contained in:
parent
6aca010fce
commit
58f67785f6
@ -3100,6 +3100,7 @@ def sequence_mask(lengths, maxlen=None, dtype=dtypes.bool, name=None):
|
||||
|
||||
|
||||
@tf_export(v1=["squeeze"])
|
||||
@dispatch.add_dispatch_support
|
||||
@deprecation.deprecated_args(None, "Use the `axis` argument instead",
|
||||
"squeeze_dims")
|
||||
def squeeze(input, axis=None, name=None, squeeze_dims=None):
|
||||
@ -3125,12 +3126,18 @@ def squeeze(input, axis=None, name=None, squeeze_dims=None):
|
||||
tf.shape(tf.squeeze(t, [2, 4])) # [1, 2, 3, 1]
|
||||
```
|
||||
|
||||
Note: When it comes to squeezing ragged tensors, it has O(number of elements).
|
||||
|
||||
Note: if `input` is a `tf.RaggedTensor`, then this operation takes `O(N)`
|
||||
time, where `N` is the number of elements in the squeezed dimensions.
|
||||
|
||||
Args:
|
||||
input: A `Tensor`. The `input` to squeeze.
|
||||
axis: An optional list of `ints`. Defaults to `[]`. If specified, only
|
||||
squeezes the dimensions listed. The dimension index starts at 0. It is an
|
||||
error to squeeze a dimension that is not 1. Must be in the range
|
||||
`[-rank(input), rank(input))`.
|
||||
Must be specified if `input` is a `RaggedTensor`.
|
||||
name: A name for the operation (optional).
|
||||
squeeze_dims: Deprecated keyword argument that is now axis.
|
||||
|
||||
@ -3150,6 +3157,7 @@ def squeeze(input, axis=None, name=None, squeeze_dims=None):
|
||||
|
||||
|
||||
@tf_export("squeeze", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def squeeze_v2(input, axis=None, name=None):
|
||||
# pylint: disable=redefined-builtin
|
||||
return squeeze(input, axis, name)
|
||||
|
@ -269,6 +269,19 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "ragged_squeeze_op",
|
||||
srcs = ["ragged_squeeze_op.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_tensor",
|
||||
":ragged_tensor_shape",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "ragged_tensor",
|
||||
srcs = ["ragged_tensor.py"],
|
||||
@ -391,6 +404,7 @@ py_library(
|
||||
":ragged_array_ops",
|
||||
":ragged_batch_gather_ops",
|
||||
":ragged_math_ops",
|
||||
":ragged_squeeze_op",
|
||||
":ragged_tensor",
|
||||
":ragged_tensor_shape",
|
||||
":ragged_util",
|
||||
@ -1012,3 +1026,19 @@ py_test(
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "ragged_squeeze_op_test",
|
||||
srcs = ["ragged_squeeze_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_conversion_ops",
|
||||
":ragged_factory_ops",
|
||||
":ragged_squeeze_op",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
@ -36,10 +36,12 @@ from tensorflow.python.ops.ragged import ragged_batch_gather_ops
|
||||
from tensorflow.python.ops.ragged import ragged_concat_ops
|
||||
from tensorflow.python.ops.ragged import ragged_gather_ops
|
||||
from tensorflow.python.ops.ragged import ragged_math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_squeeze_op
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_tensor_shape
|
||||
from tensorflow.python.ops.ragged import ragged_util
|
||||
from tensorflow.python.ops.ragged import ragged_where_op
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import dispatch
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_export
|
||||
@ -432,6 +434,11 @@ def _ragged_size_v1(input, name=None, out_type=dtypes.int32): # pylint: disable
|
||||
return ragged_array_ops.size(input=input, out_type=out_type, name=name)
|
||||
|
||||
|
||||
def _ragged_squeeze_v1(input, axis=None, name=None, squeeze_dims=None): # pylint: disable=redefined-builtin
|
||||
axis = deprecation.deprecated_argument_lookup('axis', axis, 'squeeze_dims',
|
||||
squeeze_dims)
|
||||
return ragged_squeeze_op.squeeze(input, axis, name)
|
||||
|
||||
# (original_op, ragged_op, ragged_args)
|
||||
_RAGGED_DISPATCH_OPS = [
|
||||
(array_ops.batch_gather, ragged_batch_gather_ops.batch_gather,
|
||||
@ -442,11 +449,13 @@ _RAGGED_DISPATCH_OPS = [
|
||||
(array_ops.gather, _ragged_gather_v1, ['params', 'indices']),
|
||||
(array_ops.gather_v2, ragged_gather_ops.gather, ['params', 'indices']),
|
||||
(array_ops.gather_nd, _ragged_gather_nd_v1, ['params', 'indices']),
|
||||
(array_ops.gather_nd_v2, ragged_gather_ops.gather_nd,
|
||||
['params', 'indices']),
|
||||
(array_ops.gather_nd_v2, ragged_gather_ops.gather_nd, ['params',
|
||||
'indices']),
|
||||
(array_ops.rank, ragged_array_ops.rank, ['input']),
|
||||
(array_ops.size, _ragged_size_v1, ['input']),
|
||||
(array_ops.size_v2, ragged_array_ops.size, ['input']),
|
||||
(array_ops.squeeze, _ragged_squeeze_v1, ['input']),
|
||||
(array_ops.squeeze_v2, ragged_squeeze_op.squeeze, ['input']),
|
||||
(array_ops.stack, ragged_concat_ops.stack, ['[values]']),
|
||||
(array_ops.tile, ragged_array_ops.tile, ['input']),
|
||||
(array_ops.where, ragged_where_op.where, ['condition', 'x', 'y']),
|
||||
|
@ -695,6 +695,20 @@ class RaggedElementwiseOpsTest(ragged_test_util.RaggedTensorTestCase,
|
||||
op=array_ops.size_v2,
|
||||
kwargs={'input': ragged_factory_ops.constant_value([[8, 3], [5]])},
|
||||
expected=3),
|
||||
dict(
|
||||
op=array_ops.squeeze,
|
||||
kwargs={
|
||||
'input': ragged_factory_ops.constant_value([[[1, 2, 3], [4, 5]]]),
|
||||
'axis': [0]
|
||||
},
|
||||
expected=ragged_factory_ops.constant_value([[1, 2, 3], [4, 5]])),
|
||||
dict(
|
||||
op=array_ops.squeeze_v2,
|
||||
kwargs={
|
||||
'input': ragged_factory_ops.constant_value([[[1, 2, 3], [4, 5]]]),
|
||||
'axis': [0]
|
||||
},
|
||||
expected=ragged_factory_ops.constant_value([[1, 2, 3], [4, 5]])),
|
||||
])
|
||||
def testRaggedDispatch(self, op, expected, args=(), kwargs=None):
|
||||
if kwargs is None: kwargs = {}
|
||||
|
120
tensorflow/python/ops/ragged/ragged_squeeze_op.py
Normal file
120
tensorflow/python/ops/ragged/ragged_squeeze_op.py
Normal file
@ -0,0 +1,120 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Operator Squeeze for RaggedTensors."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_util
|
||||
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
|
||||
|
||||
|
||||
def squeeze(input, axis=None, name=None): # pylint: disable=redefined-builtin
|
||||
"""Ragged compatible squeeze.
|
||||
|
||||
If `input` is a `tf.Tensor`, then this calls `tf.squeeze`.
|
||||
|
||||
If `input` is a `tf.RaggedTensor`, then this operation takes `O(N)` time,
|
||||
where `N` is the number of elements in the squeezed dimensions.
|
||||
|
||||
Args:
|
||||
input: A potentially ragged tensor. The input to squeeze.
|
||||
axis: An optional list of ints. Defaults to `None`. If the `input` is
|
||||
ragged, it only squeezes the dimensions listed. It fails if `input` is
|
||||
ragged and axis is []. If `input` is not ragged it calls tf.squeeze. Note
|
||||
that it is an error to squeeze a dimension that is not 1. It must be in
|
||||
the range of [-rank(input), rank(input)).
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A potentially ragged tensor. Contains the same data as input,
|
||||
but has one or more dimensions of size 1 removed.
|
||||
"""
|
||||
with ops.name_scope(name, 'RaggedSqueeze', [input]):
|
||||
input = ragged_tensor.convert_to_tensor_or_ragged_tensor(input)
|
||||
if isinstance(input, ops.Tensor):
|
||||
return array_ops.squeeze(input, axis, name)
|
||||
|
||||
if axis is None:
|
||||
raise ValueError('Ragged.squeeze must have an axis argument.')
|
||||
if isinstance(axis, int):
|
||||
axis = [axis]
|
||||
elif ((not isinstance(axis, (list, tuple))) or
|
||||
(not all(isinstance(d, int) for d in axis))):
|
||||
raise TypeError('Axis must be a list or tuple of integers.')
|
||||
|
||||
dense_dims = []
|
||||
ragged_dims = []
|
||||
# Normalize all the dims in axis to be positive
|
||||
axis = [ragged_util.get_positive_axis(d, input.shape.ndims) for d in axis]
|
||||
for dim in axis:
|
||||
if dim > input.ragged_rank:
|
||||
dense_dims.append(dim - input.ragged_rank)
|
||||
else:
|
||||
ragged_dims.append(dim)
|
||||
|
||||
# Make sure the specified ragged dimensions are squeezable.
|
||||
assertion_list = []
|
||||
scalar_tensor_one = constant_op.constant(1, dtype=dtypes.int64)
|
||||
for i, r in enumerate(input.nested_row_lengths()):
|
||||
if i + 1 in ragged_dims:
|
||||
assertion_list.append(
|
||||
control_flow_ops.Assert(
|
||||
math_ops.reduce_all(math_ops.equal(r, scalar_tensor_one)),
|
||||
['the given axis (axis = %d) is not squeezable!' % (i + 1)]))
|
||||
if 0 in ragged_dims:
|
||||
scalar_tensor_two = constant_op.constant(2, dtype=dtypes.int32)
|
||||
assertion_list.append(
|
||||
control_flow_ops.Assert(
|
||||
math_ops.equal(
|
||||
array_ops.size(input.row_splits), scalar_tensor_two),
|
||||
['the given axis (axis = 0) is not squeezable!']))
|
||||
|
||||
# Till now, we are sure that the ragged dimensions are squeezable.
|
||||
squeezed_rt = None
|
||||
squeezed_rt = control_flow_ops.with_dependencies(assertion_list,
|
||||
input.flat_values)
|
||||
|
||||
if dense_dims:
|
||||
# Gives error if the dense dimension is not squeezable.
|
||||
squeezed_rt = array_ops.squeeze(squeezed_rt, dense_dims)
|
||||
|
||||
remaining_row_splits = []
|
||||
remaining_row_splits = list()
|
||||
for i, row_split in enumerate(input.nested_row_splits):
|
||||
# each row_splits tensor is for dimension #(i+1) .
|
||||
if (i + 1) not in ragged_dims:
|
||||
remaining_row_splits.append(row_split)
|
||||
# Take care of the first row if it is to be squeezed.
|
||||
if remaining_row_splits and 0 in ragged_dims:
|
||||
remaining_row_splits.pop(0)
|
||||
|
||||
squeezed_rt = RaggedTensor.from_nested_row_splits(squeezed_rt,
|
||||
remaining_row_splits)
|
||||
|
||||
# Corner case: when removing all the ragged dimensions and the output is
|
||||
# a scalar tensor e.g. ragged.squeeze(ragged.constant([[[1]]])).
|
||||
if set(range(0, input.ragged_rank + 1)).issubset(set(ragged_dims)):
|
||||
squeezed_rt = array_ops.squeeze(squeezed_rt, [0], name)
|
||||
|
||||
return squeezed_rt
|
292
tensorflow/python/ops/ragged/ragged_squeeze_op_test.py
Normal file
292
tensorflow/python/ops/ragged/ragged_squeeze_op_test.py
Normal file
@ -0,0 +1,292 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for ragged.size."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_conversion_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_squeeze_op
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedSqueezeTest(ragged_test_util.RaggedTensorTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([
|
||||
{
|
||||
'input_list': []
|
||||
},
|
||||
{
|
||||
'input_list': [[]],
|
||||
'squeeze_ranks': [0]
|
||||
},
|
||||
{
|
||||
'input_list': [[[[], []], [[], []]]],
|
||||
'squeeze_ranks': [0]
|
||||
},
|
||||
])
|
||||
def test_passing_empty(self, input_list, squeeze_ranks=None):
|
||||
rt = ragged_squeeze_op.squeeze(
|
||||
ragged_factory_ops.constant(input_list), squeeze_ranks)
|
||||
dt = array_ops.squeeze(constant_op.constant(input_list), squeeze_ranks)
|
||||
self.assertRaggedEqual(ragged_conversion_ops.to_tensor(rt), dt)
|
||||
|
||||
@parameterized.parameters([
|
||||
{
|
||||
'input_list': [[1]],
|
||||
'squeeze_ranks': [0]
|
||||
},
|
||||
{
|
||||
'input_list': [[1]],
|
||||
'squeeze_ranks': [0, 1]
|
||||
},
|
||||
{
|
||||
'input_list': [[1, 2]],
|
||||
'squeeze_ranks': [0]
|
||||
},
|
||||
{
|
||||
'input_list': [[1], [2]],
|
||||
'squeeze_ranks': [1]
|
||||
},
|
||||
{
|
||||
'input_list': [[[[12], [11]]]],
|
||||
'squeeze_ranks': [0]
|
||||
},
|
||||
{
|
||||
'input_list': [[[[12], [11]]]],
|
||||
'squeeze_ranks': [1]
|
||||
},
|
||||
{
|
||||
'input_list': [[[[12], [11]]]],
|
||||
'squeeze_ranks': [3]
|
||||
},
|
||||
{
|
||||
'input_list': [[[[12], [11]]]],
|
||||
'squeeze_ranks': [0, 3]
|
||||
},
|
||||
{
|
||||
'input_list': [[[[12], [11]]]],
|
||||
'squeeze_ranks': [0, 1]
|
||||
},
|
||||
{
|
||||
'input_list': [[[[12], [11]]]],
|
||||
'squeeze_ranks': [1, 3]
|
||||
},
|
||||
{
|
||||
'input_list': [[[[12], [11]]]],
|
||||
'squeeze_ranks': [0, 1, 3]
|
||||
},
|
||||
{
|
||||
'input_list': [[[1], [2]], [[3], [4]]],
|
||||
'squeeze_ranks': [2]
|
||||
},
|
||||
{
|
||||
'input_list': [[1], [2]],
|
||||
'squeeze_ranks': [-1]
|
||||
},
|
||||
])
|
||||
def test_passing_simple(self, input_list, squeeze_ranks=None):
|
||||
rt = ragged_squeeze_op.squeeze(
|
||||
ragged_factory_ops.constant(input_list), squeeze_ranks)
|
||||
dt = array_ops.squeeze(constant_op.constant(input_list), squeeze_ranks)
|
||||
self.assertRaggedEqual(ragged_conversion_ops.to_tensor(rt), dt)
|
||||
|
||||
@parameterized.parameters([
|
||||
# ragged_conversion_ops.from_tensor does not work for this
|
||||
# {'input_list': [1]},
|
||||
{
|
||||
'input_list': [[1]],
|
||||
'squeeze_ranks': [0]
|
||||
},
|
||||
{
|
||||
'input_list': [[1, 2]],
|
||||
'squeeze_ranks': [0]
|
||||
},
|
||||
{
|
||||
'input_list': [[1], [2]],
|
||||
'squeeze_ranks': [1]
|
||||
},
|
||||
{
|
||||
'input_list': [[[[12], [11]]]],
|
||||
'squeeze_ranks': [0]
|
||||
},
|
||||
{
|
||||
'input_list': [[[[12], [11]]]],
|
||||
'squeeze_ranks': [1]
|
||||
},
|
||||
{
|
||||
'input_list': [[[[12], [11]]]],
|
||||
'squeeze_ranks': [3]
|
||||
},
|
||||
{
|
||||
'input_list': [[[[12], [11]]]],
|
||||
'squeeze_ranks': [0, 3]
|
||||
},
|
||||
{
|
||||
'input_list': [[[[12], [11]]]],
|
||||
'squeeze_ranks': [0, 1]
|
||||
},
|
||||
{
|
||||
'input_list': [[[[12], [11]]]],
|
||||
'squeeze_ranks': [1, 3]
|
||||
},
|
||||
{
|
||||
'input_list': [[[[12], [11]]]],
|
||||
'squeeze_ranks': [0, 1, 3]
|
||||
},
|
||||
{
|
||||
'input_list': [[[1], [2]], [[3], [4]]],
|
||||
'squeeze_ranks': [2]
|
||||
},
|
||||
])
|
||||
def test_passing_simple_from_dense(self, input_list, squeeze_ranks=None):
|
||||
dt = constant_op.constant(input_list)
|
||||
rt = ragged_conversion_ops.from_tensor(dt)
|
||||
rt_s = ragged_squeeze_op.squeeze(rt, squeeze_ranks)
|
||||
dt_s = array_ops.squeeze(dt, squeeze_ranks)
|
||||
self.assertRaggedEqual(ragged_conversion_ops.to_tensor(rt_s), dt_s)
|
||||
|
||||
@parameterized.parameters([
|
||||
{
|
||||
'input_list': [[[[[[1]], [[1, 2]]]], [[[[]], [[]]]]]],
|
||||
'output_list': [[[1], [1, 2]], [[], []]],
|
||||
'squeeze_ranks': [0, 2, 4]
|
||||
},
|
||||
{
|
||||
'input_list': [[[[[[1]], [[1, 2]]]], [[[[]], [[]]]]]],
|
||||
'output_list': [[[[[1]], [[1, 2]]]], [[[[]], [[]]]]],
|
||||
'squeeze_ranks': [0]
|
||||
},
|
||||
])
|
||||
def test_passing_ragged(self, input_list, output_list, squeeze_ranks=None):
|
||||
rt = ragged_factory_ops.constant(input_list)
|
||||
rt_s = ragged_squeeze_op.squeeze(rt, squeeze_ranks)
|
||||
ref = ragged_factory_ops.constant(output_list)
|
||||
self.assertRaggedEqual(rt_s, ref)
|
||||
|
||||
def test_passing_text(self):
|
||||
rt = ragged_factory_ops.constant([[[[[[[['H']], [['e']], [['l']], [['l']],
|
||||
[['o']]],
|
||||
[[['W']], [['o']], [['r']], [['l']],
|
||||
[['d']], [['!']]]]],
|
||||
[[[[['T']], [['h']], [['i']], [['s']]],
|
||||
[[['i']], [['s']]],
|
||||
[[['M']], [['e']], [['h']], [['r']],
|
||||
[['d']], [['a']], [['d']]],
|
||||
[[['.']]]]]]]])
|
||||
output_list = [[['H', 'e', 'l', 'l', 'o'], ['W', 'o', 'r', 'l', 'd', '!']],
|
||||
[['T', 'h', 'i', 's'], ['i', 's'],
|
||||
['M', 'e', 'h', 'r', 'd', 'a', 'd'], ['.']]]
|
||||
ref = ragged_factory_ops.constant(output_list)
|
||||
rt_s = ragged_squeeze_op.squeeze(rt, [0, 1, 3, 6, 7])
|
||||
self.assertRaggedEqual(rt_s, ref)
|
||||
|
||||
@parameterized.parameters([
|
||||
{
|
||||
'input_list': [[]],
|
||||
'squeeze_ranks': [1]
|
||||
},
|
||||
{
|
||||
'input_list': [[1, 2]],
|
||||
'squeeze_ranks': [1]
|
||||
},
|
||||
{
|
||||
'input_list': [[1], [2]],
|
||||
'squeeze_ranks': [0]
|
||||
},
|
||||
{
|
||||
'input_list': [[[[12], [11]]]],
|
||||
'squeeze_ranks': [0, 2]
|
||||
},
|
||||
{
|
||||
'input_list': [[[[12], [11]]]],
|
||||
'squeeze_ranks': [2]
|
||||
},
|
||||
{
|
||||
'input_list': [[[1], [2]], [[3], [4]]],
|
||||
'squeeze_ranks': [0]
|
||||
},
|
||||
{
|
||||
'input_list': [[[1], [2]], [[3], [4]]],
|
||||
'squeeze_ranks': [1]
|
||||
},
|
||||
{
|
||||
'input_list': [[], []],
|
||||
'squeeze_ranks': [1]
|
||||
},
|
||||
{
|
||||
'input_list': [[[], []], [[], []]],
|
||||
'squeeze_ranks': [1]
|
||||
},
|
||||
])
|
||||
def test_failing_InvalidArgumentError(self, input_list, squeeze_ranks):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(
|
||||
ragged_squeeze_op.squeeze(
|
||||
ragged_factory_ops.constant(input_list), squeeze_ranks))
|
||||
|
||||
@parameterized.parameters([
|
||||
{
|
||||
'input_list': [[]]
|
||||
},
|
||||
{
|
||||
'input_list': [[1]]
|
||||
},
|
||||
{
|
||||
'input_list': [[1, 2]]
|
||||
},
|
||||
{
|
||||
'input_list': [[[1], [2]], [[3], [4]]]
|
||||
},
|
||||
{
|
||||
'input_list': [[1]]
|
||||
},
|
||||
{
|
||||
'input_list': [[[1], [2]], [[3], [4]]]
|
||||
},
|
||||
{
|
||||
'input_list': [[[[12], [11]]]]
|
||||
},
|
||||
])
|
||||
def test_failing_no_squeeze_dim_specified(self, input_list):
|
||||
with self.assertRaises(ValueError):
|
||||
ragged_squeeze_op.squeeze(ragged_factory_ops.constant(input_list))
|
||||
|
||||
@parameterized.parameters([
|
||||
{
|
||||
'input_list': [[[[12], [11]]]],
|
||||
'squeeze_ranks': [0, 1, 3]
|
||||
},
|
||||
])
|
||||
def test_failing_axis_is_not_a_list(self, input_list, squeeze_ranks):
|
||||
with self.assertRaises(TypeError):
|
||||
tensor_ranks = constant_op.constant(squeeze_ranks)
|
||||
ragged_squeeze_op.squeeze(
|
||||
ragged_factory_ops.constant(input_list), tensor_ranks)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
Loading…
Reference in New Issue
Block a user