Update StructuredTensor to use merged RowPartitions.

- StructuredTensor constructor is now private; please use StructuredTensor.from_fields() instead.
- StructuredTensor.nested_row_splits was replaced by StructuredTensor.row_partitions.
- New method: StructuredTensor.partition_outer_dimension()
- New method: StructuredTensor.merge_dims()

Also added tests to bring test coverage up to 100%.

PiperOrigin-RevId: 308609833
Change-Id: I2c475a405708c9b8e44b3c559b33e442767ae596
This commit is contained in:
Edward Loper 2020-04-27 06:39:32 -07:00 committed by TensorFlower Gardener
parent 7436755e80
commit 6d34ddd808
5 changed files with 1324 additions and 342 deletions

View File

@ -1374,7 +1374,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
if not outer_axis < inner_axis:
raise ValueError("Expected outer_axis (%d) to be less than "
"inner_axis (%d)" % (outer_axis, inner_axis))
return _merge_dims(self, outer_axis, inner_axis)
return merge_dims(self, outer_axis, inner_axis)
def _set_shape(self, shape):
"""Updates the static shape of `self` to be `shape`.
@ -2491,7 +2491,7 @@ def _nrows(tensor, out_type=dtypes.int32):
return array_ops.shape(tensor, out_type=out_type)[0]
def _merge_dims(value, outer_axis, inner_axis):
def merge_dims(value, outer_axis, inner_axis):
"""Merges value[outer_axis...inner_axis] into a single dimension.
See `RaggedTensor.merge_dims()` for more details. This helper differs from
@ -2529,7 +2529,7 @@ def _merge_dims(value, outer_axis, inner_axis):
# Handle outer_axis>1 via recursion.
if outer_axis > 1:
return value.with_values(
_merge_dims(value.values, outer_axis - 1, inner_axis - 1))
merge_dims(value.values, outer_axis - 1, inner_axis - 1))
# At this point, we know outer_axis == 1, and value is a RaggedTensor.
# So we need to flatten the values and build a corresponding splits tensor.

View File

@ -30,9 +30,12 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_util
from tensorflow.python.ops.ragged.row_partition import RowPartition
from tensorflow.python.util import compat
from tensorflow.python.util import nest
@ -86,143 +89,161 @@ class StructuredTensor(composite_tensor.CompositeTensor):
# Constructor & Factory Methods
#=============================================================================
# TODO(edloper): Add optional shape validation:
# Check that the fields all have the same runtime-shape. (We check static
# shape now, but that doesn't capture ragged shapes or shapes that aren't
# statically known.) I.e., if shape validation is turned on, then check that
# the outer shape.rank dimensions of each value in fields is the same. For
# ragged tensors, this means checking their row-splits.
def __init__(self, shape, fields):
def __init__(self, fields, shape, nrows, row_partitions, internal=False):
"""Private constructor -- use factory methods to create StructuredTensors.
This constructor builds a `StructuredTensor` from the given attributes,
performing minimal validation.
Args:
fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or
`StructuredTensor`. (This dict is not copied, so the caller must ensure
that it does not get mutated via leaked references.)
shape: `tf.TensorShape` with statically known rank.
nrows: scalar integer `tf.Tensor`, or `None` if `shape.rank==0`.
row_partitions: tuple of `RowPartition`s, with length `shape.rank-1`.
internal: Private key value, required to ensure that this private
constructor is *only* called from the factory methods.
"""
if internal is not _structured_tensor_factory_key:
raise ValueError('StructuredTensor constructor is private; please use '
'one of the factory methods instead (e.g., '
'StructuredTensor.from_fields())')
assert isinstance(fields, dict), fields
assert isinstance(shape, tensor_shape.TensorShape), shape
assert nrows is None or isinstance(nrows, ops.Tensor), nrows
assert isinstance(row_partitions, tuple), row_partitions
self._fields = fields
self._shape = shape
self._nrows = nrows
self._row_partitions = row_partitions
@classmethod
def from_fields(cls,
fields,
shape=(),
nrows=None,
row_partitions=None,
validate=False):
"""Creates a `StructuredTensor` from a dictionary of fields.
Args:
shape: A `TensorShape`: static information about the shape of the
`StructuredTensor`. Must have a known `rank`.
fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or
`StructuredTensor`, providing the values for individual fields in each
structure. If `ndims > 0`, then every tensor in `fields` must have the
same shape in the first `shape.rank` dimensions; and that shape must be
compatible with `shape`.
structure. If `shape.rank > 0`, then every tensor in `fields` must have
the same shape in the first `shape.rank` dimensions; and that shape must
be compatible with `shape`; and
`result[i1...iN][key] = fields[key][i1...iN]` (where `N==shape.rank`).
shape: A `TensorShape`: static information about the shape of the
`StructuredTensor`. Must have a known `rank`. Defaults to scalar
shape (i.e. `rank=0`).
nrows: scalar integer tensor containing the number of rows in this
`StructuredTensor`. Should only be specified if `shape.rank > 0`.
Default value is inferred from the `fields` values. If `fields` is
empty, then this must be specified.
row_partitions: A list of `RowPartition`s describing the (possibly ragged)
shape of this `StructuredTensor`. Should only be specified if
`shape.rank > 1`. Default value is inferred from the `fields` values.
If `fields` is empty, then this must be specified.
validate: If true, then add runtime validation ops that check that the
field values all have compatible shapes in the outer `shape.rank`
dimensions.
Returns:
A `StructuredTensor`.
Examples:
>>> StructuredTensor.from_fields({'x': 1, 'y': [1, 2, 3]})
(FILL THIS IN)
>>> StructuredTensor.from_fields({'foo': [1, 2], 'bar': [3, 4]},
... shape=[2])
(FILL THIS IN)
"""
shape = tensor_shape.as_shape(shape)
rank = shape.ndims
rank = shape.rank
if rank is None:
raise ValueError("StructuredTensor's shape must have known rank.")
if not isinstance(fields, dict):
raise TypeError('fields must be a dictionary, got %s' %
type(fields).__name__)
self._fields = {}
if rank < 2 and row_partitions:
raise ValueError('row_partitions must be None or [] if shape.rank<2')
if rank == 0 and nrows is not None:
raise ValueError('nrows must be None if shape.rank==0')
if row_partitions is not None:
row_partitions = tuple(row_partitions)
if len(row_partitions) != max(0, rank - 1):
raise ValueError('len(row_partitions) must be shape.rank-1')
elif rank < 2:
row_partitions = ()
fields = dict(fields) # Make a private copy.
with ops.name_scope(None, 'StructuredTensor', fields.values()):
for (key, value) in fields.items():
# Validate keys and convert field values to tensors.
for key, value in fields.items():
if not isinstance(key, str):
raise TypeError('Unexpected type for key in `fields`: %r' % key)
if not _FIELD_NAME_RE.match(key):
raise ValueError('Field name %r is not currently allowed.' % key)
if not isinstance(
value, (ops.Tensor, ragged_tensor.RaggedTensor, StructuredTensor)):
if ragged_tensor.is_ragged(value):
value = ragged_tensor.convert_to_tensor_or_ragged_tensor(value)
fields[key] = _convert_to_structured_field_value(value)
# Determine dtype for row_partitions and nrows.
shape_dtype = _find_shape_dtype(fields, nrows, row_partitions)
if nrows is not None:
nrows = ops.convert_to_tensor(nrows, shape_dtype)
# Get the static TensorShape for this StructuredTensor.
if rank > 0:
for key, value in fields.items():
if not shape.is_compatible_with(value.shape[:rank]):
raise ValueError('Field {} has shape {}, which is incompatible '
'with the shape that was specified or inferred '
'from other fields: {}'.format(
key, value.shape[:rank], shape))
shape = shape.merge_with(value.shape[:rank])
if rank == 1:
# Find a consistent value for `nrows`.
static_nrows = tensor_shape.dimension_at_index(shape, 0)
for value in fields.values():
nrows, static_nrows = _merge_nrows(nrows, static_nrows, value,
shape_dtype, validate)
if nrows is None:
if static_nrows.value is None:
raise ValueError('nrows must be specified if rank==1 '
'and `fields` is empty.')
else:
try:
value = ops.convert_to_tensor(value)
except (ValueError, TypeError):
raise TypeError('Unexpected type for value in `fields`: %r' %
value)
self._fields[key] = value
nrows = constant_op.constant(static_nrows.value, shape_dtype)
# Check the static TensorShape for this StructuredTensor.
self._static_shape = shape
if rank > 0:
for value in self._fields.values():
self._static_shape = self._static_shape.merge_with(value.shape[:rank])
if rank > 1:
# Find a consistent list of RowPartitions.
for value in fields.values():
row_partitions = _merge_row_partitions(row_partitions, value, rank,
shape_dtype, validate)
if row_partitions is None:
if not shape.is_fully_defined():
raise ValueError('row_partitions must be specified if rank>1 '
'and `fields` is empty.')
else:
row_partitions = _row_partitions_for_uniform_shape(
np.array(shape.as_list(), dtype=shape_dtype.as_numpy_dtype),
shape.rank)
assert len(row_partitions) == rank - 1
nrows = row_partitions[0].nrows()
# Update all field values to use the shared RowPartition objects.
fields = dict([(k, _replace_row_partitions(v, row_partitions))
for (k, v) in fields.items()])
self._nested_row_splits = []
if rank > 1:
# If any fields are ragged, then check that all row-splits match.
shared_row_splits = []
for field in self._fields.values():
# TODO(edloper): A field shouldn't count as ragged if it has
# uniform_row_length defined for all the dimensions in question.
if isinstance(field, ragged_tensor.RaggedTensor):
shared_row_splits.append(field.nested_row_splits[:rank - 1])
elif isinstance(field, StructuredTensor) and field.ragged_rank > 0:
shared_row_splits.append(field.nested_row_splits[:rank - 1])
if shared_row_splits:
if len(shared_row_splits) != len(self._fields):
raise ValueError('Ragged StructuredTensor contains non-ragged fields')
# Check if the splits are identical. This should be the common case.
identical_splits = True
for splits in shared_row_splits[1:]:
if len(splits) != len(shared_row_splits[0]):
raise ValueError('Fields have inconsistent ragged_rank')
for (s1, s2) in zip(splits, shared_row_splits[0]):
if s1 is not s2:
identical_splits = False
if identical_splits:
self._nested_row_splits = shared_row_splits[0]
else:
# If splits aren't identical, then add assertions to check that they
# match.
with ops.control_dependencies(
ragged_util.assert_splits_match(shared_row_splits)):
self._nested_row_splits = [array_ops.identity(splits)
for splits in shared_row_splits[0]]
# TODO(edloper): Rebuild all fields to ensure that they use the
# identical row_splits tensor.
@classmethod
def from_row_splits(cls, values, row_splits, validate=True):
"""Creates a ragged StructuredTensor with rows partitioned by `row_splits`.
See `tf.RaggedTensor` for information about row_splits.
Args:
values: A `StructuredTensor` with shape `[nvals, ...]`.
row_splits: A 1-D integer tensor with shape `[nrows+1]`. Must not be
empty, and must be sorted in ascending order. `row_splits[0]` must be
zero and `row_splits[-1]` must be `nvals`.
validate: If true, then use assertions to check that the arguments form
a valid ragged `StructuredTensor`.
Returns:
A ragged `StructuredTensor`. `result.rank = values.rank + 1`.
"""
if not isinstance(values, StructuredTensor):
raise TypeError('values must be a StructuredTensor.')
if values.shape.rank == 0:
raise ValueError('Shape %s must have rank at least 1' % values.shape)
row_splits = ops.convert_to_tensor(row_splits, name='row_splits')
row_splits.shape.assert_has_rank(1)
if tensor_shape.dimension_value(row_splits.shape[0]) == 0:
raise ValueError('row_splits may not be empty')
if row_splits.dtype not in (dtypes.int32, dtypes.int64):
raise ValueError('Row-partitioning tensors must have dtype '
'int32 or int64')
if (row_splits.shape and
tensor_shape.dimension_value(row_splits.shape[0]) is not None):
nrows = tensor_shape.dimension_value(row_splits.shape[0]) - 1
else:
nrows = None
result_shape = tensor_shape.TensorShape([nrows, None
]).concatenate(values.shape[1:])
result_fields = {}
for (name, field) in values._fields.items():
if isinstance(field, StructuredTensor):
result_fields[name] = StructuredTensor.from_row_splits(
field, row_splits)
else:
result_fields[name] = ragged_tensor.RaggedTensor.from_row_splits(
field, row_splits, validate=validate)
return cls(result_shape, result_fields)
# @TODO(edloper): Add from_row_lengths, etc.
return cls(
fields,
shape,
nrows,
row_partitions,
internal=_structured_tensor_factory_key)
#=============================================================================
# Properties
@ -231,7 +252,7 @@ class StructuredTensor(composite_tensor.CompositeTensor):
@property
def rank(self):
"""The rank of this StructuredTensor. Guaranteed not to be `None`."""
return self._static_shape.rank
return self._shape.rank
@property
def shape(self):
@ -243,33 +264,34 @@ class StructuredTensor(composite_tensor.CompositeTensor):
Returns:
`tf.TensorShape`
"""
return self._static_shape
return self._shape
# TODO(edloper): Make this a func instead of a property? Or make nrows
# a property instead of a func? Seems like these should be consistent.
@property
def nested_row_splits(self):
"""A tuple containing the row_splits for all ragged dimensions.
def row_partitions(self):
"""A tuple of `RowPartition`s defining the shape of this `StructuredTensor`.
If non-empty, then every `field` in this StructuredTensor is ragged, and
has these `nested_row_splits` as their outermost row-splits tensors.
If this `StructuredTensor` has a ragged shape, then all fields will be
encoded as either `RaggedTensor`s or `StructuredTensor`s with these
`RowPartition`s used to define their outermost `self.rank` dimensions.
If this `StructuredTensor` has a uniform (non-ragged) shape, then these
row partitions will all be defined using `uniform_row_length`.
Returns:
A `tuple` of 1-D integer `Tensor`s. The length of this tuple will
always be less than `self.rank`.
A `tuple` of `RowPartition` objects with length `self.rank - 1`
(or `0` if `self.rank < 2`).
"""
return self._nested_row_splits
return self._row_partitions
@property
def ragged_rank(self):
"""The number of ragged dimensions in this StructuredTensor.
See `tf.RaggedTensor` for more information about ragged dimensions and
`ragged_rank`.
def nrows(self):
"""The number of rows in this StructuredTensor (if rank>0).
Returns:
A Python `int` indicating the number of ragged dimensions in this ragged
tensor. The outermost dimension is not considered ragged.
A scalar integer `Tensor` (or `None` if `self.rank == 0`).
"""
return len(self._nested_row_splits)
return self._nrows
def _is_eager(self):
"""True if all fields are composed of eager tensors."""
@ -304,10 +326,16 @@ class StructuredTensor(composite_tensor.CompositeTensor):
Returns:
`Tensor`, `StructuredTensor`, or `RaggedTensor`.
Raises:
KeyError: If the given field_name is not found.
"""
if isinstance(field_name, (list, tuple)):
value = self
for f in field_name:
if not isinstance(value, StructuredTensor):
raise KeyError('Field path {} not found in {}'.format(
field_name, self))
value = value.field_value(f)
return value
return self._fields[field_name]
@ -356,17 +384,17 @@ class StructuredTensor(composite_tensor.CompositeTensor):
if not key:
return self
if self._static_shape.ndims == 0:
if self._shape.rank == 0:
return self._scalar_getitem(key)
else:
return self._tensor_getitem(key)
def _scalar_getitem(self, key):
if (isinstance(key[0], slice) and slice.start is None and
slice.stop is None and slice.step is None):
if (isinstance(key[0], slice) and key[0].start is None and
key[0].stop is None and key[0].step is None):
fields = dict((field_name, field_value.__getitem__(key[1:]))
for (field_name, field_value) in self._fields.items())
return StructuredTensor(self._static_shape[1:], fields)
return StructuredTensor.from_fields(fields, self._shape)
elif not isinstance(key[0], compat.bytes_or_text_types):
raise ValueError('Key for indexing a StructuredTensor must be a '
@ -375,7 +403,7 @@ class StructuredTensor(composite_tensor.CompositeTensor):
return self._fields[key[0]].__getitem__(key[1:])
def _tensor_getitem(self, key):
rank = self._static_shape.ndims
rank = self._shape.rank
if len(key) <= rank:
new_fields = dict((field_name, field_value.__getitem__(key))
for (field_name, field_value) in self._fields.items())
@ -387,20 +415,23 @@ class StructuredTensor(composite_tensor.CompositeTensor):
result_shape[d] = None
elif isinstance(k, (int, ops.Tensor)):
result_shape[d] = -1 # mark for deletion
elif k is None:
raise ValueError('Slicing not supported for tf.newaxis')
else:
# Ellipsis, tf.newaxis:
raise ValueError('Slicing not supported for %r' % k)
result_shape = [d for d in result_shape if d != -1]
return StructuredTensor(result_shape, new_fields)
return StructuredTensor.from_fields(new_fields, result_shape)
else:
if not isinstance(key[rank], compat.bytes_or_text_types):
# TODO(edloper): Also support full slice here?
raise ValueError('Key for indexing a StructuredTensor must be a string')
return self._fields[key[rank]].__getitem__(key[:rank] + key[rank + 1:])
def __repr__(self):
return '<StructuredTensor(shape=%s, fields=%r)>' % (self._static_shape,
self._fields)
return '<StructuredTensor(fields={%s}, shape=%s)>' % (', '.join(
'%r' % k for k in sorted(self._fields)), self._shape)
#=============================================================================
# Conversion
@ -427,7 +458,8 @@ class StructuredTensor(composite_tensor.CompositeTensor):
Requires that all fields are Eager tensors.
>>> print(StructuredTensor([3], {'a': [1, 2, 3]}).to_pyval())
>>> print(StructuredTensor.from_fields(
... {'a': [1, 2, 3]}, [3]).to_pyval())
[{b'a': 1}, {b'a': 2}, {b'a': 3}]
Note that `StructuredTensor.from_pyval(pyval).to_pyval() == pyval`.
@ -454,10 +486,9 @@ class StructuredTensor(composite_tensor.CompositeTensor):
result[key] = value
# If rank>0, then re-group each value from dict-of-list to list-of-dict.
if len(self._static_shape) > 0: # pylint: disable=g-explicit-length-test
return _pyval_field_major_to_node_major(list(result.keys()),
list(result.values()),
self._static_shape.as_list())
if len(self._shape) > 0: # pylint: disable=g-explicit-length-test
return _pyval_field_major_to_node_major(
list(result.keys()), list(result.values()), self._shape.as_list())
else:
return result
@ -503,12 +534,12 @@ class StructuredTensor(composite_tensor.CompositeTensor):
spec_shape = typespec._shape # pylint: disable=protected-access
field_specs = typespec._field_specs # pylint: disable=protected-access
if not (isinstance(typespec, StructuredTensorSpec) and
spec_shape.ndims == 0 and set(pyval) == set(field_specs)):
spec_shape.rank == 0 and set(pyval) == set(field_specs)):
raise ValueError('Value does not match typespec: %r vs %r' %
(pyval, typespec))
fields = dict(
(k, cls.from_pyval(v, field_specs[k])) for (k, v) in pyval.items())
return StructuredTensor(shape=(), fields=fields)
return StructuredTensor.from_fields(fields=fields, shape=(), validate=False)
@classmethod
def _from_pylist_of_dict(cls, pyval, keys, rank, typespec):
@ -532,10 +563,8 @@ class StructuredTensor(composite_tensor.CompositeTensor):
'%r vs %r' % (pyval, typespec))
for (key, spec) in field_specs.items():
fields[key] = cls.from_pyval(fields.get(key, []), spec)
if not spec.is_compatible_with(fields[key]):
raise ValueError('Value does not match typespec: %r vs %r' %
(spec, fields[key]))
return StructuredTensor(shape=shape, fields=fields)
return StructuredTensor.from_fields(
fields=fields, shape=shape, validate=False)
@classmethod
def _from_pylist_of_value(cls, pyval, typespec):
@ -543,8 +572,11 @@ class StructuredTensor(composite_tensor.CompositeTensor):
if typespec is None:
return ragged_factory_ops.constant(pyval)
elif isinstance(typespec, tensor_spec.TensorSpec):
# TODO(edloper): Check that typespec.shape matches.
return constant_op.constant(pyval, typespec.dtype)
result = constant_op.constant(pyval, typespec.dtype)
if not typespec.shape.is_compatible_with(result.shape):
raise ValueError('Value does not match typespec: %r vs %r' %
(typespec, pyval))
return result
elif isinstance(typespec, ragged_tensor.RaggedTensorSpec):
# pylint: disable=protected-access
return ragged_factory_ops.constant(
@ -571,11 +603,83 @@ class StructuredTensor(composite_tensor.CompositeTensor):
return constant_op.constant(pyval)
else:
if not (isinstance(typespec, tensor_spec.TensorSpec) and
typespec.shape.ndims == 0):
raise ValueError('Value does not match typespec.')
typespec.shape.rank == 0):
raise ValueError('Value does not match typespec: %r vs %r' %
(typespec, pyval))
# TODO(edloper): Check that typespec.shape matches.
return constant_op.constant(pyval, typespec.dtype)
#=============================================================================
# Transforms
#=============================================================================
# TODO(edloper): Add a 'validate' option here?
# TODO(edloper): Unify nomenclature with RaggedTensor. Should RaggedTensor
# have a partition_outer_dimension method?
def partition_outer_dimension(self, row_partition):
"""Partitions the outer dimension of this StructuredTensor.
Returns a new `StructuredTensor` with the same values as `self`, where
the outer dimension is partitioned into two (possibly ragged) dimensions.
Requires that this StructuredTensor have an outer dimension (i.e.,
`self.shape.rank > 0`).
>>> st = StructuredTensor.from_pyval(
... [{'foo': 12}, {'foo': 33}, {'foo': 99}])
>>> partition = RowPartition.from_row_lengths([2, 0, 1])
>>> st.partition_outer_dimension(partition)
<StructuredTensor [[{'foo': 12}, {'foo': 33}], [], [{'foo': 99}]]>
Args:
row_partition: A `RowPartition`.
Returns:
A `StructuredTensor` with rank `values.rank + 1`.
"""
if not isinstance(row_partition, RowPartition):
raise TypeError('row_partition must be a RowPartition.')
if self.shape.rank == 0:
raise ValueError('Shape %s must have rank at least 1' % self.shape)
return _partition_outer_dimension(self, row_partition)
def merge_dims(self, outer_axis, inner_axis):
"""Merges outer_axis...inner_axis into a single dimension.
Returns a copy of this RaggedTensor with the specified range of dimensions
flattened into a single dimension, with elements in row-major order.
>>> st = StructuredTensor.from_pyval(
... [[{'foo': 12}, {'foo': 33}], [], [{'foo': 99}]])
>>> st.merge_dims(0, 1)
<StructuredTensor [{'foo': 12}, {'foo': 33}, {'foo': 99}]>
Args:
outer_axis: `int`: The first dimension in the range of dimensions to
merge. May be negative (to index from the last dimension).
inner_axis: `int`: The last dimension in the range of dimensions to merge.
May be negative (to index from the last dimension).
Returns:
A copy of this tensor, with the specified dimensions merged into a
single dimension. The shape of the returned tensor will be
`self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N`
is the total number of slices in the merged dimensions.
"""
outer_axis = array_ops.get_positive_axis(
outer_axis,
self.shape.rank,
axis_name='outer_axis',
ndims_name='rank(self)')
inner_axis = array_ops.get_positive_axis(
inner_axis,
self.shape.rank,
axis_name='inner_axis',
ndims_name='rank(self)')
if not outer_axis < inner_axis:
raise ValueError('Expected outer_axis (%d) to be less than '
'inner_axis (%d)' % (outer_axis, inner_axis))
return _merge_dims(self, outer_axis, inner_axis)
#=============================================================================
# Composite Tensor
#=============================================================================
@ -594,7 +698,7 @@ class StructuredTensorSpec(type_spec.BatchableTypeSpec):
"""Build a type specification for a StructuredTensor.
Args:
shape: The shape of the StructuredTensor. shape.ndims must not be None.
shape: The shape of the StructuredTensor. shape.rank must not be None.
field_specs: A dictionary mapping from field name to TypeSpec, specifying
the tensor type used to encode each field. These TypeSpecs should
specify the type of the entire field (including outer dimensions which
@ -602,20 +706,23 @@ class StructuredTensorSpec(type_spec.BatchableTypeSpec):
contains an int32 vector of size `10` for each structure, then
`field_specs['x']` should be `tf.TensorSpec([2, 3, 10], tf.int32)`.
"""
self._shape = tensor_shape.as_shape(shape)
self._field_specs = dict(field_specs)
shape = tensor_shape.as_shape(shape)
# Perform a few sanity checks on the inputs.
if self._shape.ndims is None:
if shape.rank is None:
raise TypeError("StructuredTensor's shape must have known rank.")
if not isinstance(self._field_specs, dict):
raise TypeError('field_specs must be a dictionary')
for key, value in self._field_specs.items():
if not isinstance(field_specs, dict):
raise TypeError('field_specs must be a dictionary.')
for key, value in field_specs.items():
if not isinstance(key, str):
raise TypeError('field_specs must be a dictionary with string keys.')
if not isinstance(value, (StructuredTensorSpec, tensor_spec.TensorSpec,
ragged_tensor.RaggedTensorSpec)):
raise TypeError('field_spec must be a dictionary with TypeSpec values.')
raise TypeError('field_specs must be a dictionary with '
'TypeSpec values.')
self._shape = shape
self._field_specs = dict(field_specs)
@property
def value_type(self):
@ -625,7 +732,7 @@ class StructuredTensorSpec(type_spec.BatchableTypeSpec):
return value._fields
def _from_components(self, components):
return StructuredTensor(self._shape, components)
return StructuredTensor.from_fields(components, self._shape, validate=False)
@property
def _component_specs(self):
@ -659,6 +766,143 @@ class StructuredTensorSpec(type_spec.BatchableTypeSpec):
_FIELD_NAME_RE = re.compile('^[a-zA-Z][a-zA-Z0-9_]*$')
#=============================================================================
# Helper funtions
#=============================================================================
# TODO(edloper): Move some of these helpers to row_partition.py?
def _convert_to_structured_field_value(value):
"""Converts `value` to a Tensor, RaggedTensor, or StructuredTensor."""
if isinstance(value,
(ops.Tensor, ragged_tensor.RaggedTensor, StructuredTensor)):
return value
elif ragged_tensor.is_ragged(value):
return ragged_tensor.convert_to_tensor_or_ragged_tensor(value)
else:
try:
return ops.convert_to_tensor(value)
except (ValueError, TypeError):
raise TypeError('Unexpected type for value in `fields`: %r' % value)
def _find_shape_dtype(fields, nrows, row_partitions):
"""Return a consistent dtype for fields, nrows, & row_partitions."""
shape_dtypes = set()
for value in fields.values():
if isinstance(value, ragged_tensor.RaggedTensor):
shape_dtypes.add(value.row_splits.dtype)
elif isinstance(value, StructuredTensor) and value.rank > 0:
shape_dtypes.add(value.nrows().dtype)
if isinstance(nrows, ops.Tensor):
shape_dtypes.add(nrows.dtype)
if row_partitions is not None:
for partition in row_partitions:
shape_dtypes.add(partition.dtype)
if len(shape_dtypes) > 1:
raise ValueError('field values have incompatible row_partition dtypes.')
elif shape_dtypes:
return shape_dtypes.pop()
else:
return dtypes.int64
def _merge_nrows(nrows, static_nrows, value, dtype, validate):
"""Merges `nrows` with `nrows(value)`.
Checks that `value` has the expected number of rows (`nrows`), and returns
`nrows`. If `validate` is true, then add validation ops that check that
the `nrows` values match.
Args:
nrows: scalar integer Tensor.
static_nrows: tf.Dimension: static value of nrows, if known.
value: Tensor or RaggedTensor or StructuredTensor
dtype: dtype for `nrows`.
validate: bool -- whether to add validation ops.
Returns:
A tuple `(nrows, static_nrows)`.
"""
static_value_nrows = tensor_shape.dimension_at_index(value.shape, 0)
if isinstance(value, ops.Tensor):
value_nrows = array_ops.shape(value, out_type=dtype)[0]
else:
value_nrows = value.nrows()
if nrows is None:
nrows = value_nrows
elif (static_value_nrows.value is not None and
static_nrows.value is not None):
if not static_value_nrows.is_compatible_with(static_nrows):
raise ValueError('fields have incompatible nrows')
nrows = value_nrows # No need to add an assertion op.
elif validate:
nrows = control_flow_ops.with_dependencies([
check_ops.assert_equal(nrows, value_nrows,
message='fields have incompatible nrows')
], nrows)
return nrows, static_nrows.merge_with(static_value_nrows)
def _merge_row_partitions(row_partitions, value, rank, dtype, validate):
"""Merges `row_partitions` with `row_partitions(value)`."""
if isinstance(value, ops.Tensor):
value_row_partitions = _row_partitions_for_tensor(value, rank, dtype)
elif isinstance(value, ragged_tensor.RaggedTensor):
value_row_partitions = _row_partitions_for_ragged_tensor(value, rank, dtype)
else:
assert isinstance(value, StructuredTensor), type(value)
value_row_partitions = value.row_partitions[:rank - 1]
assert len(value_row_partitions) == rank - 1
if row_partitions is None:
return tuple(value_row_partitions)
else:
return tuple([
p1.merge_precomputed_encodings(p2, validate)
for (p1, p2) in zip(row_partitions, value_row_partitions)
])
def _row_partitions_for_tensor(value, rank, dtype):
"""Returns the row partitions for a tf.Tensor."""
shape = array_ops.shape(value, out_type=dtype)
return _row_partitions_for_uniform_shape(shape, rank)
def _row_partitions_for_ragged_tensor(value, rank, dtype):
"""Returns the row partitions for a tf.RaggedTensor."""
assert rank > 1
value_row_partitions = value._nested_row_partitions[:rank - 1] # pylint: disable=protected-access
if len(value_row_partitions) < (rank - 1):
value_row_partitions += _row_partitions_for_tensor(
value.flat_values, rank - len(value_row_partitions), dtype)
assert len(value_row_partitions) == rank - 1
return value_row_partitions
def _row_partitions_for_uniform_shape(shape, rank):
"""Returns row partitions for the given shape Tensor.
Args:
shape: A vector describing a uniform shape.
rank: The number of dimensions to generate row partitions for
Returns:
A list of (rank-1) `RowPartition`s with uniform row length.
"""
shape_cumprod = math_ops.cumprod(shape[:rank])
# pylint: disable=g-complex-comprehension
return tuple([
RowPartition.from_uniform_row_length(
uniform_row_length=shape[i + 1],
nvals=shape_cumprod[i + 1],
nrows=shape_cumprod[i]) for i in range(rank - 1)
])
def _pyval_field_major_to_node_major(keys, values, shape):
"""Regroup each field (k, v) from dict-of-list to list-of-dict.
@ -767,3 +1011,140 @@ def _pyval_empty_list_depth(pyval):
return max(depths) + 1
else:
return None
def _replace_row_partitions(value, new_partitions):
"""Updates `value` to use `new_partitions` as its (outer) row partitions.
This is used to ensure that all fields in a `StructuredTensor` use identical
`RowPartition` objects for the shared dimensions. In particular,
`StructuredTensor.from_fields` first merges all of the row partitions from
any fields, and then replaces the outer row partitions of all fields with
the merged row partitions (using this function).
Args:
value: A `Tensor`, `RaggedTensor`, or `StructuredTensor`.
new_partitions: A list of row-partitions that should be used by `value`.
Must be equivalent to `value`'s current row partitions.
Returns:
A value that is equivalent to `value`, where outer row partitions have been
replaced by `new_partitions`.
"""
if isinstance(value, ops.Tensor) or not new_partitions:
return value
elif isinstance(value, ragged_tensor.RaggedTensor):
return ragged_tensor.RaggedTensor._from_row_partition( # pylint: disable=protected-access
values=_replace_row_partitions(value.values, new_partitions[1:]),
row_partition=new_partitions[0])
else:
assert isinstance(value, StructuredTensor)
new_fields = dict((k, _replace_row_partitions(v, new_partitions))
for (k, v) in value._fields.items())
return StructuredTensor(
fields=new_fields,
shape=value.shape,
nrows=value.nrows(),
row_partitions=new_partitions,
internal=_structured_tensor_factory_key)
def _partition_outer_dimension(value, row_partition):
"""Partitions the outer dimension of `value` using `row_partitions`.
Examples:
>>> partition = row_partition.RowPartition.from_row_lengths([2, 0, 1])
>>> _partition_outer_dimension(tf.constant([1, 2, 3]), partition)
[[1, 2], [], [3]]
>>> struct_value = StructuredTensor.from_pyval(
... [{'x': 1}, {'x': 2}, {'x': 3}])
>>> _partition_outer_dimension(struct_value, partition)
[[{'x': 1}, {'x': 2}], [], [{'x': 3}]])
Args:
value: Tensor, RaggedTensor, or StructuredTensor
row_partition: RowPartition
Returns:
A value with the same type as `value`, where
`result.rank = value.rank + 1`.
"""
is_ragged = row_partition.uniform_row_length() is None
if isinstance(value, ops.Tensor) and not is_ragged:
new_shape = array_ops.concat(
[[row_partition.nrows(),
row_partition.uniform_row_length()],
array_ops.shape(value, out_type=row_partition.dtype)[2:]],
axis=0)
return array_ops.reshape(value, new_shape)
elif isinstance(value, (ops.Tensor, ragged_tensor.RaggedTensor)):
return ragged_tensor.RaggedTensor._from_row_partition( # pylint: disable=protected-access
value, row_partition)
else:
assert isinstance(value, StructuredTensor)
nrows = row_partition.static_nrows
ncols = row_partition.static_uniform_row_length
shape = tensor_shape.TensorShape([nrows, ncols]).concatenate(
value.shape[2:])
fields = dict((k, _partition_outer_dimension(v, row_partition))
for (k, v) in value._fields.items())
return StructuredTensor(
fields,
shape,
row_partition.nrows(), (row_partition,) + value.row_partitions,
internal=_structured_tensor_factory_key)
def _merge_dims(value, outer_axis, inner_axis):
"""Merges `outer_axis...inner_axis` of `value` into a single dimension."""
assert outer_axis < inner_axis
if isinstance(value, (ops.Tensor, ragged_tensor.RaggedTensor)):
return ragged_tensor.merge_dims(value, outer_axis, inner_axis)
else:
assert isinstance(value, StructuredTensor)
# Build the new fields.
fields = dict((k, _merge_dims(v, outer_axis, inner_axis))
for (k, v) in value._fields.items())
# Build the new shape.
value_shape = value.shape
shape = (
value_shape[:outer_axis] +
[value_shape[outer_axis:inner_axis].num_elements()] +
value_shape[inner_axis + 1:])
# Build the new row_partitions & nrows
if outer_axis == 0:
if inner_axis == value.shape.rank - 1:
partitions = ()
nrows = value.row_partitions[-1].nvals()
else:
partitions = value.row_partitions[inner_axis:]
nrows = partitions[0].nrows()
else:
# Use tf.gather to merge row_splits from the merged row partitions.
merged_splits = value.row_partitions[outer_axis - 1].row_splits()
for dim in range(outer_axis, inner_axis):
merged_splits = array_ops.gather(value.row_partitions[dim].row_splits(),
merged_splits)
partitions = (
value.row_partitions[:outer_axis - 1] +
(RowPartition.from_row_splits(merged_splits),) +
value.row_partitions[inner_axis:])
nrows = partitions[0].nrows()
return StructuredTensor(
fields,
shape,
nrows,
partitions,
internal=_structured_tensor_factory_key)
_structured_tensor_factory_key = object() # unique private object

View File

@ -201,6 +201,10 @@ class StructuredTensorSliceTest(test_util.TensorFlowTestCase,
(SLICE_BUILDER["f4", 1:, "f4_2"], [b"b"]),
(SLICE_BUILDER["f4", :, "f4_2"], [b"a", b"b"]),
(SLICE_BUILDER["f5", :, :, "f5_1"], [[1, 2], [3, 4]]),
# Slicing over multiple keys
(SLICE_BUILDER[:], EXAMPLE_STRUCT),
# List-valued key.
(["f2", 1], EXAMPLE_STRUCT["f2"][1]),
])
def testGetitemFromScalarStruct(self, slice_spec, expected):
# By default, lists are converted to RaggedTensors.
@ -242,6 +246,29 @@ class StructuredTensorSliceTest(test_util.TensorFlowTestCase,
# TODO(edloper): Add tests for slicing from matrix StructuredTensors.
@parameterized.parameters([
(SLICE_BUILDER[:2], r"Key for indexing a StructuredTensor must be "
r"a string or a full slice \(':'\)"),
(SLICE_BUILDER["f4", ...], r"Slicing not supported for Ellipsis"),
(SLICE_BUILDER["f4", None], r"Slicing not supported for tf.newaxis"),
(SLICE_BUILDER["f4", :, 0],
r"Key for indexing a StructuredTensor must be a string"),
])
def testGetItemError(self, slice_spec, error, exception=ValueError):
struct = structured_tensor.StructuredTensor.from_pyval(EXAMPLE_STRUCT)
with self.assertRaisesRegexp(exception, error):
struct.__getitem__(slice_spec)
@parameterized.parameters([
(SLICE_BUILDER[:, 1],
r"Key for indexing a StructuredTensor must be a string"),
])
def testGetItemFromVectorError(self, slice_spec, error, exception=ValueError):
struct = structured_tensor.StructuredTensor.from_pyval(
EXAMPLE_STRUCT_VECTOR)
with self.assertRaisesRegexp(exception, error):
struct.__getitem__(slice_spec)
if __name__ == "__main__":
googletest.main()

View File

@ -81,6 +81,18 @@ class StructuredTensorSpecTest(test_util.TensorFlowTestCase,
self.assertEqual(spec2._shape, (1, 2))
self.assertEqual(spec2._field_specs, spec2_fields)
@parameterized.parameters([
(None, {}, r"StructuredTensor's shape must have known rank\."),
([], None, r'field_specs must be a dictionary\.'),
([], {1: tensor_spec.TensorSpec(None)},
r'field_specs must be a dictionary with string keys\.'),
([], {'x': 0},
r'field_specs must be a dictionary with TypeSpec values\.'),
])
def testConstructionErrors(self, shape, field_specs, error):
with self.assertRaisesRegexp(TypeError, error):
structured_tensor.StructuredTensorSpec(shape, field_specs)
def testValueType(self):
spec1 = StructuredTensorSpec([1, 2, 3], dict(a=T_1_2))
self.assertEqual(spec1.value_type, StructuredTensor)
@ -118,11 +130,13 @@ class StructuredTensorSpecTest(test_util.TensorFlowTestCase,
'fields': dict(x=[[1.0, 2.0]]),
'field_specs': dict(x=T_1_2),
},
{
'shape': [1, 2, 3],
'fields': {},
'field_specs': {},
},
# TODO(edloper): Enable this test once we update StructuredTensorSpec
# to contain the shared row partitions.
#{
# 'shape': [1, 2, 3],
# 'fields': {},
# 'field_specs': {},
#},
{
'shape': [2],
'fields': dict(
@ -133,7 +147,7 @@ class StructuredTensorSpecTest(test_util.TensorFlowTestCase,
]) # pyformat: disable
def testToFromComponents(self, shape, fields, field_specs):
components = fields
struct = StructuredTensor(shape, fields)
struct = StructuredTensor.from_fields(fields, shape)
spec = StructuredTensorSpec(shape, field_specs)
actual_components = spec._to_components(struct)
self.assertAllTensorsEqual(actual_components, components)
@ -164,39 +178,40 @@ class StructuredTensorSpecTest(test_util.TensorFlowTestCase,
@parameterized.parameters([
{
'unbatched': lambda: [
StructuredTensor([], {'a': 1, 'b': [5, 6]}),
StructuredTensor([], {'a': 2, 'b': [7, 8]})],
StructuredTensor.from_fields({'a': 1, 'b': [5, 6]}),
StructuredTensor.from_fields({'a': 2, 'b': [7, 8]})],
'batch_size': 2,
'batched': lambda: StructuredTensor([2], {
'batched': lambda: StructuredTensor.from_fields(shape=[2], fields={
'a': [1, 2],
'b': [[5, 6], [7, 8]]}),
},
{
'unbatched': lambda: [
StructuredTensor([3], {
StructuredTensor.from_fields(shape=[3], fields={
'a': [1, 2, 3],
'b': [[5, 6], [6, 7], [7, 8]]}),
StructuredTensor([3], {
StructuredTensor.from_fields(shape=[3], fields={
'a': [2, 3, 4],
'b': [[2, 2], [3, 3], [4, 4]]})],
'batch_size': 2,
'batched': lambda: StructuredTensor([2, 3], {
'batched': lambda: StructuredTensor.from_fields(shape=[2, 3], fields={
'a': [[1, 2, 3], [2, 3, 4]],
'b': [[[5, 6], [6, 7], [7, 8]],
[[2, 2], [3, 3], [4, 4]]]}),
},
{
'unbatched': lambda: [
StructuredTensor([], {
StructuredTensor.from_fields(shape=[], fields={
'a': 1,
'b': StructuredTensor([], {'x': [5]})}),
StructuredTensor([], {
'b': StructuredTensor.from_fields({'x': [5]})}),
StructuredTensor.from_fields(shape=[], fields={
'a': 2,
'b': StructuredTensor([], {'x': [6]})})],
'b': StructuredTensor.from_fields({'x': [6]})})],
'batch_size': 2,
'batched': lambda: StructuredTensor([2], {
'batched': lambda: StructuredTensor.from_fields(shape=[2], fields={
'a': [1, 2],
'b': StructuredTensor([2], {'x': [[5], [6]]})}),
'b': StructuredTensor.from_fields(shape=[2], fields={
'x': [[5], [6]]})}),
},
]) # pyformat: disable
def testBatchUnbatchValues(self, unbatched, batch_size, batched):
@ -225,5 +240,6 @@ class StructuredTensorSpecTest(test_util.TensorFlowTestCase,
for (actual, expected) in zip(actual_unbatched, unbatched):
self.assertAllEqual(actual, expected)
if __name__ == '__main__':
googletest.main()

File diff suppressed because it is too large Load Diff