diff --git a/tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py b/tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py index e9cefb2c616..bc1bbc45ffe 100644 --- a/tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py @@ -44,8 +44,7 @@ class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase): initial_dist = [0.2] * 5 if initial_known else None classes = math_ops.cast(classes, dtypes.int64) # needed for Windows build. dataset = dataset_ops.Dataset.from_tensor_slices(classes).shuffle( - 200, seed=21, reshuffle_each_iteration=False).map( - lambda c: (c, string_ops.as_string(c))).repeat() + 200, seed=21).map(lambda c: (c, string_ops.as_string(c))).repeat() get_next = self.getNext( dataset.apply( diff --git a/tensorflow/python/data/experimental/ops/resampling.py b/tensorflow/python/data/experimental/ops/resampling.py index a9da1a7d092..87d7f8429eb 100644 --- a/tensorflow/python/data/experimental/ops/resampling.py +++ b/tensorflow/python/data/experimental/ops/resampling.py @@ -56,7 +56,6 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist") - class_values_ds = dataset.map(class_func) # Get initial distribution. if initial_dist is not None: @@ -71,8 +70,8 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): prob_of_original_ds = dataset_ops.Dataset.from_tensors( prob_of_original).repeat() else: - initial_dist_ds = _estimate_initial_dist_ds( - target_dist_t, class_values_ds) + initial_dist_ds = _estimate_initial_dist_ds(target_dist_t, + dataset.map(class_func)) acceptance_and_original_prob_ds = initial_dist_ds.map( lambda initial: _calculate_acceptance_probs_with_mixing( # pylint: disable=g-long-lambda initial, target_dist_t)) @@ -81,19 +80,26 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): prob_of_original_ds = acceptance_and_original_prob_ds.map( lambda _, prob_original: prob_original) filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, - class_values_ds, seed) + class_func, seed) # Prefetch filtered dataset for speed. filtered_ds = filtered_ds.prefetch(3) prob_original_static = _get_prob_original_static( initial_dist_t, target_dist_t) if initial_dist is not None else None + + def add_class_value(*x): + if len(x) == 1: + return class_func(*x), x[0] + else: + return class_func(*x), x + if prob_original_static == 1: - return dataset_ops.Dataset.zip((class_values_ds, dataset)) + return dataset.map(add_class_value) elif prob_original_static == 0: return filtered_ds else: return interleave_ops.sample_from_datasets( - [dataset_ops.Dataset.zip((class_values_ds, dataset)), filtered_ds], + [dataset.map(add_class_value), filtered_ds], weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]), seed=seed) @@ -123,8 +129,7 @@ def _get_prob_original_static(initial_dist_t, target_dist_t): return np.min(target_static / init_static) -def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds, - seed): +def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_func, seed): """Filters a dataset based on per-class acceptance probabilities. Args: @@ -132,7 +137,8 @@ def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds, acceptance_dist_ds: A dataset of acceptance probabilities. initial_dist_ds: A dataset of the initial probability distribution, given or estimated. - class_values_ds: A dataset of the corresponding classes. + class_func: A function mapping an element of the input dataset to a scalar + `tf.int32` tensor. Values should be in `[0, num_classes)`. seed: (Optional.) Python integer seed for the resampler. Returns: @@ -153,14 +159,18 @@ def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds, initial_dist_ds)) .map(maybe_warn_on_large_rejection)) - def _gather_and_copy(class_val, acceptance_prob, data): + def _gather_and_copy(acceptance_prob, data): + if isinstance(data, tuple): + class_val = class_func(*data) + else: + class_val = class_func(data) return class_val, array_ops.gather(acceptance_prob, class_val), data current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip( - (class_values_ds, acceptance_dist_ds, dataset)).map(_gather_and_copy) + (acceptance_dist_ds, dataset)).map(_gather_and_copy) filtered_ds = ( - current_probabilities_and_class_and_data_ds - .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p)) + current_probabilities_and_class_and_data_ds.filter( + lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p)) return filtered_ds.map(lambda class_value, _, data: (class_value, data))