Score function estimators with baselines

Change: 130443765
This commit is contained in:
A. Unique TensorFlower 2016-08-16 13:05:03 -08:00 committed by TensorFlower Gardener
parent ce57cb7882
commit 9d62d40f9f
4 changed files with 237 additions and 40 deletions

View File

@ -21,4 +21,5 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import,line-too-long
from tensorflow.contrib.bayesflow.python.ops import stochastic_gradient_estimators
from tensorflow.contrib.bayesflow.python.ops import stochastic_graph

View File

@ -22,6 +22,7 @@ import numpy as np
import tensorflow as tf
sg = tf.contrib.bayesflow.stochastic_graph
sge = tf.contrib.bayesflow.stochastic_gradient_estimators
distributions = tf.contrib.distributions
@ -177,7 +178,7 @@ class DistributionTensorTest(tf.test.TestCase):
mu=mu,
sigma=sigma,
dist_value_type=sg.MeanValue(stop_gradient=True),
loss_fn=sg.get_score_function_with_baseline(
loss_fn=sge.get_score_function_with_constant_baseline(
baseline=tf.constant(8.0)))
loss = dt.loss([tf.constant(2.0)])
self.assertTrue(loss is not None)

View File

@ -0,0 +1,193 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Stochastic gradient estimators.
These functions are meant to be used in conjuction with `StochasticTensor`
(`loss_fn` parameter) and `surrogate_loss`.
See Gradient Estimation Using Stochastic Computation Graphs
(http://arxiv.org/abs/1506.05254) by Schulman et al., eq. 1 and section 4, for
mathematical details.
## Score function estimator
The score function is an unbiased estimator of the gradient of `E_p(x)[f(x)]`,
where `f(x)` can be considered to be a "loss" term. It is computed as
`E_p(x)[f(x) grad(log p(x))]`. A constant `b`, referred to here as the
"baseline", can be subtracted from `f(x)` without affecting the expectation. The
term `(f(x) - b)` is referred to here as the "advantage".
Note that the methods defined in this module actually compute the integrand of
the score function, such that when taking the gradient, the true score function
is computed.
@@score_function
@@get_score_function_with_baseline
@@get_score_function_with_constant_baseline
@@get_score_function_with_advantage
## Baseline functions
Baselines reduce the variance of Monte Carlo estimate of an expectation. The
baseline for a stochastic node can be a function of all non-influenced nodes
(see section 4 of Schulman et al., linked above). Baselines are also known as
"control variates."
In the context of a MC estimate of `E_p(x)[f(x) - b]`, baseline functions have
the signature `(st, fx) => Tensor`, where `st` is a `StochasticTensor` backed by
the distribution `p(x)` and `fx` is the influenced loss.
@@get_mean_baseline
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training import training
from tensorflow.python.util.all_util import make_all
def score_function(dist_tensor, value, loss, baseline=None,
name="ScoreFunction"):
"""Score function estimator.
Computes the integrand of the score function with a baseline:
`p.log_prob(value) * (loss - baseline)`.
It will add a `stop_gradient` to the advantage `(loss - baseline)`.
Args:
dist_tensor: `DistributionTensor` p(x).
value: `Tensor` x. Samples from p(x).
loss: `Tensor`.
baseline: `Tensor` broadcastable to `loss`.
name: name to prepend ops with.
Returns:
`Tensor` `p.log_prob(x) * (loss - b)`. Taking the gradient yields the score
function estimator.
"""
with ops.name_scope(name, values=[value, loss, baseline]):
value = ops.convert_to_tensor(value)
loss = ops.convert_to_tensor(loss)
if baseline is not None:
baseline = ops.convert_to_tensor(baseline)
advantage = loss - baseline
else:
advantage = loss
advantage = array_ops.stop_gradient(advantage)
return dist_tensor.distribution.log_prob(value) * advantage
def get_score_function_with_advantage(advantage_fn=None,
name="ScoreFunctionWithAdvantage"):
"""Score function estimator with advantage function.
Args:
advantage_fn: callable that takes the `DistributionTensor` and the
downstream `loss` and returns a `Tensor` advantage
(e.g. `loss - baseline`).
name: name to prepend ops with.
Returns:
Callable score function estimator that takes the `DistributionTensor`, the
sampled `value`, and the downstream `loss`, and uses the provided advantage.
"""
def score_function_with_advantage(dist_tensor, value, loss):
with ops.name_scope(name, values=[value, loss]):
advantage = advantage_fn(dist_tensor, loss)
advantage = array_ops.stop_gradient(advantage)
return dist_tensor.distribution.log_prob(value) * advantage
return score_function_with_advantage
def get_score_function_with_constant_baseline(baseline, name="ScoreFunction"):
"""Score function estimator with constant baseline.
Args:
baseline: `Tensor` to be subtracted from loss.
name: name to prepend ops with.
Returns:
Callable score function estimator that takes the `DistributionTensor`, the
sampled `value`, and the downstream `loss`, and subtracts the provided
`baseline` from the `loss`.
"""
def score_function_with_constant_baseline(dist_tensor, value, loss):
return score_function(dist_tensor, value, loss, baseline, name)
return score_function_with_constant_baseline
def get_score_function_with_baseline(baseline_fn=None, name="ScoreFunction"):
"""Score function estimator with baseline function.
Args:
baseline_fn: callable that takes the `DistributionTensor` and the downstream
`loss` and returns a `Tensor` baseline to be subtracted from the `loss`.
If None, defaults to `get_mean_baseline`, which is an EMA of the loss.
name: name to prepend ops with.
Returns:
Callable score function estimator that takes the `DistributionTensor`, the
sampled `value`, and the downstream `loss`, and subtracts the provided
`baseline` from the `loss`.
"""
if baseline_fn is None:
baseline_fn = get_mean_baseline()
def score_function_with_baseline(dist_tensor, value, loss):
with ops.name_scope(name):
b = baseline_fn(dist_tensor, loss)
return score_function(dist_tensor, value, loss, b)
return score_function_with_baseline
def get_mean_baseline(ema_decay=0.99, name="MeanBaseline"):
"""ExponentialMovingAverage baseline.
Args:
ema_decay: decay rate for the ExponentialMovingAverage.
name: name to prepend ops with.
Returns:
Callable baseline function that takes the `DistributionTensor` (unused) and
the downstream `loss`, and returns an EMA of the loss.
"""
def mean_baseline(_, loss):
with ops.name_scope(name):
ema = training.ExponentialMovingAverage(decay=ema_decay)
update_op = ema.apply(math_ops.reduce_mean(loss))
with control_flow_ops.control_dependencies([update_op]):
# TODO(rsepassi): Possibly implement the initialization bias correction
# term from Adam (section 3 of https://arxiv.org/pdf/1412.6980v8.pdf).
baseline = ema.average(loss)
return baseline
return mean_baseline
__all__ = make_all(__name__)

View File

@ -27,11 +27,6 @@
@@value_type
@@get_current_value_type
## Stochastic Computation Surrogate Loss Functions
@@score_function
@@get_score_function_with_baseline
## Stochastic Computation Graph Helper Functions
@@surrogate_loss
@ -48,6 +43,7 @@ import threading
import six
from tensorflow.contrib.bayesflow.python.ops import stochastic_gradient_estimators as sge
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
@ -87,19 +83,17 @@ class StochasticTensor(object):
pass
@abc.abstractmethod
def loss(self, sample_losses):
def loss(self, sample_loss):
"""Returns the term to add to the surrogate loss.
This method is called by `surrogate_loss`. The input `sample_losses`
should have already had `stop_gradient` applied to them. This is
because the surrogate_loss usually provides a monte carlo sample term
of the form `differentiable_surrogate * sum(sample_losses)` where
`sample_losses` is considered constant with respect to the input
for purposes of the gradient.
This method is called by `surrogate_loss`. The input `sample_loss` should
have already had `stop_gradient` applied to it. This is because the
surrogate_loss usually provides a Monte Carlo sample term of the form
`differentiable_surrogate * sample_loss` where `sample_loss` is considered
constant with respect to the input for purposes of the gradient.
Args:
sample_losses: a list of Tensors, the sample losses downstream of this
`StochasticTensor`.
sample_loss: `Tensor`, sample loss downstream of this `StochasticTensor`.
Returns:
Either `None` or a `Tensor`.
@ -309,19 +303,6 @@ def get_current_value_type():
return _STOCHASTIC_VALUE_STACK[thread_id][-1]
def get_score_function_with_baseline(baseline):
def score_function_with_baseline(dist_tensor, value, losses):
advantage = math_ops.add_n(losses) - baseline
return dist_tensor.distribution.log_prob(value) * advantage
return score_function_with_baseline
def score_function(dist_tensor, value, losses):
return dist_tensor.distribution.log_prob(value) * math_ops.add_n(losses)
class DistributionTensor(StochasticTensor):
"""DistributionTensor is a StochasticTensor backed by a distribution."""
@ -329,7 +310,7 @@ class DistributionTensor(StochasticTensor):
dist_cls,
name=None,
dist_value_type=None,
loss_fn=score_function,
loss_fn=sge.score_function,
**dist_args):
"""Construct a `DistributionTensor`.
@ -357,8 +338,12 @@ class DistributionTensor(StochasticTensor):
dist_value_type: a `_StochasticValueType`, which will determine what the
`value` of this `DistributionTensor` will be. If not provided, the
value type set with the `value_type` context manager will be used.
loss_fn: callable that takes `(dt, dt.value(), influenced_losses)`, where
`dt` is this `DistributionTensor`, and returns a `Tensor` loss.
loss_fn: callable that takes `(dt, dt.value(), influenced_loss)`, where
`dt` is this `DistributionTensor`, and returns a `Tensor` loss. By
default, `loss_fn` is the `score_function`, or more precisely, the
integral of the score function, such that when the gradient is taken,
the score function results. See the `stochastic_gradient_estimators`
module for additional loss functions and baselines.
**dist_args: keyword arguments to be passed through to `dist_cls` on
construction.
"""
@ -469,8 +454,8 @@ class DistributionTensor(StochasticTensor):
def value(self, name="value"):
return self._value
def loss(self, final_losses, name="Loss"):
# Return a loss based on final_losses and the distribution. Returns
def loss(self, final_loss, name="Loss"):
# Return a loss based on final_loss and the distribution. Returns
# None if pathwise derivatives are supported, if the loss_fn
# was explicitly set to None, or if the value type is MeanValue.
if self._loss_fn is None:
@ -481,12 +466,12 @@ class DistributionTensor(StochasticTensor):
# Can perform pathwise-derivative on this one; no additional loss needed.
return None
with ops.name_scope(self.name, values=final_losses):
with ops.name_scope(self.name, values=[final_loss]):
with ops.name_scope(name):
if (self._value_type.stop_gradient or
isinstance(self._value_type, SampleAndReshapeValue) or
isinstance(self._value_type, SampleValue)):
return self._loss_fn(self, self._value, final_losses)
return self._loss_fn(self, self._value, final_loss)
elif isinstance(self._value_type, MeanValue):
return None # MeanValue generally provides its own gradient
else:
@ -564,7 +549,6 @@ def surrogate_loss(sample_losses,
or greater.
"""
with ops.name_scope(name, values=sample_losses):
fixed_losses = []
if not isinstance(sample_losses, (list, tuple)):
raise TypeError("sample_losses must be a list or tuple")
for loss in sample_losses:
@ -574,10 +558,9 @@ def surrogate_loss(sample_losses,
if not (ndims is not None and ndims >= 1):
raise ValueError("loss must have dimensionality 1 or greater: %s" %
loss)
fixed_losses.append(array_ops.stop_gradient(loss))
stoch_dependencies_map = _stochastic_dependencies_map(
fixed_losses, stochastic_tensors=stochastic_tensors)
sample_losses, stochastic_tensors=stochastic_tensors)
if not stoch_dependencies_map:
logging.warn(
"No collection of Stochastic Tensors found for current graph.")
@ -588,8 +571,27 @@ def surrogate_loss(sample_losses,
sample_losses = [ops.convert_to_tensor(loss) for loss in sample_losses]
loss_terms = sample_losses
for (stoch_node, dependent_losses) in stoch_dependencies_map.items():
loss_term = stoch_node.loss(list(dependent_losses))
dependent_losses = list(dependent_losses)
# Sum up the downstream losses for this ST
influenced_loss = _add_n_or_sum(dependent_losses)
# Compute surrogate loss term
loss_term = stoch_node.loss(array_ops.stop_gradient(influenced_loss))
if loss_term is not None:
loss_terms.append(loss_term)
return math_ops.add_n(loss_terms)
return _add_n_or_sum(loss_terms)
def _add_n_or_sum(terms):
# add_n works for Tensors of the same dtype and shape
shape = terms[0].get_shape()
dtype = terms[0].dtype
if all(term.get_shape().is_fully_defined() and
term.get_shape().is_compatible_with(shape) and term.dtype == dtype
for term in terms):
return math_ops.add_n(terms)
else:
return sum(terms)