[framework] Add 'most_specific_compatible_shape' for the most specific TensorShape compatible with given two TensorShapes.
PiperOrigin-RevId: 161112884
This commit is contained in:
parent
27b341c800
commit
70aa8daacf
@ -350,8 +350,7 @@ def _estimate_data_distribution(c, num_examples_per_class_seen):
|
|||||||
# cross-device round-trip. Just use the cached value.
|
# cross-device round-trip. Just use the cached value.
|
||||||
num_examples_per_class_seen = num_examples_per_class_seen.assign_add(
|
num_examples_per_class_seen = num_examples_per_class_seen.assign_add(
|
||||||
math_ops.reduce_sum(
|
math_ops.reduce_sum(
|
||||||
array_ops.one_hot(c, num_classes, dtype=dtypes.int64),
|
array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0))
|
||||||
0))
|
|
||||||
init_prob_estimate = math_ops.truediv(
|
init_prob_estimate = math_ops.truediv(
|
||||||
num_examples_per_class_seen,
|
num_examples_per_class_seen,
|
||||||
math_ops.reduce_sum(num_examples_per_class_seen))
|
math_ops.reduce_sum(num_examples_per_class_seen))
|
||||||
@ -445,8 +444,8 @@ class Dataset(object):
|
|||||||
output_shapes = str(output_shapes).replace("'", "")
|
output_shapes = str(output_shapes).replace("'", "")
|
||||||
output_types = nest.map_structure(repr, self.output_types)
|
output_types = nest.map_structure(repr, self.output_types)
|
||||||
output_types = str(output_types).replace("'", "")
|
output_types = str(output_types).replace("'", "")
|
||||||
return ("<%s shapes: %s, types: %s>"
|
return ("<%s shapes: %s, types: %s>" % (type(self).__name__, output_shapes,
|
||||||
% (type(self).__name__, output_shapes, output_types))
|
output_types))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_tensors(tensors):
|
def from_tensors(tensors):
|
||||||
@ -595,9 +594,7 @@ class Dataset(object):
|
|||||||
dataset = dataset.repeat(num_epochs)
|
dataset = dataset.repeat(num_epochs)
|
||||||
if randomize_input:
|
if randomize_input:
|
||||||
dataset = dataset.shuffle(capacity)
|
dataset = dataset.shuffle(capacity)
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(lambda x: _parse_example(nest.flatten(x), features))
|
||||||
lambda x: _parse_example(nest.flatten(x), features)
|
|
||||||
)
|
|
||||||
dataset = dataset.batch(batch_size)
|
dataset = dataset.batch(batch_size)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
@ -897,7 +894,7 @@ class Dataset(object):
|
|||||||
A `Dataset`.
|
A `Dataset`.
|
||||||
"""
|
"""
|
||||||
return self.flat_map(
|
return self.flat_map(
|
||||||
map_func=lambda *args: Dataset.from_tensor_slices(args))
|
map_func=lambda *args: Dataset.from_tensor_slices(args))
|
||||||
|
|
||||||
def filter(self, predicate):
|
def filter(self, predicate):
|
||||||
"""Filters this dataset according to `predicate`.
|
"""Filters this dataset according to `predicate`.
|
||||||
@ -1031,12 +1028,14 @@ class ZipDataset(Dataset):
|
|||||||
@property
|
@property
|
||||||
def output_shapes(self):
|
def output_shapes(self):
|
||||||
return nest.pack_sequence_as(self._datasets, [
|
return nest.pack_sequence_as(self._datasets, [
|
||||||
ds.output_shapes for ds in nest.flatten(self._datasets)])
|
ds.output_shapes for ds in nest.flatten(self._datasets)
|
||||||
|
])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_types(self):
|
def output_types(self):
|
||||||
return nest.pack_sequence_as(self._datasets, [
|
return nest.pack_sequence_as(self._datasets, [
|
||||||
ds.output_types for ds in nest.flatten(self._datasets)])
|
ds.output_types for ds in nest.flatten(self._datasets)
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
class RepeatDataset(Dataset):
|
class RepeatDataset(Dataset):
|
||||||
@ -1049,8 +1048,8 @@ class RepeatDataset(Dataset):
|
|||||||
if count is None:
|
if count is None:
|
||||||
self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
|
self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
|
||||||
else:
|
else:
|
||||||
self._count = ops.convert_to_tensor(count, dtype=dtypes.int64,
|
self._count = ops.convert_to_tensor(
|
||||||
name="count")
|
count, dtype=dtypes.int64, name="count")
|
||||||
|
|
||||||
def make_dataset_resource(self):
|
def make_dataset_resource(self):
|
||||||
return gen_dataset_ops.repeat_dataset(
|
return gen_dataset_ops.repeat_dataset(
|
||||||
@ -1155,8 +1154,8 @@ class ShuffleDataset(Dataset):
|
|||||||
if seed2 is None:
|
if seed2 is None:
|
||||||
self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2")
|
self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2")
|
||||||
else:
|
else:
|
||||||
self._seed2 = ops.convert_to_tensor(seed2, dtype=dtypes.int64,
|
self._seed2 = ops.convert_to_tensor(
|
||||||
name="seed2")
|
seed2, dtype=dtypes.int64, name="seed2")
|
||||||
|
|
||||||
def make_dataset_resource(self):
|
def make_dataset_resource(self):
|
||||||
return gen_dataset_ops.shuffle_dataset(
|
return gen_dataset_ops.shuffle_dataset(
|
||||||
@ -1310,12 +1309,10 @@ def _padding_value_to_tensor(value, output_type):
|
|||||||
"""
|
"""
|
||||||
value = ops.convert_to_tensor(value, name="padding_value")
|
value = ops.convert_to_tensor(value, name="padding_value")
|
||||||
if not value.shape.is_compatible_with(tensor_shape.scalar()):
|
if not value.shape.is_compatible_with(tensor_shape.scalar()):
|
||||||
raise ValueError(
|
raise ValueError("Padding value should be a scalar, but is not: %s" % value)
|
||||||
"Padding value should be a scalar, but is not: %s" % value)
|
|
||||||
if value.dtype != output_type:
|
if value.dtype != output_type:
|
||||||
raise TypeError(
|
raise TypeError("Padding value tensor (%s) does not match output type: %s" %
|
||||||
"Padding value tensor (%s) does not match output type: %s"
|
(value, output_type))
|
||||||
% (value, output_type))
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
@ -1329,20 +1326,20 @@ class PaddedBatchDataset(Dataset):
|
|||||||
self._batch_size = batch_size
|
self._batch_size = batch_size
|
||||||
padding_values = (padding_values if padding_values is not None else
|
padding_values = (padding_values if padding_values is not None else
|
||||||
self._default_padding(input_dataset))
|
self._default_padding(input_dataset))
|
||||||
self._padded_shapes = nest.map_structure_up_to(input_dataset.output_shapes,
|
self._padded_shapes = nest.map_structure_up_to(
|
||||||
_partial_shape_to_tensor,
|
input_dataset.output_shapes, _partial_shape_to_tensor, padded_shapes)
|
||||||
padded_shapes)
|
self._padding_values = nest.map_structure_up_to(
|
||||||
self._padding_values = nest.map_structure_up_to(input_dataset.output_shapes,
|
input_dataset.output_shapes, _padding_value_to_tensor, padding_values,
|
||||||
_padding_value_to_tensor,
|
input_dataset.output_types)
|
||||||
padding_values,
|
|
||||||
input_dataset.output_types)
|
|
||||||
|
|
||||||
def _default_padding(self, input_dataset):
|
def _default_padding(self, input_dataset):
|
||||||
|
|
||||||
def make_zero(t):
|
def make_zero(t):
|
||||||
if t.base_dtype == dtypes.string:
|
if t.base_dtype == dtypes.string:
|
||||||
return ""
|
return ""
|
||||||
else:
|
else:
|
||||||
return np.zeros_like(t.as_numpy_dtype())
|
return np.zeros_like(t.as_numpy_dtype())
|
||||||
|
|
||||||
return nest.map_structure(make_zero, input_dataset.output_types)
|
return nest.map_structure(make_zero, input_dataset.output_types)
|
||||||
|
|
||||||
def make_dataset_resource(self):
|
def make_dataset_resource(self):
|
||||||
@ -1358,9 +1355,11 @@ class PaddedBatchDataset(Dataset):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def output_shapes(self):
|
def output_shapes(self):
|
||||||
|
|
||||||
def _padded_shape_to_batch_shape(s):
|
def _padded_shape_to_batch_shape(s):
|
||||||
return tensor_shape.vector(None).concatenate(
|
return tensor_shape.vector(None).concatenate(
|
||||||
tensor_util.constant_value_as_shape(s))
|
tensor_util.constant_value_as_shape(s))
|
||||||
|
|
||||||
return nest.map_structure(_padded_shape_to_batch_shape, self._padded_shapes)
|
return nest.map_structure(_padded_shape_to_batch_shape, self._padded_shapes)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -1376,8 +1375,8 @@ class DenseToSparseBatchDataset(Dataset):
|
|||||||
super(DenseToSparseBatchDataset, self).__init__()
|
super(DenseToSparseBatchDataset, self).__init__()
|
||||||
if not isinstance(input_dataset.output_types, dtypes.DType):
|
if not isinstance(input_dataset.output_types, dtypes.DType):
|
||||||
raise TypeError("DenseToSparseDataset requires an input whose elements "
|
raise TypeError("DenseToSparseDataset requires an input whose elements "
|
||||||
"have a single component, whereas the input has %r."
|
"have a single component, whereas the input has %r." %
|
||||||
% input_dataset.output_types)
|
input_dataset.output_types)
|
||||||
self._input_dataset = input_dataset
|
self._input_dataset = input_dataset
|
||||||
self._batch_size = batch_size
|
self._batch_size = batch_size
|
||||||
self._row_shape = _partial_shape_to_tensor(row_shape)
|
self._row_shape = _partial_shape_to_tensor(row_shape)
|
||||||
@ -1493,22 +1492,6 @@ class GroupByWindowDataset(Dataset):
|
|||||||
return self._output_types
|
return self._output_types
|
||||||
|
|
||||||
|
|
||||||
def _most_specific_compatible_shape(s1, s2):
|
|
||||||
"""Returns the most specific shape compatible with `s1` and `s2`."""
|
|
||||||
if s1.dims is None:
|
|
||||||
return s1
|
|
||||||
if s2.dims is None:
|
|
||||||
return s2
|
|
||||||
s1.assert_same_rank(s2)
|
|
||||||
dims = []
|
|
||||||
for dim1, dim2 in zip(s1, s2):
|
|
||||||
if dim1.value is None or dim2.value is None or dim1.value != dim2.value:
|
|
||||||
dims.append(tensor_shape.Dimension(None))
|
|
||||||
else:
|
|
||||||
dims.append(dim1.value)
|
|
||||||
return tensor_shape.TensorShape(dims)
|
|
||||||
|
|
||||||
|
|
||||||
class MapDataset(Dataset):
|
class MapDataset(Dataset):
|
||||||
"""A `Dataset` that maps a function over elements in its input."""
|
"""A `Dataset` that maps a function over elements in its input."""
|
||||||
|
|
||||||
@ -1593,9 +1576,7 @@ class MapDataset(Dataset):
|
|||||||
class FlatMapDataset(Dataset):
|
class FlatMapDataset(Dataset):
|
||||||
"""A `Dataset` that maps a function over its input and flattens the result."""
|
"""A `Dataset` that maps a function over its input and flattens the result."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, input_dataset, map_func):
|
||||||
input_dataset,
|
|
||||||
map_func):
|
|
||||||
"""See `Dataset.flat_map()` for details."""
|
"""See `Dataset.flat_map()` for details."""
|
||||||
super(FlatMapDataset, self).__init__()
|
super(FlatMapDataset, self).__init__()
|
||||||
self._input_dataset = input_dataset
|
self._input_dataset = input_dataset
|
||||||
@ -1799,8 +1780,11 @@ class FixedLengthRecordDataset(Dataset):
|
|||||||
return dtypes.string
|
return dtypes.string
|
||||||
|
|
||||||
|
|
||||||
def rejection_resample(dataset, class_func, target_dist,
|
def rejection_resample(dataset,
|
||||||
initial_dist=None, seed=None):
|
class_func,
|
||||||
|
target_dist,
|
||||||
|
initial_dist=None,
|
||||||
|
seed=None):
|
||||||
"""Resamples this dataset to achieve a target class distribution.
|
"""Resamples this dataset to achieve a target class distribution.
|
||||||
|
|
||||||
**NOTE** Resampling is performed via rejection sampling; some fraction
|
**NOTE** Resampling is performed via rejection sampling; some fraction
|
||||||
@ -1825,36 +1809,34 @@ def rejection_resample(dataset, class_func, target_dist,
|
|||||||
target_dist = ops.convert_to_tensor(target_dist, name="initial_dist")
|
target_dist = ops.convert_to_tensor(target_dist, name="initial_dist")
|
||||||
class_values_ds = dataset.map(class_func)
|
class_values_ds = dataset.map(class_func)
|
||||||
if initial_dist is not None:
|
if initial_dist is not None:
|
||||||
initial_dist = ops.convert_to_tensor(
|
initial_dist = ops.convert_to_tensor(initial_dist, name="initial_dist")
|
||||||
initial_dist, name="initial_dist")
|
|
||||||
acceptance_dist = _calculate_acceptance_probs(initial_dist, target_dist)
|
acceptance_dist = _calculate_acceptance_probs(initial_dist, target_dist)
|
||||||
initial_dist_ds = Dataset.from_tensors(initial_dist).repeat()
|
initial_dist_ds = Dataset.from_tensors(initial_dist).repeat()
|
||||||
acceptance_dist_ds = Dataset.from_tensors(acceptance_dist).repeat()
|
acceptance_dist_ds = Dataset.from_tensors(acceptance_dist).repeat()
|
||||||
else:
|
else:
|
||||||
num_classes = (target_dist.shape[0].value
|
num_classes = (target_dist.shape[0].value or
|
||||||
or array_ops.shape(target_dist)[0])
|
array_ops.shape(target_dist)[0])
|
||||||
smoothing_constant = 10
|
smoothing_constant = 10
|
||||||
num_examples_per_class_seen = resource_variable_ops.ResourceVariable(
|
num_examples_per_class_seen = resource_variable_ops.ResourceVariable(
|
||||||
initial_value=array_ops.fill(
|
initial_value=array_ops.fill([num_classes],
|
||||||
[num_classes], np.int64(smoothing_constant)),
|
np.int64(smoothing_constant)),
|
||||||
trainable=False,
|
trainable=False,
|
||||||
name="class_count",
|
name="class_count",
|
||||||
dtype=dtypes.int64)
|
dtype=dtypes.int64)
|
||||||
|
|
||||||
def update_estimate_and_tile(c):
|
def update_estimate_and_tile(c):
|
||||||
return array_ops.tile(
|
return array_ops.tile(
|
||||||
array_ops.expand_dims(
|
array_ops.expand_dims(
|
||||||
_estimate_data_distribution(c, num_examples_per_class_seen), 0),
|
_estimate_data_distribution(c, num_examples_per_class_seen), 0),
|
||||||
[dist_estimation_batch_size, 1])
|
[dist_estimation_batch_size, 1])
|
||||||
initial_dist_ds = (class_values_ds
|
|
||||||
.batch(dist_estimation_batch_size)
|
initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size)
|
||||||
.map(update_estimate_and_tile)
|
.map(update_estimate_and_tile).unbatch())
|
||||||
.unbatch())
|
|
||||||
acceptance_dist_ds = initial_dist_ds.map(
|
acceptance_dist_ds = initial_dist_ds.map(
|
||||||
lambda initial: _calculate_acceptance_probs(initial, target_dist))
|
lambda initial: _calculate_acceptance_probs(initial, target_dist))
|
||||||
|
|
||||||
def maybe_warn_on_large_rejection(accept_dist, initial_dist):
|
def maybe_warn_on_large_rejection(accept_dist, initial_dist):
|
||||||
proportion_rejected = math_ops.reduce_sum(
|
proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist)
|
||||||
(1 - accept_dist) * initial_dist)
|
|
||||||
return control_flow_ops.cond(
|
return control_flow_ops.cond(
|
||||||
math_ops.less(proportion_rejected, .5),
|
math_ops.less(proportion_rejected, .5),
|
||||||
lambda: accept_dist,
|
lambda: accept_dist,
|
||||||
@ -1864,12 +1846,10 @@ def rejection_resample(dataset, class_func, target_dist,
|
|||||||
summarize=100,
|
summarize=100,
|
||||||
first_n=10))
|
first_n=10))
|
||||||
|
|
||||||
acceptance_dist_ds = (
|
acceptance_dist_ds = (Dataset.zip((acceptance_dist_ds, initial_dist_ds))
|
||||||
Dataset.zip((acceptance_dist_ds, initial_dist_ds))
|
.map(maybe_warn_on_large_rejection))
|
||||||
.map(maybe_warn_on_large_rejection))
|
|
||||||
|
|
||||||
current_probabilities_ds = (Dataset
|
current_probabilities_ds = (Dataset.zip((acceptance_dist_ds, class_values_ds))
|
||||||
.zip((acceptance_dist_ds, class_values_ds))
|
|
||||||
.map(array_ops.gather))
|
.map(array_ops.gather))
|
||||||
filtered_ds = (
|
filtered_ds = (
|
||||||
Dataset.zip((class_values_ds, current_probabilities_ds, dataset))
|
Dataset.zip((class_values_ds, current_probabilities_ds, dataset))
|
||||||
|
@ -736,6 +736,36 @@ class TensorShape(object):
|
|||||||
if not self.is_compatible_with(other):
|
if not self.is_compatible_with(other):
|
||||||
raise ValueError("Shapes %s and %s are incompatible" % (self, other))
|
raise ValueError("Shapes %s and %s are incompatible" % (self, other))
|
||||||
|
|
||||||
|
def most_specific_compatible_shape(self, other):
|
||||||
|
"""Returns the most specific TensorShape compatible with `self` and `other`.
|
||||||
|
|
||||||
|
* TensorShape([None, 1]) is the most specific TensorShape compatible with
|
||||||
|
both TensorShape([2, 1]) and TensorShape([5, 1]). Note that
|
||||||
|
TensorShape(None) is also compatible with above mentioned TensorShapes.
|
||||||
|
|
||||||
|
* TensorShape([1, 2, 3]) is the most specific TensorShape compatible with
|
||||||
|
both TensorShape([1, 2, 3]) and TensorShape([1, 2, 3]). There are more
|
||||||
|
less specific TensorShapes compatible with above mentioned TensorShapes,
|
||||||
|
e.g. TensorShape([1, 2, None]), TensorShape(None).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other: Another `TensorShape`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `TensorShape` which is the most specific compatible shape of `self`
|
||||||
|
and `other`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
other = as_shape(other)
|
||||||
|
if self._dims is None or other.dims is None or self.ndims != other.ndims:
|
||||||
|
return unknown_shape()
|
||||||
|
|
||||||
|
dims = [(Dimension(None))] * self.ndims
|
||||||
|
for i, (d1, d2) in enumerate(zip(self._dims, other.dims)):
|
||||||
|
if d1 is not None and d2 is not None and d1 == d2:
|
||||||
|
dims[i] = d1
|
||||||
|
return TensorShape(dims)
|
||||||
|
|
||||||
def is_fully_defined(self):
|
def is_fully_defined(self):
|
||||||
"""Returns True iff `self` is fully defined in every dimension."""
|
"""Returns True iff `self` is fully defined in every dimension."""
|
||||||
return (self._dims is not None and all(dim.value is not None
|
return (self._dims is not None and all(dim.value is not None
|
||||||
|
@ -275,6 +275,26 @@ class ShapeTest(test_util.TensorFlowTestCase):
|
|||||||
tensor_shape.TensorShape([1, 2]).concatenate(
|
tensor_shape.TensorShape([1, 2]).concatenate(
|
||||||
tensor_shape.Dimension(3)))
|
tensor_shape.Dimension(3)))
|
||||||
|
|
||||||
|
def _testMostSpecificCompatibleShapeHelper(self, x, y, expected):
|
||||||
|
mcs = tensor_shape.TensorShape(x).most_specific_compatible_shape(
|
||||||
|
tensor_shape.TensorShape(y))
|
||||||
|
mcs_dims = mcs.dims
|
||||||
|
if expected is None or mcs_dims is None:
|
||||||
|
self.assertIs(expected, mcs_dims)
|
||||||
|
else:
|
||||||
|
self.assertEqual(expected, mcs.as_list())
|
||||||
|
|
||||||
|
def testMostSpecificCompatibleShape(self):
|
||||||
|
self._testMostSpecificCompatibleShapeHelper([1, 2], None, None)
|
||||||
|
self._testMostSpecificCompatibleShapeHelper(None, [1, 2], None)
|
||||||
|
self._testMostSpecificCompatibleShapeHelper([1, 2], [1, 2, 3, 4], None)
|
||||||
|
self._testMostSpecificCompatibleShapeHelper([1, 2, 3, 4], [1, 2], None)
|
||||||
|
self._testMostSpecificCompatibleShapeHelper([1, 2], [1, 2], [1, 2])
|
||||||
|
self._testMostSpecificCompatibleShapeHelper([None, 2, 3], [1, 1, 3],
|
||||||
|
[None, None, 3])
|
||||||
|
self._testMostSpecificCompatibleShapeHelper([1, 1, 3], [None, 2, 3],
|
||||||
|
[None, None, 3])
|
||||||
|
|
||||||
def testHelpers(self):
|
def testHelpers(self):
|
||||||
tensor_shape.TensorShape([]).assert_is_compatible_with(
|
tensor_shape.TensorShape([]).assert_is_compatible_with(
|
||||||
tensor_shape.scalar())
|
tensor_shape.scalar())
|
||||||
|
@ -54,6 +54,10 @@ tf_class {
|
|||||||
name: "merge_with"
|
name: "merge_with"
|
||||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "most_specific_compatible_shape"
|
||||||
|
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "num_elements"
|
name: "num_elements"
|
||||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
Loading…
Reference in New Issue
Block a user