Supporting expand_dims for StructuredTensor.
PiperOrigin-RevId: 343774969 Change-Id: Iaba40cbac16f85427ee4245014e063699b0ffe99
This commit is contained in:
parent
2ad6aa9304
commit
3ae4bd8bf6
@ -18,17 +18,60 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["nofixdeps"],
|
||||
deps = [
|
||||
":structured_array_ops",
|
||||
":structured_tensor",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "structured_tensor",
|
||||
srcs = ["structured_tensor.py"],
|
||||
srcs = [
|
||||
"structured_array_ops.py",
|
||||
"structured_tensor.py",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:check_ops",
|
||||
"//tensorflow/python:composite_tensor",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:type_spec",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/ops/ragged:ragged_factory_ops",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
"//tensorflow/python/ops/ragged:row_partition",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "structured_array_ops",
|
||||
srcs = [
|
||||
"structured_array_ops.py",
|
||||
],
|
||||
deps = [
|
||||
":structured_tensor",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:check_ops",
|
||||
"//tensorflow/python:composite_tensor",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:type_spec",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/ops/ragged:ragged_factory_ops",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
"//tensorflow/python/ops/ragged:row_partition",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
@ -37,13 +80,23 @@ py_test(
|
||||
srcs = ["structured_tensor_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":structured_array_ops",
|
||||
":structured_tensor",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/ops/ragged:ragged_factory_ops",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
"//tensorflow/python/ops/ragged:row_partition",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
157
tensorflow/python/ops/structured/structured_array_ops.py
Normal file
157
tensorflow/python/ops/structured/structured_array_ops.py
Normal file
@ -0,0 +1,157 @@
|
||||
# Lint as python3
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""StructuredTensor array ops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.ragged.row_partition import RowPartition
|
||||
from tensorflow.python.ops.structured.structured_tensor import StructuredTensor
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import dispatch
|
||||
|
||||
|
||||
@dispatch.dispatch_for_types(array_ops.expand_dims, StructuredTensor)
|
||||
@deprecation.deprecated_args(None, 'Use the `axis` argument instead', 'dim')
|
||||
def expand_dims(input, axis=None, name=None, dim=None): # pylint: disable=redefined-builtin
|
||||
"""Creates a StructuredTensor with a length 1 axis inserted at index `axis`.
|
||||
|
||||
This is an implementation of tf.expand_dims for StructuredTensor. Note
|
||||
that the `axis` must be less than or equal to rank.
|
||||
|
||||
>>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]])
|
||||
>>> tf.expand_dims(st, 0).to_pyval()
|
||||
[[[{'x': 1}, {'x': 2}], [{'x': 3}]]]
|
||||
>>> tf.expand_dims(st, 1).to_pyval()
|
||||
[[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]]
|
||||
>>> tf.expand_dims(st, 2).to_pyval()
|
||||
[[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
|
||||
>>> tf.expand_dims(st, -1).to_pyval() # -1 is the same as 2
|
||||
[[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
|
||||
|
||||
Args:
|
||||
input: the original StructuredTensor.
|
||||
axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank`
|
||||
name: the name of the op.
|
||||
dim: deprecated: use axis.
|
||||
|
||||
Returns:
|
||||
a new structured tensor with larger rank.
|
||||
|
||||
Raises:
|
||||
an error if `axis < -(rank + 1)` or `rank < axis`.
|
||||
"""
|
||||
axis = deprecation.deprecated_argument_lookup('axis', axis, 'dim', dim)
|
||||
return _expand_dims_impl(input, axis, name=name)
|
||||
|
||||
|
||||
@dispatch.dispatch_for_types(array_ops.expand_dims_v2, StructuredTensor)
|
||||
def expand_dims_v2(input, axis, name=None): # pylint: disable=redefined-builtin
|
||||
"""Creates a StructuredTensor with a length 1 axis inserted at index `axis`.
|
||||
|
||||
This is an implementation of tf.expand_dims for StructuredTensor. Note
|
||||
that the `axis` must be less than or equal to rank.
|
||||
|
||||
>>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]])
|
||||
>>> tf.expand_dims(st, 0).to_pyval()
|
||||
[[[{'x': 1}, {'x': 2}], [{'x': 3}]]]
|
||||
>>> tf.expand_dims(st, 1).to_pyval()
|
||||
[[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]]
|
||||
>>> tf.expand_dims(st, 2).to_pyval()
|
||||
[[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
|
||||
>>> tf.expand_dims(st, -1).to_pyval() # -1 is the same as 2
|
||||
[[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
|
||||
|
||||
Args:
|
||||
input: the original StructuredTensor.
|
||||
axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank`
|
||||
name: the name of the op.
|
||||
|
||||
Returns:
|
||||
a new structured tensor with larger rank.
|
||||
|
||||
Raises:
|
||||
an error if `axis < -(rank + 1)` or `rank < axis`.
|
||||
"""
|
||||
return _expand_dims_impl(input, axis, name=name)
|
||||
|
||||
|
||||
def _expand_dims_impl(st, axis, name=None): # pylint: disable=redefined-builtin
|
||||
"""Creates a StructuredTensor with a length 1 axis inserted at index `axis`.
|
||||
|
||||
This is an implementation of tf.expand_dims for StructuredTensor. Note
|
||||
that the `axis` must be less than or equal to rank.
|
||||
|
||||
>>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]])
|
||||
>>> tf.expand_dims(st, 0).to_pyval()
|
||||
[[[{'x': 1}, {'x': 2}], [{'x': 3}]]]
|
||||
>>> tf.expand_dims(st, 1).to_pyval()
|
||||
[[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]]
|
||||
>>> tf.expand_dims(st, 2).to_pyval()
|
||||
[[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
|
||||
>>> tf.expand_dims(st, -1).to_pyval() # -1 is the same as 2
|
||||
[[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
|
||||
|
||||
Args:
|
||||
st: the original StructuredTensor.
|
||||
axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank`
|
||||
name: the name of the op.
|
||||
|
||||
Returns:
|
||||
a new structured tensor with larger rank.
|
||||
|
||||
Raises:
|
||||
an error if `axis < -(rank + 1)` or `rank < axis`.
|
||||
"""
|
||||
axis = array_ops.get_positive_axis(
|
||||
axis, st.rank + 1, axis_name='axis', ndims_name='rank(st)')
|
||||
with ops.name_scope(name, 'ExpandDims', [st, axis]):
|
||||
new_fields = {
|
||||
k: array_ops.expand_dims(v, axis)
|
||||
for (k, v) in st._fields.items()
|
||||
}
|
||||
new_shape = st.shape[:axis] + (1,) + st.shape[axis:]
|
||||
new_row_partitions = _expand_st_row_partitions(st, axis)
|
||||
new_nrows = st.nrows() if (axis > 0) else 1
|
||||
return StructuredTensor.from_fields(
|
||||
new_fields,
|
||||
shape=new_shape,
|
||||
row_partitions=new_row_partitions,
|
||||
nrows=new_nrows)
|
||||
|
||||
|
||||
def _expand_st_row_partitions(st, axis):
|
||||
"""Create the row_partitions for expand_dims."""
|
||||
if axis == 0:
|
||||
if st.shape.rank == 0:
|
||||
return ()
|
||||
nvals = st.nrows()
|
||||
new_partition = RowPartition.from_uniform_row_length(
|
||||
nvals, nvals, nrows=1, validate=False)
|
||||
return (new_partition,) + st.row_partitions
|
||||
elif axis == st.rank:
|
||||
nvals = (
|
||||
st.row_partitions[axis - 2].nvals() if (axis - 2 >= 0) else st.nrows())
|
||||
return st.row_partitions + (RowPartition.from_uniform_row_length(
|
||||
1, nvals, nrows=nvals, validate=False),)
|
||||
else:
|
||||
nvals = (
|
||||
st.row_partitions[axis - 1].nrows() if (axis - 1 >= 0) else st.nrows())
|
||||
return st.row_partitions[:axis - 1] + (RowPartition.from_uniform_row_length(
|
||||
1, nvals, nrows=nvals, validate=False),) + st.row_partitions[axis - 1:]
|
@ -36,6 +36,10 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import row_partition
|
||||
|
||||
# TODO(b/173144447): remove when structured_array_ops is included in init.
|
||||
from tensorflow.python.ops.structured import structured_array_ops # pylint: disable=unused-import
|
||||
|
||||
from tensorflow.python.ops.structured import structured_tensor
|
||||
from tensorflow.python.ops.structured.structured_tensor import StructuredTensor
|
||||
from tensorflow.python.platform import googletest
|
||||
@ -977,6 +981,114 @@ class StructuredTensorTest(test_util.TensorFlowTestCase,
|
||||
r"or equal to inner_axis \(1\)"):
|
||||
st.merge_dims(2, 1)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
dict(
|
||||
testcase_name="0D_0",
|
||||
st={"x": 1},
|
||||
axis=0,
|
||||
expected=[{"x": 1}]),
|
||||
dict(
|
||||
testcase_name="0D_minus_1",
|
||||
st={"x": 1},
|
||||
axis=-1,
|
||||
expected=[{"x": 1}]),
|
||||
dict(
|
||||
testcase_name="1D_0",
|
||||
st=[{"x": [1, 3]}, {"x": [2, 7, 9]}],
|
||||
axis=0,
|
||||
expected=[[{"x": [1, 3]}, {"x": [2, 7, 9]}]]),
|
||||
dict(
|
||||
testcase_name="1D_1",
|
||||
st=[{"x": [1]}, {"x": [2, 10]}],
|
||||
axis=1,
|
||||
expected=[[{"x": [1]}], [{"x": [2, 10]}]]),
|
||||
dict(
|
||||
testcase_name="2D_0",
|
||||
st=[[{"x": [1]}, {"x": [2]}], [{"x": [3, 4]}]],
|
||||
axis=0,
|
||||
expected=[[[{"x": [1]}, {"x": [2]}], [{"x": [3, 4]}]]]),
|
||||
dict(
|
||||
testcase_name="2D_1",
|
||||
st=[[{"x": 1}, {"x": 2}], [{"x": 3}]],
|
||||
axis=1,
|
||||
expected=[[[{"x": 1}, {"x": 2}]], [[{"x": 3}]]]),
|
||||
dict(
|
||||
testcase_name="2D_2",
|
||||
st=[[{"x": [1]}, {"x": [2]}], [{"x": [3, 4]}]],
|
||||
axis=2,
|
||||
expected=[[[{"x": [1]}], [{"x": [2]}]], [[{"x": [3, 4]}]]]),
|
||||
dict(
|
||||
testcase_name="3D_0",
|
||||
st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
|
||||
axis=0,
|
||||
expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]],
|
||||
[[{"x": [4, 5]}]]]]),
|
||||
dict(
|
||||
testcase_name="3D_minus_4",
|
||||
st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
|
||||
axis=-4, # same as zero
|
||||
expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]],
|
||||
[[{"x": [4, 5]}]]]]),
|
||||
dict(
|
||||
testcase_name="3D_1",
|
||||
st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
|
||||
axis=1,
|
||||
expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]]],
|
||||
[[[{"x": [4, 5]}]]]]),
|
||||
dict(
|
||||
testcase_name="3D_minus_3",
|
||||
st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
|
||||
axis=-3, # same as 1
|
||||
expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]]],
|
||||
[[[{"x": [4, 5]}]]]]),
|
||||
dict(
|
||||
testcase_name="3D_2",
|
||||
st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
|
||||
axis=2,
|
||||
expected=[[[[{"x": [1]}, {"x": [2]}]], [[{"x": [3]}]]],
|
||||
[[[{"x": [4, 5]}]]]]),
|
||||
dict(
|
||||
testcase_name="3D_minus_2",
|
||||
st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
|
||||
axis=-2, # same as 2
|
||||
expected=[[[[{"x": [1]}, {"x": [2]}]], [[{"x": [3]}]]],
|
||||
[[[{"x": [4, 5]}]]]]),
|
||||
dict(
|
||||
testcase_name="3D_3",
|
||||
st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
|
||||
axis=3,
|
||||
expected=[[[[{"x": [1]}], [{"x": [2]}]], [[{"x": [3]}]]],
|
||||
[[[{"x": [4, 5]}]]]]),
|
||||
]) # pyformat: disable
|
||||
def testExpandDims(self, st, axis, expected):
|
||||
st = StructuredTensor.from_pyval(st)
|
||||
result = array_ops.expand_dims(st, axis)
|
||||
self.assertAllEqual(result, expected)
|
||||
|
||||
def testExpandDimsAxisTooBig(self):
|
||||
st = [[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]]
|
||||
st = StructuredTensor.from_pyval(st)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"axis=4 out of bounds: expected -4<=axis<4"):
|
||||
array_ops.expand_dims(st, 4)
|
||||
|
||||
def testExpandDimsAxisTooSmall(self):
|
||||
st = [[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]]
|
||||
st = StructuredTensor.from_pyval(st)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"axis=-5 out of bounds: expected -4<=axis<4"):
|
||||
array_ops.expand_dims(st, -5)
|
||||
|
||||
def testExpandDimsScalar(self):
|
||||
# Note that if we expand_dims for the final dimension and there are scalar
|
||||
# fields, then the shape is (2, None, None, 1), whereas if it is constructed
|
||||
# from pyval it is (2, None, None, None).
|
||||
st = [[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]]
|
||||
st = StructuredTensor.from_pyval(st)
|
||||
result = array_ops.expand_dims(st, 3)
|
||||
expected_shape = tensor_shape.TensorShape([2, None, None, 1])
|
||||
self.assertEqual(repr(expected_shape), repr(result.shape))
|
||||
|
||||
def testTupleFieldValue(self):
|
||||
st = StructuredTensor.from_pyval({"a": 5, "b": {"c": [1, 2, 3]}})
|
||||
self.assertAllEqual(st.field_value(("a",)), 5)
|
||||
|
Loading…
x
Reference in New Issue
Block a user