[tf.data] Changing the tf.data.experimental.rejection_resampling
implementation to avoid relying on the assumption that a dataset copy produces elements in the same order as the original dataset -- which is not guaranteed to be true (e.g. for shuffled datasets).
PiperOrigin-RevId: 289165771 Change-Id: I430aed5aee8e58e29e2af6292ebf1cc81b2068db
This commit is contained in:
parent
be948988a2
commit
f313fdce45
@ -44,8 +44,7 @@ class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
initial_dist = [0.2] * 5 if initial_known else None
|
initial_dist = [0.2] * 5 if initial_known else None
|
||||||
classes = math_ops.cast(classes, dtypes.int64) # needed for Windows build.
|
classes = math_ops.cast(classes, dtypes.int64) # needed for Windows build.
|
||||||
dataset = dataset_ops.Dataset.from_tensor_slices(classes).shuffle(
|
dataset = dataset_ops.Dataset.from_tensor_slices(classes).shuffle(
|
||||||
200, seed=21, reshuffle_each_iteration=False).map(
|
200, seed=21).map(lambda c: (c, string_ops.as_string(c))).repeat()
|
||||||
lambda c: (c, string_ops.as_string(c))).repeat()
|
|
||||||
|
|
||||||
get_next = self.getNext(
|
get_next = self.getNext(
|
||||||
dataset.apply(
|
dataset.apply(
|
||||||
|
@ -56,7 +56,6 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
|
|||||||
def _apply_fn(dataset):
|
def _apply_fn(dataset):
|
||||||
"""Function from `Dataset` to `Dataset` that applies the transformation."""
|
"""Function from `Dataset` to `Dataset` that applies the transformation."""
|
||||||
target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist")
|
target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist")
|
||||||
class_values_ds = dataset.map(class_func)
|
|
||||||
|
|
||||||
# Get initial distribution.
|
# Get initial distribution.
|
||||||
if initial_dist is not None:
|
if initial_dist is not None:
|
||||||
@ -71,8 +70,8 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
|
|||||||
prob_of_original_ds = dataset_ops.Dataset.from_tensors(
|
prob_of_original_ds = dataset_ops.Dataset.from_tensors(
|
||||||
prob_of_original).repeat()
|
prob_of_original).repeat()
|
||||||
else:
|
else:
|
||||||
initial_dist_ds = _estimate_initial_dist_ds(
|
initial_dist_ds = _estimate_initial_dist_ds(target_dist_t,
|
||||||
target_dist_t, class_values_ds)
|
dataset.map(class_func))
|
||||||
acceptance_and_original_prob_ds = initial_dist_ds.map(
|
acceptance_and_original_prob_ds = initial_dist_ds.map(
|
||||||
lambda initial: _calculate_acceptance_probs_with_mixing( # pylint: disable=g-long-lambda
|
lambda initial: _calculate_acceptance_probs_with_mixing( # pylint: disable=g-long-lambda
|
||||||
initial, target_dist_t))
|
initial, target_dist_t))
|
||||||
@ -81,19 +80,26 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
|
|||||||
prob_of_original_ds = acceptance_and_original_prob_ds.map(
|
prob_of_original_ds = acceptance_and_original_prob_ds.map(
|
||||||
lambda _, prob_original: prob_original)
|
lambda _, prob_original: prob_original)
|
||||||
filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds,
|
filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds,
|
||||||
class_values_ds, seed)
|
class_func, seed)
|
||||||
# Prefetch filtered dataset for speed.
|
# Prefetch filtered dataset for speed.
|
||||||
filtered_ds = filtered_ds.prefetch(3)
|
filtered_ds = filtered_ds.prefetch(3)
|
||||||
|
|
||||||
prob_original_static = _get_prob_original_static(
|
prob_original_static = _get_prob_original_static(
|
||||||
initial_dist_t, target_dist_t) if initial_dist is not None else None
|
initial_dist_t, target_dist_t) if initial_dist is not None else None
|
||||||
|
|
||||||
|
def add_class_value(*x):
|
||||||
|
if len(x) == 1:
|
||||||
|
return class_func(*x), x[0]
|
||||||
|
else:
|
||||||
|
return class_func(*x), x
|
||||||
|
|
||||||
if prob_original_static == 1:
|
if prob_original_static == 1:
|
||||||
return dataset_ops.Dataset.zip((class_values_ds, dataset))
|
return dataset.map(add_class_value)
|
||||||
elif prob_original_static == 0:
|
elif prob_original_static == 0:
|
||||||
return filtered_ds
|
return filtered_ds
|
||||||
else:
|
else:
|
||||||
return interleave_ops.sample_from_datasets(
|
return interleave_ops.sample_from_datasets(
|
||||||
[dataset_ops.Dataset.zip((class_values_ds, dataset)), filtered_ds],
|
[dataset.map(add_class_value), filtered_ds],
|
||||||
weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]),
|
weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]),
|
||||||
seed=seed)
|
seed=seed)
|
||||||
|
|
||||||
@ -123,8 +129,7 @@ def _get_prob_original_static(initial_dist_t, target_dist_t):
|
|||||||
return np.min(target_static / init_static)
|
return np.min(target_static / init_static)
|
||||||
|
|
||||||
|
|
||||||
def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds,
|
def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_func, seed):
|
||||||
seed):
|
|
||||||
"""Filters a dataset based on per-class acceptance probabilities.
|
"""Filters a dataset based on per-class acceptance probabilities.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -132,7 +137,8 @@ def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds,
|
|||||||
acceptance_dist_ds: A dataset of acceptance probabilities.
|
acceptance_dist_ds: A dataset of acceptance probabilities.
|
||||||
initial_dist_ds: A dataset of the initial probability distribution, given or
|
initial_dist_ds: A dataset of the initial probability distribution, given or
|
||||||
estimated.
|
estimated.
|
||||||
class_values_ds: A dataset of the corresponding classes.
|
class_func: A function mapping an element of the input dataset to a scalar
|
||||||
|
`tf.int32` tensor. Values should be in `[0, num_classes)`.
|
||||||
seed: (Optional.) Python integer seed for the resampler.
|
seed: (Optional.) Python integer seed for the resampler.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -153,14 +159,18 @@ def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds,
|
|||||||
initial_dist_ds))
|
initial_dist_ds))
|
||||||
.map(maybe_warn_on_large_rejection))
|
.map(maybe_warn_on_large_rejection))
|
||||||
|
|
||||||
def _gather_and_copy(class_val, acceptance_prob, data):
|
def _gather_and_copy(acceptance_prob, data):
|
||||||
|
if isinstance(data, tuple):
|
||||||
|
class_val = class_func(*data)
|
||||||
|
else:
|
||||||
|
class_val = class_func(data)
|
||||||
return class_val, array_ops.gather(acceptance_prob, class_val), data
|
return class_val, array_ops.gather(acceptance_prob, class_val), data
|
||||||
|
|
||||||
current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip(
|
current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip(
|
||||||
(class_values_ds, acceptance_dist_ds, dataset)).map(_gather_and_copy)
|
(acceptance_dist_ds, dataset)).map(_gather_and_copy)
|
||||||
filtered_ds = (
|
filtered_ds = (
|
||||||
current_probabilities_and_class_and_data_ds
|
current_probabilities_and_class_and_data_ds.filter(
|
||||||
.filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p))
|
lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p))
|
||||||
return filtered_ds.map(lambda class_value, _, data: (class_value, data))
|
return filtered_ds.map(lambda class_value, _, data: (class_value, data))
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user