Add initial version of tf.contrib.integrate
Includes tf.contrib.integrate.odeint, an integrator for ODEs patterned off of scipy.integrate.odeint. Change: 137973526
This commit is contained in:
parent
0b730a0d93
commit
65b1677070
@ -97,6 +97,7 @@ filegroup(
|
||||
"//tensorflow/contrib/framework:all_files",
|
||||
"//tensorflow/contrib/graph_editor:all_files",
|
||||
"//tensorflow/contrib/grid_rnn:all_files",
|
||||
"//tensorflow/contrib/integrate:all_files",
|
||||
"//tensorflow/contrib/layers:all_files",
|
||||
"//tensorflow/contrib/layers/kernels:all_files",
|
||||
"//tensorflow/contrib/learn:all_files",
|
||||
|
@ -23,6 +23,7 @@ py_library(
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
"//tensorflow/contrib/graph_editor:graph_editor_py",
|
||||
"//tensorflow/contrib/grid_rnn:grid_rnn_py",
|
||||
"//tensorflow/contrib/integrate:integrate_py",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/contrib/learn",
|
||||
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.contrib import factorization
|
||||
from tensorflow.contrib import framework
|
||||
from tensorflow.contrib import graph_editor
|
||||
from tensorflow.contrib import grid_rnn
|
||||
from tensorflow.contrib import integrate
|
||||
from tensorflow.contrib import layers
|
||||
from tensorflow.contrib import learn
|
||||
from tensorflow.contrib import linear_optimizer
|
||||
|
38
tensorflow/contrib/integrate/BUILD
Normal file
38
tensorflow/contrib/integrate/BUILD
Normal file
@ -0,0 +1,38 @@
|
||||
# Description:
|
||||
# Integration and ODE solvers for TensorFlow.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
py_library(
|
||||
name = "integrate_py",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"python/ops/odes.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "odes_test",
|
||||
srcs = ["python/ops/odes_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":integrate_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
)
|
9
tensorflow/contrib/integrate/README.md
Normal file
9
tensorflow/contrib/integrate/README.md
Normal file
@ -0,0 +1,9 @@
|
||||
# Integration and ODE solvers for TensorFlow
|
||||
|
||||
TensorFlow equivalents to the routines provided by `scipy.integrate`. Currently
|
||||
contains a single function, `odeint`, for integrating ordinary differential
|
||||
equations.
|
||||
|
||||
Maintainers:
|
||||
- Stephan Hoyer (shoyer@google.com, github.com/shoyer)
|
||||
- Marc Coram (mcoram@google.com, github.com/mcoram)
|
68
tensorflow/contrib/integrate/__init__.py
Normal file
68
tensorflow/contrib/integrate/__init__.py
Normal file
@ -0,0 +1,68 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Integration and ODE solvers for TensorFlow.
|
||||
|
||||
## Example: Lorenz attractor
|
||||
|
||||
We can use `odeint` to solve the
|
||||
[Lorentz system](https://en.wikipedia.org/wiki/Lorenz_system) of ordinary
|
||||
differential equations, a prototypical example of chaotic dynamics:
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import tensorflow as tf
|
||||
|
||||
rho = 28.0
|
||||
sigma = 10.0
|
||||
beta = 8.0/3.0
|
||||
|
||||
def lorenz_equation(state, t):
|
||||
x, y, z = tf.unpack(state)
|
||||
dx = sigma * (y - x)
|
||||
dy = x * (rho - z) - y
|
||||
dz = x * y - beta * z
|
||||
return tf.pack([dx, dy, dz])
|
||||
|
||||
init_state = tf.constant([0, 2, 20], dtype=tf.float64)
|
||||
t = np.linspace(0, 50, num=5000)
|
||||
tensor_state, tensor_info = tf.contrib.integrate.odeint(
|
||||
lorenz_equation, init_state, t, full_output=True)
|
||||
|
||||
sess = tf.Session()
|
||||
state, info = sess.run([tensor_state, tensor_info])
|
||||
x, y, z = state.T
|
||||
plt.plot(x, z)
|
||||
```
|
||||
|
||||
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
<img style="width:100%" src="../../images/lorenz_attractor.png" alt>
|
||||
</div>
|
||||
|
||||
## Ops
|
||||
|
||||
@@odeint
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.integrate.python.ops.odes import *
|
||||
from tensorflow.python.util.all_util import make_all
|
||||
|
||||
__all__ = make_all(__name__)
|
503
tensorflow/contrib/integrate/python/ops/odes.py
Normal file
503
tensorflow/contrib/integrate/python/ops/odes.py
Normal file
@ -0,0 +1,503 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""ODE solvers for TensorFlow."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
|
||||
|
||||
_ButcherTableau = collections.namedtuple(
|
||||
'_ButcherTableau', 'alpha beta c_sol c_mid c_error')
|
||||
|
||||
# Parameters from Shampine (1986), section 4.
|
||||
_DORMAND_PRINCE_TABLEAU = _ButcherTableau(
|
||||
alpha=[1/5, 3/10, 4/5, 8/9, 1., 1.],
|
||||
beta=[[1/5],
|
||||
[3/40, 9/40],
|
||||
[44/45, -56/15, 32/9],
|
||||
[19372/6561, -25360/2187, 64448/6561, -212/729],
|
||||
[9017/3168, -355/33, 46732/5247, 49/176, -5103/18656],
|
||||
[35/384, 0, 500/1113, 125/192, -2187/6784, 11/84]],
|
||||
c_sol=[35/384, 0, 500/1113, 125/192, -2187/6784, 11/84, 0],
|
||||
c_mid=[6025192743/30085553152 / 2, 0, 51252292925/65400821598 / 2,
|
||||
-2691868925/45128329728 / 2, 187940372067/1594534317056 / 2,
|
||||
-1776094331/19743644256 / 2, 11237099/235043384 / 2],
|
||||
c_error=[1951/21600 - 35/384,
|
||||
0,
|
||||
22642/50085 - 500/1113,
|
||||
451/720 - 125/192,
|
||||
-12231/42400 - -2187/6784,
|
||||
649/6300 - 11/84,
|
||||
1/60],
|
||||
)
|
||||
|
||||
|
||||
def _possibly_nonzero(x):
|
||||
return isinstance(x, ops.Tensor) or x != 0
|
||||
|
||||
|
||||
def _scaled_dot_product(scale, xs, ys, name=None):
|
||||
"""Calculate a scaled, vector inner product between lists of Tensors."""
|
||||
with ops.name_scope(name, 'scaled_dot_product', [scale, xs, ys]) as scope:
|
||||
# Some of the parameters in our Butcher tableau include zeros. Using
|
||||
# _possibly_nonzero lets us avoid wasted computation.
|
||||
return math_ops.add_n([(scale * x) * y for x, y in zip(xs, ys)
|
||||
if _possibly_nonzero(x) or _possibly_nonzero(y)],
|
||||
name=scope)
|
||||
|
||||
|
||||
def _dot_product(xs, ys, name=None):
|
||||
"""Calculate the vector inner product between two lists of Tensors."""
|
||||
with ops.name_scope(name, 'dot_product', [xs, ys]) as scope:
|
||||
return math_ops.add_n([x * y for x, y in zip(xs, ys)], name=scope)
|
||||
|
||||
|
||||
def _runge_kutta_step(func, y0, f0, t0, dt, tableau=_DORMAND_PRINCE_TABLEAU,
|
||||
name=None):
|
||||
"""Take an arbitrary Runge-Kutta step and estimate error.
|
||||
|
||||
Args:
|
||||
func: Function to evaluate like `func(y, t)` to compute the time derivative
|
||||
of `y`.
|
||||
y0: Tensor initial value for the state.
|
||||
f0: Tensor initial value for the derivative, computed from `func(y0, t0)`.
|
||||
t0: float64 scalar Tensor giving the initial time.
|
||||
dt: float64 scalar Tensor giving the size of the desired time step.
|
||||
tableau: optional _ButcherTableau describing how to take the Runge-Kutta
|
||||
step.
|
||||
name: optional name for the operation.
|
||||
|
||||
Returns:
|
||||
Tuple `(y1, f1, y1_error, k)` giving the estimated function value after
|
||||
the Runge-Kutta step at `t1 = t0 + dt`, the derivative of the state at `t1`,
|
||||
estimated error at `t1`, and a list of Runge-Kutta coefficients `k` used for
|
||||
calculating these terms.
|
||||
"""
|
||||
with ops.name_scope(name, 'runge_kutta_step', [y0, f0, t0, dt]) as scope:
|
||||
y0 = ops.convert_to_tensor(y0, name='y0')
|
||||
f0 = ops.convert_to_tensor(f0, name='f0')
|
||||
t0 = ops.convert_to_tensor(t0, name='t0')
|
||||
dt = ops.convert_to_tensor(dt, name='dt')
|
||||
dt_cast = math_ops.cast(dt, y0.dtype)
|
||||
|
||||
k = [f0]
|
||||
for alpha_i, beta_i in zip(tableau.alpha, tableau.beta):
|
||||
ti = t0 + alpha_i * dt
|
||||
yi = y0 + _scaled_dot_product(dt_cast, beta_i, k)
|
||||
k.append(func(yi, ti))
|
||||
|
||||
if not (tableau.c_sol[-1] == 0 and tableau.c_sol == tableau.beta[-1]):
|
||||
# This property (true for Dormand-Prince) lets us save a few FLOPs.
|
||||
yi = y0 + _scaled_dot_product(dt_cast, tableau.c_sol, k)
|
||||
|
||||
y1 = array_ops.identity(yi, name='%s/y1' % scope)
|
||||
f1 = array_ops.identity(k[-1], name='%s/f1' % scope)
|
||||
y1_error = _scaled_dot_product(dt_cast, tableau.c_error, k,
|
||||
name='%s/y1_error' % scope)
|
||||
return (y1, f1, y1_error, k)
|
||||
|
||||
|
||||
def _interp_fit(y0, y1, y_mid, f0, f1, dt):
|
||||
"""Fit coefficients for 4th order polynomial interpolation.
|
||||
|
||||
Args:
|
||||
y0: function value at the start of the interval.
|
||||
y1: function value at the end of the interval.
|
||||
y_mid: function value at the mid-point of the interval.
|
||||
f0: derivative value at the start of the interval.
|
||||
f1: derivative value at the end of the interval.
|
||||
dt: width of the interval.
|
||||
|
||||
Returns:
|
||||
List of coefficients `[a, b, c, d, e]` for interpolating with the polynomial
|
||||
`p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e` for values of `x`
|
||||
between 0 (start of interval) and 1 (end of interval).
|
||||
"""
|
||||
# a, b, c, d, e = sympy.symbols('a b c d e')
|
||||
# x, dt, y0, y1, y_mid, f0, f1 = sympy.symbols('x dt y0 y1 y_mid f0 f1')
|
||||
# p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e
|
||||
# sympy.solve([p.subs(x, 0) - y0,
|
||||
# p.subs(x, 1 / 2) - y_mid,
|
||||
# p.subs(x, 1) - y1,
|
||||
# (p.diff(x) / dt).subs(x, 0) - f0,
|
||||
# (p.diff(x) / dt).subs(x, 1) - f1],
|
||||
# [a, b, c, d, e])
|
||||
# {a: -2.0*dt*f0 + 2.0*dt*f1 - 8.0*y0 - 8.0*y1 + 16.0*y_mid,
|
||||
# b: 5.0*dt*f0 - 3.0*dt*f1 + 18.0*y0 + 14.0*y1 - 32.0*y_mid,
|
||||
# c: -4.0*dt*f0 + dt*f1 - 11.0*y0 - 5.0*y1 + 16.0*y_mid,
|
||||
# d: dt*f0,
|
||||
# e: y0}
|
||||
a = _dot_product([-2 * dt, 2 * dt, -8, -8, 16], [f0, f1, y0, y1, y_mid])
|
||||
b = _dot_product([5 * dt, -3 * dt, 18, 14, -32], [f0, f1, y0, y1, y_mid])
|
||||
c = _dot_product([-4 * dt, dt, -11, -5, 16], [f0, f1, y0, y1, y_mid])
|
||||
d = dt * f0
|
||||
e = y0
|
||||
return [a, b, c, d, e]
|
||||
|
||||
|
||||
def _interp_fit_rk(y0, y1, k, dt, tableau=_DORMAND_PRINCE_TABLEAU):
|
||||
"""Fit an interpolating polynomial to the results of a Runge-Kutta step."""
|
||||
with ops.name_scope('interp_fit_rk'):
|
||||
dt = math_ops.cast(dt, y0.dtype)
|
||||
y_mid = y0 + _scaled_dot_product(dt, tableau.c_mid, k)
|
||||
f0 = k[0]
|
||||
f1 = k[-1]
|
||||
return _interp_fit(y0, y1, y_mid, f0, f1, dt)
|
||||
|
||||
|
||||
def _interp_evaluate(coefficients, t0, t1, t):
|
||||
"""Evaluate polynomial interpolation at the given time point.
|
||||
|
||||
Args:
|
||||
coefficients: list of Tensor coefficients as created by `interp_fit`.
|
||||
t0: scalar float64 Tensor giving the start of the interval.
|
||||
t1: scalar float64 Tensor giving the end of the interval.
|
||||
t: scalar float64 Tensor giving the desired interpolation point.
|
||||
|
||||
Returns:
|
||||
Polynomial interpolation of the coefficients at time `t`.
|
||||
"""
|
||||
with ops.name_scope('interp_evaluate'):
|
||||
t0 = ops.convert_to_tensor(t0)
|
||||
t1 = ops.convert_to_tensor(t1)
|
||||
t = ops.convert_to_tensor(t)
|
||||
|
||||
dtype = coefficients[0].dtype
|
||||
|
||||
assert_op = control_flow_ops.Assert(
|
||||
(t0 <= t) & (t <= t1),
|
||||
['invalid interpolation, fails `t0 <= t <= t1`:', t0, t, t1])
|
||||
with ops.control_dependencies([assert_op]):
|
||||
x = math_ops.cast((t - t0) / (t1 - t0), dtype)
|
||||
|
||||
xs = [constant_op.constant(1, dtype), x]
|
||||
for _ in range(2, len(coefficients)):
|
||||
xs.append(xs[-1] * x)
|
||||
|
||||
return _dot_product(coefficients, reversed(xs))
|
||||
|
||||
|
||||
def _optimal_step_size(last_step,
|
||||
error_ratio,
|
||||
safety=0.9,
|
||||
ifactor=10.0,
|
||||
dfactor=0.2,
|
||||
order=5,
|
||||
name=None):
|
||||
"""Calculate the optimal size for the next Runge-Kutta step."""
|
||||
with ops.name_scope(
|
||||
name, 'optimal_step_size', [last_step, error_ratio]) as scope:
|
||||
error_ratio = math_ops.cast(error_ratio, last_step.dtype)
|
||||
exponent = math_ops.cast(1 / order, last_step.dtype)
|
||||
# this looks more complex than necessary, but importantly it keeps
|
||||
# error_ratio in the numerator so we can't divide by zero:
|
||||
factor = math_ops.maximum(
|
||||
1 / ifactor,
|
||||
math_ops.minimum(error_ratio ** exponent / safety, 1 / dfactor))
|
||||
return math_ops.div(last_step, factor, name=scope)
|
||||
|
||||
|
||||
def _abs_square(x):
|
||||
if x.dtype.is_complex:
|
||||
return math_ops.square(math_ops.real(x)) + math_ops.square(math_ops.imag(x))
|
||||
else:
|
||||
return math_ops.square(x)
|
||||
|
||||
|
||||
def _ta_append(tensor_array, value):
|
||||
"""Append a value to the end of a tf.TensorArray."""
|
||||
return tensor_array.write(tensor_array.size(), value)
|
||||
|
||||
|
||||
class _RungeKuttaState(collections.namedtuple(
|
||||
'_RungeKuttaState', 'y1, f1, t0, t1, dt, interp_coeff')):
|
||||
"""Saved state of the Runge Kutta solver.
|
||||
|
||||
Attributes:
|
||||
y1: Tensor giving the function value at the end of the last time step.
|
||||
f1: Tensor giving derivative at the end of the last time step.
|
||||
t0: scalar float64 Tensor giving start of the last time step.
|
||||
t1: scalar float64 Tensor giving end of the last time step.
|
||||
dt: scalar float64 Tensor giving the size for the next time step.
|
||||
interp_coef: list of Tensors giving coefficients for polynomial
|
||||
interpolation between `t0` and `t1`.
|
||||
"""
|
||||
|
||||
|
||||
class _History(collections.namedtuple(
|
||||
'_History', 'integrate_points, error_ratio')):
|
||||
"""Saved integration history for use in `info_dict`.
|
||||
|
||||
Attributes:
|
||||
integrate_points: tf.TensorArray storing integrating time points.
|
||||
error_ratio: tf.TensorArray storing computed error ratios at each
|
||||
integration step.
|
||||
"""
|
||||
|
||||
|
||||
def _dopri5(func,
|
||||
y0,
|
||||
t,
|
||||
rtol,
|
||||
atol,
|
||||
full_output=False,
|
||||
first_step=None,
|
||||
safety=0.9,
|
||||
ifactor=10.0,
|
||||
dfactor=0.2,
|
||||
max_num_steps=1000,
|
||||
name=None):
|
||||
"""Solve an ODE for `odeint` using method='dopri5'."""
|
||||
|
||||
if first_step is None:
|
||||
# at some point, we might want to switch to picking the step size
|
||||
# automatically
|
||||
first_step = 1.0
|
||||
|
||||
with ops.name_scope(
|
||||
name, 'dopri5',
|
||||
[y0, t, rtol, atol, safety, ifactor, dfactor, max_num_steps]) as scope:
|
||||
|
||||
first_step = ops.convert_to_tensor(first_step, dtype=t.dtype,
|
||||
name='first_step')
|
||||
safety = ops.convert_to_tensor(safety, dtype=t.dtype, name='safety')
|
||||
ifactor = ops.convert_to_tensor(ifactor, dtype=t.dtype, name='ifactor')
|
||||
dfactor = ops.convert_to_tensor(dfactor, dtype=t.dtype, name='dfactor')
|
||||
max_num_steps = ops.convert_to_tensor(max_num_steps, dtype=dtypes.int32,
|
||||
name='max_num_steps')
|
||||
|
||||
def adaptive_runge_kutta_step(rk_state, history, n_steps):
|
||||
"""Take an adaptive Runge-Kutta step to integrate the ODE."""
|
||||
y0, f0, _, t0, dt, interp_coeff = rk_state
|
||||
with ops.name_scope('assertions'):
|
||||
check_underflow = control_flow_ops.Assert(
|
||||
t0 + dt > t0, ['underflow in dt', dt])
|
||||
check_max_num_steps = control_flow_ops.Assert(
|
||||
n_steps < max_num_steps, ['max_num_steps exceeded'])
|
||||
check_numerics = control_flow_ops.Assert(
|
||||
math_ops.reduce_all(math_ops.is_finite(abs(y0))),
|
||||
['non-finite values in state `y`', y0])
|
||||
with ops.control_dependencies(
|
||||
[check_underflow, check_max_num_steps, check_numerics]):
|
||||
y1, f1, y1_error, k = _runge_kutta_step(func, y0, f0, t0, dt)
|
||||
|
||||
with ops.name_scope('error_ratio'):
|
||||
# We use the same approach as the dopri5 fortran code.
|
||||
error_tol = atol + rtol * math_ops.maximum(abs(y0), abs(y1))
|
||||
tensor_error_ratio = _abs_square(y1_error) / _abs_square(error_tol)
|
||||
# Could also use reduce_maximum here.
|
||||
error_ratio = math_ops.sqrt(math_ops.reduce_mean(tensor_error_ratio))
|
||||
accept_step = error_ratio <= 1
|
||||
|
||||
with ops.name_scope('update/rk_state'):
|
||||
# If we don't accept the step, the _RungeKuttaState will be useless
|
||||
# (covering a time-interval of size 0), but that's OK, because in such
|
||||
# cases we always immediately take another Runge-Kutta step.
|
||||
y_next = control_flow_ops.cond(accept_step, lambda: y1, lambda: y0)
|
||||
f_next = control_flow_ops.cond(accept_step, lambda: f1, lambda: f0)
|
||||
t_next = control_flow_ops.cond(accept_step, lambda: t0 + dt, lambda: t0)
|
||||
interp_coeff = control_flow_ops.cond(
|
||||
accept_step,
|
||||
lambda: _interp_fit_rk(y0, y1, k, dt),
|
||||
lambda: interp_coeff)
|
||||
dt_next = _optimal_step_size(dt, error_ratio, safety, ifactor, dfactor)
|
||||
rk_state = _RungeKuttaState(
|
||||
y_next, f_next, t0, t_next, dt_next, interp_coeff)
|
||||
|
||||
with ops.name_scope('update/history'):
|
||||
history = _History(_ta_append(history.integrate_points, t0 + dt),
|
||||
_ta_append(history.error_ratio, error_ratio))
|
||||
return rk_state, history, n_steps + 1
|
||||
|
||||
def interpolate(solution, history, rk_state, i):
|
||||
"""Interpolate through the next time point, integrating as necessary."""
|
||||
with ops.name_scope('interpolate'):
|
||||
rk_state, history, _ = control_flow_ops.while_loop(
|
||||
lambda rk_state, *_: t[i] > rk_state.t1,
|
||||
adaptive_runge_kutta_step,
|
||||
(rk_state, history, 0),
|
||||
name='integrate_loop')
|
||||
y = _interp_evaluate(
|
||||
rk_state.interp_coeff, rk_state.t0, rk_state.t1, t[i])
|
||||
solution = solution.write(i, y)
|
||||
return solution, history, rk_state, i + 1
|
||||
|
||||
assert_increasing = control_flow_ops.Assert(
|
||||
math_ops.reduce_all(t[1:] > t[:-1]),
|
||||
['`t` must be monotonic increasing'])
|
||||
with ops.control_dependencies([assert_increasing]):
|
||||
num_times = array_ops.size(t)
|
||||
|
||||
solution = tensor_array_ops.TensorArray(
|
||||
y0.dtype, size=num_times).write(0, y0)
|
||||
history = _History(
|
||||
integrate_points=tensor_array_ops.TensorArray(
|
||||
t.dtype, size=0, dynamic_size=True),
|
||||
error_ratio=tensor_array_ops.TensorArray(
|
||||
rtol.dtype, size=0, dynamic_size=True))
|
||||
rk_state = _RungeKuttaState(
|
||||
y0, func(y0, t[0]), t[0], t[0], first_step, interp_coeff=[y0] * 5)
|
||||
|
||||
solution, history, _, _ = control_flow_ops.while_loop(
|
||||
lambda _, __, ___, i: i < num_times,
|
||||
interpolate,
|
||||
(solution, history, rk_state, 1),
|
||||
name='interpolate_loop')
|
||||
|
||||
y = solution.pack(name=scope)
|
||||
y.set_shape(t.get_shape().concatenate(y0.get_shape()))
|
||||
if not full_output:
|
||||
return y
|
||||
else:
|
||||
integrate_points = history.integrate_points.pack()
|
||||
info_dict = {'num_func_evals': 6 * array_ops.size(integrate_points) + 1,
|
||||
'integrate_points': integrate_points,
|
||||
'error_ratio': history.error_ratio.pack()}
|
||||
return (y, info_dict)
|
||||
|
||||
|
||||
def odeint(func,
|
||||
y0,
|
||||
t,
|
||||
rtol=1e-6,
|
||||
atol=1e-12,
|
||||
method=None,
|
||||
options=None,
|
||||
full_output=False,
|
||||
name=None):
|
||||
"""Integrate a system of ordinary differential equations.
|
||||
|
||||
Solves the initial value problem for a non-stiff system of first order ode-s:
|
||||
|
||||
```
|
||||
dy/dt = func(y, t), y(t[0]) = y0
|
||||
```
|
||||
|
||||
where y is a Tensor of any shape.
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
# solve `dy/dt = -y`, corresponding to exponential decay
|
||||
tf.contrib.integrate.odeint(lambda y, _: -y, 1.0, [0, 1, 2])
|
||||
=> [1, exp(-1), exp(-2)]
|
||||
```
|
||||
|
||||
Output dtypes and numerical precision are based on the dtypes of the inputs
|
||||
`y0` and `t`.
|
||||
|
||||
Currently, implements 5th order Runge-Kutta with adaptive step size control
|
||||
and dense output, using the Dormand-Prince method. Similar to the 'dopri5'
|
||||
method of `scipy.integrate.ode` and MATLAB's `ode45`.
|
||||
|
||||
Based on: Shampine, Lawrence F. (1986), "Some Practical Runge-Kutta Formulas",
|
||||
Mathematics of Computation, American Mathematical Society, 46 (173): 135-150,
|
||||
doi:10.2307/2008219
|
||||
|
||||
Args:
|
||||
func: Function that maps a Tensor holding the state `y` and a scalar Tensor
|
||||
`t` into a Tensor of state derivatives with respect to time.
|
||||
y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May
|
||||
have any floating point or complex dtype.
|
||||
t: 1-D Tensor holding a sequence of time points for which to solve for
|
||||
`y`. The initial time point should be the first element of this sequence,
|
||||
and each time must be larger than the previous time. May have any floating
|
||||
point dtype. If not provided as a Tensor, converted to a Tensor with
|
||||
float64 dtype.
|
||||
rtol: optional float64 Tensor specifying an upper bound on relative error,
|
||||
per element of `y`.
|
||||
atol: optional float64 Tensor specifying an upper bound on absolute error,
|
||||
per element of `y`.
|
||||
method: optional string indicating the integration method to use. Currently,
|
||||
the only valid option is `'dopri5'`.
|
||||
options: optional dict of configuring options for the indicated integration
|
||||
method. Can only be provided if a `method` is explicitly set. For
|
||||
`'dopri5'`, valid options include:
|
||||
* first_step: an initial guess for the size of the first integration
|
||||
(current default: 1.0, but may later be changed to use heuristics based
|
||||
on the gradient).
|
||||
* safety: safety factor for adaptive step control, generally a constant
|
||||
in the range 0.8-1 (default: 0.9).
|
||||
* ifactor: maximum factor by which the adaptive step may be increased
|
||||
(default: 10.0).
|
||||
* dfactor: maximum factor by which the adpative step may be decreased
|
||||
(default: 0.2).
|
||||
* max_num_steps: integer maximum number of integrate steps between time
|
||||
points in `t` (default: 1000).
|
||||
full_output: optional boolean. If True, `odeint` returns a tuple
|
||||
`(y, info_dict)` describing the integration process.
|
||||
name: Optional name for this operation.
|
||||
|
||||
Returns:
|
||||
y: (N+1)-D tensor, where the first dimension corresponds to different
|
||||
time points. Contains the solved value of y for each desired time point in
|
||||
`t`, with the initial value `y0` being the first element along the first
|
||||
dimension.
|
||||
info_dict: only if `full_output == True`. A dict with the following values:
|
||||
* num_func_evals: integer Tensor counting the number of function
|
||||
evaluations.
|
||||
* integrate_points: 1D float64 Tensor with the upper bound of each
|
||||
integration time step.
|
||||
* error_ratio: 1D float Tensor with the estimated ratio of the integration
|
||||
error to the error tolerance at each integration step. An ratio greater
|
||||
than 1 corresponds to rejected steps.
|
||||
|
||||
Raises:
|
||||
ValueError: if an invalid `method` is provided.
|
||||
TypeError: if `options` is supplied without `method`, or if `t` or `y0` has
|
||||
an invalid dtype.
|
||||
"""
|
||||
if method is not None and method != 'dopri5':
|
||||
raise ValueError('invalid method: %r' % method)
|
||||
|
||||
if options is None:
|
||||
options = {}
|
||||
elif method is None:
|
||||
raise ValueError('cannot supply `options` without specifying `method`')
|
||||
|
||||
with ops.name_scope(name, 'odeint', [y0, t, rtol, atol]) as scope:
|
||||
# TODO(shoyer): use nest.flatten (like tf.while_loop) to allow `y0` to be an
|
||||
# arbitrarily nested tuple. This will help performance and usability by
|
||||
# avoiding the need to pack/unpack in user functions.
|
||||
y0 = ops.convert_to_tensor(y0, name='y0')
|
||||
if not (y0.dtype.is_floating or y0.dtype.is_complex):
|
||||
raise TypeError('`y0` must have a floating point or complex floating '
|
||||
'point dtype')
|
||||
|
||||
t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t')
|
||||
if not t.dtype.is_floating:
|
||||
raise TypeError('`t` must have a floating point dtype')
|
||||
|
||||
error_dtype = abs(y0).dtype
|
||||
rtol = ops.convert_to_tensor(rtol, dtype=error_dtype, name='rtol')
|
||||
atol = ops.convert_to_tensor(atol, dtype=error_dtype, name='atol')
|
||||
|
||||
return _dopri5(func, y0, t,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
full_output=full_output,
|
||||
name=scope,
|
||||
**options)
|
232
tensorflow/contrib/integrate/python/ops/odes_test.py
Normal file
232
tensorflow/contrib/integrate/python/ops/odes_test.py
Normal file
@ -0,0 +1,232 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for ODE solvers."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.integrate.python.ops import odes
|
||||
|
||||
|
||||
class OdeIntTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(OdeIntTest, self).setUp()
|
||||
# simple defaults (solution is a sin-wave)
|
||||
matrix = tf.constant([[0, 1], [-1, 0]], dtype=tf.float64)
|
||||
self.func = lambda y, t: tf.matmul(matrix, y)
|
||||
self.y0 = np.array([[1.0], [0.0]])
|
||||
|
||||
def test_odeint_exp(self):
|
||||
# Test odeint by an exponential function:
|
||||
# dy / dt = y, y(0) = 1.0.
|
||||
# Its analytical solution is y = exp(t).
|
||||
func = lambda y, t: y
|
||||
y0 = tf.constant(1.0, dtype=tf.float64)
|
||||
t = np.linspace(0.0, 1.0, 11)
|
||||
y_solved = tf.contrib.integrate.odeint(func, y0, t)
|
||||
self.assertIn('odeint', y_solved.name)
|
||||
self.assertEqual(y_solved.get_shape(), tf.TensorShape([11]))
|
||||
with self.test_session() as sess:
|
||||
y_solved = sess.run(y_solved)
|
||||
y_true = np.exp(t)
|
||||
self.assertAllClose(y_true, y_solved)
|
||||
|
||||
def test_odeint_complex(self):
|
||||
# Test a complex, linear ODE:
|
||||
# dy / dt = k * y, y(0) = 1.0.
|
||||
# Its analytical solution is y = exp(k * t).
|
||||
k = 1j - 0.1
|
||||
func = lambda y, t: k * y
|
||||
t = np.linspace(0.0, 1.0, 11)
|
||||
y_solved = tf.contrib.integrate.odeint(func, 1.0 + 0.0j, t)
|
||||
with self.test_session() as sess:
|
||||
y_solved = sess.run(y_solved)
|
||||
y_true = np.exp(k * t)
|
||||
self.assertAllClose(y_true, y_solved)
|
||||
|
||||
def test_odeint_riccati(self):
|
||||
# The Ricatti equation is:
|
||||
# dy / dt = (y - t) ** 2 + 1.0, y(0) = 0.5.
|
||||
# Its analytical solution is y = 1.0 / (2.0 - t) + t.
|
||||
func = lambda t, y: (y - t)**2 + 1.0
|
||||
t = np.linspace(0.0, 1.0, 11)
|
||||
y_solved = tf.contrib.integrate.odeint(func, np.float64(0.5), t)
|
||||
with self.test_session() as sess:
|
||||
y_solved = sess.run(y_solved)
|
||||
y_true = 1.0 / (2.0 - t) + t
|
||||
self.assertAllClose(y_true, y_solved)
|
||||
|
||||
def test_odeint_2d_linear(self):
|
||||
# Solve the 2D linear differential equation:
|
||||
# dy1 / dt = 3.0 * y1 + 4.0 * y2,
|
||||
# dy2 / dt = -4.0 * y1 + 3.0 * y2,
|
||||
# y1(0) = 0.0,
|
||||
# y2(0) = 1.0.
|
||||
# Its analytical solution is
|
||||
# y1 = sin(4.0 * t) * exp(3.0 * t),
|
||||
# y2 = cos(4.0 * t) * exp(3.0 * t).
|
||||
matrix = tf.constant([[3.0, 4.0], [-4.0, 3.0]], dtype=tf.float64)
|
||||
func = lambda y, t: tf.matmul(matrix, y)
|
||||
|
||||
y0 = tf.constant([[0.0], [1.0]], dtype=tf.float64)
|
||||
t = np.linspace(0.0, 1.0, 11)
|
||||
|
||||
y_solved = tf.contrib.integrate.odeint(func, y0, t)
|
||||
with self.test_session() as sess:
|
||||
y_solved = sess.run(y_solved)
|
||||
|
||||
y_true = np.zeros((len(t), 2, 1))
|
||||
y_true[:, 0, 0] = np.sin(4.0 * t) * np.exp(3.0 * t)
|
||||
y_true[:, 1, 0] = np.cos(4.0 * t) * np.exp(3.0 * t)
|
||||
self.assertAllClose(y_true, y_solved, atol=1e-5)
|
||||
|
||||
def test_odeint_higher_rank(self):
|
||||
func = lambda y, t: y
|
||||
y0 = tf.constant(1.0, dtype=tf.float64)
|
||||
t = np.linspace(0.0, 1.0, 11)
|
||||
for shape in [(), (1,), (1, 1)]:
|
||||
expected_shape = (len(t),) + shape
|
||||
y_solved = tf.contrib.integrate.odeint(func, tf.reshape(y0, shape), t)
|
||||
self.assertEqual(y_solved.get_shape(), tf.TensorShape(expected_shape))
|
||||
with self.test_session() as sess:
|
||||
y_solved = sess.run(y_solved)
|
||||
self.assertEquals(y_solved.shape, expected_shape)
|
||||
|
||||
def test_odeint_all_dtypes(self):
|
||||
func = lambda y, t: y
|
||||
t = np.linspace(0.0, 1.0, 11)
|
||||
for y0_dtype in [tf.float32, tf.float64, tf.complex64, tf.complex128]:
|
||||
for t_dtype in [tf.float32, tf.float64]:
|
||||
y0 = tf.cast(1.0, y0_dtype)
|
||||
y_solved = tf.contrib.integrate.odeint(func, y0, tf.cast(t, t_dtype))
|
||||
with self.test_session() as sess:
|
||||
y_solved = sess.run(y_solved)
|
||||
expected = np.asarray(np.exp(t))
|
||||
self.assertAllClose(y_solved, expected, rtol=1e-5)
|
||||
self.assertEqual(tf.as_dtype(y_solved.dtype), y0_dtype)
|
||||
|
||||
def test_odeint_required_dtypes(self):
|
||||
with self.assertRaisesRegexp(TypeError, '`y0` must have a floating point'):
|
||||
tf.contrib.integrate.odeint(self.func, tf.cast(self.y0, tf.int32), [0, 1])
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, '`t` must have a floating point'):
|
||||
tf.contrib.integrate.odeint(self.func, self.y0, tf.cast([0, 1], tf.int32))
|
||||
|
||||
def test_odeint_runtime_errors(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'cannot supply `options` without'):
|
||||
tf.contrib.integrate.odeint(self.func, self.y0, [0, 1],
|
||||
options={'first_step': 1.0})
|
||||
|
||||
y = tf.contrib.integrate.odeint(self.func, self.y0, [0, 1], method='dopri5',
|
||||
options={'max_num_steps': 0})
|
||||
with self.test_session() as sess:
|
||||
with self.assertRaisesRegexp(
|
||||
tf.errors.InvalidArgumentError, 'max_num_steps'):
|
||||
sess.run(y)
|
||||
|
||||
y = tf.contrib.integrate.odeint(self.func, self.y0, [1, 0])
|
||||
with self.test_session() as sess:
|
||||
with self.assertRaisesRegexp(
|
||||
tf.errors.InvalidArgumentError, 'monotonic increasing'):
|
||||
sess.run(y)
|
||||
|
||||
def test_odeint_different_times(self):
|
||||
# integrate steps should be independent of interpolation times
|
||||
times0 = np.linspace(0, 10, num=11, dtype=float)
|
||||
times1 = np.linspace(0, 10, num=101, dtype=float)
|
||||
|
||||
with self.test_session() as sess:
|
||||
y_solved_0, info_0 = sess.run(
|
||||
tf.contrib.integrate.odeint(
|
||||
self.func, self.y0, times0, full_output=True))
|
||||
y_solved_1, info_1 = sess.run(
|
||||
tf.contrib.integrate.odeint(
|
||||
self.func, self.y0, times1, full_output=True))
|
||||
|
||||
self.assertAllClose(y_solved_0, y_solved_1[::10])
|
||||
self.assertEqual(info_0['num_func_evals'], info_1['num_func_evals'])
|
||||
self.assertAllEqual(info_0['integrate_points'], info_1['integrate_points'])
|
||||
self.assertAllEqual(info_0['error_ratio'], info_1['error_ratio'])
|
||||
|
||||
def test_odeint_5th_order_accuracy(self):
|
||||
t = [0, 20]
|
||||
kwargs = dict(full_output=True,
|
||||
method='dopri5',
|
||||
options=dict(max_num_steps=2000))
|
||||
with self.test_session() as sess:
|
||||
_, info_0 = sess.run(tf.contrib.integrate.odeint(
|
||||
self.func, self.y0, t, rtol=0, atol=1e-6, **kwargs))
|
||||
_, info_1 = sess.run(tf.contrib.integrate.odeint(
|
||||
self.func, self.y0, t, rtol=0, atol=1e-9, **kwargs))
|
||||
self.assertAllClose(info_0['integrate_points'].size * 1000 ** 0.2,
|
||||
float(info_1['integrate_points'].size),
|
||||
rtol=0.01)
|
||||
|
||||
|
||||
class StepSizeTest(tf.test.TestCase):
|
||||
|
||||
def test_error_ratio_one(self):
|
||||
new_step = odes._optimal_step_size(last_step=tf.constant(1.0),
|
||||
error_ratio=tf.constant(1.0))
|
||||
with self.test_session() as sess:
|
||||
new_step = sess.run(new_step)
|
||||
self.assertAllClose(new_step, 0.9)
|
||||
|
||||
def test_ifactor(self):
|
||||
new_step = odes._optimal_step_size(last_step=tf.constant(1.0),
|
||||
error_ratio=tf.constant(0.0))
|
||||
with self.test_session() as sess:
|
||||
new_step = sess.run(new_step)
|
||||
self.assertAllClose(new_step, 10.0)
|
||||
|
||||
def test_dfactor(self):
|
||||
new_step = odes._optimal_step_size(last_step=tf.constant(1.0),
|
||||
error_ratio=tf.constant(1e6))
|
||||
with self.test_session() as sess:
|
||||
new_step = sess.run(new_step)
|
||||
self.assertAllClose(new_step, 0.2)
|
||||
|
||||
|
||||
class InterpolationTest(tf.test.TestCase):
|
||||
|
||||
def test_5th_order_polynomial(self):
|
||||
# this should be an exact fit
|
||||
f = lambda x: x ** 4 + x ** 3 - 2 * x ** 2 + 4 * x + 5
|
||||
f_prime = lambda x: 4 * x ** 3 + 3 * x ** 2 - 4 * x + 4
|
||||
coeffs = odes._interp_fit(
|
||||
f(0.0), f(10.0), f(5.0), f_prime(0.0), f_prime(10.0), 10.0)
|
||||
times = np.linspace(0, 10, dtype=np.float32)
|
||||
y_fit = tf.pack([odes._interp_evaluate(coeffs, 0.0, 10.0, t)
|
||||
for t in times])
|
||||
y_expected = f(times)
|
||||
with self.test_session() as sess:
|
||||
y_actual = sess.run(y_fit)
|
||||
self.assertAllClose(y_expected, y_actual)
|
||||
|
||||
# attempt interpolation outside bounds
|
||||
y_invalid = odes._interp_evaluate(coeffs, 0.0, 10.0, 100.0)
|
||||
with self.test_session() as sess:
|
||||
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||
sess.run(y_invalid)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
@ -67,6 +67,7 @@ def module_names():
|
||||
"tf.contrib.ffmpeg",
|
||||
"tf.contrib.framework",
|
||||
"tf.contrib.graph_editor",
|
||||
"tf.contrib.integrate",
|
||||
"tf.contrib.layers",
|
||||
"tf.contrib.learn",
|
||||
"tf.contrib.learn.monitors",
|
||||
@ -220,6 +221,7 @@ def all_libraries(module_to_name, members, documented):
|
||||
library("contrib.framework", "Framework (contrib)", tf.contrib.framework),
|
||||
library("contrib.graph_editor", "Graph Editor (contrib)",
|
||||
tf.contrib.graph_editor),
|
||||
library("contrib.integrate", "Integrate (contrib)", tf.contrib.integrate),
|
||||
library("contrib.layers", "Layers (contrib)", tf.contrib.layers),
|
||||
library("contrib.learn", "Learn (contrib)", tf.contrib.learn),
|
||||
library("contrib.learn.monitors", "Monitors (contrib)",
|
||||
|
Loading…
Reference in New Issue
Block a user