Allows tfe.enable_eager_execution(device_policy=tfe.DEVICE_POLICY_WARN).
PiperOrigin-RevId: 172943398
This commit is contained in:
parent
f5b14e496f
commit
985031a101
tensorflow
@ -1,15 +0,0 @@
|
||||
TensorFlow has many kernels for doing (deep) learning and data manipulation.
|
||||
There are typically assembled into computational graphs which can run
|
||||
efficiently in a variety of environments.
|
||||
|
||||
We are exploring an alternative interaction, where kernels are invoked
|
||||
immediately and call this "eager execution". We are hoping to retain the
|
||||
benefits of graphs while improving usability with benefits like:
|
||||
|
||||
- Immediate error messages and easier debugging
|
||||
- Flexibility to use Python datastructures and control flow
|
||||
- Reduced boilerplate
|
||||
|
||||
Eager execution is under active development.
|
||||
There are not many developer-facing materials yet, but stay tuned for updates
|
||||
in this directory.
|
@ -1,65 +1,15 @@
|
||||
# TensorFlow Eager Execution
|
||||
TensorFlow has many kernels for doing (deep) learning and data manipulation.
|
||||
There are typically assembled into computational graphs which can run
|
||||
efficiently in a variety of environments.
|
||||
|
||||
> *WARNING*: This is a preview/pre-alpha version. The API and performance
|
||||
> characteristics are subject to change.
|
||||
We are exploring an alternative interaction, where kernels are invoked
|
||||
immediately and call this "eager execution". We are hoping to retain the
|
||||
benefits of graphs while improving usability with benefits like:
|
||||
|
||||
- Immediate error messages and easier debugging
|
||||
- Flexibility to use Python datastructures and control flow
|
||||
- Reduced boilerplate
|
||||
|
||||
Eager execution is an experimental interface to TensorFlow that provides an
|
||||
imperative programming style (à la [NumPy](http://www.numpy.org)). When you
|
||||
enable eager execution, TensorFlow operations execute immediately; you do not
|
||||
execute a pre-constructed graph with
|
||||
[`Session.run()`](https://www.tensorflow.org/api_docs/python/tf/Session).
|
||||
|
||||
For example, consider a simple computation in TensorFlow:
|
||||
|
||||
```python
|
||||
x = tf.placeholder(tf.float32, shape=[1, 1])
|
||||
m = tf.matmul(x, x)
|
||||
|
||||
with tf.Session() as sess:
|
||||
print(sess.run(m, feed_dict={x: [[2.]]}))
|
||||
|
||||
# Will print [[4.]]
|
||||
```
|
||||
|
||||
Eager execution makes this much simpler:
|
||||
|
||||
```python
|
||||
x = [[2.]]
|
||||
m = tf.matmul(x, x)
|
||||
|
||||
print(m)
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
Since eager execution is not yet part of a TensorFlow release, using it requires
|
||||
either [building from source](https://www.tensorflow.org/install/install_sources)
|
||||
or the latest nightly builds. The nightly builds are available as:
|
||||
|
||||
- [`pip` packages](https://github.com/tensorflow/tensorflow/blob/master/README.md#installation) and
|
||||
|
||||
- [docker](https://hub.docker.com/r/tensorflow/tensorflow/) images.
|
||||
|
||||
For example, to run the latest nightly docker image:
|
||||
|
||||
```sh
|
||||
# If you have a GPU, use https://github.com/NVIDIA/nvidia-docker
|
||||
nvidia-docker pull tensorflow/tensorflow:nightly-gpu
|
||||
nvidia-docker run -it -p 8888:8888 tensorflow/tensorflow:nightly-gpu
|
||||
|
||||
# If you do not have a GPU, use the CPU-only image
|
||||
docker pull tensorflow/tensorflow:nightly
|
||||
docker run -it -p 8888:8888 tensorflow/tensorflow:nightly
|
||||
```
|
||||
|
||||
And then visit http://localhost:8888 in your browser for a Jupyter notebook
|
||||
environment. Try out the notebooks below.
|
||||
|
||||
## Documentation
|
||||
|
||||
For an introduction to TensorFlow eager execution, see the Jupyter notebooks:
|
||||
|
||||
- [Basic Usage](examples/notebooks/1_basics.ipynb)
|
||||
- [Gradients](examples/notebooks/2_gradients.ipynb)
|
||||
- [Importing Data](examples/notebooks/3_datasets.ipynb)
|
||||
Eager execution is under active development.
|
||||
There are not many developer-facing materials yet, but stay tuned for updates
|
||||
in this directory.
|
||||
|
@ -1,134 +0,0 @@
|
||||
# Description:
|
||||
# Open-source examples and tutorials for TensorFlow Eager Execution.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
|
||||
py_binary(
|
||||
name = "linear_regression",
|
||||
srcs = ["linear_regression.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/contrib/eager/python:tfe",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "cart_pole_helper",
|
||||
srcs = ["cart_pole_helper.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "spinn",
|
||||
srcs = ["spinn.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "cart_pole",
|
||||
srcs = ["cart_pole.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":cart_pole_helper",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/contrib/eager/python:tfe",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "spinn_prep_data",
|
||||
srcs = ["spinn_prep_data.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "spinn_train",
|
||||
srcs = ["spinn_train.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":spinn",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/contrib/eager/python:tfe",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "linear_regression_test",
|
||||
size = "small",
|
||||
srcs = ["tests/linear_regression_test.py"],
|
||||
additional_deps = [
|
||||
":linear_regression",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "cart_pole_helper_test",
|
||||
srcs = ["tests/cart_pole_helper_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":cart_pole_helper",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "cart_pole_test",
|
||||
size = "small",
|
||||
srcs = ["tests/cart_pole_test.py"],
|
||||
additional_deps = [
|
||||
":cart_pole",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "spinn_test",
|
||||
size = "medium",
|
||||
srcs = ["tests/spinn_test.py"],
|
||||
additional_deps = [
|
||||
":spinn",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
@ -1,282 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""TensorFlow Eager Execution Example: OpenAI Gym CartPole.
|
||||
|
||||
Solves the cart-pole problem with policy gradient-based reinforcement learning.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from six.moves import input # pylint: disable=redefined-builtin
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.eager.python import tfe
|
||||
from tensorflow.contrib.eager.python.examples import cart_pole_helper
|
||||
|
||||
|
||||
class PolicyNetwork(object):
|
||||
"""Policy network for the cart-pole reinforcement learning problem.
|
||||
|
||||
The forward path of the network takes an observation from the cart-pole
|
||||
environment (length-4 vector) and outputs an action.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, train_logdir=None):
|
||||
"""Constructor of PolicyNetwork.
|
||||
|
||||
Args:
|
||||
hidden_size: Size of the hidden layer, as an `int`.
|
||||
train_logdir: The directory in which summaries will be written for
|
||||
TensorBoard during training (optional).
|
||||
"""
|
||||
self._hidden_layer = tf.layers.Dense(hidden_size, activation=tf.nn.elu)
|
||||
self._output_layer = tf.layers.Dense(1)
|
||||
|
||||
# Gradient function.
|
||||
self._grad_fn = tfe.implicit_gradients(
|
||||
self._get_cross_entropy_and_save_actions)
|
||||
|
||||
# Support for TensorBoard summaries. Once training has started, use:
|
||||
# tensorboard --logdir=<train_logdir>
|
||||
self._summary_writer = (tfe.SummaryWriter(train_logdir) if train_logdir
|
||||
else None)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Given inputs, calculate logits and action.
|
||||
|
||||
Args:
|
||||
inputs: Observations from a step in the cart-pole environment, of shape
|
||||
`(batch_size, input_size)`
|
||||
|
||||
Returns:
|
||||
logits: the logits output by the output layer. This can be viewed as the
|
||||
likelihood vales of choosing the left (0) action. Shape:
|
||||
`(batch_size, 1)`.
|
||||
actions: randomly selected actions ({0, 1}) based on the logits. Shape:
|
||||
`(batch_size, 1)`.
|
||||
"""
|
||||
hidden = self._hidden_layer(inputs)
|
||||
logits = self._output_layer(hidden)
|
||||
|
||||
# Probability of selecting the left action.
|
||||
left_p = tf.nn.sigmoid(logits)
|
||||
# Probabilities of selecting the left and right actions.
|
||||
left_right_ps = tf.concat([left_p, 1.0 - left_p], 1)
|
||||
# Randomly-generated actions based on the probabilities.
|
||||
actions = tf.multinomial(tf.log(left_right_ps), 1)
|
||||
return logits, actions
|
||||
|
||||
def _get_cross_entropy_and_save_actions(self, inputs):
|
||||
"""Given inputs, get the sigmoid cross entropy and save selection action.
|
||||
|
||||
Args:
|
||||
inputs: Observation from a step in the cart-pole environment.
|
||||
|
||||
Returns:
|
||||
The sigmoid cross-entropy loss given the selected action and logits, based
|
||||
on the assumption that the selected action was rewarded by the
|
||||
environment.
|
||||
"""
|
||||
logits, actions = self.forward(inputs)
|
||||
|
||||
# N.B.: This is an important step. We save the value of the `actions` in a
|
||||
# member variable for use with the RL environment. In classic TensorFlow
|
||||
# (non-eager execution), it is less straightfoward to access intermediate
|
||||
# computation results in this manner (c.f., `tf.Session.partial_run()`).
|
||||
self._current_actions = actions
|
||||
|
||||
labels = 1.0 - tf.cast(actions, tf.float32)
|
||||
return tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)
|
||||
|
||||
def train(self,
|
||||
cart_pole_env,
|
||||
optimizer,
|
||||
discount_rate,
|
||||
num_games,
|
||||
max_steps_per_game):
|
||||
"""Train the PolicyNetwork by playing `num_games` games in `cart_pole_env`.
|
||||
|
||||
Arguments:
|
||||
cart_pole_env: The cart-pole gym environment object.
|
||||
optimizer: A TensorFlow `Optimizer` object to be used in this training
|
||||
(e.g., `tf.train.AdamOptimizer`).
|
||||
discount_rate: Reward discounting rate.
|
||||
num_games: Number of games to run per parameter update.
|
||||
max_steps_per_game: Maximum number of steps to run in each game.
|
||||
|
||||
Returns:
|
||||
Step counts from all games, as a `list` of `int`.
|
||||
"""
|
||||
all_gradient_lists = []
|
||||
all_rewards = []
|
||||
for _ in xrange(num_games):
|
||||
obs = cart_pole_env.reset()
|
||||
game_rewards = []
|
||||
game_gradient_lists = []
|
||||
for _ in xrange(max_steps_per_game):
|
||||
# TODO(cais): Can we save the tf.constant() call?
|
||||
grad_list, var_list = zip(*self._grad_fn(tf.constant([obs])))
|
||||
game_gradient_lists.append(grad_list)
|
||||
|
||||
action = self._current_actions.numpy()[0][0]
|
||||
obs, reward, done, _ = cart_pole_env.step(action)
|
||||
game_rewards.append(reward)
|
||||
if reward != 1.0 or done:
|
||||
break
|
||||
|
||||
all_gradient_lists.append(game_gradient_lists)
|
||||
all_rewards.append(game_rewards)
|
||||
|
||||
normalized_rewards = cart_pole_helper.discount_and_normalize_rewards(
|
||||
all_rewards, discount_rate)
|
||||
all_grads_and_vars = self._scale_and_average_gradients(var_list,
|
||||
all_gradient_lists,
|
||||
normalized_rewards)
|
||||
optimizer.apply_gradients(all_grads_and_vars)
|
||||
step_counts = [len(rewards) for rewards in all_rewards]
|
||||
|
||||
if self._summary_writer:
|
||||
self._summary_writer.scalar("mean_step_count", np.mean(step_counts))
|
||||
self._summary_writer.step()
|
||||
|
||||
return step_counts
|
||||
|
||||
def _scale_and_average_gradients(self,
|
||||
variable_list,
|
||||
all_gradient_lists,
|
||||
normalized_rewards):
|
||||
"""Scale gradient tensors with normalized rewards."""
|
||||
num_games = len(all_gradient_lists)
|
||||
grads_and_vars = []
|
||||
for j, var in enumerate(variable_list):
|
||||
scaled_gradients = []
|
||||
for g in xrange(int(num_games)):
|
||||
num_steps = len(all_gradient_lists[g])
|
||||
for s in xrange(num_steps):
|
||||
scaled_gradients.append(
|
||||
all_gradient_lists[g][s][j] * normalized_rewards[g][s])
|
||||
mean_scaled_gradients = sum(scaled_gradients) / len(scaled_gradients)
|
||||
grads_and_vars.append((mean_scaled_gradients, var))
|
||||
return grads_and_vars
|
||||
|
||||
def play(self, cart_pole_env, max_steps=None, render=False):
|
||||
"""Play a game in the cart-pole gym environment.
|
||||
|
||||
Args:
|
||||
cart_pole_env: The cart-pole gym environment object.
|
||||
max_steps: Maximum number of steps to run in the game.
|
||||
render: Whether the game state is to be rendered on the screen.
|
||||
"""
|
||||
if render:
|
||||
input("\nAbout to play a game with rendering. Press Enter to continue: ")
|
||||
|
||||
steps = 0
|
||||
obs = cart_pole_env.reset()
|
||||
while True:
|
||||
# TODO(cais): Can we save the tf.constant() call?
|
||||
_, actions = self.forward(tf.constant([obs]))
|
||||
if render:
|
||||
cart_pole_env.render()
|
||||
obs, reward, done, _ = cart_pole_env.step(actions.numpy()[0][0])
|
||||
steps += 1
|
||||
if done or reward != 1.0 or max_steps is not None and steps >= max_steps:
|
||||
break
|
||||
|
||||
|
||||
def main(_):
|
||||
tf.set_random_seed(0)
|
||||
|
||||
cart_pole_env = gym.make("CartPole-v0")
|
||||
cart_pole_env.seed(0)
|
||||
cart_pole_env.reset()
|
||||
|
||||
device = "gpu:0" if tfe.num_gpus() else "cpu:0"
|
||||
print("Using device: %s" % device)
|
||||
|
||||
with tf.device(device):
|
||||
policy_network = PolicyNetwork(FLAGS.hidden_size, train_logdir=FLAGS.logdir)
|
||||
optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
|
||||
|
||||
# Training loop.
|
||||
for i in xrange(FLAGS.num_iterations):
|
||||
step_counts = policy_network.train(
|
||||
cart_pole_env,
|
||||
optimizer,
|
||||
FLAGS.discount_rate,
|
||||
FLAGS.num_games_per_iteration,
|
||||
FLAGS.max_steps_per_game)
|
||||
print("Iteration %d: step counts = %s; mean = %g" % (
|
||||
i, step_counts, np.mean(step_counts)))
|
||||
sys.stdout.flush()
|
||||
|
||||
# Optional playing after training, with rendering.
|
||||
if FLAGS.play_after_training:
|
||||
policy_network.play(cart_pole_env,
|
||||
max_steps=FLAGS.max_steps_per_game,
|
||||
render=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--hidden_size",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Size of the hidden layer of the policy network.")
|
||||
parser.add_argument(
|
||||
"--discount_rate",
|
||||
type=float,
|
||||
default=0.95,
|
||||
help="Reward discounting rate.")
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=0.05,
|
||||
help="Learning rate to be used during training.")
|
||||
parser.add_argument(
|
||||
"--num_iterations",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of training iterations.")
|
||||
parser.add_argument(
|
||||
"--num_games_per_iteration",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of games to run in each training iteration.")
|
||||
parser.add_argument(
|
||||
"--max_steps_per_game",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Maximum number of steps to run in each game.")
|
||||
parser.add_argument(
|
||||
"--logdir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="logdir in which TensorBoard summaries will be written (optional).")
|
||||
parser.add_argument(
|
||||
"--play_after_training",
|
||||
action="store_true",
|
||||
help="Play a game after training (with rendering).")
|
||||
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tfe.run(main=main, argv=[sys.argv[0]] + unparsed)
|
@ -1,60 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Helper functions for reinforcement learning in the cart-pole problem."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def discount_rewards(rewards, discount_rate):
|
||||
"""Discout reward values with discount rate.
|
||||
|
||||
Args:
|
||||
rewards: A sequence of reward values in time.
|
||||
discount_rate: (`float`) reward discounting rate (e.g., 0.95).
|
||||
|
||||
Returns:
|
||||
Discounted reward values.
|
||||
"""
|
||||
discounted = []
|
||||
for reward in reversed(rewards):
|
||||
discounted.append(
|
||||
(discounted[-1] if discounted else 0.0) * discount_rate + reward)
|
||||
return list(reversed(discounted))
|
||||
|
||||
|
||||
def discount_and_normalize_rewards(reward_sequences, discount_rate):
|
||||
"""Perform discounting on a number of reward sequences; then normalize values.
|
||||
|
||||
Args:
|
||||
reward_sequences: an `iterable` of reward sequences.
|
||||
discount_rate: reward discounting rate (e.g., 0.95).
|
||||
|
||||
Returns:
|
||||
A `list` of reward value `list`s, discounted and normalized.
|
||||
"""
|
||||
discounted = []
|
||||
for sequence in reward_sequences:
|
||||
discounted.append(discount_rewards(sequence, discount_rate))
|
||||
discounted = np.array(discounted)
|
||||
|
||||
# Compute overall mean and stddev.
|
||||
flattened = np.concatenate(discounted)
|
||||
mean = np.mean(flattened)
|
||||
std = np.std(flattened)
|
||||
return [((d - mean) / std) for d in discounted]
|
@ -1,197 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=line-too-long
|
||||
r"""TensorFlow Eager Execution Example: Linear Regression.
|
||||
|
||||
This example shows how to use TensorFlow Eager Execution to fit a simple linear
|
||||
regression model using some synthesized data. Specifically, it illustrates how
|
||||
to define the forward path of the linear model and the loss function, as well
|
||||
as how to obtain the gradients of the loss function with respect to the
|
||||
variables and update the variables with the gradients.
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
import tensorflow as tf
|
||||
|
||||
# TODO(cais): Use tf.contrib.eager namespace when ready.
|
||||
from tensorflow.contrib.eager.python import tfe
|
||||
|
||||
|
||||
class DataGenerator(object):
|
||||
"""Generates synthetic data for linear regression."""
|
||||
|
||||
def __init__(self, w, b, noise_level, batch_size):
|
||||
self._w = w
|
||||
self._b = b
|
||||
self._noise_level = noise_level
|
||||
self._batch_size = batch_size
|
||||
self._ndims = w.shape[0]
|
||||
|
||||
def next_batch(self):
|
||||
"""Generate a synthetic batch of xs and ys."""
|
||||
xs = tf.random_normal([self._batch_size, self._ndims])
|
||||
ys = (tf.matmul(xs, self._w) + self._b +
|
||||
self._noise_level * tf.random_normal([self._batch_size, 1]))
|
||||
return xs, ys
|
||||
|
||||
|
||||
class LinearModel(object):
|
||||
"""A TensorFlow linear regression model.
|
||||
|
||||
Uses TensorFlow's eager execution.
|
||||
|
||||
For those familiar with TensorFlow graphs, notice the absence of
|
||||
`tf.Session`. The `forward()` method here immediately executes and
|
||||
returns output values. The `loss()` method immediately compares the
|
||||
output of `forward()` with the target adn returns the MSE loss value.
|
||||
The `fit()` performs gradient-descent training on the model's weights
|
||||
and bias.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Constructs a LinearModel object."""
|
||||
self._hidden_layer = tf.layers.Dense(1)
|
||||
|
||||
# loss_value_and_grad_fn is a function that when invoked, will return the
|
||||
# loss value and the gradients of loss with respect to the variables. It has
|
||||
# the same input arguments as `self.loss()`.
|
||||
self._loss_value_and_grad_fn = tfe.implicit_value_and_gradients(self.loss)
|
||||
|
||||
@property
|
||||
def weights(self):
|
||||
"""Get values of weights as a numpy array."""
|
||||
return self._hidden_layer.variables[0].read_value().numpy()
|
||||
|
||||
@property
|
||||
def biases(self):
|
||||
"""Get values of biases as a numpy array."""
|
||||
return self._hidden_layer.variables[1].read_value().numpy()
|
||||
|
||||
def forward(self, xs):
|
||||
"""Invoke the linear model.
|
||||
|
||||
Args:
|
||||
xs: input features, as a tensor of size [batch_size, ndims].
|
||||
|
||||
Returns:
|
||||
ys: the predictions of the linear mode, as a tensor of size [batch_size]
|
||||
"""
|
||||
# Note: Unlike classic TensorFlow, operations such as self._hidden_layer
|
||||
# will execute the underlying computation immediately.
|
||||
return self._hidden_layer(xs)
|
||||
|
||||
def loss(self, xs, ys):
|
||||
"""Loss of the linear model.
|
||||
|
||||
Args:
|
||||
xs: input features, as a tensor of size [batch_size, ndims].
|
||||
ys: the target values of y, as a tensor of size [batch_size].
|
||||
|
||||
Returns:
|
||||
The mean square error loss value.
|
||||
"""
|
||||
return tf.reduce_mean(tf.square(self.forward(xs) - ys))
|
||||
|
||||
def fit(self,
|
||||
batch_fn,
|
||||
optimizer,
|
||||
num_iters,
|
||||
verbose=False,
|
||||
logdir=None):
|
||||
"""Fit the linear-regression model.
|
||||
|
||||
Args:
|
||||
batch_fn: A function, which when called without any arguments, returns a
|
||||
batch of xs and ys for training.
|
||||
optimizer: The TensorFlow Optimizer object to be used.
|
||||
num_iters: Number of training iterations to perform.
|
||||
verbose: If true, will print out loss values at every iteration.
|
||||
logdir: The directory in which summaries will be written for TensorBoard
|
||||
(optional).
|
||||
"""
|
||||
if logdir:
|
||||
# Support for TensorBoard summaries. Once training has started, use:
|
||||
# tensorboard --logdir=<logdir>
|
||||
summary_writer = tfe.SummaryWriter(logdir)
|
||||
|
||||
# Training loop.
|
||||
for i in xrange(num_iters):
|
||||
# Generate a (mini-)batch of data for training.
|
||||
xs, ys = batch_fn()
|
||||
|
||||
# Call the function obtained above to get the loss and gradient values at
|
||||
# the specific training batch. The function has the same input arguments
|
||||
# as the forward function, i.e., `linear_loss()`.
|
||||
loss_value, grads_and_vars = self._loss_value_and_grad_fn(xs, ys)
|
||||
if verbose:
|
||||
print("Iteration %d: loss = %s" % (i, loss_value.numpy()))
|
||||
|
||||
# Send the gradients to the optimizer and update the Variables, i.e., `w`
|
||||
# and `b`.
|
||||
optimizer.apply_gradients(grads_and_vars)
|
||||
|
||||
if logdir:
|
||||
summary_writer.scalar("loss", loss_value)
|
||||
summary_writer.step()
|
||||
|
||||
|
||||
def main(_):
|
||||
# Ground-truth constants.
|
||||
true_w = np.array([[-2.0], [4.0], [1.0]], dtype=np.float32)
|
||||
true_b = np.array([0.5], dtype=np.float32)
|
||||
noise_level = 0.01
|
||||
|
||||
# Training constants.
|
||||
batch_size = 64
|
||||
learning_rate = 0.1
|
||||
num_iters = 20
|
||||
|
||||
print("True w: %s" % true_w)
|
||||
print("True b: %s\n" % true_b)
|
||||
|
||||
device = "gpu:0" if tfe.num_gpus() else "cpu:0"
|
||||
print("Using device: %s" % device)
|
||||
with tf.device(device):
|
||||
linear_model = LinearModel()
|
||||
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
|
||||
data_gen = DataGenerator(true_w, true_b, noise_level, batch_size)
|
||||
linear_model.fit(data_gen.next_batch, optimizer, num_iters, verbose=True,
|
||||
logdir=FLAGS.logdir)
|
||||
|
||||
print("\nAfter training: w = %s" % linear_model.weights)
|
||||
print("\nAfter training: b = %s" % linear_model.biases)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--logdir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="logdir in which TensorBoard summaries will be written (optional).")
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
|
||||
# Use tfe.run() instead of tf.app.run() for eager execution.
|
||||
tfe.run(main=main, argv=[sys.argv[0]] + unparsed)
|
@ -1,529 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "U9i2Dsh-ziXr"
|
||||
},
|
||||
"source": [
|
||||
"# Eager Execution Tutorial: Basics\n",
|
||||
"\n",
|
||||
"This notebook introduces the basics of using TensorFlow's eager execution capabilities. It covers concepts such as:\n",
|
||||
"\n",
|
||||
"* Importing required packages\n",
|
||||
"* Enabling eager execution\n",
|
||||
"* Creating and using TensorFlow Tensors and Variables\n",
|
||||
"* Using TensorFlow interactively\n",
|
||||
"* Using GPUs with eager execution enabled\n",
|
||||
"\n",
|
||||
"This notebook does *not* cover modeling topics, such as gradients."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "z1JcS5iBXMRO"
|
||||
},
|
||||
"source": [
|
||||
"# Step 1: Import Eager\n",
|
||||
"\n",
|
||||
"The key imports for eager execution are the following:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"cellView": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "RlIWhyeLoYnG"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Import TensorFlow.\n",
|
||||
"import tensorflow as tf\n",
|
||||
"\n",
|
||||
"# Import TensorFlow eager execution support (subject to future changes).\n",
|
||||
"from tensorflow.contrib.eager.python import tfe"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "H9UySOPLXdaw"
|
||||
},
|
||||
"source": [
|
||||
"# Step 2: Enable eager execution\n",
|
||||
"\n",
|
||||
"All future TensorFlow calls will execute the\n",
|
||||
"underlying TensorFlow ops immediately:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"cellView": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "WPTUfGq6kJ5w"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tfe.enable_eager_execution()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "twBfWd5xyu_d"
|
||||
},
|
||||
"source": [
|
||||
"# Step 3: Interactively Use TensorFlow!\n",
|
||||
"\n",
|
||||
"Now you can call TensorFlow functions and get results, immediately! No more `tf.Sessions`!\n",
|
||||
"\n",
|
||||
"TensorFlow will automatically wrap native Python types for you with operator overloading for TensorFlow Tensors."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"cellView": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "ngUe237Wt48W"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(tf.add(1, 2))\n",
|
||||
"print(tf.add([1, 2], [3, 4]))\n",
|
||||
"print(tf.square(5))\n",
|
||||
"print(tf.reduce_sum([1, 2, 3]))\n",
|
||||
"print(tf.encode_base64(\"hello world\"))\n",
|
||||
"print(\"\")\n",
|
||||
"\n",
|
||||
"x = tf.constant(2)\n",
|
||||
"y = tf.constant(3)\n",
|
||||
"print(x * y + 1)\n",
|
||||
"\n",
|
||||
"# Most TensorFlow ops are directly usable with eager execution, giving\n",
|
||||
"# results immediately.\n",
|
||||
"print(tf.contrib.signal.hamming_window(x * y + 1))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "IDY4WsYRhP81"
|
||||
},
|
||||
"source": [
|
||||
"Numpy arrays are supported, too:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "lCUWzso6mbqR"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"ones = np.ones([3, 3])\n",
|
||||
"\n",
|
||||
"print(\"numpy 3x3 matrix of 1s:\")\n",
|
||||
"print(ones)\n",
|
||||
"print(\"\")\n",
|
||||
"\n",
|
||||
"print(\"Multiplied by 42:\")\n",
|
||||
"print(tf.multiply(ones, 42))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "PBNP8yTRfu_X"
|
||||
},
|
||||
"source": [
|
||||
"# Step 4: Define and Print TensorFlow Variables\n",
|
||||
"\n",
|
||||
"To define TensorFlow variables, use the `get_variable()` function as follows:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"cellView": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "3Twf_Rw-gQFM"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x = tf.get_variable(name=\"x\", shape=[1], dtype=tf.float32, initializer=tf.zeros_initializer)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "45G7094TxsMb"
|
||||
},
|
||||
"source": [
|
||||
"## Printing TensorFlow Variables"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"cellView": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "UJBJeZ5XxuwA"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This does NOT print the Variable's actual value:\n",
|
||||
"print(\"Printing a TensorFlow Variable:\")\n",
|
||||
"print(x)\n",
|
||||
"print(\"\")\n",
|
||||
"\n",
|
||||
"# A TensorFlow variable represents a reference to a tensor.\n",
|
||||
"# The `read_value()` method provides access to the current value of the\n",
|
||||
"# variable. Tensorflow Variables are automatically initialized according to the\n",
|
||||
"# semantics defined in tf.get_variable().\n",
|
||||
"print(\"Printing a TensorFlow Variable's value using .read_value():\")\n",
|
||||
"print(x.read_value())\n",
|
||||
"print(\"\")\n",
|
||||
"\n",
|
||||
"print(\"Printing a TensorFlow Variable's value using .read_value().numpy():\")\n",
|
||||
"print(x.read_value().numpy())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "2njjWHcTpBEn"
|
||||
},
|
||||
"source": [
|
||||
"## Changing a TensorFlow Variable's value\n",
|
||||
"\n",
|
||||
"To change a TensorFlow Variable's value, use its `.assign()` or `.assign_add()` method:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "v3wr6Erbo_hB"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x.assign(42)\n",
|
||||
"print(x.read_value())\n",
|
||||
"\n",
|
||||
"x.assign_add(3)\n",
|
||||
"print(x.read_value())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "uhtynjHVpTB5"
|
||||
},
|
||||
"source": [
|
||||
"## Use a Variable just like any other Tensor"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "7PbktdnHoehR"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(x + 3)\n",
|
||||
"\n",
|
||||
"# This code will broadcast the value across the list of numbers:\n",
|
||||
"print(x * [1, 2, 4])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "GVChqwlwy1SI"
|
||||
},
|
||||
"source": [
|
||||
"# Step 5: Debug Errors with Instant Feedback\n",
|
||||
"\n",
|
||||
"TensorFlow's eager execution helps you identify and debug runtime issues through interactive exploration of code snippets.\n",
|
||||
"\n",
|
||||
"Below, we'll define a length-4 vector, and attempt two `tf.slice()` operations,\n",
|
||||
"one being legal and the other being illegal, leading to a runtime error that is\n",
|
||||
"raised immediately."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"cellView": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "23ap04N0v4k0"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vector = tf.constant([10.0, 20.0, 30.0, 40.0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"cellView": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "FCUMsIYxxRRa"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Works, because the values of `begin` and `size` (the 2nd and 3rd input\n",
|
||||
"# arguments) are within the bound of `vector`.\n",
|
||||
"print(tf.slice(vector, [1], [3]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"cellView": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "T8me2oCNxpFp"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# The following does NOT work, because the value of `size` (the 3rd\n",
|
||||
"# argument) causes the indices to go out of the bounds of `vector`. The\n",
|
||||
"# error is raised immediately.\n",
|
||||
"try:\n",
|
||||
" print(tf.slice(vector, [1], [4]))\n",
|
||||
"except tf.OpError as e:\n",
|
||||
" print(\"Caught error: %s\" % e)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "irxJhAgar84v"
|
||||
},
|
||||
"source": [
|
||||
"# Step 6: Using the GPU\n",
|
||||
"\n",
|
||||
"You can place Tensors on the GPU by calling a Tensor's `.gpu()` method.\n",
|
||||
"\n",
|
||||
"The first operation executing on the GPU may be slow as TensorFlow initializes. Subsequent uses will be much faster."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "7J4N9baqaKCL"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# The example code from here on will work only if your notebook\n",
|
||||
"# is running on a machine with a functional CUDA GPU. The following\n",
|
||||
"# line checks that.\n",
|
||||
"is_gpu_available = tfe.num_gpus() \u003e 0\n",
|
||||
"\n",
|
||||
"# Create some Tensors\n",
|
||||
"SIZE = 1000\n",
|
||||
"cpu_tensor = tf.random_normal([SIZE, SIZE])\n",
|
||||
"\n",
|
||||
"if is_gpu_available:\n",
|
||||
" gpu_tensor = cpu_tensor.gpu()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "4E-2n7VbzY1n"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Time a CPU-based matrix multiplication\n",
|
||||
"\n",
|
||||
"print(\"Time to conduct matmul on CPU:\")\n",
|
||||
"%time tf.matmul(cpu_tensor, cpu_tensor)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "vbSFW-T5zhZF"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Time GPU-based matrix multiplications.\n",
|
||||
"\n",
|
||||
"if is_gpu_available:\n",
|
||||
" # First use of the GPU will be slow:\n",
|
||||
" print(\"Time to conduct first matmul on GPU:\")\n",
|
||||
" %time tf.matmul(gpu_tensor, gpu_tensor)\n",
|
||||
" print()\n",
|
||||
"\n",
|
||||
" # Subsequent uses are much faster:\n",
|
||||
" print(\"Time to conduct second matmul on GPU:\")\n",
|
||||
" %time tf.matmul(gpu_tensor, gpu_tensor)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "E5pIOe3Rz7iW"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Second timing demo for GPUs, after it has been used once:\n",
|
||||
"\n",
|
||||
"cpu_tensor = tf.random_normal([SIZE, SIZE])\n",
|
||||
"print(\"Time to conduct CPU matmul:\")\n",
|
||||
"%time tf.matmul(cpu_tensor, cpu_tensor)\n",
|
||||
"print()\n",
|
||||
"\n",
|
||||
"if is_gpu_available:\n",
|
||||
" gpu_tensor = cpu_tensor.gpu()\n",
|
||||
" print(\"Time to conduct GPU matmul:\")\n",
|
||||
" %time tf.matmul(gpu_tensor, gpu_tensor)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"default_view": {},
|
||||
"name": "Eager Execution Tutorial: Basics",
|
||||
"provenance": [
|
||||
{
|
||||
"file_id": "0B0kLcpwLFwKEVm9XNkFueGk4bTg",
|
||||
"timestamp": 1504118841551
|
||||
}
|
||||
],
|
||||
"version": "0.3.2",
|
||||
"views": {}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
File diff suppressed because one or more lines are too long
@ -1,218 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "U9i2Dsh-ziXr"
|
||||
},
|
||||
"source": [
|
||||
"# Eager Execution Tutorial: Importing Data\n",
|
||||
"\n",
|
||||
"This notebook demonstrates the use of the [`tf.contrib.data.Dataset` API](https://www.tensorflow.org/programmers_guide/datasets) to build pipelines to feed data to your program. It covers:\n",
|
||||
"\n",
|
||||
"* Creating a `Dataset`.\n",
|
||||
"* Iteration over a `Dataset` with eager execution enabled.\n",
|
||||
"\n",
|
||||
"We recommend using the `Dataset`s API for building performant, complex input pipelines from simple, re-usable pieces that will feed your model's training or evaluation loops.\n",
|
||||
"\n",
|
||||
"If you're familiar with TensorFlow graphs, the API for constructing the `Dataset` object remains exactly the same when eager execution is enabled, but the process of iterating over elements of the dataset is slightly different. You will use a Pythonic `Iterator()` class instead of using `make_one_shot_iterator()` and `get_next()`. As a result, the discussion on iterators in the [Programmer's Guide](https://www.tensorflow.org/programmers_guide/datasets) is not relevant when eager execution is enabled."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "z1JcS5iBXMRO"
|
||||
},
|
||||
"source": [
|
||||
"# Setup: Enable eager execution\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"cellView": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "RlIWhyeLoYnG"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Import TensorFlow.\n",
|
||||
"import tensorflow as tf\n",
|
||||
"\n",
|
||||
"# Import TensorFlow eager execution support (subject to future changes).\n",
|
||||
"from tensorflow.contrib.eager.python import tfe\n",
|
||||
"\n",
|
||||
"# Enable eager execution\n",
|
||||
"tfe.enable_eager_execution()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "H9UySOPLXdaw"
|
||||
},
|
||||
"source": [
|
||||
"# Step 1: Create a source `Dataset`\n",
|
||||
"\n",
|
||||
"Create a _source_ dataset using one of the factory functions like [`Dataset.from_tensors`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#from_tensors), [`Dataset.from_tensor_slices`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#from_tensor_slices) or using objects that read from files like [`TextLineDataset`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/TextLineDataset) or [`TFRecordDataset`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/TFRecordDataset). See the [Programmer's Guide](https://www.google.com/url?sa=D\u0026q=https%3A%2F%2Fwww.tensorflow.org%2Fprogrammers_guide%2Fdatasets%23reading_input_data) for more information."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"cellView": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "WPTUfGq6kJ5w"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ds_tensors = tf.contrib.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])\n",
|
||||
"\n",
|
||||
"# Create a CSV file\n",
|
||||
"import tempfile\n",
|
||||
"_, filename = tempfile.mkstemp()\n",
|
||||
"with open(filename, 'w') as f:\n",
|
||||
" f.write(\"\"\"Line 1\n",
|
||||
"Line 2\n",
|
||||
"Line 3\n",
|
||||
" \"\"\")\n",
|
||||
"ds_file = tf.contrib.data.TextLineDataset(filename)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "twBfWd5xyu_d"
|
||||
},
|
||||
"source": [
|
||||
"# Step 2: Apply transformations\n",
|
||||
"\n",
|
||||
"Use the transformations functions like [`map`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#map), [`batch`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#batch), [`shuffle`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#shuffle) etc. to apply transformations to the records of the dataset. See the [API documentation for `tf.contrib.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset) for details."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"cellView": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "ngUe237Wt48W"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ds_tensors = ds_tensors.map(tf.square).shuffle(2).batch(2)\n",
|
||||
"ds_file = ds_file.batch(2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "IDY4WsYRhP81"
|
||||
},
|
||||
"source": [
|
||||
"# Step 3: Iterate\n",
|
||||
"\n",
|
||||
"Use `tfe.Iterator` on the `Dataset` object to get a Python iterator over the contents of the dataset.\n",
|
||||
"\n",
|
||||
"If you're familiar with the use of `Dataset`s in TensorFlow graphs, note that this process of iteration is different. Here there are no calls to `Dataset.make_one_shot_iterator()` and no `get_next()` calls."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
},
|
||||
"height": 153,
|
||||
"output_extras": [
|
||||
{
|
||||
"item_id": 1
|
||||
}
|
||||
]
|
||||
},
|
||||
"colab_type": "code",
|
||||
"executionInfo": {
|
||||
"elapsed": 201,
|
||||
"status": "ok",
|
||||
"timestamp": 1505952405928,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"photoUrl": "",
|
||||
"userId": ""
|
||||
},
|
||||
"user_tz": 420
|
||||
},
|
||||
"id": "lCUWzso6mbqR",
|
||||
"outputId": "ec027d30-96c6-4ea4-9ee1-ef74ec1ae29a"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Elements of ds_tensors:\n",
|
||||
"tf.Tensor([4 9], shape=(2,), dtype=int32)\n",
|
||||
"tf.Tensor([16 25], shape=(2,), dtype=int32)\n",
|
||||
"tf.Tensor([36 1], shape=(2,), dtype=int32)\n",
|
||||
"\n",
|
||||
"Elements in ds_file:\n",
|
||||
"tf.Tensor(['Line 1' 'Line 2'], shape=(2,), dtype=string)\n",
|
||||
"tf.Tensor(['Line 3' ' '], shape=(2,), dtype=string)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print('Elements of ds_tensors:')\n",
|
||||
"for x in tfe.Iterator(ds_tensors):\n",
|
||||
" print(x)\n",
|
||||
"\n",
|
||||
"print('\\nElements in ds_file:')\n",
|
||||
"for x in tfe.Iterator(ds_file):\n",
|
||||
" print(x)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"default_view": {},
|
||||
"last_runtime": {
|
||||
"build_target": "",
|
||||
"kind": "local"
|
||||
},
|
||||
"name": "Eager Execution Tutorial: Importing Data",
|
||||
"provenance": [],
|
||||
"version": "0.3.2",
|
||||
"views": {}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
@ -1,51 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.eager.python.examples import cart_pole_helper
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class RewardDiscountingTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testDiscountingRewards(self):
|
||||
rewards = [0.0, 10.0, 20.0]
|
||||
discount_rate = 0.9
|
||||
self.assertAllClose(
|
||||
[10 * discount_rate + 20 * discount_rate * discount_rate,
|
||||
10 + 20 * discount_rate, 20],
|
||||
cart_pole_helper.discount_rewards(rewards, discount_rate))
|
||||
self.assertAllClose(
|
||||
[-1.2], cart_pole_helper.discount_rewards([-1.2], discount_rate))
|
||||
self.assertEqual([], cart_pole_helper.discount_rewards([], discount_rate))
|
||||
|
||||
def testDiscountAndNormalizeRewardSequences(self):
|
||||
rewards1 = [0.0, 10.0, 20.0]
|
||||
rewards2 = [0.0, 5.0, -5.0]
|
||||
reward_sequences = [rewards1, rewards2]
|
||||
discount_rate = 0.9
|
||||
dn = cart_pole_helper.discount_and_normalize_rewards(reward_sequences,
|
||||
discount_rate)
|
||||
self.assertAllClose(
|
||||
[[1.03494653, 1.24685514, 0.64140196],
|
||||
[-0.83817424, -0.83439016, -1.25063922]], dn)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -1,162 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Unit test for cart-pole reinforcement learning under eager exection."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gc
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.eager.python.examples import cart_pole
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import training
|
||||
|
||||
|
||||
class CartPoleTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(CartPoleTest, self).setUp()
|
||||
self._tmp_logdir = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self._tmp_logdir)
|
||||
super(CartPoleTest, self).tearDown()
|
||||
|
||||
def testGetLogitsAndAction(self):
|
||||
hidden_size = 5
|
||||
policy_network = cart_pole.PolicyNetwork(hidden_size)
|
||||
|
||||
dummy_inputs = np.array([[0.1, 0.3, 0.2, 0.5],
|
||||
[0.0, -0.2, 0.6, -0.8]], dtype=np.float32)
|
||||
logits, actions = policy_network.forward(constant_op.constant(dummy_inputs))
|
||||
|
||||
self.assertEqual((2, 1), logits.shape)
|
||||
self.assertEqual(dtypes.float32, logits.dtype)
|
||||
self.assertEqual((2, 1), actions.shape)
|
||||
self.assertEqual(dtypes.int64, actions.dtype)
|
||||
|
||||
def testCrossEntropy(self):
|
||||
hidden_size = 5
|
||||
policy_network = cart_pole.PolicyNetwork(hidden_size)
|
||||
|
||||
dummy_inputs = np.array([[0.1, 0.3, 0.2, 0.5],
|
||||
[0.0, -0.2, 0.6, -0.8]], dtype=np.float32)
|
||||
cross_entropy = policy_network._get_cross_entropy_and_save_actions(
|
||||
constant_op.constant(dummy_inputs))
|
||||
|
||||
self.assertEqual((2, 1), cross_entropy.shape)
|
||||
self.assertEqual(dtypes.float32, cross_entropy.dtype)
|
||||
|
||||
def testPlayAGame(self):
|
||||
hidden_size = 5
|
||||
cart_pole_env = gym.make("CartPole-v0")
|
||||
cart_pole_env.seed(0)
|
||||
cart_pole_env.reset()
|
||||
|
||||
device = "gpu:0" if context.context().num_gpus() > 0 else "cpu:0"
|
||||
logging.info("device = %s", device)
|
||||
with context.device(device):
|
||||
policy_network = cart_pole.PolicyNetwork(hidden_size)
|
||||
policy_network.play(cart_pole_env, max_steps=10, render=False)
|
||||
|
||||
def testTrain(self):
|
||||
hidden_size = 5
|
||||
num_games_per_iteration = 5
|
||||
max_steps_per_game = 10
|
||||
discount_rate = 0.95
|
||||
learning_rate = 0.02
|
||||
|
||||
cart_pole_env = gym.make("CartPole-v0")
|
||||
cart_pole_env.reset()
|
||||
|
||||
device = "gpu:0" if context.context().num_gpus() > 0 else "cpu:0"
|
||||
logging.info("device = %s", device)
|
||||
with context.device(device):
|
||||
policy_network = cart_pole.PolicyNetwork(hidden_size,
|
||||
train_logdir=self._tmp_logdir)
|
||||
optimizer = training.AdamOptimizer(learning_rate)
|
||||
policy_network.train(
|
||||
cart_pole_env,
|
||||
optimizer,
|
||||
discount_rate,
|
||||
num_games_per_iteration,
|
||||
max_steps_per_game)
|
||||
self.assertTrue(glob.glob(os.path.join(self._tmp_logdir, "events.out.*")))
|
||||
|
||||
|
||||
class EagerCartPoleTrainingBenchmark(test.Benchmark):
|
||||
|
||||
def benchmarkEagerCartPolePolicyNetworkTraining(self):
|
||||
burn_in_iterations = 1
|
||||
benchmark_iterations = 2
|
||||
num_games_per_iteration = 10
|
||||
max_steps_per_game = 100
|
||||
discount_rate = 0.95
|
||||
learning_rate = 0.02
|
||||
|
||||
cart_pole_env = gym.make("CartPole-v0")
|
||||
cart_pole_env.seed(0)
|
||||
random_seed.set_random_seed(0)
|
||||
cart_pole_env.reset()
|
||||
|
||||
hidden_size = 5
|
||||
policy_network = cart_pole.PolicyNetwork(hidden_size)
|
||||
optimizer = training.AdamOptimizer(learning_rate)
|
||||
|
||||
# Perform burn-in.
|
||||
for _ in xrange(burn_in_iterations):
|
||||
policy_network.train(
|
||||
cart_pole_env,
|
||||
optimizer,
|
||||
discount_rate,
|
||||
num_games_per_iteration,
|
||||
max_steps_per_game)
|
||||
|
||||
gc.collect()
|
||||
start_time = time.time()
|
||||
for _ in xrange(benchmark_iterations):
|
||||
policy_network.train(
|
||||
cart_pole_env,
|
||||
optimizer,
|
||||
discount_rate,
|
||||
num_games_per_iteration,
|
||||
max_steps_per_game)
|
||||
wall_time = time.time() - start_time
|
||||
# Named "examples"_per_sec to conform with other benchmarks.
|
||||
extras = {"examples_per_sec": benchmark_iterations / wall_time}
|
||||
self.report_benchmark(
|
||||
name="EagerCartPoleReinforcementLearning",
|
||||
iters=benchmark_iterations,
|
||||
wall_time=wall_time,
|
||||
extras=extras)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -1,114 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Unit tests for linear regression example under TensorFlow eager execution."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.eager.python.examples import linear_regression
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
|
||||
def _create_data_gen_for_test():
|
||||
true_w = np.array([[1.0], [-0.5], [2.0]], dtype=np.float32)
|
||||
true_b = np.array([1.0], dtype=np.float32)
|
||||
noise_level = 0
|
||||
batch_size = 64
|
||||
return (
|
||||
true_w, true_b, noise_level, batch_size,
|
||||
linear_regression.DataGenerator(true_w, true_b, noise_level, batch_size))
|
||||
|
||||
|
||||
class LinearRegressionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(LinearRegressionTest, self).setUp()
|
||||
self._tmp_logdir = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self._tmp_logdir)
|
||||
super(LinearRegressionTest, self).tearDown()
|
||||
|
||||
def testSyntheticBatch(self):
|
||||
_, _, _, batch_size, data_gen = _create_data_gen_for_test()
|
||||
|
||||
xs, ys = data_gen.next_batch()
|
||||
self.assertEqual((batch_size, 3), xs.shape)
|
||||
self.assertEqual((batch_size, 1), ys.shape)
|
||||
self.assertEqual(tf.float32, xs.dtype)
|
||||
self.assertEqual(tf.float32, ys.dtype)
|
||||
|
||||
def testLinearRegression(self):
|
||||
true_w, true_b, _, _, data_gen = _create_data_gen_for_test()
|
||||
|
||||
learning_rate = 0.1
|
||||
num_iters = 40
|
||||
|
||||
device = "gpu:0" if context.context().num_gpus() > 0 else "cpu:0"
|
||||
logging.info("device = %s", device)
|
||||
with context.device(device):
|
||||
linear_model = linear_regression.LinearModel()
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
|
||||
linear_model.fit(data_gen.next_batch, optimizer, num_iters,
|
||||
logdir=self._tmp_logdir)
|
||||
|
||||
self.assertAllClose(true_w, linear_model.weights, rtol=1e-2)
|
||||
self.assertAllClose(true_b, linear_model.biases, rtol=1e-2)
|
||||
self.assertTrue(glob.glob(os.path.join(self._tmp_logdir, "events.out.*")))
|
||||
|
||||
|
||||
class EagerLinearRegressionBenchmark(test.Benchmark):
|
||||
|
||||
def benchmarkEagerLinearRegression(self):
|
||||
_, _, _, _, data_gen = _create_data_gen_for_test()
|
||||
|
||||
learning_rate = 0.1
|
||||
num_burnin_iters = 10
|
||||
num_iters = 200
|
||||
|
||||
device = "gpu:0" if context.context().num_gpus() > 0 else "cpu:0"
|
||||
logging.info("device = %s", device)
|
||||
with context.device(device):
|
||||
linear_model = linear_regression.LinearModel()
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
|
||||
|
||||
# Perform burn-in.
|
||||
linear_model.fit(data_gen.next_batch, optimizer, num_burnin_iters)
|
||||
|
||||
start_time = time.time()
|
||||
linear_model.fit(data_gen.next_batch, optimizer, num_iters)
|
||||
wall_time = time.time() - start_time
|
||||
|
||||
self.report_benchmark(
|
||||
name="EagerLinearRegression",
|
||||
iters=num_iters,
|
||||
wall_time=wall_time)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -1,311 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import gc
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.eager.python import tfe
|
||||
from tensorflow.contrib.eager.python.examples import spinn
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import test_util
|
||||
|
||||
|
||||
def _generate_synthetic_snli_data_batch(sequence_length,
|
||||
batch_size,
|
||||
vocab_size):
|
||||
"""Generate a fake batch of SNLI data for testing."""
|
||||
with tf.device("cpu:0"):
|
||||
labels = tf.random_uniform([batch_size], minval=1, maxval=4, dtype=tf.int64)
|
||||
prem = tf.random_uniform(
|
||||
(sequence_length, batch_size), maxval=vocab_size, dtype=tf.int64)
|
||||
prem_trans = tf.constant(np.array(
|
||||
[[3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3,
|
||||
2, 3, 3, 2, 2, 3, 3, 3, 2, 2, 2, 2,
|
||||
3, 2, 2]] * batch_size, dtype=np.int64).T)
|
||||
hypo = tf.random_uniform(
|
||||
(sequence_length, batch_size), maxval=vocab_size, dtype=tf.int64)
|
||||
hypo_trans = tf.constant(np.array(
|
||||
[[3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3,
|
||||
2, 3, 3, 2, 2, 3, 3, 3, 2, 2, 2, 2,
|
||||
3, 2, 2]] * batch_size, dtype=np.int64).T)
|
||||
if tfe.num_gpus():
|
||||
labels = labels.gpu()
|
||||
prem = prem.gpu()
|
||||
prem_trans = prem_trans.gpu()
|
||||
hypo = hypo.gpu()
|
||||
hypo_trans = hypo_trans.gpu()
|
||||
return labels, prem, prem_trans, hypo, hypo_trans
|
||||
|
||||
|
||||
def _snli_classifier_config(d_embed, d_out):
|
||||
config_tuple = collections.namedtuple(
|
||||
"Config", ["d_hidden", "d_proj", "d_tracker", "predict",
|
||||
"embed_dropout", "mlp_dropout", "n_mlp_layers", "d_mlp",
|
||||
"d_out", "projection", "lr"])
|
||||
config = config_tuple(
|
||||
d_hidden=d_embed,
|
||||
d_proj=d_embed * 2,
|
||||
d_tracker=8,
|
||||
predict=False,
|
||||
embed_dropout=0.1,
|
||||
mlp_dropout=0.1,
|
||||
n_mlp_layers=2,
|
||||
d_mlp=32,
|
||||
d_out=d_out,
|
||||
projection=True,
|
||||
lr=2e-3)
|
||||
return config
|
||||
|
||||
|
||||
class SpinnTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(SpinnTest, self).setUp()
|
||||
self._test_device = "gpu:0" if tfe.num_gpus() else "cpu:0"
|
||||
|
||||
def testBundle(self):
|
||||
with tf.device(self._test_device):
|
||||
lstm_iter = [np.array([[0, 1], [2, 3]], dtype=np.float32),
|
||||
np.array([[0, -1], [-2, -3]], dtype=np.float32),
|
||||
np.array([[0, 2], [4, 6]], dtype=np.float32),
|
||||
np.array([[0, -2], [-4, -6]], dtype=np.float32)]
|
||||
out = spinn._bundle(lstm_iter)
|
||||
|
||||
self.assertEqual(2, len(out))
|
||||
self.assertEqual(tf.float32, out[0].dtype)
|
||||
self.assertEqual(tf.float32, out[1].dtype)
|
||||
self.assertAllEqual(np.array([[0, 2, 0, -2, 0, 4, 0, -4]]).T,
|
||||
out[0].numpy())
|
||||
self.assertAllEqual(np.array([[1, 3, -1, -3, 2, 6, -2, -6]]).T,
|
||||
out[1].numpy())
|
||||
|
||||
def testUnbunbdle(self):
|
||||
with tf.device(self._test_device):
|
||||
state = [np.array([[0, 1, 2], [3, 4, 5]], dtype=np.float32),
|
||||
np.array([[0, -1, -2], [-3, -4, -5]], dtype=np.float32)]
|
||||
out = spinn._unbundle(state)
|
||||
|
||||
self.assertEqual(2, len(out))
|
||||
self.assertEqual(tf.float32, out[0].dtype)
|
||||
self.assertEqual(tf.float32, out[1].dtype)
|
||||
self.assertAllEqual(np.array([[0, 1, 2, 0, -1, -2]]),
|
||||
out[0].numpy())
|
||||
self.assertAllEqual(np.array([[3, 4, 5, -3, -4, -5]]),
|
||||
out[1].numpy())
|
||||
|
||||
def testReduce(self):
|
||||
with tf.device(self._test_device):
|
||||
batch_size = 3
|
||||
size = 10
|
||||
tracker_size = 8
|
||||
reducer = spinn.Reduce(size, tracker_size=tracker_size)
|
||||
|
||||
left_in = []
|
||||
right_in = []
|
||||
tracking = []
|
||||
for _ in range(batch_size):
|
||||
left_in.append(tf.random_normal((1, size * 2)))
|
||||
right_in.append(tf.random_normal((1, size * 2)))
|
||||
tracking.append(tf.random_normal((1, tracker_size * 2)))
|
||||
|
||||
out = reducer(left_in, right_in, tracking=tracking)
|
||||
self.assertEqual(batch_size, len(out))
|
||||
self.assertEqual(tf.float32, out[0].dtype)
|
||||
self.assertEqual((1, size * 2), out[0].shape)
|
||||
|
||||
def testReduceTreeLSTM(self):
|
||||
with tf.device(self._test_device):
|
||||
size = 10
|
||||
tracker_size = 8
|
||||
reducer = spinn.Reduce(size, tracker_size=tracker_size)
|
||||
|
||||
lstm_in = np.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
|
||||
[0, -1, -2, -3, -4, -5, -6, -7, -8, -9]],
|
||||
dtype=np.float32)
|
||||
c1 = np.array([[0, 1], [2, 3]], dtype=np.float32)
|
||||
c2 = np.array([[0, -1], [-2, -3]], dtype=np.float32)
|
||||
|
||||
h, c = reducer._tree_lstm(c1, c2, lstm_in)
|
||||
self.assertEqual(tf.float32, h.dtype)
|
||||
self.assertEqual(tf.float32, c.dtype)
|
||||
self.assertEqual((2, 2), h.shape)
|
||||
self.assertEqual((2, 2), c.shape)
|
||||
|
||||
def testTracker(self):
|
||||
with tf.device(self._test_device):
|
||||
batch_size = 2
|
||||
size = 10
|
||||
tracker_size = 8
|
||||
buffer_length = 18
|
||||
stack_size = 3
|
||||
|
||||
tracker = spinn.Tracker(tracker_size, False)
|
||||
tracker.reset_state()
|
||||
|
||||
# Create dummy inputs for testing.
|
||||
bufs = []
|
||||
buf = []
|
||||
for _ in range(buffer_length):
|
||||
buf.append(tf.random_normal((batch_size, size * 2)))
|
||||
bufs.append(buf)
|
||||
self.assertEqual(1, len(bufs))
|
||||
self.assertEqual(buffer_length, len(bufs[0]))
|
||||
self.assertEqual((batch_size, size * 2), bufs[0][0].shape)
|
||||
|
||||
stacks = []
|
||||
stack = []
|
||||
for _ in range(stack_size):
|
||||
stack.append(tf.random_normal((batch_size, size * 2)))
|
||||
stacks.append(stack)
|
||||
self.assertEqual(1, len(stacks))
|
||||
self.assertEqual(3, len(stacks[0]))
|
||||
self.assertEqual((batch_size, size * 2), stacks[0][0].shape)
|
||||
|
||||
for _ in range(2):
|
||||
out1, out2 = tracker(bufs, stacks)
|
||||
self.assertIsNone(out2)
|
||||
self.assertEqual(batch_size, len(out1))
|
||||
self.assertEqual(tf.float32, out1[0].dtype)
|
||||
self.assertEqual((1, tracker_size * 2), out1[0].shape)
|
||||
|
||||
self.assertEqual(tf.float32, tracker.state.c.dtype)
|
||||
self.assertEqual((batch_size, tracker_size), tracker.state.c.shape)
|
||||
self.assertEqual(tf.float32, tracker.state.h.dtype)
|
||||
self.assertEqual((batch_size, tracker_size), tracker.state.h.shape)
|
||||
|
||||
def testSPINN(self):
|
||||
with tf.device(self._test_device):
|
||||
embedding_dims = 10
|
||||
d_tracker = 8
|
||||
sequence_length = 15
|
||||
num_transitions = 27
|
||||
|
||||
config_tuple = collections.namedtuple(
|
||||
"Config", ["d_hidden", "d_proj", "d_tracker", "predict"])
|
||||
config = config_tuple(
|
||||
embedding_dims, embedding_dims * 2, d_tracker, False)
|
||||
s = spinn.SPINN(config)
|
||||
|
||||
# Create some fake data.
|
||||
buffers = tf.random_normal((sequence_length, 1, config.d_proj))
|
||||
transitions = np.array(
|
||||
[[3], [3], [2], [3], [3], [3], [2], [2], [2], [3], [3], [3],
|
||||
[2], [3], [3], [2], [2], [3], [3], [3], [2], [2], [2], [2],
|
||||
[3], [2], [2]], dtype=np.int32)
|
||||
self.assertEqual(tf.int32, transitions.dtype)
|
||||
self.assertEqual((num_transitions, 1), transitions.shape)
|
||||
|
||||
out = s(buffers, transitions, training=True)
|
||||
self.assertEqual(tf.float32, out.dtype)
|
||||
self.assertEqual((1, embedding_dims), out.shape)
|
||||
|
||||
def testSNLIClassifierAndTrainer(self):
|
||||
with tf.device(self._test_device):
|
||||
vocab_size = 40
|
||||
batch_size = 2
|
||||
d_embed = 10
|
||||
sequence_length = 15
|
||||
d_out = 4
|
||||
|
||||
config = _snli_classifier_config(d_embed, d_out)
|
||||
|
||||
# Create fake embedding matrix.
|
||||
embed = tf.random_normal((vocab_size, d_embed))
|
||||
|
||||
model = spinn.SNLIClassifier(config, embed)
|
||||
trainer = spinn.SNLIClassifierTrainer(model, config.lr)
|
||||
|
||||
(labels, prem, prem_trans, hypo,
|
||||
hypo_trans) = _generate_synthetic_snli_data_batch(sequence_length,
|
||||
batch_size,
|
||||
vocab_size)
|
||||
|
||||
# Invoke model under non-training mode.
|
||||
logits = model(prem, prem_trans, hypo, hypo_trans, training=False)
|
||||
self.assertEqual(tf.float32, logits.dtype)
|
||||
self.assertEqual((batch_size, d_out), logits.shape)
|
||||
|
||||
# Invoke model under training model.
|
||||
logits = model(prem, prem_trans, hypo, hypo_trans, training=True)
|
||||
self.assertEqual(tf.float32, logits.dtype)
|
||||
self.assertEqual((batch_size, d_out), logits.shape)
|
||||
|
||||
# Calculate loss.
|
||||
loss1 = trainer.loss(labels, logits)
|
||||
self.assertEqual(tf.float32, loss1.dtype)
|
||||
self.assertEqual((), loss1.shape)
|
||||
|
||||
loss2, logits = trainer.train_batch(
|
||||
labels, prem, prem_trans, hypo, hypo_trans)
|
||||
self.assertEqual(tf.float32, loss2.dtype)
|
||||
self.assertEqual((), loss2.shape)
|
||||
self.assertEqual(tf.float32, logits.dtype)
|
||||
self.assertEqual((batch_size, d_out), logits.shape)
|
||||
# Training on the batch should have led to a change in the loss value.
|
||||
self.assertNotEqual(loss1.numpy(), loss2.numpy())
|
||||
|
||||
|
||||
class EagerSpinnSNLIClassifierBenchmark(test.Benchmark):
|
||||
|
||||
def benchmarkEagerSpinnSNLIClassifier(self):
|
||||
test_device = "gpu:0" if tfe.num_gpus() else "cpu:0"
|
||||
with tf.device(test_device):
|
||||
burn_in_iterations = 2
|
||||
benchmark_iterations = 10
|
||||
|
||||
vocab_size = 1000
|
||||
batch_size = 128
|
||||
sequence_length = 15
|
||||
d_embed = 200
|
||||
d_out = 4
|
||||
|
||||
embed = tf.random_normal((vocab_size, d_embed))
|
||||
|
||||
config = _snli_classifier_config(d_embed, d_out)
|
||||
model = spinn.SNLIClassifier(config, embed)
|
||||
trainer = spinn.SNLIClassifierTrainer(model, config.lr)
|
||||
|
||||
(labels, prem, prem_trans, hypo,
|
||||
hypo_trans) = _generate_synthetic_snli_data_batch(sequence_length,
|
||||
batch_size,
|
||||
vocab_size)
|
||||
|
||||
for _ in range(burn_in_iterations):
|
||||
trainer.train_batch(labels, prem, prem_trans, hypo, hypo_trans)
|
||||
|
||||
gc.collect()
|
||||
start_time = time.time()
|
||||
for _ in xrange(benchmark_iterations):
|
||||
trainer.train_batch(labels, prem, prem_trans, hypo, hypo_trans)
|
||||
wall_time = time.time() - start_time
|
||||
# Named "examples"_per_sec to conform with other benchmarks.
|
||||
extras = {"examples_per_sec": benchmark_iterations / wall_time}
|
||||
self.report_benchmark(
|
||||
name="Eager_SPINN_SNLIClassifier_Benchmark",
|
||||
iters=benchmark_iterations,
|
||||
wall_time=wall_time,
|
||||
extras=extras)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -55,6 +55,10 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
|
||||
|
||||
@@IsolateTest
|
||||
@@run_test_in_graph_and_eager_modes
|
||||
|
||||
@@DEVICE_PLACEMENT_EXPLICIT
|
||||
@@DEVICE_PLACEMENT_WARN
|
||||
@@DEVICE_PLACEMENT_SILENT
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@ -71,6 +75,9 @@ from tensorflow.contrib.eager.python.saver import Saver
|
||||
from tensorflow.contrib.eager.python.summary_writer import SummaryWriter
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.eager.context import DEVICE_PLACEMENT_EXPLICIT
|
||||
from tensorflow.python.eager.context import DEVICE_PLACEMENT_WARN
|
||||
from tensorflow.python.eager.context import DEVICE_PLACEMENT_SILENT
|
||||
from tensorflow.python.eager.context import in_eager_mode
|
||||
from tensorflow.python.eager.context import in_graph_mode
|
||||
from tensorflow.python.eager.context import list_devices
|
||||
|
@ -42,6 +42,10 @@ _device_parsing_cache = {}
|
||||
|
||||
_MAXINT32 = 2**31 - 1
|
||||
|
||||
DEVICE_PLACEMENT_EXPLICIT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_EXPLICIT
|
||||
DEVICE_PLACEMENT_WARN = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_WARN
|
||||
DEVICE_PLACEMENT_SILENT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT
|
||||
|
||||
|
||||
# TODO(agarwal): better name ?
|
||||
class _EagerContext(threading.local):
|
||||
@ -62,13 +66,22 @@ class _EagerContext(threading.local):
|
||||
class Context(object):
|
||||
"""Environment in which eager operations execute."""
|
||||
|
||||
def __init__(self, config=None):
|
||||
def __init__(self, config=None, device_policy=None):
|
||||
"""Creates a new Context.
|
||||
|
||||
Args:
|
||||
config: (Optional.) A `ConfigProto` protocol buffer with configuration
|
||||
options for the Context. Note that a lot of these options may be
|
||||
currently unimplemented or irrelevant for EAGER mode.
|
||||
options for the Context. Note that a lot of these options may be
|
||||
currently unimplemented or irrelevant when eager execution is enabled.
|
||||
device_policy: (Optional.) What policy to use when trying to run an
|
||||
operation on a device with inputs which are not on that device.
|
||||
Valid values:
|
||||
tfe.DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is not
|
||||
correct.
|
||||
tfe.DEVICE_PLACEMENT_WARN: copies the tensors which are not on the
|
||||
right device but raises a warning.
|
||||
tfe.DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might
|
||||
hide performance problems.
|
||||
"""
|
||||
self._eager_context = _EagerContext()
|
||||
self._context_handle = None
|
||||
@ -78,6 +91,7 @@ class Context(object):
|
||||
self._config = config
|
||||
self._seed = None
|
||||
self._initialize_lock = threading.Lock()
|
||||
self._device_policy = device_policy
|
||||
|
||||
def _set_global_seed(self, seed):
|
||||
"""Set a global eager mode seed for random ops."""
|
||||
@ -109,6 +123,9 @@ class Context(object):
|
||||
config_str = self._config.SerializeToString()
|
||||
pywrap_tensorflow.TFE_ContextOptionsSetConfig(
|
||||
opts, config_str, len(config_str), status)
|
||||
if self._device_policy is not None:
|
||||
pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy(
|
||||
opts, self._device_policy)
|
||||
self._context_handle = pywrap_tensorflow.TFE_NewContext(opts, status)
|
||||
finally:
|
||||
pywrap_tensorflow.TFE_DeleteContextOptions(opts)
|
||||
|
@ -266,6 +266,23 @@ class OpsTest(test_util.TensorFlowTestCase):
|
||||
shape = array_ops.shape(value)
|
||||
self.assertEqual([1], shape.numpy())
|
||||
|
||||
def testSilentCopy(self):
|
||||
if not context.context().num_gpus():
|
||||
self.skipTest('No GPUs found')
|
||||
# Temporarily replace the context
|
||||
# pylint: disable=protected-access
|
||||
del context._context
|
||||
try:
|
||||
context._context = context.Context(
|
||||
device_policy=context.DEVICE_PLACEMENT_SILENT)
|
||||
cpu_tensor = constant_op.constant(1.0)
|
||||
gpu_tensor = cpu_tensor.gpu()
|
||||
self.assertAllEqual(cpu_tensor + gpu_tensor, 2.0)
|
||||
finally:
|
||||
del context._context
|
||||
context._context = context.Context()
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def testRandomUniform(self):
|
||||
scalar_shape = constant_op.constant([], dtype=dtypes.int32)
|
||||
|
||||
|
@ -4559,7 +4559,7 @@ class _DefaultGraphStack(_DefaultStack): # pylint: disable=protected-access
|
||||
_default_graph_stack = _DefaultGraphStack()
|
||||
|
||||
|
||||
def enable_eager_execution():
|
||||
def enable_eager_execution(config=None, device_policy=None):
|
||||
"""Enables, for the rest of the lifetime of this program, eager execution.
|
||||
|
||||
If not called immediately on startup risks creating breakage and bugs.
|
||||
@ -4574,8 +4574,24 @@ def enable_eager_execution():
|
||||
assert tf.multiply(6, 7).numpy() == 42
|
||||
```
|
||||
|
||||
Args:
|
||||
config: (Optional.) A `ConfigProto` protocol buffer with configuration
|
||||
options for the Context. Note that a lot of these options may be
|
||||
currently unimplemented or irrelevant when eager execution is enabled.
|
||||
device_policy: (Optional.) What policy to use when trying to run an
|
||||
operation on a device with inputs which are not on that device.
|
||||
Valid values:
|
||||
tfe.DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is not
|
||||
correct.
|
||||
tfe.DEVICE_PLACEMENT_WARN: copies the tensors which are not on the
|
||||
right device but raises a warning.
|
||||
tfe.DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might
|
||||
hide performance problems.
|
||||
|
||||
Raises:
|
||||
ValueError: If this method has already been invoked in the current process.
|
||||
ValueError: If trying to create a context after using graph operations
|
||||
or if trying to create a context with nontrivial options which differ
|
||||
from those of the existing context.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
if context._default_mode == context.GRAPH_MODE:
|
||||
@ -4586,6 +4602,18 @@ def enable_eager_execution():
|
||||
raise ValueError(
|
||||
"tfe.enable_eager_execution has to be called at program startup.")
|
||||
context._default_mode = context.EAGER_MODE
|
||||
if context._context is None:
|
||||
context._context = context.Context(config=config,
|
||||
device_policy=device_policy)
|
||||
elif ((config is not None and config is not context._context._config)
|
||||
or (device_policy is not None
|
||||
and device_policy is not context._context._device_policy)):
|
||||
raise ValueError("Trying to change the options of an active eager"
|
||||
" execution. Context config: %s, specified config:"
|
||||
" %s. Context device policy: %s; specified device"
|
||||
" policy: %s." % (config, context._context._config,
|
||||
device_policy,
|
||||
context._context._device_policy))
|
||||
|
||||
|
||||
def eager_run(main=None, argv=None):
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
%rename("%s") TFE_Py_TapeExport;
|
||||
%rename("%s") TFE_NewContextOptions;
|
||||
%rename("%s") TFE_ContextOptionsSetConfig;
|
||||
%rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy;
|
||||
%rename("%s") TFE_DeleteContextOptions;
|
||||
|
||||
%{
|
||||
@ -101,6 +102,11 @@ limitations under the License.
|
||||
}
|
||||
}
|
||||
|
||||
%rename("%s") TFE_ContextDevicePlacementPolicy;
|
||||
%rename("%s") TFE_DEVICE_PLACEMENT_EXPLICIT;
|
||||
%rename("%s") TFE_DEVICE_PLACEMENT_WARN;
|
||||
%rename("%s") TFE_DEVICE_PLACEMENT_SILENT;
|
||||
|
||||
%include "tensorflow/c/eager/c_api.h"
|
||||
|
||||
%typemap(in) TFE_InputTensorHandles* inputs (TFE_InputTensorHandles temp) {
|
||||
|
Loading…
Reference in New Issue
Block a user