diff --git a/tensorflow/python/ops/structured/BUILD b/tensorflow/python/ops/structured/BUILD index 33834f0e914..81e8b37dc7d 100644 --- a/tensorflow/python/ops/structured/BUILD +++ b/tensorflow/python/ops/structured/BUILD @@ -18,17 +18,60 @@ py_library( srcs_version = "PY2AND3", tags = ["nofixdeps"], deps = [ + ":structured_array_ops", ":structured_tensor", ], ) py_library( name = "structured_tensor", - srcs = ["structured_tensor.py"], + srcs = [ + "structured_array_ops.py", + "structured_tensor.py", + ], deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:composite_tensor", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_spec", + "//tensorflow/python:type_spec", + "//tensorflow/python:util", + "//tensorflow/python/ops/ragged:ragged_factory_ops", "//tensorflow/python/ops/ragged:ragged_tensor", + "//tensorflow/python/ops/ragged:row_partition", + "//third_party/py/numpy", + ], +) + +py_library( + name = "structured_array_ops", + srcs = [ + "structured_array_ops.py", + ], + deps = [ + ":structured_tensor", + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:composite_tensor", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_spec", + "//tensorflow/python:type_spec", + "//tensorflow/python:util", + "//tensorflow/python/ops/ragged:ragged_factory_ops", + "//tensorflow/python/ops/ragged:ragged_tensor", + "//tensorflow/python/ops/ragged:row_partition", + "//third_party/py/numpy", ], ) @@ -37,13 +80,23 @@ py_test( srcs = ["structured_tensor_test.py"], python_version = "PY3", deps = [ + ":structured_array_ops", ":structured_tensor", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", "//tensorflow/python:sparse_tensor", "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_spec", + "//tensorflow/python/eager:context", "//tensorflow/python/ops/ragged:ragged_factory_ops", "//tensorflow/python/ops/ragged:ragged_tensor", + "//tensorflow/python/ops/ragged:row_partition", + "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/python/ops/structured/structured_array_ops.py b/tensorflow/python/ops/structured/structured_array_ops.py new file mode 100644 index 00000000000..dca8084575e --- /dev/null +++ b/tensorflow/python/ops/structured/structured_array_ops.py @@ -0,0 +1,157 @@ +# Lint as python3 +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""StructuredTensor array ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.ragged.row_partition import RowPartition +from tensorflow.python.ops.structured.structured_tensor import StructuredTensor +from tensorflow.python.util import deprecation +from tensorflow.python.util import dispatch + + +@dispatch.dispatch_for_types(array_ops.expand_dims, StructuredTensor) +@deprecation.deprecated_args(None, 'Use the `axis` argument instead', 'dim') +def expand_dims(input, axis=None, name=None, dim=None): # pylint: disable=redefined-builtin + """Creates a StructuredTensor with a length 1 axis inserted at index `axis`. + + This is an implementation of tf.expand_dims for StructuredTensor. Note + that the `axis` must be less than or equal to rank. + + >>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]]) + >>> tf.expand_dims(st, 0).to_pyval() + [[[{'x': 1}, {'x': 2}], [{'x': 3}]]] + >>> tf.expand_dims(st, 1).to_pyval() + [[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]] + >>> tf.expand_dims(st, 2).to_pyval() + [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]] + >>> tf.expand_dims(st, -1).to_pyval() # -1 is the same as 2 + [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]] + + Args: + input: the original StructuredTensor. + axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank` + name: the name of the op. + dim: deprecated: use axis. + + Returns: + a new structured tensor with larger rank. + + Raises: + an error if `axis < -(rank + 1)` or `rank < axis`. + """ + axis = deprecation.deprecated_argument_lookup('axis', axis, 'dim', dim) + return _expand_dims_impl(input, axis, name=name) + + +@dispatch.dispatch_for_types(array_ops.expand_dims_v2, StructuredTensor) +def expand_dims_v2(input, axis, name=None): # pylint: disable=redefined-builtin + """Creates a StructuredTensor with a length 1 axis inserted at index `axis`. + + This is an implementation of tf.expand_dims for StructuredTensor. Note + that the `axis` must be less than or equal to rank. + + >>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]]) + >>> tf.expand_dims(st, 0).to_pyval() + [[[{'x': 1}, {'x': 2}], [{'x': 3}]]] + >>> tf.expand_dims(st, 1).to_pyval() + [[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]] + >>> tf.expand_dims(st, 2).to_pyval() + [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]] + >>> tf.expand_dims(st, -1).to_pyval() # -1 is the same as 2 + [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]] + + Args: + input: the original StructuredTensor. + axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank` + name: the name of the op. + + Returns: + a new structured tensor with larger rank. + + Raises: + an error if `axis < -(rank + 1)` or `rank < axis`. + """ + return _expand_dims_impl(input, axis, name=name) + + +def _expand_dims_impl(st, axis, name=None): # pylint: disable=redefined-builtin + """Creates a StructuredTensor with a length 1 axis inserted at index `axis`. + + This is an implementation of tf.expand_dims for StructuredTensor. Note + that the `axis` must be less than or equal to rank. + + >>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]]) + >>> tf.expand_dims(st, 0).to_pyval() + [[[{'x': 1}, {'x': 2}], [{'x': 3}]]] + >>> tf.expand_dims(st, 1).to_pyval() + [[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]] + >>> tf.expand_dims(st, 2).to_pyval() + [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]] + >>> tf.expand_dims(st, -1).to_pyval() # -1 is the same as 2 + [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]] + + Args: + st: the original StructuredTensor. + axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank` + name: the name of the op. + + Returns: + a new structured tensor with larger rank. + + Raises: + an error if `axis < -(rank + 1)` or `rank < axis`. + """ + axis = array_ops.get_positive_axis( + axis, st.rank + 1, axis_name='axis', ndims_name='rank(st)') + with ops.name_scope(name, 'ExpandDims', [st, axis]): + new_fields = { + k: array_ops.expand_dims(v, axis) + for (k, v) in st._fields.items() + } + new_shape = st.shape[:axis] + (1,) + st.shape[axis:] + new_row_partitions = _expand_st_row_partitions(st, axis) + new_nrows = st.nrows() if (axis > 0) else 1 + return StructuredTensor.from_fields( + new_fields, + shape=new_shape, + row_partitions=new_row_partitions, + nrows=new_nrows) + + +def _expand_st_row_partitions(st, axis): + """Create the row_partitions for expand_dims.""" + if axis == 0: + if st.shape.rank == 0: + return () + nvals = st.nrows() + new_partition = RowPartition.from_uniform_row_length( + nvals, nvals, nrows=1, validate=False) + return (new_partition,) + st.row_partitions + elif axis == st.rank: + nvals = ( + st.row_partitions[axis - 2].nvals() if (axis - 2 >= 0) else st.nrows()) + return st.row_partitions + (RowPartition.from_uniform_row_length( + 1, nvals, nrows=nvals, validate=False),) + else: + nvals = ( + st.row_partitions[axis - 1].nrows() if (axis - 1 >= 0) else st.nrows()) + return st.row_partitions[:axis - 1] + (RowPartition.from_uniform_row_length( + 1, nvals, nrows=nvals, validate=False),) + st.row_partitions[axis - 1:] diff --git a/tensorflow/python/ops/structured/structured_tensor_test.py b/tensorflow/python/ops/structured/structured_tensor_test.py index 28acfbb3304..7a15b67662e 100644 --- a/tensorflow/python/ops/structured/structured_tensor_test.py +++ b/tensorflow/python/ops/structured/structured_tensor_test.py @@ -36,6 +36,10 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import row_partition + +# TODO(b/173144447): remove when structured_array_ops is included in init. +from tensorflow.python.ops.structured import structured_array_ops # pylint: disable=unused-import + from tensorflow.python.ops.structured import structured_tensor from tensorflow.python.ops.structured.structured_tensor import StructuredTensor from tensorflow.python.platform import googletest @@ -977,6 +981,114 @@ class StructuredTensorTest(test_util.TensorFlowTestCase, r"or equal to inner_axis \(1\)"): st.merge_dims(2, 1) + @parameterized.named_parameters([ + dict( + testcase_name="0D_0", + st={"x": 1}, + axis=0, + expected=[{"x": 1}]), + dict( + testcase_name="0D_minus_1", + st={"x": 1}, + axis=-1, + expected=[{"x": 1}]), + dict( + testcase_name="1D_0", + st=[{"x": [1, 3]}, {"x": [2, 7, 9]}], + axis=0, + expected=[[{"x": [1, 3]}, {"x": [2, 7, 9]}]]), + dict( + testcase_name="1D_1", + st=[{"x": [1]}, {"x": [2, 10]}], + axis=1, + expected=[[{"x": [1]}], [{"x": [2, 10]}]]), + dict( + testcase_name="2D_0", + st=[[{"x": [1]}, {"x": [2]}], [{"x": [3, 4]}]], + axis=0, + expected=[[[{"x": [1]}, {"x": [2]}], [{"x": [3, 4]}]]]), + dict( + testcase_name="2D_1", + st=[[{"x": 1}, {"x": 2}], [{"x": 3}]], + axis=1, + expected=[[[{"x": 1}, {"x": 2}]], [[{"x": 3}]]]), + dict( + testcase_name="2D_2", + st=[[{"x": [1]}, {"x": [2]}], [{"x": [3, 4]}]], + axis=2, + expected=[[[{"x": [1]}], [{"x": [2]}]], [[{"x": [3, 4]}]]]), + dict( + testcase_name="3D_0", + st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]], + axis=0, + expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], + [[{"x": [4, 5]}]]]]), + dict( + testcase_name="3D_minus_4", + st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]], + axis=-4, # same as zero + expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], + [[{"x": [4, 5]}]]]]), + dict( + testcase_name="3D_1", + st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]], + axis=1, + expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]]], + [[[{"x": [4, 5]}]]]]), + dict( + testcase_name="3D_minus_3", + st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]], + axis=-3, # same as 1 + expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]]], + [[[{"x": [4, 5]}]]]]), + dict( + testcase_name="3D_2", + st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]], + axis=2, + expected=[[[[{"x": [1]}, {"x": [2]}]], [[{"x": [3]}]]], + [[[{"x": [4, 5]}]]]]), + dict( + testcase_name="3D_minus_2", + st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]], + axis=-2, # same as 2 + expected=[[[[{"x": [1]}, {"x": [2]}]], [[{"x": [3]}]]], + [[[{"x": [4, 5]}]]]]), + dict( + testcase_name="3D_3", + st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]], + axis=3, + expected=[[[[{"x": [1]}], [{"x": [2]}]], [[{"x": [3]}]]], + [[[{"x": [4, 5]}]]]]), + ]) # pyformat: disable + def testExpandDims(self, st, axis, expected): + st = StructuredTensor.from_pyval(st) + result = array_ops.expand_dims(st, axis) + self.assertAllEqual(result, expected) + + def testExpandDimsAxisTooBig(self): + st = [[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]] + st = StructuredTensor.from_pyval(st) + with self.assertRaisesRegex(ValueError, + "axis=4 out of bounds: expected -4<=axis<4"): + array_ops.expand_dims(st, 4) + + def testExpandDimsAxisTooSmall(self): + st = [[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]] + st = StructuredTensor.from_pyval(st) + with self.assertRaisesRegex(ValueError, + "axis=-5 out of bounds: expected -4<=axis<4"): + array_ops.expand_dims(st, -5) + + def testExpandDimsScalar(self): + # Note that if we expand_dims for the final dimension and there are scalar + # fields, then the shape is (2, None, None, 1), whereas if it is constructed + # from pyval it is (2, None, None, None). + st = [[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]] + st = StructuredTensor.from_pyval(st) + result = array_ops.expand_dims(st, 3) + expected_shape = tensor_shape.TensorShape([2, None, None, 1]) + self.assertEqual(repr(expected_shape), repr(result.shape)) + def testTupleFieldValue(self): st = StructuredTensor.from_pyval({"a": 5, "b": {"c": [1, 2, 3]}}) self.assertAllEqual(st.field_value(("a",)), 5)