Add initial version of tf.contrib.integrate

Includes tf.contrib.integrate.odeint, an integrator for ODEs patterned off of
scipy.integrate.odeint.
Change: 137973526
This commit is contained in:
Stephan Hoyer 2016-11-02 10:48:36 -08:00 committed by TensorFlower Gardener
parent 0b730a0d93
commit 65b1677070
9 changed files with 855 additions and 0 deletions

View File

@ -97,6 +97,7 @@ filegroup(
"//tensorflow/contrib/framework:all_files",
"//tensorflow/contrib/graph_editor:all_files",
"//tensorflow/contrib/grid_rnn:all_files",
"//tensorflow/contrib/integrate:all_files",
"//tensorflow/contrib/layers:all_files",
"//tensorflow/contrib/layers/kernels:all_files",
"//tensorflow/contrib/learn:all_files",

View File

@ -23,6 +23,7 @@ py_library(
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/contrib/grid_rnn:grid_rnn_py",
"//tensorflow/contrib/integrate:integrate_py",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/learn",
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",

View File

@ -28,6 +28,7 @@ from tensorflow.contrib import factorization
from tensorflow.contrib import framework
from tensorflow.contrib import graph_editor
from tensorflow.contrib import grid_rnn
from tensorflow.contrib import integrate
from tensorflow.contrib import layers
from tensorflow.contrib import learn
from tensorflow.contrib import linear_optimizer

View File

@ -0,0 +1,38 @@
# Description:
# Integration and ODE solvers for TensorFlow.
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package(default_visibility = ["//tensorflow:__subpackages__"])
py_library(
name = "integrate_py",
srcs = [
"__init__.py",
"python/ops/odes.py",
],
srcs_version = "PY2AND3",
)
py_test(
name = "odes_test",
srcs = ["python/ops/odes_test.py"],
srcs_version = "PY2AND3",
deps = [
":integrate_py",
"//tensorflow:tensorflow_py",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
)

View File

@ -0,0 +1,9 @@
# Integration and ODE solvers for TensorFlow
TensorFlow equivalents to the routines provided by `scipy.integrate`. Currently
contains a single function, `odeint`, for integrating ordinary differential
equations.
Maintainers:
- Stephan Hoyer (shoyer@google.com, github.com/shoyer)
- Marc Coram (mcoram@google.com, github.com/mcoram)

View File

@ -0,0 +1,68 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Integration and ODE solvers for TensorFlow.
## Example: Lorenz attractor
We can use `odeint` to solve the
[Lorentz system](https://en.wikipedia.org/wiki/Lorenz_system) of ordinary
differential equations, a prototypical example of chaotic dynamics:
```python
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
rho = 28.0
sigma = 10.0
beta = 8.0/3.0
def lorenz_equation(state, t):
x, y, z = tf.unpack(state)
dx = sigma * (y - x)
dy = x * (rho - z) - y
dz = x * y - beta * z
return tf.pack([dx, dy, dz])
init_state = tf.constant([0, 2, 20], dtype=tf.float64)
t = np.linspace(0, 50, num=5000)
tensor_state, tensor_info = tf.contrib.integrate.odeint(
lorenz_equation, init_state, t, full_output=True)
sess = tf.Session()
state, info = sess.run([tensor_state, tensor_info])
x, y, z = state.T
plt.plot(x, z)
```
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../../images/lorenz_attractor.png" alt>
</div>
## Ops
@@odeint
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import
from tensorflow.contrib.integrate.python.ops.odes import *
from tensorflow.python.util.all_util import make_all
__all__ = make_all(__name__)

View File

@ -0,0 +1,503 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ODE solvers for TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
_ButcherTableau = collections.namedtuple(
'_ButcherTableau', 'alpha beta c_sol c_mid c_error')
# Parameters from Shampine (1986), section 4.
_DORMAND_PRINCE_TABLEAU = _ButcherTableau(
alpha=[1/5, 3/10, 4/5, 8/9, 1., 1.],
beta=[[1/5],
[3/40, 9/40],
[44/45, -56/15, 32/9],
[19372/6561, -25360/2187, 64448/6561, -212/729],
[9017/3168, -355/33, 46732/5247, 49/176, -5103/18656],
[35/384, 0, 500/1113, 125/192, -2187/6784, 11/84]],
c_sol=[35/384, 0, 500/1113, 125/192, -2187/6784, 11/84, 0],
c_mid=[6025192743/30085553152 / 2, 0, 51252292925/65400821598 / 2,
-2691868925/45128329728 / 2, 187940372067/1594534317056 / 2,
-1776094331/19743644256 / 2, 11237099/235043384 / 2],
c_error=[1951/21600 - 35/384,
0,
22642/50085 - 500/1113,
451/720 - 125/192,
-12231/42400 - -2187/6784,
649/6300 - 11/84,
1/60],
)
def _possibly_nonzero(x):
return isinstance(x, ops.Tensor) or x != 0
def _scaled_dot_product(scale, xs, ys, name=None):
"""Calculate a scaled, vector inner product between lists of Tensors."""
with ops.name_scope(name, 'scaled_dot_product', [scale, xs, ys]) as scope:
# Some of the parameters in our Butcher tableau include zeros. Using
# _possibly_nonzero lets us avoid wasted computation.
return math_ops.add_n([(scale * x) * y for x, y in zip(xs, ys)
if _possibly_nonzero(x) or _possibly_nonzero(y)],
name=scope)
def _dot_product(xs, ys, name=None):
"""Calculate the vector inner product between two lists of Tensors."""
with ops.name_scope(name, 'dot_product', [xs, ys]) as scope:
return math_ops.add_n([x * y for x, y in zip(xs, ys)], name=scope)
def _runge_kutta_step(func, y0, f0, t0, dt, tableau=_DORMAND_PRINCE_TABLEAU,
name=None):
"""Take an arbitrary Runge-Kutta step and estimate error.
Args:
func: Function to evaluate like `func(y, t)` to compute the time derivative
of `y`.
y0: Tensor initial value for the state.
f0: Tensor initial value for the derivative, computed from `func(y0, t0)`.
t0: float64 scalar Tensor giving the initial time.
dt: float64 scalar Tensor giving the size of the desired time step.
tableau: optional _ButcherTableau describing how to take the Runge-Kutta
step.
name: optional name for the operation.
Returns:
Tuple `(y1, f1, y1_error, k)` giving the estimated function value after
the Runge-Kutta step at `t1 = t0 + dt`, the derivative of the state at `t1`,
estimated error at `t1`, and a list of Runge-Kutta coefficients `k` used for
calculating these terms.
"""
with ops.name_scope(name, 'runge_kutta_step', [y0, f0, t0, dt]) as scope:
y0 = ops.convert_to_tensor(y0, name='y0')
f0 = ops.convert_to_tensor(f0, name='f0')
t0 = ops.convert_to_tensor(t0, name='t0')
dt = ops.convert_to_tensor(dt, name='dt')
dt_cast = math_ops.cast(dt, y0.dtype)
k = [f0]
for alpha_i, beta_i in zip(tableau.alpha, tableau.beta):
ti = t0 + alpha_i * dt
yi = y0 + _scaled_dot_product(dt_cast, beta_i, k)
k.append(func(yi, ti))
if not (tableau.c_sol[-1] == 0 and tableau.c_sol == tableau.beta[-1]):
# This property (true for Dormand-Prince) lets us save a few FLOPs.
yi = y0 + _scaled_dot_product(dt_cast, tableau.c_sol, k)
y1 = array_ops.identity(yi, name='%s/y1' % scope)
f1 = array_ops.identity(k[-1], name='%s/f1' % scope)
y1_error = _scaled_dot_product(dt_cast, tableau.c_error, k,
name='%s/y1_error' % scope)
return (y1, f1, y1_error, k)
def _interp_fit(y0, y1, y_mid, f0, f1, dt):
"""Fit coefficients for 4th order polynomial interpolation.
Args:
y0: function value at the start of the interval.
y1: function value at the end of the interval.
y_mid: function value at the mid-point of the interval.
f0: derivative value at the start of the interval.
f1: derivative value at the end of the interval.
dt: width of the interval.
Returns:
List of coefficients `[a, b, c, d, e]` for interpolating with the polynomial
`p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e` for values of `x`
between 0 (start of interval) and 1 (end of interval).
"""
# a, b, c, d, e = sympy.symbols('a b c d e')
# x, dt, y0, y1, y_mid, f0, f1 = sympy.symbols('x dt y0 y1 y_mid f0 f1')
# p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e
# sympy.solve([p.subs(x, 0) - y0,
# p.subs(x, 1 / 2) - y_mid,
# p.subs(x, 1) - y1,
# (p.diff(x) / dt).subs(x, 0) - f0,
# (p.diff(x) / dt).subs(x, 1) - f1],
# [a, b, c, d, e])
# {a: -2.0*dt*f0 + 2.0*dt*f1 - 8.0*y0 - 8.0*y1 + 16.0*y_mid,
# b: 5.0*dt*f0 - 3.0*dt*f1 + 18.0*y0 + 14.0*y1 - 32.0*y_mid,
# c: -4.0*dt*f0 + dt*f1 - 11.0*y0 - 5.0*y1 + 16.0*y_mid,
# d: dt*f0,
# e: y0}
a = _dot_product([-2 * dt, 2 * dt, -8, -8, 16], [f0, f1, y0, y1, y_mid])
b = _dot_product([5 * dt, -3 * dt, 18, 14, -32], [f0, f1, y0, y1, y_mid])
c = _dot_product([-4 * dt, dt, -11, -5, 16], [f0, f1, y0, y1, y_mid])
d = dt * f0
e = y0
return [a, b, c, d, e]
def _interp_fit_rk(y0, y1, k, dt, tableau=_DORMAND_PRINCE_TABLEAU):
"""Fit an interpolating polynomial to the results of a Runge-Kutta step."""
with ops.name_scope('interp_fit_rk'):
dt = math_ops.cast(dt, y0.dtype)
y_mid = y0 + _scaled_dot_product(dt, tableau.c_mid, k)
f0 = k[0]
f1 = k[-1]
return _interp_fit(y0, y1, y_mid, f0, f1, dt)
def _interp_evaluate(coefficients, t0, t1, t):
"""Evaluate polynomial interpolation at the given time point.
Args:
coefficients: list of Tensor coefficients as created by `interp_fit`.
t0: scalar float64 Tensor giving the start of the interval.
t1: scalar float64 Tensor giving the end of the interval.
t: scalar float64 Tensor giving the desired interpolation point.
Returns:
Polynomial interpolation of the coefficients at time `t`.
"""
with ops.name_scope('interp_evaluate'):
t0 = ops.convert_to_tensor(t0)
t1 = ops.convert_to_tensor(t1)
t = ops.convert_to_tensor(t)
dtype = coefficients[0].dtype
assert_op = control_flow_ops.Assert(
(t0 <= t) & (t <= t1),
['invalid interpolation, fails `t0 <= t <= t1`:', t0, t, t1])
with ops.control_dependencies([assert_op]):
x = math_ops.cast((t - t0) / (t1 - t0), dtype)
xs = [constant_op.constant(1, dtype), x]
for _ in range(2, len(coefficients)):
xs.append(xs[-1] * x)
return _dot_product(coefficients, reversed(xs))
def _optimal_step_size(last_step,
error_ratio,
safety=0.9,
ifactor=10.0,
dfactor=0.2,
order=5,
name=None):
"""Calculate the optimal size for the next Runge-Kutta step."""
with ops.name_scope(
name, 'optimal_step_size', [last_step, error_ratio]) as scope:
error_ratio = math_ops.cast(error_ratio, last_step.dtype)
exponent = math_ops.cast(1 / order, last_step.dtype)
# this looks more complex than necessary, but importantly it keeps
# error_ratio in the numerator so we can't divide by zero:
factor = math_ops.maximum(
1 / ifactor,
math_ops.minimum(error_ratio ** exponent / safety, 1 / dfactor))
return math_ops.div(last_step, factor, name=scope)
def _abs_square(x):
if x.dtype.is_complex:
return math_ops.square(math_ops.real(x)) + math_ops.square(math_ops.imag(x))
else:
return math_ops.square(x)
def _ta_append(tensor_array, value):
"""Append a value to the end of a tf.TensorArray."""
return tensor_array.write(tensor_array.size(), value)
class _RungeKuttaState(collections.namedtuple(
'_RungeKuttaState', 'y1, f1, t0, t1, dt, interp_coeff')):
"""Saved state of the Runge Kutta solver.
Attributes:
y1: Tensor giving the function value at the end of the last time step.
f1: Tensor giving derivative at the end of the last time step.
t0: scalar float64 Tensor giving start of the last time step.
t1: scalar float64 Tensor giving end of the last time step.
dt: scalar float64 Tensor giving the size for the next time step.
interp_coef: list of Tensors giving coefficients for polynomial
interpolation between `t0` and `t1`.
"""
class _History(collections.namedtuple(
'_History', 'integrate_points, error_ratio')):
"""Saved integration history for use in `info_dict`.
Attributes:
integrate_points: tf.TensorArray storing integrating time points.
error_ratio: tf.TensorArray storing computed error ratios at each
integration step.
"""
def _dopri5(func,
y0,
t,
rtol,
atol,
full_output=False,
first_step=None,
safety=0.9,
ifactor=10.0,
dfactor=0.2,
max_num_steps=1000,
name=None):
"""Solve an ODE for `odeint` using method='dopri5'."""
if first_step is None:
# at some point, we might want to switch to picking the step size
# automatically
first_step = 1.0
with ops.name_scope(
name, 'dopri5',
[y0, t, rtol, atol, safety, ifactor, dfactor, max_num_steps]) as scope:
first_step = ops.convert_to_tensor(first_step, dtype=t.dtype,
name='first_step')
safety = ops.convert_to_tensor(safety, dtype=t.dtype, name='safety')
ifactor = ops.convert_to_tensor(ifactor, dtype=t.dtype, name='ifactor')
dfactor = ops.convert_to_tensor(dfactor, dtype=t.dtype, name='dfactor')
max_num_steps = ops.convert_to_tensor(max_num_steps, dtype=dtypes.int32,
name='max_num_steps')
def adaptive_runge_kutta_step(rk_state, history, n_steps):
"""Take an adaptive Runge-Kutta step to integrate the ODE."""
y0, f0, _, t0, dt, interp_coeff = rk_state
with ops.name_scope('assertions'):
check_underflow = control_flow_ops.Assert(
t0 + dt > t0, ['underflow in dt', dt])
check_max_num_steps = control_flow_ops.Assert(
n_steps < max_num_steps, ['max_num_steps exceeded'])
check_numerics = control_flow_ops.Assert(
math_ops.reduce_all(math_ops.is_finite(abs(y0))),
['non-finite values in state `y`', y0])
with ops.control_dependencies(
[check_underflow, check_max_num_steps, check_numerics]):
y1, f1, y1_error, k = _runge_kutta_step(func, y0, f0, t0, dt)
with ops.name_scope('error_ratio'):
# We use the same approach as the dopri5 fortran code.
error_tol = atol + rtol * math_ops.maximum(abs(y0), abs(y1))
tensor_error_ratio = _abs_square(y1_error) / _abs_square(error_tol)
# Could also use reduce_maximum here.
error_ratio = math_ops.sqrt(math_ops.reduce_mean(tensor_error_ratio))
accept_step = error_ratio <= 1
with ops.name_scope('update/rk_state'):
# If we don't accept the step, the _RungeKuttaState will be useless
# (covering a time-interval of size 0), but that's OK, because in such
# cases we always immediately take another Runge-Kutta step.
y_next = control_flow_ops.cond(accept_step, lambda: y1, lambda: y0)
f_next = control_flow_ops.cond(accept_step, lambda: f1, lambda: f0)
t_next = control_flow_ops.cond(accept_step, lambda: t0 + dt, lambda: t0)
interp_coeff = control_flow_ops.cond(
accept_step,
lambda: _interp_fit_rk(y0, y1, k, dt),
lambda: interp_coeff)
dt_next = _optimal_step_size(dt, error_ratio, safety, ifactor, dfactor)
rk_state = _RungeKuttaState(
y_next, f_next, t0, t_next, dt_next, interp_coeff)
with ops.name_scope('update/history'):
history = _History(_ta_append(history.integrate_points, t0 + dt),
_ta_append(history.error_ratio, error_ratio))
return rk_state, history, n_steps + 1
def interpolate(solution, history, rk_state, i):
"""Interpolate through the next time point, integrating as necessary."""
with ops.name_scope('interpolate'):
rk_state, history, _ = control_flow_ops.while_loop(
lambda rk_state, *_: t[i] > rk_state.t1,
adaptive_runge_kutta_step,
(rk_state, history, 0),
name='integrate_loop')
y = _interp_evaluate(
rk_state.interp_coeff, rk_state.t0, rk_state.t1, t[i])
solution = solution.write(i, y)
return solution, history, rk_state, i + 1
assert_increasing = control_flow_ops.Assert(
math_ops.reduce_all(t[1:] > t[:-1]),
['`t` must be monotonic increasing'])
with ops.control_dependencies([assert_increasing]):
num_times = array_ops.size(t)
solution = tensor_array_ops.TensorArray(
y0.dtype, size=num_times).write(0, y0)
history = _History(
integrate_points=tensor_array_ops.TensorArray(
t.dtype, size=0, dynamic_size=True),
error_ratio=tensor_array_ops.TensorArray(
rtol.dtype, size=0, dynamic_size=True))
rk_state = _RungeKuttaState(
y0, func(y0, t[0]), t[0], t[0], first_step, interp_coeff=[y0] * 5)
solution, history, _, _ = control_flow_ops.while_loop(
lambda _, __, ___, i: i < num_times,
interpolate,
(solution, history, rk_state, 1),
name='interpolate_loop')
y = solution.pack(name=scope)
y.set_shape(t.get_shape().concatenate(y0.get_shape()))
if not full_output:
return y
else:
integrate_points = history.integrate_points.pack()
info_dict = {'num_func_evals': 6 * array_ops.size(integrate_points) + 1,
'integrate_points': integrate_points,
'error_ratio': history.error_ratio.pack()}
return (y, info_dict)
def odeint(func,
y0,
t,
rtol=1e-6,
atol=1e-12,
method=None,
options=None,
full_output=False,
name=None):
"""Integrate a system of ordinary differential equations.
Solves the initial value problem for a non-stiff system of first order ode-s:
```
dy/dt = func(y, t), y(t[0]) = y0
```
where y is a Tensor of any shape.
For example:
```
# solve `dy/dt = -y`, corresponding to exponential decay
tf.contrib.integrate.odeint(lambda y, _: -y, 1.0, [0, 1, 2])
=> [1, exp(-1), exp(-2)]
```
Output dtypes and numerical precision are based on the dtypes of the inputs
`y0` and `t`.
Currently, implements 5th order Runge-Kutta with adaptive step size control
and dense output, using the Dormand-Prince method. Similar to the 'dopri5'
method of `scipy.integrate.ode` and MATLAB's `ode45`.
Based on: Shampine, Lawrence F. (1986), "Some Practical Runge-Kutta Formulas",
Mathematics of Computation, American Mathematical Society, 46 (173): 135-150,
doi:10.2307/2008219
Args:
func: Function that maps a Tensor holding the state `y` and a scalar Tensor
`t` into a Tensor of state derivatives with respect to time.
y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May
have any floating point or complex dtype.
t: 1-D Tensor holding a sequence of time points for which to solve for
`y`. The initial time point should be the first element of this sequence,
and each time must be larger than the previous time. May have any floating
point dtype. If not provided as a Tensor, converted to a Tensor with
float64 dtype.
rtol: optional float64 Tensor specifying an upper bound on relative error,
per element of `y`.
atol: optional float64 Tensor specifying an upper bound on absolute error,
per element of `y`.
method: optional string indicating the integration method to use. Currently,
the only valid option is `'dopri5'`.
options: optional dict of configuring options for the indicated integration
method. Can only be provided if a `method` is explicitly set. For
`'dopri5'`, valid options include:
* first_step: an initial guess for the size of the first integration
(current default: 1.0, but may later be changed to use heuristics based
on the gradient).
* safety: safety factor for adaptive step control, generally a constant
in the range 0.8-1 (default: 0.9).
* ifactor: maximum factor by which the adaptive step may be increased
(default: 10.0).
* dfactor: maximum factor by which the adpative step may be decreased
(default: 0.2).
* max_num_steps: integer maximum number of integrate steps between time
points in `t` (default: 1000).
full_output: optional boolean. If True, `odeint` returns a tuple
`(y, info_dict)` describing the integration process.
name: Optional name for this operation.
Returns:
y: (N+1)-D tensor, where the first dimension corresponds to different
time points. Contains the solved value of y for each desired time point in
`t`, with the initial value `y0` being the first element along the first
dimension.
info_dict: only if `full_output == True`. A dict with the following values:
* num_func_evals: integer Tensor counting the number of function
evaluations.
* integrate_points: 1D float64 Tensor with the upper bound of each
integration time step.
* error_ratio: 1D float Tensor with the estimated ratio of the integration
error to the error tolerance at each integration step. An ratio greater
than 1 corresponds to rejected steps.
Raises:
ValueError: if an invalid `method` is provided.
TypeError: if `options` is supplied without `method`, or if `t` or `y0` has
an invalid dtype.
"""
if method is not None and method != 'dopri5':
raise ValueError('invalid method: %r' % method)
if options is None:
options = {}
elif method is None:
raise ValueError('cannot supply `options` without specifying `method`')
with ops.name_scope(name, 'odeint', [y0, t, rtol, atol]) as scope:
# TODO(shoyer): use nest.flatten (like tf.while_loop) to allow `y0` to be an
# arbitrarily nested tuple. This will help performance and usability by
# avoiding the need to pack/unpack in user functions.
y0 = ops.convert_to_tensor(y0, name='y0')
if not (y0.dtype.is_floating or y0.dtype.is_complex):
raise TypeError('`y0` must have a floating point or complex floating '
'point dtype')
t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t')
if not t.dtype.is_floating:
raise TypeError('`t` must have a floating point dtype')
error_dtype = abs(y0).dtype
rtol = ops.convert_to_tensor(rtol, dtype=error_dtype, name='rtol')
atol = ops.convert_to_tensor(atol, dtype=error_dtype, name='atol')
return _dopri5(func, y0, t,
rtol=rtol,
atol=atol,
full_output=full_output,
name=scope,
**options)

View File

@ -0,0 +1,232 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ODE solvers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.contrib.integrate.python.ops import odes
class OdeIntTest(tf.test.TestCase):
def setUp(self):
super(OdeIntTest, self).setUp()
# simple defaults (solution is a sin-wave)
matrix = tf.constant([[0, 1], [-1, 0]], dtype=tf.float64)
self.func = lambda y, t: tf.matmul(matrix, y)
self.y0 = np.array([[1.0], [0.0]])
def test_odeint_exp(self):
# Test odeint by an exponential function:
# dy / dt = y, y(0) = 1.0.
# Its analytical solution is y = exp(t).
func = lambda y, t: y
y0 = tf.constant(1.0, dtype=tf.float64)
t = np.linspace(0.0, 1.0, 11)
y_solved = tf.contrib.integrate.odeint(func, y0, t)
self.assertIn('odeint', y_solved.name)
self.assertEqual(y_solved.get_shape(), tf.TensorShape([11]))
with self.test_session() as sess:
y_solved = sess.run(y_solved)
y_true = np.exp(t)
self.assertAllClose(y_true, y_solved)
def test_odeint_complex(self):
# Test a complex, linear ODE:
# dy / dt = k * y, y(0) = 1.0.
# Its analytical solution is y = exp(k * t).
k = 1j - 0.1
func = lambda y, t: k * y
t = np.linspace(0.0, 1.0, 11)
y_solved = tf.contrib.integrate.odeint(func, 1.0 + 0.0j, t)
with self.test_session() as sess:
y_solved = sess.run(y_solved)
y_true = np.exp(k * t)
self.assertAllClose(y_true, y_solved)
def test_odeint_riccati(self):
# The Ricatti equation is:
# dy / dt = (y - t) ** 2 + 1.0, y(0) = 0.5.
# Its analytical solution is y = 1.0 / (2.0 - t) + t.
func = lambda t, y: (y - t)**2 + 1.0
t = np.linspace(0.0, 1.0, 11)
y_solved = tf.contrib.integrate.odeint(func, np.float64(0.5), t)
with self.test_session() as sess:
y_solved = sess.run(y_solved)
y_true = 1.0 / (2.0 - t) + t
self.assertAllClose(y_true, y_solved)
def test_odeint_2d_linear(self):
# Solve the 2D linear differential equation:
# dy1 / dt = 3.0 * y1 + 4.0 * y2,
# dy2 / dt = -4.0 * y1 + 3.0 * y2,
# y1(0) = 0.0,
# y2(0) = 1.0.
# Its analytical solution is
# y1 = sin(4.0 * t) * exp(3.0 * t),
# y2 = cos(4.0 * t) * exp(3.0 * t).
matrix = tf.constant([[3.0, 4.0], [-4.0, 3.0]], dtype=tf.float64)
func = lambda y, t: tf.matmul(matrix, y)
y0 = tf.constant([[0.0], [1.0]], dtype=tf.float64)
t = np.linspace(0.0, 1.0, 11)
y_solved = tf.contrib.integrate.odeint(func, y0, t)
with self.test_session() as sess:
y_solved = sess.run(y_solved)
y_true = np.zeros((len(t), 2, 1))
y_true[:, 0, 0] = np.sin(4.0 * t) * np.exp(3.0 * t)
y_true[:, 1, 0] = np.cos(4.0 * t) * np.exp(3.0 * t)
self.assertAllClose(y_true, y_solved, atol=1e-5)
def test_odeint_higher_rank(self):
func = lambda y, t: y
y0 = tf.constant(1.0, dtype=tf.float64)
t = np.linspace(0.0, 1.0, 11)
for shape in [(), (1,), (1, 1)]:
expected_shape = (len(t),) + shape
y_solved = tf.contrib.integrate.odeint(func, tf.reshape(y0, shape), t)
self.assertEqual(y_solved.get_shape(), tf.TensorShape(expected_shape))
with self.test_session() as sess:
y_solved = sess.run(y_solved)
self.assertEquals(y_solved.shape, expected_shape)
def test_odeint_all_dtypes(self):
func = lambda y, t: y
t = np.linspace(0.0, 1.0, 11)
for y0_dtype in [tf.float32, tf.float64, tf.complex64, tf.complex128]:
for t_dtype in [tf.float32, tf.float64]:
y0 = tf.cast(1.0, y0_dtype)
y_solved = tf.contrib.integrate.odeint(func, y0, tf.cast(t, t_dtype))
with self.test_session() as sess:
y_solved = sess.run(y_solved)
expected = np.asarray(np.exp(t))
self.assertAllClose(y_solved, expected, rtol=1e-5)
self.assertEqual(tf.as_dtype(y_solved.dtype), y0_dtype)
def test_odeint_required_dtypes(self):
with self.assertRaisesRegexp(TypeError, '`y0` must have a floating point'):
tf.contrib.integrate.odeint(self.func, tf.cast(self.y0, tf.int32), [0, 1])
with self.assertRaisesRegexp(TypeError, '`t` must have a floating point'):
tf.contrib.integrate.odeint(self.func, self.y0, tf.cast([0, 1], tf.int32))
def test_odeint_runtime_errors(self):
with self.assertRaisesRegexp(
ValueError, 'cannot supply `options` without'):
tf.contrib.integrate.odeint(self.func, self.y0, [0, 1],
options={'first_step': 1.0})
y = tf.contrib.integrate.odeint(self.func, self.y0, [0, 1], method='dopri5',
options={'max_num_steps': 0})
with self.test_session() as sess:
with self.assertRaisesRegexp(
tf.errors.InvalidArgumentError, 'max_num_steps'):
sess.run(y)
y = tf.contrib.integrate.odeint(self.func, self.y0, [1, 0])
with self.test_session() as sess:
with self.assertRaisesRegexp(
tf.errors.InvalidArgumentError, 'monotonic increasing'):
sess.run(y)
def test_odeint_different_times(self):
# integrate steps should be independent of interpolation times
times0 = np.linspace(0, 10, num=11, dtype=float)
times1 = np.linspace(0, 10, num=101, dtype=float)
with self.test_session() as sess:
y_solved_0, info_0 = sess.run(
tf.contrib.integrate.odeint(
self.func, self.y0, times0, full_output=True))
y_solved_1, info_1 = sess.run(
tf.contrib.integrate.odeint(
self.func, self.y0, times1, full_output=True))
self.assertAllClose(y_solved_0, y_solved_1[::10])
self.assertEqual(info_0['num_func_evals'], info_1['num_func_evals'])
self.assertAllEqual(info_0['integrate_points'], info_1['integrate_points'])
self.assertAllEqual(info_0['error_ratio'], info_1['error_ratio'])
def test_odeint_5th_order_accuracy(self):
t = [0, 20]
kwargs = dict(full_output=True,
method='dopri5',
options=dict(max_num_steps=2000))
with self.test_session() as sess:
_, info_0 = sess.run(tf.contrib.integrate.odeint(
self.func, self.y0, t, rtol=0, atol=1e-6, **kwargs))
_, info_1 = sess.run(tf.contrib.integrate.odeint(
self.func, self.y0, t, rtol=0, atol=1e-9, **kwargs))
self.assertAllClose(info_0['integrate_points'].size * 1000 ** 0.2,
float(info_1['integrate_points'].size),
rtol=0.01)
class StepSizeTest(tf.test.TestCase):
def test_error_ratio_one(self):
new_step = odes._optimal_step_size(last_step=tf.constant(1.0),
error_ratio=tf.constant(1.0))
with self.test_session() as sess:
new_step = sess.run(new_step)
self.assertAllClose(new_step, 0.9)
def test_ifactor(self):
new_step = odes._optimal_step_size(last_step=tf.constant(1.0),
error_ratio=tf.constant(0.0))
with self.test_session() as sess:
new_step = sess.run(new_step)
self.assertAllClose(new_step, 10.0)
def test_dfactor(self):
new_step = odes._optimal_step_size(last_step=tf.constant(1.0),
error_ratio=tf.constant(1e6))
with self.test_session() as sess:
new_step = sess.run(new_step)
self.assertAllClose(new_step, 0.2)
class InterpolationTest(tf.test.TestCase):
def test_5th_order_polynomial(self):
# this should be an exact fit
f = lambda x: x ** 4 + x ** 3 - 2 * x ** 2 + 4 * x + 5
f_prime = lambda x: 4 * x ** 3 + 3 * x ** 2 - 4 * x + 4
coeffs = odes._interp_fit(
f(0.0), f(10.0), f(5.0), f_prime(0.0), f_prime(10.0), 10.0)
times = np.linspace(0, 10, dtype=np.float32)
y_fit = tf.pack([odes._interp_evaluate(coeffs, 0.0, 10.0, t)
for t in times])
y_expected = f(times)
with self.test_session() as sess:
y_actual = sess.run(y_fit)
self.assertAllClose(y_expected, y_actual)
# attempt interpolation outside bounds
y_invalid = odes._interp_evaluate(coeffs, 0.0, 10.0, 100.0)
with self.test_session() as sess:
with self.assertRaises(tf.errors.InvalidArgumentError):
sess.run(y_invalid)
if __name__ == '__main__':
tf.test.main()

View File

@ -67,6 +67,7 @@ def module_names():
"tf.contrib.ffmpeg",
"tf.contrib.framework",
"tf.contrib.graph_editor",
"tf.contrib.integrate",
"tf.contrib.layers",
"tf.contrib.learn",
"tf.contrib.learn.monitors",
@ -220,6 +221,7 @@ def all_libraries(module_to_name, members, documented):
library("contrib.framework", "Framework (contrib)", tf.contrib.framework),
library("contrib.graph_editor", "Graph Editor (contrib)",
tf.contrib.graph_editor),
library("contrib.integrate", "Integrate (contrib)", tf.contrib.integrate),
library("contrib.layers", "Layers (contrib)", tf.contrib.layers),
library("contrib.learn", "Learn (contrib)", tf.contrib.learn),
library("contrib.learn.monitors", "Monitors (contrib)",