Supporting expand_dims for StructuredTensor.

PiperOrigin-RevId: 343774969
Change-Id: Iaba40cbac16f85427ee4245014e063699b0ffe99
This commit is contained in:
A. Unique TensorFlower 2020-11-22 19:29:26 -08:00 committed by TensorFlower Gardener
parent 2ad6aa9304
commit 3ae4bd8bf6
3 changed files with 323 additions and 1 deletions

View File

@ -18,17 +18,60 @@ py_library(
srcs_version = "PY2AND3",
tags = ["nofixdeps"],
deps = [
":structured_array_ops",
":structured_tensor",
],
)
py_library(
name = "structured_tensor",
srcs = ["structured_tensor.py"],
srcs = [
"structured_array_ops.py",
"structured_tensor.py",
],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:composite_tensor",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:type_spec",
"//tensorflow/python:util",
"//tensorflow/python/ops/ragged:ragged_factory_ops",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//tensorflow/python/ops/ragged:row_partition",
"//third_party/py/numpy",
],
)
py_library(
name = "structured_array_ops",
srcs = [
"structured_array_ops.py",
],
deps = [
":structured_tensor",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:composite_tensor",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:type_spec",
"//tensorflow/python:util",
"//tensorflow/python/ops/ragged:ragged_factory_ops",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//tensorflow/python/ops/ragged:row_partition",
"//third_party/py/numpy",
],
)
@ -37,13 +80,23 @@ py_test(
srcs = ["structured_tensor_test.py"],
python_version = "PY3",
deps = [
":structured_array_ops",
":structured_tensor",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_spec",
"//tensorflow/python/eager:context",
"//tensorflow/python/ops/ragged:ragged_factory_ops",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//tensorflow/python/ops/ragged:row_partition",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -0,0 +1,157 @@
# Lint as python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""StructuredTensor array ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged.row_partition import RowPartition
from tensorflow.python.ops.structured.structured_tensor import StructuredTensor
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
@dispatch.dispatch_for_types(array_ops.expand_dims, StructuredTensor)
@deprecation.deprecated_args(None, 'Use the `axis` argument instead', 'dim')
def expand_dims(input, axis=None, name=None, dim=None): # pylint: disable=redefined-builtin
"""Creates a StructuredTensor with a length 1 axis inserted at index `axis`.
This is an implementation of tf.expand_dims for StructuredTensor. Note
that the `axis` must be less than or equal to rank.
>>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]])
>>> tf.expand_dims(st, 0).to_pyval()
[[[{'x': 1}, {'x': 2}], [{'x': 3}]]]
>>> tf.expand_dims(st, 1).to_pyval()
[[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]]
>>> tf.expand_dims(st, 2).to_pyval()
[[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
>>> tf.expand_dims(st, -1).to_pyval() # -1 is the same as 2
[[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
Args:
input: the original StructuredTensor.
axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank`
name: the name of the op.
dim: deprecated: use axis.
Returns:
a new structured tensor with larger rank.
Raises:
an error if `axis < -(rank + 1)` or `rank < axis`.
"""
axis = deprecation.deprecated_argument_lookup('axis', axis, 'dim', dim)
return _expand_dims_impl(input, axis, name=name)
@dispatch.dispatch_for_types(array_ops.expand_dims_v2, StructuredTensor)
def expand_dims_v2(input, axis, name=None): # pylint: disable=redefined-builtin
"""Creates a StructuredTensor with a length 1 axis inserted at index `axis`.
This is an implementation of tf.expand_dims for StructuredTensor. Note
that the `axis` must be less than or equal to rank.
>>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]])
>>> tf.expand_dims(st, 0).to_pyval()
[[[{'x': 1}, {'x': 2}], [{'x': 3}]]]
>>> tf.expand_dims(st, 1).to_pyval()
[[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]]
>>> tf.expand_dims(st, 2).to_pyval()
[[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
>>> tf.expand_dims(st, -1).to_pyval() # -1 is the same as 2
[[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
Args:
input: the original StructuredTensor.
axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank`
name: the name of the op.
Returns:
a new structured tensor with larger rank.
Raises:
an error if `axis < -(rank + 1)` or `rank < axis`.
"""
return _expand_dims_impl(input, axis, name=name)
def _expand_dims_impl(st, axis, name=None): # pylint: disable=redefined-builtin
"""Creates a StructuredTensor with a length 1 axis inserted at index `axis`.
This is an implementation of tf.expand_dims for StructuredTensor. Note
that the `axis` must be less than or equal to rank.
>>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]])
>>> tf.expand_dims(st, 0).to_pyval()
[[[{'x': 1}, {'x': 2}], [{'x': 3}]]]
>>> tf.expand_dims(st, 1).to_pyval()
[[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]]
>>> tf.expand_dims(st, 2).to_pyval()
[[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
>>> tf.expand_dims(st, -1).to_pyval() # -1 is the same as 2
[[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
Args:
st: the original StructuredTensor.
axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank`
name: the name of the op.
Returns:
a new structured tensor with larger rank.
Raises:
an error if `axis < -(rank + 1)` or `rank < axis`.
"""
axis = array_ops.get_positive_axis(
axis, st.rank + 1, axis_name='axis', ndims_name='rank(st)')
with ops.name_scope(name, 'ExpandDims', [st, axis]):
new_fields = {
k: array_ops.expand_dims(v, axis)
for (k, v) in st._fields.items()
}
new_shape = st.shape[:axis] + (1,) + st.shape[axis:]
new_row_partitions = _expand_st_row_partitions(st, axis)
new_nrows = st.nrows() if (axis > 0) else 1
return StructuredTensor.from_fields(
new_fields,
shape=new_shape,
row_partitions=new_row_partitions,
nrows=new_nrows)
def _expand_st_row_partitions(st, axis):
"""Create the row_partitions for expand_dims."""
if axis == 0:
if st.shape.rank == 0:
return ()
nvals = st.nrows()
new_partition = RowPartition.from_uniform_row_length(
nvals, nvals, nrows=1, validate=False)
return (new_partition,) + st.row_partitions
elif axis == st.rank:
nvals = (
st.row_partitions[axis - 2].nvals() if (axis - 2 >= 0) else st.nrows())
return st.row_partitions + (RowPartition.from_uniform_row_length(
1, nvals, nrows=nvals, validate=False),)
else:
nvals = (
st.row_partitions[axis - 1].nrows() if (axis - 1 >= 0) else st.nrows())
return st.row_partitions[:axis - 1] + (RowPartition.from_uniform_row_length(
1, nvals, nrows=nvals, validate=False),) + st.row_partitions[axis - 1:]

View File

@ -36,6 +36,10 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import row_partition
# TODO(b/173144447): remove when structured_array_ops is included in init.
from tensorflow.python.ops.structured import structured_array_ops # pylint: disable=unused-import
from tensorflow.python.ops.structured import structured_tensor
from tensorflow.python.ops.structured.structured_tensor import StructuredTensor
from tensorflow.python.platform import googletest
@ -977,6 +981,114 @@ class StructuredTensorTest(test_util.TensorFlowTestCase,
r"or equal to inner_axis \(1\)"):
st.merge_dims(2, 1)
@parameterized.named_parameters([
dict(
testcase_name="0D_0",
st={"x": 1},
axis=0,
expected=[{"x": 1}]),
dict(
testcase_name="0D_minus_1",
st={"x": 1},
axis=-1,
expected=[{"x": 1}]),
dict(
testcase_name="1D_0",
st=[{"x": [1, 3]}, {"x": [2, 7, 9]}],
axis=0,
expected=[[{"x": [1, 3]}, {"x": [2, 7, 9]}]]),
dict(
testcase_name="1D_1",
st=[{"x": [1]}, {"x": [2, 10]}],
axis=1,
expected=[[{"x": [1]}], [{"x": [2, 10]}]]),
dict(
testcase_name="2D_0",
st=[[{"x": [1]}, {"x": [2]}], [{"x": [3, 4]}]],
axis=0,
expected=[[[{"x": [1]}, {"x": [2]}], [{"x": [3, 4]}]]]),
dict(
testcase_name="2D_1",
st=[[{"x": 1}, {"x": 2}], [{"x": 3}]],
axis=1,
expected=[[[{"x": 1}, {"x": 2}]], [[{"x": 3}]]]),
dict(
testcase_name="2D_2",
st=[[{"x": [1]}, {"x": [2]}], [{"x": [3, 4]}]],
axis=2,
expected=[[[{"x": [1]}], [{"x": [2]}]], [[{"x": [3, 4]}]]]),
dict(
testcase_name="3D_0",
st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
axis=0,
expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]],
[[{"x": [4, 5]}]]]]),
dict(
testcase_name="3D_minus_4",
st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
axis=-4, # same as zero
expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]],
[[{"x": [4, 5]}]]]]),
dict(
testcase_name="3D_1",
st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
axis=1,
expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]]],
[[[{"x": [4, 5]}]]]]),
dict(
testcase_name="3D_minus_3",
st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
axis=-3, # same as 1
expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]]],
[[[{"x": [4, 5]}]]]]),
dict(
testcase_name="3D_2",
st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
axis=2,
expected=[[[[{"x": [1]}, {"x": [2]}]], [[{"x": [3]}]]],
[[[{"x": [4, 5]}]]]]),
dict(
testcase_name="3D_minus_2",
st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
axis=-2, # same as 2
expected=[[[[{"x": [1]}, {"x": [2]}]], [[{"x": [3]}]]],
[[[{"x": [4, 5]}]]]]),
dict(
testcase_name="3D_3",
st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
axis=3,
expected=[[[[{"x": [1]}], [{"x": [2]}]], [[{"x": [3]}]]],
[[[{"x": [4, 5]}]]]]),
]) # pyformat: disable
def testExpandDims(self, st, axis, expected):
st = StructuredTensor.from_pyval(st)
result = array_ops.expand_dims(st, axis)
self.assertAllEqual(result, expected)
def testExpandDimsAxisTooBig(self):
st = [[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]]
st = StructuredTensor.from_pyval(st)
with self.assertRaisesRegex(ValueError,
"axis=4 out of bounds: expected -4<=axis<4"):
array_ops.expand_dims(st, 4)
def testExpandDimsAxisTooSmall(self):
st = [[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]]
st = StructuredTensor.from_pyval(st)
with self.assertRaisesRegex(ValueError,
"axis=-5 out of bounds: expected -4<=axis<4"):
array_ops.expand_dims(st, -5)
def testExpandDimsScalar(self):
# Note that if we expand_dims for the final dimension and there are scalar
# fields, then the shape is (2, None, None, 1), whereas if it is constructed
# from pyval it is (2, None, None, None).
st = [[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]]
st = StructuredTensor.from_pyval(st)
result = array_ops.expand_dims(st, 3)
expected_shape = tensor_shape.TensorShape([2, None, None, 1])
self.assertEqual(repr(expected_shape), repr(result.shape))
def testTupleFieldValue(self):
st = StructuredTensor.from_pyval({"a": 5, "b": {"c": [1, 2, 3]}})
self.assertAllEqual(st.field_value(("a",)), 5)