STT-tensorflow/tensorflow/python/data/experimental/ops/resampling.py

305 lines
12 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Resampling dataset transformations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.experimental.ops import scan_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.util.tf_export import tf_export
@tf_export("data.experimental.rejection_resample")
def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
"""A transformation that resamples a dataset to achieve a target distribution.
**NOTE** Resampling is performed via rejection sampling; some fraction
of the input values will be dropped.
Args:
class_func: A function mapping an element of the input dataset to a scalar
`tf.int32` tensor. Values should be in `[0, num_classes)`.
target_dist: A floating point type tensor, shaped `[num_classes]`.
initial_dist: (Optional.) A floating point type tensor, shaped
`[num_classes]`. If not provided, the true class distribution is
estimated live in a streaming fashion.
seed: (Optional.) Python integer seed for the resampler.
Returns:
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist")
# Get initial distribution.
if initial_dist is not None:
initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist")
acceptance_dist, prob_of_original = (
_calculate_acceptance_probs_with_mixing(initial_dist_t,
target_dist_t))
initial_dist_ds = dataset_ops.Dataset.from_tensors(
initial_dist_t).repeat()
acceptance_dist_ds = dataset_ops.Dataset.from_tensors(
acceptance_dist).repeat()
prob_of_original_ds = dataset_ops.Dataset.from_tensors(
prob_of_original).repeat()
else:
initial_dist_ds = _estimate_initial_dist_ds(target_dist_t,
dataset.map(class_func))
acceptance_and_original_prob_ds = initial_dist_ds.map(
lambda initial: _calculate_acceptance_probs_with_mixing( # pylint: disable=g-long-lambda
initial, target_dist_t))
acceptance_dist_ds = acceptance_and_original_prob_ds.map(
lambda accept_prob, _: accept_prob)
prob_of_original_ds = acceptance_and_original_prob_ds.map(
lambda _, prob_original: prob_original)
filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds,
class_func, seed)
# Prefetch filtered dataset for speed.
filtered_ds = filtered_ds.prefetch(3)
prob_original_static = _get_prob_original_static(
initial_dist_t, target_dist_t) if initial_dist is not None else None
def add_class_value(*x):
if len(x) == 1:
return class_func(*x), x[0]
else:
return class_func(*x), x
if prob_original_static == 1:
return dataset.map(add_class_value)
elif prob_original_static == 0:
return filtered_ds
else:
return interleave_ops.sample_from_datasets(
[dataset.map(add_class_value), filtered_ds],
weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]),
seed=seed)
return _apply_fn
def _get_prob_original_static(initial_dist_t, target_dist_t):
"""Returns the static probability of sampling from the original.
`tensor_util.constant_value(prob_of_original)` returns `None` if it encounters
an Op that it isn't defined for. We have some custom logic to avoid this.
Args:
initial_dist_t: A tensor of the initial distribution.
target_dist_t: A tensor of the target distribution.
Returns:
The probability of sampling from the original distribution as a constant,
if it is a constant, or `None`.
"""
init_static = tensor_util.constant_value(initial_dist_t)
target_static = tensor_util.constant_value(target_dist_t)
if init_static is None or target_static is None:
return None
else:
return np.min(target_static / init_static)
def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_func, seed):
"""Filters a dataset based on per-class acceptance probabilities.
Args:
dataset: The dataset to be filtered.
acceptance_dist_ds: A dataset of acceptance probabilities.
initial_dist_ds: A dataset of the initial probability distribution, given or
estimated.
class_func: A function mapping an element of the input dataset to a scalar
`tf.int32` tensor. Values should be in `[0, num_classes)`.
seed: (Optional.) Python integer seed for the resampler.
Returns:
A dataset of (class value, data) after filtering.
"""
def maybe_warn_on_large_rejection(accept_dist, initial_dist):
proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist)
return control_flow_ops.cond(
math_ops.less(proportion_rejected, .5),
lambda: accept_dist,
lambda: logging_ops.Print( # pylint: disable=g-long-lambda
accept_dist, [proportion_rejected, initial_dist, accept_dist],
message="Proportion of examples rejected by sampler is high: ",
summarize=100,
first_n=10))
acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds,
initial_dist_ds))
.map(maybe_warn_on_large_rejection))
def _gather_and_copy(acceptance_prob, data):
if isinstance(data, tuple):
class_val = class_func(*data)
else:
class_val = class_func(data)
return class_val, array_ops.gather(acceptance_prob, class_val), data
current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip(
(acceptance_dist_ds, dataset)).map(_gather_and_copy)
filtered_ds = (
current_probabilities_and_class_and_data_ds.filter(
lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p))
return filtered_ds.map(lambda class_value, _, data: (class_value, data))
def _estimate_initial_dist_ds(
target_dist_t, class_values_ds, dist_estimation_batch_size=32,
smoothing_constant=10):
num_classes = (target_dist_t.shape[0] or array_ops.shape(target_dist_t)[0])
initial_examples_per_class_seen = array_ops.fill(
[num_classes], np.int64(smoothing_constant))
def update_estimate_and_tile(num_examples_per_class_seen, c):
updated_examples_per_class_seen, dist = _estimate_data_distribution(
c, num_examples_per_class_seen)
tiled_dist = array_ops.tile(
array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1])
return updated_examples_per_class_seen, tiled_dist
initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size)
.apply(scan_ops.scan(initial_examples_per_class_seen,
update_estimate_and_tile))
.unbatch())
return initial_dist_ds
def _get_target_to_initial_ratio(initial_probs, target_probs):
# Add tiny to initial_probs to avoid divide by zero.
denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny)
return target_probs / denom
def _estimate_data_distribution(c, num_examples_per_class_seen):
"""Estimate data distribution as labels are seen.
Args:
c: The class labels. Type `int32`, shape `[batch_size]`.
num_examples_per_class_seen: Type `int64`, shape `[num_classes]`,
containing counts.
Returns:
num_examples_per_lass_seen: Updated counts. Type `int64`, shape
`[num_classes]`.
dist: The updated distribution. Type `float32`, shape `[num_classes]`.
"""
num_classes = num_examples_per_class_seen.get_shape()[0]
# Update the class-count based on what labels are seen in batch.
num_examples_per_class_seen = math_ops.add(
num_examples_per_class_seen, math_ops.reduce_sum(
array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0))
init_prob_estimate = math_ops.truediv(
num_examples_per_class_seen,
math_ops.reduce_sum(num_examples_per_class_seen))
dist = math_ops.cast(init_prob_estimate, dtypes.float32)
return num_examples_per_class_seen, dist
def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs):
"""Calculates the acceptance probabilities and mixing ratio.
In this case, we assume that we can *either* sample from the original data
distribution with probability `m`, or sample from a reshaped distribution
that comes from rejection sampling on the original distribution. This
rejection sampling is done on a per-class basis, with `a_i` representing the
probability of accepting data from class `i`.
This method is based on solving the following analysis for the reshaped
distribution:
Let F be the probability of a rejection (on any example).
Let p_i be the proportion of examples in the data in class i (init_probs)
Let a_i is the rate the rejection sampler should *accept* class i
Let t_i is the target proportion in the minibatches for class i (target_probs)
```
F = sum_i(p_i * (1-a_i))
= 1 - sum_i(p_i * a_i) using sum_i(p_i) = 1
```
An example with class `i` will be accepted if `k` rejections occur, then an
example with class `i` is seen by the rejector, and it is accepted. This can
be written as follows:
```
t_i = sum_k=0^inf(F^k * p_i * a_i)
= p_i * a_j / (1 - F) using geometric series identity, since 0 <= F < 1
= p_i * a_i / sum_j(p_j * a_j) using F from above
```
Note that the following constraints hold:
```
0 <= p_i <= 1, sum_i(p_i) = 1
0 <= a_i <= 1
0 <= t_i <= 1, sum_i(t_i) = 1
```
A solution for a_i in terms of the other variables is the following:
```a_i = (t_i / p_i) / max_i[t_i / p_i]```
If we try to minimize the amount of data rejected, we get the following:
M_max = max_i [ t_i / p_i ]
M_min = min_i [ t_i / p_i ]
The desired probability of accepting data if it comes from class `i`:
a_i = (t_i/p_i - m) / (M_max - m)
The desired probability of pulling a data element from the original dataset,
rather than the filtered one:
m = M_min
Args:
initial_probs: A Tensor of the initial probability distribution, given or
estimated.
target_probs: A Tensor of the corresponding classes.
Returns:
(A 1D Tensor with the per-class acceptance probabilities, the desired
probability of pull from the original distribution.)
"""
ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs)
max_ratio = math_ops.reduce_max(ratio_l)
min_ratio = math_ops.reduce_min(ratio_l)
# Target prob to sample from original distribution.
m = min_ratio
# TODO(joelshor): Simplify fraction, if possible.
a_i = (ratio_l - m) / (max_ratio - m)
return a_i, m