[framework] Add 'most_specific_compatible_shape' for the most specific TensorShape compatible with given two TensorShapes.

PiperOrigin-RevId: 161112884
This commit is contained in:
A. Unique TensorFlower 2017-07-06 12:14:45 -07:00 committed by TensorFlower Gardener
parent 27b341c800
commit 70aa8daacf
4 changed files with 100 additions and 66 deletions

View File

@ -350,8 +350,7 @@ def _estimate_data_distribution(c, num_examples_per_class_seen):
# cross-device round-trip. Just use the cached value. # cross-device round-trip. Just use the cached value.
num_examples_per_class_seen = num_examples_per_class_seen.assign_add( num_examples_per_class_seen = num_examples_per_class_seen.assign_add(
math_ops.reduce_sum( math_ops.reduce_sum(
array_ops.one_hot(c, num_classes, dtype=dtypes.int64), array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0))
0))
init_prob_estimate = math_ops.truediv( init_prob_estimate = math_ops.truediv(
num_examples_per_class_seen, num_examples_per_class_seen,
math_ops.reduce_sum(num_examples_per_class_seen)) math_ops.reduce_sum(num_examples_per_class_seen))
@ -445,8 +444,8 @@ class Dataset(object):
output_shapes = str(output_shapes).replace("'", "") output_shapes = str(output_shapes).replace("'", "")
output_types = nest.map_structure(repr, self.output_types) output_types = nest.map_structure(repr, self.output_types)
output_types = str(output_types).replace("'", "") output_types = str(output_types).replace("'", "")
return ("<%s shapes: %s, types: %s>" return ("<%s shapes: %s, types: %s>" % (type(self).__name__, output_shapes,
% (type(self).__name__, output_shapes, output_types)) output_types))
@staticmethod @staticmethod
def from_tensors(tensors): def from_tensors(tensors):
@ -595,9 +594,7 @@ class Dataset(object):
dataset = dataset.repeat(num_epochs) dataset = dataset.repeat(num_epochs)
if randomize_input: if randomize_input:
dataset = dataset.shuffle(capacity) dataset = dataset.shuffle(capacity)
dataset = dataset.map( dataset = dataset.map(lambda x: _parse_example(nest.flatten(x), features))
lambda x: _parse_example(nest.flatten(x), features)
)
dataset = dataset.batch(batch_size) dataset = dataset.batch(batch_size)
return dataset return dataset
@ -1031,12 +1028,14 @@ class ZipDataset(Dataset):
@property @property
def output_shapes(self): def output_shapes(self):
return nest.pack_sequence_as(self._datasets, [ return nest.pack_sequence_as(self._datasets, [
ds.output_shapes for ds in nest.flatten(self._datasets)]) ds.output_shapes for ds in nest.flatten(self._datasets)
])
@property @property
def output_types(self): def output_types(self):
return nest.pack_sequence_as(self._datasets, [ return nest.pack_sequence_as(self._datasets, [
ds.output_types for ds in nest.flatten(self._datasets)]) ds.output_types for ds in nest.flatten(self._datasets)
])
class RepeatDataset(Dataset): class RepeatDataset(Dataset):
@ -1049,8 +1048,8 @@ class RepeatDataset(Dataset):
if count is None: if count is None:
self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count") self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
else: else:
self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, self._count = ops.convert_to_tensor(
name="count") count, dtype=dtypes.int64, name="count")
def make_dataset_resource(self): def make_dataset_resource(self):
return gen_dataset_ops.repeat_dataset( return gen_dataset_ops.repeat_dataset(
@ -1155,8 +1154,8 @@ class ShuffleDataset(Dataset):
if seed2 is None: if seed2 is None:
self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2") self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2")
else: else:
self._seed2 = ops.convert_to_tensor(seed2, dtype=dtypes.int64, self._seed2 = ops.convert_to_tensor(
name="seed2") seed2, dtype=dtypes.int64, name="seed2")
def make_dataset_resource(self): def make_dataset_resource(self):
return gen_dataset_ops.shuffle_dataset( return gen_dataset_ops.shuffle_dataset(
@ -1310,12 +1309,10 @@ def _padding_value_to_tensor(value, output_type):
""" """
value = ops.convert_to_tensor(value, name="padding_value") value = ops.convert_to_tensor(value, name="padding_value")
if not value.shape.is_compatible_with(tensor_shape.scalar()): if not value.shape.is_compatible_with(tensor_shape.scalar()):
raise ValueError( raise ValueError("Padding value should be a scalar, but is not: %s" % value)
"Padding value should be a scalar, but is not: %s" % value)
if value.dtype != output_type: if value.dtype != output_type:
raise TypeError( raise TypeError("Padding value tensor (%s) does not match output type: %s" %
"Padding value tensor (%s) does not match output type: %s" (value, output_type))
% (value, output_type))
return value return value
@ -1329,20 +1326,20 @@ class PaddedBatchDataset(Dataset):
self._batch_size = batch_size self._batch_size = batch_size
padding_values = (padding_values if padding_values is not None else padding_values = (padding_values if padding_values is not None else
self._default_padding(input_dataset)) self._default_padding(input_dataset))
self._padded_shapes = nest.map_structure_up_to(input_dataset.output_shapes, self._padded_shapes = nest.map_structure_up_to(
_partial_shape_to_tensor, input_dataset.output_shapes, _partial_shape_to_tensor, padded_shapes)
padded_shapes) self._padding_values = nest.map_structure_up_to(
self._padding_values = nest.map_structure_up_to(input_dataset.output_shapes, input_dataset.output_shapes, _padding_value_to_tensor, padding_values,
_padding_value_to_tensor,
padding_values,
input_dataset.output_types) input_dataset.output_types)
def _default_padding(self, input_dataset): def _default_padding(self, input_dataset):
def make_zero(t): def make_zero(t):
if t.base_dtype == dtypes.string: if t.base_dtype == dtypes.string:
return "" return ""
else: else:
return np.zeros_like(t.as_numpy_dtype()) return np.zeros_like(t.as_numpy_dtype())
return nest.map_structure(make_zero, input_dataset.output_types) return nest.map_structure(make_zero, input_dataset.output_types)
def make_dataset_resource(self): def make_dataset_resource(self):
@ -1358,9 +1355,11 @@ class PaddedBatchDataset(Dataset):
@property @property
def output_shapes(self): def output_shapes(self):
def _padded_shape_to_batch_shape(s): def _padded_shape_to_batch_shape(s):
return tensor_shape.vector(None).concatenate( return tensor_shape.vector(None).concatenate(
tensor_util.constant_value_as_shape(s)) tensor_util.constant_value_as_shape(s))
return nest.map_structure(_padded_shape_to_batch_shape, self._padded_shapes) return nest.map_structure(_padded_shape_to_batch_shape, self._padded_shapes)
@property @property
@ -1376,8 +1375,8 @@ class DenseToSparseBatchDataset(Dataset):
super(DenseToSparseBatchDataset, self).__init__() super(DenseToSparseBatchDataset, self).__init__()
if not isinstance(input_dataset.output_types, dtypes.DType): if not isinstance(input_dataset.output_types, dtypes.DType):
raise TypeError("DenseToSparseDataset requires an input whose elements " raise TypeError("DenseToSparseDataset requires an input whose elements "
"have a single component, whereas the input has %r." "have a single component, whereas the input has %r." %
% input_dataset.output_types) input_dataset.output_types)
self._input_dataset = input_dataset self._input_dataset = input_dataset
self._batch_size = batch_size self._batch_size = batch_size
self._row_shape = _partial_shape_to_tensor(row_shape) self._row_shape = _partial_shape_to_tensor(row_shape)
@ -1493,22 +1492,6 @@ class GroupByWindowDataset(Dataset):
return self._output_types return self._output_types
def _most_specific_compatible_shape(s1, s2):
"""Returns the most specific shape compatible with `s1` and `s2`."""
if s1.dims is None:
return s1
if s2.dims is None:
return s2
s1.assert_same_rank(s2)
dims = []
for dim1, dim2 in zip(s1, s2):
if dim1.value is None or dim2.value is None or dim1.value != dim2.value:
dims.append(tensor_shape.Dimension(None))
else:
dims.append(dim1.value)
return tensor_shape.TensorShape(dims)
class MapDataset(Dataset): class MapDataset(Dataset):
"""A `Dataset` that maps a function over elements in its input.""" """A `Dataset` that maps a function over elements in its input."""
@ -1593,9 +1576,7 @@ class MapDataset(Dataset):
class FlatMapDataset(Dataset): class FlatMapDataset(Dataset):
"""A `Dataset` that maps a function over its input and flattens the result.""" """A `Dataset` that maps a function over its input and flattens the result."""
def __init__(self, def __init__(self, input_dataset, map_func):
input_dataset,
map_func):
"""See `Dataset.flat_map()` for details.""" """See `Dataset.flat_map()` for details."""
super(FlatMapDataset, self).__init__() super(FlatMapDataset, self).__init__()
self._input_dataset = input_dataset self._input_dataset = input_dataset
@ -1799,8 +1780,11 @@ class FixedLengthRecordDataset(Dataset):
return dtypes.string return dtypes.string
def rejection_resample(dataset, class_func, target_dist, def rejection_resample(dataset,
initial_dist=None, seed=None): class_func,
target_dist,
initial_dist=None,
seed=None):
"""Resamples this dataset to achieve a target class distribution. """Resamples this dataset to achieve a target class distribution.
**NOTE** Resampling is performed via rejection sampling; some fraction **NOTE** Resampling is performed via rejection sampling; some fraction
@ -1825,36 +1809,34 @@ def rejection_resample(dataset, class_func, target_dist,
target_dist = ops.convert_to_tensor(target_dist, name="initial_dist") target_dist = ops.convert_to_tensor(target_dist, name="initial_dist")
class_values_ds = dataset.map(class_func) class_values_ds = dataset.map(class_func)
if initial_dist is not None: if initial_dist is not None:
initial_dist = ops.convert_to_tensor( initial_dist = ops.convert_to_tensor(initial_dist, name="initial_dist")
initial_dist, name="initial_dist")
acceptance_dist = _calculate_acceptance_probs(initial_dist, target_dist) acceptance_dist = _calculate_acceptance_probs(initial_dist, target_dist)
initial_dist_ds = Dataset.from_tensors(initial_dist).repeat() initial_dist_ds = Dataset.from_tensors(initial_dist).repeat()
acceptance_dist_ds = Dataset.from_tensors(acceptance_dist).repeat() acceptance_dist_ds = Dataset.from_tensors(acceptance_dist).repeat()
else: else:
num_classes = (target_dist.shape[0].value num_classes = (target_dist.shape[0].value or
or array_ops.shape(target_dist)[0]) array_ops.shape(target_dist)[0])
smoothing_constant = 10 smoothing_constant = 10
num_examples_per_class_seen = resource_variable_ops.ResourceVariable( num_examples_per_class_seen = resource_variable_ops.ResourceVariable(
initial_value=array_ops.fill( initial_value=array_ops.fill([num_classes],
[num_classes], np.int64(smoothing_constant)), np.int64(smoothing_constant)),
trainable=False, trainable=False,
name="class_count", name="class_count",
dtype=dtypes.int64) dtype=dtypes.int64)
def update_estimate_and_tile(c): def update_estimate_and_tile(c):
return array_ops.tile( return array_ops.tile(
array_ops.expand_dims( array_ops.expand_dims(
_estimate_data_distribution(c, num_examples_per_class_seen), 0), _estimate_data_distribution(c, num_examples_per_class_seen), 0),
[dist_estimation_batch_size, 1]) [dist_estimation_batch_size, 1])
initial_dist_ds = (class_values_ds
.batch(dist_estimation_batch_size) initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size)
.map(update_estimate_and_tile) .map(update_estimate_and_tile).unbatch())
.unbatch())
acceptance_dist_ds = initial_dist_ds.map( acceptance_dist_ds = initial_dist_ds.map(
lambda initial: _calculate_acceptance_probs(initial, target_dist)) lambda initial: _calculate_acceptance_probs(initial, target_dist))
def maybe_warn_on_large_rejection(accept_dist, initial_dist): def maybe_warn_on_large_rejection(accept_dist, initial_dist):
proportion_rejected = math_ops.reduce_sum( proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist)
(1 - accept_dist) * initial_dist)
return control_flow_ops.cond( return control_flow_ops.cond(
math_ops.less(proportion_rejected, .5), math_ops.less(proportion_rejected, .5),
lambda: accept_dist, lambda: accept_dist,
@ -1864,12 +1846,10 @@ def rejection_resample(dataset, class_func, target_dist,
summarize=100, summarize=100,
first_n=10)) first_n=10))
acceptance_dist_ds = ( acceptance_dist_ds = (Dataset.zip((acceptance_dist_ds, initial_dist_ds))
Dataset.zip((acceptance_dist_ds, initial_dist_ds))
.map(maybe_warn_on_large_rejection)) .map(maybe_warn_on_large_rejection))
current_probabilities_ds = (Dataset current_probabilities_ds = (Dataset.zip((acceptance_dist_ds, class_values_ds))
.zip((acceptance_dist_ds, class_values_ds))
.map(array_ops.gather)) .map(array_ops.gather))
filtered_ds = ( filtered_ds = (
Dataset.zip((class_values_ds, current_probabilities_ds, dataset)) Dataset.zip((class_values_ds, current_probabilities_ds, dataset))

View File

@ -736,6 +736,36 @@ class TensorShape(object):
if not self.is_compatible_with(other): if not self.is_compatible_with(other):
raise ValueError("Shapes %s and %s are incompatible" % (self, other)) raise ValueError("Shapes %s and %s are incompatible" % (self, other))
def most_specific_compatible_shape(self, other):
"""Returns the most specific TensorShape compatible with `self` and `other`.
* TensorShape([None, 1]) is the most specific TensorShape compatible with
both TensorShape([2, 1]) and TensorShape([5, 1]). Note that
TensorShape(None) is also compatible with above mentioned TensorShapes.
* TensorShape([1, 2, 3]) is the most specific TensorShape compatible with
both TensorShape([1, 2, 3]) and TensorShape([1, 2, 3]). There are more
less specific TensorShapes compatible with above mentioned TensorShapes,
e.g. TensorShape([1, 2, None]), TensorShape(None).
Args:
other: Another `TensorShape`.
Returns:
A `TensorShape` which is the most specific compatible shape of `self`
and `other`.
"""
other = as_shape(other)
if self._dims is None or other.dims is None or self.ndims != other.ndims:
return unknown_shape()
dims = [(Dimension(None))] * self.ndims
for i, (d1, d2) in enumerate(zip(self._dims, other.dims)):
if d1 is not None and d2 is not None and d1 == d2:
dims[i] = d1
return TensorShape(dims)
def is_fully_defined(self): def is_fully_defined(self):
"""Returns True iff `self` is fully defined in every dimension.""" """Returns True iff `self` is fully defined in every dimension."""
return (self._dims is not None and all(dim.value is not None return (self._dims is not None and all(dim.value is not None

View File

@ -275,6 +275,26 @@ class ShapeTest(test_util.TensorFlowTestCase):
tensor_shape.TensorShape([1, 2]).concatenate( tensor_shape.TensorShape([1, 2]).concatenate(
tensor_shape.Dimension(3))) tensor_shape.Dimension(3)))
def _testMostSpecificCompatibleShapeHelper(self, x, y, expected):
mcs = tensor_shape.TensorShape(x).most_specific_compatible_shape(
tensor_shape.TensorShape(y))
mcs_dims = mcs.dims
if expected is None or mcs_dims is None:
self.assertIs(expected, mcs_dims)
else:
self.assertEqual(expected, mcs.as_list())
def testMostSpecificCompatibleShape(self):
self._testMostSpecificCompatibleShapeHelper([1, 2], None, None)
self._testMostSpecificCompatibleShapeHelper(None, [1, 2], None)
self._testMostSpecificCompatibleShapeHelper([1, 2], [1, 2, 3, 4], None)
self._testMostSpecificCompatibleShapeHelper([1, 2, 3, 4], [1, 2], None)
self._testMostSpecificCompatibleShapeHelper([1, 2], [1, 2], [1, 2])
self._testMostSpecificCompatibleShapeHelper([None, 2, 3], [1, 1, 3],
[None, None, 3])
self._testMostSpecificCompatibleShapeHelper([1, 1, 3], [None, 2, 3],
[None, None, 3])
def testHelpers(self): def testHelpers(self):
tensor_shape.TensorShape([]).assert_is_compatible_with( tensor_shape.TensorShape([]).assert_is_compatible_with(
tensor_shape.scalar()) tensor_shape.scalar())

View File

@ -54,6 +54,10 @@ tf_class {
name: "merge_with" name: "merge_with"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "most_specific_compatible_shape"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "num_elements" name: "num_elements"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"