Score function estimators with baselines
Change: 130443765
This commit is contained in:
parent
ce57cb7882
commit
9d62d40f9f
tensorflow/contrib/bayesflow
@ -21,4 +21,5 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import,line-too-long
|
||||
from tensorflow.contrib.bayesflow.python.ops import stochastic_gradient_estimators
|
||||
from tensorflow.contrib.bayesflow.python.ops import stochastic_graph
|
||||
|
@ -22,6 +22,7 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
sg = tf.contrib.bayesflow.stochastic_graph
|
||||
sge = tf.contrib.bayesflow.stochastic_gradient_estimators
|
||||
distributions = tf.contrib.distributions
|
||||
|
||||
|
||||
@ -177,7 +178,7 @@ class DistributionTensorTest(tf.test.TestCase):
|
||||
mu=mu,
|
||||
sigma=sigma,
|
||||
dist_value_type=sg.MeanValue(stop_gradient=True),
|
||||
loss_fn=sg.get_score_function_with_baseline(
|
||||
loss_fn=sge.get_score_function_with_constant_baseline(
|
||||
baseline=tf.constant(8.0)))
|
||||
loss = dt.loss([tf.constant(2.0)])
|
||||
self.assertTrue(loss is not None)
|
||||
|
@ -0,0 +1,193 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Stochastic gradient estimators.
|
||||
|
||||
These functions are meant to be used in conjuction with `StochasticTensor`
|
||||
(`loss_fn` parameter) and `surrogate_loss`.
|
||||
|
||||
See Gradient Estimation Using Stochastic Computation Graphs
|
||||
(http://arxiv.org/abs/1506.05254) by Schulman et al., eq. 1 and section 4, for
|
||||
mathematical details.
|
||||
|
||||
## Score function estimator
|
||||
|
||||
The score function is an unbiased estimator of the gradient of `E_p(x)[f(x)]`,
|
||||
where `f(x)` can be considered to be a "loss" term. It is computed as
|
||||
`E_p(x)[f(x) grad(log p(x))]`. A constant `b`, referred to here as the
|
||||
"baseline", can be subtracted from `f(x)` without affecting the expectation. The
|
||||
term `(f(x) - b)` is referred to here as the "advantage".
|
||||
|
||||
Note that the methods defined in this module actually compute the integrand of
|
||||
the score function, such that when taking the gradient, the true score function
|
||||
is computed.
|
||||
|
||||
@@score_function
|
||||
@@get_score_function_with_baseline
|
||||
@@get_score_function_with_constant_baseline
|
||||
@@get_score_function_with_advantage
|
||||
|
||||
## Baseline functions
|
||||
|
||||
Baselines reduce the variance of Monte Carlo estimate of an expectation. The
|
||||
baseline for a stochastic node can be a function of all non-influenced nodes
|
||||
(see section 4 of Schulman et al., linked above). Baselines are also known as
|
||||
"control variates."
|
||||
|
||||
In the context of a MC estimate of `E_p(x)[f(x) - b]`, baseline functions have
|
||||
the signature `(st, fx) => Tensor`, where `st` is a `StochasticTensor` backed by
|
||||
the distribution `p(x)` and `fx` is the influenced loss.
|
||||
|
||||
@@get_mean_baseline
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.training import training
|
||||
from tensorflow.python.util.all_util import make_all
|
||||
|
||||
|
||||
def score_function(dist_tensor, value, loss, baseline=None,
|
||||
name="ScoreFunction"):
|
||||
"""Score function estimator.
|
||||
|
||||
Computes the integrand of the score function with a baseline:
|
||||
`p.log_prob(value) * (loss - baseline)`.
|
||||
|
||||
It will add a `stop_gradient` to the advantage `(loss - baseline)`.
|
||||
|
||||
Args:
|
||||
dist_tensor: `DistributionTensor` p(x).
|
||||
value: `Tensor` x. Samples from p(x).
|
||||
loss: `Tensor`.
|
||||
baseline: `Tensor` broadcastable to `loss`.
|
||||
name: name to prepend ops with.
|
||||
|
||||
Returns:
|
||||
`Tensor` `p.log_prob(x) * (loss - b)`. Taking the gradient yields the score
|
||||
function estimator.
|
||||
"""
|
||||
with ops.name_scope(name, values=[value, loss, baseline]):
|
||||
value = ops.convert_to_tensor(value)
|
||||
loss = ops.convert_to_tensor(loss)
|
||||
if baseline is not None:
|
||||
baseline = ops.convert_to_tensor(baseline)
|
||||
advantage = loss - baseline
|
||||
else:
|
||||
advantage = loss
|
||||
|
||||
advantage = array_ops.stop_gradient(advantage)
|
||||
return dist_tensor.distribution.log_prob(value) * advantage
|
||||
|
||||
|
||||
def get_score_function_with_advantage(advantage_fn=None,
|
||||
name="ScoreFunctionWithAdvantage"):
|
||||
"""Score function estimator with advantage function.
|
||||
|
||||
Args:
|
||||
advantage_fn: callable that takes the `DistributionTensor` and the
|
||||
downstream `loss` and returns a `Tensor` advantage
|
||||
(e.g. `loss - baseline`).
|
||||
name: name to prepend ops with.
|
||||
|
||||
Returns:
|
||||
Callable score function estimator that takes the `DistributionTensor`, the
|
||||
sampled `value`, and the downstream `loss`, and uses the provided advantage.
|
||||
"""
|
||||
|
||||
def score_function_with_advantage(dist_tensor, value, loss):
|
||||
with ops.name_scope(name, values=[value, loss]):
|
||||
advantage = advantage_fn(dist_tensor, loss)
|
||||
advantage = array_ops.stop_gradient(advantage)
|
||||
return dist_tensor.distribution.log_prob(value) * advantage
|
||||
|
||||
return score_function_with_advantage
|
||||
|
||||
|
||||
def get_score_function_with_constant_baseline(baseline, name="ScoreFunction"):
|
||||
"""Score function estimator with constant baseline.
|
||||
|
||||
Args:
|
||||
baseline: `Tensor` to be subtracted from loss.
|
||||
name: name to prepend ops with.
|
||||
|
||||
Returns:
|
||||
Callable score function estimator that takes the `DistributionTensor`, the
|
||||
sampled `value`, and the downstream `loss`, and subtracts the provided
|
||||
`baseline` from the `loss`.
|
||||
"""
|
||||
|
||||
def score_function_with_constant_baseline(dist_tensor, value, loss):
|
||||
return score_function(dist_tensor, value, loss, baseline, name)
|
||||
|
||||
return score_function_with_constant_baseline
|
||||
|
||||
|
||||
def get_score_function_with_baseline(baseline_fn=None, name="ScoreFunction"):
|
||||
"""Score function estimator with baseline function.
|
||||
|
||||
Args:
|
||||
baseline_fn: callable that takes the `DistributionTensor` and the downstream
|
||||
`loss` and returns a `Tensor` baseline to be subtracted from the `loss`.
|
||||
If None, defaults to `get_mean_baseline`, which is an EMA of the loss.
|
||||
name: name to prepend ops with.
|
||||
|
||||
Returns:
|
||||
Callable score function estimator that takes the `DistributionTensor`, the
|
||||
sampled `value`, and the downstream `loss`, and subtracts the provided
|
||||
`baseline` from the `loss`.
|
||||
"""
|
||||
if baseline_fn is None:
|
||||
baseline_fn = get_mean_baseline()
|
||||
|
||||
def score_function_with_baseline(dist_tensor, value, loss):
|
||||
with ops.name_scope(name):
|
||||
b = baseline_fn(dist_tensor, loss)
|
||||
return score_function(dist_tensor, value, loss, b)
|
||||
|
||||
return score_function_with_baseline
|
||||
|
||||
|
||||
def get_mean_baseline(ema_decay=0.99, name="MeanBaseline"):
|
||||
"""ExponentialMovingAverage baseline.
|
||||
|
||||
Args:
|
||||
ema_decay: decay rate for the ExponentialMovingAverage.
|
||||
name: name to prepend ops with.
|
||||
|
||||
Returns:
|
||||
Callable baseline function that takes the `DistributionTensor` (unused) and
|
||||
the downstream `loss`, and returns an EMA of the loss.
|
||||
"""
|
||||
|
||||
def mean_baseline(_, loss):
|
||||
with ops.name_scope(name):
|
||||
ema = training.ExponentialMovingAverage(decay=ema_decay)
|
||||
update_op = ema.apply(math_ops.reduce_mean(loss))
|
||||
with control_flow_ops.control_dependencies([update_op]):
|
||||
# TODO(rsepassi): Possibly implement the initialization bias correction
|
||||
# term from Adam (section 3 of https://arxiv.org/pdf/1412.6980v8.pdf).
|
||||
baseline = ema.average(loss)
|
||||
return baseline
|
||||
|
||||
return mean_baseline
|
||||
|
||||
|
||||
__all__ = make_all(__name__)
|
@ -27,11 +27,6 @@
|
||||
@@value_type
|
||||
@@get_current_value_type
|
||||
|
||||
## Stochastic Computation Surrogate Loss Functions
|
||||
|
||||
@@score_function
|
||||
@@get_score_function_with_baseline
|
||||
|
||||
## Stochastic Computation Graph Helper Functions
|
||||
|
||||
@@surrogate_loss
|
||||
@ -48,6 +43,7 @@ import threading
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.bayesflow.python.ops import stochastic_gradient_estimators as sge
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -87,19 +83,17 @@ class StochasticTensor(object):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def loss(self, sample_losses):
|
||||
def loss(self, sample_loss):
|
||||
"""Returns the term to add to the surrogate loss.
|
||||
|
||||
This method is called by `surrogate_loss`. The input `sample_losses`
|
||||
should have already had `stop_gradient` applied to them. This is
|
||||
because the surrogate_loss usually provides a monte carlo sample term
|
||||
of the form `differentiable_surrogate * sum(sample_losses)` where
|
||||
`sample_losses` is considered constant with respect to the input
|
||||
for purposes of the gradient.
|
||||
This method is called by `surrogate_loss`. The input `sample_loss` should
|
||||
have already had `stop_gradient` applied to it. This is because the
|
||||
surrogate_loss usually provides a Monte Carlo sample term of the form
|
||||
`differentiable_surrogate * sample_loss` where `sample_loss` is considered
|
||||
constant with respect to the input for purposes of the gradient.
|
||||
|
||||
Args:
|
||||
sample_losses: a list of Tensors, the sample losses downstream of this
|
||||
`StochasticTensor`.
|
||||
sample_loss: `Tensor`, sample loss downstream of this `StochasticTensor`.
|
||||
|
||||
Returns:
|
||||
Either `None` or a `Tensor`.
|
||||
@ -309,19 +303,6 @@ def get_current_value_type():
|
||||
return _STOCHASTIC_VALUE_STACK[thread_id][-1]
|
||||
|
||||
|
||||
def get_score_function_with_baseline(baseline):
|
||||
|
||||
def score_function_with_baseline(dist_tensor, value, losses):
|
||||
advantage = math_ops.add_n(losses) - baseline
|
||||
return dist_tensor.distribution.log_prob(value) * advantage
|
||||
|
||||
return score_function_with_baseline
|
||||
|
||||
|
||||
def score_function(dist_tensor, value, losses):
|
||||
return dist_tensor.distribution.log_prob(value) * math_ops.add_n(losses)
|
||||
|
||||
|
||||
class DistributionTensor(StochasticTensor):
|
||||
"""DistributionTensor is a StochasticTensor backed by a distribution."""
|
||||
|
||||
@ -329,7 +310,7 @@ class DistributionTensor(StochasticTensor):
|
||||
dist_cls,
|
||||
name=None,
|
||||
dist_value_type=None,
|
||||
loss_fn=score_function,
|
||||
loss_fn=sge.score_function,
|
||||
**dist_args):
|
||||
"""Construct a `DistributionTensor`.
|
||||
|
||||
@ -357,8 +338,12 @@ class DistributionTensor(StochasticTensor):
|
||||
dist_value_type: a `_StochasticValueType`, which will determine what the
|
||||
`value` of this `DistributionTensor` will be. If not provided, the
|
||||
value type set with the `value_type` context manager will be used.
|
||||
loss_fn: callable that takes `(dt, dt.value(), influenced_losses)`, where
|
||||
`dt` is this `DistributionTensor`, and returns a `Tensor` loss.
|
||||
loss_fn: callable that takes `(dt, dt.value(), influenced_loss)`, where
|
||||
`dt` is this `DistributionTensor`, and returns a `Tensor` loss. By
|
||||
default, `loss_fn` is the `score_function`, or more precisely, the
|
||||
integral of the score function, such that when the gradient is taken,
|
||||
the score function results. See the `stochastic_gradient_estimators`
|
||||
module for additional loss functions and baselines.
|
||||
**dist_args: keyword arguments to be passed through to `dist_cls` on
|
||||
construction.
|
||||
"""
|
||||
@ -469,8 +454,8 @@ class DistributionTensor(StochasticTensor):
|
||||
def value(self, name="value"):
|
||||
return self._value
|
||||
|
||||
def loss(self, final_losses, name="Loss"):
|
||||
# Return a loss based on final_losses and the distribution. Returns
|
||||
def loss(self, final_loss, name="Loss"):
|
||||
# Return a loss based on final_loss and the distribution. Returns
|
||||
# None if pathwise derivatives are supported, if the loss_fn
|
||||
# was explicitly set to None, or if the value type is MeanValue.
|
||||
if self._loss_fn is None:
|
||||
@ -481,12 +466,12 @@ class DistributionTensor(StochasticTensor):
|
||||
# Can perform pathwise-derivative on this one; no additional loss needed.
|
||||
return None
|
||||
|
||||
with ops.name_scope(self.name, values=final_losses):
|
||||
with ops.name_scope(self.name, values=[final_loss]):
|
||||
with ops.name_scope(name):
|
||||
if (self._value_type.stop_gradient or
|
||||
isinstance(self._value_type, SampleAndReshapeValue) or
|
||||
isinstance(self._value_type, SampleValue)):
|
||||
return self._loss_fn(self, self._value, final_losses)
|
||||
return self._loss_fn(self, self._value, final_loss)
|
||||
elif isinstance(self._value_type, MeanValue):
|
||||
return None # MeanValue generally provides its own gradient
|
||||
else:
|
||||
@ -564,7 +549,6 @@ def surrogate_loss(sample_losses,
|
||||
or greater.
|
||||
"""
|
||||
with ops.name_scope(name, values=sample_losses):
|
||||
fixed_losses = []
|
||||
if not isinstance(sample_losses, (list, tuple)):
|
||||
raise TypeError("sample_losses must be a list or tuple")
|
||||
for loss in sample_losses:
|
||||
@ -574,10 +558,9 @@ def surrogate_loss(sample_losses,
|
||||
if not (ndims is not None and ndims >= 1):
|
||||
raise ValueError("loss must have dimensionality 1 or greater: %s" %
|
||||
loss)
|
||||
fixed_losses.append(array_ops.stop_gradient(loss))
|
||||
|
||||
stoch_dependencies_map = _stochastic_dependencies_map(
|
||||
fixed_losses, stochastic_tensors=stochastic_tensors)
|
||||
sample_losses, stochastic_tensors=stochastic_tensors)
|
||||
if not stoch_dependencies_map:
|
||||
logging.warn(
|
||||
"No collection of Stochastic Tensors found for current graph.")
|
||||
@ -588,8 +571,27 @@ def surrogate_loss(sample_losses,
|
||||
sample_losses = [ops.convert_to_tensor(loss) for loss in sample_losses]
|
||||
loss_terms = sample_losses
|
||||
for (stoch_node, dependent_losses) in stoch_dependencies_map.items():
|
||||
loss_term = stoch_node.loss(list(dependent_losses))
|
||||
dependent_losses = list(dependent_losses)
|
||||
|
||||
# Sum up the downstream losses for this ST
|
||||
influenced_loss = _add_n_or_sum(dependent_losses)
|
||||
|
||||
# Compute surrogate loss term
|
||||
loss_term = stoch_node.loss(array_ops.stop_gradient(influenced_loss))
|
||||
if loss_term is not None:
|
||||
loss_terms.append(loss_term)
|
||||
|
||||
return math_ops.add_n(loss_terms)
|
||||
return _add_n_or_sum(loss_terms)
|
||||
|
||||
|
||||
def _add_n_or_sum(terms):
|
||||
# add_n works for Tensors of the same dtype and shape
|
||||
shape = terms[0].get_shape()
|
||||
dtype = terms[0].dtype
|
||||
|
||||
if all(term.get_shape().is_fully_defined() and
|
||||
term.get_shape().is_compatible_with(shape) and term.dtype == dtype
|
||||
for term in terms):
|
||||
return math_ops.add_n(terms)
|
||||
else:
|
||||
return sum(terms)
|
||||
|
Loading…
Reference in New Issue
Block a user