[tf.data] Changing the tf.data.experimental.rejection_resampling
implementation to avoid relying on the assumption that a dataset copy produces elements in the same order as the original dataset -- which is not guaranteed to be true (e.g. for shuffled datasets).
PiperOrigin-RevId: 289165771 Change-Id: I430aed5aee8e58e29e2af6292ebf1cc81b2068db
This commit is contained in:
parent
be948988a2
commit
f313fdce45
@ -44,8 +44,7 @@ class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
initial_dist = [0.2] * 5 if initial_known else None
|
||||
classes = math_ops.cast(classes, dtypes.int64) # needed for Windows build.
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(classes).shuffle(
|
||||
200, seed=21, reshuffle_each_iteration=False).map(
|
||||
lambda c: (c, string_ops.as_string(c))).repeat()
|
||||
200, seed=21).map(lambda c: (c, string_ops.as_string(c))).repeat()
|
||||
|
||||
get_next = self.getNext(
|
||||
dataset.apply(
|
||||
|
@ -56,7 +56,6 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
|
||||
def _apply_fn(dataset):
|
||||
"""Function from `Dataset` to `Dataset` that applies the transformation."""
|
||||
target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist")
|
||||
class_values_ds = dataset.map(class_func)
|
||||
|
||||
# Get initial distribution.
|
||||
if initial_dist is not None:
|
||||
@ -71,8 +70,8 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
|
||||
prob_of_original_ds = dataset_ops.Dataset.from_tensors(
|
||||
prob_of_original).repeat()
|
||||
else:
|
||||
initial_dist_ds = _estimate_initial_dist_ds(
|
||||
target_dist_t, class_values_ds)
|
||||
initial_dist_ds = _estimate_initial_dist_ds(target_dist_t,
|
||||
dataset.map(class_func))
|
||||
acceptance_and_original_prob_ds = initial_dist_ds.map(
|
||||
lambda initial: _calculate_acceptance_probs_with_mixing( # pylint: disable=g-long-lambda
|
||||
initial, target_dist_t))
|
||||
@ -81,19 +80,26 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
|
||||
prob_of_original_ds = acceptance_and_original_prob_ds.map(
|
||||
lambda _, prob_original: prob_original)
|
||||
filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds,
|
||||
class_values_ds, seed)
|
||||
class_func, seed)
|
||||
# Prefetch filtered dataset for speed.
|
||||
filtered_ds = filtered_ds.prefetch(3)
|
||||
|
||||
prob_original_static = _get_prob_original_static(
|
||||
initial_dist_t, target_dist_t) if initial_dist is not None else None
|
||||
|
||||
def add_class_value(*x):
|
||||
if len(x) == 1:
|
||||
return class_func(*x), x[0]
|
||||
else:
|
||||
return class_func(*x), x
|
||||
|
||||
if prob_original_static == 1:
|
||||
return dataset_ops.Dataset.zip((class_values_ds, dataset))
|
||||
return dataset.map(add_class_value)
|
||||
elif prob_original_static == 0:
|
||||
return filtered_ds
|
||||
else:
|
||||
return interleave_ops.sample_from_datasets(
|
||||
[dataset_ops.Dataset.zip((class_values_ds, dataset)), filtered_ds],
|
||||
[dataset.map(add_class_value), filtered_ds],
|
||||
weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]),
|
||||
seed=seed)
|
||||
|
||||
@ -123,8 +129,7 @@ def _get_prob_original_static(initial_dist_t, target_dist_t):
|
||||
return np.min(target_static / init_static)
|
||||
|
||||
|
||||
def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds,
|
||||
seed):
|
||||
def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_func, seed):
|
||||
"""Filters a dataset based on per-class acceptance probabilities.
|
||||
|
||||
Args:
|
||||
@ -132,7 +137,8 @@ def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds,
|
||||
acceptance_dist_ds: A dataset of acceptance probabilities.
|
||||
initial_dist_ds: A dataset of the initial probability distribution, given or
|
||||
estimated.
|
||||
class_values_ds: A dataset of the corresponding classes.
|
||||
class_func: A function mapping an element of the input dataset to a scalar
|
||||
`tf.int32` tensor. Values should be in `[0, num_classes)`.
|
||||
seed: (Optional.) Python integer seed for the resampler.
|
||||
|
||||
Returns:
|
||||
@ -153,14 +159,18 @@ def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds,
|
||||
initial_dist_ds))
|
||||
.map(maybe_warn_on_large_rejection))
|
||||
|
||||
def _gather_and_copy(class_val, acceptance_prob, data):
|
||||
def _gather_and_copy(acceptance_prob, data):
|
||||
if isinstance(data, tuple):
|
||||
class_val = class_func(*data)
|
||||
else:
|
||||
class_val = class_func(data)
|
||||
return class_val, array_ops.gather(acceptance_prob, class_val), data
|
||||
|
||||
current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip(
|
||||
(class_values_ds, acceptance_dist_ds, dataset)).map(_gather_and_copy)
|
||||
(acceptance_dist_ds, dataset)).map(_gather_and_copy)
|
||||
filtered_ds = (
|
||||
current_probabilities_and_class_and_data_ds
|
||||
.filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p))
|
||||
current_probabilities_and_class_and_data_ds.filter(
|
||||
lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p))
|
||||
return filtered_ds.map(lambda class_value, _, data: (class_value, data))
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user