[tf.data] Changing the tf.data.experimental.rejection_resampling implementation to avoid relying on the assumption that a dataset copy produces elements in the same order as the original dataset -- which is not guaranteed to be true (e.g. for shuffled datasets).

PiperOrigin-RevId: 289165771
Change-Id: I430aed5aee8e58e29e2af6292ebf1cc81b2068db
This commit is contained in:
Jiri Simsa 2020-01-10 14:27:21 -08:00 committed by TensorFlower Gardener
parent be948988a2
commit f313fdce45
2 changed files with 24 additions and 15 deletions

View File

@ -44,8 +44,7 @@ class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
initial_dist = [0.2] * 5 if initial_known else None
classes = math_ops.cast(classes, dtypes.int64) # needed for Windows build.
dataset = dataset_ops.Dataset.from_tensor_slices(classes).shuffle(
200, seed=21, reshuffle_each_iteration=False).map(
lambda c: (c, string_ops.as_string(c))).repeat()
200, seed=21).map(lambda c: (c, string_ops.as_string(c))).repeat()
get_next = self.getNext(
dataset.apply(

View File

@ -56,7 +56,6 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist")
class_values_ds = dataset.map(class_func)
# Get initial distribution.
if initial_dist is not None:
@ -71,8 +70,8 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
prob_of_original_ds = dataset_ops.Dataset.from_tensors(
prob_of_original).repeat()
else:
initial_dist_ds = _estimate_initial_dist_ds(
target_dist_t, class_values_ds)
initial_dist_ds = _estimate_initial_dist_ds(target_dist_t,
dataset.map(class_func))
acceptance_and_original_prob_ds = initial_dist_ds.map(
lambda initial: _calculate_acceptance_probs_with_mixing( # pylint: disable=g-long-lambda
initial, target_dist_t))
@ -81,19 +80,26 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
prob_of_original_ds = acceptance_and_original_prob_ds.map(
lambda _, prob_original: prob_original)
filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds,
class_values_ds, seed)
class_func, seed)
# Prefetch filtered dataset for speed.
filtered_ds = filtered_ds.prefetch(3)
prob_original_static = _get_prob_original_static(
initial_dist_t, target_dist_t) if initial_dist is not None else None
def add_class_value(*x):
if len(x) == 1:
return class_func(*x), x[0]
else:
return class_func(*x), x
if prob_original_static == 1:
return dataset_ops.Dataset.zip((class_values_ds, dataset))
return dataset.map(add_class_value)
elif prob_original_static == 0:
return filtered_ds
else:
return interleave_ops.sample_from_datasets(
[dataset_ops.Dataset.zip((class_values_ds, dataset)), filtered_ds],
[dataset.map(add_class_value), filtered_ds],
weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]),
seed=seed)
@ -123,8 +129,7 @@ def _get_prob_original_static(initial_dist_t, target_dist_t):
return np.min(target_static / init_static)
def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds,
seed):
def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_func, seed):
"""Filters a dataset based on per-class acceptance probabilities.
Args:
@ -132,7 +137,8 @@ def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds,
acceptance_dist_ds: A dataset of acceptance probabilities.
initial_dist_ds: A dataset of the initial probability distribution, given or
estimated.
class_values_ds: A dataset of the corresponding classes.
class_func: A function mapping an element of the input dataset to a scalar
`tf.int32` tensor. Values should be in `[0, num_classes)`.
seed: (Optional.) Python integer seed for the resampler.
Returns:
@ -153,14 +159,18 @@ def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds,
initial_dist_ds))
.map(maybe_warn_on_large_rejection))
def _gather_and_copy(class_val, acceptance_prob, data):
def _gather_and_copy(acceptance_prob, data):
if isinstance(data, tuple):
class_val = class_func(*data)
else:
class_val = class_func(data)
return class_val, array_ops.gather(acceptance_prob, class_val), data
current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip(
(class_values_ds, acceptance_dist_ds, dataset)).map(_gather_and_copy)
(acceptance_dist_ds, dataset)).map(_gather_and_copy)
filtered_ds = (
current_probabilities_and_class_and_data_ds
.filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p))
current_probabilities_and_class_and_data_ds.filter(
lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p))
return filtered_ds.map(lambda class_value, _, data: (class_value, data))