Update StructuredTensor to use merged RowPartitions.
- StructuredTensor constructor is now private; please use StructuredTensor.from_fields() instead. - StructuredTensor.nested_row_splits was replaced by StructuredTensor.row_partitions. - New method: StructuredTensor.partition_outer_dimension() - New method: StructuredTensor.merge_dims() Also added tests to bring test coverage up to 100%. PiperOrigin-RevId: 308609833 Change-Id: I2c475a405708c9b8e44b3c559b33e442767ae596
This commit is contained in:
parent
7436755e80
commit
6d34ddd808
@ -1374,7 +1374,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
|
||||
if not outer_axis < inner_axis:
|
||||
raise ValueError("Expected outer_axis (%d) to be less than "
|
||||
"inner_axis (%d)" % (outer_axis, inner_axis))
|
||||
return _merge_dims(self, outer_axis, inner_axis)
|
||||
return merge_dims(self, outer_axis, inner_axis)
|
||||
|
||||
def _set_shape(self, shape):
|
||||
"""Updates the static shape of `self` to be `shape`.
|
||||
@ -2491,7 +2491,7 @@ def _nrows(tensor, out_type=dtypes.int32):
|
||||
return array_ops.shape(tensor, out_type=out_type)[0]
|
||||
|
||||
|
||||
def _merge_dims(value, outer_axis, inner_axis):
|
||||
def merge_dims(value, outer_axis, inner_axis):
|
||||
"""Merges value[outer_axis...inner_axis] into a single dimension.
|
||||
|
||||
See `RaggedTensor.merge_dims()` for more details. This helper differs from
|
||||
@ -2529,7 +2529,7 @@ def _merge_dims(value, outer_axis, inner_axis):
|
||||
# Handle outer_axis>1 via recursion.
|
||||
if outer_axis > 1:
|
||||
return value.with_values(
|
||||
_merge_dims(value.values, outer_axis - 1, inner_axis - 1))
|
||||
merge_dims(value.values, outer_axis - 1, inner_axis - 1))
|
||||
|
||||
# At this point, we know outer_axis == 1, and value is a RaggedTensor.
|
||||
# So we need to flatten the values and build a corresponding splits tensor.
|
||||
|
@ -30,9 +30,12 @@ from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_util
|
||||
from tensorflow.python.ops.ragged.row_partition import RowPartition
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
@ -86,143 +89,161 @@ class StructuredTensor(composite_tensor.CompositeTensor):
|
||||
# Constructor & Factory Methods
|
||||
#=============================================================================
|
||||
|
||||
# TODO(edloper): Add optional shape validation:
|
||||
# Check that the fields all have the same runtime-shape. (We check static
|
||||
# shape now, but that doesn't capture ragged shapes or shapes that aren't
|
||||
# statically known.) I.e., if shape validation is turned on, then check that
|
||||
# the outer shape.rank dimensions of each value in fields is the same. For
|
||||
# ragged tensors, this means checking their row-splits.
|
||||
def __init__(self, shape, fields):
|
||||
def __init__(self, fields, shape, nrows, row_partitions, internal=False):
|
||||
"""Private constructor -- use factory methods to create StructuredTensors.
|
||||
|
||||
This constructor builds a `StructuredTensor` from the given attributes,
|
||||
performing minimal validation.
|
||||
|
||||
Args:
|
||||
fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or
|
||||
`StructuredTensor`. (This dict is not copied, so the caller must ensure
|
||||
that it does not get mutated via leaked references.)
|
||||
shape: `tf.TensorShape` with statically known rank.
|
||||
nrows: scalar integer `tf.Tensor`, or `None` if `shape.rank==0`.
|
||||
row_partitions: tuple of `RowPartition`s, with length `shape.rank-1`.
|
||||
internal: Private key value, required to ensure that this private
|
||||
constructor is *only* called from the factory methods.
|
||||
"""
|
||||
if internal is not _structured_tensor_factory_key:
|
||||
raise ValueError('StructuredTensor constructor is private; please use '
|
||||
'one of the factory methods instead (e.g., '
|
||||
'StructuredTensor.from_fields())')
|
||||
assert isinstance(fields, dict), fields
|
||||
assert isinstance(shape, tensor_shape.TensorShape), shape
|
||||
assert nrows is None or isinstance(nrows, ops.Tensor), nrows
|
||||
assert isinstance(row_partitions, tuple), row_partitions
|
||||
self._fields = fields
|
||||
self._shape = shape
|
||||
self._nrows = nrows
|
||||
self._row_partitions = row_partitions
|
||||
|
||||
@classmethod
|
||||
def from_fields(cls,
|
||||
fields,
|
||||
shape=(),
|
||||
nrows=None,
|
||||
row_partitions=None,
|
||||
validate=False):
|
||||
"""Creates a `StructuredTensor` from a dictionary of fields.
|
||||
|
||||
Args:
|
||||
shape: A `TensorShape`: static information about the shape of the
|
||||
`StructuredTensor`. Must have a known `rank`.
|
||||
fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or
|
||||
`StructuredTensor`, providing the values for individual fields in each
|
||||
structure. If `ndims > 0`, then every tensor in `fields` must have the
|
||||
same shape in the first `shape.rank` dimensions; and that shape must be
|
||||
compatible with `shape`.
|
||||
structure. If `shape.rank > 0`, then every tensor in `fields` must have
|
||||
the same shape in the first `shape.rank` dimensions; and that shape must
|
||||
be compatible with `shape`; and
|
||||
`result[i1...iN][key] = fields[key][i1...iN]` (where `N==shape.rank`).
|
||||
shape: A `TensorShape`: static information about the shape of the
|
||||
`StructuredTensor`. Must have a known `rank`. Defaults to scalar
|
||||
shape (i.e. `rank=0`).
|
||||
nrows: scalar integer tensor containing the number of rows in this
|
||||
`StructuredTensor`. Should only be specified if `shape.rank > 0`.
|
||||
Default value is inferred from the `fields` values. If `fields` is
|
||||
empty, then this must be specified.
|
||||
row_partitions: A list of `RowPartition`s describing the (possibly ragged)
|
||||
shape of this `StructuredTensor`. Should only be specified if
|
||||
`shape.rank > 1`. Default value is inferred from the `fields` values.
|
||||
If `fields` is empty, then this must be specified.
|
||||
validate: If true, then add runtime validation ops that check that the
|
||||
field values all have compatible shapes in the outer `shape.rank`
|
||||
dimensions.
|
||||
|
||||
Returns:
|
||||
A `StructuredTensor`.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> StructuredTensor.from_fields({'x': 1, 'y': [1, 2, 3]})
|
||||
(FILL THIS IN)
|
||||
|
||||
>>> StructuredTensor.from_fields({'foo': [1, 2], 'bar': [3, 4]},
|
||||
... shape=[2])
|
||||
(FILL THIS IN)
|
||||
|
||||
"""
|
||||
shape = tensor_shape.as_shape(shape)
|
||||
rank = shape.ndims
|
||||
rank = shape.rank
|
||||
if rank is None:
|
||||
raise ValueError("StructuredTensor's shape must have known rank.")
|
||||
if not isinstance(fields, dict):
|
||||
raise TypeError('fields must be a dictionary, got %s' %
|
||||
type(fields).__name__)
|
||||
self._fields = {}
|
||||
if rank < 2 and row_partitions:
|
||||
raise ValueError('row_partitions must be None or [] if shape.rank<2')
|
||||
if rank == 0 and nrows is not None:
|
||||
raise ValueError('nrows must be None if shape.rank==0')
|
||||
if row_partitions is not None:
|
||||
row_partitions = tuple(row_partitions)
|
||||
if len(row_partitions) != max(0, rank - 1):
|
||||
raise ValueError('len(row_partitions) must be shape.rank-1')
|
||||
elif rank < 2:
|
||||
row_partitions = ()
|
||||
|
||||
fields = dict(fields) # Make a private copy.
|
||||
with ops.name_scope(None, 'StructuredTensor', fields.values()):
|
||||
for (key, value) in fields.items():
|
||||
|
||||
# Validate keys and convert field values to tensors.
|
||||
for key, value in fields.items():
|
||||
if not isinstance(key, str):
|
||||
raise TypeError('Unexpected type for key in `fields`: %r' % key)
|
||||
if not _FIELD_NAME_RE.match(key):
|
||||
raise ValueError('Field name %r is not currently allowed.' % key)
|
||||
if not isinstance(
|
||||
value, (ops.Tensor, ragged_tensor.RaggedTensor, StructuredTensor)):
|
||||
if ragged_tensor.is_ragged(value):
|
||||
value = ragged_tensor.convert_to_tensor_or_ragged_tensor(value)
|
||||
fields[key] = _convert_to_structured_field_value(value)
|
||||
|
||||
# Determine dtype for row_partitions and nrows.
|
||||
shape_dtype = _find_shape_dtype(fields, nrows, row_partitions)
|
||||
if nrows is not None:
|
||||
nrows = ops.convert_to_tensor(nrows, shape_dtype)
|
||||
|
||||
# Get the static TensorShape for this StructuredTensor.
|
||||
if rank > 0:
|
||||
for key, value in fields.items():
|
||||
if not shape.is_compatible_with(value.shape[:rank]):
|
||||
raise ValueError('Field {} has shape {}, which is incompatible '
|
||||
'with the shape that was specified or inferred '
|
||||
'from other fields: {}'.format(
|
||||
key, value.shape[:rank], shape))
|
||||
shape = shape.merge_with(value.shape[:rank])
|
||||
|
||||
if rank == 1:
|
||||
# Find a consistent value for `nrows`.
|
||||
static_nrows = tensor_shape.dimension_at_index(shape, 0)
|
||||
for value in fields.values():
|
||||
nrows, static_nrows = _merge_nrows(nrows, static_nrows, value,
|
||||
shape_dtype, validate)
|
||||
if nrows is None:
|
||||
if static_nrows.value is None:
|
||||
raise ValueError('nrows must be specified if rank==1 '
|
||||
'and `fields` is empty.')
|
||||
else:
|
||||
try:
|
||||
value = ops.convert_to_tensor(value)
|
||||
except (ValueError, TypeError):
|
||||
raise TypeError('Unexpected type for value in `fields`: %r' %
|
||||
value)
|
||||
self._fields[key] = value
|
||||
nrows = constant_op.constant(static_nrows.value, shape_dtype)
|
||||
|
||||
# Check the static TensorShape for this StructuredTensor.
|
||||
self._static_shape = shape
|
||||
if rank > 0:
|
||||
for value in self._fields.values():
|
||||
self._static_shape = self._static_shape.merge_with(value.shape[:rank])
|
||||
if rank > 1:
|
||||
# Find a consistent list of RowPartitions.
|
||||
for value in fields.values():
|
||||
row_partitions = _merge_row_partitions(row_partitions, value, rank,
|
||||
shape_dtype, validate)
|
||||
if row_partitions is None:
|
||||
if not shape.is_fully_defined():
|
||||
raise ValueError('row_partitions must be specified if rank>1 '
|
||||
'and `fields` is empty.')
|
||||
else:
|
||||
row_partitions = _row_partitions_for_uniform_shape(
|
||||
np.array(shape.as_list(), dtype=shape_dtype.as_numpy_dtype),
|
||||
shape.rank)
|
||||
assert len(row_partitions) == rank - 1
|
||||
nrows = row_partitions[0].nrows()
|
||||
# Update all field values to use the shared RowPartition objects.
|
||||
fields = dict([(k, _replace_row_partitions(v, row_partitions))
|
||||
for (k, v) in fields.items()])
|
||||
|
||||
self._nested_row_splits = []
|
||||
if rank > 1:
|
||||
# If any fields are ragged, then check that all row-splits match.
|
||||
shared_row_splits = []
|
||||
for field in self._fields.values():
|
||||
# TODO(edloper): A field shouldn't count as ragged if it has
|
||||
# uniform_row_length defined for all the dimensions in question.
|
||||
if isinstance(field, ragged_tensor.RaggedTensor):
|
||||
shared_row_splits.append(field.nested_row_splits[:rank - 1])
|
||||
elif isinstance(field, StructuredTensor) and field.ragged_rank > 0:
|
||||
shared_row_splits.append(field.nested_row_splits[:rank - 1])
|
||||
if shared_row_splits:
|
||||
if len(shared_row_splits) != len(self._fields):
|
||||
raise ValueError('Ragged StructuredTensor contains non-ragged fields')
|
||||
|
||||
# Check if the splits are identical. This should be the common case.
|
||||
identical_splits = True
|
||||
for splits in shared_row_splits[1:]:
|
||||
if len(splits) != len(shared_row_splits[0]):
|
||||
raise ValueError('Fields have inconsistent ragged_rank')
|
||||
for (s1, s2) in zip(splits, shared_row_splits[0]):
|
||||
if s1 is not s2:
|
||||
identical_splits = False
|
||||
|
||||
if identical_splits:
|
||||
self._nested_row_splits = shared_row_splits[0]
|
||||
else:
|
||||
# If splits aren't identical, then add assertions to check that they
|
||||
# match.
|
||||
with ops.control_dependencies(
|
||||
ragged_util.assert_splits_match(shared_row_splits)):
|
||||
self._nested_row_splits = [array_ops.identity(splits)
|
||||
for splits in shared_row_splits[0]]
|
||||
|
||||
# TODO(edloper): Rebuild all fields to ensure that they use the
|
||||
# identical row_splits tensor.
|
||||
|
||||
@classmethod
|
||||
def from_row_splits(cls, values, row_splits, validate=True):
|
||||
"""Creates a ragged StructuredTensor with rows partitioned by `row_splits`.
|
||||
|
||||
See `tf.RaggedTensor` for information about row_splits.
|
||||
|
||||
Args:
|
||||
values: A `StructuredTensor` with shape `[nvals, ...]`.
|
||||
row_splits: A 1-D integer tensor with shape `[nrows+1]`. Must not be
|
||||
empty, and must be sorted in ascending order. `row_splits[0]` must be
|
||||
zero and `row_splits[-1]` must be `nvals`.
|
||||
validate: If true, then use assertions to check that the arguments form
|
||||
a valid ragged `StructuredTensor`.
|
||||
|
||||
Returns:
|
||||
A ragged `StructuredTensor`. `result.rank = values.rank + 1`.
|
||||
"""
|
||||
if not isinstance(values, StructuredTensor):
|
||||
raise TypeError('values must be a StructuredTensor.')
|
||||
if values.shape.rank == 0:
|
||||
raise ValueError('Shape %s must have rank at least 1' % values.shape)
|
||||
row_splits = ops.convert_to_tensor(row_splits, name='row_splits')
|
||||
row_splits.shape.assert_has_rank(1)
|
||||
if tensor_shape.dimension_value(row_splits.shape[0]) == 0:
|
||||
raise ValueError('row_splits may not be empty')
|
||||
if row_splits.dtype not in (dtypes.int32, dtypes.int64):
|
||||
raise ValueError('Row-partitioning tensors must have dtype '
|
||||
'int32 or int64')
|
||||
|
||||
if (row_splits.shape and
|
||||
tensor_shape.dimension_value(row_splits.shape[0]) is not None):
|
||||
nrows = tensor_shape.dimension_value(row_splits.shape[0]) - 1
|
||||
else:
|
||||
nrows = None
|
||||
result_shape = tensor_shape.TensorShape([nrows, None
|
||||
]).concatenate(values.shape[1:])
|
||||
result_fields = {}
|
||||
for (name, field) in values._fields.items():
|
||||
if isinstance(field, StructuredTensor):
|
||||
result_fields[name] = StructuredTensor.from_row_splits(
|
||||
field, row_splits)
|
||||
else:
|
||||
result_fields[name] = ragged_tensor.RaggedTensor.from_row_splits(
|
||||
field, row_splits, validate=validate)
|
||||
return cls(result_shape, result_fields)
|
||||
|
||||
# @TODO(edloper): Add from_row_lengths, etc.
|
||||
return cls(
|
||||
fields,
|
||||
shape,
|
||||
nrows,
|
||||
row_partitions,
|
||||
internal=_structured_tensor_factory_key)
|
||||
|
||||
#=============================================================================
|
||||
# Properties
|
||||
@ -231,7 +252,7 @@ class StructuredTensor(composite_tensor.CompositeTensor):
|
||||
@property
|
||||
def rank(self):
|
||||
"""The rank of this StructuredTensor. Guaranteed not to be `None`."""
|
||||
return self._static_shape.rank
|
||||
return self._shape.rank
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
@ -243,33 +264,34 @@ class StructuredTensor(composite_tensor.CompositeTensor):
|
||||
Returns:
|
||||
`tf.TensorShape`
|
||||
"""
|
||||
return self._static_shape
|
||||
return self._shape
|
||||
|
||||
# TODO(edloper): Make this a func instead of a property? Or make nrows
|
||||
# a property instead of a func? Seems like these should be consistent.
|
||||
@property
|
||||
def nested_row_splits(self):
|
||||
"""A tuple containing the row_splits for all ragged dimensions.
|
||||
def row_partitions(self):
|
||||
"""A tuple of `RowPartition`s defining the shape of this `StructuredTensor`.
|
||||
|
||||
If non-empty, then every `field` in this StructuredTensor is ragged, and
|
||||
has these `nested_row_splits` as their outermost row-splits tensors.
|
||||
If this `StructuredTensor` has a ragged shape, then all fields will be
|
||||
encoded as either `RaggedTensor`s or `StructuredTensor`s with these
|
||||
`RowPartition`s used to define their outermost `self.rank` dimensions.
|
||||
|
||||
If this `StructuredTensor` has a uniform (non-ragged) shape, then these
|
||||
row partitions will all be defined using `uniform_row_length`.
|
||||
|
||||
Returns:
|
||||
A `tuple` of 1-D integer `Tensor`s. The length of this tuple will
|
||||
always be less than `self.rank`.
|
||||
A `tuple` of `RowPartition` objects with length `self.rank - 1`
|
||||
(or `0` if `self.rank < 2`).
|
||||
"""
|
||||
return self._nested_row_splits
|
||||
return self._row_partitions
|
||||
|
||||
@property
|
||||
def ragged_rank(self):
|
||||
"""The number of ragged dimensions in this StructuredTensor.
|
||||
|
||||
See `tf.RaggedTensor` for more information about ragged dimensions and
|
||||
`ragged_rank`.
|
||||
def nrows(self):
|
||||
"""The number of rows in this StructuredTensor (if rank>0).
|
||||
|
||||
Returns:
|
||||
A Python `int` indicating the number of ragged dimensions in this ragged
|
||||
tensor. The outermost dimension is not considered ragged.
|
||||
A scalar integer `Tensor` (or `None` if `self.rank == 0`).
|
||||
"""
|
||||
return len(self._nested_row_splits)
|
||||
return self._nrows
|
||||
|
||||
def _is_eager(self):
|
||||
"""True if all fields are composed of eager tensors."""
|
||||
@ -304,10 +326,16 @@ class StructuredTensor(composite_tensor.CompositeTensor):
|
||||
|
||||
Returns:
|
||||
`Tensor`, `StructuredTensor`, or `RaggedTensor`.
|
||||
|
||||
Raises:
|
||||
KeyError: If the given field_name is not found.
|
||||
"""
|
||||
if isinstance(field_name, (list, tuple)):
|
||||
value = self
|
||||
for f in field_name:
|
||||
if not isinstance(value, StructuredTensor):
|
||||
raise KeyError('Field path {} not found in {}'.format(
|
||||
field_name, self))
|
||||
value = value.field_value(f)
|
||||
return value
|
||||
return self._fields[field_name]
|
||||
@ -356,17 +384,17 @@ class StructuredTensor(composite_tensor.CompositeTensor):
|
||||
if not key:
|
||||
return self
|
||||
|
||||
if self._static_shape.ndims == 0:
|
||||
if self._shape.rank == 0:
|
||||
return self._scalar_getitem(key)
|
||||
else:
|
||||
return self._tensor_getitem(key)
|
||||
|
||||
def _scalar_getitem(self, key):
|
||||
if (isinstance(key[0], slice) and slice.start is None and
|
||||
slice.stop is None and slice.step is None):
|
||||
if (isinstance(key[0], slice) and key[0].start is None and
|
||||
key[0].stop is None and key[0].step is None):
|
||||
fields = dict((field_name, field_value.__getitem__(key[1:]))
|
||||
for (field_name, field_value) in self._fields.items())
|
||||
return StructuredTensor(self._static_shape[1:], fields)
|
||||
return StructuredTensor.from_fields(fields, self._shape)
|
||||
|
||||
elif not isinstance(key[0], compat.bytes_or_text_types):
|
||||
raise ValueError('Key for indexing a StructuredTensor must be a '
|
||||
@ -375,7 +403,7 @@ class StructuredTensor(composite_tensor.CompositeTensor):
|
||||
return self._fields[key[0]].__getitem__(key[1:])
|
||||
|
||||
def _tensor_getitem(self, key):
|
||||
rank = self._static_shape.ndims
|
||||
rank = self._shape.rank
|
||||
if len(key) <= rank:
|
||||
new_fields = dict((field_name, field_value.__getitem__(key))
|
||||
for (field_name, field_value) in self._fields.items())
|
||||
@ -387,20 +415,23 @@ class StructuredTensor(composite_tensor.CompositeTensor):
|
||||
result_shape[d] = None
|
||||
elif isinstance(k, (int, ops.Tensor)):
|
||||
result_shape[d] = -1 # mark for deletion
|
||||
elif k is None:
|
||||
raise ValueError('Slicing not supported for tf.newaxis')
|
||||
else:
|
||||
# Ellipsis, tf.newaxis:
|
||||
raise ValueError('Slicing not supported for %r' % k)
|
||||
result_shape = [d for d in result_shape if d != -1]
|
||||
return StructuredTensor(result_shape, new_fields)
|
||||
return StructuredTensor.from_fields(new_fields, result_shape)
|
||||
|
||||
else:
|
||||
if not isinstance(key[rank], compat.bytes_or_text_types):
|
||||
# TODO(edloper): Also support full slice here?
|
||||
raise ValueError('Key for indexing a StructuredTensor must be a string')
|
||||
return self._fields[key[rank]].__getitem__(key[:rank] + key[rank + 1:])
|
||||
|
||||
def __repr__(self):
|
||||
return '<StructuredTensor(shape=%s, fields=%r)>' % (self._static_shape,
|
||||
self._fields)
|
||||
return '<StructuredTensor(fields={%s}, shape=%s)>' % (', '.join(
|
||||
'%r' % k for k in sorted(self._fields)), self._shape)
|
||||
|
||||
#=============================================================================
|
||||
# Conversion
|
||||
@ -427,7 +458,8 @@ class StructuredTensor(composite_tensor.CompositeTensor):
|
||||
|
||||
Requires that all fields are Eager tensors.
|
||||
|
||||
>>> print(StructuredTensor([3], {'a': [1, 2, 3]}).to_pyval())
|
||||
>>> print(StructuredTensor.from_fields(
|
||||
... {'a': [1, 2, 3]}, [3]).to_pyval())
|
||||
[{b'a': 1}, {b'a': 2}, {b'a': 3}]
|
||||
|
||||
Note that `StructuredTensor.from_pyval(pyval).to_pyval() == pyval`.
|
||||
@ -454,10 +486,9 @@ class StructuredTensor(composite_tensor.CompositeTensor):
|
||||
result[key] = value
|
||||
|
||||
# If rank>0, then re-group each value from dict-of-list to list-of-dict.
|
||||
if len(self._static_shape) > 0: # pylint: disable=g-explicit-length-test
|
||||
return _pyval_field_major_to_node_major(list(result.keys()),
|
||||
list(result.values()),
|
||||
self._static_shape.as_list())
|
||||
if len(self._shape) > 0: # pylint: disable=g-explicit-length-test
|
||||
return _pyval_field_major_to_node_major(
|
||||
list(result.keys()), list(result.values()), self._shape.as_list())
|
||||
else:
|
||||
return result
|
||||
|
||||
@ -503,12 +534,12 @@ class StructuredTensor(composite_tensor.CompositeTensor):
|
||||
spec_shape = typespec._shape # pylint: disable=protected-access
|
||||
field_specs = typespec._field_specs # pylint: disable=protected-access
|
||||
if not (isinstance(typespec, StructuredTensorSpec) and
|
||||
spec_shape.ndims == 0 and set(pyval) == set(field_specs)):
|
||||
spec_shape.rank == 0 and set(pyval) == set(field_specs)):
|
||||
raise ValueError('Value does not match typespec: %r vs %r' %
|
||||
(pyval, typespec))
|
||||
fields = dict(
|
||||
(k, cls.from_pyval(v, field_specs[k])) for (k, v) in pyval.items())
|
||||
return StructuredTensor(shape=(), fields=fields)
|
||||
return StructuredTensor.from_fields(fields=fields, shape=(), validate=False)
|
||||
|
||||
@classmethod
|
||||
def _from_pylist_of_dict(cls, pyval, keys, rank, typespec):
|
||||
@ -532,10 +563,8 @@ class StructuredTensor(composite_tensor.CompositeTensor):
|
||||
'%r vs %r' % (pyval, typespec))
|
||||
for (key, spec) in field_specs.items():
|
||||
fields[key] = cls.from_pyval(fields.get(key, []), spec)
|
||||
if not spec.is_compatible_with(fields[key]):
|
||||
raise ValueError('Value does not match typespec: %r vs %r' %
|
||||
(spec, fields[key]))
|
||||
return StructuredTensor(shape=shape, fields=fields)
|
||||
return StructuredTensor.from_fields(
|
||||
fields=fields, shape=shape, validate=False)
|
||||
|
||||
@classmethod
|
||||
def _from_pylist_of_value(cls, pyval, typespec):
|
||||
@ -543,8 +572,11 @@ class StructuredTensor(composite_tensor.CompositeTensor):
|
||||
if typespec is None:
|
||||
return ragged_factory_ops.constant(pyval)
|
||||
elif isinstance(typespec, tensor_spec.TensorSpec):
|
||||
# TODO(edloper): Check that typespec.shape matches.
|
||||
return constant_op.constant(pyval, typespec.dtype)
|
||||
result = constant_op.constant(pyval, typespec.dtype)
|
||||
if not typespec.shape.is_compatible_with(result.shape):
|
||||
raise ValueError('Value does not match typespec: %r vs %r' %
|
||||
(typespec, pyval))
|
||||
return result
|
||||
elif isinstance(typespec, ragged_tensor.RaggedTensorSpec):
|
||||
# pylint: disable=protected-access
|
||||
return ragged_factory_ops.constant(
|
||||
@ -571,11 +603,83 @@ class StructuredTensor(composite_tensor.CompositeTensor):
|
||||
return constant_op.constant(pyval)
|
||||
else:
|
||||
if not (isinstance(typespec, tensor_spec.TensorSpec) and
|
||||
typespec.shape.ndims == 0):
|
||||
raise ValueError('Value does not match typespec.')
|
||||
typespec.shape.rank == 0):
|
||||
raise ValueError('Value does not match typespec: %r vs %r' %
|
||||
(typespec, pyval))
|
||||
# TODO(edloper): Check that typespec.shape matches.
|
||||
return constant_op.constant(pyval, typespec.dtype)
|
||||
|
||||
#=============================================================================
|
||||
# Transforms
|
||||
#=============================================================================
|
||||
|
||||
# TODO(edloper): Add a 'validate' option here?
|
||||
# TODO(edloper): Unify nomenclature with RaggedTensor. Should RaggedTensor
|
||||
# have a partition_outer_dimension method?
|
||||
def partition_outer_dimension(self, row_partition):
|
||||
"""Partitions the outer dimension of this StructuredTensor.
|
||||
|
||||
Returns a new `StructuredTensor` with the same values as `self`, where
|
||||
the outer dimension is partitioned into two (possibly ragged) dimensions.
|
||||
Requires that this StructuredTensor have an outer dimension (i.e.,
|
||||
`self.shape.rank > 0`).
|
||||
|
||||
>>> st = StructuredTensor.from_pyval(
|
||||
... [{'foo': 12}, {'foo': 33}, {'foo': 99}])
|
||||
>>> partition = RowPartition.from_row_lengths([2, 0, 1])
|
||||
>>> st.partition_outer_dimension(partition)
|
||||
<StructuredTensor [[{'foo': 12}, {'foo': 33}], [], [{'foo': 99}]]>
|
||||
|
||||
Args:
|
||||
row_partition: A `RowPartition`.
|
||||
|
||||
Returns:
|
||||
A `StructuredTensor` with rank `values.rank + 1`.
|
||||
"""
|
||||
if not isinstance(row_partition, RowPartition):
|
||||
raise TypeError('row_partition must be a RowPartition.')
|
||||
if self.shape.rank == 0:
|
||||
raise ValueError('Shape %s must have rank at least 1' % self.shape)
|
||||
return _partition_outer_dimension(self, row_partition)
|
||||
|
||||
def merge_dims(self, outer_axis, inner_axis):
|
||||
"""Merges outer_axis...inner_axis into a single dimension.
|
||||
|
||||
Returns a copy of this RaggedTensor with the specified range of dimensions
|
||||
flattened into a single dimension, with elements in row-major order.
|
||||
|
||||
>>> st = StructuredTensor.from_pyval(
|
||||
... [[{'foo': 12}, {'foo': 33}], [], [{'foo': 99}]])
|
||||
>>> st.merge_dims(0, 1)
|
||||
<StructuredTensor [{'foo': 12}, {'foo': 33}, {'foo': 99}]>
|
||||
|
||||
Args:
|
||||
outer_axis: `int`: The first dimension in the range of dimensions to
|
||||
merge. May be negative (to index from the last dimension).
|
||||
inner_axis: `int`: The last dimension in the range of dimensions to merge.
|
||||
May be negative (to index from the last dimension).
|
||||
|
||||
Returns:
|
||||
A copy of this tensor, with the specified dimensions merged into a
|
||||
single dimension. The shape of the returned tensor will be
|
||||
`self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N`
|
||||
is the total number of slices in the merged dimensions.
|
||||
"""
|
||||
outer_axis = array_ops.get_positive_axis(
|
||||
outer_axis,
|
||||
self.shape.rank,
|
||||
axis_name='outer_axis',
|
||||
ndims_name='rank(self)')
|
||||
inner_axis = array_ops.get_positive_axis(
|
||||
inner_axis,
|
||||
self.shape.rank,
|
||||
axis_name='inner_axis',
|
||||
ndims_name='rank(self)')
|
||||
if not outer_axis < inner_axis:
|
||||
raise ValueError('Expected outer_axis (%d) to be less than '
|
||||
'inner_axis (%d)' % (outer_axis, inner_axis))
|
||||
return _merge_dims(self, outer_axis, inner_axis)
|
||||
|
||||
#=============================================================================
|
||||
# Composite Tensor
|
||||
#=============================================================================
|
||||
@ -594,7 +698,7 @@ class StructuredTensorSpec(type_spec.BatchableTypeSpec):
|
||||
"""Build a type specification for a StructuredTensor.
|
||||
|
||||
Args:
|
||||
shape: The shape of the StructuredTensor. shape.ndims must not be None.
|
||||
shape: The shape of the StructuredTensor. shape.rank must not be None.
|
||||
field_specs: A dictionary mapping from field name to TypeSpec, specifying
|
||||
the tensor type used to encode each field. These TypeSpecs should
|
||||
specify the type of the entire field (including outer dimensions which
|
||||
@ -602,20 +706,23 @@ class StructuredTensorSpec(type_spec.BatchableTypeSpec):
|
||||
contains an int32 vector of size `10` for each structure, then
|
||||
`field_specs['x']` should be `tf.TensorSpec([2, 3, 10], tf.int32)`.
|
||||
"""
|
||||
self._shape = tensor_shape.as_shape(shape)
|
||||
self._field_specs = dict(field_specs)
|
||||
shape = tensor_shape.as_shape(shape)
|
||||
|
||||
# Perform a few sanity checks on the inputs.
|
||||
if self._shape.ndims is None:
|
||||
if shape.rank is None:
|
||||
raise TypeError("StructuredTensor's shape must have known rank.")
|
||||
if not isinstance(self._field_specs, dict):
|
||||
raise TypeError('field_specs must be a dictionary')
|
||||
for key, value in self._field_specs.items():
|
||||
if not isinstance(field_specs, dict):
|
||||
raise TypeError('field_specs must be a dictionary.')
|
||||
for key, value in field_specs.items():
|
||||
if not isinstance(key, str):
|
||||
raise TypeError('field_specs must be a dictionary with string keys.')
|
||||
if not isinstance(value, (StructuredTensorSpec, tensor_spec.TensorSpec,
|
||||
ragged_tensor.RaggedTensorSpec)):
|
||||
raise TypeError('field_spec must be a dictionary with TypeSpec values.')
|
||||
raise TypeError('field_specs must be a dictionary with '
|
||||
'TypeSpec values.')
|
||||
|
||||
self._shape = shape
|
||||
self._field_specs = dict(field_specs)
|
||||
|
||||
@property
|
||||
def value_type(self):
|
||||
@ -625,7 +732,7 @@ class StructuredTensorSpec(type_spec.BatchableTypeSpec):
|
||||
return value._fields
|
||||
|
||||
def _from_components(self, components):
|
||||
return StructuredTensor(self._shape, components)
|
||||
return StructuredTensor.from_fields(components, self._shape, validate=False)
|
||||
|
||||
@property
|
||||
def _component_specs(self):
|
||||
@ -659,6 +766,143 @@ class StructuredTensorSpec(type_spec.BatchableTypeSpec):
|
||||
_FIELD_NAME_RE = re.compile('^[a-zA-Z][a-zA-Z0-9_]*$')
|
||||
|
||||
|
||||
#=============================================================================
|
||||
# Helper funtions
|
||||
#=============================================================================
|
||||
# TODO(edloper): Move some of these helpers to row_partition.py?
|
||||
|
||||
|
||||
def _convert_to_structured_field_value(value):
|
||||
"""Converts `value` to a Tensor, RaggedTensor, or StructuredTensor."""
|
||||
if isinstance(value,
|
||||
(ops.Tensor, ragged_tensor.RaggedTensor, StructuredTensor)):
|
||||
return value
|
||||
elif ragged_tensor.is_ragged(value):
|
||||
return ragged_tensor.convert_to_tensor_or_ragged_tensor(value)
|
||||
else:
|
||||
try:
|
||||
return ops.convert_to_tensor(value)
|
||||
except (ValueError, TypeError):
|
||||
raise TypeError('Unexpected type for value in `fields`: %r' % value)
|
||||
|
||||
|
||||
def _find_shape_dtype(fields, nrows, row_partitions):
|
||||
"""Return a consistent dtype for fields, nrows, & row_partitions."""
|
||||
shape_dtypes = set()
|
||||
for value in fields.values():
|
||||
if isinstance(value, ragged_tensor.RaggedTensor):
|
||||
shape_dtypes.add(value.row_splits.dtype)
|
||||
elif isinstance(value, StructuredTensor) and value.rank > 0:
|
||||
shape_dtypes.add(value.nrows().dtype)
|
||||
if isinstance(nrows, ops.Tensor):
|
||||
shape_dtypes.add(nrows.dtype)
|
||||
if row_partitions is not None:
|
||||
for partition in row_partitions:
|
||||
shape_dtypes.add(partition.dtype)
|
||||
if len(shape_dtypes) > 1:
|
||||
raise ValueError('field values have incompatible row_partition dtypes.')
|
||||
elif shape_dtypes:
|
||||
return shape_dtypes.pop()
|
||||
else:
|
||||
return dtypes.int64
|
||||
|
||||
|
||||
def _merge_nrows(nrows, static_nrows, value, dtype, validate):
|
||||
"""Merges `nrows` with `nrows(value)`.
|
||||
|
||||
Checks that `value` has the expected number of rows (`nrows`), and returns
|
||||
`nrows`. If `validate` is true, then add validation ops that check that
|
||||
the `nrows` values match.
|
||||
|
||||
Args:
|
||||
nrows: scalar integer Tensor.
|
||||
static_nrows: tf.Dimension: static value of nrows, if known.
|
||||
value: Tensor or RaggedTensor or StructuredTensor
|
||||
dtype: dtype for `nrows`.
|
||||
validate: bool -- whether to add validation ops.
|
||||
|
||||
Returns:
|
||||
A tuple `(nrows, static_nrows)`.
|
||||
"""
|
||||
static_value_nrows = tensor_shape.dimension_at_index(value.shape, 0)
|
||||
if isinstance(value, ops.Tensor):
|
||||
value_nrows = array_ops.shape(value, out_type=dtype)[0]
|
||||
else:
|
||||
value_nrows = value.nrows()
|
||||
if nrows is None:
|
||||
nrows = value_nrows
|
||||
elif (static_value_nrows.value is not None and
|
||||
static_nrows.value is not None):
|
||||
if not static_value_nrows.is_compatible_with(static_nrows):
|
||||
raise ValueError('fields have incompatible nrows')
|
||||
nrows = value_nrows # No need to add an assertion op.
|
||||
elif validate:
|
||||
nrows = control_flow_ops.with_dependencies([
|
||||
check_ops.assert_equal(nrows, value_nrows,
|
||||
message='fields have incompatible nrows')
|
||||
], nrows)
|
||||
return nrows, static_nrows.merge_with(static_value_nrows)
|
||||
|
||||
|
||||
def _merge_row_partitions(row_partitions, value, rank, dtype, validate):
|
||||
"""Merges `row_partitions` with `row_partitions(value)`."""
|
||||
if isinstance(value, ops.Tensor):
|
||||
value_row_partitions = _row_partitions_for_tensor(value, rank, dtype)
|
||||
|
||||
elif isinstance(value, ragged_tensor.RaggedTensor):
|
||||
value_row_partitions = _row_partitions_for_ragged_tensor(value, rank, dtype)
|
||||
|
||||
else:
|
||||
assert isinstance(value, StructuredTensor), type(value)
|
||||
value_row_partitions = value.row_partitions[:rank - 1]
|
||||
|
||||
assert len(value_row_partitions) == rank - 1
|
||||
if row_partitions is None:
|
||||
return tuple(value_row_partitions)
|
||||
else:
|
||||
return tuple([
|
||||
p1.merge_precomputed_encodings(p2, validate)
|
||||
for (p1, p2) in zip(row_partitions, value_row_partitions)
|
||||
])
|
||||
|
||||
|
||||
def _row_partitions_for_tensor(value, rank, dtype):
|
||||
"""Returns the row partitions for a tf.Tensor."""
|
||||
shape = array_ops.shape(value, out_type=dtype)
|
||||
return _row_partitions_for_uniform_shape(shape, rank)
|
||||
|
||||
|
||||
def _row_partitions_for_ragged_tensor(value, rank, dtype):
|
||||
"""Returns the row partitions for a tf.RaggedTensor."""
|
||||
assert rank > 1
|
||||
value_row_partitions = value._nested_row_partitions[:rank - 1] # pylint: disable=protected-access
|
||||
if len(value_row_partitions) < (rank - 1):
|
||||
value_row_partitions += _row_partitions_for_tensor(
|
||||
value.flat_values, rank - len(value_row_partitions), dtype)
|
||||
assert len(value_row_partitions) == rank - 1
|
||||
return value_row_partitions
|
||||
|
||||
|
||||
def _row_partitions_for_uniform_shape(shape, rank):
|
||||
"""Returns row partitions for the given shape Tensor.
|
||||
|
||||
Args:
|
||||
shape: A vector describing a uniform shape.
|
||||
rank: The number of dimensions to generate row partitions for
|
||||
|
||||
Returns:
|
||||
A list of (rank-1) `RowPartition`s with uniform row length.
|
||||
"""
|
||||
shape_cumprod = math_ops.cumprod(shape[:rank])
|
||||
# pylint: disable=g-complex-comprehension
|
||||
return tuple([
|
||||
RowPartition.from_uniform_row_length(
|
||||
uniform_row_length=shape[i + 1],
|
||||
nvals=shape_cumprod[i + 1],
|
||||
nrows=shape_cumprod[i]) for i in range(rank - 1)
|
||||
])
|
||||
|
||||
|
||||
def _pyval_field_major_to_node_major(keys, values, shape):
|
||||
"""Regroup each field (k, v) from dict-of-list to list-of-dict.
|
||||
|
||||
@ -767,3 +1011,140 @@ def _pyval_empty_list_depth(pyval):
|
||||
return max(depths) + 1
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _replace_row_partitions(value, new_partitions):
|
||||
"""Updates `value` to use `new_partitions` as its (outer) row partitions.
|
||||
|
||||
This is used to ensure that all fields in a `StructuredTensor` use identical
|
||||
`RowPartition` objects for the shared dimensions. In particular,
|
||||
`StructuredTensor.from_fields` first merges all of the row partitions from
|
||||
any fields, and then replaces the outer row partitions of all fields with
|
||||
the merged row partitions (using this function).
|
||||
|
||||
Args:
|
||||
value: A `Tensor`, `RaggedTensor`, or `StructuredTensor`.
|
||||
new_partitions: A list of row-partitions that should be used by `value`.
|
||||
Must be equivalent to `value`'s current row partitions.
|
||||
|
||||
Returns:
|
||||
A value that is equivalent to `value`, where outer row partitions have been
|
||||
replaced by `new_partitions`.
|
||||
"""
|
||||
if isinstance(value, ops.Tensor) or not new_partitions:
|
||||
return value
|
||||
|
||||
elif isinstance(value, ragged_tensor.RaggedTensor):
|
||||
return ragged_tensor.RaggedTensor._from_row_partition( # pylint: disable=protected-access
|
||||
values=_replace_row_partitions(value.values, new_partitions[1:]),
|
||||
row_partition=new_partitions[0])
|
||||
|
||||
else:
|
||||
assert isinstance(value, StructuredTensor)
|
||||
new_fields = dict((k, _replace_row_partitions(v, new_partitions))
|
||||
for (k, v) in value._fields.items())
|
||||
return StructuredTensor(
|
||||
fields=new_fields,
|
||||
shape=value.shape,
|
||||
nrows=value.nrows(),
|
||||
row_partitions=new_partitions,
|
||||
internal=_structured_tensor_factory_key)
|
||||
|
||||
|
||||
def _partition_outer_dimension(value, row_partition):
|
||||
"""Partitions the outer dimension of `value` using `row_partitions`.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> partition = row_partition.RowPartition.from_row_lengths([2, 0, 1])
|
||||
>>> _partition_outer_dimension(tf.constant([1, 2, 3]), partition)
|
||||
[[1, 2], [], [3]]
|
||||
|
||||
>>> struct_value = StructuredTensor.from_pyval(
|
||||
... [{'x': 1}, {'x': 2}, {'x': 3}])
|
||||
>>> _partition_outer_dimension(struct_value, partition)
|
||||
[[{'x': 1}, {'x': 2}], [], [{'x': 3}]])
|
||||
|
||||
Args:
|
||||
value: Tensor, RaggedTensor, or StructuredTensor
|
||||
row_partition: RowPartition
|
||||
|
||||
Returns:
|
||||
A value with the same type as `value`, where
|
||||
`result.rank = value.rank + 1`.
|
||||
"""
|
||||
is_ragged = row_partition.uniform_row_length() is None
|
||||
if isinstance(value, ops.Tensor) and not is_ragged:
|
||||
new_shape = array_ops.concat(
|
||||
[[row_partition.nrows(),
|
||||
row_partition.uniform_row_length()],
|
||||
array_ops.shape(value, out_type=row_partition.dtype)[2:]],
|
||||
axis=0)
|
||||
return array_ops.reshape(value, new_shape)
|
||||
elif isinstance(value, (ops.Tensor, ragged_tensor.RaggedTensor)):
|
||||
return ragged_tensor.RaggedTensor._from_row_partition( # pylint: disable=protected-access
|
||||
value, row_partition)
|
||||
else:
|
||||
assert isinstance(value, StructuredTensor)
|
||||
nrows = row_partition.static_nrows
|
||||
ncols = row_partition.static_uniform_row_length
|
||||
shape = tensor_shape.TensorShape([nrows, ncols]).concatenate(
|
||||
value.shape[2:])
|
||||
fields = dict((k, _partition_outer_dimension(v, row_partition))
|
||||
for (k, v) in value._fields.items())
|
||||
return StructuredTensor(
|
||||
fields,
|
||||
shape,
|
||||
row_partition.nrows(), (row_partition,) + value.row_partitions,
|
||||
internal=_structured_tensor_factory_key)
|
||||
|
||||
|
||||
def _merge_dims(value, outer_axis, inner_axis):
|
||||
"""Merges `outer_axis...inner_axis` of `value` into a single dimension."""
|
||||
assert outer_axis < inner_axis
|
||||
if isinstance(value, (ops.Tensor, ragged_tensor.RaggedTensor)):
|
||||
return ragged_tensor.merge_dims(value, outer_axis, inner_axis)
|
||||
else:
|
||||
assert isinstance(value, StructuredTensor)
|
||||
|
||||
# Build the new fields.
|
||||
fields = dict((k, _merge_dims(v, outer_axis, inner_axis))
|
||||
for (k, v) in value._fields.items())
|
||||
|
||||
# Build the new shape.
|
||||
value_shape = value.shape
|
||||
shape = (
|
||||
value_shape[:outer_axis] +
|
||||
[value_shape[outer_axis:inner_axis].num_elements()] +
|
||||
value_shape[inner_axis + 1:])
|
||||
|
||||
# Build the new row_partitions & nrows
|
||||
if outer_axis == 0:
|
||||
if inner_axis == value.shape.rank - 1:
|
||||
partitions = ()
|
||||
nrows = value.row_partitions[-1].nvals()
|
||||
else:
|
||||
partitions = value.row_partitions[inner_axis:]
|
||||
nrows = partitions[0].nrows()
|
||||
else:
|
||||
# Use tf.gather to merge row_splits from the merged row partitions.
|
||||
merged_splits = value.row_partitions[outer_axis - 1].row_splits()
|
||||
for dim in range(outer_axis, inner_axis):
|
||||
merged_splits = array_ops.gather(value.row_partitions[dim].row_splits(),
|
||||
merged_splits)
|
||||
|
||||
partitions = (
|
||||
value.row_partitions[:outer_axis - 1] +
|
||||
(RowPartition.from_row_splits(merged_splits),) +
|
||||
value.row_partitions[inner_axis:])
|
||||
nrows = partitions[0].nrows()
|
||||
|
||||
return StructuredTensor(
|
||||
fields,
|
||||
shape,
|
||||
nrows,
|
||||
partitions,
|
||||
internal=_structured_tensor_factory_key)
|
||||
|
||||
|
||||
_structured_tensor_factory_key = object() # unique private object
|
||||
|
@ -201,6 +201,10 @@ class StructuredTensorSliceTest(test_util.TensorFlowTestCase,
|
||||
(SLICE_BUILDER["f4", 1:, "f4_2"], [b"b"]),
|
||||
(SLICE_BUILDER["f4", :, "f4_2"], [b"a", b"b"]),
|
||||
(SLICE_BUILDER["f5", :, :, "f5_1"], [[1, 2], [3, 4]]),
|
||||
# Slicing over multiple keys
|
||||
(SLICE_BUILDER[:], EXAMPLE_STRUCT),
|
||||
# List-valued key.
|
||||
(["f2", 1], EXAMPLE_STRUCT["f2"][1]),
|
||||
])
|
||||
def testGetitemFromScalarStruct(self, slice_spec, expected):
|
||||
# By default, lists are converted to RaggedTensors.
|
||||
@ -242,6 +246,29 @@ class StructuredTensorSliceTest(test_util.TensorFlowTestCase,
|
||||
|
||||
# TODO(edloper): Add tests for slicing from matrix StructuredTensors.
|
||||
|
||||
@parameterized.parameters([
|
||||
(SLICE_BUILDER[:2], r"Key for indexing a StructuredTensor must be "
|
||||
r"a string or a full slice \(':'\)"),
|
||||
(SLICE_BUILDER["f4", ...], r"Slicing not supported for Ellipsis"),
|
||||
(SLICE_BUILDER["f4", None], r"Slicing not supported for tf.newaxis"),
|
||||
(SLICE_BUILDER["f4", :, 0],
|
||||
r"Key for indexing a StructuredTensor must be a string"),
|
||||
])
|
||||
def testGetItemError(self, slice_spec, error, exception=ValueError):
|
||||
struct = structured_tensor.StructuredTensor.from_pyval(EXAMPLE_STRUCT)
|
||||
with self.assertRaisesRegexp(exception, error):
|
||||
struct.__getitem__(slice_spec)
|
||||
|
||||
@parameterized.parameters([
|
||||
(SLICE_BUILDER[:, 1],
|
||||
r"Key for indexing a StructuredTensor must be a string"),
|
||||
])
|
||||
def testGetItemFromVectorError(self, slice_spec, error, exception=ValueError):
|
||||
struct = structured_tensor.StructuredTensor.from_pyval(
|
||||
EXAMPLE_STRUCT_VECTOR)
|
||||
with self.assertRaisesRegexp(exception, error):
|
||||
struct.__getitem__(slice_spec)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
@ -81,6 +81,18 @@ class StructuredTensorSpecTest(test_util.TensorFlowTestCase,
|
||||
self.assertEqual(spec2._shape, (1, 2))
|
||||
self.assertEqual(spec2._field_specs, spec2_fields)
|
||||
|
||||
@parameterized.parameters([
|
||||
(None, {}, r"StructuredTensor's shape must have known rank\."),
|
||||
([], None, r'field_specs must be a dictionary\.'),
|
||||
([], {1: tensor_spec.TensorSpec(None)},
|
||||
r'field_specs must be a dictionary with string keys\.'),
|
||||
([], {'x': 0},
|
||||
r'field_specs must be a dictionary with TypeSpec values\.'),
|
||||
])
|
||||
def testConstructionErrors(self, shape, field_specs, error):
|
||||
with self.assertRaisesRegexp(TypeError, error):
|
||||
structured_tensor.StructuredTensorSpec(shape, field_specs)
|
||||
|
||||
def testValueType(self):
|
||||
spec1 = StructuredTensorSpec([1, 2, 3], dict(a=T_1_2))
|
||||
self.assertEqual(spec1.value_type, StructuredTensor)
|
||||
@ -118,11 +130,13 @@ class StructuredTensorSpecTest(test_util.TensorFlowTestCase,
|
||||
'fields': dict(x=[[1.0, 2.0]]),
|
||||
'field_specs': dict(x=T_1_2),
|
||||
},
|
||||
{
|
||||
'shape': [1, 2, 3],
|
||||
'fields': {},
|
||||
'field_specs': {},
|
||||
},
|
||||
# TODO(edloper): Enable this test once we update StructuredTensorSpec
|
||||
# to contain the shared row partitions.
|
||||
#{
|
||||
# 'shape': [1, 2, 3],
|
||||
# 'fields': {},
|
||||
# 'field_specs': {},
|
||||
#},
|
||||
{
|
||||
'shape': [2],
|
||||
'fields': dict(
|
||||
@ -133,7 +147,7 @@ class StructuredTensorSpecTest(test_util.TensorFlowTestCase,
|
||||
]) # pyformat: disable
|
||||
def testToFromComponents(self, shape, fields, field_specs):
|
||||
components = fields
|
||||
struct = StructuredTensor(shape, fields)
|
||||
struct = StructuredTensor.from_fields(fields, shape)
|
||||
spec = StructuredTensorSpec(shape, field_specs)
|
||||
actual_components = spec._to_components(struct)
|
||||
self.assertAllTensorsEqual(actual_components, components)
|
||||
@ -164,39 +178,40 @@ class StructuredTensorSpecTest(test_util.TensorFlowTestCase,
|
||||
@parameterized.parameters([
|
||||
{
|
||||
'unbatched': lambda: [
|
||||
StructuredTensor([], {'a': 1, 'b': [5, 6]}),
|
||||
StructuredTensor([], {'a': 2, 'b': [7, 8]})],
|
||||
StructuredTensor.from_fields({'a': 1, 'b': [5, 6]}),
|
||||
StructuredTensor.from_fields({'a': 2, 'b': [7, 8]})],
|
||||
'batch_size': 2,
|
||||
'batched': lambda: StructuredTensor([2], {
|
||||
'batched': lambda: StructuredTensor.from_fields(shape=[2], fields={
|
||||
'a': [1, 2],
|
||||
'b': [[5, 6], [7, 8]]}),
|
||||
},
|
||||
{
|
||||
'unbatched': lambda: [
|
||||
StructuredTensor([3], {
|
||||
StructuredTensor.from_fields(shape=[3], fields={
|
||||
'a': [1, 2, 3],
|
||||
'b': [[5, 6], [6, 7], [7, 8]]}),
|
||||
StructuredTensor([3], {
|
||||
StructuredTensor.from_fields(shape=[3], fields={
|
||||
'a': [2, 3, 4],
|
||||
'b': [[2, 2], [3, 3], [4, 4]]})],
|
||||
'batch_size': 2,
|
||||
'batched': lambda: StructuredTensor([2, 3], {
|
||||
'batched': lambda: StructuredTensor.from_fields(shape=[2, 3], fields={
|
||||
'a': [[1, 2, 3], [2, 3, 4]],
|
||||
'b': [[[5, 6], [6, 7], [7, 8]],
|
||||
[[2, 2], [3, 3], [4, 4]]]}),
|
||||
},
|
||||
{
|
||||
'unbatched': lambda: [
|
||||
StructuredTensor([], {
|
||||
StructuredTensor.from_fields(shape=[], fields={
|
||||
'a': 1,
|
||||
'b': StructuredTensor([], {'x': [5]})}),
|
||||
StructuredTensor([], {
|
||||
'b': StructuredTensor.from_fields({'x': [5]})}),
|
||||
StructuredTensor.from_fields(shape=[], fields={
|
||||
'a': 2,
|
||||
'b': StructuredTensor([], {'x': [6]})})],
|
||||
'b': StructuredTensor.from_fields({'x': [6]})})],
|
||||
'batch_size': 2,
|
||||
'batched': lambda: StructuredTensor([2], {
|
||||
'batched': lambda: StructuredTensor.from_fields(shape=[2], fields={
|
||||
'a': [1, 2],
|
||||
'b': StructuredTensor([2], {'x': [[5], [6]]})}),
|
||||
'b': StructuredTensor.from_fields(shape=[2], fields={
|
||||
'x': [[5], [6]]})}),
|
||||
},
|
||||
]) # pyformat: disable
|
||||
def testBatchUnbatchValues(self, unbatched, batch_size, batched):
|
||||
@ -225,5 +240,6 @@ class StructuredTensorSpecTest(test_util.TensorFlowTestCase,
|
||||
for (actual, expected) in zip(actual_unbatched, unbatched):
|
||||
self.assertAllEqual(actual, expected)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user