BEGIN_PUBLIC

Delete tf.contrib.kfac. K-FAC in Tensorflow is now its own separate package.
END_PUBLIC

RELNOTES: n/a

Automated rollback of commit 938b9a4078

PiperOrigin-RevId: 209813506
This commit is contained in:
Vikram Tankasali 2018-08-22 12:51:59 -07:00 committed by TensorFlower Gardener
parent c85e0a9829
commit c73964210c
49 changed files with 1 additions and 14435 deletions

View File

@ -61,7 +61,6 @@ py_library(
"//tensorflow/contrib/integrate:integrate_py",
"//tensorflow/contrib/keras",
"//tensorflow/contrib/kernel_methods",
"//tensorflow/contrib/kfac",
"//tensorflow/contrib/labeled_tensor",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/learn",

View File

@ -51,7 +51,6 @@ from tensorflow.contrib import input_pipeline
from tensorflow.contrib import integrate
from tensorflow.contrib import keras
from tensorflow.contrib import kernel_methods
from tensorflow.contrib import kfac
from tensorflow.contrib import labeled_tensor
from tensorflow.contrib import layers
from tensorflow.contrib import learn

View File

@ -247,10 +247,6 @@ tensorflow/contrib/kernel_methods/python
tensorflow/contrib/kernel_methods/python/mappers
tensorflow/contrib/kinesis/python
tensorflow/contrib/kinesis/python/ops
tensorflow/contrib/kfac
tensorflow/contrib/kfac/examples
tensorflow/contrib/kfac/python
tensorflow/contrib/kfac/python/ops
tensorflow/contrib/labeled_tensor
tensorflow/contrib/labeled_tensor/python
tensorflow/contrib/labeled_tensor/python/ops

View File

@ -1,26 +0,0 @@
# Description:
# Contains KfacOptimizer, an implementation of the K-FAC optimization
# algorithm in TensorFlow.
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
py_library(
name = "kfac",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/kfac/python/ops:curvature_matrix_vector_products_lib",
"//tensorflow/contrib/kfac/python/ops:fisher_blocks_lib",
"//tensorflow/contrib/kfac/python/ops:fisher_estimator_lib",
"//tensorflow/contrib/kfac/python/ops:fisher_factors_lib",
"//tensorflow/contrib/kfac/python/ops:kfac_optimizer_lib",
"//tensorflow/contrib/kfac/python/ops:layer_collection_lib",
"//tensorflow/contrib/kfac/python/ops:loss_functions_lib",
"//tensorflow/contrib/kfac/python/ops:op_queue_lib",
"//tensorflow/contrib/kfac/python/ops:utils_lib",
"//tensorflow/python:util",
],
)

View File

@ -1,94 +1,3 @@
# K-FAC: Kronecker-Factored Approximate Curvature
# <font color="red", size=10><u>WARNING: </u></font>
# ==third_party/tensorflow/contrib/kfac is deprecated. This will be==
# ==removed on 15-07-2018. <!-- STY:begin_strip_and_replace -->Please import third_party/tensorflow_kfac.==
# ==<!-- STY:end_strip_and_replace Please check https://github.com/tensorflow/kfac. -->==
**K-FAC in TensorFlow** is an implementation of [K-FAC][kfac-paper], an
approximate second-order optimization method, in TensorFlow. When applied to
feedforward and convolutional neural networks, K-FAC can converge `>3.5x`
faster in `>14x` fewer iterations than SGD with Momentum.
[kfac-paper]: https://arxiv.org/abs/1503.05671
## What is K-FAC?
K-FAC, short for "Kronecker-factored Approximate Curvature", is an approximation
to the [Natural Gradient][natural_gradient] algorithm designed specifically for
neural networks. It maintains a block-diagonal approximation to the [Fisher
Information matrix][fisher_information], whose inverse preconditions the
gradient.
K-FAC can be used in place of SGD, Adam, and other `Optimizer` implementations.
Experimentally, K-FAC converges `>3.5x` faster than well-tuned SGD.
Unlike most optimizers, K-FAC exploits structure in the model itself (e.g. "What
are the weights for layer i?"). As such, you must add some additional code while
constructing your model to use K-FAC.
[natural_gradient]: http://www.mitpressjournals.org/doi/abs/10.1162/089976698300017746
[fisher_information]: https://en.wikipedia.org/wiki/Fisher_information#Matrix_form
## Why should I use K-FAC?
K-FAC can take advantage of the curvature of the optimization problem, resulting
in **faster training**. For an 8-layer Autoencoder, K-FAC converges to the same
loss as SGD with Momentum in 3.8x fewer seconds and 14.7x fewer updates. See how
training loss changes as a function of number of epochs, steps, and seconds:
![autoencoder](g3doc/autoencoder.png)
## Is K-FAC for me?
If you have a feedforward or convolutional model for classification that is
converging too slowly, K-FAC is for you. K-FAC can be used in your model if:
* Your model defines a posterior distribution.
* Your model uses only fully-connected or convolutional layers (residual
connections OK).
* You are training on CPU or GPU.
* You can modify model code to register layers with K-FAC.
## How do I use K-FAC?
Using K-FAC requires three steps:
1. Registering layer inputs, weights, and pre-activations with a
`LayerCollection`.
1. Minimizing the loss with a `KfacOptimizer`.
1. Keeping K-FAC's preconditioner updated.
```python
# Build model.
w = tf.get_variable("w", ...)
b = tf.get_variable("b", ...)
logits = tf.matmul(x, w) + b
loss = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits))
# Register layers.
layer_collection = LayerCollection()
layer_collection.register_fully_connected((w, b), x, logits)
layer_collection.register_categorical_predictive_distribution(logits)
# Construct training ops.
optimizer = KfacOptimizer(..., layer_collection=layer_collection)
train_op = optimizer.minimize(loss)
# Minimize loss.
with tf.Session() as sess:
...
sess.run([train_op, optimizer.cov_update_op, optimizer.inv_update_op])
```
See [`examples/`](https://www.tensorflow.org/code/tensorflow/contrib/kfac/examples/) for runnable, end-to-end illustrations.
## Authors
- Alok Aggarwal
- Daniel Duckworth
- James Martens
- Matthew Johnson
- Olga Wichrowska
- Roger Grosse
## KFAC moved to third_party/tensorflow_kfac.

View File

@ -1,46 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Kronecker-factored Approximate Curvature Optimizer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,line-too-long
from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products_lib as curvature_matrix_vector_products
from tensorflow.contrib.kfac.python.ops import estimator_lib as estimator
from tensorflow.contrib.kfac.python.ops import fisher_blocks_lib as fisher_blocks
from tensorflow.contrib.kfac.python.ops import fisher_factors_lib as fisher_factors
from tensorflow.contrib.kfac.python.ops import layer_collection_lib as layer_collection
from tensorflow.contrib.kfac.python.ops import loss_functions_lib as loss_functions
from tensorflow.contrib.kfac.python.ops import op_queue_lib as op_queue
from tensorflow.contrib.kfac.python.ops import optimizer_lib as optimizer
from tensorflow.contrib.kfac.python.ops import utils_lib as utils
from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long
_allowed_symbols = [
"curvature_matrix_vector_products",
"estimator",
"fisher_blocks",
"fisher_factors",
"layer_collection",
"loss_functions",
"op_queue",
"optimizer",
"utils",
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)

View File

@ -1,80 +0,0 @@
package(default_visibility = [
"//learning/brain/contrib/kfac/examples:__subpackages__",
"//tensorflow/contrib/kfac/examples:__subpackages__",
])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
py_binary(
name = "mlp_mnist_main",
srcs = ["mlp_mnist_main.py"],
srcs_version = "PY2AND3",
deps = [
":mlp",
"//tensorflow:tensorflow_py",
],
)
py_library(
name = "mlp",
srcs = ["mlp.py"],
srcs_version = "PY2AND3",
deps = [
":mnist",
"//tensorflow:tensorflow_py",
],
)
py_binary(
name = "convnet_mnist_single_main",
srcs = ["convnet_mnist_single_main.py"],
srcs_version = "PY2AND3",
deps = [
":convnet",
"//tensorflow:tensorflow_py",
],
)
py_binary(
name = "convnet_mnist_multi_tower_main",
srcs = ["convnet_mnist_multi_tower_main.py"],
srcs_version = "PY2AND3",
deps = [
":convnet",
"//tensorflow:tensorflow_py",
],
)
py_binary(
name = "convnet_mnist_distributed_main",
srcs = ["convnet_mnist_distributed_main.py"],
srcs_version = "PY2AND3",
deps = [
":convnet",
"//tensorflow:tensorflow_py",
],
)
py_library(
name = "convnet",
srcs = ["convnet.py"],
srcs_version = "PY2AND3",
deps = [
":mlp",
":mnist",
"//tensorflow:tensorflow_py",
"//third_party/py/numpy",
],
)
py_library(
name = "mnist",
srcs = ["mnist.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
"//third_party/py/numpy",
],
)

View File

@ -1,667 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Train a ConvNet on MNIST using K-FAC.
This library fits a 5-layer ConvNet on MNIST using K-FAC. The model has the
following structure,
- Conv Layer: 5x5 kernel, 16 output channels.
- Max Pool: 3x3 kernel, stride 2.
- Conv Layer: 5x5 kernel, 16 output channels.
- Max Pool: 3x3 kernel, stride 2.
- Linear: 10 output dims.
After 3k~6k steps, this should reach perfect accuracy on the training set.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
from tensorflow.contrib.kfac.examples import mlp
from tensorflow.contrib.kfac.examples import mnist
from tensorflow.contrib.kfac.python.ops import optimizer as opt
lc = tf.contrib.kfac.layer_collection
oq = tf.contrib.kfac.op_queue
opt = tf.contrib.kfac.optimizer
__all__ = [
"conv_layer",
"max_pool_layer",
"linear_layer",
"build_model",
"minimize_loss_single_machine",
"distributed_grads_only_and_ops_chief_worker",
"distributed_grads_and_ops_dedicated_workers",
"train_mnist_single_machine",
"train_mnist_distributed_sync_replicas",
"train_mnist_multitower"
]
# Inverse update ops will be run every _INVERT_EVRY iterations.
_INVERT_EVERY = 10
def conv_layer(layer_id, inputs, kernel_size, out_channels):
"""Builds a convolutional layer with ReLU non-linearity.
Args:
layer_id: int. Integer ID for this layer's variables.
inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row
corresponds to a single example.
kernel_size: int. Width and height of the convolution kernel. The kernel is
assumed to be square.
out_channels: int. Number of output features per pixel.
Returns:
preactivations: Tensor of shape [num_examples, width, height, out_channels].
Values of the layer immediately before the activation function.
activations: Tensor of shape [num_examples, width, height, out_channels].
Values of the layer immediately after the activation function.
params: Tuple of (kernel, bias), parameters for this layer.
"""
# TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
layer = tf.layers.Conv2D(
out_channels,
kernel_size=[kernel_size, kernel_size],
kernel_initializer=tf.random_normal_initializer(stddev=0.01),
padding="SAME",
name="conv_%d" % layer_id)
preactivations = layer(inputs)
activations = tf.nn.relu(preactivations)
# layer.weights is a list. This converts it a (hashable) tuple.
return preactivations, activations, (layer.kernel, layer.bias)
def max_pool_layer(layer_id, inputs, kernel_size, stride):
"""Build a max-pooling layer.
Args:
layer_id: int. Integer ID for this layer's variables.
inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row
corresponds to a single example.
kernel_size: int. Width and height to pool over per input channel. The
kernel is assumed to be square.
stride: int. Step size between pooling operations.
Returns:
Tensor of shape [num_examples, width/stride, height/stride, out_channels].
Result of applying max pooling to 'inputs'.
"""
# TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
with tf.variable_scope("pool_%d" % layer_id):
return tf.nn.max_pool(
inputs, [1, kernel_size, kernel_size, 1], [1, stride, stride, 1],
padding="SAME",
name="pool")
def linear_layer(layer_id, inputs, output_size):
"""Builds the final linear layer for an MNIST classification problem.
Args:
layer_id: int. Integer ID for this layer's variables.
inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row
corresponds to a single example.
output_size: int. Number of output dims per example.
Returns:
activations: Tensor of shape [num_examples, output_size]. Values of the
layer immediately after the activation function.
params: Tuple of (weights, bias), parameters for this layer.
"""
# TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
pre, _, params = mlp.fc_layer(layer_id, inputs, output_size)
return pre, params
def build_model(examples, labels, num_labels, layer_collection):
"""Builds a ConvNet classification model.
Args:
examples: Tensor of shape [num_examples, num_features]. Represents inputs of
model.
labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted
by softmax for each example.
num_labels: int. Number of distinct values 'labels' can take on.
layer_collection: LayerCollection instance. Layers will be registered here.
Returns:
loss: 0-D Tensor representing loss to be minimized.
accuracy: 0-D Tensor representing model's accuracy.
"""
# Build a ConvNet. For each layer with parameters, we'll keep track of the
# preactivations, activations, weights, and bias.
tf.logging.info("Building model.")
pre0, act0, params0 = conv_layer(
layer_id=0, inputs=examples, kernel_size=5, out_channels=16)
act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2)
pre2, act2, params2 = conv_layer(
layer_id=2, inputs=act1, kernel_size=5, out_channels=16)
act3 = max_pool_layer(layer_id=3, inputs=act2, kernel_size=3, stride=2)
flat_act3 = tf.reshape(act3, shape=[-1, int(np.prod(act3.shape[1:4]))])
logits, params4 = linear_layer(
layer_id=4, inputs=flat_act3, output_size=num_labels)
loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits))
accuracy = tf.reduce_mean(
tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32))
with tf.device("/cpu:0"):
tf.summary.scalar("loss", loss)
tf.summary.scalar("accuracy", accuracy)
# Register parameters. K-FAC needs to know about the inputs, outputs, and
# parameters of each conv/fully connected layer and the logits powering the
# posterior probability over classes.
tf.logging.info("Building LayerCollection.")
layer_collection.register_conv2d(params0, (1, 1, 1, 1), "SAME", examples,
pre0)
layer_collection.register_conv2d(params2, (1, 1, 1, 1), "SAME", act1, pre2)
layer_collection.register_fully_connected(params4, flat_act3, logits)
layer_collection.register_categorical_predictive_distribution(
logits, name="logits")
return loss, accuracy
def minimize_loss_single_machine(loss,
accuracy,
layer_collection,
device="/gpu:0",
session_config=None):
"""Minimize loss with K-FAC on a single machine.
A single Session is responsible for running all of K-FAC's ops. The covariance
and inverse update ops are placed on `device`. All model variables are on CPU.
Args:
loss: 0-D Tensor. Loss to be minimized.
accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
layer_collection: LayerCollection instance describing model architecture.
Used by K-FAC to construct preconditioner.
device: string, Either '/cpu:0' or '/gpu:0'. The covariance and inverse
update ops are run on this device.
session_config: None or tf.ConfigProto. Configuration for tf.Session().
Returns:
final value for 'accuracy'.
"""
# Train with K-FAC.
g_step = tf.train.get_or_create_global_step()
optimizer = opt.KfacOptimizer(
learning_rate=0.0001,
cov_ema_decay=0.95,
damping=0.001,
layer_collection=layer_collection,
placement_strategy="round_robin",
cov_devices=[device],
inv_devices=[device],
momentum=0.9)
(cov_update_thunks,
inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
def make_update_op(update_thunks):
update_ops = [thunk() for thunk in update_thunks]
return tf.group(*update_ops)
cov_update_op = make_update_op(cov_update_thunks)
with tf.control_dependencies([cov_update_op]):
inverse_op = tf.cond(
tf.equal(tf.mod(g_step, _INVERT_EVERY), 0),
lambda: make_update_op(inv_update_thunks), tf.no_op)
with tf.control_dependencies([inverse_op]):
with tf.device(device):
train_op = optimizer.minimize(loss, global_step=g_step)
tf.logging.info("Starting training.")
with tf.train.MonitoredTrainingSession(config=session_config) as sess:
while not sess.should_stop():
global_step_, loss_, accuracy_, _ = sess.run(
[g_step, loss, accuracy, train_op])
if global_step_ % _INVERT_EVERY == 0:
tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
global_step_, loss_, accuracy_)
return accuracy_
def _is_gradient_task(task_id, num_tasks):
"""Returns True if this task should update the weights."""
if num_tasks < 3:
return True
return 0 <= task_id < 0.6 * num_tasks
def _is_cov_update_task(task_id, num_tasks):
"""Returns True if this task should update K-FAC's covariance matrices."""
if num_tasks < 3:
return False
return 0.6 * num_tasks <= task_id < num_tasks - 1
def _is_inv_update_task(task_id, num_tasks):
"""Returns True if this task should update K-FAC's preconditioner."""
if num_tasks < 3:
return False
return task_id == num_tasks - 1
def _num_gradient_tasks(num_tasks):
"""Number of tasks that will update weights."""
if num_tasks < 3:
return num_tasks
return int(np.ceil(0.6 * num_tasks))
def _make_distributed_train_op(
task_id,
num_worker_tasks,
num_ps_tasks,
layer_collection
):
"""Creates optimizer and distributed training op.
Constructs KFAC optimizer and wraps it in `sync_replicas` optimizer. Makes
the train op.
Args:
task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
num_worker_tasks: int. Number of workers in this distributed training setup.
num_ps_tasks: int. Number of parameter servers holding variables. If 0,
parameter servers are not used.
layer_collection: LayerCollection instance describing model architecture.
Used by K-FAC to construct preconditioner.
Returns:
sync_optimizer: `tf.train.SyncReplicasOptimizer` instance which wraps KFAC
optimizer.
optimizer: Instance of `opt.KfacOptimizer`.
global_step: `tensor`, Global step.
"""
tf.logging.info("Task id : %d", task_id)
with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
global_step = tf.train.get_or_create_global_step()
optimizer = opt.KfacOptimizer(
learning_rate=0.0001,
cov_ema_decay=0.95,
damping=0.001,
layer_collection=layer_collection,
momentum=0.9)
sync_optimizer = tf.train.SyncReplicasOptimizer(
opt=optimizer,
replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks),
total_num_replicas=num_worker_tasks)
return sync_optimizer, optimizer, global_step
def distributed_grads_only_and_ops_chief_worker(
task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir,
loss, accuracy, layer_collection, invert_every=10):
"""Minimize loss with a synchronous implementation of K-FAC.
All workers perform gradient computation. Chief worker applies gradient after
averaging the gradients obtained from all the workers. All workers block
execution until the update is applied. Chief worker runs covariance and
inverse update ops. Covariance and inverse matrices are placed on parameter
servers in a round robin manner. For further details on synchronous
distributed optimization check `tf.train.SyncReplicasOptimizer`.
Args:
task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
is_chief: `boolean`, `True` if the worker is chief worker.
num_worker_tasks: int. Number of workers in this distributed training setup.
num_ps_tasks: int. Number of parameter servers holding variables. If 0,
parameter servers are not used.
master: string. IP and port of TensorFlow runtime process. Set to empty
string to run locally.
checkpoint_dir: string or None. Path to store checkpoints under.
loss: 0-D Tensor. Loss to be minimized.
accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to
run with each step.
layer_collection: LayerCollection instance describing model architecture.
Used by K-FAC to construct preconditioner.
invert_every: `int`, Number of steps between update the inverse.
Returns:
final value for 'accuracy'.
Raises:
ValueError: if task_id >= num_worker_tasks.
"""
sync_optimizer, optimizer, global_step = _make_distributed_train_op(
task_id, num_worker_tasks, num_ps_tasks, layer_collection)
(cov_update_thunks,
inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
tf.logging.info("Starting training.")
hooks = [sync_optimizer.make_session_run_hook(is_chief)]
def make_update_op(update_thunks):
update_ops = [thunk() for thunk in update_thunks]
return tf.group(*update_ops)
if is_chief:
cov_update_op = make_update_op(cov_update_thunks)
with tf.control_dependencies([cov_update_op]):
inverse_op = tf.cond(
tf.equal(tf.mod(global_step, invert_every), 0),
lambda: make_update_op(inv_update_thunks),
tf.no_op)
with tf.control_dependencies([inverse_op]):
train_op = sync_optimizer.minimize(loss, global_step=global_step)
else:
train_op = sync_optimizer.minimize(loss, global_step=global_step)
with tf.train.MonitoredTrainingSession(
master=master,
is_chief=is_chief,
checkpoint_dir=checkpoint_dir,
hooks=hooks,
stop_grace_period_secs=0) as sess:
while not sess.should_stop():
global_step_, loss_, accuracy_, _ = sess.run(
[global_step, loss, accuracy, train_op])
tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_,
loss_, accuracy_)
return accuracy_
def distributed_grads_and_ops_dedicated_workers(
task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir,
loss, accuracy, layer_collection):
"""Minimize loss with a synchronous implementation of K-FAC.
Different workers are responsible for different parts of K-FAC's Ops. The
first 60% of tasks compute gradients; the next 20% accumulate covariance
statistics; the last 20% invert the matrices used to precondition gradients.
The chief worker applies the gradient .
Args:
task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
is_chief: `boolean`, `True` if the worker is chief worker.
num_worker_tasks: int. Number of workers in this distributed training setup.
num_ps_tasks: int. Number of parameter servers holding variables. If 0,
parameter servers are not used.
master: string. IP and port of TensorFlow runtime process. Set to empty
string to run locally.
checkpoint_dir: string or None. Path to store checkpoints under.
loss: 0-D Tensor. Loss to be minimized.
accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to
run with each step.
layer_collection: LayerCollection instance describing model architecture.
Used by K-FAC to construct preconditioner.
Returns:
final value for 'accuracy'.
Raises:
ValueError: if task_id >= num_worker_tasks.
"""
sync_optimizer, optimizer, global_step = _make_distributed_train_op(
task_id, num_worker_tasks, num_ps_tasks, layer_collection)
_, cov_update_op, inv_update_ops, _, _, _ = optimizer.make_ops_and_vars()
train_op = sync_optimizer.minimize(loss, global_step=global_step)
inv_update_queue = oq.OpQueue(inv_update_ops)
tf.logging.info("Starting training.")
is_chief = (task_id == 0)
hooks = [sync_optimizer.make_session_run_hook(is_chief)]
with tf.train.MonitoredTrainingSession(
master=master,
is_chief=is_chief,
checkpoint_dir=checkpoint_dir,
hooks=hooks,
stop_grace_period_secs=0) as sess:
while not sess.should_stop():
# Choose which op this task is responsible for running.
if _is_gradient_task(task_id, num_worker_tasks):
learning_op = train_op
elif _is_cov_update_task(task_id, num_worker_tasks):
learning_op = cov_update_op
elif _is_inv_update_task(task_id, num_worker_tasks):
# TODO(duckworthd): Running this op before cov_update_op has been run a
# few times can result in "InvalidArgumentError: Cholesky decomposition
# was not successful." Delay running this op until cov_update_op has
# been run a few times.
learning_op = inv_update_queue.next_op(sess)
else:
raise ValueError("Which op should task %d do?" % task_id)
global_step_, loss_, accuracy_, _ = sess.run(
[global_step, loss, accuracy, learning_op])
tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_,
loss_, accuracy_)
return accuracy_
def train_mnist_single_machine(data_dir,
num_epochs,
use_fake_data=False,
device="/gpu:0"):
"""Train a ConvNet on MNIST.
Args:
data_dir: string. Directory to read MNIST examples from.
num_epochs: int. Number of passes to make over the training set.
use_fake_data: bool. If True, generate a synthetic dataset.
device: string, Either '/cpu:0' or '/gpu:0'. The covariance and inverse
update ops are run on this device.
Returns:
accuracy of model on the final minibatch of training data.
"""
# Load a dataset.
tf.logging.info("Loading MNIST into memory.")
examples, labels = mnist.load_mnist(
data_dir,
num_epochs=num_epochs,
batch_size=128,
use_fake_data=use_fake_data,
flatten_images=False)
# Build a ConvNet.
layer_collection = lc.LayerCollection()
loss, accuracy = build_model(
examples, labels, num_labels=10, layer_collection=layer_collection)
# Fit model.
return minimize_loss_single_machine(
loss, accuracy, layer_collection, device=device)
def train_mnist_multitower(data_dir, num_epochs, num_towers,
use_fake_data=True, devices=None):
"""Train a ConvNet on MNIST.
Training data is split equally among the towers. Each tower computes loss on
its own batch of data and the loss is aggregated on the CPU. The model
variables are placed on first tower. The covariance and inverse update ops
and variables are placed on GPUs in a round robin manner.
Args:
data_dir: string. Directory to read MNIST examples from.
num_epochs: int. Number of passes to make over the training set.
num_towers: int. Number of CPUs to split inference across.
use_fake_data: bool. If True, generate a synthetic dataset.
devices: string, Either list of CPU or GPU. The covariance and inverse
update ops are run on this device.
Returns:
accuracy of model on the final minibatch of training data.
"""
if devices:
device_count = {"GPU": num_towers}
else:
device_count = {"CPU": num_towers}
devices = devices or [
"/cpu:{}".format(tower_id) for tower_id in range(num_towers)
]
# Load a dataset.
tf.logging.info("Loading MNIST into memory.")
tower_batch_size = 128
batch_size = tower_batch_size * num_towers
tf.logging.info(
("Loading MNIST into memory. Using batch_size = %d = %d towers * %d "
"tower batch size.") % (batch_size, num_towers, tower_batch_size))
examples, labels = mnist.load_mnist(
data_dir,
num_epochs=num_epochs,
batch_size=batch_size,
use_fake_data=use_fake_data,
flatten_images=False)
# Split minibatch across towers.
examples = tf.split(examples, num_towers)
labels = tf.split(labels, num_towers)
# Build an MLP. Each tower's layers will be added to the LayerCollection.
layer_collection = lc.LayerCollection()
tower_results = []
for tower_id in range(num_towers):
with tf.device(devices[tower_id]):
with tf.name_scope("tower%d" % tower_id):
with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)):
tf.logging.info("Building tower %d." % tower_id)
tower_results.append(
build_model(examples[tower_id], labels[tower_id], 10,
layer_collection))
losses, accuracies = zip(*tower_results)
# Average across towers.
loss = tf.reduce_mean(losses)
accuracy = tf.reduce_mean(accuracies)
# Fit model.
session_config = tf.ConfigProto(
allow_soft_placement=False,
device_count=device_count,
)
g_step = tf.train.get_or_create_global_step()
optimizer = opt.KfacOptimizer(
learning_rate=0.0001,
cov_ema_decay=0.95,
damping=0.001,
layer_collection=layer_collection,
placement_strategy="round_robin",
cov_devices=devices,
inv_devices=devices,
momentum=0.9)
(cov_update_thunks,
inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
def make_update_op(update_thunks):
update_ops = [thunk() for thunk in update_thunks]
return tf.group(*update_ops)
cov_update_op = make_update_op(cov_update_thunks)
with tf.control_dependencies([cov_update_op]):
inverse_op = tf.cond(
tf.equal(tf.mod(g_step, _INVERT_EVERY), 0),
lambda: make_update_op(inv_update_thunks), tf.no_op)
with tf.control_dependencies([inverse_op]):
train_op = optimizer.minimize(loss, global_step=g_step)
tf.logging.info("Starting training.")
with tf.train.MonitoredTrainingSession(config=session_config) as sess:
while not sess.should_stop():
global_step_, loss_, accuracy_, _ = sess.run(
[g_step, loss, accuracy, train_op])
if global_step_ % _INVERT_EVERY == 0:
tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
global_step_, loss_, accuracy_)
def train_mnist_distributed_sync_replicas(task_id,
is_chief,
num_worker_tasks,
num_ps_tasks,
master,
data_dir,
num_epochs,
op_strategy,
use_fake_data=False):
"""Train a ConvNet on MNIST using Sync replicas optimizer.
Args:
task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
is_chief: `boolean`, `True` if the worker is chief worker.
num_worker_tasks: int. Number of workers in this distributed training setup.
num_ps_tasks: int. Number of parameter servers holding variables.
master: string. IP and port of TensorFlow runtime process.
data_dir: string. Directory to read MNIST examples from.
num_epochs: int. Number of passes to make over the training set.
op_strategy: `string`, Strategy to run the covariance and inverse
ops. If op_strategy == `chief_worker` then covariance and inverse
update ops are run on chief worker otherwise they are run on dedicated
workers.
use_fake_data: bool. If True, generate a synthetic dataset.
Returns:
accuracy of model on the final minibatch of training data.
Raises:
ValueError: If `op_strategy` not in ["chief_worker", "dedicated_workers"].
"""
# Load a dataset.
tf.logging.info("Loading MNIST into memory.")
examples, labels = mnist.load_mnist(
data_dir,
num_epochs=num_epochs,
batch_size=128,
use_fake_data=use_fake_data,
flatten_images=False)
# Build a ConvNet.
layer_collection = lc.LayerCollection()
with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
loss, accuracy = build_model(
examples, labels, num_labels=10, layer_collection=layer_collection)
# Fit model.
checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac")
if op_strategy == "chief_worker":
return distributed_grads_only_and_ops_chief_worker(
task_id, is_chief, num_worker_tasks, num_ps_tasks, master,
checkpoint_dir, loss, accuracy, layer_collection)
elif op_strategy == "dedicated_workers":
return distributed_grads_and_ops_dedicated_workers(
task_id, is_chief, num_worker_tasks, num_ps_tasks, master,
checkpoint_dir, loss, accuracy, layer_collection)
else:
raise ValueError("Only supported op strategies are : {}, {}".format(
"chief_worker", "dedicated_workers"))
if __name__ == "__main__":
tf.app.run()

View File

@ -1,62 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Train a ConvNet on MNIST using K-FAC.
Distributed training with sync replicas optimizer. See
`convnet.train_mnist_distributed_sync_replicas` for details.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
import tensorflow as tf
from tensorflow.contrib.kfac.examples import convnet
FLAGS = flags.FLAGS
flags.DEFINE_integer("task", -1, "Task identifier")
flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir")
flags.DEFINE_string(
"cov_inv_op_strategy", "chief_worker",
"In dist training mode run the cov, inv ops on chief or dedicated workers."
)
flags.DEFINE_string("master", "local", "Session master.")
flags.DEFINE_integer("ps_tasks", 2,
"Number of tasks in the parameter server job.")
flags.DEFINE_integer("replicas_to_aggregate", 5,
"Number of replicas to aggregate.")
flags.DEFINE_integer("worker_replicas", 5, "Number of replicas in worker job.")
flags.DEFINE_integer("num_epochs", None, "Number of epochs.")
def _is_chief():
"""Determines whether a job is the chief worker."""
if "chief_worker" in FLAGS.brain_jobs:
return FLAGS.brain_job_name == "chief_worker"
else:
return FLAGS.task == 0
def main(unused_argv):
_ = unused_argv
convnet.train_mnist_distributed_sync_replicas(
FLAGS.task, _is_chief(), FLAGS.worker_replicas, FLAGS.ps_tasks,
FLAGS.master, FLAGS.data_dir, FLAGS.num_epochs, FLAGS.cov_inv_op_strategy)
if __name__ == "__main__":
tf.app.run(main=main)

View File

@ -1,48 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Train a ConvNet on MNIST using K-FAC.
Multi tower training mode. See `convnet.train_mnist_multitower` for details.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
import tensorflow as tf
from tensorflow.contrib.kfac.examples import convnet
FLAGS = flags.FLAGS
flags.DEFINE_string("data_dir", "/tmp/multitower_1/mnist", "local mnist dir")
flags.DEFINE_integer("num_towers", 2,
"Number of towers for multi tower training.")
def main(unused_argv):
_ = unused_argv
assert FLAGS.num_towers > 1
devices = ["/gpu:{}".format(tower_id) for tower_id in range(FLAGS.num_towers)]
convnet.train_mnist_multitower(
FLAGS.data_dir,
num_epochs=200,
num_towers=FLAGS.num_towers,
devices=devices)
if __name__ == "__main__":
tf.app.run(main=main)

View File

@ -1,39 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Train a ConvNet on MNIST using K-FAC.
Train on single machine. See `convnet.train_mnist_single_machine` for details.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
import tensorflow as tf
from tensorflow.contrib.kfac.examples import convnet
FLAGS = flags.FLAGS
flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir")
def main(unused_argv):
convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200)
if __name__ == "__main__":
tf.app.run(main=main)

View File

@ -1,354 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Train an MLP on MNIST using K-FAC.
This library fits a 3-layer, tanh-activated MLP on MNIST using K-FAC. After
~25k steps, this should reach perfect accuracy on the training set.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.kfac.examples import mnist
lc = tf.contrib.kfac.layer_collection
opt = tf.contrib.kfac.optimizer
__all__ = [
"fc_layer",
"train_mnist",
"train_mnist_multitower",
]
def fc_layer(layer_id, inputs, output_size):
"""Builds a fully connected layer.
Args:
layer_id: int. Integer ID for this layer's variables.
inputs: Tensor of shape [num_examples, input_size]. Each row corresponds
to a single example.
output_size: int. Number of output dimensions after fully connected layer.
Returns:
preactivations: Tensor of shape [num_examples, output_size]. Values of the
layer immediately before the activation function.
activations: Tensor of shape [num_examples, output_size]. Values of the
layer immediately after the activation function.
params: Tuple of (weights, bias), parameters for this layer.
"""
# TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
layer = tf.layers.Dense(
output_size,
kernel_initializer=tf.random_normal_initializer(),
name="fc_%d" % layer_id)
preactivations = layer(inputs)
activations = tf.nn.tanh(preactivations)
# layer.weights is a list. This converts it a (hashable) tuple.
return preactivations, activations, (layer.kernel, layer.bias)
def build_model(examples, labels, num_labels, layer_collection):
"""Builds an MLP classification model.
Args:
examples: Tensor of shape [num_examples, num_features]. Represents inputs of
model.
labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted
by softmax for each example.
num_labels: int. Number of distinct values 'labels' can take on.
layer_collection: LayerCollection instance describing model architecture.
Returns:
loss: 0-D Tensor representing loss to be minimized.
accuracy: 0-D Tensor representing model's accuracy.
"""
# Build an MLP. For each layer, we'll keep track of the preactivations,
# activations, weights, and bias.
pre0, act0, params0 = fc_layer(layer_id=0, inputs=examples, output_size=128)
pre1, act1, params1 = fc_layer(layer_id=1, inputs=act0, output_size=64)
pre2, act2, params2 = fc_layer(layer_id=2, inputs=act1, output_size=32)
logits, _, params3 = fc_layer(layer_id=3, inputs=act2, output_size=num_labels)
loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits))
accuracy = tf.reduce_mean(
tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32))
# Register parameters. K-FAC needs to know about the inputs, outputs, and
# parameters of each layer and the logits powering the posterior probability
# over classes.
tf.logging.info("Building LayerCollection.")
layer_collection.register_fully_connected(params0, examples, pre0)
layer_collection.register_fully_connected(params1, act0, pre1)
layer_collection.register_fully_connected(params2, act1, pre2)
layer_collection.register_fully_connected(params3, act2, logits)
layer_collection.register_categorical_predictive_distribution(
logits, name="logits")
return loss, accuracy
def minimize(loss, accuracy, layer_collection, num_towers, session_config=None):
"""Minimize 'loss' with KfacOptimizer.
Args:
loss: 0-D Tensor. Loss to be minimized.
accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
layer_collection: LayerCollection instance. Describes layers in model.
num_towers: int. Number of CPUs to split minibatch across.
session_config: tf.ConfigProto. Configuration for tf.Session().
Returns:
accuracy of classifier on final minibatch.
"""
devices = tuple("/cpu:%d" % tower_id for tower_id in range(num_towers))
# Train with K-FAC. We'll use a decreasing learning rate that's cut in 1/2
# every 10k iterations.
tf.logging.info("Building KFAC Optimizer.")
global_step = tf.train.get_or_create_global_step()
optimizer = opt.KfacOptimizer(
learning_rate=tf.train.exponential_decay(
0.00002, global_step, 10000, 0.5, staircase=True),
cov_ema_decay=0.95,
damping=0.0005,
layer_collection=layer_collection,
momentum=0.99,
placement_strategy="round_robin",
cov_devices=devices,
inv_devices=devices)
(cov_update_thunks,
inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
def make_update_op(update_thunks):
update_ops = [thunk() for thunk in update_thunks]
return tf.group(*update_ops)
# TODO(b/78537047): change (some) examples to use PeriodicInvCovUpdateKfacOpt
# once that gets moved over? Could still leave more advanced examples as they
# are (e.g. train_mnist_estimator in this file)
cov_update_op = make_update_op(cov_update_thunks)
with tf.control_dependencies([cov_update_op]):
# We update the inverses only every 20 iterations.
inverse_op = tf.cond(
tf.equal(tf.mod(global_step, 100), 0),
lambda: make_update_op(inv_update_thunks), tf.no_op)
with tf.control_dependencies([inverse_op]):
train_op = optimizer.minimize(loss, global_step=global_step)
tf.logging.info("Starting training.")
with tf.train.MonitoredTrainingSession(config=session_config) as sess:
while not sess.should_stop():
global_step_, loss_, accuracy_, _ = sess.run(
[global_step, loss, accuracy, train_op])
if global_step_ % 100 == 0:
tf.logging.info("global_step: %d | loss: %f | accuracy: %f",
global_step_, loss_, accuracy_)
return accuracy_
def train_mnist(data_dir, num_epochs, use_fake_data=False):
"""Train an MLP on MNIST.
Args:
data_dir: string. Directory to read MNIST examples from.
num_epochs: int. Number of passes to make over the training set.
use_fake_data: bool. If True, generate a synthetic dataset.
Returns:
accuracy of model on the final minibatch of training data.
"""
# Load a dataset.
tf.logging.info("Loading MNIST into memory.")
examples, labels = mnist.load_mnist(
data_dir,
num_epochs=num_epochs,
batch_size=64,
flatten_images=True,
use_fake_data=use_fake_data)
# Build an MLP. The model's layers will be added to the LayerCollection.
tf.logging.info("Building model.")
layer_collection = lc.LayerCollection()
loss, accuracy = build_model(examples, labels, 10, layer_collection)
# Fit model.
minimize(loss, accuracy, layer_collection, 1)
def train_mnist_multitower(data_dir,
num_epochs,
num_towers,
use_fake_data=False):
"""Train an MLP on MNIST, splitting the minibatch across multiple towers.
Args:
data_dir: string. Directory to read MNIST examples from.
num_epochs: int. Number of passes to make over the training set.
num_towers: int. Number of CPUs to split minibatch across.
use_fake_data: bool. If True, generate a synthetic dataset.
Returns:
accuracy of model on the final minibatch of training data.
"""
# Load a dataset.
tower_batch_size = 64
batch_size = tower_batch_size * num_towers
tf.logging.info(
("Loading MNIST into memory. Using batch_size = %d = %d towers * %d "
"tower batch size.") % (batch_size, num_towers, tower_batch_size))
examples, labels = mnist.load_mnist(
data_dir,
num_epochs=num_epochs,
batch_size=batch_size,
flatten_images=True,
use_fake_data=use_fake_data)
# Split minibatch across towers.
examples = tf.split(examples, num_towers)
labels = tf.split(labels, num_towers)
# Build an MLP. Each tower's layers will be added to the LayerCollection.
layer_collection = lc.LayerCollection()
tower_results = []
for tower_id in range(num_towers):
with tf.device("/cpu:%d" % tower_id):
with tf.name_scope("tower%d" % tower_id):
with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)):
tf.logging.info("Building tower %d." % tower_id)
tower_results.append(
build_model(examples[tower_id], labels[tower_id], 10,
layer_collection))
losses, accuracies = zip(*tower_results)
# Average across towers.
loss = tf.reduce_mean(losses)
accuracy = tf.reduce_mean(accuracies)
# Fit model.
session_config = tf.ConfigProto(
allow_soft_placement=False, device_count={
"CPU": num_towers
})
return minimize(
loss, accuracy, layer_collection, num_towers,
session_config=session_config)
def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False):
"""Train an MLP on MNIST using tf.estimator.
Args:
data_dir: string. Directory to read MNIST examples from.
num_epochs: int. Number of passes to make over the training set.
use_fake_data: bool. If True, generate a synthetic dataset.
Returns:
accuracy of model on the final minibatch of training data.
"""
# Load a dataset.
def input_fn():
tf.logging.info("Loading MNIST into memory.")
return mnist.load_mnist(
data_dir,
num_epochs=num_epochs,
batch_size=64,
flatten_images=True,
use_fake_data=use_fake_data)
def model_fn(features, labels, mode, params):
"""Model function for MLP trained with K-FAC.
Args:
features: Tensor of shape [batch_size, input_size]. Input features.
labels: Tensor of shape [batch_size]. Target labels for training.
mode: tf.estimator.ModeKey. Must be TRAIN.
params: ignored.
Returns:
EstimatorSpec for training.
Raises:
ValueError: If 'mode' is anything other than TRAIN.
"""
del params
if mode != tf.estimator.ModeKeys.TRAIN:
raise ValueError("Only training is supposed with this API.")
# Build a ConvNet.
layer_collection = lc.LayerCollection()
loss, accuracy = build_model(
features, labels, num_labels=10, layer_collection=layer_collection)
# Train with K-FAC.
global_step = tf.train.get_or_create_global_step()
optimizer = opt.KfacOptimizer(
learning_rate=tf.train.exponential_decay(
0.00002, global_step, 10000, 0.5, staircase=True),
cov_ema_decay=0.95,
damping=0.0001,
layer_collection=layer_collection,
momentum=0.99)
(cov_update_thunks,
inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
def make_update_op(update_thunks):
update_ops = [thunk() for thunk in update_thunks]
return tf.group(*update_ops)
def make_batch_executed_op(update_thunks, batch_size=1):
return tf.group(*tf.contrib.kfac.utils.batch_execute(
global_step, update_thunks, batch_size=batch_size))
# Run cov_update_op every step. Run 1 inv_update_ops per step.
cov_update_op = make_update_op(cov_update_thunks)
with tf.control_dependencies([cov_update_op]):
# But make sure to execute all the inverse ops on the first step
inverse_op = tf.cond(tf.equal(global_step, 0),
lambda: make_update_op(inv_update_thunks),
lambda: make_batch_executed_op(inv_update_thunks))
with tf.control_dependencies([inverse_op]):
train_op = optimizer.minimize(loss, global_step=global_step)
# Print metrics every 5 sec.
hooks = [
tf.train.LoggingTensorHook(
{
"loss": loss,
"accuracy": accuracy
}, every_n_secs=5),
]
return tf.estimator.EstimatorSpec(
mode=mode, loss=loss, train_op=train_op, training_hooks=hooks)
run_config = tf.estimator.RunConfig(
model_dir="/tmp/mnist", save_checkpoints_steps=1, keep_checkpoint_max=100)
# Train until input_fn() is empty with Estimator. This is a prerequisite for
# TPU compatibility.
estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)
estimator.train(input_fn=input_fn)

View File

@ -1,64 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Train an MLP on MNIST using K-FAC.
See mlp.py for details.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import tensorflow as tf
from tensorflow.contrib.kfac.examples import mlp
FLAGS = None
def main(argv):
_ = argv
if FLAGS.use_estimator:
if FLAGS.num_towers != 1:
raise ValueError("Only 1 device supported in tf.estimator example.")
mlp.train_mnist_estimator(FLAGS.data_dir, num_epochs=200)
elif FLAGS.num_towers > 1:
mlp.train_mnist_multitower(
FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers)
else:
mlp.train_mnist(FLAGS.data_dir, num_epochs=200)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_dir",
type=str,
default="/tmp/mnist",
help="Directory to store dataset in.")
parser.add_argument(
"--num_towers",
type=int,
default=1,
help="Number of CPUs to split minibatch across.")
parser.add_argument(
"--use_estimator",
action="store_true",
help="Use tf.estimator API to train.")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -1,69 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for loading MNIST into TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
__all__ = [
'load_mnist',
]
def load_mnist(data_dir,
num_epochs,
batch_size,
flatten_images=True,
use_fake_data=False):
"""Loads MNIST dataset into memory.
Args:
data_dir: string. Directory to read MNIST examples from.
num_epochs: int. Number of passes to make over the dataset.
batch_size: int. Number of examples per minibatch.
flatten_images: bool. If True, [28, 28, 1]-shaped images are flattened into
[784]-shaped vectors.
use_fake_data: bool. If True, generate a synthetic dataset rather than
reading MNIST in.
Returns:
examples: Tensor of shape [batch_size, 784] if 'flatten_images' is
True, else [batch_size, 28, 28, 1]. Each row is one example.
Values in [0, 1].
labels: Tensor of shape [batch_size]. Indices of integer corresponding to
each example. Values in {0...9}.
"""
if use_fake_data:
rng = np.random.RandomState(42)
num_examples = batch_size * 4
images = rng.rand(num_examples, 28 * 28)
if not flatten_images:
images = np.reshape(images, [num_examples, 28, 28, 1])
labels = rng.randint(10, size=num_examples)
else:
mnist_data = tf.contrib.learn.datasets.mnist.read_data_sets(
data_dir, reshape=flatten_images)
num_examples = len(mnist_data.train.labels)
images = mnist_data.train.images
labels = mnist_data.train.labels
dataset = tf.data.Dataset.from_tensor_slices((np.asarray(
images, dtype=np.float32), np.asarray(labels, dtype=np.int64)))
return (dataset.repeat(num_epochs).shuffle(num_examples).batch(batch_size)
.make_one_shot_iterator().get_next())

View File

@ -1,52 +0,0 @@
package(default_visibility = ["//visibility:private"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test")
py_test(
name = "mlp_test",
size = "large",
srcs = ["mlp_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_pip",
"notsan",
],
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/contrib/kfac/examples:mlp",
"//third_party/py/numpy",
],
)
py_test(
name = "convnet_test",
size = "large",
srcs = ["convnet_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_pip",
"notsan",
],
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/contrib/kfac",
"//tensorflow/contrib/kfac/examples:convnet",
"//third_party/py/numpy",
],
)
py_test(
name = "mnist_test",
srcs = ["mnist_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/contrib/kfac/examples:mnist",
"//third_party/py/numpy",
],
)

View File

@ -1,166 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for convnet.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.contrib.kfac import layer_collection as lc
from tensorflow.contrib.kfac.examples import convnet
class ConvNetTest(tf.test.TestCase):
def testConvLayer(self):
with tf.Graph().as_default():
pre, act, (w, b) = convnet.conv_layer(
layer_id=1,
inputs=tf.zeros([5, 3, 3, 2]),
kernel_size=3,
out_channels=5)
self.assertShapeEqual(np.zeros([5, 3, 3, 5]), pre)
self.assertShapeEqual(np.zeros([5, 3, 3, 5]), act)
self.assertShapeEqual(np.zeros([3, 3, 2, 5]), tf.convert_to_tensor(w))
self.assertShapeEqual(np.zeros([5]), tf.convert_to_tensor(b))
self.assertIsInstance(w, tf.Variable)
self.assertIsInstance(b, tf.Variable)
self.assertIn("conv_1", w.op.name)
self.assertIn("conv_1", b.op.name)
def testMaxPoolLayer(self):
with tf.Graph().as_default():
act = convnet.max_pool_layer(
layer_id=1, inputs=tf.zeros([5, 6, 6, 2]), kernel_size=5, stride=3)
self.assertShapeEqual(np.zeros([5, 2, 2, 2]), act)
self.assertEqual(act.op.name, "pool_1/pool")
def testLinearLayer(self):
with tf.Graph().as_default():
act, (w, b) = convnet.linear_layer(
layer_id=1, inputs=tf.zeros([5, 20]), output_size=5)
self.assertShapeEqual(np.zeros([5, 5]), act)
self.assertShapeEqual(np.zeros([20, 5]), tf.convert_to_tensor(w))
self.assertShapeEqual(np.zeros([5]), tf.convert_to_tensor(b))
self.assertIsInstance(w, tf.Variable)
self.assertIsInstance(b, tf.Variable)
self.assertIn("fc_1", w.op.name)
self.assertIn("fc_1", b.op.name)
def testBuildModel(self):
with tf.Graph().as_default():
x = tf.placeholder(tf.float32, [None, 6, 6, 3])
y = tf.placeholder(tf.int64, [None])
layer_collection = lc.LayerCollection()
loss, accuracy = convnet.build_model(
x, y, num_labels=5, layer_collection=layer_collection)
# Ensure layers and logits were registered.
self.assertEqual(len(layer_collection.fisher_blocks), 3)
self.assertEqual(len(layer_collection.losses), 1)
# Ensure inference doesn't crash.
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
feed_dict = {
x: np.random.randn(10, 6, 6, 3).astype(np.float32),
y: np.random.randint(5, size=10).astype(np.int64),
}
sess.run([loss, accuracy], feed_dict=feed_dict)
def _build_toy_problem(self):
"""Construct a toy linear regression problem.
Initial loss should be,
2.5 = 0.5 * (1^2 + 2^2)
Returns:
loss: 0-D Tensor representing loss to be minimized.
accuracy: 0-D Tensors representing model accuracy.
layer_collection: LayerCollection instance describing model architecture.
"""
x = np.asarray([[1.], [2.]]).astype(np.float32)
y = np.asarray([1., 2.]).astype(np.float32)
x, y = (tf.data.Dataset.from_tensor_slices((x, y))
.repeat(100).batch(2).make_one_shot_iterator().get_next())
w = tf.get_variable("w", shape=[1, 1], initializer=tf.zeros_initializer())
y_hat = tf.matmul(x, w)
loss = tf.reduce_mean(0.5 * tf.square(y_hat - y))
accuracy = loss
layer_collection = lc.LayerCollection()
layer_collection.register_fully_connected(params=w, inputs=x, outputs=y_hat)
layer_collection.register_normal_predictive_distribution(y_hat)
return loss, accuracy, layer_collection
def testMinimizeLossSingleMachine(self):
with tf.Graph().as_default():
loss, accuracy, layer_collection = self._build_toy_problem()
accuracy_ = convnet.minimize_loss_single_machine(
loss, accuracy, layer_collection, device="/cpu:0")
self.assertLess(accuracy_, 2.0)
def testMinimizeLossDistributed(self):
with tf.Graph().as_default():
loss, accuracy, layer_collection = self._build_toy_problem()
accuracy_ = convnet.distributed_grads_only_and_ops_chief_worker(
task_id=0,
is_chief=True,
num_worker_tasks=1,
num_ps_tasks=0,
master="",
checkpoint_dir=None,
loss=loss,
accuracy=accuracy,
layer_collection=layer_collection)
self.assertLess(accuracy_, 2.0)
def testTrainMnistSingleMachine(self):
with tf.Graph().as_default():
# Ensure model training doesn't crash.
#
# Ideally, we should check that accuracy increases as the model converges,
# but there are too few parameters for the model to effectively memorize
# the training set the way an MLP can.
convnet.train_mnist_single_machine(
data_dir=None, num_epochs=1, use_fake_data=True, device="/cpu:0")
def testTrainMnistMultitower(self):
with tf.Graph().as_default():
# Ensure model training doesn't crash.
convnet.train_mnist_multitower(
data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True)
def testTrainMnistDistributed(self):
with tf.Graph().as_default():
# Ensure model training doesn't crash.
convnet.train_mnist_distributed_sync_replicas(
task_id=0,
is_chief=True,
num_worker_tasks=1,
num_ps_tasks=0,
master="",
data_dir=None,
num_epochs=2,
op_strategy="chief_worker",
use_fake_data=True)
if __name__ == "__main__":
tf.test.main()

View File

@ -1,63 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for mlp.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.contrib.kfac.examples import mlp
class MlpTest(tf.test.TestCase):
def testFcLayer(self):
with tf.Graph().as_default():
pre, act, (w, b) = mlp.fc_layer(
layer_id=1, inputs=tf.zeros([5, 3]), output_size=10)
self.assertShapeEqual(np.zeros([5, 10]), pre)
self.assertShapeEqual(np.zeros([5, 10]), act)
self.assertShapeEqual(np.zeros([3, 10]), tf.convert_to_tensor(w))
self.assertShapeEqual(np.zeros([10]), tf.convert_to_tensor(b))
self.assertIsInstance(w, tf.Variable)
self.assertIsInstance(b, tf.Variable)
self.assertIn("fc_1/", w.op.name)
self.assertIn("fc_1/", b.op.name)
def testTrainMnist(self):
with tf.Graph().as_default():
# Ensure model training doesn't crash.
#
# Ideally, we should check that accuracy increases as the model converges,
# but that takes a non-trivial amount of compute.
mlp.train_mnist(data_dir=None, num_epochs=1, use_fake_data=True)
def testTrainMnistMultitower(self):
with tf.Graph().as_default():
# Ensure model training doesn't crash.
mlp.train_mnist_multitower(
data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True)
def testTrainMnistEstimator(self):
with tf.Graph().as_default():
# Ensure model training doesn't crash.
mlp.train_mnist_estimator(data_dir=None, num_epochs=1, use_fake_data=True)
if __name__ == "__main__":
tf.test.main()

View File

@ -1,72 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for mnist.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.contrib.kfac.examples import mnist
class MnistTest(tf.test.TestCase):
def testValues(self):
"""Ensure values are in their expected range."""
with tf.Graph().as_default():
examples, labels = mnist.load_mnist(
data_dir=None, num_epochs=1, batch_size=64, use_fake_data=True)
with self.test_session() as sess:
examples_, labels_ = sess.run([examples, labels])
self.assertTrue(np.all((0 <= examples_) & (examples_ < 1)))
self.assertTrue(np.all((0 <= labels_) & (labels_ < 10)))
def testFlattenedShapes(self):
"""Ensure images are flattened into their appropriate shape."""
with tf.Graph().as_default():
examples, labels = mnist.load_mnist(
data_dir=None,
num_epochs=1,
batch_size=64,
flatten_images=True,
use_fake_data=True)
with self.test_session() as sess:
examples_, labels_ = sess.run([examples, labels])
self.assertEqual(examples_.shape, (64, 784))
self.assertEqual(labels_.shape, (64,))
def testNotFlattenedShapes(self):
"""Ensure non-flattened images are their appropriate shape."""
with tf.Graph().as_default():
examples, labels = mnist.load_mnist(
data_dir=None,
num_epochs=1,
batch_size=64,
flatten_images=False,
use_fake_data=True)
with self.test_session() as sess:
examples_, labels_ = sess.run([examples, labels])
self.assertEqual(examples_.shape, (64, 28, 28, 1))
self.assertEqual(labels_.shape, (64,))
if __name__ == '__main__':
tf.test.main()

Binary file not shown.

Before

Width:  |  Height:  |  Size: 53 KiB

View File

@ -1,160 +0,0 @@
package(default_visibility = ["//visibility:private"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test")
py_test(
name = "estimator_test",
srcs = ["estimator_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/kfac/python/ops:fisher_estimator",
"//tensorflow/contrib/kfac/python/ops:layer_collection",
"//tensorflow/contrib/kfac/python/ops:utils",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//third_party/py/numpy",
],
)
py_test(
name = "fisher_factors_test",
srcs = ["fisher_factors_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/kfac/python/ops:fisher_blocks",
"//tensorflow/contrib/kfac/python/ops:fisher_factors",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python:random_seed",
"//tensorflow/python:variables",
"//third_party/py/numpy",
],
)
py_test(
name = "fisher_blocks_test",
srcs = ["fisher_blocks_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/kfac/python/ops:fisher_blocks",
"//tensorflow/contrib/kfac/python/ops:layer_collection",
"//tensorflow/contrib/kfac/python/ops:linear_operator",
"//tensorflow/contrib/kfac/python/ops:utils",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:random_seed",
"//tensorflow/python:state_ops",
"//tensorflow/python:variables",
"//third_party/py/numpy",
],
)
py_test(
name = "layer_collection_test",
srcs = ["layer_collection_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/kfac/python/ops:fisher_blocks",
"//tensorflow/contrib/kfac/python/ops:fisher_factors",
"//tensorflow/contrib/kfac/python/ops:layer_collection",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:random_seed",
"//tensorflow/python:variable_scope",
],
)
py_test(
name = "optimizer_test",
srcs = ["optimizer_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/kfac/python/ops:fisher_factors",
"//tensorflow/contrib/kfac/python/ops:kfac_optimizer",
"//tensorflow/contrib/kfac/python/ops:layer_collection",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//third_party/py/numpy",
],
)
py_test(
name = "utils_test",
srcs = ["utils_test.py"],
srcs_version = "PY2AND3",
tags = ["no_windows"], # TODO: needs investigation on Windows
deps = [
"//tensorflow/contrib/kfac/python/ops:utils",
"//tensorflow/contrib/tpu",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:random_seed",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//third_party/py/numpy",
],
)
py_test(
name = "op_queue_test",
srcs = ["op_queue_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/kfac/python/ops:op_queue",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
],
)
py_test(
name = "loss_functions_test",
srcs = ["loss_functions_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/kfac/python/ops:loss_functions",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_ops",
"//tensorflow/python:random_ops",
"//third_party/py/numpy",
],
)

View File

@ -1,310 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf.contrib.kfac.estimator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.kfac.python.ops import estimator
from tensorflow.contrib.kfac.python.ops import layer_collection as lc
from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import training_util
_ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"]
class EstimatorTest(test.TestCase):
def setUp(self):
self._graph = ops.Graph()
with self._graph.as_default():
self.layer_collection = lc.LayerCollection()
self.inputs = random_ops.random_normal((2, 2), dtype=dtypes.float32)
self.weights = variable_scope.get_variable(
"w", shape=(2, 2), dtype=dtypes.float32)
self.bias = variable_scope.get_variable(
"b", initializer=init_ops.zeros_initializer(), shape=(2, 1))
self.output = math_ops.matmul(self.inputs, self.weights) + self.bias
# Only register the weights.
self.layer_collection.register_fully_connected(
params=(self.weights,), inputs=self.inputs, outputs=self.output)
self.outputs = math_ops.tanh(self.output)
self.targets = array_ops.zeros_like(self.outputs)
self.layer_collection.register_categorical_predictive_distribution(
logits=self.outputs, targets=self.targets)
def testEstimatorInitManualRegistration(self):
with self._graph.as_default():
# We should be able to build an estimator for only the registered vars.
estimator.FisherEstimatorRoundRobin(
variables=[self.weights],
cov_ema_decay=0.1,
damping=0.2,
layer_collection=self.layer_collection
)
# Check that we throw an error if we try to build an estimator for vars
# that were not manually registered.
with self.assertRaises(ValueError):
est = estimator.FisherEstimatorRoundRobin(
variables=[self.weights, self.bias],
cov_ema_decay=0.1,
damping=0.2,
layer_collection=self.layer_collection
)
est.make_vars_and_create_op_thunks()
# Check that we throw an error if we don't include registered variables,
# i.e. self.weights
with self.assertRaises(ValueError):
est = estimator.FisherEstimatorRoundRobin(
variables=[],
cov_ema_decay=0.1,
damping=0.2,
layer_collection=self.layer_collection)
est.make_vars_and_create_op_thunks()
@test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42)
def testVariableWrongNumberOfUses(self, mock_uses):
with self.assertRaises(ValueError):
est = estimator.FisherEstimatorRoundRobin(
variables=[self.weights],
cov_ema_decay=0.1,
damping=0.2,
layer_collection=self.layer_collection)
est.make_vars_and_create_op_thunks()
def testInvalidEstimationMode(self):
with self.assertRaises(ValueError):
est = estimator.FisherEstimatorRoundRobin(
variables=[self.weights],
cov_ema_decay=0.1,
damping=0.2,
layer_collection=self.layer_collection,
estimation_mode="not_a_real_mode")
est.make_vars_and_create_op_thunks()
def testGradientsModeBuild(self):
with self._graph.as_default():
est = estimator.FisherEstimatorRoundRobin(
variables=[self.weights],
cov_ema_decay=0.1,
damping=0.2,
layer_collection=self.layer_collection,
estimation_mode="gradients")
est.make_vars_and_create_op_thunks()
def testEmpiricalModeBuild(self):
with self._graph.as_default():
est = estimator.FisherEstimatorRoundRobin(
variables=[self.weights],
cov_ema_decay=0.1,
damping=0.2,
layer_collection=self.layer_collection,
estimation_mode="empirical")
est.make_vars_and_create_op_thunks()
def testCurvaturePropModeBuild(self):
with self._graph.as_default():
est = estimator.FisherEstimatorRoundRobin(
variables=[self.weights],
cov_ema_decay=0.1,
damping=0.2,
layer_collection=self.layer_collection,
estimation_mode="curvature_prop")
est.make_vars_and_create_op_thunks()
def testExactModeBuild(self):
with self._graph.as_default():
est = estimator.FisherEstimatorRoundRobin(
variables=[self.weights],
cov_ema_decay=0.1,
damping=0.2,
layer_collection=self.layer_collection,
estimation_mode="exact")
est.make_vars_and_create_op_thunks()
def test_cov_update_thunks(self):
"""Ensures covariance update ops run once per global_step."""
with self._graph.as_default(), self.cached_session() as sess:
fisher_estimator = estimator.FisherEstimatorRoundRobin(
variables=[self.weights],
layer_collection=self.layer_collection,
damping=0.2,
cov_ema_decay=0.0)
# Construct an op that executes one covariance update per step.
global_step = training_util.get_or_create_global_step()
(cov_variable_thunks, cov_update_op_thunks, _,
_) = fisher_estimator.create_ops_and_vars_thunks()
for thunk in cov_variable_thunks:
thunk()
cov_matrices = [
fisher_factor.get_cov()
for fisher_factor in self.layer_collection.get_factors()
]
cov_update_op = control_flow_ops.case(
[(math_ops.equal(global_step, i), thunk)
for i, thunk in enumerate(cov_update_op_thunks)])
increment_global_step = global_step.assign_add(1)
sess.run(variables.global_variables_initializer())
initial_cov_values = sess.run(cov_matrices)
# Ensure there's one update per covariance matrix.
self.assertEqual(len(cov_matrices), len(cov_update_op_thunks))
# Test is no-op if only 1 covariance matrix.
assert len(cov_matrices) > 1
for i in range(len(cov_matrices)):
# Compare new and old covariance values
new_cov_values = sess.run(cov_matrices)
is_cov_equal = [
np.allclose(initial_cov_value, new_cov_value)
for (initial_cov_value,
new_cov_value) in zip(initial_cov_values, new_cov_values)
]
num_cov_equal = sum(is_cov_equal)
# Ensure exactly one covariance matrix changes per step.
self.assertEqual(num_cov_equal, len(cov_matrices) - i)
# Run all covariance update ops.
sess.run(cov_update_op)
sess.run(increment_global_step)
def test_round_robin_placement(self):
"""Check if the ops and variables are placed on devices correctly."""
with self._graph.as_default():
fisher_estimator = estimator.FisherEstimatorRoundRobin(
variables=[self.weights],
layer_collection=self.layer_collection,
damping=0.2,
cov_ema_decay=0.0,
cov_devices=["/cpu:{}".format(i) for i in range(2)],
inv_devices=["/cpu:{}".format(i) for i in range(2)])
# Construct an op that executes one covariance update per step.
(cov_update_thunks,
inv_update_thunks) = fisher_estimator.make_vars_and_create_op_thunks(
scope="test")
cov_update_ops = tuple(thunk() for thunk in cov_update_thunks)
inv_update_ops = tuple(thunk() for thunk in inv_update_thunks)
self.assertEqual(cov_update_ops[0].device, "/device:CPU:0")
self.assertEqual(cov_update_ops[1].device, "/device:CPU:1")
self.assertEqual(inv_update_ops[0].device, "/device:CPU:0")
self.assertEqual(inv_update_ops[1].device, "/device:CPU:1")
cov_matrices = [
fisher_factor.get_cov()
for fisher_factor in self.layer_collection.get_factors()
]
inv_matrices = [
matrix
for fisher_factor in self.layer_collection.get_factors()
for matrix in fisher_factor._matpower_by_exp_and_damping.values()
]
self.assertEqual(cov_matrices[0].device, "/device:CPU:0")
self.assertEqual(cov_matrices[1].device, "/device:CPU:1")
# Inverse matrices need to be explicitly placed.
self.assertEqual(inv_matrices[0].device, "")
self.assertEqual(inv_matrices[1].device, "")
def test_inv_update_thunks(self):
"""Ensures inverse update ops run once per global_step."""
with self._graph.as_default(), self.cached_session() as sess:
fisher_estimator = estimator.FisherEstimatorRoundRobin(
variables=[self.weights],
layer_collection=self.layer_collection,
damping=0.2,
cov_ema_decay=0.0)
# Construct op that updates one inverse per global step.
global_step = training_util.get_or_create_global_step()
(cov_variable_thunks, _, inv_variable_thunks,
inv_update_op_thunks) = fisher_estimator.create_ops_and_vars_thunks()
for thunk in cov_variable_thunks:
thunk()
for thunk in inv_variable_thunks:
thunk()
inv_matrices = [
matrix
for fisher_factor in self.layer_collection.get_factors()
for matrix in fisher_factor._matpower_by_exp_and_damping.values()
]
inv_update_op = control_flow_ops.case(
[(math_ops.equal(global_step, i), thunk)
for i, thunk in enumerate(inv_update_op_thunks)])
increment_global_step = global_step.assign_add(1)
sess.run(variables.global_variables_initializer())
initial_inv_values = sess.run(inv_matrices)
# Ensure there's one update per inverse matrix. This is true as long as
# there's no fan-in/fan-out or parameter re-use.
self.assertEqual(len(inv_matrices), len(inv_update_op_thunks))
# Test is no-op if only 1 invariance matrix.
assert len(inv_matrices) > 1
# Assign each covariance matrix a value other than the identity. This
# ensures that the inverse matrices are updated to something different as
# well.
cov_matrices = [
fisher_factor.get_cov()
for fisher_factor in self.layer_collection.get_factors()
]
sess.run([
cov_matrix.assign(2 * linalg_ops.eye(int(cov_matrix.shape[0])))
for cov_matrix in cov_matrices
])
for i in range(len(inv_matrices)):
# Compare new and old inverse values
new_inv_values = sess.run(inv_matrices)
is_inv_equal = [
np.allclose(initial_inv_value, new_inv_value)
for (initial_inv_value,
new_inv_value) in zip(initial_inv_values, new_inv_values)
]
num_inv_equal = sum(is_inv_equal)
# Ensure exactly one inverse matrix changes per step.
self.assertEqual(num_inv_equal, len(inv_matrices) - i)
# Run all inverse update ops.
sess.run(inv_update_op)
sess.run(increment_global_step)
if __name__ == "__main__":
test.main()

View File

@ -1,955 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf.contrib.kfac.fisher_factors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import numpy.random as npr
from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import test
# We need to set these constants since the numerical values used in the tests
# were chosen when these used to be the defaults.
ff.set_global_constants(init_covariances_at_zero=False,
zero_debias=False,
init_inverses_at_zero=False)
def make_damping_func(damping):
return fb._package_func(lambda: damping, damping)
class FisherFactorTestingDummy(ff.FisherFactor):
"""Dummy class to test the non-abstract methods on ff.FisherFactor."""
@property
def _var_scope(self):
return 'dummy/a_b_c'
@property
def _cov_shape(self):
raise NotImplementedError
@property
def _num_sources(self):
return 1
@property
def _dtype(self):
return dtypes.float32
def _compute_new_cov(self):
raise NotImplementedError
def instantiate_covariance(self):
pass
def make_inverse_update_ops(self):
return []
def get_cov(self):
return NotImplementedError
def instantiate_inv_variables(self):
return NotImplementedError
def _num_towers(self):
raise NotImplementedError
def _get_data_device(self):
raise NotImplementedError
def register_matpower(self, exp, damping_func):
raise NotImplementedError
def register_cholesky(self, damping_func):
raise NotImplementedError
def register_cholesky_inverse(self, damping_func):
raise NotImplementedError
def get_matpower(self, exp, damping_func):
raise NotImplementedError
def get_cholesky(self, damping_func):
raise NotImplementedError
def get_cholesky_inverse(self, damping_func):
raise NotImplementedError
def get_cov_as_linear_operator(self):
raise NotImplementedError
class DenseSquareMatrixFactorTestingDummy(ff.DenseSquareMatrixFactor):
"""Dummy class to test the non-abstract methods on ff.DenseSquareMatrixFactor.
"""
def __init__(self, shape):
self._shape = shape
super(DenseSquareMatrixFactorTestingDummy, self).__init__()
@property
def _var_scope(self):
return 'dummy/a_b_c'
@property
def _cov_shape(self):
return self._shape
@property
def _num_sources(self):
return 1
@property
def _dtype(self):
return dtypes.float32
def _compute_new_cov(self):
raise NotImplementedError
def instantiate_covariance(self):
pass
def _num_towers(self):
raise NotImplementedError
def _get_data_device(self):
raise NotImplementedError
class NumericalUtilsTest(test.TestCase):
def testComputeCovAgainstNumpy(self):
with tf_ops.Graph().as_default(), self.cached_session() as sess:
npr.seed(0)
random_seed.set_random_seed(200)
x = npr.randn(100, 3)
cov = ff.compute_cov(array_ops.constant(x))
np_cov = np.dot(x.T, x) / x.shape[0]
self.assertAllClose(sess.run(cov), np_cov)
def testComputeCovAgainstNumpyWithAlternativeNormalizer(self):
with tf_ops.Graph().as_default(), self.cached_session() as sess:
npr.seed(0)
random_seed.set_random_seed(200)
normalizer = 10.
x = npr.randn(100, 3)
cov = ff.compute_cov(array_ops.constant(x), normalizer=normalizer)
np_cov = np.dot(x.T, x) / normalizer
self.assertAllClose(sess.run(cov), np_cov)
def testAppendHomog(self):
with tf_ops.Graph().as_default(), self.cached_session() as sess:
npr.seed(0)
m, n = 3, 4
a = npr.randn(m, n)
a_homog = ff.append_homog(array_ops.constant(a))
np_result = np.hstack([a, np.ones((m, 1))])
self.assertAllClose(sess.run(a_homog), np_result)
class NameStringUtilFunctionTest(test.TestCase):
def _make_tensor(self):
x = array_ops.placeholder(dtypes.float64, (3, 1))
w = array_ops.constant(npr.RandomState(0).randn(3, 3))
y = math_ops.matmul(w, x)
g = gradients_impl.gradients(y, x)[0]
return g
def testScopeStringFromParamsSingleTensor(self):
with tf_ops.Graph().as_default():
g = self._make_tensor()
scope_string = ff.scope_string_from_params(g)
self.assertEqual('gradients_MatMul_grad_MatMul_1', scope_string)
def testScopeStringFromParamsMultipleTensors(self):
with tf_ops.Graph().as_default():
x = array_ops.constant(1,)
y = array_ops.constant(2,)
scope_string = ff.scope_string_from_params((x, y))
self.assertEqual('Const_Const_1', scope_string)
def testScopeStringFromParamsMultipleTypes(self):
with tf_ops.Graph().as_default():
x = array_ops.constant(1,)
y = array_ops.constant(2,)
scope_string = ff.scope_string_from_params([[1, 2, 3], 'foo', True, 4,
(x, y)])
self.assertEqual('1-2-3_foo_True_4_Const__Const_1', scope_string)
def testScopeStringFromParamsUnsupportedType(self):
with tf_ops.Graph().as_default():
x = array_ops.constant(1,)
y = array_ops.constant(2,)
unsupported = 1.2 # Floats are not supported.
with self.assertRaises(ValueError):
ff.scope_string_from_params([[1, 2, 3], 'foo', True, 4, (x, y),
unsupported])
def testScopeStringFromName(self):
with tf_ops.Graph().as_default():
g = self._make_tensor()
scope_string = ff.scope_string_from_name(g)
self.assertEqual('gradients_MatMul_grad_MatMul_1', scope_string)
def testScalarOrTensorToString(self):
with tf_ops.Graph().as_default():
self.assertEqual(ff.scalar_or_tensor_to_string(5.), repr(5.))
g = self._make_tensor()
scope_string = ff.scope_string_from_name(g)
self.assertEqual(ff.scalar_or_tensor_to_string(g), scope_string)
class FisherFactorTest(test.TestCase):
def testMakeInverseUpdateOps(self):
with tf_ops.Graph().as_default():
random_seed.set_random_seed(200)
factor = FisherFactorTestingDummy()
self.assertEqual(0, len(factor.make_inverse_update_ops()))
class DenseSquareMatrixFactorTest(test.TestCase):
def testRegisterDampedInverse(self):
with tf_ops.Graph().as_default():
random_seed.set_random_seed(200)
shape = [2, 2]
factor = DenseSquareMatrixFactorTestingDummy(shape)
factor_var_scope = 'dummy/a_b_c'
damping_funcs = [make_damping_func(0.1),
make_damping_func(0.1),
make_damping_func(1e-5),
make_damping_func(1e-5)]
for damping_func in damping_funcs:
factor.register_inverse(damping_func)
factor.instantiate_inv_variables()
inv = factor.get_inverse(damping_funcs[0]).to_dense()
self.assertEqual(inv, factor.get_inverse(damping_funcs[1]).to_dense())
self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]).to_dense())
self.assertEqual(factor.get_inverse(damping_funcs[2]).to_dense(),
factor.get_inverse(damping_funcs[3]).to_dense())
factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
factor_var_scope)
factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars)
self.assertEqual(set([inv,
factor.get_inverse(damping_funcs[2]).to_dense()]),
set(factor_tensors))
self.assertEqual(shape, inv.get_shape())
def testRegisterMatpower(self):
with tf_ops.Graph().as_default():
random_seed.set_random_seed(200)
shape = [3, 3]
factor = DenseSquareMatrixFactorTestingDummy(shape)
factor_var_scope = 'dummy/a_b_c'
# TODO(b/74201126): Change to using the same func for both once
# Topohash is in place.
damping_func_1 = make_damping_func(0.5)
damping_func_2 = make_damping_func(0.5)
factor.register_matpower(-0.5, damping_func_1)
factor.register_matpower(2, damping_func_2)
factor.instantiate_inv_variables()
factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
factor_var_scope)
factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars)
matpower1 = factor.get_matpower(-0.5, damping_func_1).to_dense()
matpower2 = factor.get_matpower(2, damping_func_2).to_dense()
self.assertEqual(set([matpower1, matpower2]), set(factor_tensors))
self.assertEqual(shape, matpower1.get_shape())
self.assertEqual(shape, matpower2.get_shape())
def testMakeInverseUpdateOps(self):
with tf_ops.Graph().as_default():
random_seed.set_random_seed(200)
factor = FisherFactorTestingDummy()
self.assertEqual(0, len(factor.make_inverse_update_ops()))
def testMakeInverseUpdateOpsManyInversesEigenDecomp(self):
with tf_ops.Graph().as_default(), self.cached_session() as sess:
random_seed.set_random_seed(200)
cov = np.array([[1., 2.], [3., 4.]])
factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
damping_funcs = []
for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1):
damping_funcs.append(make_damping_func(1./i))
for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
factor.register_inverse(damping_funcs[i])
factor.instantiate_inv_variables()
ops = factor.make_inverse_update_ops()
self.assertEqual(1, len(ops))
sess.run(tf_variables.global_variables_initializer())
new_invs = []
sess.run(ops)
for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
# The inverse op will assign the damped inverse of cov to the inv var.
new_invs.append(
sess.run(factor.get_inverse(damping_funcs[i]).to_dense()))
# We want to see that the new invs are all different from each other.
for i in range(len(new_invs)):
for j in range(i + 1, len(new_invs)):
# Just check the first element.
self.assertNotEqual(new_invs[i][0][0], new_invs[j][0][0])
def testMakeInverseUpdateOpsMatPowerEigenDecomp(self):
with tf_ops.Graph().as_default(), self.cached_session() as sess:
random_seed.set_random_seed(200)
cov = np.array([[6., 2.], [2., 4.]])
factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power
damping = 0.5
damping_func = make_damping_func(damping)
factor.register_matpower(exp, damping_func)
factor.instantiate_inv_variables()
ops = factor.make_inverse_update_ops()
self.assertEqual(1, len(ops))
sess.run(tf_variables.global_variables_initializer())
sess.run(ops[0])
matpower = sess.run(factor.get_matpower(exp, damping_func).to_dense())
matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp)
self.assertAllClose(matpower, matpower_np)
def testMakeInverseUpdateOpsNoEigenDecomp(self):
with tf_ops.Graph().as_default(), self.cached_session() as sess:
random_seed.set_random_seed(200)
cov = np.array([[5., 2.], [2., 4.]]) # NOTE(mattjj): must be symmetric
factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
damping_func = make_damping_func(0)
factor.register_inverse(damping_func)
factor.instantiate_inv_variables()
ops = factor.make_inverse_update_ops()
self.assertEqual(1, len(ops))
sess.run(tf_variables.global_variables_initializer())
# The inverse op will assign the damped inverse of cov to the inv var.
old_inv = sess.run(factor.get_inverse(damping_func).to_dense())
self.assertAllClose(
sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv)
sess.run(ops)
new_inv = sess.run(factor.get_inverse(damping_func).to_dense())
self.assertAllClose(new_inv, np.linalg.inv(cov))
class FullFactorTest(test.TestCase):
def testFullFactorInit(self):
with tf_ops.Graph().as_default():
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), name='a/b/c')
factor = ff.FullFactor((tensor,), 32)
factor.instantiate_cov_variables()
self.assertEqual([6, 6], factor.get_cov().get_shape().as_list())
def testFullFactorInitFloat64(self):
with tf_ops.Graph().as_default():
dtype = dtypes.float64_ref
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
factor = ff.FullFactor((tensor,), 32)
factor.instantiate_cov_variables()
cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
self.assertEqual([6, 6], cov.get_shape().as_list())
def testMakeCovarianceUpdateOp(self):
with tf_ops.Graph().as_default(), self.cached_session() as sess:
random_seed.set_random_seed(200)
tensor = array_ops.constant([1., 2.], name='a/b/c')
factor = ff.FullFactor((tensor,), 2)
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))
self.assertAllClose([[0.75, 0.5], [0.5, 1.5]], new_cov)
class NaiveDiagonalFactorTest(test.TestCase):
def testNaiveDiagonalFactorInit(self):
with tf_ops.Graph().as_default():
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), name='a/b/c')
factor = ff.NaiveDiagonalFactor((tensor,), 32)
factor.instantiate_cov_variables()
self.assertEqual([6, 1], factor.get_cov().get_shape().as_list())
def testNaiveDiagonalFactorInitFloat64(self):
with tf_ops.Graph().as_default():
dtype = dtypes.float64_ref
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
factor = ff.NaiveDiagonalFactor((tensor,), 32)
factor.instantiate_cov_variables()
cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
self.assertEqual([6, 1], cov.get_shape().as_list())
def testMakeCovarianceUpdateOp(self):
with tf_ops.Graph().as_default(), self.cached_session() as sess:
random_seed.set_random_seed(200)
tensor = array_ops.constant([1., 2.], name='a/b/c')
factor = ff.NaiveDiagonalFactor((tensor,), 2)
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))
self.assertAllClose([[0.75], [1.5]], new_cov)
class EmbeddingInputKroneckerFactorTest(test.TestCase):
def testInitialization(self):
with tf_ops.Graph().as_default():
input_ids = array_ops.constant([[0], [1], [4]])
vocab_size = 5
factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
factor.instantiate_cov_variables()
cov = factor.get_cov()
self.assertEqual(cov.shape.as_list(), [vocab_size])
def testCovarianceUpdateOp(self):
with tf_ops.Graph().as_default():
input_ids = array_ops.constant([[0], [1], [4]])
vocab_size = 5
factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
factor.instantiate_cov_variables()
cov_update_op = factor.make_covariance_update_op(0.0)
with self.cached_session() as sess:
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(cov_update_op)
self.assertAllClose(np.array([1., 1., 0., 0., 1.]) / 3., new_cov)
class ConvDiagonalFactorTest(test.TestCase):
def setUp(self):
self.batch_size = 10
self.height = self.width = 32
self.in_channels = 3
self.out_channels = 1
self.kernel_height = self.kernel_width = 3
self.strides = [1, 2, 2, 1]
self.data_format = 'NHWC'
self.padding = 'SAME'
self.kernel_shape = [
self.kernel_height, self.kernel_width, self.in_channels,
self.out_channels
]
def testInit(self):
with tf_ops.Graph().as_default():
inputs = random_ops.random_uniform(
[self.batch_size, self.height, self.width, self.in_channels])
outputs_grads = [
random_ops.random_uniform([
self.batch_size, self.height // self.strides[1],
self.width // self.strides[2], self.out_channels
]) for _ in range(3)
]
factor = ff.ConvDiagonalFactor(
(inputs,),
(outputs_grads,),
self.kernel_shape,
self.strides,
self.padding,
data_format=self.data_format)
factor.instantiate_cov_variables()
# Ensure covariance matrix's shape makes sense.
self.assertEqual([
self.kernel_height * self.kernel_width * self.in_channels,
self.out_channels
],
factor.get_cov().shape.as_list())
def testMakeCovarianceUpdateOp(self):
with tf_ops.Graph().as_default():
# Construct all arguments such that convolution kernel is applied in
# exactly one spatial location.
inputs = np.random.randn(
1, # batch_size
self.kernel_height,
self.kernel_width,
self.in_channels) # in_channels
outputs_grad = np.random.randn(
1, # batch_size
1, # output_height
1, # output_width
self.out_channels)
factor = ff.ConvDiagonalFactor(
(constant_op.constant(inputs),),
((constant_op.constant(outputs_grad),),),
self.kernel_shape,
strides=[1, 1, 1, 1],
padding='VALID')
factor.instantiate_cov_variables()
# Completely forget initial value on first update.
cov_update_op = factor.make_covariance_update_op(0.0)
# Ensure new covariance value is same as outer-product of inputs/outputs
# vectorized, squared.
with self.cached_session() as sess:
sess.run(tf_variables.global_variables_initializer())
cov = sess.run(cov_update_op)
expected_cov = np.outer(inputs.flatten(), outputs_grad.flatten())**2
self.assertAllClose(expected_cov, cov)
def testHasBias(self):
with tf_ops.Graph().as_default():
inputs = random_ops.random_uniform(
[self.batch_size, self.height, self.width, self.in_channels])
outputs_grads = [
random_ops.random_uniform([
self.batch_size, self.height // self.strides[1],
self.width // self.strides[2], self.out_channels
]) for _ in range(3)
]
factor = ff.ConvDiagonalFactor(
(inputs,),
(outputs_grads,),
self.kernel_shape,
self.strides,
self.padding,
data_format=self.data_format,
has_bias=True)
factor.instantiate_cov_variables()
# Ensure shape accounts for bias.
self.assertEqual([
self.kernel_height * self.kernel_width * self.in_channels + 1,
self.out_channels
],
factor.get_cov().shape.as_list())
# Ensure update op doesn't crash.
cov_update_op = factor.make_covariance_update_op(0.0)
with self.cached_session() as sess:
sess.run(tf_variables.global_variables_initializer())
sess.run(cov_update_op)
class FullyConnectedKroneckerFactorTest(test.TestCase):
def _testFullyConnectedKroneckerFactorInit(self,
has_bias,
final_shape,
dtype=dtypes.float32_ref):
with tf_ops.Graph().as_default():
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=has_bias)
factor.instantiate_cov_variables()
cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
self.assertEqual(final_shape, cov.get_shape().as_list())
def testFullyConnectedKroneckerFactorInitNoBias(self):
for dtype in (dtypes.float32_ref, dtypes.float64_ref):
self._testFullyConnectedKroneckerFactorInit(False, [3, 3], dtype=dtype)
def testFullyConnectedKroneckerFactorInitWithBias(self):
for dtype in (dtypes.float32_ref, dtypes.float64_ref):
self._testFullyConnectedKroneckerFactorInit(True, [4, 4], dtype=dtype)
def testMakeCovarianceUpdateOpWithBias(self):
with tf_ops.Graph().as_default(), self.cached_session() as sess:
random_seed.set_random_seed(200)
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=True)
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))
self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov)
def testMakeCovarianceUpdateOpNoBias(self):
with tf_ops.Graph().as_default(), self.cached_session() as sess:
random_seed.set_random_seed(200)
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
factor = ff.FullyConnectedKroneckerFactor(((tensor,),))
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))
self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov)
class ConvFactorTestCase(test.TestCase):
def assertMatrixRank(self, rank, matrix, atol=1e-5):
assert rank <= matrix.shape[0], 'Rank cannot be larger than matrix size.'
eigvals = np.linalg.eigvals(matrix)
nnz_eigvals = np.sum(eigvals > atol)
self.assertEqual(
rank,
nnz_eigvals,
msg=('Found %d of %d expected non-zero eigenvalues: %s.' %
(nnz_eigvals, rank, eigvals)))
class ConvInputKroneckerFactorTest(ConvFactorTestCase):
def test3DConvolution(self):
with tf_ops.Graph().as_default():
batch_size = 1
width = 3
in_channels = 3**3
out_channels = 4
factor = ff.ConvInputKroneckerFactor(
inputs=(random_ops.random_uniform(
(batch_size, width, width, width, in_channels), seed=0),),
filter_shape=(width, width, width, in_channels, out_channels),
padding='SAME',
strides=(2, 2, 2),
extract_patches_fn='extract_convolution_patches',
has_bias=False)
factor.instantiate_cov_variables()
# Ensure shape of covariance matches input size of filter.
input_size = in_channels * (width**3)
self.assertEqual([input_size, input_size],
factor.get_cov().shape.as_list())
# Ensure cov_update_op doesn't crash.
with self.cached_session() as sess:
sess.run(tf_variables.global_variables_initializer())
sess.run(factor.make_covariance_update_op(0.0))
cov = sess.run(factor.get_cov())
# Cov should be rank-8, as the filter will be applied at each corner of
# the 4-D cube.
self.assertMatrixRank(8, cov)
def testPointwiseConv2d(self):
with tf_ops.Graph().as_default():
batch_size = 1
width = 3
in_channels = 3**2
out_channels = 4
factor = ff.ConvInputKroneckerFactor(
inputs=(random_ops.random_uniform(
(batch_size, width, width, in_channels), seed=0),),
filter_shape=(1, 1, in_channels, out_channels),
padding='SAME',
strides=(1, 1, 1, 1),
extract_patches_fn='extract_pointwise_conv2d_patches',
has_bias=False)
factor.instantiate_cov_variables()
# Ensure shape of covariance matches input size of filter.
self.assertEqual([in_channels, in_channels],
factor.get_cov().shape.as_list())
# Ensure cov_update_op doesn't crash.
with self.cached_session() as sess:
sess.run(tf_variables.global_variables_initializer())
sess.run(factor.make_covariance_update_op(0.0))
cov = sess.run(factor.get_cov())
# Cov should be rank-9, as the filter will be applied at each location.
self.assertMatrixRank(9, cov)
def testStrides(self):
with tf_ops.Graph().as_default():
batch_size = 1
width = 3
in_channels = 3**2
out_channels = 4
factor = ff.ConvInputKroneckerFactor(
inputs=(random_ops.random_uniform(
(batch_size, width, width, in_channels), seed=0),),
filter_shape=(1, 1, in_channels, out_channels),
padding='SAME',
strides=(1, 2, 1, 1),
extract_patches_fn='extract_image_patches',
has_bias=False)
factor.instantiate_cov_variables()
with self.cached_session() as sess:
sess.run(tf_variables.global_variables_initializer())
sess.run(factor.make_covariance_update_op(0.0))
cov = sess.run(factor.get_cov())
# Cov should be the sum of 3 * 2 = 6 outer products.
self.assertMatrixRank(6, cov)
def testDilationRate(self):
with tf_ops.Graph().as_default():
batch_size = 1
width = 3
in_channels = 2
out_channels = 4
factor = ff.ConvInputKroneckerFactor(
inputs=(random_ops.random_uniform(
(batch_size, width, width, in_channels), seed=0),),
filter_shape=(3, 3, in_channels, out_channels),
padding='SAME',
extract_patches_fn='extract_image_patches',
strides=(1, 1, 1, 1),
dilation_rate=(1, width, width, 1),
has_bias=False)
factor.instantiate_cov_variables()
with self.cached_session() as sess:
sess.run(tf_variables.global_variables_initializer())
sess.run(factor.make_covariance_update_op(0.0))
cov = sess.run(factor.get_cov())
# Cov should be rank = in_channels, as only the center of the filter
# receives non-zero input for each input channel.
self.assertMatrixRank(in_channels, cov)
def testConvInputKroneckerFactorInitNoBias(self):
with tf_ops.Graph().as_default():
tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c')
factor = ff.ConvInputKroneckerFactor(
inputs=(tensor,),
filter_shape=(1, 2, 3, 4),
padding='SAME',
has_bias=False)
factor.instantiate_cov_variables()
self.assertEqual([1 * 2 * 3, 1 * 2 * 3],
factor.get_cov().get_shape().as_list())
def testConvInputKroneckerFactorInit(self):
with tf_ops.Graph().as_default():
tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c')
factor = ff.ConvInputKroneckerFactor(
(tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True)
factor.instantiate_cov_variables()
self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
factor.get_cov().get_shape().as_list())
def testConvInputKroneckerFactorInitFloat64(self):
with tf_ops.Graph().as_default():
dtype = dtypes.float64_ref
tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c', dtype=dtypes.float64)
factor = ff.ConvInputKroneckerFactor(
(tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True)
factor.instantiate_cov_variables()
cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
cov.get_shape().as_list())
def testMakeCovarianceUpdateOpWithBias(self):
with tf_ops.Graph().as_default(), self.cached_session() as sess:
input_shape = (2, 1, 1, 1)
tensor = array_ops.constant(
np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype(
np.float32))
factor = ff.ConvInputKroneckerFactor(
(tensor,), filter_shape=(1, 1, 1, 1), padding='SAME', has_bias=True)
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(0.))
self.assertAllClose(
[
[(1. + 4.) / 2., (1. + 2.) / 2.], #
[(1. + 2.) / 2., (1. + 1.) / 2.]
], #
new_cov)
def testMakeCovarianceUpdateOpNoBias(self):
with tf_ops.Graph().as_default(), self.cached_session() as sess:
input_shape = (2, 1, 1, 1)
tensor = array_ops.constant(
np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype(
np.float32))
factor = ff.ConvInputKroneckerFactor(
(tensor,), filter_shape=(1, 1, 1, 1), padding='SAME')
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(0.))
self.assertAllClose([[(1. + 4.) / 2.]], new_cov)
def testSubSample(self):
with tf_ops.Graph().as_default():
patches_1 = array_ops.constant(1, shape=(10, 2))
patches_2 = array_ops.constant(1, shape=(10, 8))
patches_3 = array_ops.constant(1, shape=(3, 3))
patches_1_sub = ff._subsample_for_cov_computation(patches_1)
patches_2_sub = ff._subsample_for_cov_computation(patches_2)
patches_3_sub = ff._subsample_for_cov_computation(patches_3)
patches_1_sub_batch_size = patches_1_sub.shape.as_list()[0]
patches_2_sub_batch_size = patches_2_sub.shape.as_list()[0]
patches_3_sub_batch_size = patches_3_sub.shape.as_list()[0]
self.assertEqual(2, patches_1_sub_batch_size)
self.assertEqual(8, patches_2_sub_batch_size)
self.assertEqual(3, patches_3_sub_batch_size)
class ConvOutputKroneckerFactorTest(ConvFactorTestCase):
def test3DConvolution(self):
with tf_ops.Graph().as_default():
batch_size = 1
width = 3
out_channels = width**3
factor = ff.ConvOutputKroneckerFactor(outputs_grads=([
random_ops.random_uniform(
(batch_size, width, width, width, out_channels), seed=0)
],))
factor.instantiate_cov_variables()
with self.cached_session() as sess:
sess.run(tf_variables.global_variables_initializer())
sess.run(factor.make_covariance_update_op(0.0))
cov = sess.run(factor.get_cov())
# Cov should be rank 3^3, as each spatial position donates a rank-1
# update.
self.assertMatrixRank(width**3, cov)
def testConvOutputKroneckerFactorInit(self):
with tf_ops.Graph().as_default():
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3, 4, 5), name='a/b/c')
factor = ff.ConvOutputKroneckerFactor(((tensor,),))
factor.instantiate_cov_variables()
self.assertEqual([5, 5], factor.get_cov().get_shape().as_list())
def testConvOutputKroneckerFactorInitFloat64(self):
with tf_ops.Graph().as_default():
dtype = dtypes.float64_ref
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3, 4, 5), dtype=dtype, name='a/b/c')
factor = ff.ConvOutputKroneckerFactor(((tensor,),))
factor.instantiate_cov_variables()
cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
self.assertEqual([5, 5], cov.get_shape().as_list())
def testMakeCovarianceUpdateOp(self):
with tf_ops.Graph().as_default(), self.cached_session() as sess:
random_seed.set_random_seed(200)
tensor = np.arange(1, 17).reshape(2, 2, 2, 2).astype(np.float32)
factor = ff.ConvOutputKroneckerFactor(((array_ops.constant(tensor),),))
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))
self.assertAllClose([[43, 46.5], [46.5, 51.5]], new_cov)
class FullyConnectedMultiKFTest(test.TestCase):
def testFullyConnectedMultiKFInit(self):
with tf_ops.Graph().as_default():
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), name='a/b/c')
factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False)
factor.instantiate_cov_variables()
self.assertEqual([3, 3], factor.get_cov().get_shape().as_list())
def testFullyConnectedMultiKFInitFloat64(self):
with tf_ops.Graph().as_default():
dtype = dtypes.float64_ref
random_seed.set_random_seed(200)
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False)
factor.instantiate_cov_variables()
cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
self.assertEqual([3, 3], cov.get_shape().as_list())
def testMakeCovarianceUpdateOpWithBias(self):
with tf_ops.Graph().as_default(), self.cached_session() as sess:
random_seed.set_random_seed(200)
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=True)
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))
self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov)
def testMakeCovarianceUpdateOpNoBias(self):
with tf_ops.Graph().as_default(), self.cached_session() as sess:
random_seed.set_random_seed(200)
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
factor = ff.FullyConnectedMultiKF(((tensor,),))
factor.instantiate_cov_variables()
sess.run(tf_variables.global_variables_initializer())
new_cov = sess.run(factor.make_covariance_update_op(.5))
self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov)
if __name__ == '__main__':
test.main()

View File

@ -1,597 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf.contrib.kfac.layer_collection."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.kfac.python.ops import fisher_blocks
from tensorflow.contrib.kfac.python.ops import fisher_factors
from tensorflow.contrib.kfac.python.ops import layer_collection
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
class MockFisherBlock(object):
"""A fake FisherBlock."""
num_registered_towers = 2
def __init__(self, name='MockFisherBlock'):
self.name = name
def __eq__(self, other):
return isinstance(other, MockFisherBlock) and other.name == self.name
def __hash__(self):
return hash(self.name)
class LayerParametersDictTest(test.TestCase):
def testSetItem(self):
"""Ensure insertion, contains, retrieval works for supported key types."""
with ops.Graph().as_default():
lp_dict = layer_collection.LayerParametersDict()
x = array_ops.constant(0)
y0 = array_ops.constant(0)
y1 = array_ops.constant(0)
z0 = array_ops.constant(0)
z1 = array_ops.constant(0)
keys = [x, (y0, y1), [z0, z1]]
for key in keys:
lp_dict[key] = key
for key in keys:
self.assertTrue(key in lp_dict)
self.assertEqual(lp_dict[key], key)
def testSetItemOverlap(self):
"""Ensure insertion fails if key overlaps with existing key."""
with ops.Graph().as_default():
lp_dict = layer_collection.LayerParametersDict()
x = array_ops.constant(0)
y = array_ops.constant(0)
lp_dict[x] = 'value'
with self.assertRaises(ValueError):
lp_dict[(x, y)] = 'value'
# Ensure 'y' wasn't inserted.
self.assertTrue(x in lp_dict)
self.assertFalse(y in lp_dict)
class LayerCollectionTest(test.TestCase):
def testLayerCollectionInit(self):
lc = layer_collection.LayerCollection()
self.assertEqual(0, len(lc.get_blocks()))
self.assertEqual(0, len(lc.get_factors()))
self.assertFalse(lc.losses)
def testRegisterBlocks(self):
with ops.Graph().as_default():
random_seed.set_random_seed(200)
lc = layer_collection.LayerCollection()
lc.register_fully_connected(
array_ops.constant(1), array_ops.constant(2), array_ops.constant(3))
lc.register_fully_connected(
array_ops.constant(1),
array_ops.constant(2),
array_ops.constant(3),
approx=layer_collection.APPROX_DIAGONAL_NAME)
lc.register_conv2d(
params=array_ops.ones((2, 3, 4, 5)),
strides=[1, 1, 1, 1],
padding='SAME',
inputs=array_ops.ones((1, 2, 3, 4)),
outputs=array_ops.ones((1, 1, 1, 5)))
lc.register_conv2d(
params=array_ops.ones((2, 3, 4, 5)),
strides=[1, 1, 1, 1],
padding='SAME',
inputs=array_ops.ones((1, 2, 3, 4)),
outputs=array_ops.ones((1, 1, 1, 5)),
approx=layer_collection.APPROX_DIAGONAL_NAME)
lc.register_separable_conv2d(
depthwise_params=array_ops.ones((3, 3, 1, 2)),
pointwise_params=array_ops.ones((1, 1, 2, 4)),
inputs=array_ops.ones((32, 5, 5, 1)),
depthwise_outputs=array_ops.ones((32, 5, 5, 2)),
pointwise_outputs=array_ops.ones((32, 5, 5, 4)),
strides=[1, 1, 1, 1],
padding='SAME')
lc.register_convolution(
params=array_ops.ones((3, 3, 1, 8)),
inputs=array_ops.ones((32, 5, 5, 1)),
outputs=array_ops.ones((32, 5, 5, 8)),
padding='SAME')
lc.register_generic(
array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME)
lc.register_generic(
array_ops.constant(6),
16,
approx=layer_collection.APPROX_DIAGONAL_NAME)
lc.register_fully_connected_multi(
array_ops.constant(1),
(array_ops.constant(2), array_ops.constant(3)),
(array_ops.constant(4), array_ops.constant(5)))
lc.register_conv2d_multi(
params=array_ops.ones((2, 3, 4, 5)),
strides=[1, 1, 1, 1],
padding='SAME',
inputs=(array_ops.ones((1, 2, 3, 4)), array_ops.ones((5, 6, 7, 8))),
outputs=(array_ops.ones((1, 1, 1, 5)), array_ops.ones((2, 2, 2, 10))))
lc.register_embedding_multi(
array_ops.constant((1,)),
(array_ops.constant(2), array_ops.constant(3)),
(array_ops.constant(4), array_ops.constant(5)))
self.assertEqual(12, len(lc.get_blocks()))
def testRegisterBlocksMultipleRegistrations(self):
with ops.Graph().as_default():
random_seed.set_random_seed(200)
lc = layer_collection.LayerCollection()
key = array_ops.constant(1)
lc.register_fully_connected(key, array_ops.constant(2),
array_ops.constant(3))
with self.assertRaises(ValueError) as cm:
lc.register_generic(key, 16)
self.assertIn('already in LayerCollection', str(cm.exception))
def testRegisterSingleParamNotRegistered(self):
x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
lc = layer_collection.LayerCollection()
lc.fisher_blocks = {
variable_scope.get_variable('y', initializer=array_ops.constant(1,)):
'1'
}
lc.register_block(x, 'foo')
def testShouldRegisterSingleParamRegistered(self):
x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
lc = layer_collection.LayerCollection()
lc.fisher_blocks = {x: '1'}
with self.assertRaises(ValueError) as cm:
lc.register_block(x, 'foo')
self.assertIn('already in LayerCollection', str(cm.exception))
def testRegisterSingleParamRegisteredInTuple(self):
x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
lc = layer_collection.LayerCollection()
lc.fisher_blocks = {(x, y): '1'}
with self.assertRaises(ValueError) as cm:
lc.register_block(x, 'foo')
self.assertIn('was already registered', str(cm.exception))
def testRegisterTupleParamNotRegistered(self):
x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
lc = layer_collection.LayerCollection()
lc.fisher_blocks = {
variable_scope.get_variable('z', initializer=array_ops.constant(1,)):
'1'
}
lc.register_block((x, y), 'foo')
self.assertEqual(set(['1', 'foo']), set(lc.get_blocks()))
def testRegisterTupleParamRegistered(self):
x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
lc = layer_collection.LayerCollection()
lc.fisher_blocks = {(x, y): '1'}
with self.assertRaises(ValueError) as cm:
lc.register_block((x, y), 'foo')
self.assertIn('already in LayerCollection', str(cm.exception))
def testRegisterTupleParamRegisteredInSuperset(self):
x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
z = variable_scope.get_variable('z', initializer=array_ops.constant(1,))
lc = layer_collection.LayerCollection()
lc.fisher_blocks = {(x, y, z): '1'}
with self.assertRaises(ValueError) as cm:
lc.register_block((x, y), 'foo')
self.assertIn('was already registered', str(cm.exception))
def testRegisterTupleParamSomeRegistered(self):
x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
z = variable_scope.get_variable('z', initializer=array_ops.constant(1,))
lc = layer_collection.LayerCollection()
lc.fisher_blocks = {x: MockFisherBlock('1'), z: MockFisherBlock('2')}
with self.assertRaises(ValueError) as cm:
lc.register_block((x, y), MockFisherBlock('foo'))
self.assertIn('was already registered', str(cm.exception))
def testRegisterTupleVarSomeRegisteredInOtherTuples(self):
x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
z = variable_scope.get_variable('z', initializer=array_ops.constant(1,))
w = variable_scope.get_variable('w', initializer=array_ops.constant(1,))
lc = layer_collection.LayerCollection()
lc.fisher_blocks = {(x, z): '1', (z, w): '2'}
with self.assertRaises(ValueError) as cm:
lc.register_block((x, y), 'foo')
self.assertIn('was already registered', str(cm.exception))
def testRegisterCategoricalPredictiveDistribution(self):
with ops.Graph().as_default(), self.cached_session() as sess:
random_seed.set_random_seed(200)
logits = linalg_ops.eye(2)
lc = layer_collection.LayerCollection()
lc.register_categorical_predictive_distribution(logits, seed=200)
single_loss = sess.run(lc.total_sampled_loss())
lc2 = layer_collection.LayerCollection()
lc2.register_categorical_predictive_distribution(logits, seed=200)
lc2.register_categorical_predictive_distribution(logits, seed=200)
double_loss = sess.run(lc2.total_sampled_loss())
self.assertAlmostEqual(2 * single_loss, double_loss)
def testLossFunctionByName(self):
"""Ensure loss functions can be identified by name."""
with ops.Graph().as_default():
logits = linalg_ops.eye(2)
lc = layer_collection.LayerCollection()
# Create a new loss function by name.
lc.register_categorical_predictive_distribution(logits, name='loss1')
self.assertEqual(1, len(lc.towers_by_loss))
# Add logits to same loss function.
lc.register_categorical_predictive_distribution(
logits, name='loss1', reuse=True)
self.assertEqual(1, len(lc.towers_by_loss))
# Add another new loss function.
lc.register_categorical_predictive_distribution(logits, name='loss2')
self.assertEqual(2, len(lc.towers_by_loss))
def testLossFunctionWithoutName(self):
"""Ensure loss functions get unique names if 'name' not specified."""
with ops.Graph().as_default():
logits = linalg_ops.eye(2)
lc = layer_collection.LayerCollection()
# Create a new loss function with default names.
lc.register_categorical_predictive_distribution(logits)
lc.register_categorical_predictive_distribution(logits)
self.assertEqual(2, len(lc.losses))
def testCategoricalPredictiveDistributionMultipleMinibatches(self):
"""Ensure multiple minibatches are registered."""
with ops.Graph().as_default():
batch_size = 3
output_size = 2
logits = array_ops.zeros([batch_size, output_size])
targets = array_ops.ones([batch_size], dtype=dtypes.int32)
lc = layer_collection.LayerCollection()
# Create a new loss function.
lc.register_categorical_predictive_distribution(
logits, targets=targets, name='loss1')
# Can add when reuse=True
lc.register_categorical_predictive_distribution(
logits, targets=targets, name='loss1', reuse=True)
# Can add when reuse=VARIABLE_SCOPE and reuse=True there.
with variable_scope.variable_scope(
variable_scope.get_variable_scope(), reuse=True):
lc.register_categorical_predictive_distribution(
logits,
targets=targets,
name='loss1',
reuse=layer_collection.VARIABLE_SCOPE)
# Can't add when reuse=False
with self.assertRaises(KeyError):
lc.register_categorical_predictive_distribution(
logits, targets=targets, name='loss1', reuse=False)
# Can't add when reuse=VARIABLE_SCOPE and reuse=False there.
with self.assertRaises(KeyError):
lc.register_categorical_predictive_distribution(
logits,
targets=targets,
name='loss1',
reuse=layer_collection.VARIABLE_SCOPE)
self.assertEqual(len(lc.towers_by_loss), 1)
# Three successful registrations.
self.assertEqual(len(lc.towers_by_loss[0]), 3)
def testRegisterCategoricalPredictiveDistributionBatchSize1(self):
with ops.Graph().as_default():
random_seed.set_random_seed(200)
logits = random_ops.random_normal((1, 2))
lc = layer_collection.LayerCollection()
lc.register_categorical_predictive_distribution(logits, seed=200)
def testRegisterCategoricalPredictiveDistributionSpecifiedTargets(self):
with ops.Graph().as_default(), self.cached_session() as sess:
random_seed.set_random_seed(200)
logits = array_ops.constant([[1., 2.], [3., 4.]], dtype=dtypes.float32)
lc = layer_collection.LayerCollection()
targets = array_ops.constant([0, 1], dtype=dtypes.int32)
lc.register_categorical_predictive_distribution(logits, targets=targets)
single_loss = sess.run(lc.total_loss())
self.assertAlmostEqual(1.6265233, single_loss)
def testRegisterNormalPredictiveDistribution(self):
with ops.Graph().as_default(), self.cached_session() as sess:
random_seed.set_random_seed(200)
predictions = array_ops.constant(
[[1., 2.], [3., 4]], dtype=dtypes.float32)
lc = layer_collection.LayerCollection()
lc.register_normal_predictive_distribution(predictions, 1., seed=200)
single_loss = sess.run(lc.total_sampled_loss())
lc2 = layer_collection.LayerCollection()
lc2.register_normal_predictive_distribution(predictions, 1., seed=200)
lc2.register_normal_predictive_distribution(predictions, 1., seed=200)
double_loss = sess.run(lc2.total_sampled_loss())
self.assertAlmostEqual(2 * single_loss, double_loss)
def testRegisterNormalPredictiveDistributionSpecifiedTargets(self):
with ops.Graph().as_default(), self.cached_session() as sess:
random_seed.set_random_seed(200)
predictions = array_ops.constant(
[[1., 2.], [3., 4.]], dtype=dtypes.float32)
lc = layer_collection.LayerCollection()
targets = array_ops.constant([[3., 1.], [4., 2.]], dtype=dtypes.float32)
lc.register_normal_predictive_distribution(
predictions, 2.**2, targets=targets)
single_loss = sess.run(lc.total_loss())
self.assertAlmostEqual(7.6983433, single_loss)
def ensureLayerReuseWorks(self, register_fn):
"""Ensure the 'reuse' keyword argument function as intended.
Args:
register_fn: function for registering a layer. Arguments are
layer_collection, reuse, and approx.
"""
# Fails on second if reuse=False.
lc = layer_collection.LayerCollection()
register_fn(lc)
with self.assertRaises(ValueError):
register_fn(lc, reuse=False)
# Succeeds on second if reuse=True.
lc = layer_collection.LayerCollection()
register_fn(lc)
register_fn(lc, reuse=True)
# Fails on second if reuse=VARIABLE_SCOPE and no variable reuse.
lc = layer_collection.LayerCollection()
register_fn(lc)
with self.assertRaises(ValueError):
register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE)
# Succeeds on second if reuse=VARIABLE_SCOPE and variable reuse.
lc = layer_collection.LayerCollection()
register_fn(lc)
with variable_scope.variable_scope(
variable_scope.get_variable_scope(), reuse=True):
register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE)
# Fails if block type changes.
lc = layer_collection.LayerCollection()
register_fn(lc, approx=layer_collection.APPROX_KRONECKER_NAME)
with self.assertRaises(ValueError):
register_fn(lc, approx=layer_collection.APPROX_DIAGONAL_NAME, reuse=True)
# Fails if reuse requested but no FisherBlock exists.
lc = layer_collection.LayerCollection()
with self.assertRaises(KeyError):
register_fn(lc, reuse=True)
def testRegisterFullyConnectedReuse(self):
"""Ensure the 'reuse' works with register_fully_connected."""
with ops.Graph().as_default():
inputs = array_ops.ones([2, 10])
outputs = array_ops.zeros([2, 5])
params = (
variable_scope.get_variable('w', [10, 5]), #
variable_scope.get_variable('b', [5]))
def register_fn(lc, **kwargs):
lc.register_fully_connected(
params=params, inputs=inputs, outputs=outputs, **kwargs)
self.ensureLayerReuseWorks(register_fn)
def testRegisterConv2dReuse(self):
"""Ensure the 'reuse' works with register_conv2d."""
with ops.Graph().as_default():
inputs = array_ops.ones([2, 5, 5, 10])
outputs = array_ops.zeros([2, 5, 5, 3])
params = (
variable_scope.get_variable('w', [1, 1, 10, 3]), #
variable_scope.get_variable('b', [3]))
def register_fn(lc, **kwargs):
lc.register_conv2d(
params=params,
strides=[1, 1, 1, 1],
padding='SAME',
inputs=inputs,
outputs=outputs,
**kwargs)
self.ensureLayerReuseWorks(register_fn)
def testReuseWithInvalidRegistration(self):
"""Invalid registrations shouldn't overwrite existing blocks."""
with ops.Graph().as_default():
inputs = array_ops.ones([2, 5, 5, 10])
outputs = array_ops.zeros([2, 5, 5, 3])
w = variable_scope.get_variable('w', [1, 1, 10, 3])
b = variable_scope.get_variable('b', [3])
lc = layer_collection.LayerCollection()
lc.register_fully_connected(w, inputs, outputs)
self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1)
with self.assertRaises(KeyError):
lc.register_fully_connected((w, b), inputs, outputs, reuse=True)
self.assertNotIn((w, b), lc.fisher_blocks)
self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1)
lc.register_fully_connected(w, inputs, outputs, reuse=True)
self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 2)
def testMakeOrGetFactor(self):
with ops.Graph().as_default():
random_seed.set_random_seed(200)
lc = layer_collection.LayerCollection()
key = array_ops.constant(1)
lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
lc.make_or_get_factor(fisher_factors.FullFactor,
((array_ops.constant(2),), 16))
self.assertEqual(2, len(lc.get_factors()))
variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertTrue(
all([var.name.startswith('LayerCollection') for var in variables]))
def testMakeOrGetFactorCustomScope(self):
with ops.Graph().as_default():
random_seed.set_random_seed(200)
scope = 'Foo'
lc = layer_collection.LayerCollection(name=scope)
key = array_ops.constant(1)
lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
lc.make_or_get_factor(fisher_factors.FullFactor,
((array_ops.constant(2),), 16))
self.assertEqual(2, len(lc.get_factors()))
variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertTrue(all([var.name.startswith(scope) for var in variables]))
def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self):
x = variable_scope.get_variable('x', shape=())
y = variable_scope.get_variable('y', shape=())
z = variable_scope.get_variable('z', shape=())
lc = layer_collection.LayerCollection()
lc.define_linked_parameters((x, y))
with self.assertRaises(ValueError):
lc.define_linked_parameters((x, z))
def testIdentifySubsetPreviouslyRegisteredTensor(self):
x = variable_scope.get_variable('x', shape=())
y = variable_scope.get_variable('y', shape=())
lc = layer_collection.LayerCollection()
lc.define_linked_parameters((x, y))
with self.assertRaises(ValueError):
lc.define_linked_parameters(x)
def testSpecifyApproximation(self):
w_0 = variable_scope.get_variable('w_0', [10, 10])
w_1 = variable_scope.get_variable('w_1', [10, 10])
b_0 = variable_scope.get_variable('b_0', [10])
b_1 = variable_scope.get_variable('b_1', [10])
x_0 = array_ops.placeholder(dtypes.float32, shape=(32, 10))
x_1 = array_ops.placeholder(dtypes.float32, shape=(32, 10))
pre_bias_0 = math_ops.matmul(x_0, w_0)
pre_bias_1 = math_ops.matmul(x_1, w_1)
# Build the fully connected layers in the graph.
pre_bias_0 + b_0 # pylint: disable=pointless-statement
pre_bias_1 + b_1 # pylint: disable=pointless-statement
lc = layer_collection.LayerCollection()
lc.define_linked_parameters(
w_0, approximation=layer_collection.APPROX_DIAGONAL_NAME)
lc.define_linked_parameters(
w_1, approximation=layer_collection.APPROX_DIAGONAL_NAME)
lc.define_linked_parameters(
b_0, approximation=layer_collection.APPROX_FULL_NAME)
lc.define_linked_parameters(
b_1, approximation=layer_collection.APPROX_FULL_NAME)
lc.register_fully_connected(w_0, x_0, pre_bias_0)
lc.register_fully_connected(
w_1, x_1, pre_bias_1, approx=layer_collection.APPROX_KRONECKER_NAME)
self.assertIsInstance(lc.fisher_blocks[w_0],
fisher_blocks.FullyConnectedDiagonalFB)
self.assertIsInstance(lc.fisher_blocks[w_1],
fisher_blocks.FullyConnectedKFACBasicFB)
lc.register_generic(b_0, batch_size=1)
lc.register_generic(
b_1, batch_size=1, approx=layer_collection.APPROX_DIAGONAL_NAME)
self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.FullFB)
self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB)
def testDefaultLayerCollection(self):
with ops.Graph().as_default():
# Can't get default if there isn't one set.
with self.assertRaises(ValueError):
layer_collection.get_default_layer_collection()
# Can't set default twice.
lc = layer_collection.LayerCollection()
layer_collection.set_default_layer_collection(lc)
with self.assertRaises(ValueError):
layer_collection.set_default_layer_collection(lc)
# Same as one set.
self.assertTrue(lc is layer_collection.get_default_layer_collection())
# Can set to None.
layer_collection.set_default_layer_collection(None)
with self.assertRaises(ValueError):
layer_collection.get_default_layer_collection()
# as_default() is the same as setting/clearing.
with lc.as_default():
self.assertTrue(lc is layer_collection.get_default_layer_collection())
with self.assertRaises(ValueError):
layer_collection.get_default_layer_collection()
if __name__ == '__main__':
test.main()

View File

@ -1,190 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf.contrib.kfac.loss_functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.kfac.python.ops import loss_functions
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class InsertSliceInZerosTest(test.TestCase):
def testBadShape(self):
bad_shaped_ones = array_ops.ones(shape=[1, 3]) # n.b. shape[1] != 1
with self.assertRaises(ValueError):
loss_functions.insert_slice_in_zeros(bad_shaped_ones, 1, 42, 17)
def test3d(self):
input_tensor = constant_op.constant([[[1, 2]], [[3, 4]]])
expected_output_array = [[[1, 2], [0, 0]], [[3, 4], [0, 0]]]
op = loss_functions.insert_slice_in_zeros(input_tensor, 1, 2, 0)
with self.cached_session() as sess:
actual_output_array = sess.run(op)
self.assertAllEqual(expected_output_array, actual_output_array)
class CategoricalLogitsNegativeLogProbLossTest(test.TestCase):
def testSample(self):
"""Ensure samples can be drawn."""
with ops.Graph().as_default(), self.cached_session() as sess:
logits = np.asarray([
[0., 0., 0.], #
[1., -1., 0.]
]).astype(np.float32)
loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
array_ops.constant(logits))
sample = loss.sample(42)
sample = sess.run(sample)
self.assertEqual(sample.shape, (2,))
def testEvaluateOnTargets(self):
"""Ensure log probability can be evaluated correctly."""
with ops.Graph().as_default(), self.cached_session() as sess:
logits = np.asarray([
[0., 0., 0.], #
[1., -1., 0.]
]).astype(np.float32)
targets = np.asarray([2, 1]).astype(np.int32)
loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
array_ops.constant(logits), targets=array_ops.constant(targets))
neg_log_prob = loss.evaluate()
neg_log_prob = sess.run(neg_log_prob)
# Calculate explicit log probability of targets.
probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
log_probs = np.log([
probs[0, targets[0]], #
probs[1, targets[1]]
])
expected_log_prob = np.sum(log_probs)
self.assertAllClose(neg_log_prob, -expected_log_prob)
def testEvaluateOnSample(self):
"""Ensure log probability of a sample can be drawn."""
with ops.Graph().as_default(), self.cached_session() as sess:
logits = np.asarray([
[0., 0., 0.], #
[1., -1., 0.]
]).astype(np.float32)
loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
array_ops.constant(logits))
neg_log_prob = loss.evaluate_on_sample(42)
# Simply ensure this doesn't crash. As the output is random, it's
# difficult to say if the output is correct or not...
neg_log_prob = sess.run(neg_log_prob)
def testMultiplyFisherSingleVector(self):
with ops.Graph().as_default(), self.cached_session() as sess:
logits = np.array([1., 2., 3.])
loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)
# the LossFunction.multiply_fisher docstring only says it supports the
# case where the vector is the same shape as the input natural parameters
# (i.e. the logits here), but here we also test leading dimensions
vector = np.array([1., 2., 3.])
vectors = [vector, vector.reshape(1, -1), np.stack([vector] * 4)]
probs = np.exp(logits - np.logaddexp.reduce(logits))
fisher = np.diag(probs) - np.outer(probs, probs)
for vector in vectors:
result = loss.multiply_fisher(vector)
expected_result = np.dot(vector, fisher)
self.assertAllClose(expected_result, sess.run(result))
def testMultiplyFisherBatch(self):
with ops.Graph().as_default(), self.cached_session() as sess:
logits = np.array([[1., 2., 3.], [4., 6., 8.]])
loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)
vector = np.array([[1., 2., 3.], [5., 3., 1.]])
na = np.newaxis
probs = np.exp(logits - np.logaddexp.reduce(logits, axis=-1,
keepdims=True))
fishers = probs[..., na] * np.eye(3) - probs[..., na] * probs[..., na, :]
result = loss.multiply_fisher(vector)
expected_result = np.matmul(vector[..., na, :], fishers)[..., 0, :]
self.assertEqual(sess.run(result).shape, logits.shape)
self.assertAllClose(expected_result, sess.run(result))
class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase):
def testSample(self):
"""Ensure samples can be drawn."""
with ops.Graph().as_default(), self.cached_session() as sess:
logits = np.asarray([
[0., 0., 0.], #
[1., -1., 0.]
]).astype(np.float32)
loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
array_ops.constant(logits))
sample = loss.sample(42)
sample = sess.run(sample)
self.assertEqual(sample.shape, (2, 3))
def testEvaluateOnTargets(self):
"""Ensure log probability can be evaluated correctly."""
with ops.Graph().as_default(), self.cached_session() as sess:
logits = np.asarray([
[0., 0., 0.], #
[1., -1., 0.]
]).astype(np.float32)
targets = np.asarray([2, 1]).astype(np.int32)
loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
array_ops.constant(logits), targets=array_ops.one_hot(targets, 3))
neg_log_prob = loss.evaluate()
neg_log_prob = sess.run(neg_log_prob)
# Calculate explicit log probability of targets.
probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
log_probs = np.log([
probs[0, targets[0]], #
probs[1, targets[1]]
])
expected_log_prob = np.sum(log_probs)
self.assertAllClose(neg_log_prob, -expected_log_prob)
def testEvaluateOnSample(self):
"""Ensure log probability of a sample can be drawn."""
with ops.Graph().as_default(), self.cached_session() as sess:
logits = np.asarray([
[0., 0., 0.], #
[1., -1., 0.]
]).astype(np.float32)
loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
array_ops.constant(logits))
neg_log_prob = loss.evaluate_on_sample(42)
# Simply ensure this doesn't crash. As the output is random, it's
# difficult to say if the output is correct or not...
neg_log_prob = sess.run(neg_log_prob)
if __name__ == "__main__":
test.main()

View File

@ -1,50 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf.contrib.kfac.op_queue."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.kfac.python.ops import op_queue
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class OpQueueTest(test.TestCase):
def testNextOp(self):
"""Ensures all ops get selected eventually."""
with tf_ops.Graph().as_default():
ops = [
math_ops.add(1, 2),
math_ops.subtract(1, 2),
math_ops.reduce_mean([1, 2]),
]
queue = op_queue.OpQueue(ops, seed=0)
with self.cached_session() as sess:
# Ensure every inv update op gets selected.
selected_ops = set([queue.next_op(sess) for _ in ops])
self.assertEqual(set(ops), set(selected_ops))
# Ensure additional calls don't create any new ops.
selected_ops.add(queue.next_op(sess))
self.assertEqual(set(ops), set(selected_ops))
if __name__ == "__main__":
test.main()

View File

@ -1,219 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf.contrib.kfac.optimizer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
from tensorflow.contrib.kfac.python.ops import layer_collection as lc
from tensorflow.contrib.kfac.python.ops import optimizer
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import test
# We need to set these constants since the numerical values used in the tests
# were chosen when these used to be the defaults.
ff.set_global_constants(init_covariances_at_zero=False,
zero_debias=False,
init_inverses_at_zero=False)
def dummy_layer_collection():
lcoll = lc.LayerCollection()
dummy = array_ops.constant([1., 2.])
lcoll.register_categorical_predictive_distribution(logits=dummy)
return lcoll
class OptimizerTest(test.TestCase):
def testOptimizerInitInvalidMomentumRegistration(self):
with self.assertRaises(ValueError):
optimizer.KfacOptimizer(
0.1, 0.2, 0.3, lc.LayerCollection(), momentum_type='foo')
def testOptimizerInit(self):
with ops.Graph().as_default():
layer_collection = lc.LayerCollection()
inputs = array_ops.ones((2, 1)) * 2
weights_val = np.ones((1, 1), dtype=np.float32) * 3.
weights = variable_scope.get_variable(
'w', initializer=array_ops.constant(weights_val))
bias = variable_scope.get_variable(
'b', initializer=init_ops.zeros_initializer(), shape=(1, 1))
output = math_ops.matmul(inputs, weights) + bias
layer_collection.register_fully_connected((weights, bias), inputs, output)
logits = math_ops.tanh(output)
targets = array_ops.constant([[0.], [1.]])
output = math_ops.reduce_mean(
nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets))
layer_collection.register_categorical_predictive_distribution(logits)
optimizer.KfacOptimizer(
0.1,
0.2,
0.3,
layer_collection,
momentum=0.5,
momentum_type='regular')
def testSquaredFisherNorm(self):
with ops.Graph().as_default(), self.cached_session() as sess:
grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None),
(array_ops.constant([[2., 3.], [4., 5.]]), None)]
pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None),
(array_ops.constant([[7., 8.], [9., 10.]]), None)]
opt = optimizer.KfacOptimizer(0.1, 0.2, 0.3, dummy_layer_collection())
sq_norm = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars)
self.assertAlmostEqual(174., sess.run(sq_norm), places=5)
def testUpdateClipCoeff(self):
with ops.Graph().as_default(), self.cached_session() as sess:
grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None),
(array_ops.constant([[2., 3.], [4., 5.]]), None)]
pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None),
(array_ops.constant([[7., 8.], [9., 10.]]), None)]
lrate = 0.1
# Note: without rescaling, the squared Fisher norm of the update
# is 1.74
# If the update already satisfies the norm constraint, there should
# be no rescaling.
opt = optimizer.KfacOptimizer(
lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=10.)
coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars)
self.assertAlmostEqual(1., sess.run(coeff), places=5)
# If the update violates the constraint, it should be rescaled to
# be on the constraint boundary.
opt = optimizer.KfacOptimizer(
lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=0.5)
coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars)
sq_norm_pgrad = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars)
sq_norm_update = lrate**2 * coeff**2 * sq_norm_pgrad
self.assertAlmostEqual(0.5, sess.run(sq_norm_update), places=5)
def testComputeUpdateStepsRegular(self):
# TODO(olganw): implement this.
pass
def testComputeUpdateStepsAdam(self):
# TODO(olganw): implement this.
pass
def testUpdateVelocities(self):
with ops.Graph().as_default(), self.cached_session() as sess:
layers = lc.LayerCollection()
layers.register_categorical_predictive_distribution(
array_ops.constant([1.0]))
opt = optimizer.KfacOptimizer(
0.1, 0.2, 0.3, layers, momentum=0.5, momentum_type='regular')
x = variable_scope.get_variable('x', initializer=array_ops.ones((2, 2)))
y = variable_scope.get_variable(
'y', initializer=array_ops.ones((2, 2)) * 2)
vec1 = array_ops.ones((2, 2)) * 3
vec2 = array_ops.ones((2, 2)) * 4
model_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
update_op = opt._update_velocities([(vec1, x), (vec2, y)], 0.5)
opt_vars = [
v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
if v not in model_vars
]
sess.run(tf_variables.global_variables_initializer())
old_opt_vars = sess.run(opt_vars)
# Optimizer vars start out at 0.
for opt_var in old_opt_vars:
self.assertAllEqual(sess.run(array_ops.zeros_like(opt_var)), opt_var)
sess.run(update_op)
new_opt_vars = sess.run(opt_vars)
# After one update, the velocities are equal to the vectors.
for vec, opt_var in zip([vec1, vec2], new_opt_vars):
self.assertAllEqual(sess.run(vec), opt_var)
sess.run(update_op)
final_opt_vars = sess.run(opt_vars)
for first, second in zip(new_opt_vars, final_opt_vars):
self.assertFalse(np.equal(first, second).all())
def testApplyGradients(self):
with ops.Graph().as_default(), self.cached_session() as sess:
layer_collection = lc.LayerCollection()
inputs = array_ops.ones((2, 1)) * 2
weights_val = np.ones((1, 1), dtype=np.float32) * 3.
weights = variable_scope.get_variable(
'w', initializer=array_ops.constant(weights_val))
bias = variable_scope.get_variable(
'b', initializer=init_ops.zeros_initializer(), shape=(1, 1))
output = math_ops.matmul(inputs, weights) + bias
layer_collection.register_fully_connected((weights, bias), inputs, output)
logits = math_ops.tanh(output)
targets = array_ops.constant([[0.], [1.]])
output = math_ops.reduce_mean(
nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets))
layer_collection.register_categorical_predictive_distribution(logits)
opt = optimizer.KfacOptimizer(
0.1,
0.2,
0.3,
layer_collection,
momentum=0.5,
momentum_type='regular')
(cov_update_thunks,
inv_update_thunks) = opt.make_vars_and_create_op_thunks()
cov_update_ops = tuple(thunk() for thunk in cov_update_thunks)
inv_update_ops = tuple(thunk() for thunk in inv_update_thunks)
grads_and_vars = opt.compute_gradients(output, [weights, bias])
all_vars = [grad_and_var[1] for grad_and_var in grads_and_vars]
op = opt.apply_gradients(grads_and_vars)
sess.run(tf_variables.global_variables_initializer())
old_vars = sess.run(all_vars)
sess.run(cov_update_ops)
sess.run(inv_update_ops)
sess.run(op)
new_vars = sess.run(all_vars)
for old_var, new_var in zip(old_vars, new_vars):
self.assertNotEqual(old_var, new_var)
if __name__ == '__main__':
test.main()

View File

@ -1,410 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf.contrib.kfac.utils."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import numpy.random as npr
from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
class SequenceDictTest(test.TestCase):
def testSequenceDictInit(self):
seq_dict = utils.SequenceDict()
self.assertFalse(seq_dict._dict)
def testSequenceDictInitWithIterable(self):
reg_dict = {'a': 'foo', 'b': 'bar'}
itr = zip(reg_dict.keys(), reg_dict.values())
seq_dict = utils.SequenceDict(itr)
self.assertEqual(reg_dict, seq_dict._dict)
def testGetItemSingleKey(self):
seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'})
self.assertEqual('foo', seq_dict['a'])
def testGetItemMultipleKeys(self):
seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'})
self.assertEqual(['foo', 'bar'], seq_dict[('a', 'b')])
def testSetItemSingleKey(self):
seq_dict = utils.SequenceDict()
seq_dict['a'] = 'foo'
self.assertEqual([('a', 'foo')], seq_dict.items())
def testSetItemMultipleKeys(self):
seq_dict = utils.SequenceDict()
keys = ('a', 'b', 'c')
values = ('foo', 'bar', 'baz')
seq_dict[keys] = values
self.assertItemsEqual(list(zip(keys, values)), seq_dict.items())
class SubGraphTest(test.TestCase):
def testBasicGraph(self):
a = array_ops.constant([[1., 2.], [3., 4.]])
b = array_ops.constant([[5., 6.], [7., 8.]])
c = a + b
d = a * b
sub_graph = utils.SubGraph((c,))
self.assertTrue(sub_graph.is_member(a))
self.assertTrue(sub_graph.is_member(b))
self.assertTrue(sub_graph.is_member(c))
self.assertFalse(sub_graph.is_member(d))
def testRepeatedAdds(self):
a = array_ops.constant([[1., 2.], [3., 4.]])
b = array_ops.constant([[5., 6.], [7., 8.]])
c = a + b + a # note that a appears twice in this graph
sub_graph = utils.SubGraph((c,))
self.assertTrue(sub_graph.is_member(a))
self.assertTrue(sub_graph.is_member(b))
self.assertTrue(sub_graph.is_member(c))
def testFilterList(self):
a = array_ops.constant([[1., 2.], [3., 4.]])
b = array_ops.constant([[5., 6.], [7., 8.]])
c = a + b
d = a * b
sub_graph = utils.SubGraph((c,))
input_list = [b, d]
filtered_list = sub_graph.filter_list(input_list)
self.assertEqual(filtered_list, [b])
def testVariableUses(self):
with ops.Graph().as_default():
var = variable_scope.get_variable('var', shape=[10, 10])
resource_var = variable_scope.get_variable(
'resource_var', shape=[10, 10], use_resource=True)
x = array_ops.zeros([3, 10])
z0 = math_ops.matmul(x, var) + math_ops.matmul(x, var)
z1 = math_ops.matmul(x, resource_var)
sub_graph = utils.SubGraph((z0, z1))
self.assertEqual(2, sub_graph.variable_uses(var))
self.assertEqual(1, sub_graph.variable_uses(resource_var))
class UtilsTest(test.TestCase):
def _fully_connected_layer_params(self):
weights_part = array_ops.constant([[1., 2.], [4., 3.]])
bias_part = array_ops.constant([1., 2.])
return (weights_part, bias_part)
def _conv_layer_params(self):
weights_shape = 2, 2, 3, 4
biases_shape = weights_shape[-1:]
weights = array_ops.constant(npr.RandomState(0).randn(*weights_shape))
biases = array_ops.constant(npr.RandomState(1).randn(*biases_shape))
return (weights, biases)
def testFullyConnectedLayerParamsTupleToMat2d(self):
with ops.Graph().as_default(), self.cached_session() as sess:
random_seed.set_random_seed(200)
layer_params = self._fully_connected_layer_params()
output = utils.layer_params_to_mat2d(layer_params)
self.assertListEqual([3, 2], output.get_shape().as_list())
self.assertAllClose(
sess.run(output), np.array([[1., 2.], [4., 3.], [1., 2.]]))
def testFullyConnectedLayerParamsTensorToMat2d(self):
with ops.Graph().as_default(), self.cached_session() as sess:
random_seed.set_random_seed(200)
layer_params = self._fully_connected_layer_params()
output = utils.layer_params_to_mat2d(layer_params[0])
self.assertListEqual([2, 2], output.get_shape().as_list())
self.assertAllClose(sess.run(output), np.array([[1., 2.], [4., 3.]]))
def testConvLayerParamsTupleToMat2d(self):
with ops.Graph().as_default():
random_seed.set_random_seed(200)
layer_params = self._conv_layer_params()
output = utils.layer_params_to_mat2d(layer_params)
self.assertListEqual([2 * 2 * 3 + 1, 4], output.get_shape().as_list())
def testKron(self):
with ops.Graph().as_default(), self.cached_session() as sess:
mat1 = np.array([[1., 2.], [3., 4.]])
mat2 = np.array([[5., 6.], [7., 8.]])
mat1_tf = array_ops.constant(mat1)
mat2_tf = array_ops.constant(mat2)
ans_tf = sess.run(utils.kronecker_product(mat1_tf, mat2_tf))
ans_np = np.kron(mat1, mat2)
self.assertAllClose(ans_tf, ans_np)
def testMat2dToFullyConnectedLayerParamsTuple(self):
with ops.Graph().as_default(), self.cached_session() as sess:
random_seed.set_random_seed(200)
vector_template = self._fully_connected_layer_params()
mat2d = array_ops.constant([[5., 4.], [3., 2.], [1., 0.]])
output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d))
self.assertIsInstance(output, tuple)
self.assertEqual(len(output), 2)
a, b = output
self.assertAllClose(a, np.array([[5., 4.], [3., 2.]]))
self.assertAllClose(b, np.array([1., 0.]))
def testMat2dToFullyConnectedLayerParamsTensor(self):
with ops.Graph().as_default(), self.cached_session() as sess:
random_seed.set_random_seed(200)
vector_template = self._fully_connected_layer_params()[0]
mat2d = array_ops.constant([[5., 4.], [3., 2.]])
output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d))
self.assertAllClose(output, np.array([[5., 4.], [3., 2.]]))
def testTensorsToColumn(self):
with ops.Graph().as_default(), self.cached_session() as sess:
random_seed.set_random_seed(200)
vector = array_ops.constant(np.array([[0., 1.], [2., 3.]]))
output = utils.tensors_to_column(vector)
self.assertListEqual([4, 1], output.get_shape().as_list())
self.assertAllClose(sess.run(output), np.array([0., 1., 2., 3.])[:, None])
vector = self._fully_connected_layer_params()
output = utils.tensors_to_column(vector)
self.assertListEqual([6, 1], output.get_shape().as_list())
self.assertAllClose(
sess.run(output), np.array([1., 2., 4., 3., 1., 2.])[:, None])
vector = list(vector)
vector.append(array_ops.constant([[6.], [7.], [8.], [9.]]))
output = utils.tensors_to_column(vector)
self.assertListEqual([10, 1], output.get_shape().as_list())
self.assertAllClose(
sess.run(output),
np.array([1., 2., 4., 3., 1., 2., 6., 7., 8., 9.])[:, None])
def testColumnToTensors(self):
with ops.Graph().as_default(), self.cached_session() as sess:
random_seed.set_random_seed(200)
vector_template = array_ops.constant(np.array([[0., 1.], [2., 3.]]))
colvec = array_ops.constant(np.arange(4.)[:, None])
output = sess.run(utils.column_to_tensors(vector_template, colvec))
self.assertAllClose(output, np.array([[0., 1.], [2., 3.]]))
vector_template = self._fully_connected_layer_params()
colvec = array_ops.constant(np.arange(6.)[:, None])
output = sess.run(utils.column_to_tensors(vector_template, colvec))
self.assertIsInstance(output, tuple)
self.assertEqual(len(output), 2)
a, b = output
self.assertAllClose(a, np.array([[0., 1.], [2., 3.]]))
self.assertAllClose(b, np.array([4., 5.]))
vector_template = list(vector_template)
vector_template.append(array_ops.constant([[6.], [7.], [8.], [9.]]))
colvec = array_ops.constant(np.arange(10.)[:, None])
output = sess.run(utils.column_to_tensors(vector_template, colvec))
self.assertIsInstance(output, tuple)
self.assertEqual(len(output), 3)
a, b, c = output
self.assertAllClose(a, np.array([[0., 1.], [2., 3.]]))
self.assertAllClose(b, np.array([4., 5.]))
self.assertAllClose(c, np.array([[6.], [7.], [8.], [9.]]))
def testPosDefInvCholesky(self):
with ops.Graph().as_default(), self.cached_session() as sess:
random_seed.set_random_seed(200)
npr.seed(0)
square = lambda x: np.dot(x, x.T)
size = 3
x = square(npr.randn(size, size))
damp = 0.1
identity = linalg_ops.eye(size, dtype=dtypes.float64)
tf_inv = utils.posdef_inv_cholesky(array_ops.constant(x), identity, damp)
np_inv = np.linalg.inv(x + damp * np.eye(size))
self.assertAllClose(sess.run(tf_inv), np_inv)
def testPosDefInvMatrixInverse(self):
with ops.Graph().as_default(), self.cached_session() as sess:
random_seed.set_random_seed(200)
npr.seed(0)
square = lambda x: np.dot(x, x.T)
size = 3
x = square(npr.randn(size, size))
damp = 0.1
identity = linalg_ops.eye(size, dtype=dtypes.float64)
tf_inv = utils.posdef_inv_matrix_inverse(
array_ops.constant(x), identity, damp)
np_inv = np.linalg.inv(x + damp * np.eye(size))
self.assertAllClose(sess.run(tf_inv), np_inv)
def testCrossReplicaMean(self):
"""Ensures that cross_replica_mean() executes only when num_shards > 1."""
with ops.Graph().as_default():
with tpu_function.tpu_shard_context(4):
tensor = array_ops.zeros([], dtype=dtypes.float32)
mean = utils.cross_replica_mean(tensor)
self.assertNotEqual(mean, tensor)
with ops.Graph().as_default():
with tpu_function.tpu_shard_context(1):
tensor = array_ops.zeros([], dtype=dtypes.float32)
mean = utils.cross_replica_mean(tensor)
self.assertEqual(mean, tensor)
with ops.Graph().as_default():
with self.assertRaises(ValueError): # Outside of TPU context.
tensor = array_ops.zeros([], dtype=dtypes.float32)
mean = utils.cross_replica_mean(tensor)
def testBatchExecute(self):
"""Ensure batch_execute runs in a round-robin fashion."""
def increment_var(var):
return lambda: var.assign_add(1)
with ops.Graph().as_default(), self.cached_session() as sess:
i = variable_scope.get_variable('i', initializer=0)
accumulators = [
variable_scope.get_variable('var%d' % j, initializer=0)
for j in range(3)
]
thunks = [increment_var(var) for var in accumulators]
increment_accumulators = utils.batch_execute(i, thunks, 2)
increment_i = i.assign_add(1)
sess.run(variables.global_variables_initializer())
# Ensure one op per thunk.
self.assertEqual(3, len(increment_accumulators))
# Ensure round-robin execution.
values = []
for _ in range(5):
sess.run(increment_accumulators)
sess.run(increment_i)
values.append(sess.run(accumulators))
self.assertAllClose(
[
[1, 1, 0], #
[2, 1, 1], #
[2, 2, 2], #
[3, 3, 2], #
[4, 3, 3]
],
values)
def testExtractConvolutionPatches(self):
with ops.Graph().as_default(), self.cached_session() as sess:
batch_size = 10
image_spatial_shape = [9, 10, 11]
in_channels = out_channels = 32
kernel_spatial_shape = [5, 3, 3]
spatial_strides = [1, 2, 1]
spatial_dilation = [1, 1, 1]
padding = 'SAME'
images = random_ops.random_uniform(
[batch_size] + image_spatial_shape + [in_channels], seed=0)
kernel_shape = kernel_spatial_shape + [in_channels, out_channels]
kernel = random_ops.random_uniform(kernel_shape, seed=1)
# Ensure shape matches expectation.
patches = utils.extract_convolution_patches(
images,
kernel_shape,
padding,
strides=spatial_strides,
dilation_rate=spatial_dilation)
result_spatial_shape = (
patches.shape.as_list()[1:1 + len(image_spatial_shape)])
self.assertEqual(patches.shape.as_list(),
[batch_size] + result_spatial_shape +
kernel_spatial_shape + [in_channels])
# Ensure extract...patches() + matmul() and convolution() implementation
# give the same answer.
outputs = nn_ops.convolution(
images,
kernel,
padding,
strides=spatial_strides,
dilation_rate=spatial_dilation)
patches_flat = array_ops.reshape(
patches, [-1, np.prod(kernel_spatial_shape) * in_channels])
kernel_flat = array_ops.reshape(kernel, [-1, out_channels])
outputs_flat = math_ops.matmul(patches_flat, kernel_flat)
outputs_, outputs_flat_ = sess.run([outputs, outputs_flat])
self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten())
def testExtractPointwiseConv2dPatches(self):
with ops.Graph().as_default(), self.cached_session() as sess:
batch_size = 10
image_height = image_width = 8
in_channels = out_channels = 3
kernel_height = kernel_width = 1
strides = [1, 1, 1, 1]
padding = 'VALID'
images = random_ops.random_uniform(
[batch_size, image_height, image_width, in_channels], seed=0)
kernel_shape = [kernel_height, kernel_width, in_channels, out_channels]
kernel = random_ops.random_uniform(kernel_shape, seed=1)
# Ensure shape matches expectation.
patches = utils.extract_pointwise_conv2d_patches(images, kernel_shape)
self.assertEqual(patches.shape.as_list(), [
batch_size, image_height, image_width, kernel_height, kernel_width,
in_channels
])
# Ensure extract...patches() + matmul() and conv2d() implementation
# give the same answer.
outputs = nn_ops.conv2d(images, kernel, strides, padding)
patches_flat = array_ops.reshape(
patches, [-1, kernel_height * kernel_width * in_channels])
kernel_flat = array_ops.reshape(kernel, [-1, out_channels])
outputs_flat = math_ops.matmul(patches_flat, kernel_flat)
outputs_, outputs_flat_ = sess.run([outputs, outputs_flat])
self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten())
if __name__ == '__main__':
test.main()

View File

@ -1,263 +0,0 @@
package(default_visibility = [
"//tensorflow/contrib/kfac:__pkg__",
"//tensorflow/contrib/kfac/python/kernel_tests:__pkg__",
])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
py_library(
name = "fisher_blocks",
srcs = ["fisher_blocks.py"],
srcs_version = "PY2AND3",
deps = [
":fisher_factors",
":utils",
"//tensorflow/python:array_ops",
"//tensorflow/python:math_ops",
"@six_archive//:six",
],
)
py_library(
name = "fisher_blocks_lib",
srcs = ["fisher_blocks_lib.py"],
srcs_version = "PY2AND3",
deps = [
":fisher_blocks",
"//tensorflow/python:util",
],
)
py_library(
name = "fisher_factors",
srcs = ["fisher_factors.py"],
srcs_version = "PY2AND3",
deps = [
":linear_operator",
":utils",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:special_math_ops",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//third_party/py/numpy",
"@six_archive//:six",
],
)
py_library(
name = "fisher_factors_lib",
srcs = ["fisher_factors_lib.py"],
srcs_version = "PY2AND3",
deps = [
":fisher_factors",
"//tensorflow/python:util",
],
)
py_library(
name = "linear_operator",
srcs = ["linear_operator.py"],
srcs_version = "PY2AND3",
deps = [
":utils",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python/ops/linalg",
"@six_archive//:six",
],
)
py_library(
name = "loss_functions",
srcs = ["loss_functions.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/ops/distributions",
"@six_archive//:six",
],
)
py_library(
name = "loss_functions_lib",
srcs = ["loss_functions_lib.py"],
srcs_version = "PY2AND3",
deps = [
":loss_functions",
"//tensorflow/python:util",
],
)
py_library(
name = "curvature_matrix_vector_products",
srcs = ["curvature_matrix_vector_products.py"],
srcs_version = "PY2AND3",
deps = [
":utils",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
],
)
py_library(
name = "curvature_matrix_vector_products_lib",
srcs = ["curvature_matrix_vector_products_lib.py"],
srcs_version = "PY2AND3",
deps = [
":curvature_matrix_vector_products",
"//tensorflow/python:util",
],
)
py_library(
name = "layer_collection",
srcs = ["layer_collection.py"],
srcs_version = "PY2AND3",
deps = [
":fisher_blocks",
":loss_functions",
":utils",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"@six_archive//:six",
],
)
py_library(
name = "layer_collection_lib",
srcs = ["layer_collection_lib.py"],
srcs_version = "PY2AND3",
deps = [
":layer_collection",
"//tensorflow/python:util",
],
)
py_library(
name = "kfac_optimizer",
srcs = [
"optimizer.py",
],
srcs_version = "PY2AND3",
deps = [
":curvature_matrix_vector_products",
":fisher_estimator",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:state_ops",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
],
)
py_library(
name = "kfac_optimizer_lib",
srcs = [
"optimizer_lib.py",
],
srcs_version = "PY2AND3",
deps = [
":kfac_optimizer",
"//tensorflow/python:util",
],
)
py_library(
name = "fisher_estimator",
srcs = [
"estimator.py",
"placement.py",
],
srcs_version = "PY2AND3",
deps = [
":utils",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:util",
"//third_party/py/numpy",
"@six_archive//:six",
],
)
py_library(
name = "fisher_estimator_lib",
srcs = [
"estimator_lib.py",
],
srcs_version = "PY2AND3",
deps = [
":fisher_estimator",
"//tensorflow/python:util",
],
)
py_library(
name = "utils",
srcs = ["utils.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/tpu",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:random_ops",
"//third_party/py/numpy",
],
)
py_library(
name = "utils_lib",
srcs = ["utils_lib.py"],
srcs_version = "PY2AND3",
deps = [
":utils",
"//tensorflow/python:util",
],
)
py_library(
name = "op_queue",
srcs = ["op_queue.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/python:framework_ops",
],
)
py_library(
name = "op_queue_lib",
srcs = ["op_queue_lib.py"],
srcs_version = "PY2AND3",
deps = [
":op_queue",
"//tensorflow/python:util",
],
)

View File

@ -1,183 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Curvature matrix-vector multiplication."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.util import nest
class CurvatureMatrixVectorProductComputer(object):
"""Class for computing matrix-vector products for Fishers, GGNs and Hessians.
In other words we compute M*v where M is the matrix, v is the vector, and
* refers to standard matrix/vector multiplication (not element-wise
multiplication).
The matrices are defined in terms of some differential quantity of the total
loss function with respect to a provided list of tensors ("wrt_tensors").
For example, the Fisher associated with a log-prob loss w.r.t. the
parameters.
The 'vecs' argument to each method are lists of tensors that must be the
size as the corresponding ones from "wrt_tensors". They represent
the vector being multiplied.
"factors" of the matrix M are defined as matrices B such that B*B^T = M.
Methods that multiply by the factor B take a 'loss_inner_vecs' argument
instead of 'vecs', which must be a list of tensors with shapes given by the
corresponding XXX_inner_shapes property.
Note that matrix-vector products are not normalized by the batch size, nor
are any damping terms added to the results. These things can be easily
applied externally, if desired.
See for example: www.cs.utoronto.ca/~jmartens/docs/HF_book_chapter.pdf
and https://arxiv.org/abs/1412.1193 for more information about the
generalized Gauss-Newton, Fisher, etc., and how to compute matrix-vector
products.
"""
def __init__(self, losses, wrt_tensors):
"""Create a CurvatureMatrixVectorProductComputer object.
Args:
losses: A list of LossFunction instances whose sum defines the total loss.
wrt_tensors: A list of Tensors to compute the differential quantities
(defining the matrices) with respect to. See class description for more
info.
"""
self._losses = losses
self._inputs_to_losses = list(loss.inputs for loss in losses)
self._inputs_to_losses_flat = nest.flatten(self._inputs_to_losses)
self._wrt_tensors = wrt_tensors
@property
def _total_loss(self):
return math_ops.add_n(tuple(loss.evaluate() for loss in self._losses))
# Jacobian multiplication functions:
def _multiply_jacobian(self, vecs):
"""Multiply vecs by the Jacobian of losses."""
# We stop gradients at wrt_tensors to produce partial derivatives (which is
# what we want for Jacobians).
jacobian_vecs_flat = utils.fwd_gradients(
self._inputs_to_losses_flat, self._wrt_tensors, grad_xs=vecs,
stop_gradients=self._wrt_tensors)
return nest.pack_sequence_as(self._inputs_to_losses, jacobian_vecs_flat)
def _multiply_jacobian_transpose(self, loss_vecs):
"""Multiply vecs by the transpose Jacobian of losses."""
loss_vecs_flat = nest.flatten(loss_vecs)
# We stop gradients at wrt_tensors to produce partial derivatives (which is
# what we want for Jacobians).
return gradients_impl.gradients(
self._inputs_to_losses_flat, self._wrt_tensors, grad_ys=loss_vecs_flat,
stop_gradients=self._wrt_tensors)
# Losses Fisher/Hessian multiplication functions:
def _multiply_loss_fisher(self, loss_vecs):
"""Multiply loss_vecs by Fisher of total loss."""
return tuple(
loss.multiply_fisher(loss_vec)
for loss, loss_vec in zip(self._losses, loss_vecs))
def _multiply_loss_fisher_factor(self, loss_inner_vecs):
"""Multiply loss_inner_vecs by factor of Fisher of total loss."""
return tuple(
loss.multiply_fisher_factor(loss_vec)
for loss, loss_vec in zip(self._losses, loss_inner_vecs))
def _multiply_loss_fisher_factor_transpose(self, loss_vecs):
"""Multiply loss_vecs by transpose factor of Fisher of total loss."""
return tuple(
loss.multiply_fisher_factor_transpose(loss_vec)
for loss, loss_vec in zip(self._losses, loss_vecs))
def _multiply_loss_hessian(self, loss_vecs):
"""Multiply loss_vecs by Hessian of total loss."""
return tuple(
loss.multiply_hessian(loss_vec)
for loss, loss_vec in zip(self._losses, loss_vecs))
def _multiply_loss_hessian_factor(self, loss_inner_vecs):
"""Multiply loss_inner_vecs by factor of Hessian of total loss."""
return tuple(
loss.multiply_hessian_factor(loss_vec)
for loss, loss_vec in zip(self._losses, loss_inner_vecs))
def _multiply_loss_hessian_factor_transpose(self, loss_vecs):
"""Multiply loss_vecs by transpose factor of Hessian of total loss."""
return tuple(
loss.multiply_hessian_factor_transpose(loss_vec)
for loss, loss_vec in zip(self._losses, loss_vecs))
# Matrix-vector product functions:
def multiply_fisher(self, vecs):
"""Multiply vecs by Fisher of total loss."""
jacobian_vecs = self._multiply_jacobian(vecs)
loss_fisher_jacobian_vecs = self._multiply_loss_fisher(jacobian_vecs)
return self._multiply_jacobian_transpose(loss_fisher_jacobian_vecs)
def multiply_fisher_factor_transpose(self, vecs):
"""Multiply vecs by transpose of factor of Fisher of total loss."""
jacobian_vecs = self._multiply_jacobian(vecs)
return self._multiply_loss_fisher_factor_transpose(jacobian_vecs)
def multiply_fisher_factor(self, loss_inner_vecs):
"""Multiply loss_inner_vecs by factor of Fisher of total loss."""
fisher_factor_transpose_vecs = self._multiply_loss_fisher_factor_transpose(
loss_inner_vecs)
return self._multiply_jacobian_transpose(fisher_factor_transpose_vecs)
def multiply_hessian(self, vecs):
"""Multiply vecs by Hessian of total loss."""
return gradients_impl.gradients(
gradients_impl.gradients(self._total_loss, self._wrt_tensors),
self._wrt_tensors,
grad_ys=vecs)
def multiply_generalized_gauss_newton(self, vecs):
"""Multiply vecs by generalized Gauss-Newton of total loss."""
jacobian_vecs = self._multiply_jacobian(vecs)
loss_hessian_jacobian_vecs = self._multiply_loss_hessian(jacobian_vecs)
return self._multiply_jacobian_transpose(loss_hessian_jacobian_vecs)
def multiply_generalized_gauss_newton_factor_transpose(self, vecs):
"""Multiply vecs by transpose of factor of GGN of total loss."""
jacobian_vecs = self._multiply_jacobian(vecs)
return self._multiply_loss_hessian_factor_transpose(jacobian_vecs)
def multiply_generalized_gauss_newton_factor(self, loss_inner_vecs):
"""Multiply loss_inner_vecs by factor of GGN of total loss."""
hessian_factor_transpose_vecs = (
self._multiply_loss_hessian_factor_transpose(loss_inner_vecs))
return self._multiply_jacobian_transpose(hessian_factor_transpose_vecs)
# Shape properties for multiply_XXX_factor methods:
@property
def fisher_factor_inner_shapes(self):
"""Shapes required by multiply_fisher_factor."""
return tuple(loss.fisher_factor_inner_shape for loss in self._losses)
@property
def generalized_gauss_newton_factor_inner_shapes(self):
"""Shapes required by multiply_generalized_gauss_newton_factor."""
return tuple(loss.hessian_factor_inner_shape for loss in self._losses)

View File

@ -1,30 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Curvature matrix-vector multiplication."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,line-too-long,wildcard-import
from tensorflow.contrib.kfac.python.ops.curvature_matrix_vector_products import *
from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long,wildcard-import
_allowed_symbols = [
'CurvatureMatrixVectorProductComputer',
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)

View File

@ -1,516 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines the high-level Fisher estimator class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import numpy as np
import six
from tensorflow.contrib.kfac.python.ops import placement
from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest
# The linter is confused.
# pylint: disable=abstract-class-instantiated
def make_fisher_estimator(placement_strategy=None, **kwargs):
"""Creates Fisher estimator instances based on the placement strategy.
For example if the `placement_strategy` is 'round_robin' then
`FisherEstimatorRoundRobin` instance is returned.
Args:
placement_strategy: `string`, Strategy to be used for placing covariance
variables, covariance ops and inverse ops. Check
`placement.FisherEstimatorRoundRobin` for a concrete example.
**kwargs: Arguments to be passed into `FisherEstimator` class initializer.
Returns:
An instance of class which inherits from `FisherEstimator` and the mixin
which implements specific placement strategy. See,
`FisherEstimatorRoundRobin` which inherits from `FisherEstimator` and
`RoundRobinPlacementMixin`.
Raises:
ValueError: If the `placement_strategy` is not equal to 'round_robin'.
"""
if placement_strategy in [None, "round_robin"]:
return FisherEstimatorRoundRobin(**kwargs)
else:
raise ValueError("Unimplemented vars and ops "
"placement strategy : {}".format(placement_strategy))
# pylint: enable=abstract-class-instantiated
@six.add_metaclass(abc.ABCMeta)
class FisherEstimator(object):
"""Fisher estimator class supporting various approximations of the Fisher.
This is an abstract base class which does not implement a strategy for
placing covariance variables, covariance update ops and inverse update ops.
The placement strategies are implemented in `placement.py`. See
`FisherEstimatorRoundRobin` for example of a concrete subclass with
a round-robin placement strategy.
"""
def __init__(self,
variables,
cov_ema_decay,
damping,
layer_collection,
exps=(-1,),
estimation_mode="gradients",
colocate_gradients_with_ops=True,
name="FisherEstimator",
compute_cholesky=False,
compute_cholesky_inverse=False):
"""Create a FisherEstimator object.
Args:
variables: A `list` of variables or `callable` which returns the variables
for which to estimate the Fisher. This must match the variables
registered in layer_collection (if it is not None).
cov_ema_decay: The decay factor used when calculating the covariance
estimate moving averages.
damping: float. The damping factor used to stabilize training due to
errors in the local approximation with the Fisher information matrix,
and to regularize the update direction by making it closer to the
gradient. (Higher damping means the update looks more like a standard
gradient update - see Tikhonov regularization.)
layer_collection: The layer collection object, which holds the Fisher
blocks, Kronecker factors, and losses associated with the
graph.
exps: List of floats or ints. These represent the different matrix
powers of the approximate Fisher that the FisherEstimator will be able
to multiply vectors by. If the user asks for a matrix power other
one of these (or 1, which is always supported), there will be a
failure. (Default: (-1,))
estimation_mode: The type of estimator to use for the Fishers. Can be
'gradients', 'empirical', 'curvature_prop', or 'exact'.
(Default: 'gradients'). 'gradients' is the basic estimation approach
from the original K-FAC paper. 'empirical' computes the 'empirical'
Fisher information matrix (which uses the data's distribution for the
targets, as opposed to the true Fisher which uses the model's
distribution) and requires that each registered loss have specified
targets. 'curvature_propagation' is a method which estimates the
Fisher using self-products of random 1/-1 vectors times "half-factors"
of the Fisher, as described here: https://arxiv.org/abs/1206.6464 .
Finally, 'exact' is the obvious generalization of Curvature
Propagation to compute the exact Fisher (modulo any additional
diagonal or Kronecker approximations) by looping over one-hot vectors
for each coordinate of the output instead of using 1/-1 vectors. It
is more expensive to compute than the other three options by a factor
equal to the output dimension, roughly speaking.
colocate_gradients_with_ops: Whether we should request gradients be
colocated with their respective ops. (Default: True)
name: A string. A name given to this estimator, which is added to the
variable scope when constructing variables and ops.
(Default: "FisherEstimator")
compute_cholesky: Bool. Whether or not the FisherEstimator will be
able to multiply vectors by the Cholesky factor.
(Default: False)
compute_cholesky_inverse: Bool. Whether or not the FisherEstimator
will be able to multiply vectors by the Cholesky factor inverse.
(Default: False)
Raises:
ValueError: If no losses have been registered with layer_collection.
"""
self._variables = variables
self._cov_ema_decay = cov_ema_decay
self._damping = damping
self._estimation_mode = estimation_mode
self._layers = layer_collection
self._gradient_fns = {
"gradients": self._get_grads_lists_gradients,
"empirical": self._get_grads_lists_empirical,
"curvature_prop": self._get_grads_lists_curvature_prop,
"exact": self._get_grads_lists_exact
}
self._colocate_gradients_with_ops = colocate_gradients_with_ops
self._made_vars = False
self._exps = exps
self._compute_cholesky = compute_cholesky
self._compute_cholesky_inverse = compute_cholesky_inverse
self._name = name
@property
def variables(self):
if callable(self._variables):
return self._variables()
else:
return self._variables
@property
def damping(self):
return self._damping
@property
def blocks(self):
"""All registered FisherBlocks."""
return self._layers.get_blocks()
@property
def factors(self):
"""All registered FisherFactors."""
return self._layers.get_factors()
@property
def name(self):
return self._name
@abc.abstractmethod
def make_vars_and_create_op_thunks(self, scope=None):
"""Make vars and create op thunks with a specific placement strategy.
For each factor, all of that factor's cov variables and their associated
update ops will be placed on a particular device. A new device is chosen
for each factor by cycling through list of devices in the cov_devices
argument. If cov_devices is None then no explicit device placement occurs.
An analogous strategy is followed for inverse update ops, with the list of
devices being given by the inv_devices argument.
Inverse variables on the other hand are not placed on any specific device
(they will just use the current the device placement context, whatever
that happens to be). The idea is that the inverse variable belong where
they will be accessed most often, which is the device that actually applies
the preconditioner to the gradient. The user will be responsible for setting
the device context for this.
Args:
scope: A string or None. If None it will be set to the name of this
estimator (given by the name property). All variables will be created,
and all thunks will execute, inside of a variable scope of the given
name. (Default: None)
Returns:
cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
the list of factors given by the "factors" property.
inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
the list of factors given by the "factors" property.
"""
pass
def _apply_transformation(self, vecs_and_vars, transform):
"""Applies an block-wise transformation to the corresponding vectors.
Args:
vecs_and_vars: List of (vector, variable) pairs.
transform: A function of the form f(fb, vec), where vec is the vector
to transform and fb is its corresponding block in the matrix, that
returns the transformed vector.
Returns:
A list of (transformed vector, var) pairs in the same order as
vecs_and_vars.
"""
vecs = utils.SequenceDict((var, vec) for vec, var in vecs_and_vars)
trans_vecs = utils.SequenceDict()
for params, fb in self._layers.fisher_blocks.items():
trans_vecs[params] = transform(fb, vecs[params])
return [(trans_vecs[var], var) for _, var in vecs_and_vars]
def multiply_inverse(self, vecs_and_vars):
"""Multiplies the vecs by the corresponding (damped) inverses of the blocks.
Args:
vecs_and_vars: List of (vector, variable) pairs.
Returns:
A list of (transformed vector, var) pairs in the same order as
vecs_and_vars.
"""
return self.multiply_matpower(-1, vecs_and_vars)
def multiply(self, vecs_and_vars):
"""Multiplies the vectors by the corresponding (damped) blocks.
Args:
vecs_and_vars: List of (vector, variable) pairs.
Returns:
A list of (transformed vector, var) pairs in the same order as
vecs_and_vars.
"""
return self.multiply_matpower(1, vecs_and_vars)
def multiply_matpower(self, exp, vecs_and_vars):
"""Multiplies the vecs by the corresponding matrix powers of the blocks.
Args:
exp: A float representing the power to raise the blocks by before
multiplying it by the vector.
vecs_and_vars: List of (vector, variable) pairs.
Returns:
A list of (transformed vector, var) pairs in the same order as
vecs_and_vars.
"""
assert exp in self._exps
fcn = lambda fb, vec: fb.multiply_matpower(vec, exp)
return self._apply_transformation(vecs_and_vars, fcn)
def multiply_cholesky(self, vecs_and_vars, transpose=False):
"""Multiplies the vecs by the corresponding Cholesky factors.
Args:
vecs_and_vars: List of (vector, variable) pairs.
transpose: Bool. If true the Cholesky factors are transposed before
multiplying the vecs. (Default: False)
Returns:
A list of (transformed vector, var) pairs in the same order as
vecs_and_vars.
"""
assert self._compute_cholesky
fcn = lambda fb, vec: fb.multiply_cholesky(vec, transpose=transpose)
return self._apply_transformation(vecs_and_vars, fcn)
def multiply_cholesky_inverse(self, vecs_and_vars, transpose=False):
"""Mults the vecs by the inverses of the corresponding Cholesky factors.
Note: if you are using Cholesky inverse multiplication to sample from
a matrix-variate Gaussian you will want to multiply by the transpose.
Let L be the Cholesky factor of F and observe that
L^-T * L^-1 = (L * L^T)^-1 = F^-1 .
Thus we want to multiply by L^-T in order to sample from Gaussian with
covariance F^-1.
Args:
vecs_and_vars: List of (vector, variable) pairs.
transpose: Bool. If true the Cholesky factor inverses are transposed
before multiplying the vecs. (Default: False)
Returns:
A list of (transformed vector, var) pairs in the same order as
vecs_and_vars.
"""
assert self._compute_cholesky_inverse
fcn = lambda fb, vec: fb.multiply_cholesky_inverse(vec, transpose=transpose)
return self._apply_transformation(vecs_and_vars, fcn)
def _instantiate_factors(self):
"""Instantiates FisherFactors' variables.
Raises:
ValueError: If estimation_mode was improperly specified at construction.
"""
blocks = self.blocks
tensors_to_compute_grads = [
block.tensors_to_compute_grads() for block in blocks
]
try:
grads_lists = self._gradient_fns[self._estimation_mode](
tensors_to_compute_grads)
except KeyError:
raise ValueError("Unrecognized value {} for estimation_mode.".format(
self._estimation_mode))
for grads_list, block in zip(grads_lists, blocks):
block.instantiate_factors(grads_list, self.damping)
def _check_vars_unmade_and_set_made_flag(self):
if self._made_vars:
raise Exception("Already made variables.")
self._made_vars = True
def made_vars(self):
return self._made_vars
def _register_matrix_functions(self):
for block in self.blocks:
for exp in self._exps:
block.register_matpower(exp)
if self._compute_cholesky:
block.register_cholesky()
if self._compute_cholesky_inverse:
block.register_cholesky_inverse()
def _finalize_layer_collection(self):
self._layers.create_subgraph()
self._layers.check_registration(self.variables)
self._instantiate_factors()
self._register_matrix_functions()
def create_ops_and_vars_thunks(self, scope=None):
"""Create thunks that make the ops and vars on demand.
This function returns 4 lists of thunks: cov_variable_thunks,
cov_update_thunks, inv_variable_thunks, and inv_update_thunks.
The length of each list is the number of factors and the i-th element of
each list corresponds to the i-th factor (given by the "factors" property).
Note that the execution of these thunks must happen in a certain
partial order. The i-th element of cov_variable_thunks must execute
before the i-th element of cov_update_thunks (and also the i-th element
of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks
must execute before the i-th element of inv_update_thunks.
TL;DR (oversimplified): Execute the thunks according to the order that
they are returned.
Args:
scope: A string or None. If None it will be set to the name of this
estimator (given by the name property). All thunks will execute inside
of a variable scope of the given name. (Default: None)
Returns:
cov_variable_thunks: A list of thunks that make the cov variables.
cov_update_thunks: A list of thunks that make the cov update ops.
inv_variable_thunks: A list of thunks that make the inv variables.
inv_update_thunks: A list of thunks that make the inv update ops.
"""
self._check_vars_unmade_and_set_made_flag()
self._finalize_layer_collection()
scope = self.name if scope is None else scope
cov_variable_thunks = [
self._create_cov_variable_thunk(factor, scope)
for factor in self.factors
]
cov_update_thunks = [
self._create_cov_update_thunk(factor, scope) for factor in self.factors
]
inv_variable_thunks = [
self._create_inv_variable_thunk(factor, scope)
for factor in self.factors
]
inv_update_thunks = [
self._create_inv_update_thunk(factor, scope) for factor in self.factors
]
return (cov_variable_thunks, cov_update_thunks,
inv_variable_thunks, inv_update_thunks)
def _create_cov_variable_thunk(self, factor, scope):
"""Constructs a covariance variable thunk for a single FisherFactor."""
def thunk():
with variable_scope.variable_scope(scope):
return factor.instantiate_cov_variables()
return thunk
def _create_cov_update_thunk(self, factor, scope):
"""Constructs a covariance update thunk for a single FisherFactor."""
def thunk():
with variable_scope.variable_scope(scope):
return factor.make_covariance_update_op(self._cov_ema_decay)
return thunk
def _create_inv_variable_thunk(self, factor, scope):
"""Constructs a inverse variable thunk for a single FisherFactor."""
def thunk():
with variable_scope.variable_scope(scope):
return factor.instantiate_inv_variables()
return thunk
def _create_inv_update_thunk(self, factor, scope):
"""Constructs an inverse update thunk for a single FisherFactor."""
def thunk():
with variable_scope.variable_scope(scope):
return control_flow_ops.group(factor.make_inverse_update_ops())
return thunk
def _get_grads_lists_gradients(self, tensors):
# Passing in a list of loss values is better than passing in the sum as
# the latter creates unnessesary ops on the default device
grads_flat = gradients_impl.gradients(
self._layers.eval_losses_on_samples(),
nest.flatten(tensors),
colocate_gradients_with_ops=self._colocate_gradients_with_ops)
grads_all = nest.pack_sequence_as(tensors, grads_flat)
return tuple((grad,) for grad in grads_all)
def _get_grads_lists_empirical(self, tensors):
# Passing in a list of loss values is better than passing in the sum as
# the latter creates unnecessary ops on the default device
grads_flat = gradients_impl.gradients(
self._layers.eval_losses(),
nest.flatten(tensors),
colocate_gradients_with_ops=self._colocate_gradients_with_ops)
grads_all = nest.pack_sequence_as(tensors, grads_flat)
return tuple((grad,) for grad in grads_all)
def _get_transformed_random_signs(self):
transformed_random_signs = []
for loss in self._layers.losses:
with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]):
transformed_random_signs.append(
loss.multiply_fisher_factor(
utils.generate_random_signs(loss.fisher_factor_inner_shape)))
return transformed_random_signs
def _get_grads_lists_curvature_prop(self, tensors):
loss_inputs = list(loss.inputs for loss in self._layers.losses)
transformed_random_signs = self._get_transformed_random_signs()
grads_flat = gradients_impl.gradients(
nest.flatten(loss_inputs),
nest.flatten(tensors),
grad_ys=nest.flatten(transformed_random_signs),
colocate_gradients_with_ops=self._colocate_gradients_with_ops)
grads_all = nest.pack_sequence_as(tensors, grads_flat)
return tuple((grad,) for grad in grads_all)
def _get_grads_lists_exact(self, tensors):
"""No docstring required."""
# Loop over all coordinates of all losses.
grads_all = []
for loss in self._layers.losses:
with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]):
for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]):
transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot(
index)
grads_flat = gradients_impl.gradients(
loss.inputs,
nest.flatten(tensors),
grad_ys=transformed_one_hot,
colocate_gradients_with_ops=self._colocate_gradients_with_ops)
grads_all.append(nest.pack_sequence_as(tensors, grads_flat))
return zip(*grads_all)
class FisherEstimatorRoundRobin(placement.RoundRobinPlacementMixin,
FisherEstimator):
"""Fisher estimator which provides round robin device placement strategy."""
pass

View File

@ -1,31 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines the high-level Fisher estimator class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,line-too-long,wildcard-import
from tensorflow.contrib.kfac.python.ops.estimator import *
from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long,wildcard-import
_allowed_symbols = [
'FisherEstimator',
'make_fisher_estimator',
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)

File diff suppressed because it is too large Load Diff

View File

@ -1,45 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""FisherBlock definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,line-too-long,wildcard-import
from tensorflow.contrib.kfac.python.ops.fisher_blocks import *
from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long,wildcard-import
_allowed_symbols = [
'FisherBlock',
'FullFB',
'NaiveDiagonalFB',
'FullyConnectedDiagonalFB',
'KroneckerProductFB',
'EmbeddingKFACFB',
'FullyConnectedKFACBasicFB',
'ConvKFCBasicFB',
'ConvDiagonalFB',
'set_global_constants',
'compute_pi_tracenorm',
'compute_pi_adjusted_damping',
'num_conv_locations',
'normalize_damping',
'LEFT_MULTIPLY',
'RIGHT_MULTIPLY',
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)

File diff suppressed because it is too large Load Diff

View File

@ -1,38 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""FisherFactor definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,line-too-long,wildcard-import
from tensorflow.contrib.kfac.python.ops.fisher_factors import *
from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long,wildcard-import
_allowed_symbols = [
"inverse_initializer", "covariance_initializer",
"diagonal_covariance_initializer", "scope_string_from_params",
"scope_string_from_name", "scalar_or_tensor_to_string", "FisherFactor",
"InverseProvidingFactor", "FullFactor", "DiagonalFactor",
"NaiveDiagonalFactor", "EmbeddingInputKroneckerFactor",
"FullyConnectedDiagonalFactor", "FullyConnectedKroneckerFactor",
"ConvInputKroneckerFactor", "ConvOutputKroneckerFactor",
"ConvDiagonalFactor", "set_global_constants", "maybe_colocate_with",
"compute_cov", "append_homog"
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)

File diff suppressed because it is too large Load Diff

View File

@ -1,46 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Registry for layers and their parameters/variables.
This represents the collection of all layers in the approximate Fisher
information matrix to which a particular FisherBlock may belong. That is, we
might have several layer collections for one TF graph (if we have multiple K-FAC
optimizers being used, for example.)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,line-too-long,wildcard-import
from tensorflow.contrib.kfac.python.ops.layer_collection import *
from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long,wildcard-import
_allowed_symbols = [
"get_default_layer_collection",
"set_default_layer_collection",
"LayerParametersDict",
"LayerCollection",
"APPROX_KRONECKER_NAME",
"APPROX_DIAGONAL_NAME",
"APPROX_FULL_NAME",
"VARIABLE_SCOPE",
"APPROX_KRONECKER_INDEP_NAME",
"APPROX_KRONECKER_SERIES_1_NAME",
"APPROX_KRONECKER_SERIES_2_NAME"
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)

View File

@ -1,95 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SmartMatrices definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linalg
from tensorflow.python.ops.linalg import linalg_impl
from tensorflow.python.ops.linalg import linear_operator_util as lou
class LinearOperatorExtras(object): # pylint: disable=missing-docstring
def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
with self._name_scope(name, values=[x]):
if isinstance(x, ops.IndexedSlices):
return self._matmul_sparse(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
x = ops.convert_to_tensor(x, name="x")
self._check_input_dtype(x)
self_dim = -2 if adjoint else -1
arg_dim = -1 if adjoint_arg else -2
self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim])
return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
def matmul_right(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
with self._name_scope(name, values=[x]):
if isinstance(x, ops.IndexedSlices):
return self._matmul_right_sparse(
x, adjoint=adjoint, adjoint_arg=adjoint_arg)
x = ops.convert_to_tensor(x, name="x")
self._check_input_dtype(x)
self_dim = -1 if adjoint else -2
arg_dim = -2 if adjoint_arg else -1
self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim])
return self._matmul_right(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
class LinearOperatorFullMatrix(LinearOperatorExtras,
linalg.LinearOperatorFullMatrix):
# TODO(b/78117889) Remove this definition once core LinearOperator
# has _matmul_right.
def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
return lou.matmul_with_broadcast(
x, self._matrix, adjoint_a=adjoint_arg, adjoint_b=adjoint)
def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
raise NotImplementedError
def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):
assert not adjoint and not adjoint_arg
return utils.matmul_sparse_dense(x, self._matrix)
class LinearOperatorDiag(LinearOperatorExtras, # pylint: disable=missing-docstring
linalg.LinearOperatorDiag):
def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
diag_mat = math_ops.conj(self._diag) if adjoint else self._diag
x = linalg_impl.adjoint(x) if adjoint_arg else x
return diag_mat * x
def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
diag_mat = math_ops.conj(self._diag) if adjoint else self._diag
assert not adjoint_arg
return utils.matmul_diag_sparse(diag_mat, x)
def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):
raise NotImplementedError

View File

@ -1,754 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loss functions to be used by LayerCollection."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import six
from tensorflow.contrib.distributions.python.ops import onehot_categorical
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bernoulli
from tensorflow.python.ops.distributions import categorical
from tensorflow.python.ops.distributions import normal
@six.add_metaclass(abc.ABCMeta)
class LossFunction(object):
"""Abstract base class for loss functions.
Note that unlike typical loss functions used in neural networks these are
summed and not averaged across cases in the batch, since this is what the
users of this class (FisherEstimator and MatrixVectorProductComputer) will
be expecting. The implication of this is that you will may want to
normalize things like Fisher-vector products by the batch size when you
use this class. It depends on the use case.
"""
@abc.abstractproperty
def targets(self):
"""The targets being predicted by the model.
Returns:
None or Tensor of appropriate shape for calling self._evaluate() on.
"""
pass
@abc.abstractproperty
def inputs(self):
"""The inputs to the loss function (excluding the targets)."""
pass
def evaluate(self):
"""Evaluate the loss function on the targets."""
if self.targets is not None:
# We treat the targets as "constant". It's only the inputs that get
# "back-propped" through.
return self._evaluate(array_ops.stop_gradient(self.targets))
else:
raise Exception("Cannot evaluate losses with unspecified targets.")
@abc.abstractmethod
def _evaluate(self, targets):
"""Evaluates the negative log probability of the targets.
Args:
targets: Tensor that distribution can calculate log_prob() of.
Returns:
negative log probability of each target, summed across all targets.
"""
pass
@abc.abstractmethod
def multiply_hessian(self, vector):
"""Right-multiply a vector by the Hessian.
Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
of the loss function with respect to its inputs.
Args:
vector: The vector to multiply. Must be the same shape(s) as the
'inputs' property.
Returns:
The vector right-multiplied by the Hessian. Will be of the same shape(s)
as the 'inputs' property.
"""
pass
@abc.abstractmethod
def multiply_hessian_factor(self, vector):
"""Right-multiply a vector by a factor B of the Hessian.
Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
of the loss function with respect to its inputs. Typically this will be
block-diagonal across different cases in the batch, since the loss function
is typically summed across cases.
Note that B can be any matrix satisfying B * B^T = H where H is the Hessian,
but will agree with the one used in the other methods of this class.
Args:
vector: The vector to multiply. Must be of the shape given by the
'hessian_factor_inner_shape' property.
Returns:
The vector right-multiplied by B. Will be of the same shape(s) as the
'inputs' property.
"""
pass
@abc.abstractmethod
def multiply_hessian_factor_transpose(self, vector):
"""Right-multiply a vector by the transpose of a factor B of the Hessian.
Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
of the loss function with respect to its inputs. Typically this will be
block-diagonal across different cases in the batch, since the loss function
is typically summed across cases.
Note that B can be any matrix satisfying B * B^T = H where H is the Hessian,
but will agree with the one used in the other methods of this class.
Args:
vector: The vector to multiply. Must be the same shape(s) as the
'inputs' property.
Returns:
The vector right-multiplied by B^T. Will be of the shape given by the
'hessian_factor_inner_shape' property.
"""
pass
@abc.abstractmethod
def multiply_hessian_factor_replicated_one_hot(self, index):
"""Right-multiply a replicated-one-hot vector by a factor B of the Hessian.
Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
of the loss function with respect to its inputs. Typically this will be
block-diagonal across different cases in the batch, since the loss function
is typically summed across cases.
A 'replicated-one-hot' vector means a tensor which, for each slice along the
batch dimension (assumed to be dimension 0), is 1.0 in the entry
corresponding to the given index and 0 elsewhere.
Note that B can be any matrix satisfying B * B^T = H where H is the Hessian,
but will agree with the one used in the other methods of this class.
Args:
index: A tuple representing in the index of the entry in each slice that
is 1.0. Note that len(index) must be equal to the number of elements
of the 'hessian_factor_inner_shape' tensor minus one.
Returns:
The vector right-multiplied by B^T. Will be of the same shape(s) as the
'inputs' property.
"""
pass
@abc.abstractproperty
def hessian_factor_inner_shape(self):
"""The shape of the tensor returned by multiply_hessian_factor."""
pass
@abc.abstractproperty
def hessian_factor_inner_static_shape(self):
"""Static version of hessian_factor_inner_shape."""
pass
@six.add_metaclass(abc.ABCMeta)
class NegativeLogProbLoss(LossFunction):
"""Abstract base class for loss functions that are negative log probs."""
def __init__(self, seed=None):
self._default_seed = seed
super(NegativeLogProbLoss, self).__init__()
@property
def inputs(self):
return self.params
@abc.abstractproperty
def params(self):
"""Parameters to the underlying distribution."""
pass
@abc.abstractmethod
def multiply_fisher(self, vector):
"""Right-multiply a vector by the Fisher.
Args:
vector: The vector to multiply. Must be the same shape(s) as the
'inputs' property.
Returns:
The vector right-multiplied by the Fisher. Will be of the same shape(s)
as the 'inputs' property.
"""
pass
@abc.abstractmethod
def multiply_fisher_factor(self, vector):
"""Right-multiply a vector by a factor B of the Fisher.
Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
product of gradients) with respect to the parameters of the underlying
probability distribution (whose log-prob defines the loss). Typically this
will be block-diagonal across different cases in the batch, since the
distribution is usually (but not always) conditionally iid across different
cases.
Note that B can be any matrix satisfying B * B^T = F where F is the Fisher,
but will agree with the one used in the other methods of this class.
Args:
vector: The vector to multiply. Must be of the shape given by the
'fisher_factor_inner_shape' property.
Returns:
The vector right-multiplied by B. Will be of the same shape(s) as the
'inputs' property.
"""
pass
@abc.abstractmethod
def multiply_fisher_factor_transpose(self, vector):
"""Right-multiply a vector by the transpose of a factor B of the Fisher.
Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
product of gradients) with respect to the parameters of the underlying
probability distribution (whose log-prob defines the loss). Typically this
will be block-diagonal across different cases in the batch, since the
distribution is usually (but not always) conditionally iid across different
cases.
Note that B can be any matrix satisfying B * B^T = F where F is the Fisher,
but will agree with the one used in the other methods of this class.
Args:
vector: The vector to multiply. Must be the same shape(s) as the
'inputs' property.
Returns:
The vector right-multiplied by B^T. Will be of the shape given by the
'fisher_factor_inner_shape' property.
"""
pass
@abc.abstractmethod
def multiply_fisher_factor_replicated_one_hot(self, index):
"""Right-multiply a replicated-one-hot vector by a factor B of the Fisher.
Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
product of gradients) with respect to the parameters of the underlying
probability distribution (whose log-prob defines the loss). Typically this
will be block-diagonal across different cases in the batch, since the
distribution is usually (but not always) conditionally iid across different
cases.
A 'replicated-one-hot' vector means a tensor which, for each slice along the
batch dimension (assumed to be dimension 0), is 1.0 in the entry
corresponding to the given index and 0 elsewhere.
Note that B can be any matrix satisfying B * B^T = H where H is the Fisher,
but will agree with the one used in the other methods of this class.
Args:
index: A tuple representing in the index of the entry in each slice that
is 1.0. Note that len(index) must be equal to the number of elements
of the 'fisher_factor_inner_shape' tensor minus one.
Returns:
The vector right-multiplied by B. Will be of the same shape(s) as the
'inputs' property.
"""
pass
@abc.abstractproperty
def fisher_factor_inner_shape(self):
"""The shape of the tensor returned by multiply_fisher_factor."""
pass
@abc.abstractproperty
def fisher_factor_inner_static_shape(self):
"""Static version of fisher_factor_inner_shape."""
pass
@abc.abstractmethod
def sample(self, seed):
"""Sample 'targets' from the underlying distribution."""
pass
def evaluate_on_sample(self, seed=None):
"""Evaluates the log probability on a random sample.
Args:
seed: int or None. Random seed for this draw from the distribution.
Returns:
Log probability of sampled targets, summed across examples.
"""
if seed is None:
seed = self._default_seed
# We treat the targets as "constant". It's only the inputs that get
# "back-propped" through.
return self._evaluate(array_ops.stop_gradient(self.sample(seed)))
# TODO(jamesmartens): should this just inherit from object to avoid "diamond"
# inheritance, or is there a better way?
class NaturalParamsNegativeLogProbLoss(NegativeLogProbLoss):
"""Base class for neg log prob losses whose inputs are 'natural' parameters.
Note that the Hessian and Fisher for natural parameters of exponential-
family models are the same, hence the purpose of this class.
See here: https://arxiv.org/abs/1412.1193
'Natural parameters' are defined for exponential-family models. See for
example: https://en.wikipedia.org/wiki/Exponential_family
"""
def multiply_hessian(self, vector):
return self.multiply_fisher(vector)
def multiply_hessian_factor(self, vector):
return self.multiply_fisher_factor(vector)
def multiply_hessian_factor_transpose(self, vector):
return self.multiply_fisher_factor_transpose(vector)
def multiply_hessian_factor_replicated_one_hot(self, index):
return self.multiply_fisher_factor_replicated_one_hot(index)
@property
def hessian_factor_inner_shape(self):
return self.fisher_factor_inner_shape
@property
def hessian_factor_inner_static_shape(self):
return self.fisher_factor_inner_shape
class DistributionNegativeLogProbLoss(NegativeLogProbLoss):
"""Base class for neg log prob losses that use the TF Distribution classes."""
def __init__(self, seed=None):
super(DistributionNegativeLogProbLoss, self).__init__(seed=seed)
@abc.abstractproperty
def dist(self):
"""The underlying tf.distributions.Distribution."""
pass
def _evaluate(self, targets):
return -math_ops.reduce_sum(self.dist.log_prob(targets))
def sample(self, seed):
return self.dist.sample(seed=seed)
class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss,
NaturalParamsNegativeLogProbLoss):
"""Neg log prob loss for a normal distribution parameterized by a mean vector.
Note that the covariance is treated as a constant 'var' times the identity.
Also note that the Fisher for such a normal distribution with respect the mean
parameter is given by:
F = (1/var) * I
See for example https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf.
"""
def __init__(self, mean, var=0.5, targets=None, seed=None):
self._mean = mean
self._var = var
self._targets = targets
super(NormalMeanNegativeLogProbLoss, self).__init__(seed=seed)
@property
def targets(self):
return self._targets
@property
def dist(self):
return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._var))
@property
def params(self):
return self._mean
def multiply_fisher(self, vector):
return (1. / self._var) * vector
def multiply_fisher_factor(self, vector):
return self._var**-0.5 * vector
def multiply_fisher_factor_transpose(self, vector):
return self.multiply_fisher_factor(vector) # it's symmetric in this case
def multiply_fisher_factor_replicated_one_hot(self, index):
assert len(index) == 1, "Length of index was {}".format(len(index))
ones_slice = array_ops.expand_dims(
array_ops.ones(array_ops.shape(self._mean)[:1], dtype=self._mean.dtype),
axis=-1)
output_slice = self._var**-0.5 * ones_slice
return insert_slice_in_zeros(output_slice, 1, int(self._mean.shape[1]),
index[0])
@property
def fisher_factor_inner_shape(self):
return array_ops.shape(self._mean)
@property
def fisher_factor_inner_static_shape(self):
return self._mean.shape
class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss):
"""Negative log prob loss for a normal distribution with mean and variance.
This class parameterizes a multivariate normal distribution with n independent
dimensions. Unlike `NormalMeanNegativeLogProbLoss`, this class does not
assume the variance is held constant. The Fisher Information for n = 1
is given by,
F = [[1 / variance, 0],
[ 0, 0.5 / variance^2]]
where the parameters of the distribution are concatenated into a single
vector as [mean, variance]. For n > 1, the mean parameter vector is
concatenated with the variance parameter vector.
See https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf for derivation.
"""
def __init__(self, mean, variance, targets=None, seed=None):
assert len(mean.shape) == 2, "Expect 2D mean tensor."
assert len(variance.shape) == 2, "Expect 2D variance tensor."
self._mean = mean
self._variance = variance
self._targets = targets
super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(seed=seed)
@property
def targets(self):
return self._targets
@property
def dist(self):
return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._variance))
@property
def params(self):
return self._mean, self._variance
def _concat(self, mean, variance):
return array_ops.concat([mean, variance], axis=-1)
def _split(self, params):
return array_ops.split(params, 2, axis=-1)
@property
def _fisher_mean(self):
return 1. / self._variance
@property
def _fisher_mean_factor(self):
return 1. / math_ops.sqrt(self._variance)
@property
def _fisher_var(self):
return 1. / (2 * math_ops.square(self._variance))
@property
def _fisher_var_factor(self):
return 1. / (math_ops.sqrt(2.) * self._variance)
def multiply_fisher(self, vecs):
mean_vec, var_vec = vecs
return (self._fisher_mean * mean_vec, self._fisher_var * var_vec)
def multiply_fisher_factor(self, vecs):
mean_vec, var_vec = self._split(vecs)
return (self._fisher_mean_factor * mean_vec,
self._fisher_var_factor * var_vec)
def multiply_fisher_factor_transpose(self, vecs):
mean_vec, var_vec = vecs
return self._concat(self._fisher_mean_factor * mean_vec,
self._fisher_var_factor * var_vec)
def multiply_fisher_factor_replicated_one_hot(self, index):
assert len(index) == 1, "Length of index was {}".format(len(index))
index = index[0]
if index < int(self._mean.shape[-1]):
# Index corresponds to mean parameter.
mean_slice = self._fisher_mean_factor[:, index]
mean_slice = array_ops.expand_dims(mean_slice, axis=-1)
mean_output = insert_slice_in_zeros(mean_slice, 1, int(
self._mean.shape[1]), index)
var_output = array_ops.zeros_like(mean_output)
else:
index -= int(self._mean.shape[-1])
# Index corresponds to variance parameter.
var_slice = self._fisher_var_factor[:, index]
var_slice = array_ops.expand_dims(var_slice, axis=-1)
var_output = insert_slice_in_zeros(var_slice, 1,
int(self._variance.shape[1]), index)
mean_output = array_ops.zeros_like(var_output)
return mean_output, var_output
@property
def fisher_factor_inner_shape(self):
return array_ops.concat(
[
array_ops.shape(self._mean)[:-1],
2 * array_ops.shape(self._mean)[-1:]
],
axis=0)
@property
def fisher_factor_inner_static_shape(self):
shape = self._mean.shape.as_list()
return tensor_shape.TensorShape(shape[-1:] + [2 * shape[-1]])
def multiply_hessian(self, vector):
raise NotImplementedError()
def multiply_hessian_factor(self, vector):
raise NotImplementedError()
def multiply_hessian_factor_transpose(self, vector):
raise NotImplementedError()
def multiply_hessian_factor_replicated_one_hot(self, index):
raise NotImplementedError()
@property
def hessian_factor_inner_shape(self):
raise NotImplementedError()
@property
def hessian_factor_inner_static_shape(self):
raise NotImplementedError()
class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss,
NaturalParamsNegativeLogProbLoss):
"""Neg log prob loss for a categorical distribution parameterized by logits.
Note that the Fisher (for a single case) of a categorical distribution, with
respect to the natural parameters (i.e. the logits), is given by:
F = diag(p) - p*p^T
where p = softmax(logits). F can be factorized as F = B * B^T where
B = diag(q) - p*q^T
where q is the entry-wise square root of p. This is easy to verify using the
fact that q^T*q = 1.
"""
def __init__(self, logits, targets=None, seed=None):
"""Instantiates a CategoricalLogitsNegativeLogProbLoss.
Args:
logits: Tensor of shape [batch_size, output_size]. Parameters for
underlying distribution.
targets: None or Tensor of shape [output_size]. Each elements contains an
index in [0, output_size).
seed: int or None. Default random seed when sampling.
"""
self._logits = logits
self._targets = targets
super(CategoricalLogitsNegativeLogProbLoss, self).__init__(seed=seed)
@property
def targets(self):
return self._targets
@property
def dist(self):
return categorical.Categorical(logits=self._logits)
@property
def _probs(self):
return self.dist.probs
@property
def _sqrt_probs(self):
return math_ops.sqrt(self._probs)
@property
def params(self):
return self._logits
def multiply_fisher(self, vector):
probs = self._probs
return vector * probs - probs * math_ops.reduce_sum(
vector * probs, axis=-1, keepdims=True)
def multiply_fisher_factor(self, vector):
probs = self._probs
sqrt_probs = self._sqrt_probs
return sqrt_probs * vector - probs * math_ops.reduce_sum(
sqrt_probs * vector, axis=-1, keepdims=True)
def multiply_fisher_factor_transpose(self, vector):
probs = self._probs
sqrt_probs = self._sqrt_probs
return sqrt_probs * vector - sqrt_probs * math_ops.reduce_sum(
probs * vector, axis=-1, keepdims=True)
def multiply_fisher_factor_replicated_one_hot(self, index):
assert len(index) == 1, "Length of index was {}".format(len(index))
probs = self._probs
sqrt_probs = self._sqrt_probs
sqrt_probs_slice = array_ops.expand_dims(sqrt_probs[:, index[0]], -1)
padded_slice = insert_slice_in_zeros(sqrt_probs_slice, 1,
int(sqrt_probs.shape[1]), index[0])
return padded_slice - probs * sqrt_probs_slice
@property
def fisher_factor_inner_shape(self):
return array_ops.shape(self._logits)
@property
def fisher_factor_inner_static_shape(self):
return self._logits.shape
class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss,
NaturalParamsNegativeLogProbLoss):
"""Neg log prob loss for multiple Bernoulli distributions param'd by logits.
Represents N independent Bernoulli distributions where N = len(logits). Its
Fisher Information matrix is given by,
F = diag(p * (1-p))
p = sigmoid(logits)
As F is diagonal with positive entries, its factor B is,
B = diag(sqrt(p * (1-p)))
"""
def __init__(self, logits, targets=None, seed=None):
self._logits = logits
self._targets = targets
super(MultiBernoulliNegativeLogProbLoss, self).__init__(seed=seed)
@property
def targets(self):
return self._targets
@property
def dist(self):
return bernoulli.Bernoulli(logits=self._logits)
@property
def _probs(self):
return self.dist.probs
@property
def params(self):
return self._logits
def multiply_fisher(self, vector):
return self._probs * (1 - self._probs) * vector
def multiply_fisher_factor(self, vector):
return math_ops.sqrt(self._probs * (1 - self._probs)) * vector
def multiply_fisher_factor_transpose(self, vector):
return self.multiply_fisher_factor(vector) # it's symmetric in this case
def multiply_fisher_factor_replicated_one_hot(self, index):
assert len(index) == 1, "Length of index was {}".format(len(index))
probs_slice = array_ops.expand_dims(self._probs[:, index[0]], -1)
output_slice = math_ops.sqrt(probs_slice * (1 - probs_slice))
return insert_slice_in_zeros(output_slice, 1, int(self._logits.shape[1]),
index[0])
@property
def fisher_factor_inner_shape(self):
return array_ops.shape(self._logits)
@property
def fisher_factor_inner_static_shape(self):
return self._logits.shape
def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position):
"""Inserts slice into a larger tensor of zeros.
Forms a new tensor which is the same shape as slice_to_insert, except that
the dimension given by 'dim' is expanded to the size given by 'dim_size'.
'position' determines the position (index) at which to insert the slice within
that dimension.
Assumes slice_to_insert.shape[dim] = 1.
Args:
slice_to_insert: The slice to insert.
dim: The dimension which to expand with zeros.
dim_size: The new size of the 'dim' dimension.
position: The position of 'slice_to_insert' in the new tensor.
Returns:
The new tensor.
Raises:
ValueError: If the slice's shape at the given dim is not 1.
"""
slice_shape = slice_to_insert.shape
if slice_shape[dim] != 1:
raise ValueError("Expected slice_to_insert.shape to have {} dim of 1, but "
"was {}".format(dim, slice_to_insert.shape[dim]))
before = [0] * int(len(slice_shape))
after = before[:]
before[dim] = position
after[dim] = dim_size - position - 1
return array_ops.pad(slice_to_insert, list(zip(before, after)))
class OnehotCategoricalLogitsNegativeLogProbLoss(
CategoricalLogitsNegativeLogProbLoss):
"""Neg log prob loss for a categorical distribution with onehot targets.
Identical to CategoricalLogitsNegativeLogProbLoss except that the underlying
distribution is OneHotCategorical as opposed to Categorical.
"""
@property
def dist(self):
return onehot_categorical.OneHotCategorical(logits=self._logits)

View File

@ -1,39 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loss functions to be used by LayerCollection."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,line-too-long,wildcard-import
from tensorflow.contrib.kfac.python.ops.loss_functions import *
from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long,wildcard-import
_allowed_symbols = [
"LossFunction",
"NegativeLogProbLoss",
"NaturalParamsNegativeLogProbLoss",
"DistributionNegativeLogProbLoss",
"NormalMeanNegativeLogProbLoss",
"NormalMeanVarianceNegativeLogProbLoss",
"CategoricalLogitsNegativeLogProbLoss",
"OnehotCategoricalLogitsNegativeLogProbLoss",
"MultiBernoulliNegativeLogProbLoss",
"insert_slice_in_zeros",
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)

View File

@ -1,69 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper for choosing which op to run next in a distributed setting."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops as tf_ops
class OpQueue(object):
"""Class for choosing which Op to run next.
Constructs an infinitely repeating sequence of Ops in shuffled order.
In K-FAC, this can be used to distribute inverse update operations among
workers.
"""
def __init__(self, ops, seed=None):
"""Initializes an OpQueue.
Args:
ops: list of TensorFlow Ops. Ops to be selected from. All workers must
initialize with the same set of ops.
seed: int or None. Random seed used when shuffling order of ops.
"""
self._ops_by_name = {op.name: op for op in ops}
# Construct a (shuffled) Dataset with Op names.
op_names = tf_ops.convert_to_tensor(list(sorted(op.name for op in ops)))
op_names_dataset = (dataset_ops.Dataset.from_tensor_slices(op_names)
.shuffle(len(ops), seed=seed).repeat())
self._next_op_name = op_names_dataset.make_one_shot_iterator().get_next()
@property
def ops(self):
"""Ops this OpQueue can return in next_op()."""
return self._ops_by_name.values()
def next_op(self, sess):
"""Chooses which op to run next.
Note: This call will make a call to sess.run().
Args:
sess: tf.Session.
Returns:
Next Op chosen from 'ops'.
"""
# In Python 3, type(next_op_name) == bytes. Calling bytes.decode('ascii')
# returns a str.
next_op_name = sess.run(self._next_op_name).decode('ascii')
return self._ops_by_name[next_op_name]

View File

@ -1,30 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper for choosing which op to run next in a distributed setting."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,line-too-long,wildcard-import
from tensorflow.contrib.kfac.python.ops.op_queue import *
from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long,wildcard-import
_allowed_symbols = [
'OpQueue',
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)

View File

@ -1,727 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The KFAC optimizer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import warnings
# pylint disable=long-line
from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp
from tensorflow.contrib.kfac.python.ops import estimator as est
# pylint enable=long-line
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.training import gradient_descent
class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
"""The KFAC Optimizer (https://arxiv.org/abs/1503.05671)."""
def __init__(self,
learning_rate,
cov_ema_decay,
damping,
layer_collection,
var_list=None,
momentum=0.9,
momentum_type="regular",
norm_constraint=None,
name="KFAC",
estimation_mode="gradients",
colocate_gradients_with_ops=True,
batch_size=None,
placement_strategy=None,
**kwargs):
"""Initializes the KFAC optimizer with the given settings.
Args:
learning_rate: The base learning rate for the optimizer. Should probably
be set to 1.0 when using momentum_type = 'qmodel', but can still be
set lowered if desired (effectively lowering the trust in the
quadratic model.)
cov_ema_decay: The decay factor used when calculating the covariance
estimate moving averages.
damping: The damping factor used to stabilize training due to errors in
the local approximation with the Fisher information matrix, and to
regularize the update direction by making it closer to the gradient.
If damping is adapted during training then this value is used for
initializing damping variable.
(Higher damping means the update looks more like a standard gradient
update - see Tikhonov regularization.)
layer_collection: The layer collection object, which holds the fisher
blocks, Kronecker factors, and losses associated with the
graph. The layer_collection cannot be modified after KfacOptimizer's
initialization.
var_list: Optional list or tuple of variables to train. Defaults to the
list of variables collected in the graph under the key
`GraphKeys.TRAINABLE_VARIABLES`.
momentum: The momentum decay constant to use. Only applies when
momentum_type is 'regular' or 'adam'. (Default: 0.9)
momentum_type: The type of momentum to use in this optimizer, one of
'regular', 'adam', or 'qmodel'. (Default: 'regular')
norm_constraint: float or Tensor. If specified, the update is scaled down
so that its approximate squared Fisher norm v^T F v is at most the
specified value. May only be used with momentum type 'regular'.
(Default: None)
name: The name for this optimizer. (Default: 'KFAC')
estimation_mode: The type of estimator to use for the Fishers. Can be
'gradients', 'empirical', 'curvature_propagation', or 'exact'.
(Default: 'gradients'). See the doc-string for FisherEstimator for
more a more detailed description of these options.
colocate_gradients_with_ops: Whether we should request gradients we
compute in the estimator be colocated with their respective ops.
(Default: True)
batch_size: The size of the mini-batch. Only needed when momentum_type
== 'qmodel' or when automatic adjustment is used. (Default: None)
placement_strategy: string, Device placement strategy used when creating
covariance variables, covariance ops, and inverse ops.
(Default: `None`)
**kwargs: Arguments to be passed to specific placement
strategy mixin. Check `placement.RoundRobinPlacementMixin` for example.
Raises:
ValueError: If the momentum type is unsupported.
ValueError: If clipping is used with momentum type other than 'regular'.
ValueError: If no losses have been registered with layer_collection.
ValueError: If momentum is non-zero and momentum_type is not 'regular'
or 'adam'.
"""
warnings.warn(
"third_party.tensorflow.contrib.kfac is deprecated."
"This will be removed on 15-07-2018. Check README for further details.",
DeprecationWarning)
# Parameters to be passed to the Fisher estimator:
self._variables = var_list or tf_variables.trainable_variables
self._cov_ema_decay = cov_ema_decay
self._layers = layer_collection
self._estimation_mode = estimation_mode
self._colocate_gradients_with_ops = colocate_gradients_with_ops
# The below parameters are required only if damping needs to be adapted.
# These parameters can be set by calling
# set_damping_adaptation_params() explicitly.
self._damping_adaptation_decay = 0.95
self._damping_adaptation_interval = 5
# Check section 6.5 KFAC paper. omega(1) = pow(damping decay, interval)
self._omega = (
self._damping_adaptation_decay**self._damping_adaptation_interval)
self._adapt_damping = False
self._min_damping = 1e-5
self._prev_train_batch = None
self._is_chief = False
self._loss_fn = None
self._damping_constant = damping
self._damping = None
self._rho = None
self._prev_loss = None
self._q_model_change = None
self._update_damping_op = None
momentum_type = momentum_type.lower()
legal_momentum_types = ["regular", "adam", "qmodel"]
if momentum_type not in legal_momentum_types:
raise ValueError("Unsupported momentum type {}. Must be one of {}."
.format(momentum_type, legal_momentum_types))
if momentum_type != "regular" and norm_constraint is not None:
raise ValueError("Update clipping is only supported with momentum "
"type 'regular'.")
if momentum_type not in ["regular", "adam"] and momentum != 0:
raise ValueError("Momentum must be unspecified if using a momentum_type "
"other than 'regular' or 'adam'.")
# Extra parameters of the optimizer
self._momentum = momentum
self._momentum_type = momentum_type
self._norm_constraint = norm_constraint
self._batch_size = batch_size
self._placement_strategy = placement_strategy
with variable_scope.variable_scope(name):
self._fisher_est = est.make_fisher_estimator(
placement_strategy=placement_strategy,
variables=self._variables,
cov_ema_decay=self._cov_ema_decay,
damping=self.damping,
layer_collection=self._layers,
exps=(-1,),
estimation_mode=self._estimation_mode,
colocate_gradients_with_ops=self._colocate_gradients_with_ops,
**kwargs)
super(KfacOptimizer, self).__init__(learning_rate, name=name)
def set_damping_adaptation_params(self,
is_chief,
prev_train_batch,
loss_fn,
min_damping=1e-5,
damping_adaptation_decay=0.99,
damping_adaptation_interval=5):
"""Sets parameters required to adapt damping during training.
When called, enables damping adaptation according to the Levenberg-Marquardt
style rule described in Section 6.5 of "Optimizing Neural Networks with
Kronecker-factored Approximate Curvature".
Note that this function creates Tensorflow variables which store a few
scalars and are accessed by the ops which update the damping (as part
of the training op returned by the minimize() method).
Args:
is_chief: `Boolean`, `True` if the worker is chief.
prev_train_batch: Training data used to minimize loss in the previous
step. This will be used to evaluate loss by calling
`loss_fn(prev_train_batch)`.
loss_fn: `function` that takes as input training data tensor and returns
a scalar loss.
min_damping: `float`(Optional), Minimum value the damping parameter
can take. Default value 1e-5.
damping_adaptation_decay: `float`(Optional), The `damping` parameter is
multiplied by the `damping_adaptation_decay` every
`damping_adaptation_interval` number of iterations. Default value 0.99.
damping_adaptation_interval: `int`(Optional), Number of steps in between
updating the `damping` parameter. Default value 5.
Raises:
ValueError: If `set_damping_adaptation_params` is already called and the
the `adapt_damping` is `True`.
"""
if self._adapt_damping:
raise ValueError("Damping adaptation parameters already set.")
with variable_scope.variable_scope(self.get_name()):
self._adapt_damping = True
self._is_chief = is_chief
self._prev_train_batch = prev_train_batch
self._loss_fn = loss_fn
self._damping_adaptation_decay = damping_adaptation_decay
self._damping_adaptation_interval = damping_adaptation_interval
self._omega = (
self._damping_adaptation_decay**self._damping_adaptation_interval)
self._min_damping = min_damping
self._rho = variable_scope.get_variable(
"rho", shape=(), dtype=dtypes.float32, trainable=False) # LM ratio.
self._prev_loss = variable_scope.get_variable(
"prev_loss", shape=(), dtype=dtypes.float32, trainable=False)
self._q_model_change = variable_scope.get_variable(
"q_model_change", shape=(), dtype=dtypes.float32, trainable=False)
self._damping = variable_scope.get_variable(
"damping", initializer=self._damping_constant, trainable=False)
@property
def variables(self):
return self._fisher_est.variables
@property
def damping(self):
if self._damping:
return self._damping
else:
return self._damping_constant
@property
def damping_adaptation_interval(self):
return self._damping_adaptation_interval
def make_vars_and_create_op_thunks(self):
"""Make vars and create op thunks.
Returns:
cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
the list of factors given by the "factors" property.
inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
the list of factors given by the "factors" property.
"""
scope = self.get_name() + "/" + self._fisher_est.name
return self._fisher_est.make_vars_and_create_op_thunks(scope=scope)
def create_ops_and_vars_thunks(self):
"""Create thunks that make the ops and vars on demand.
This function returns 4 lists of thunks: cov_variable_thunks,
cov_update_thunks, inv_variable_thunks, and inv_update_thunks.
The length of each list is the number of factors and the i-th element of
each list corresponds to the i-th factor (given by the "factors" property).
Note that the execution of these thunks must happen in a certain
partial order. The i-th element of cov_variable_thunks must execute
before the i-th element of cov_update_thunks (and also the i-th element
of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks
must execute before the i-th element of inv_update_thunks.
TL;DR (oversimplified): Execute the thunks according to the order that
they are returned.
Returns:
cov_variable_thunks: A list of thunks that make the cov variables.
cov_update_thunks: A list of thunks that make the cov update ops.
inv_variable_thunks: A list of thunks that make the inv variables.
inv_update_thunks: A list of thunks that make the inv update ops.
"""
scope = self.get_name() + "/" + self._fisher_est.name
return self._fisher_est.create_ops_and_vars_thunks(scope=scope)
def minimize(self, *args, **kwargs):
# Should this variable scope encompass everything below? Or will the super-
# class make another copy of the same name scope?
with variable_scope.variable_scope(self.get_name()):
kwargs["var_list"] = kwargs.get("var_list") or self.variables
if set(kwargs["var_list"]) != set(self.variables):
raise ValueError("var_list doesn't match with set of Fisher-estimating "
"variables.")
if self._adapt_damping and self._is_chief:
global_step = kwargs.get("global_step", None)
if not global_step:
raise KeyError("global_step needs to be passed to optimizer.minimize "
"if damping parameter is adapted.")
update_damping_op = self._update_damping(self._prev_train_batch,
global_step)
with ops.control_dependencies([update_damping_op]):
loss = args[0]
loss_assign_op = state_ops.assign(self._prev_loss, loss)
train_op = super(KfacOptimizer, self).minimize(*args, **kwargs)
return control_flow_ops.group(loss_assign_op, train_op)
else:
return super(KfacOptimizer, self).minimize(*args, **kwargs)
def compute_gradients(self, *args, **kwargs):
# args[1] could be our var_list
if len(args) > 1:
var_list = args[1]
else:
kwargs["var_list"] = kwargs.get("var_list") or self.variables
var_list = kwargs["var_list"]
if set(var_list) != set(self.variables):
raise ValueError("var_list doesn't match with set of Fisher-estimating "
"variables.")
return super(KfacOptimizer, self).compute_gradients(*args, **kwargs)
def apply_gradients(self, grads_and_vars, *args, **kwargs):
"""Applies gradients to variables.
Args:
grads_and_vars: List of (gradient, variable) pairs.
*args: Additional arguments for super.apply_gradients.
**kwargs: Additional keyword arguments for super.apply_gradients.
Returns:
An `Operation` that applies the specified gradients.
"""
# In Python 3, grads_and_vars can be a zip() object which can only be
# iterated over once. By converting it to a list, we ensure that it can be
# iterated over more than once.
grads_and_vars = list(grads_and_vars)
# Compute step.
steps_and_vars = self._compute_update_steps(grads_and_vars)
# Update trainable variables with this step.
return super(KfacOptimizer, self).apply_gradients(steps_and_vars, *args,
**kwargs)
def _squared_fisher_norm(self, grads_and_vars, precon_grads_and_vars):
"""Computes the squared (approximate) Fisher norm of the updates.
This is defined as v^T F v, where F is the approximate Fisher matrix
as computed by the estimator, and v = F^{-1} g, where g is the gradient.
This is computed efficiently as v^T g.
Args:
grads_and_vars: List of (gradient, variable) pairs.
precon_grads_and_vars: List of (preconditioned gradient, variable) pairs.
Must be the result of calling `self._fisher_est.multiply_inverse`
on `grads_and_vars`.
Returns:
Scalar representing the squared norm.
Raises:
ValueError: if the two list arguments do not contain the same variables,
in the same order.
"""
for (_, gvar), (_, pgvar) in zip(grads_and_vars, precon_grads_and_vars):
if gvar is not pgvar:
raise ValueError("The variables referenced by the two arguments "
"must match.")
terms = [
math_ops.reduce_sum(grad * pgrad)
for (grad, _), (pgrad, _) in zip(grads_and_vars, precon_grads_and_vars)
]
return math_ops.reduce_sum(terms)
def _update_clip_coeff(self, grads_and_vars, precon_grads_and_vars):
"""Computes the scale factor for the update to satisfy the norm constraint.
Defined as min(1, sqrt(c / r^T F r)), where c is the norm constraint,
F is the approximate Fisher matrix, and r is the update vector, i.e.
-alpha * v, where alpha is the learning rate, and v is the preconditioned
gradient.
This is based on Section 5 of Ba et al., Distributed Second-Order
Optimization using Kronecker-Factored Approximations. Note that they
absorb the learning rate alpha (which they denote eta_max) into the formula
for the coefficient, while in our implementation, the rescaling is done
before multiplying by alpha. Hence, our formula differs from theirs by a
factor of alpha.
Args:
grads_and_vars: List of (gradient, variable) pairs.
precon_grads_and_vars: List of (preconditioned gradient, variable) pairs.
Must be the result of calling `self._fisher_est.multiply_inverse`
on `grads_and_vars`.
Returns:
Scalar representing the coefficient which should be applied to the
preconditioned gradients to satisfy the norm constraint.
"""
sq_norm_grad = self._squared_fisher_norm(grads_and_vars,
precon_grads_and_vars)
sq_norm_up = sq_norm_grad * self._learning_rate**2
return math_ops.minimum(1.,
math_ops.sqrt(self._norm_constraint / sq_norm_up))
def _clip_updates(self, grads_and_vars, precon_grads_and_vars):
"""Rescales the preconditioned gradients to satisfy the norm constraint.
Rescales the preconditioned gradients such that the resulting update r
(after multiplying by the learning rate) will satisfy the norm constraint.
This constraint is that r^T F r <= C, where F is the approximate Fisher
matrix, and C is the norm_constraint attribute. See Section 5 of
Ba et al., Distributed Second-Order Optimization using Kronecker-Factored
Approximations.
Args:
grads_and_vars: List of (gradient, variable) pairs.
precon_grads_and_vars: List of (preconditioned gradient, variable) pairs.
Must be the result of calling `self._fisher_est.multiply_inverse`
on `grads_and_vars`.
Returns:
List of (rescaled preconditioned gradient, variable) pairs.
"""
coeff = self._update_clip_coeff(grads_and_vars, precon_grads_and_vars)
return [(pgrad * coeff, var) for pgrad, var in precon_grads_and_vars]
def _compute_prev_updates(self, variables):
"""Computes previous updates as negative velocities scaled by learning rate.
Args:
variables: List of variables in the graph that the update will be
applied to.
Returns:
List of previous updates applied to the `variables`.
"""
return list(
-1 * self._learning_rate * self._zeros_slot(var, "velocity", self._name)
for var in variables)
def _compute_qmodel_hyperparams(self, precon_grads, prev_updates, grads,
variables):
"""Compute optimal update hyperparameters from the quadratic model.
More specifically, if L is the loss we minimize a quadratic approximation
of L(theta + d) which we denote by qmodel(d) with
d = alpha*precon_grad + mu*prev_update with respect to alpha and mu, where
qmodel(d) = (1/2) * d^T * B * d + grad^T*d + L(theta) .
Unlike in the KL clipping approach we use the non-approximated quadratic
model where the curvature matrix C is the true Fisher on the current
mini-batch (computed without any approximations beyond mini-batch sampling),
with the usual Tikhonov damping/regularization applied,
C = F + damping * I
See Section 7 of https://arxiv.org/abs/1503.05671 for a derivation of
the formula. See Appendix C for a discussion of the trick of using
a factorized Fisher matrix to more efficiently compute the required
vector-matrix-vector products.
Note that the elements of all 4 lists passed to this function must
be in correspondence with each other.
Args:
precon_grads: List of preconditioned gradients.
prev_updates: List of updates computed at the previous iteration.
grads: List of gradients.
variables: List of variables in the graph that the update will be
applied to. (Note that this function doesn't actually apply the
update.)
Returns:
(alpha, mu, qmodel_change), where alpha and mu are chosen to optimize the
quadratic model, and
qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0)
= qmodel(alpha*precon_grad + mu*prev_update) - L(theta).
"""
cmvpc = cmvp.CurvatureMatrixVectorProductComputer(self._layers.losses,
variables)
# compute the matrix-vector products with the transposed Fisher factor
fft_precon_grads = cmvpc.multiply_fisher_factor_transpose(precon_grads)
fft_prev_updates = cmvpc.multiply_fisher_factor_transpose(prev_updates)
batch_size = math_ops.cast(
self._batch_size, dtype=fft_precon_grads[0].dtype)
# compute the entries of the 2x2 matrix
m_11 = (
_inner_product_list(fft_precon_grads, fft_precon_grads) / batch_size +
self.damping * _inner_product_list(precon_grads, precon_grads))
m_21 = (
_inner_product_list(fft_prev_updates, fft_precon_grads) / batch_size +
self.damping * _inner_product_list(prev_updates, precon_grads))
m_22 = (
_inner_product_list(fft_prev_updates, fft_prev_updates) / batch_size +
self.damping * _inner_product_list(prev_updates, prev_updates))
def non_zero_prevupd_case():
r"""Computes optimal (alpha, mu) given non-zero previous update.
We solve the full 2x2 linear system. See Martens & Grosse (2015),
Section 7, definition of $\alpha^*$ and $\mu^*$.
Returns:
(alpha, mu, qmodel_change), where alpha and mu are chosen to optimize
the quadratic model, and
qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0).
"""
m = ops.convert_to_tensor([[m_11, m_21], [m_21, m_22]])
c = ops.convert_to_tensor([[_inner_product_list(grads, precon_grads)],
[_inner_product_list(grads, prev_updates)]])
sol = -1. * _two_by_two_solve(m, c)
alpha = sol[0]
mu = sol[1]
qmodel_change = 0.5 * math_ops.reduce_sum(sol * c)
return alpha, mu, qmodel_change
def zero_prevupd_case():
r"""Computes optimal (alpha, mu) given all-zero previous update.
The linear system reduces to 1x1. See Martens & Grosse (2015),
Section 6.4, definition of $\alpha^*$.
Returns:
(alpha, 0.0, qmodel_change), where alpha is chosen to optimize the
quadratic model, and
qmodel_change = qmodel(alpha*precon_grad) - qmodel(0)
"""
m = m_11
c = _inner_product_list(grads, precon_grads)
alpha = -c / m
mu = 0.0
qmodel_change = 0.5 * alpha * c
return alpha, mu, qmodel_change
return control_flow_ops.cond(
math_ops.equal(m_22, 0.0), zero_prevupd_case, non_zero_prevupd_case)
def _assign_q_model_change(self, q_model_change):
"""Assigns `q_model_change` to `self._q_model_change` if damping is adapted.
Note only the chief worker does the assignment.
Args:
q_model_change: Scalar tensor of type `float32`.
Returns:
If `adapt_damping` is `True` then returns an assign op, Otherwise returns
a no_op().
"""
if self._adapt_damping and self._is_chief:
q_model_assign_op = state_ops.assign(self._q_model_change, q_model_change)
else:
q_model_assign_op = control_flow_ops.no_op()
return q_model_assign_op
def _compute_qmodel_hyperparams_wrapper(self, grads_and_vars,
precon_grads_and_vars):
"""Wrapper function for `self._compute_qmodel_hyperparams`.
Constructs a list of preconditioned gradients and variables. Also creates a
op to assign the computed q model change to `self._q_model_change`.
Args:
grads_and_vars: List of (gradient, variable) pairs.
precon_grads_and_vars: List of (preconditioned gradients, variable)
pairs.
Returns:
(alpha, mu, q_model_assign_op), where alpha and mu are chosen to optimize
the quadratic model, `q_model_assign_op` assigns the computed q model
change to `self._q_model_change`.
"""
precon_grads = list(
precon_grad for (precon_grad, _) in precon_grads_and_vars)
grads = list(grad for (grad, _) in grads_and_vars)
variables = list(var for (_, var) in grads_and_vars)
prev_updates = self._compute_prev_updates(variables)
# Compute optimal velocity update parameters according to quadratic model
alpha, mu, q_model_change = self._compute_qmodel_hyperparams(
precon_grads, prev_updates, grads, variables)
return alpha, mu, self._assign_q_model_change(q_model_change)
def _compute_update_steps(self, grads_and_vars):
"""Computes the update steps for the variables given the gradients.
Args:
grads_and_vars: List of (gradient, variable) pairs.
Returns:
A list of tuple (assign_op ,var) where `assign_op` assigns the update
steps to `var`.
"""
if self._momentum_type == "regular":
# Compute "preconditioned" gradient.
precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars)
# Apply "KL clipping" if asked for.
if self._norm_constraint is not None:
precon_grads_and_vars = self._clip_updates(grads_and_vars,
precon_grads_and_vars)
# Update the velocity with this and return it as the step.
if self._adapt_damping and self._is_chief:
_, _, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper(
grads_and_vars, precon_grads_and_vars)
with ops.control_dependencies([q_model_assign_op]):
return self._update_velocities(precon_grads_and_vars, self._momentum)
else:
return self._update_velocities(precon_grads_and_vars, self._momentum)
elif self._momentum_type == "adam":
# Update velocity.
velocities_and_vars = self._update_velocities(grads_and_vars,
self._momentum)
# Return "preconditioned" velocity vector as the step.
return self._fisher_est.multiply_inverse(velocities_and_vars)
elif self._momentum_type == "qmodel":
# Compute "preconditioned" gradient.
precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars)
# Compute optimal velocity update parameters according to quadratic model
alpha, mu, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper(
grads_and_vars, precon_grads_and_vars)
with ops.control_dependencies([q_model_assign_op]):
return self._update_velocities(
precon_grads_and_vars, mu, vec_coeff=-alpha)
def _update_velocities(self, vecs_and_vars, decay, vec_coeff=1.0):
"""Updates the velocities of the variables with the given vectors.
Args:
vecs_and_vars: List of (vector, variable) pairs.
decay: How much to decay the old velocity by. This is often referred to
as the 'momentum constant'.
vec_coeff: Coefficient to apply to the vectors before adding them to the
velocity.
Returns:
A list of (velocity, var) indicating the new velocity for each var.
"""
def _update_velocity(vec, var):
velocity = self._zeros_slot(var, "velocity", self._name)
with ops.colocate_with(velocity):
# NOTE(mattjj): read/modify/write race condition not suitable for async.
# Compute the new velocity for this variable.
new_velocity = decay * velocity + vec_coeff * vec
# Save the updated velocity.
return (array_ops.identity(velocity.assign(new_velocity)), var)
# Go through variable and update its associated part of the velocity vector.
return [_update_velocity(vec, var) for vec, var in vecs_and_vars]
def _update_damping(self, prev_batch, global_step):
"""Adapts damping parameter. Check KFAC (Section 6.5) for the details.
The damping parameter is updated according to the Levenberg-Marquardt rule
every `self._damping_adaptation_interval` iterations.
Args:
prev_batch: Tensor or tuple of tensors which can be passed to
`self._loss_fn` to evaluate loss.
global_step: `Variable` which keeps track of number of times the training
variables have been updated.
Returns:
A `tf.cond` op which updates the damping parameter.
"""
def compute_damping():
""""Adapts damping parameter based on "reduction ratio".
Reduction ratio captures how closely the quadratic approximation to the
loss function approximates the actual loss within a trust region. The
damping update tries to make the damping as small as possible while
maintaining the property that the quadratic model remains a good local
approximation to the loss function.
Returns:
An Op to assign newly computed damping value to `self._damping`.
"""
prev_batch_loss = self._loss_fn(prev_batch)
with ops.control_dependencies([prev_batch_loss]):
rho_assign = self._rho.assign(
(prev_batch_loss - self._prev_loss) / self._q_model_change)
with ops.control_dependencies([rho_assign]):
new_damping = control_flow_ops.case(
[(self._rho < 0.25, lambda: self.damping / self._omega),
(self._rho > 0.75, lambda: self.damping * self._omega)],
lambda: self.damping)
with ops.control_dependencies([new_damping]):
new_damping_min = math_ops.maximum(new_damping, self._min_damping)
return control_flow_ops.group(self._damping.assign(new_damping_min))
return control_flow_ops.cond(
math_ops.equal(
math_ops.mod(global_step + 1, self._damping_adaptation_interval),
0), compute_damping, control_flow_ops.no_op)
def _inner_product_list(list1, list2):
return math_ops.add_n(
[math_ops.reduce_sum(elt1 * elt2) for elt1, elt2 in zip(list1, list2)])
def _two_by_two_solve(m, c):
# it might be better just to crank out the exact formula for 2x2 inverses
return math_ops.matmul(linalg_ops.matrix_inverse(m), c)

View File

@ -1,30 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The KFAC optimizer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,line-too-long,wildcard-import
from tensorflow.contrib.kfac.python.ops.optimizer import *
from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long,wildcard-import
_allowed_symbols = [
"KfacOptimizer",
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)

View File

@ -1,114 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implements placement strategies for cov and inv ops, cov variables."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
from tensorflow.python.framework import ops as tf_ops
def _make_thunk_on_device(func, device):
def thunk():
with tf_ops.device(device):
return func()
return thunk
class RoundRobinPlacementMixin(object):
"""Implements round robin placement strategy for ops and variables."""
def __init__(self, cov_devices=None, inv_devices=None, **kwargs):
"""Initializes the RoundRobinPlacementMixin class.
Args:
cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
computations will be placed on these devices in a round-robin fashion.
Can be None, which means that no devices are specified.
inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
computations will be placed on these devices in a round-robin fashion.
Can be None, which means that no devices are specified.
**kwargs: Need something here?
"""
super(RoundRobinPlacementMixin, self).__init__(**kwargs)
self._cov_devices = cov_devices
self._inv_devices = inv_devices
def make_vars_and_create_op_thunks(self, scope=None):
"""Make vars and create op thunks w/ a round-robin device placement start.
For each factor, all of that factor's cov variables and their associated
update ops will be placed on a particular device. A new device is chosen
for each factor by cycling through list of devices in the
`self._cov_devices` attribute. If `self._cov_devices` is `Non`e then no
explicit device placement occurs.
An analogous strategy is followed for inverse update ops, with the list of
devices being given by the `self._inv_devices` attribute.
Inverse variables on the other hand are not placed on any specific device
(they will just use the current the device placement context, whatever
that happens to be). The idea is that the inverse variable belong where
they will be accessed most often, which is the device that actually applies
the preconditioner to the gradient. The user will be responsible for setting
the device context for this.
Args:
scope: A string or None. If None it will be set to the name of this
estimator (given by the name property). All variables will be created,
and all thunks will execute, inside of a variable scope of the given
name. (Default: None)
Returns:
cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
the list of factors given by the "factors" property.
inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
the list of factors given by the "factors" property.
"""
# Note: `create_ops_and_vars_thunks` is implemented in `FisherEstimator`.
(cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw,
inv_update_thunks_raw) = self.create_ops_and_vars_thunks(scope=scope)
if self._cov_devices:
cov_update_thunks = []
for cov_variable_thunk, cov_update_thunk, device in zip(
cov_variable_thunks_raw, cov_update_thunks_raw,
itertools.cycle(self._cov_devices)):
with tf_ops.device(device):
cov_variable_thunk()
cov_update_thunks.append(_make_thunk_on_device(cov_update_thunk,
device))
else:
for cov_variable_thunk in cov_variable_thunks_raw:
cov_variable_thunk()
cov_update_thunks = cov_update_thunks_raw
for inv_variable_thunk in inv_variable_thunks_raw:
inv_variable_thunk()
if self._inv_devices:
inv_update_thunks = []
for inv_update_thunk, device in zip(inv_update_thunks_raw,
itertools.cycle(self._inv_devices)):
inv_update_thunks.append(_make_thunk_on_device(inv_update_thunk,
device))
else:
inv_update_thunks = inv_update_thunks_raw
return cov_update_thunks, inv_update_thunks

View File

@ -1,709 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
# Method used for inverting matrices.
POSDEF_INV_METHOD = "cholesky"
POSDEF_EIG_METHOD = "self_adjoint"
def set_global_constants(posdef_inv_method=None):
"""Sets various global constants used by the classes in this module."""
global POSDEF_INV_METHOD
if posdef_inv_method is not None:
POSDEF_INV_METHOD = posdef_inv_method
class SequenceDict(object):
"""A dict convenience wrapper that allows getting/setting with sequences."""
def __init__(self, iterable=None):
self._dict = dict(iterable or [])
def __getitem__(self, key_or_keys):
if isinstance(key_or_keys, (tuple, list)):
return list(map(self.__getitem__, key_or_keys))
else:
return self._dict[key_or_keys]
def __setitem__(self, key_or_keys, val_or_vals):
if isinstance(key_or_keys, (tuple, list)):
for key, value in zip(key_or_keys, val_or_vals):
self[key] = value
else:
self._dict[key_or_keys] = val_or_vals
def items(self):
return list(self._dict.items())
def tensors_to_column(tensors):
"""Converts a tensor or list of tensors to a column vector.
Args:
tensors: A tensor or list of tensors.
Returns:
The tensors reshaped into vectors and stacked on top of each other.
"""
if isinstance(tensors, (tuple, list)):
return array_ops.concat(
tuple(array_ops.reshape(tensor, [-1, 1]) for tensor in tensors), axis=0)
else:
return array_ops.reshape(tensors, [-1, 1])
def column_to_tensors(tensors_template, colvec):
"""Converts a column vector back to the shape of the given template.
Args:
tensors_template: A tensor or list of tensors.
colvec: A 2d column vector with the same shape as the value of
tensors_to_column(tensors_template).
Returns:
X, where X is tensor or list of tensors with the properties:
1) tensors_to_column(X) = colvec
2) X (or its elements) have the same shape as tensors_template (or its
elements)
"""
if isinstance(tensors_template, (tuple, list)):
offset = 0
tensors = []
for tensor_template in tensors_template:
sz = np.prod(tensor_template.shape.as_list(), dtype=np.int32)
tensor = array_ops.reshape(colvec[offset:(offset + sz)],
tensor_template.shape)
tensors.append(tensor)
offset += sz
tensors = tuple(tensors)
else:
tensors = array_ops.reshape(colvec, tensors_template.shape)
return tensors
def kronecker_product(mat1, mat2):
"""Computes the Kronecker product two matrices."""
m1, n1 = mat1.get_shape().as_list()
mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1])
m2, n2 = mat2.get_shape().as_list()
mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2])
return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2])
def layer_params_to_mat2d(vector):
"""Converts a vector shaped like layer parameters to a 2D matrix.
In particular, we reshape the weights/filter component of the vector to be
2D, flattening all leading (input) dimensions. If there is a bias component,
we concatenate it to the reshaped weights/filter component.
Args:
vector: A Tensor or pair of Tensors shaped like layer parameters.
Returns:
A 2D Tensor with the same coefficients and the same output dimension.
"""
if isinstance(vector, (tuple, list)):
w_part, b_part = vector
w_part_reshaped = array_ops.reshape(w_part,
[-1, w_part.shape.as_list()[-1]])
return array_ops.concat(
(w_part_reshaped, array_ops.reshape(b_part, [1, -1])), axis=0)
elif isinstance(vector, ops.IndexedSlices):
return vector
else: # Tensor or Tensor-like.
return array_ops.reshape(vector, [-1, vector.shape.as_list()[-1]])
def mat2d_to_layer_params(vector_template, mat2d):
"""Converts a canonical 2D matrix representation back to a vector.
Args:
vector_template: A Tensor or pair of Tensors shaped like layer parameters.
mat2d: A 2D Tensor with the same shape as the value of
layer_params_to_mat2d(vector_template).
Returns:
A Tensor or pair of Tensors with the same coefficients as mat2d and the same
shape as vector_template.
"""
if isinstance(vector_template, (tuple, list)):
w_part, b_part = mat2d[:-1], mat2d[-1]
return array_ops.reshape(w_part, vector_template[0].shape), b_part
elif isinstance(vector_template, ops.IndexedSlices):
if not isinstance(mat2d, ops.IndexedSlices):
raise TypeError(
"If vector_template is an IndexedSlices, so should mat2d.")
return mat2d
else:
return array_ops.reshape(mat2d, vector_template.shape)
def posdef_inv(tensor, damping):
"""Computes the inverse of tensor + damping * identity."""
identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype)
damping = math_ops.cast(damping, dtype=tensor.dtype)
return posdef_inv_functions[POSDEF_INV_METHOD](tensor, identity, damping)
def posdef_inv_matrix_inverse(tensor, identity, damping):
"""Computes inverse(tensor + damping * identity) directly."""
return linalg_ops.matrix_inverse(tensor + damping * identity)
def posdef_inv_cholesky(tensor, identity, damping):
"""Computes inverse(tensor + damping * identity) with Cholesky."""
chol = linalg_ops.cholesky(tensor + damping * identity)
return linalg_ops.cholesky_solve(chol, identity)
def posdef_inv_eig(tensor, identity, damping):
"""Computes inverse(tensor + damping * identity) with eigendecomposition."""
eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(
tensor + damping * identity)
return math_ops.matmul(
eigenvectors / eigenvalues, eigenvectors, transpose_b=True)
posdef_inv_functions = {
"matrix_inverse": posdef_inv_matrix_inverse,
"cholesky": posdef_inv_cholesky,
"eig": posdef_inv_eig,
}
def posdef_eig(mat):
"""Computes the eigendecomposition of a positive semidefinite matrix."""
return posdef_eig_functions[POSDEF_EIG_METHOD](mat)
def posdef_eig_svd(mat):
"""Computes the singular values and left singular vectors of a matrix."""
evals, evecs, _ = linalg_ops.svd(mat)
return evals, evecs
def posdef_eig_self_adjoint(mat):
"""Computes eigendecomposition using self_adjoint_eig."""
evals, evecs = linalg_ops.self_adjoint_eig(mat)
evals = math_ops.abs(evals) # Should be equivalent to svd approach.
return evals, evecs
posdef_eig_functions = {
"self_adjoint": posdef_eig_self_adjoint,
"svd": posdef_eig_svd,
}
def cholesky(tensor, damping):
"""Computes the inverse of tensor + damping * identity."""
identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype)
damping = math_ops.cast(damping, dtype=tensor.dtype)
return linalg_ops.cholesky(tensor + damping * identity)
class SubGraph(object):
"""Defines a subgraph given by all the dependencies of a given set of outputs.
"""
def __init__(self, outputs):
# Set of all ancestor Tensors, Ops to 'outputs'.
self._members = set()
self._iter_add(outputs)
def _iter_add(self, root):
"""Iteratively adds all of nodes' ancestors using depth first search."""
stack = [root]
while stack:
nodes = stack.pop()
for node in nodes:
if node in self._members:
continue
self._members.add(node)
if isinstance(node, ops.Tensor):
stack.append((node.op,))
elif isinstance(node, ops.Operation):
stack.append(node.inputs)
def is_member(self, node):
"""Check if 'node' is in this subgraph."""
return node in self._members
def variable_uses(self, var):
"""Computes number of times a variable is used.
Args:
var: Variable or ResourceVariable instance.
Returns:
Number of times a variable is used within this subgraph.
Raises:
ValueError: If 'var' is not a variable type.
"""
if isinstance(var, resource_variable_ops.ResourceVariable):
var = var.handle
elif isinstance(var, variables.Variable):
var = var.value()
else:
raise ValueError("%s does not appear to be a variable." % str(var))
return len(self._members.intersection(set(var.consumers())))
def filter_list(self, node_list):
"""Filters 'node_list' to nodes in this subgraph."""
filtered_list = []
for node in node_list:
if self.is_member(node):
filtered_list.append(node)
return filtered_list
def generate_random_signs(shape, dtype=dtypes.float32):
"""Generate a random tensor with {-1, +1} entries."""
ints = random_ops.random_uniform(shape, maxval=2, dtype=dtypes.int32)
return 2 * math_ops.cast(ints, dtype=dtype) - 1
def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None):
"""Compute forward-mode gradients."""
# See b/37888268.
# This version of forward-mode autodiff is based on code by Tim Cooijmans
# and handles list arguments and certain special cases such as when the
# ys doesn't depend on one or more of the xs, and when ops.IndexedSlices are
# generated by the first gradients_impl.gradients call.
us = [array_ops.zeros_like(y) + float("nan") for y in ys]
dydxs = gradients_impl.gradients(
ys, xs, grad_ys=us, stop_gradients=stop_gradients)
# Deal with strange types that gradients_impl.gradients returns but can't
# deal with.
dydxs = [
ops.convert_to_tensor(dydx)
if isinstance(dydx, ops.IndexedSlices) else dydx for dydx in dydxs
]
dydxs = [
array_ops.zeros_like(x) if dydx is None else dydx
for x, dydx in zip(xs, dydxs)
]
dysdx = gradients_impl.gradients(dydxs, us, grad_ys=grad_xs)
return dysdx
def on_tpu():
"""Returns True when building a TPU computation."""
return tpu_function.get_tpu_context().number_of_shards is not None
def cross_replica_mean(tensor, name=None):
"""Takes mean value of a Tensor across all TPU cores.
Args:
tensor: Tensor to be synchronized.
name: None or string. Name of Op.
Returns:
Average of Tensor across all TPU cores.
Raises:
ValueError: If called outside of TPU context.
"""
with ops.name_scope(name, "cross_replica_mean", [tensor]):
num_shards = tpu_function.get_tpu_context().number_of_shards
if num_shards is None:
raise ValueError(
"Cannot take cross_replica_mean() outside of TPU Context.")
if num_shards == 1:
return tensor
return tpu_ops.cross_replica_sum(tensor / num_shards)
def ensure_sequence(obj):
"""If `obj` isn't a tuple or list, return a tuple containing `obj`."""
if isinstance(obj, (tuple, list)):
return obj
else:
return (obj,)
def batch_execute(global_step, thunks, batch_size, name=None):
"""Executes a subset of ops per global step.
Given a list of thunks, each of which produces a single stateful op,
ensures that exactly 'batch_size' ops are run per global step. Ops are
scheduled in a round-robin fashion. For example, with 3 ops
global_step | op0 | op1 | op2
------------+-----+-----+-----
0 | x | x |
------------+-----+-----+-----
1 | x | | x
------------+-----+-----+-----
2 | | x | x
------------+-----+-----+-----
3 | x | x |
------------+-----+-----+-----
4 | x | | x
Does not guarantee order of op execution within a single global step.
Args:
global_step: Tensor indicating time. Determines which ops run.
thunks: List of thunks. Each thunk encapsulates one op. Return values are
ignored.
batch_size: int. Number of ops to execute per global_step.
name: string or None. Name scope for newly added ops.
Returns:
List of ops. Exactly 'batch_size' ops are guaranteed to have an effect
every global step.
"""
def true_fn(thunk):
"""Ensures thunk is executed and returns an Op (not a Tensor)."""
def result():
with ops.control_dependencies([thunk()]):
return control_flow_ops.no_op()
return result
def false_fn(_):
"""Executes a no-op."""
def result():
return control_flow_ops.no_op()
return result
with ops.name_scope(name, "batch_execute"):
true_fns = [true_fn(thunk) for thunk in thunks]
false_fns = [false_fn(thunk) for thunk in thunks]
num_thunks = len(thunks)
conditions = [
math_ops.less(
math_ops.mod(batch_size - 1 + global_step * batch_size - j,
num_thunks), batch_size) for j in range(num_thunks)
]
result = [
control_flow_ops.cond(condition, true_fn, false_fn)
for (condition, true_fn,
false_fn) in zip(conditions, true_fns, false_fns)
]
return result
def extract_convolution_patches(inputs,
filter_shape,
padding,
strides=None,
dilation_rate=None,
name=None,
data_format=None):
"""Extracts inputs to each output coordinate in tf.nn.convolution.
This is a generalization of tf.extract_image_patches() to tf.nn.convolution(),
where the number of spatial dimensions may be something other than 2.
Assumes,
- First dimension of inputs is batch_size
- Convolution filter is applied to all input channels.
Args:
inputs: Tensor of shape [batch_size, ..spatial_image_shape..,
..spatial_filter_shape.., in_channels]. Inputs to tf.nn.convolution().
filter_shape: List of ints. Shape of filter passed to tf.nn.convolution().
padding: string. Padding method. One of "VALID", "SAME".
strides: None or list of ints. Strides along spatial dimensions.
dilation_rate: None or list of ints. Dilation along spatial dimensions.
name: None or str. Name of Op.
data_format: None or str. Format of data.
Returns:
Tensor of shape [batch_size, ..spatial_image_shape..,
..spatial_filter_shape.., in_channels]
Raises:
ValueError: If data_format does not put channel last.
ValueError: If inputs and filter disagree on in_channels.
"""
if not is_data_format_channel_last(data_format):
raise ValueError("Channel must be last dimension.")
with ops.name_scope(name, "extract_convolution_patches",
[inputs, filter_shape, padding, strides, dilation_rate]):
batch_size = inputs.shape.as_list()[0]
in_channels = inputs.shape.as_list()[-1]
# filter_shape = spatial_filter_shape + [in_channels, out_channels]
spatial_filter_shape = filter_shape[:-2]
if in_channels != filter_shape[-2]:
raise ValueError("inputs and filter_shape must agree on in_channels.")
# Map each input feature to a location in the output.
out_channels = np.prod(spatial_filter_shape) * in_channels
filters = linalg_ops.eye(out_channels)
filters = array_ops.reshape(
filters,
list(spatial_filter_shape) + [in_channels, out_channels])
result = nn_ops.convolution(
inputs,
filters,
padding=padding,
strides=strides,
dilation_rate=dilation_rate)
spatial_output_shape = result.shape.as_list()[1:-1]
result = array_ops.reshape(result,
[batch_size or -1] + spatial_output_shape +
list(spatial_filter_shape) + [in_channels])
return result
def extract_pointwise_conv2d_patches(inputs,
filter_shape,
name=None,
data_format=None):
"""Extract patches for a 1x1 conv2d.
Args:
inputs: 4-D Tensor of shape [batch_size, height, width, in_channels].
filter_shape: List of 4 ints. Shape of filter to apply with conv2d()
name: None or str. Name for Op.
data_format: None or str. Format for data. See 'data_format' in
tf.nn.conv2d() for details.
Returns:
Tensor of shape [batch_size, ..spatial_input_shape..,
..spatial_filter_shape.., in_channels]
Raises:
ValueError: if inputs is not 4-D.
ValueError: if filter_shape is not [1, 1, ?, ?]
ValueError: if data_format is not channels-last.
"""
if inputs.shape.ndims != 4:
raise ValueError("inputs must have 4 dims.")
if len(filter_shape) != 4:
raise ValueError("filter_shape must have 4 dims.")
if filter_shape[0] != 1 or filter_shape[1] != 1:
raise ValueError("filter_shape must have shape 1 along spatial dimensions.")
if not is_data_format_channel_last(data_format):
raise ValueError("data_format must be channels last.")
with ops.name_scope(name, "extract_pointwise_conv2d_patches",
[inputs, filter_shape]):
ksizes = [1, 1, 1, 1] # Spatial shape is 1x1.
strides = [1, 1, 1, 1] # Operate on all pixels.
rates = [1, 1, 1, 1] # Dilation has no meaning with spatial shape = 1.
padding = "VALID" # Doesn't matter.
result = array_ops.extract_image_patches(inputs, ksizes, strides, rates,
padding)
batch_size, input_height, input_width, in_channels = inputs.shape.as_list()
filter_height, filter_width, in_channels, _ = filter_shape
return array_ops.reshape(result, [
batch_size, input_height, input_width, filter_height, filter_width,
in_channels
])
def is_data_format_channel_last(data_format):
"""True if data_format puts channel last."""
if data_format is None:
return True
return data_format.endswith("C")
def matmul_sparse_dense(A, B, name=None, transpose_a=False, transpose_b=False): # pylint: disable=invalid-name
"""Computes matmul(A, B) where A is sparse, B is dense.
Args:
A: tf.IndexedSlices with dense shape [m, n].
B: tf.Tensor with shape [n, k].
name: str. Name of op.
transpose_a: Bool. If true we transpose A before multiplying it by B.
(Default: False)
transpose_b: Bool. If true we transpose B before multiplying it by A.
(Default: False)
Returns:
tf.IndexedSlices resulting from matmul(A, B).
Raises:
ValueError: If A doesn't represent a matrix.
ValueError: If B is not rank-2.
"""
with ops.name_scope(name, "matmul_sparse_dense", [A, B]):
if A.indices.shape.ndims != 1 or A.values.shape.ndims != 2:
raise ValueError("A must represent a matrix. Found: %s." % A)
if B.shape.ndims != 2:
raise ValueError("B must be a matrix.")
new_values = math_ops.matmul(
A.values, B, transpose_a=transpose_a, transpose_b=transpose_b)
return ops.IndexedSlices(
new_values,
A.indices,
dense_shape=array_ops.stack([A.dense_shape[0], new_values.shape[1]]))
def matmul_diag_sparse(A_diag, B, name=None): # pylint: disable=invalid-name
"""Computes matmul(A, B) where A is a diagonal matrix, B is sparse.
Args:
A_diag: diagonal entries of matrix A of shape [m, m].
B: tf.IndexedSlices. Represents matrix of shape [m, n].
name: str. Name of op.
Returns:
tf.IndexedSlices resulting from matmul(A, B).
Raises:
ValueError: If A_diag is not rank-1.
ValueError: If B doesn't represent a matrix.
"""
with ops.name_scope(name, "matmul_diag_sparse", [A_diag, B]):
A_diag = ops.convert_to_tensor(A_diag)
if A_diag.shape.ndims != 1:
raise ValueError("A_diag must be a rank-1 Tensor.")
if B.indices.shape.ndims != 1 or B.values.shape.ndims != 2:
raise ValueError("B must represent a matrix. Found: %s." % B)
a = array_ops.gather(A_diag, B.indices)
a = array_ops.reshape(a, list(a.shape) + [1] * (B.values.shape.ndims - 1))
return ops.IndexedSlices(a * B.values, B.indices, dense_shape=B.dense_shape)
class PartitionedTensor(object):
"""A Tensor partitioned across its 0-th dimension."""
def __init__(self, tensors):
"""Initializes PartitionedTensor.
Args:
tensors: List of Tensors. All Tensors must agree on shape (excepting
batch dimension) and dtype.
Raises:
ValueError: If 'tensors' has length zero.
ValueError: if contents of 'tensors' don't agree on shape or dtype.
"""
if not tensors:
raise ValueError("tensors must be a list of 1+ Tensors.")
dtype = tensors[0].dtype
if not all(tensor.dtype == dtype for tensor in tensors):
raise ValueError("all tensors must have dtype = %s." % dtype)
shape = tensors[0].shape[1:]
if not all(tensor.shape[1:] == shape for tensor in tensors):
raise ValueError("All tensors must have shape = %s (excluding batch "
"dimension)." % shape)
self.tensors = tensors
self._concats = {} # {device: Tensor}
@property
def shape(self):
feature_shape = self.tensors[0].shape[1:]
batch_size = sum([tensor.shape[0] for tensor in self.tensors],
tensor_shape.Dimension(0))
return tensor_shape.TensorShape([batch_size]).concatenate(feature_shape)
def get_shape(self):
return self.shape
@property
def dtype(self):
return self.tensors[0].dtype
def __str__(self):
return "PartitionedTensor([%s, ...], dtype=%s, shape=%s)" % (
self.tensors[0].name, self.dtype.name, tuple(self.shape.as_list()))
def __hash__(self):
return hash(tuple(self.tensors))
def __eq__(self, other):
if not isinstance(other, PartitionedTensor):
return False
return self.tensors == other.tensors
def __ne__(self, other):
return not self == other # pylint: disable=g-comparison-negation
def __getitem__(self, key):
return self.as_tensor()[key]
def as_tensor(self, dtype=None, name=None, as_ref=False):
with ops.name_scope(name, "PartitionedTensor.as_tensor", self.tensors):
assert not as_ref
assert dtype in [None, self.dtype]
result = array_ops.concat(self.tensors, axis=0)
# Cache 'result' if we haven't already cached a value for this device.
if result.device not in self._concats:
self._concats[result.device] = result
return self._concats[result.device]
@property
def device(self):
# PartitionedTensors in general do not live on a single device. If the
# device cannot be determined unambiguously this property will return None.
device = self.tensors[0].device
if all(tensor.device == device for tensor in self.tensors):
return device
return None
ops.register_tensor_conversion_function(
PartitionedTensor,
lambda val, dtype, name, as_ref: val.as_tensor(dtype, name, as_ref))
# TODO(b/69623235): Add a function for finding tensors that share gradients
# to eliminate redundant fisher factor computations.

View File

@ -1,50 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,line-too-long,wildcard-import
from tensorflow.contrib.kfac.python.ops.utils import *
from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long,wildcard-import
_allowed_symbols = [
"set_global_constants",
"SequenceDict",
"tensors_to_column",
"column_to_tensors",
"kronecker_product",
"layer_params_to_mat2d",
"mat2d_to_layer_params",
"posdef_inv",
"posdef_inv_matrix_inverse",
"posdef_inv_cholesky",
"posdef_inv_funcs",
"SubGraph",
"generate_random_signs",
"fwd_gradients",
"ensure_sequence",
"batch_execute",
"extract_convolution_patches",
"extract_pointwise_conv2d_patches",
"is_data_format_channel_last",
"matmul_sparse_dense",
"matmul_diag_sparse",
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)