BEGIN_PUBLIC
Delete tf.contrib.kfac. K-FAC in Tensorflow is now its own separate package.
END_PUBLIC
RELNOTES: n/a
Automated rollback of commit 938b9a4078
PiperOrigin-RevId: 209813506
This commit is contained in:
parent
c85e0a9829
commit
c73964210c
@ -61,7 +61,6 @@ py_library(
|
||||
"//tensorflow/contrib/integrate:integrate_py",
|
||||
"//tensorflow/contrib/keras",
|
||||
"//tensorflow/contrib/kernel_methods",
|
||||
"//tensorflow/contrib/kfac",
|
||||
"//tensorflow/contrib/labeled_tensor",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/contrib/learn",
|
||||
|
@ -51,7 +51,6 @@ from tensorflow.contrib import input_pipeline
|
||||
from tensorflow.contrib import integrate
|
||||
from tensorflow.contrib import keras
|
||||
from tensorflow.contrib import kernel_methods
|
||||
from tensorflow.contrib import kfac
|
||||
from tensorflow.contrib import labeled_tensor
|
||||
from tensorflow.contrib import layers
|
||||
from tensorflow.contrib import learn
|
||||
|
@ -247,10 +247,6 @@ tensorflow/contrib/kernel_methods/python
|
||||
tensorflow/contrib/kernel_methods/python/mappers
|
||||
tensorflow/contrib/kinesis/python
|
||||
tensorflow/contrib/kinesis/python/ops
|
||||
tensorflow/contrib/kfac
|
||||
tensorflow/contrib/kfac/examples
|
||||
tensorflow/contrib/kfac/python
|
||||
tensorflow/contrib/kfac/python/ops
|
||||
tensorflow/contrib/labeled_tensor
|
||||
tensorflow/contrib/labeled_tensor/python
|
||||
tensorflow/contrib/labeled_tensor/python/ops
|
||||
|
@ -1,26 +0,0 @@
|
||||
# Description:
|
||||
# Contains KfacOptimizer, an implementation of the K-FAC optimization
|
||||
# algorithm in TensorFlow.
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_library(
|
||||
name = "kfac",
|
||||
srcs = ["__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/kfac/python/ops:curvature_matrix_vector_products_lib",
|
||||
"//tensorflow/contrib/kfac/python/ops:fisher_blocks_lib",
|
||||
"//tensorflow/contrib/kfac/python/ops:fisher_estimator_lib",
|
||||
"//tensorflow/contrib/kfac/python/ops:fisher_factors_lib",
|
||||
"//tensorflow/contrib/kfac/python/ops:kfac_optimizer_lib",
|
||||
"//tensorflow/contrib/kfac/python/ops:layer_collection_lib",
|
||||
"//tensorflow/contrib/kfac/python/ops:loss_functions_lib",
|
||||
"//tensorflow/contrib/kfac/python/ops:op_queue_lib",
|
||||
"//tensorflow/contrib/kfac/python/ops:utils_lib",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
@ -1,94 +1,3 @@
|
||||
# K-FAC: Kronecker-Factored Approximate Curvature
|
||||
|
||||
# <font color="red", size=10><u>WARNING: </u></font>
|
||||
# ==third_party/tensorflow/contrib/kfac is deprecated. This will be==
|
||||
# ==removed on 15-07-2018. <!-- STY:begin_strip_and_replace -->Please import third_party/tensorflow_kfac.==
|
||||
# ==<!-- STY:end_strip_and_replace Please check https://github.com/tensorflow/kfac. -->==
|
||||
|
||||
**K-FAC in TensorFlow** is an implementation of [K-FAC][kfac-paper], an
|
||||
approximate second-order optimization method, in TensorFlow. When applied to
|
||||
feedforward and convolutional neural networks, K-FAC can converge `>3.5x`
|
||||
faster in `>14x` fewer iterations than SGD with Momentum.
|
||||
|
||||
[kfac-paper]: https://arxiv.org/abs/1503.05671
|
||||
|
||||
## What is K-FAC?
|
||||
|
||||
K-FAC, short for "Kronecker-factored Approximate Curvature", is an approximation
|
||||
to the [Natural Gradient][natural_gradient] algorithm designed specifically for
|
||||
neural networks. It maintains a block-diagonal approximation to the [Fisher
|
||||
Information matrix][fisher_information], whose inverse preconditions the
|
||||
gradient.
|
||||
|
||||
K-FAC can be used in place of SGD, Adam, and other `Optimizer` implementations.
|
||||
Experimentally, K-FAC converges `>3.5x` faster than well-tuned SGD.
|
||||
|
||||
Unlike most optimizers, K-FAC exploits structure in the model itself (e.g. "What
|
||||
are the weights for layer i?"). As such, you must add some additional code while
|
||||
constructing your model to use K-FAC.
|
||||
|
||||
[natural_gradient]: http://www.mitpressjournals.org/doi/abs/10.1162/089976698300017746
|
||||
[fisher_information]: https://en.wikipedia.org/wiki/Fisher_information#Matrix_form
|
||||
|
||||
## Why should I use K-FAC?
|
||||
|
||||
K-FAC can take advantage of the curvature of the optimization problem, resulting
|
||||
in **faster training**. For an 8-layer Autoencoder, K-FAC converges to the same
|
||||
loss as SGD with Momentum in 3.8x fewer seconds and 14.7x fewer updates. See how
|
||||
training loss changes as a function of number of epochs, steps, and seconds:
|
||||
|
||||

|
||||
|
||||
## Is K-FAC for me?
|
||||
|
||||
If you have a feedforward or convolutional model for classification that is
|
||||
converging too slowly, K-FAC is for you. K-FAC can be used in your model if:
|
||||
|
||||
* Your model defines a posterior distribution.
|
||||
* Your model uses only fully-connected or convolutional layers (residual
|
||||
connections OK).
|
||||
* You are training on CPU or GPU.
|
||||
* You can modify model code to register layers with K-FAC.
|
||||
|
||||
## How do I use K-FAC?
|
||||
|
||||
Using K-FAC requires three steps:
|
||||
|
||||
1. Registering layer inputs, weights, and pre-activations with a
|
||||
`LayerCollection`.
|
||||
1. Minimizing the loss with a `KfacOptimizer`.
|
||||
1. Keeping K-FAC's preconditioner updated.
|
||||
|
||||
```python
|
||||
# Build model.
|
||||
w = tf.get_variable("w", ...)
|
||||
b = tf.get_variable("b", ...)
|
||||
logits = tf.matmul(x, w) + b
|
||||
loss = tf.reduce_mean(
|
||||
tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits))
|
||||
|
||||
# Register layers.
|
||||
layer_collection = LayerCollection()
|
||||
layer_collection.register_fully_connected((w, b), x, logits)
|
||||
layer_collection.register_categorical_predictive_distribution(logits)
|
||||
|
||||
# Construct training ops.
|
||||
optimizer = KfacOptimizer(..., layer_collection=layer_collection)
|
||||
train_op = optimizer.minimize(loss)
|
||||
|
||||
# Minimize loss.
|
||||
with tf.Session() as sess:
|
||||
...
|
||||
sess.run([train_op, optimizer.cov_update_op, optimizer.inv_update_op])
|
||||
```
|
||||
|
||||
See [`examples/`](https://www.tensorflow.org/code/tensorflow/contrib/kfac/examples/) for runnable, end-to-end illustrations.
|
||||
|
||||
## Authors
|
||||
|
||||
- Alok Aggarwal
|
||||
- Daniel Duckworth
|
||||
- James Martens
|
||||
- Matthew Johnson
|
||||
- Olga Wichrowska
|
||||
- Roger Grosse
|
||||
## KFAC moved to third_party/tensorflow_kfac.
|
||||
|
@ -1,46 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Kronecker-factored Approximate Curvature Optimizer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,line-too-long
|
||||
from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products_lib as curvature_matrix_vector_products
|
||||
from tensorflow.contrib.kfac.python.ops import estimator_lib as estimator
|
||||
from tensorflow.contrib.kfac.python.ops import fisher_blocks_lib as fisher_blocks
|
||||
from tensorflow.contrib.kfac.python.ops import fisher_factors_lib as fisher_factors
|
||||
from tensorflow.contrib.kfac.python.ops import layer_collection_lib as layer_collection
|
||||
from tensorflow.contrib.kfac.python.ops import loss_functions_lib as loss_functions
|
||||
from tensorflow.contrib.kfac.python.ops import op_queue_lib as op_queue
|
||||
from tensorflow.contrib.kfac.python.ops import optimizer_lib as optimizer
|
||||
from tensorflow.contrib.kfac.python.ops import utils_lib as utils
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
# pylint: enable=unused-import,line-too-long
|
||||
|
||||
_allowed_symbols = [
|
||||
"curvature_matrix_vector_products",
|
||||
"estimator",
|
||||
"fisher_blocks",
|
||||
"fisher_factors",
|
||||
"layer_collection",
|
||||
"loss_functions",
|
||||
"op_queue",
|
||||
"optimizer",
|
||||
"utils",
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
|
@ -1,80 +0,0 @@
|
||||
package(default_visibility = [
|
||||
"//learning/brain/contrib/kfac/examples:__subpackages__",
|
||||
"//tensorflow/contrib/kfac/examples:__subpackages__",
|
||||
])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_binary(
|
||||
name = "mlp_mnist_main",
|
||||
srcs = ["mlp_mnist_main.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":mlp",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "mlp",
|
||||
srcs = ["mlp.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":mnist",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "convnet_mnist_single_main",
|
||||
srcs = ["convnet_mnist_single_main.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":convnet",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "convnet_mnist_multi_tower_main",
|
||||
srcs = ["convnet_mnist_multi_tower_main.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":convnet",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "convnet_mnist_distributed_main",
|
||||
srcs = ["convnet_mnist_distributed_main.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":convnet",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "convnet",
|
||||
srcs = ["convnet.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":mlp",
|
||||
":mnist",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "mnist",
|
||||
srcs = ["mnist.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
@ -1,667 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""Train a ConvNet on MNIST using K-FAC.
|
||||
|
||||
This library fits a 5-layer ConvNet on MNIST using K-FAC. The model has the
|
||||
following structure,
|
||||
|
||||
- Conv Layer: 5x5 kernel, 16 output channels.
|
||||
- Max Pool: 3x3 kernel, stride 2.
|
||||
- Conv Layer: 5x5 kernel, 16 output channels.
|
||||
- Max Pool: 3x3 kernel, stride 2.
|
||||
- Linear: 10 output dims.
|
||||
|
||||
After 3k~6k steps, this should reach perfect accuracy on the training set.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.kfac.examples import mlp
|
||||
from tensorflow.contrib.kfac.examples import mnist
|
||||
from tensorflow.contrib.kfac.python.ops import optimizer as opt
|
||||
|
||||
|
||||
lc = tf.contrib.kfac.layer_collection
|
||||
oq = tf.contrib.kfac.op_queue
|
||||
opt = tf.contrib.kfac.optimizer
|
||||
|
||||
__all__ = [
|
||||
"conv_layer",
|
||||
"max_pool_layer",
|
||||
"linear_layer",
|
||||
"build_model",
|
||||
"minimize_loss_single_machine",
|
||||
"distributed_grads_only_and_ops_chief_worker",
|
||||
"distributed_grads_and_ops_dedicated_workers",
|
||||
"train_mnist_single_machine",
|
||||
"train_mnist_distributed_sync_replicas",
|
||||
"train_mnist_multitower"
|
||||
]
|
||||
|
||||
|
||||
# Inverse update ops will be run every _INVERT_EVRY iterations.
|
||||
_INVERT_EVERY = 10
|
||||
|
||||
|
||||
def conv_layer(layer_id, inputs, kernel_size, out_channels):
|
||||
"""Builds a convolutional layer with ReLU non-linearity.
|
||||
|
||||
Args:
|
||||
layer_id: int. Integer ID for this layer's variables.
|
||||
inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row
|
||||
corresponds to a single example.
|
||||
kernel_size: int. Width and height of the convolution kernel. The kernel is
|
||||
assumed to be square.
|
||||
out_channels: int. Number of output features per pixel.
|
||||
|
||||
Returns:
|
||||
preactivations: Tensor of shape [num_examples, width, height, out_channels].
|
||||
Values of the layer immediately before the activation function.
|
||||
activations: Tensor of shape [num_examples, width, height, out_channels].
|
||||
Values of the layer immediately after the activation function.
|
||||
params: Tuple of (kernel, bias), parameters for this layer.
|
||||
"""
|
||||
# TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
|
||||
layer = tf.layers.Conv2D(
|
||||
out_channels,
|
||||
kernel_size=[kernel_size, kernel_size],
|
||||
kernel_initializer=tf.random_normal_initializer(stddev=0.01),
|
||||
padding="SAME",
|
||||
name="conv_%d" % layer_id)
|
||||
preactivations = layer(inputs)
|
||||
activations = tf.nn.relu(preactivations)
|
||||
|
||||
# layer.weights is a list. This converts it a (hashable) tuple.
|
||||
return preactivations, activations, (layer.kernel, layer.bias)
|
||||
|
||||
|
||||
def max_pool_layer(layer_id, inputs, kernel_size, stride):
|
||||
"""Build a max-pooling layer.
|
||||
|
||||
Args:
|
||||
layer_id: int. Integer ID for this layer's variables.
|
||||
inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row
|
||||
corresponds to a single example.
|
||||
kernel_size: int. Width and height to pool over per input channel. The
|
||||
kernel is assumed to be square.
|
||||
stride: int. Step size between pooling operations.
|
||||
|
||||
Returns:
|
||||
Tensor of shape [num_examples, width/stride, height/stride, out_channels].
|
||||
Result of applying max pooling to 'inputs'.
|
||||
"""
|
||||
# TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
|
||||
with tf.variable_scope("pool_%d" % layer_id):
|
||||
return tf.nn.max_pool(
|
||||
inputs, [1, kernel_size, kernel_size, 1], [1, stride, stride, 1],
|
||||
padding="SAME",
|
||||
name="pool")
|
||||
|
||||
|
||||
def linear_layer(layer_id, inputs, output_size):
|
||||
"""Builds the final linear layer for an MNIST classification problem.
|
||||
|
||||
Args:
|
||||
layer_id: int. Integer ID for this layer's variables.
|
||||
inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row
|
||||
corresponds to a single example.
|
||||
output_size: int. Number of output dims per example.
|
||||
|
||||
Returns:
|
||||
activations: Tensor of shape [num_examples, output_size]. Values of the
|
||||
layer immediately after the activation function.
|
||||
params: Tuple of (weights, bias), parameters for this layer.
|
||||
"""
|
||||
# TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
|
||||
pre, _, params = mlp.fc_layer(layer_id, inputs, output_size)
|
||||
return pre, params
|
||||
|
||||
|
||||
def build_model(examples, labels, num_labels, layer_collection):
|
||||
"""Builds a ConvNet classification model.
|
||||
|
||||
Args:
|
||||
examples: Tensor of shape [num_examples, num_features]. Represents inputs of
|
||||
model.
|
||||
labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted
|
||||
by softmax for each example.
|
||||
num_labels: int. Number of distinct values 'labels' can take on.
|
||||
layer_collection: LayerCollection instance. Layers will be registered here.
|
||||
|
||||
Returns:
|
||||
loss: 0-D Tensor representing loss to be minimized.
|
||||
accuracy: 0-D Tensor representing model's accuracy.
|
||||
"""
|
||||
# Build a ConvNet. For each layer with parameters, we'll keep track of the
|
||||
# preactivations, activations, weights, and bias.
|
||||
tf.logging.info("Building model.")
|
||||
pre0, act0, params0 = conv_layer(
|
||||
layer_id=0, inputs=examples, kernel_size=5, out_channels=16)
|
||||
act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2)
|
||||
pre2, act2, params2 = conv_layer(
|
||||
layer_id=2, inputs=act1, kernel_size=5, out_channels=16)
|
||||
act3 = max_pool_layer(layer_id=3, inputs=act2, kernel_size=3, stride=2)
|
||||
flat_act3 = tf.reshape(act3, shape=[-1, int(np.prod(act3.shape[1:4]))])
|
||||
logits, params4 = linear_layer(
|
||||
layer_id=4, inputs=flat_act3, output_size=num_labels)
|
||||
loss = tf.reduce_mean(
|
||||
tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
labels=labels, logits=logits))
|
||||
accuracy = tf.reduce_mean(
|
||||
tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32))
|
||||
|
||||
with tf.device("/cpu:0"):
|
||||
tf.summary.scalar("loss", loss)
|
||||
tf.summary.scalar("accuracy", accuracy)
|
||||
|
||||
# Register parameters. K-FAC needs to know about the inputs, outputs, and
|
||||
# parameters of each conv/fully connected layer and the logits powering the
|
||||
# posterior probability over classes.
|
||||
tf.logging.info("Building LayerCollection.")
|
||||
layer_collection.register_conv2d(params0, (1, 1, 1, 1), "SAME", examples,
|
||||
pre0)
|
||||
layer_collection.register_conv2d(params2, (1, 1, 1, 1), "SAME", act1, pre2)
|
||||
layer_collection.register_fully_connected(params4, flat_act3, logits)
|
||||
layer_collection.register_categorical_predictive_distribution(
|
||||
logits, name="logits")
|
||||
|
||||
return loss, accuracy
|
||||
|
||||
|
||||
def minimize_loss_single_machine(loss,
|
||||
accuracy,
|
||||
layer_collection,
|
||||
device="/gpu:0",
|
||||
session_config=None):
|
||||
"""Minimize loss with K-FAC on a single machine.
|
||||
|
||||
A single Session is responsible for running all of K-FAC's ops. The covariance
|
||||
and inverse update ops are placed on `device`. All model variables are on CPU.
|
||||
|
||||
Args:
|
||||
loss: 0-D Tensor. Loss to be minimized.
|
||||
accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
|
||||
layer_collection: LayerCollection instance describing model architecture.
|
||||
Used by K-FAC to construct preconditioner.
|
||||
device: string, Either '/cpu:0' or '/gpu:0'. The covariance and inverse
|
||||
update ops are run on this device.
|
||||
session_config: None or tf.ConfigProto. Configuration for tf.Session().
|
||||
|
||||
Returns:
|
||||
final value for 'accuracy'.
|
||||
"""
|
||||
# Train with K-FAC.
|
||||
g_step = tf.train.get_or_create_global_step()
|
||||
optimizer = opt.KfacOptimizer(
|
||||
learning_rate=0.0001,
|
||||
cov_ema_decay=0.95,
|
||||
damping=0.001,
|
||||
layer_collection=layer_collection,
|
||||
placement_strategy="round_robin",
|
||||
cov_devices=[device],
|
||||
inv_devices=[device],
|
||||
momentum=0.9)
|
||||
(cov_update_thunks,
|
||||
inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
|
||||
|
||||
def make_update_op(update_thunks):
|
||||
update_ops = [thunk() for thunk in update_thunks]
|
||||
return tf.group(*update_ops)
|
||||
|
||||
cov_update_op = make_update_op(cov_update_thunks)
|
||||
with tf.control_dependencies([cov_update_op]):
|
||||
inverse_op = tf.cond(
|
||||
tf.equal(tf.mod(g_step, _INVERT_EVERY), 0),
|
||||
lambda: make_update_op(inv_update_thunks), tf.no_op)
|
||||
with tf.control_dependencies([inverse_op]):
|
||||
with tf.device(device):
|
||||
train_op = optimizer.minimize(loss, global_step=g_step)
|
||||
|
||||
tf.logging.info("Starting training.")
|
||||
with tf.train.MonitoredTrainingSession(config=session_config) as sess:
|
||||
while not sess.should_stop():
|
||||
global_step_, loss_, accuracy_, _ = sess.run(
|
||||
[g_step, loss, accuracy, train_op])
|
||||
|
||||
if global_step_ % _INVERT_EVERY == 0:
|
||||
tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
|
||||
global_step_, loss_, accuracy_)
|
||||
|
||||
return accuracy_
|
||||
|
||||
|
||||
def _is_gradient_task(task_id, num_tasks):
|
||||
"""Returns True if this task should update the weights."""
|
||||
if num_tasks < 3:
|
||||
return True
|
||||
return 0 <= task_id < 0.6 * num_tasks
|
||||
|
||||
|
||||
def _is_cov_update_task(task_id, num_tasks):
|
||||
"""Returns True if this task should update K-FAC's covariance matrices."""
|
||||
if num_tasks < 3:
|
||||
return False
|
||||
return 0.6 * num_tasks <= task_id < num_tasks - 1
|
||||
|
||||
|
||||
def _is_inv_update_task(task_id, num_tasks):
|
||||
"""Returns True if this task should update K-FAC's preconditioner."""
|
||||
if num_tasks < 3:
|
||||
return False
|
||||
return task_id == num_tasks - 1
|
||||
|
||||
|
||||
def _num_gradient_tasks(num_tasks):
|
||||
"""Number of tasks that will update weights."""
|
||||
if num_tasks < 3:
|
||||
return num_tasks
|
||||
return int(np.ceil(0.6 * num_tasks))
|
||||
|
||||
|
||||
def _make_distributed_train_op(
|
||||
task_id,
|
||||
num_worker_tasks,
|
||||
num_ps_tasks,
|
||||
layer_collection
|
||||
):
|
||||
"""Creates optimizer and distributed training op.
|
||||
|
||||
Constructs KFAC optimizer and wraps it in `sync_replicas` optimizer. Makes
|
||||
the train op.
|
||||
|
||||
Args:
|
||||
task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
|
||||
num_worker_tasks: int. Number of workers in this distributed training setup.
|
||||
num_ps_tasks: int. Number of parameter servers holding variables. If 0,
|
||||
parameter servers are not used.
|
||||
layer_collection: LayerCollection instance describing model architecture.
|
||||
Used by K-FAC to construct preconditioner.
|
||||
|
||||
Returns:
|
||||
sync_optimizer: `tf.train.SyncReplicasOptimizer` instance which wraps KFAC
|
||||
optimizer.
|
||||
optimizer: Instance of `opt.KfacOptimizer`.
|
||||
global_step: `tensor`, Global step.
|
||||
"""
|
||||
tf.logging.info("Task id : %d", task_id)
|
||||
with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
optimizer = opt.KfacOptimizer(
|
||||
learning_rate=0.0001,
|
||||
cov_ema_decay=0.95,
|
||||
damping=0.001,
|
||||
layer_collection=layer_collection,
|
||||
momentum=0.9)
|
||||
sync_optimizer = tf.train.SyncReplicasOptimizer(
|
||||
opt=optimizer,
|
||||
replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks),
|
||||
total_num_replicas=num_worker_tasks)
|
||||
return sync_optimizer, optimizer, global_step
|
||||
|
||||
|
||||
def distributed_grads_only_and_ops_chief_worker(
|
||||
task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir,
|
||||
loss, accuracy, layer_collection, invert_every=10):
|
||||
"""Minimize loss with a synchronous implementation of K-FAC.
|
||||
|
||||
All workers perform gradient computation. Chief worker applies gradient after
|
||||
averaging the gradients obtained from all the workers. All workers block
|
||||
execution until the update is applied. Chief worker runs covariance and
|
||||
inverse update ops. Covariance and inverse matrices are placed on parameter
|
||||
servers in a round robin manner. For further details on synchronous
|
||||
distributed optimization check `tf.train.SyncReplicasOptimizer`.
|
||||
|
||||
Args:
|
||||
task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
|
||||
is_chief: `boolean`, `True` if the worker is chief worker.
|
||||
num_worker_tasks: int. Number of workers in this distributed training setup.
|
||||
num_ps_tasks: int. Number of parameter servers holding variables. If 0,
|
||||
parameter servers are not used.
|
||||
master: string. IP and port of TensorFlow runtime process. Set to empty
|
||||
string to run locally.
|
||||
checkpoint_dir: string or None. Path to store checkpoints under.
|
||||
loss: 0-D Tensor. Loss to be minimized.
|
||||
accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to
|
||||
run with each step.
|
||||
layer_collection: LayerCollection instance describing model architecture.
|
||||
Used by K-FAC to construct preconditioner.
|
||||
invert_every: `int`, Number of steps between update the inverse.
|
||||
|
||||
Returns:
|
||||
final value for 'accuracy'.
|
||||
|
||||
Raises:
|
||||
ValueError: if task_id >= num_worker_tasks.
|
||||
"""
|
||||
|
||||
sync_optimizer, optimizer, global_step = _make_distributed_train_op(
|
||||
task_id, num_worker_tasks, num_ps_tasks, layer_collection)
|
||||
(cov_update_thunks,
|
||||
inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
|
||||
|
||||
tf.logging.info("Starting training.")
|
||||
hooks = [sync_optimizer.make_session_run_hook(is_chief)]
|
||||
|
||||
def make_update_op(update_thunks):
|
||||
update_ops = [thunk() for thunk in update_thunks]
|
||||
return tf.group(*update_ops)
|
||||
|
||||
if is_chief:
|
||||
cov_update_op = make_update_op(cov_update_thunks)
|
||||
with tf.control_dependencies([cov_update_op]):
|
||||
inverse_op = tf.cond(
|
||||
tf.equal(tf.mod(global_step, invert_every), 0),
|
||||
lambda: make_update_op(inv_update_thunks),
|
||||
tf.no_op)
|
||||
with tf.control_dependencies([inverse_op]):
|
||||
train_op = sync_optimizer.minimize(loss, global_step=global_step)
|
||||
else:
|
||||
train_op = sync_optimizer.minimize(loss, global_step=global_step)
|
||||
|
||||
with tf.train.MonitoredTrainingSession(
|
||||
master=master,
|
||||
is_chief=is_chief,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
hooks=hooks,
|
||||
stop_grace_period_secs=0) as sess:
|
||||
while not sess.should_stop():
|
||||
global_step_, loss_, accuracy_, _ = sess.run(
|
||||
[global_step, loss, accuracy, train_op])
|
||||
tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_,
|
||||
loss_, accuracy_)
|
||||
return accuracy_
|
||||
|
||||
|
||||
def distributed_grads_and_ops_dedicated_workers(
|
||||
task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir,
|
||||
loss, accuracy, layer_collection):
|
||||
"""Minimize loss with a synchronous implementation of K-FAC.
|
||||
|
||||
Different workers are responsible for different parts of K-FAC's Ops. The
|
||||
first 60% of tasks compute gradients; the next 20% accumulate covariance
|
||||
statistics; the last 20% invert the matrices used to precondition gradients.
|
||||
The chief worker applies the gradient .
|
||||
|
||||
Args:
|
||||
task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
|
||||
is_chief: `boolean`, `True` if the worker is chief worker.
|
||||
num_worker_tasks: int. Number of workers in this distributed training setup.
|
||||
num_ps_tasks: int. Number of parameter servers holding variables. If 0,
|
||||
parameter servers are not used.
|
||||
master: string. IP and port of TensorFlow runtime process. Set to empty
|
||||
string to run locally.
|
||||
checkpoint_dir: string or None. Path to store checkpoints under.
|
||||
loss: 0-D Tensor. Loss to be minimized.
|
||||
accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to
|
||||
run with each step.
|
||||
layer_collection: LayerCollection instance describing model architecture.
|
||||
Used by K-FAC to construct preconditioner.
|
||||
|
||||
Returns:
|
||||
final value for 'accuracy'.
|
||||
|
||||
Raises:
|
||||
ValueError: if task_id >= num_worker_tasks.
|
||||
"""
|
||||
sync_optimizer, optimizer, global_step = _make_distributed_train_op(
|
||||
task_id, num_worker_tasks, num_ps_tasks, layer_collection)
|
||||
_, cov_update_op, inv_update_ops, _, _, _ = optimizer.make_ops_and_vars()
|
||||
train_op = sync_optimizer.minimize(loss, global_step=global_step)
|
||||
inv_update_queue = oq.OpQueue(inv_update_ops)
|
||||
|
||||
tf.logging.info("Starting training.")
|
||||
is_chief = (task_id == 0)
|
||||
hooks = [sync_optimizer.make_session_run_hook(is_chief)]
|
||||
with tf.train.MonitoredTrainingSession(
|
||||
master=master,
|
||||
is_chief=is_chief,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
hooks=hooks,
|
||||
stop_grace_period_secs=0) as sess:
|
||||
while not sess.should_stop():
|
||||
# Choose which op this task is responsible for running.
|
||||
if _is_gradient_task(task_id, num_worker_tasks):
|
||||
learning_op = train_op
|
||||
elif _is_cov_update_task(task_id, num_worker_tasks):
|
||||
learning_op = cov_update_op
|
||||
elif _is_inv_update_task(task_id, num_worker_tasks):
|
||||
# TODO(duckworthd): Running this op before cov_update_op has been run a
|
||||
# few times can result in "InvalidArgumentError: Cholesky decomposition
|
||||
# was not successful." Delay running this op until cov_update_op has
|
||||
# been run a few times.
|
||||
learning_op = inv_update_queue.next_op(sess)
|
||||
else:
|
||||
raise ValueError("Which op should task %d do?" % task_id)
|
||||
|
||||
global_step_, loss_, accuracy_, _ = sess.run(
|
||||
[global_step, loss, accuracy, learning_op])
|
||||
tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_,
|
||||
loss_, accuracy_)
|
||||
|
||||
return accuracy_
|
||||
|
||||
|
||||
def train_mnist_single_machine(data_dir,
|
||||
num_epochs,
|
||||
use_fake_data=False,
|
||||
device="/gpu:0"):
|
||||
"""Train a ConvNet on MNIST.
|
||||
|
||||
Args:
|
||||
data_dir: string. Directory to read MNIST examples from.
|
||||
num_epochs: int. Number of passes to make over the training set.
|
||||
use_fake_data: bool. If True, generate a synthetic dataset.
|
||||
device: string, Either '/cpu:0' or '/gpu:0'. The covariance and inverse
|
||||
update ops are run on this device.
|
||||
|
||||
Returns:
|
||||
accuracy of model on the final minibatch of training data.
|
||||
"""
|
||||
# Load a dataset.
|
||||
tf.logging.info("Loading MNIST into memory.")
|
||||
examples, labels = mnist.load_mnist(
|
||||
data_dir,
|
||||
num_epochs=num_epochs,
|
||||
batch_size=128,
|
||||
use_fake_data=use_fake_data,
|
||||
flatten_images=False)
|
||||
|
||||
# Build a ConvNet.
|
||||
layer_collection = lc.LayerCollection()
|
||||
loss, accuracy = build_model(
|
||||
examples, labels, num_labels=10, layer_collection=layer_collection)
|
||||
|
||||
# Fit model.
|
||||
return minimize_loss_single_machine(
|
||||
loss, accuracy, layer_collection, device=device)
|
||||
|
||||
|
||||
def train_mnist_multitower(data_dir, num_epochs, num_towers,
|
||||
use_fake_data=True, devices=None):
|
||||
"""Train a ConvNet on MNIST.
|
||||
|
||||
Training data is split equally among the towers. Each tower computes loss on
|
||||
its own batch of data and the loss is aggregated on the CPU. The model
|
||||
variables are placed on first tower. The covariance and inverse update ops
|
||||
and variables are placed on GPUs in a round robin manner.
|
||||
|
||||
Args:
|
||||
data_dir: string. Directory to read MNIST examples from.
|
||||
num_epochs: int. Number of passes to make over the training set.
|
||||
num_towers: int. Number of CPUs to split inference across.
|
||||
use_fake_data: bool. If True, generate a synthetic dataset.
|
||||
devices: string, Either list of CPU or GPU. The covariance and inverse
|
||||
update ops are run on this device.
|
||||
|
||||
Returns:
|
||||
accuracy of model on the final minibatch of training data.
|
||||
"""
|
||||
if devices:
|
||||
device_count = {"GPU": num_towers}
|
||||
else:
|
||||
device_count = {"CPU": num_towers}
|
||||
|
||||
devices = devices or [
|
||||
"/cpu:{}".format(tower_id) for tower_id in range(num_towers)
|
||||
]
|
||||
# Load a dataset.
|
||||
tf.logging.info("Loading MNIST into memory.")
|
||||
tower_batch_size = 128
|
||||
batch_size = tower_batch_size * num_towers
|
||||
tf.logging.info(
|
||||
("Loading MNIST into memory. Using batch_size = %d = %d towers * %d "
|
||||
"tower batch size.") % (batch_size, num_towers, tower_batch_size))
|
||||
examples, labels = mnist.load_mnist(
|
||||
data_dir,
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size,
|
||||
use_fake_data=use_fake_data,
|
||||
flatten_images=False)
|
||||
|
||||
# Split minibatch across towers.
|
||||
examples = tf.split(examples, num_towers)
|
||||
labels = tf.split(labels, num_towers)
|
||||
|
||||
# Build an MLP. Each tower's layers will be added to the LayerCollection.
|
||||
layer_collection = lc.LayerCollection()
|
||||
tower_results = []
|
||||
for tower_id in range(num_towers):
|
||||
with tf.device(devices[tower_id]):
|
||||
with tf.name_scope("tower%d" % tower_id):
|
||||
with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)):
|
||||
tf.logging.info("Building tower %d." % tower_id)
|
||||
tower_results.append(
|
||||
build_model(examples[tower_id], labels[tower_id], 10,
|
||||
layer_collection))
|
||||
losses, accuracies = zip(*tower_results)
|
||||
|
||||
# Average across towers.
|
||||
loss = tf.reduce_mean(losses)
|
||||
accuracy = tf.reduce_mean(accuracies)
|
||||
|
||||
# Fit model.
|
||||
|
||||
session_config = tf.ConfigProto(
|
||||
allow_soft_placement=False,
|
||||
device_count=device_count,
|
||||
)
|
||||
|
||||
g_step = tf.train.get_or_create_global_step()
|
||||
optimizer = opt.KfacOptimizer(
|
||||
learning_rate=0.0001,
|
||||
cov_ema_decay=0.95,
|
||||
damping=0.001,
|
||||
layer_collection=layer_collection,
|
||||
placement_strategy="round_robin",
|
||||
cov_devices=devices,
|
||||
inv_devices=devices,
|
||||
momentum=0.9)
|
||||
(cov_update_thunks,
|
||||
inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
|
||||
|
||||
def make_update_op(update_thunks):
|
||||
update_ops = [thunk() for thunk in update_thunks]
|
||||
return tf.group(*update_ops)
|
||||
|
||||
cov_update_op = make_update_op(cov_update_thunks)
|
||||
with tf.control_dependencies([cov_update_op]):
|
||||
inverse_op = tf.cond(
|
||||
tf.equal(tf.mod(g_step, _INVERT_EVERY), 0),
|
||||
lambda: make_update_op(inv_update_thunks), tf.no_op)
|
||||
with tf.control_dependencies([inverse_op]):
|
||||
train_op = optimizer.minimize(loss, global_step=g_step)
|
||||
|
||||
tf.logging.info("Starting training.")
|
||||
with tf.train.MonitoredTrainingSession(config=session_config) as sess:
|
||||
while not sess.should_stop():
|
||||
global_step_, loss_, accuracy_, _ = sess.run(
|
||||
[g_step, loss, accuracy, train_op])
|
||||
|
||||
if global_step_ % _INVERT_EVERY == 0:
|
||||
tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
|
||||
global_step_, loss_, accuracy_)
|
||||
|
||||
|
||||
def train_mnist_distributed_sync_replicas(task_id,
|
||||
is_chief,
|
||||
num_worker_tasks,
|
||||
num_ps_tasks,
|
||||
master,
|
||||
data_dir,
|
||||
num_epochs,
|
||||
op_strategy,
|
||||
use_fake_data=False):
|
||||
"""Train a ConvNet on MNIST using Sync replicas optimizer.
|
||||
|
||||
Args:
|
||||
task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
|
||||
is_chief: `boolean`, `True` if the worker is chief worker.
|
||||
num_worker_tasks: int. Number of workers in this distributed training setup.
|
||||
num_ps_tasks: int. Number of parameter servers holding variables.
|
||||
master: string. IP and port of TensorFlow runtime process.
|
||||
data_dir: string. Directory to read MNIST examples from.
|
||||
num_epochs: int. Number of passes to make over the training set.
|
||||
op_strategy: `string`, Strategy to run the covariance and inverse
|
||||
ops. If op_strategy == `chief_worker` then covariance and inverse
|
||||
update ops are run on chief worker otherwise they are run on dedicated
|
||||
workers.
|
||||
|
||||
use_fake_data: bool. If True, generate a synthetic dataset.
|
||||
|
||||
Returns:
|
||||
accuracy of model on the final minibatch of training data.
|
||||
|
||||
Raises:
|
||||
ValueError: If `op_strategy` not in ["chief_worker", "dedicated_workers"].
|
||||
"""
|
||||
# Load a dataset.
|
||||
tf.logging.info("Loading MNIST into memory.")
|
||||
examples, labels = mnist.load_mnist(
|
||||
data_dir,
|
||||
num_epochs=num_epochs,
|
||||
batch_size=128,
|
||||
use_fake_data=use_fake_data,
|
||||
flatten_images=False)
|
||||
|
||||
# Build a ConvNet.
|
||||
layer_collection = lc.LayerCollection()
|
||||
with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
|
||||
loss, accuracy = build_model(
|
||||
examples, labels, num_labels=10, layer_collection=layer_collection)
|
||||
|
||||
# Fit model.
|
||||
checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac")
|
||||
if op_strategy == "chief_worker":
|
||||
return distributed_grads_only_and_ops_chief_worker(
|
||||
task_id, is_chief, num_worker_tasks, num_ps_tasks, master,
|
||||
checkpoint_dir, loss, accuracy, layer_collection)
|
||||
elif op_strategy == "dedicated_workers":
|
||||
return distributed_grads_and_ops_dedicated_workers(
|
||||
task_id, is_chief, num_worker_tasks, num_ps_tasks, master,
|
||||
checkpoint_dir, loss, accuracy, layer_collection)
|
||||
else:
|
||||
raise ValueError("Only supported op strategies are : {}, {}".format(
|
||||
"chief_worker", "dedicated_workers"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run()
|
@ -1,62 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""Train a ConvNet on MNIST using K-FAC.
|
||||
|
||||
Distributed training with sync replicas optimizer. See
|
||||
`convnet.train_mnist_distributed_sync_replicas` for details.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
from absl import flags
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.kfac.examples import convnet
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("task", -1, "Task identifier")
|
||||
flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir")
|
||||
flags.DEFINE_string(
|
||||
"cov_inv_op_strategy", "chief_worker",
|
||||
"In dist training mode run the cov, inv ops on chief or dedicated workers."
|
||||
)
|
||||
flags.DEFINE_string("master", "local", "Session master.")
|
||||
flags.DEFINE_integer("ps_tasks", 2,
|
||||
"Number of tasks in the parameter server job.")
|
||||
flags.DEFINE_integer("replicas_to_aggregate", 5,
|
||||
"Number of replicas to aggregate.")
|
||||
flags.DEFINE_integer("worker_replicas", 5, "Number of replicas in worker job.")
|
||||
flags.DEFINE_integer("num_epochs", None, "Number of epochs.")
|
||||
|
||||
|
||||
def _is_chief():
|
||||
"""Determines whether a job is the chief worker."""
|
||||
if "chief_worker" in FLAGS.brain_jobs:
|
||||
return FLAGS.brain_job_name == "chief_worker"
|
||||
else:
|
||||
return FLAGS.task == 0
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
_ = unused_argv
|
||||
convnet.train_mnist_distributed_sync_replicas(
|
||||
FLAGS.task, _is_chief(), FLAGS.worker_replicas, FLAGS.ps_tasks,
|
||||
FLAGS.master, FLAGS.data_dir, FLAGS.num_epochs, FLAGS.cov_inv_op_strategy)
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run(main=main)
|
@ -1,48 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""Train a ConvNet on MNIST using K-FAC.
|
||||
|
||||
Multi tower training mode. See `convnet.train_mnist_multitower` for details.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
from absl import flags
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.kfac.examples import convnet
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string("data_dir", "/tmp/multitower_1/mnist", "local mnist dir")
|
||||
flags.DEFINE_integer("num_towers", 2,
|
||||
"Number of towers for multi tower training.")
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
_ = unused_argv
|
||||
assert FLAGS.num_towers > 1
|
||||
devices = ["/gpu:{}".format(tower_id) for tower_id in range(FLAGS.num_towers)]
|
||||
convnet.train_mnist_multitower(
|
||||
FLAGS.data_dir,
|
||||
num_epochs=200,
|
||||
num_towers=FLAGS.num_towers,
|
||||
devices=devices)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run(main=main)
|
@ -1,39 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""Train a ConvNet on MNIST using K-FAC.
|
||||
|
||||
Train on single machine. See `convnet.train_mnist_single_machine` for details.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
from absl import flags
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.kfac.examples import convnet
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir")
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run(main=main)
|
@ -1,354 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""Train an MLP on MNIST using K-FAC.
|
||||
|
||||
This library fits a 3-layer, tanh-activated MLP on MNIST using K-FAC. After
|
||||
~25k steps, this should reach perfect accuracy on the training set.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.kfac.examples import mnist
|
||||
|
||||
lc = tf.contrib.kfac.layer_collection
|
||||
opt = tf.contrib.kfac.optimizer
|
||||
|
||||
__all__ = [
|
||||
"fc_layer",
|
||||
"train_mnist",
|
||||
"train_mnist_multitower",
|
||||
]
|
||||
|
||||
|
||||
def fc_layer(layer_id, inputs, output_size):
|
||||
"""Builds a fully connected layer.
|
||||
|
||||
Args:
|
||||
layer_id: int. Integer ID for this layer's variables.
|
||||
inputs: Tensor of shape [num_examples, input_size]. Each row corresponds
|
||||
to a single example.
|
||||
output_size: int. Number of output dimensions after fully connected layer.
|
||||
|
||||
Returns:
|
||||
preactivations: Tensor of shape [num_examples, output_size]. Values of the
|
||||
layer immediately before the activation function.
|
||||
activations: Tensor of shape [num_examples, output_size]. Values of the
|
||||
layer immediately after the activation function.
|
||||
params: Tuple of (weights, bias), parameters for this layer.
|
||||
"""
|
||||
# TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
|
||||
layer = tf.layers.Dense(
|
||||
output_size,
|
||||
kernel_initializer=tf.random_normal_initializer(),
|
||||
name="fc_%d" % layer_id)
|
||||
preactivations = layer(inputs)
|
||||
activations = tf.nn.tanh(preactivations)
|
||||
|
||||
# layer.weights is a list. This converts it a (hashable) tuple.
|
||||
return preactivations, activations, (layer.kernel, layer.bias)
|
||||
|
||||
|
||||
def build_model(examples, labels, num_labels, layer_collection):
|
||||
"""Builds an MLP classification model.
|
||||
|
||||
Args:
|
||||
examples: Tensor of shape [num_examples, num_features]. Represents inputs of
|
||||
model.
|
||||
labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted
|
||||
by softmax for each example.
|
||||
num_labels: int. Number of distinct values 'labels' can take on.
|
||||
layer_collection: LayerCollection instance describing model architecture.
|
||||
|
||||
Returns:
|
||||
loss: 0-D Tensor representing loss to be minimized.
|
||||
accuracy: 0-D Tensor representing model's accuracy.
|
||||
"""
|
||||
# Build an MLP. For each layer, we'll keep track of the preactivations,
|
||||
# activations, weights, and bias.
|
||||
pre0, act0, params0 = fc_layer(layer_id=0, inputs=examples, output_size=128)
|
||||
pre1, act1, params1 = fc_layer(layer_id=1, inputs=act0, output_size=64)
|
||||
pre2, act2, params2 = fc_layer(layer_id=2, inputs=act1, output_size=32)
|
||||
logits, _, params3 = fc_layer(layer_id=3, inputs=act2, output_size=num_labels)
|
||||
loss = tf.reduce_mean(
|
||||
tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
labels=labels, logits=logits))
|
||||
accuracy = tf.reduce_mean(
|
||||
tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32))
|
||||
|
||||
# Register parameters. K-FAC needs to know about the inputs, outputs, and
|
||||
# parameters of each layer and the logits powering the posterior probability
|
||||
# over classes.
|
||||
tf.logging.info("Building LayerCollection.")
|
||||
layer_collection.register_fully_connected(params0, examples, pre0)
|
||||
layer_collection.register_fully_connected(params1, act0, pre1)
|
||||
layer_collection.register_fully_connected(params2, act1, pre2)
|
||||
layer_collection.register_fully_connected(params3, act2, logits)
|
||||
layer_collection.register_categorical_predictive_distribution(
|
||||
logits, name="logits")
|
||||
|
||||
return loss, accuracy
|
||||
|
||||
|
||||
def minimize(loss, accuracy, layer_collection, num_towers, session_config=None):
|
||||
"""Minimize 'loss' with KfacOptimizer.
|
||||
|
||||
Args:
|
||||
loss: 0-D Tensor. Loss to be minimized.
|
||||
accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
|
||||
layer_collection: LayerCollection instance. Describes layers in model.
|
||||
num_towers: int. Number of CPUs to split minibatch across.
|
||||
session_config: tf.ConfigProto. Configuration for tf.Session().
|
||||
|
||||
Returns:
|
||||
accuracy of classifier on final minibatch.
|
||||
"""
|
||||
devices = tuple("/cpu:%d" % tower_id for tower_id in range(num_towers))
|
||||
|
||||
# Train with K-FAC. We'll use a decreasing learning rate that's cut in 1/2
|
||||
# every 10k iterations.
|
||||
tf.logging.info("Building KFAC Optimizer.")
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
optimizer = opt.KfacOptimizer(
|
||||
learning_rate=tf.train.exponential_decay(
|
||||
0.00002, global_step, 10000, 0.5, staircase=True),
|
||||
cov_ema_decay=0.95,
|
||||
damping=0.0005,
|
||||
layer_collection=layer_collection,
|
||||
momentum=0.99,
|
||||
placement_strategy="round_robin",
|
||||
cov_devices=devices,
|
||||
inv_devices=devices)
|
||||
|
||||
(cov_update_thunks,
|
||||
inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
|
||||
|
||||
def make_update_op(update_thunks):
|
||||
update_ops = [thunk() for thunk in update_thunks]
|
||||
return tf.group(*update_ops)
|
||||
|
||||
# TODO(b/78537047): change (some) examples to use PeriodicInvCovUpdateKfacOpt
|
||||
# once that gets moved over? Could still leave more advanced examples as they
|
||||
# are (e.g. train_mnist_estimator in this file)
|
||||
|
||||
cov_update_op = make_update_op(cov_update_thunks)
|
||||
with tf.control_dependencies([cov_update_op]):
|
||||
# We update the inverses only every 20 iterations.
|
||||
inverse_op = tf.cond(
|
||||
tf.equal(tf.mod(global_step, 100), 0),
|
||||
lambda: make_update_op(inv_update_thunks), tf.no_op)
|
||||
with tf.control_dependencies([inverse_op]):
|
||||
train_op = optimizer.minimize(loss, global_step=global_step)
|
||||
|
||||
tf.logging.info("Starting training.")
|
||||
with tf.train.MonitoredTrainingSession(config=session_config) as sess:
|
||||
while not sess.should_stop():
|
||||
global_step_, loss_, accuracy_, _ = sess.run(
|
||||
[global_step, loss, accuracy, train_op])
|
||||
|
||||
if global_step_ % 100 == 0:
|
||||
tf.logging.info("global_step: %d | loss: %f | accuracy: %f",
|
||||
global_step_, loss_, accuracy_)
|
||||
|
||||
return accuracy_
|
||||
|
||||
|
||||
def train_mnist(data_dir, num_epochs, use_fake_data=False):
|
||||
"""Train an MLP on MNIST.
|
||||
|
||||
Args:
|
||||
data_dir: string. Directory to read MNIST examples from.
|
||||
num_epochs: int. Number of passes to make over the training set.
|
||||
use_fake_data: bool. If True, generate a synthetic dataset.
|
||||
|
||||
Returns:
|
||||
accuracy of model on the final minibatch of training data.
|
||||
"""
|
||||
# Load a dataset.
|
||||
tf.logging.info("Loading MNIST into memory.")
|
||||
examples, labels = mnist.load_mnist(
|
||||
data_dir,
|
||||
num_epochs=num_epochs,
|
||||
batch_size=64,
|
||||
flatten_images=True,
|
||||
use_fake_data=use_fake_data)
|
||||
|
||||
# Build an MLP. The model's layers will be added to the LayerCollection.
|
||||
tf.logging.info("Building model.")
|
||||
layer_collection = lc.LayerCollection()
|
||||
loss, accuracy = build_model(examples, labels, 10, layer_collection)
|
||||
|
||||
# Fit model.
|
||||
minimize(loss, accuracy, layer_collection, 1)
|
||||
|
||||
|
||||
def train_mnist_multitower(data_dir,
|
||||
num_epochs,
|
||||
num_towers,
|
||||
use_fake_data=False):
|
||||
"""Train an MLP on MNIST, splitting the minibatch across multiple towers.
|
||||
|
||||
Args:
|
||||
data_dir: string. Directory to read MNIST examples from.
|
||||
num_epochs: int. Number of passes to make over the training set.
|
||||
num_towers: int. Number of CPUs to split minibatch across.
|
||||
use_fake_data: bool. If True, generate a synthetic dataset.
|
||||
|
||||
Returns:
|
||||
accuracy of model on the final minibatch of training data.
|
||||
"""
|
||||
# Load a dataset.
|
||||
tower_batch_size = 64
|
||||
batch_size = tower_batch_size * num_towers
|
||||
tf.logging.info(
|
||||
("Loading MNIST into memory. Using batch_size = %d = %d towers * %d "
|
||||
"tower batch size.") % (batch_size, num_towers, tower_batch_size))
|
||||
examples, labels = mnist.load_mnist(
|
||||
data_dir,
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size,
|
||||
flatten_images=True,
|
||||
use_fake_data=use_fake_data)
|
||||
|
||||
# Split minibatch across towers.
|
||||
examples = tf.split(examples, num_towers)
|
||||
labels = tf.split(labels, num_towers)
|
||||
|
||||
# Build an MLP. Each tower's layers will be added to the LayerCollection.
|
||||
layer_collection = lc.LayerCollection()
|
||||
tower_results = []
|
||||
for tower_id in range(num_towers):
|
||||
with tf.device("/cpu:%d" % tower_id):
|
||||
with tf.name_scope("tower%d" % tower_id):
|
||||
with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)):
|
||||
tf.logging.info("Building tower %d." % tower_id)
|
||||
tower_results.append(
|
||||
build_model(examples[tower_id], labels[tower_id], 10,
|
||||
layer_collection))
|
||||
losses, accuracies = zip(*tower_results)
|
||||
|
||||
# Average across towers.
|
||||
loss = tf.reduce_mean(losses)
|
||||
accuracy = tf.reduce_mean(accuracies)
|
||||
|
||||
# Fit model.
|
||||
session_config = tf.ConfigProto(
|
||||
allow_soft_placement=False, device_count={
|
||||
"CPU": num_towers
|
||||
})
|
||||
return minimize(
|
||||
loss, accuracy, layer_collection, num_towers,
|
||||
session_config=session_config)
|
||||
|
||||
|
||||
def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False):
|
||||
"""Train an MLP on MNIST using tf.estimator.
|
||||
|
||||
Args:
|
||||
data_dir: string. Directory to read MNIST examples from.
|
||||
num_epochs: int. Number of passes to make over the training set.
|
||||
use_fake_data: bool. If True, generate a synthetic dataset.
|
||||
|
||||
Returns:
|
||||
accuracy of model on the final minibatch of training data.
|
||||
"""
|
||||
|
||||
# Load a dataset.
|
||||
def input_fn():
|
||||
tf.logging.info("Loading MNIST into memory.")
|
||||
return mnist.load_mnist(
|
||||
data_dir,
|
||||
num_epochs=num_epochs,
|
||||
batch_size=64,
|
||||
flatten_images=True,
|
||||
use_fake_data=use_fake_data)
|
||||
|
||||
def model_fn(features, labels, mode, params):
|
||||
"""Model function for MLP trained with K-FAC.
|
||||
|
||||
Args:
|
||||
features: Tensor of shape [batch_size, input_size]. Input features.
|
||||
labels: Tensor of shape [batch_size]. Target labels for training.
|
||||
mode: tf.estimator.ModeKey. Must be TRAIN.
|
||||
params: ignored.
|
||||
|
||||
Returns:
|
||||
EstimatorSpec for training.
|
||||
|
||||
Raises:
|
||||
ValueError: If 'mode' is anything other than TRAIN.
|
||||
"""
|
||||
del params
|
||||
|
||||
if mode != tf.estimator.ModeKeys.TRAIN:
|
||||
raise ValueError("Only training is supposed with this API.")
|
||||
|
||||
# Build a ConvNet.
|
||||
layer_collection = lc.LayerCollection()
|
||||
loss, accuracy = build_model(
|
||||
features, labels, num_labels=10, layer_collection=layer_collection)
|
||||
|
||||
# Train with K-FAC.
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
optimizer = opt.KfacOptimizer(
|
||||
learning_rate=tf.train.exponential_decay(
|
||||
0.00002, global_step, 10000, 0.5, staircase=True),
|
||||
cov_ema_decay=0.95,
|
||||
damping=0.0001,
|
||||
layer_collection=layer_collection,
|
||||
momentum=0.99)
|
||||
|
||||
(cov_update_thunks,
|
||||
inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
|
||||
|
||||
def make_update_op(update_thunks):
|
||||
update_ops = [thunk() for thunk in update_thunks]
|
||||
return tf.group(*update_ops)
|
||||
|
||||
def make_batch_executed_op(update_thunks, batch_size=1):
|
||||
return tf.group(*tf.contrib.kfac.utils.batch_execute(
|
||||
global_step, update_thunks, batch_size=batch_size))
|
||||
|
||||
# Run cov_update_op every step. Run 1 inv_update_ops per step.
|
||||
cov_update_op = make_update_op(cov_update_thunks)
|
||||
with tf.control_dependencies([cov_update_op]):
|
||||
# But make sure to execute all the inverse ops on the first step
|
||||
inverse_op = tf.cond(tf.equal(global_step, 0),
|
||||
lambda: make_update_op(inv_update_thunks),
|
||||
lambda: make_batch_executed_op(inv_update_thunks))
|
||||
with tf.control_dependencies([inverse_op]):
|
||||
train_op = optimizer.minimize(loss, global_step=global_step)
|
||||
|
||||
# Print metrics every 5 sec.
|
||||
hooks = [
|
||||
tf.train.LoggingTensorHook(
|
||||
{
|
||||
"loss": loss,
|
||||
"accuracy": accuracy
|
||||
}, every_n_secs=5),
|
||||
]
|
||||
return tf.estimator.EstimatorSpec(
|
||||
mode=mode, loss=loss, train_op=train_op, training_hooks=hooks)
|
||||
|
||||
run_config = tf.estimator.RunConfig(
|
||||
model_dir="/tmp/mnist", save_checkpoints_steps=1, keep_checkpoint_max=100)
|
||||
|
||||
# Train until input_fn() is empty with Estimator. This is a prerequisite for
|
||||
# TPU compatibility.
|
||||
estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)
|
||||
estimator.train(input_fn=input_fn)
|
@ -1,64 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""Train an MLP on MNIST using K-FAC.
|
||||
|
||||
See mlp.py for details.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.kfac.examples import mlp
|
||||
|
||||
FLAGS = None
|
||||
|
||||
|
||||
def main(argv):
|
||||
_ = argv
|
||||
if FLAGS.use_estimator:
|
||||
if FLAGS.num_towers != 1:
|
||||
raise ValueError("Only 1 device supported in tf.estimator example.")
|
||||
mlp.train_mnist_estimator(FLAGS.data_dir, num_epochs=200)
|
||||
elif FLAGS.num_towers > 1:
|
||||
mlp.train_mnist_multitower(
|
||||
FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers)
|
||||
else:
|
||||
mlp.train_mnist(FLAGS.data_dir, num_epochs=200)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
type=str,
|
||||
default="/tmp/mnist",
|
||||
help="Directory to store dataset in.")
|
||||
parser.add_argument(
|
||||
"--num_towers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of CPUs to split minibatch across.")
|
||||
parser.add_argument(
|
||||
"--use_estimator",
|
||||
action="store_true",
|
||||
help="Use tf.estimator API to train.")
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
@ -1,69 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities for loading MNIST into TensorFlow."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
__all__ = [
|
||||
'load_mnist',
|
||||
]
|
||||
|
||||
|
||||
def load_mnist(data_dir,
|
||||
num_epochs,
|
||||
batch_size,
|
||||
flatten_images=True,
|
||||
use_fake_data=False):
|
||||
"""Loads MNIST dataset into memory.
|
||||
|
||||
Args:
|
||||
data_dir: string. Directory to read MNIST examples from.
|
||||
num_epochs: int. Number of passes to make over the dataset.
|
||||
batch_size: int. Number of examples per minibatch.
|
||||
flatten_images: bool. If True, [28, 28, 1]-shaped images are flattened into
|
||||
[784]-shaped vectors.
|
||||
use_fake_data: bool. If True, generate a synthetic dataset rather than
|
||||
reading MNIST in.
|
||||
|
||||
Returns:
|
||||
examples: Tensor of shape [batch_size, 784] if 'flatten_images' is
|
||||
True, else [batch_size, 28, 28, 1]. Each row is one example.
|
||||
Values in [0, 1].
|
||||
labels: Tensor of shape [batch_size]. Indices of integer corresponding to
|
||||
each example. Values in {0...9}.
|
||||
"""
|
||||
if use_fake_data:
|
||||
rng = np.random.RandomState(42)
|
||||
num_examples = batch_size * 4
|
||||
images = rng.rand(num_examples, 28 * 28)
|
||||
if not flatten_images:
|
||||
images = np.reshape(images, [num_examples, 28, 28, 1])
|
||||
labels = rng.randint(10, size=num_examples)
|
||||
else:
|
||||
mnist_data = tf.contrib.learn.datasets.mnist.read_data_sets(
|
||||
data_dir, reshape=flatten_images)
|
||||
num_examples = len(mnist_data.train.labels)
|
||||
images = mnist_data.train.images
|
||||
labels = mnist_data.train.labels
|
||||
|
||||
dataset = tf.data.Dataset.from_tensor_slices((np.asarray(
|
||||
images, dtype=np.float32), np.asarray(labels, dtype=np.int64)))
|
||||
return (dataset.repeat(num_epochs).shuffle(num_examples).batch(batch_size)
|
||||
.make_one_shot_iterator().get_next())
|
@ -1,52 +0,0 @@
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
py_test(
|
||||
name = "mlp_test",
|
||||
size = "large",
|
||||
srcs = ["mlp_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_pip",
|
||||
"notsan",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/contrib/kfac/examples:mlp",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "convnet_test",
|
||||
size = "large",
|
||||
srcs = ["convnet_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_pip",
|
||||
"notsan",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/contrib/kfac",
|
||||
"//tensorflow/contrib/kfac/examples:convnet",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "mnist_test",
|
||||
srcs = ["mnist_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/contrib/kfac/examples:mnist",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
@ -1,166 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for convnet.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.kfac import layer_collection as lc
|
||||
from tensorflow.contrib.kfac.examples import convnet
|
||||
|
||||
|
||||
class ConvNetTest(tf.test.TestCase):
|
||||
|
||||
def testConvLayer(self):
|
||||
with tf.Graph().as_default():
|
||||
pre, act, (w, b) = convnet.conv_layer(
|
||||
layer_id=1,
|
||||
inputs=tf.zeros([5, 3, 3, 2]),
|
||||
kernel_size=3,
|
||||
out_channels=5)
|
||||
self.assertShapeEqual(np.zeros([5, 3, 3, 5]), pre)
|
||||
self.assertShapeEqual(np.zeros([5, 3, 3, 5]), act)
|
||||
self.assertShapeEqual(np.zeros([3, 3, 2, 5]), tf.convert_to_tensor(w))
|
||||
self.assertShapeEqual(np.zeros([5]), tf.convert_to_tensor(b))
|
||||
self.assertIsInstance(w, tf.Variable)
|
||||
self.assertIsInstance(b, tf.Variable)
|
||||
self.assertIn("conv_1", w.op.name)
|
||||
self.assertIn("conv_1", b.op.name)
|
||||
|
||||
def testMaxPoolLayer(self):
|
||||
with tf.Graph().as_default():
|
||||
act = convnet.max_pool_layer(
|
||||
layer_id=1, inputs=tf.zeros([5, 6, 6, 2]), kernel_size=5, stride=3)
|
||||
self.assertShapeEqual(np.zeros([5, 2, 2, 2]), act)
|
||||
self.assertEqual(act.op.name, "pool_1/pool")
|
||||
|
||||
def testLinearLayer(self):
|
||||
with tf.Graph().as_default():
|
||||
act, (w, b) = convnet.linear_layer(
|
||||
layer_id=1, inputs=tf.zeros([5, 20]), output_size=5)
|
||||
self.assertShapeEqual(np.zeros([5, 5]), act)
|
||||
self.assertShapeEqual(np.zeros([20, 5]), tf.convert_to_tensor(w))
|
||||
self.assertShapeEqual(np.zeros([5]), tf.convert_to_tensor(b))
|
||||
self.assertIsInstance(w, tf.Variable)
|
||||
self.assertIsInstance(b, tf.Variable)
|
||||
self.assertIn("fc_1", w.op.name)
|
||||
self.assertIn("fc_1", b.op.name)
|
||||
|
||||
def testBuildModel(self):
|
||||
with tf.Graph().as_default():
|
||||
x = tf.placeholder(tf.float32, [None, 6, 6, 3])
|
||||
y = tf.placeholder(tf.int64, [None])
|
||||
layer_collection = lc.LayerCollection()
|
||||
loss, accuracy = convnet.build_model(
|
||||
x, y, num_labels=5, layer_collection=layer_collection)
|
||||
|
||||
# Ensure layers and logits were registered.
|
||||
self.assertEqual(len(layer_collection.fisher_blocks), 3)
|
||||
self.assertEqual(len(layer_collection.losses), 1)
|
||||
|
||||
# Ensure inference doesn't crash.
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
feed_dict = {
|
||||
x: np.random.randn(10, 6, 6, 3).astype(np.float32),
|
||||
y: np.random.randint(5, size=10).astype(np.int64),
|
||||
}
|
||||
sess.run([loss, accuracy], feed_dict=feed_dict)
|
||||
|
||||
def _build_toy_problem(self):
|
||||
"""Construct a toy linear regression problem.
|
||||
|
||||
Initial loss should be,
|
||||
2.5 = 0.5 * (1^2 + 2^2)
|
||||
|
||||
Returns:
|
||||
loss: 0-D Tensor representing loss to be minimized.
|
||||
accuracy: 0-D Tensors representing model accuracy.
|
||||
layer_collection: LayerCollection instance describing model architecture.
|
||||
"""
|
||||
x = np.asarray([[1.], [2.]]).astype(np.float32)
|
||||
y = np.asarray([1., 2.]).astype(np.float32)
|
||||
x, y = (tf.data.Dataset.from_tensor_slices((x, y))
|
||||
.repeat(100).batch(2).make_one_shot_iterator().get_next())
|
||||
w = tf.get_variable("w", shape=[1, 1], initializer=tf.zeros_initializer())
|
||||
y_hat = tf.matmul(x, w)
|
||||
loss = tf.reduce_mean(0.5 * tf.square(y_hat - y))
|
||||
accuracy = loss
|
||||
|
||||
layer_collection = lc.LayerCollection()
|
||||
layer_collection.register_fully_connected(params=w, inputs=x, outputs=y_hat)
|
||||
layer_collection.register_normal_predictive_distribution(y_hat)
|
||||
|
||||
return loss, accuracy, layer_collection
|
||||
|
||||
def testMinimizeLossSingleMachine(self):
|
||||
with tf.Graph().as_default():
|
||||
loss, accuracy, layer_collection = self._build_toy_problem()
|
||||
accuracy_ = convnet.minimize_loss_single_machine(
|
||||
loss, accuracy, layer_collection, device="/cpu:0")
|
||||
self.assertLess(accuracy_, 2.0)
|
||||
|
||||
def testMinimizeLossDistributed(self):
|
||||
with tf.Graph().as_default():
|
||||
loss, accuracy, layer_collection = self._build_toy_problem()
|
||||
accuracy_ = convnet.distributed_grads_only_and_ops_chief_worker(
|
||||
task_id=0,
|
||||
is_chief=True,
|
||||
num_worker_tasks=1,
|
||||
num_ps_tasks=0,
|
||||
master="",
|
||||
checkpoint_dir=None,
|
||||
loss=loss,
|
||||
accuracy=accuracy,
|
||||
layer_collection=layer_collection)
|
||||
self.assertLess(accuracy_, 2.0)
|
||||
|
||||
def testTrainMnistSingleMachine(self):
|
||||
with tf.Graph().as_default():
|
||||
# Ensure model training doesn't crash.
|
||||
#
|
||||
# Ideally, we should check that accuracy increases as the model converges,
|
||||
# but there are too few parameters for the model to effectively memorize
|
||||
# the training set the way an MLP can.
|
||||
convnet.train_mnist_single_machine(
|
||||
data_dir=None, num_epochs=1, use_fake_data=True, device="/cpu:0")
|
||||
|
||||
def testTrainMnistMultitower(self):
|
||||
with tf.Graph().as_default():
|
||||
# Ensure model training doesn't crash.
|
||||
convnet.train_mnist_multitower(
|
||||
data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True)
|
||||
|
||||
def testTrainMnistDistributed(self):
|
||||
with tf.Graph().as_default():
|
||||
# Ensure model training doesn't crash.
|
||||
convnet.train_mnist_distributed_sync_replicas(
|
||||
task_id=0,
|
||||
is_chief=True,
|
||||
num_worker_tasks=1,
|
||||
num_ps_tasks=0,
|
||||
master="",
|
||||
data_dir=None,
|
||||
num_epochs=2,
|
||||
op_strategy="chief_worker",
|
||||
use_fake_data=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
@ -1,63 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for mlp.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.kfac.examples import mlp
|
||||
|
||||
|
||||
class MlpTest(tf.test.TestCase):
|
||||
|
||||
def testFcLayer(self):
|
||||
with tf.Graph().as_default():
|
||||
pre, act, (w, b) = mlp.fc_layer(
|
||||
layer_id=1, inputs=tf.zeros([5, 3]), output_size=10)
|
||||
self.assertShapeEqual(np.zeros([5, 10]), pre)
|
||||
self.assertShapeEqual(np.zeros([5, 10]), act)
|
||||
self.assertShapeEqual(np.zeros([3, 10]), tf.convert_to_tensor(w))
|
||||
self.assertShapeEqual(np.zeros([10]), tf.convert_to_tensor(b))
|
||||
self.assertIsInstance(w, tf.Variable)
|
||||
self.assertIsInstance(b, tf.Variable)
|
||||
self.assertIn("fc_1/", w.op.name)
|
||||
self.assertIn("fc_1/", b.op.name)
|
||||
|
||||
def testTrainMnist(self):
|
||||
with tf.Graph().as_default():
|
||||
# Ensure model training doesn't crash.
|
||||
#
|
||||
# Ideally, we should check that accuracy increases as the model converges,
|
||||
# but that takes a non-trivial amount of compute.
|
||||
mlp.train_mnist(data_dir=None, num_epochs=1, use_fake_data=True)
|
||||
|
||||
def testTrainMnistMultitower(self):
|
||||
with tf.Graph().as_default():
|
||||
# Ensure model training doesn't crash.
|
||||
mlp.train_mnist_multitower(
|
||||
data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True)
|
||||
|
||||
def testTrainMnistEstimator(self):
|
||||
with tf.Graph().as_default():
|
||||
# Ensure model training doesn't crash.
|
||||
mlp.train_mnist_estimator(data_dir=None, num_epochs=1, use_fake_data=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
@ -1,72 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for mnist.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.kfac.examples import mnist
|
||||
|
||||
|
||||
class MnistTest(tf.test.TestCase):
|
||||
|
||||
def testValues(self):
|
||||
"""Ensure values are in their expected range."""
|
||||
with tf.Graph().as_default():
|
||||
examples, labels = mnist.load_mnist(
|
||||
data_dir=None, num_epochs=1, batch_size=64, use_fake_data=True)
|
||||
|
||||
with self.test_session() as sess:
|
||||
examples_, labels_ = sess.run([examples, labels])
|
||||
self.assertTrue(np.all((0 <= examples_) & (examples_ < 1)))
|
||||
self.assertTrue(np.all((0 <= labels_) & (labels_ < 10)))
|
||||
|
||||
def testFlattenedShapes(self):
|
||||
"""Ensure images are flattened into their appropriate shape."""
|
||||
with tf.Graph().as_default():
|
||||
examples, labels = mnist.load_mnist(
|
||||
data_dir=None,
|
||||
num_epochs=1,
|
||||
batch_size=64,
|
||||
flatten_images=True,
|
||||
use_fake_data=True)
|
||||
|
||||
with self.test_session() as sess:
|
||||
examples_, labels_ = sess.run([examples, labels])
|
||||
self.assertEqual(examples_.shape, (64, 784))
|
||||
self.assertEqual(labels_.shape, (64,))
|
||||
|
||||
def testNotFlattenedShapes(self):
|
||||
"""Ensure non-flattened images are their appropriate shape."""
|
||||
with tf.Graph().as_default():
|
||||
examples, labels = mnist.load_mnist(
|
||||
data_dir=None,
|
||||
num_epochs=1,
|
||||
batch_size=64,
|
||||
flatten_images=False,
|
||||
use_fake_data=True)
|
||||
|
||||
with self.test_session() as sess:
|
||||
examples_, labels_ = sess.run([examples, labels])
|
||||
self.assertEqual(examples_.shape, (64, 28, 28, 1))
|
||||
self.assertEqual(labels_.shape, (64,))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
Binary file not shown.
Before Width: | Height: | Size: 53 KiB |
@ -1,160 +0,0 @@
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
py_test(
|
||||
name = "estimator_test",
|
||||
srcs = ["estimator_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/kfac/python/ops:fisher_estimator",
|
||||
"//tensorflow/contrib/kfac/python/ops:layer_collection",
|
||||
"//tensorflow/contrib/kfac/python/ops:utils",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:linalg_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "fisher_factors_test",
|
||||
srcs = ["fisher_factors_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/kfac/python/ops:fisher_blocks",
|
||||
"//tensorflow/contrib/kfac/python/ops:fisher_factors",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:random_seed",
|
||||
"//tensorflow/python:variables",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "fisher_blocks_test",
|
||||
srcs = ["fisher_blocks_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/kfac/python/ops:fisher_blocks",
|
||||
"//tensorflow/contrib/kfac/python/ops:layer_collection",
|
||||
"//tensorflow/contrib/kfac/python/ops:linear_operator",
|
||||
"//tensorflow/contrib/kfac/python/ops:utils",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:random_seed",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:variables",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "layer_collection_test",
|
||||
srcs = ["layer_collection_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/kfac/python/ops:fisher_blocks",
|
||||
"//tensorflow/contrib/kfac/python/ops:fisher_factors",
|
||||
"//tensorflow/contrib/kfac/python/ops:layer_collection",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:linalg_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:random_seed",
|
||||
"//tensorflow/python:variable_scope",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "optimizer_test",
|
||||
srcs = ["optimizer_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/kfac/python/ops:fisher_factors",
|
||||
"//tensorflow/contrib/kfac/python/ops:kfac_optimizer",
|
||||
"//tensorflow/contrib/kfac/python/ops:layer_collection",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "utils_test",
|
||||
srcs = ["utils_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_windows"], # TODO: needs investigation on Windows
|
||||
deps = [
|
||||
"//tensorflow/contrib/kfac/python/ops:utils",
|
||||
"//tensorflow/contrib/tpu",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:linalg_ops",
|
||||
"//tensorflow/python:random_seed",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "op_queue_test",
|
||||
srcs = ["op_queue_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/kfac/python/ops:op_queue",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "loss_functions_test",
|
||||
srcs = ["loss_functions_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/kfac/python/ops:loss_functions",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
@ -1,310 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tf.contrib.kfac.estimator."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.kfac.python.ops import estimator
|
||||
from tensorflow.contrib.kfac.python.ops import layer_collection as lc
|
||||
from tensorflow.contrib.kfac.python.ops import utils
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
_ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"]
|
||||
|
||||
|
||||
class EstimatorTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._graph = ops.Graph()
|
||||
with self._graph.as_default():
|
||||
self.layer_collection = lc.LayerCollection()
|
||||
|
||||
self.inputs = random_ops.random_normal((2, 2), dtype=dtypes.float32)
|
||||
self.weights = variable_scope.get_variable(
|
||||
"w", shape=(2, 2), dtype=dtypes.float32)
|
||||
self.bias = variable_scope.get_variable(
|
||||
"b", initializer=init_ops.zeros_initializer(), shape=(2, 1))
|
||||
self.output = math_ops.matmul(self.inputs, self.weights) + self.bias
|
||||
|
||||
# Only register the weights.
|
||||
self.layer_collection.register_fully_connected(
|
||||
params=(self.weights,), inputs=self.inputs, outputs=self.output)
|
||||
|
||||
self.outputs = math_ops.tanh(self.output)
|
||||
self.targets = array_ops.zeros_like(self.outputs)
|
||||
self.layer_collection.register_categorical_predictive_distribution(
|
||||
logits=self.outputs, targets=self.targets)
|
||||
|
||||
def testEstimatorInitManualRegistration(self):
|
||||
with self._graph.as_default():
|
||||
# We should be able to build an estimator for only the registered vars.
|
||||
estimator.FisherEstimatorRoundRobin(
|
||||
variables=[self.weights],
|
||||
cov_ema_decay=0.1,
|
||||
damping=0.2,
|
||||
layer_collection=self.layer_collection
|
||||
)
|
||||
|
||||
# Check that we throw an error if we try to build an estimator for vars
|
||||
# that were not manually registered.
|
||||
with self.assertRaises(ValueError):
|
||||
est = estimator.FisherEstimatorRoundRobin(
|
||||
variables=[self.weights, self.bias],
|
||||
cov_ema_decay=0.1,
|
||||
damping=0.2,
|
||||
layer_collection=self.layer_collection
|
||||
)
|
||||
est.make_vars_and_create_op_thunks()
|
||||
|
||||
# Check that we throw an error if we don't include registered variables,
|
||||
# i.e. self.weights
|
||||
with self.assertRaises(ValueError):
|
||||
est = estimator.FisherEstimatorRoundRobin(
|
||||
variables=[],
|
||||
cov_ema_decay=0.1,
|
||||
damping=0.2,
|
||||
layer_collection=self.layer_collection)
|
||||
est.make_vars_and_create_op_thunks()
|
||||
|
||||
@test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42)
|
||||
def testVariableWrongNumberOfUses(self, mock_uses):
|
||||
with self.assertRaises(ValueError):
|
||||
est = estimator.FisherEstimatorRoundRobin(
|
||||
variables=[self.weights],
|
||||
cov_ema_decay=0.1,
|
||||
damping=0.2,
|
||||
layer_collection=self.layer_collection)
|
||||
est.make_vars_and_create_op_thunks()
|
||||
|
||||
def testInvalidEstimationMode(self):
|
||||
with self.assertRaises(ValueError):
|
||||
est = estimator.FisherEstimatorRoundRobin(
|
||||
variables=[self.weights],
|
||||
cov_ema_decay=0.1,
|
||||
damping=0.2,
|
||||
layer_collection=self.layer_collection,
|
||||
estimation_mode="not_a_real_mode")
|
||||
est.make_vars_and_create_op_thunks()
|
||||
|
||||
def testGradientsModeBuild(self):
|
||||
with self._graph.as_default():
|
||||
est = estimator.FisherEstimatorRoundRobin(
|
||||
variables=[self.weights],
|
||||
cov_ema_decay=0.1,
|
||||
damping=0.2,
|
||||
layer_collection=self.layer_collection,
|
||||
estimation_mode="gradients")
|
||||
est.make_vars_and_create_op_thunks()
|
||||
|
||||
def testEmpiricalModeBuild(self):
|
||||
with self._graph.as_default():
|
||||
est = estimator.FisherEstimatorRoundRobin(
|
||||
variables=[self.weights],
|
||||
cov_ema_decay=0.1,
|
||||
damping=0.2,
|
||||
layer_collection=self.layer_collection,
|
||||
estimation_mode="empirical")
|
||||
est.make_vars_and_create_op_thunks()
|
||||
|
||||
def testCurvaturePropModeBuild(self):
|
||||
with self._graph.as_default():
|
||||
est = estimator.FisherEstimatorRoundRobin(
|
||||
variables=[self.weights],
|
||||
cov_ema_decay=0.1,
|
||||
damping=0.2,
|
||||
layer_collection=self.layer_collection,
|
||||
estimation_mode="curvature_prop")
|
||||
est.make_vars_and_create_op_thunks()
|
||||
|
||||
def testExactModeBuild(self):
|
||||
with self._graph.as_default():
|
||||
est = estimator.FisherEstimatorRoundRobin(
|
||||
variables=[self.weights],
|
||||
cov_ema_decay=0.1,
|
||||
damping=0.2,
|
||||
layer_collection=self.layer_collection,
|
||||
estimation_mode="exact")
|
||||
est.make_vars_and_create_op_thunks()
|
||||
|
||||
def test_cov_update_thunks(self):
|
||||
"""Ensures covariance update ops run once per global_step."""
|
||||
with self._graph.as_default(), self.cached_session() as sess:
|
||||
fisher_estimator = estimator.FisherEstimatorRoundRobin(
|
||||
variables=[self.weights],
|
||||
layer_collection=self.layer_collection,
|
||||
damping=0.2,
|
||||
cov_ema_decay=0.0)
|
||||
|
||||
# Construct an op that executes one covariance update per step.
|
||||
global_step = training_util.get_or_create_global_step()
|
||||
(cov_variable_thunks, cov_update_op_thunks, _,
|
||||
_) = fisher_estimator.create_ops_and_vars_thunks()
|
||||
for thunk in cov_variable_thunks:
|
||||
thunk()
|
||||
cov_matrices = [
|
||||
fisher_factor.get_cov()
|
||||
for fisher_factor in self.layer_collection.get_factors()
|
||||
]
|
||||
cov_update_op = control_flow_ops.case(
|
||||
[(math_ops.equal(global_step, i), thunk)
|
||||
for i, thunk in enumerate(cov_update_op_thunks)])
|
||||
increment_global_step = global_step.assign_add(1)
|
||||
|
||||
sess.run(variables.global_variables_initializer())
|
||||
initial_cov_values = sess.run(cov_matrices)
|
||||
|
||||
# Ensure there's one update per covariance matrix.
|
||||
self.assertEqual(len(cov_matrices), len(cov_update_op_thunks))
|
||||
|
||||
# Test is no-op if only 1 covariance matrix.
|
||||
assert len(cov_matrices) > 1
|
||||
|
||||
for i in range(len(cov_matrices)):
|
||||
# Compare new and old covariance values
|
||||
new_cov_values = sess.run(cov_matrices)
|
||||
is_cov_equal = [
|
||||
np.allclose(initial_cov_value, new_cov_value)
|
||||
for (initial_cov_value,
|
||||
new_cov_value) in zip(initial_cov_values, new_cov_values)
|
||||
]
|
||||
num_cov_equal = sum(is_cov_equal)
|
||||
|
||||
# Ensure exactly one covariance matrix changes per step.
|
||||
self.assertEqual(num_cov_equal, len(cov_matrices) - i)
|
||||
|
||||
# Run all covariance update ops.
|
||||
sess.run(cov_update_op)
|
||||
sess.run(increment_global_step)
|
||||
|
||||
def test_round_robin_placement(self):
|
||||
"""Check if the ops and variables are placed on devices correctly."""
|
||||
with self._graph.as_default():
|
||||
fisher_estimator = estimator.FisherEstimatorRoundRobin(
|
||||
variables=[self.weights],
|
||||
layer_collection=self.layer_collection,
|
||||
damping=0.2,
|
||||
cov_ema_decay=0.0,
|
||||
cov_devices=["/cpu:{}".format(i) for i in range(2)],
|
||||
inv_devices=["/cpu:{}".format(i) for i in range(2)])
|
||||
|
||||
# Construct an op that executes one covariance update per step.
|
||||
(cov_update_thunks,
|
||||
inv_update_thunks) = fisher_estimator.make_vars_and_create_op_thunks(
|
||||
scope="test")
|
||||
cov_update_ops = tuple(thunk() for thunk in cov_update_thunks)
|
||||
inv_update_ops = tuple(thunk() for thunk in inv_update_thunks)
|
||||
self.assertEqual(cov_update_ops[0].device, "/device:CPU:0")
|
||||
self.assertEqual(cov_update_ops[1].device, "/device:CPU:1")
|
||||
self.assertEqual(inv_update_ops[0].device, "/device:CPU:0")
|
||||
self.assertEqual(inv_update_ops[1].device, "/device:CPU:1")
|
||||
cov_matrices = [
|
||||
fisher_factor.get_cov()
|
||||
for fisher_factor in self.layer_collection.get_factors()
|
||||
]
|
||||
inv_matrices = [
|
||||
matrix
|
||||
for fisher_factor in self.layer_collection.get_factors()
|
||||
for matrix in fisher_factor._matpower_by_exp_and_damping.values()
|
||||
]
|
||||
self.assertEqual(cov_matrices[0].device, "/device:CPU:0")
|
||||
self.assertEqual(cov_matrices[1].device, "/device:CPU:1")
|
||||
# Inverse matrices need to be explicitly placed.
|
||||
self.assertEqual(inv_matrices[0].device, "")
|
||||
self.assertEqual(inv_matrices[1].device, "")
|
||||
|
||||
def test_inv_update_thunks(self):
|
||||
"""Ensures inverse update ops run once per global_step."""
|
||||
with self._graph.as_default(), self.cached_session() as sess:
|
||||
fisher_estimator = estimator.FisherEstimatorRoundRobin(
|
||||
variables=[self.weights],
|
||||
layer_collection=self.layer_collection,
|
||||
damping=0.2,
|
||||
cov_ema_decay=0.0)
|
||||
|
||||
# Construct op that updates one inverse per global step.
|
||||
global_step = training_util.get_or_create_global_step()
|
||||
(cov_variable_thunks, _, inv_variable_thunks,
|
||||
inv_update_op_thunks) = fisher_estimator.create_ops_and_vars_thunks()
|
||||
for thunk in cov_variable_thunks:
|
||||
thunk()
|
||||
for thunk in inv_variable_thunks:
|
||||
thunk()
|
||||
inv_matrices = [
|
||||
matrix
|
||||
for fisher_factor in self.layer_collection.get_factors()
|
||||
for matrix in fisher_factor._matpower_by_exp_and_damping.values()
|
||||
]
|
||||
inv_update_op = control_flow_ops.case(
|
||||
[(math_ops.equal(global_step, i), thunk)
|
||||
for i, thunk in enumerate(inv_update_op_thunks)])
|
||||
increment_global_step = global_step.assign_add(1)
|
||||
|
||||
sess.run(variables.global_variables_initializer())
|
||||
initial_inv_values = sess.run(inv_matrices)
|
||||
|
||||
# Ensure there's one update per inverse matrix. This is true as long as
|
||||
# there's no fan-in/fan-out or parameter re-use.
|
||||
self.assertEqual(len(inv_matrices), len(inv_update_op_thunks))
|
||||
|
||||
# Test is no-op if only 1 invariance matrix.
|
||||
assert len(inv_matrices) > 1
|
||||
|
||||
# Assign each covariance matrix a value other than the identity. This
|
||||
# ensures that the inverse matrices are updated to something different as
|
||||
# well.
|
||||
cov_matrices = [
|
||||
fisher_factor.get_cov()
|
||||
for fisher_factor in self.layer_collection.get_factors()
|
||||
]
|
||||
sess.run([
|
||||
cov_matrix.assign(2 * linalg_ops.eye(int(cov_matrix.shape[0])))
|
||||
for cov_matrix in cov_matrices
|
||||
])
|
||||
|
||||
for i in range(len(inv_matrices)):
|
||||
# Compare new and old inverse values
|
||||
new_inv_values = sess.run(inv_matrices)
|
||||
is_inv_equal = [
|
||||
np.allclose(initial_inv_value, new_inv_value)
|
||||
for (initial_inv_value,
|
||||
new_inv_value) in zip(initial_inv_values, new_inv_values)
|
||||
]
|
||||
num_inv_equal = sum(is_inv_equal)
|
||||
|
||||
# Ensure exactly one inverse matrix changes per step.
|
||||
self.assertEqual(num_inv_equal, len(inv_matrices) - i)
|
||||
|
||||
# Run all inverse update ops.
|
||||
sess.run(inv_update_op)
|
||||
sess.run(increment_global_step)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
File diff suppressed because it is too large
Load Diff
@ -1,955 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tf.contrib.kfac.fisher_factors."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import numpy.random as npr
|
||||
|
||||
from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
|
||||
from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops as tf_ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
# We need to set these constants since the numerical values used in the tests
|
||||
# were chosen when these used to be the defaults.
|
||||
ff.set_global_constants(init_covariances_at_zero=False,
|
||||
zero_debias=False,
|
||||
init_inverses_at_zero=False)
|
||||
|
||||
|
||||
def make_damping_func(damping):
|
||||
return fb._package_func(lambda: damping, damping)
|
||||
|
||||
|
||||
class FisherFactorTestingDummy(ff.FisherFactor):
|
||||
"""Dummy class to test the non-abstract methods on ff.FisherFactor."""
|
||||
|
||||
@property
|
||||
def _var_scope(self):
|
||||
return 'dummy/a_b_c'
|
||||
|
||||
@property
|
||||
def _cov_shape(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _num_sources(self):
|
||||
return 1
|
||||
|
||||
@property
|
||||
def _dtype(self):
|
||||
return dtypes.float32
|
||||
|
||||
def _compute_new_cov(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def instantiate_covariance(self):
|
||||
pass
|
||||
|
||||
def make_inverse_update_ops(self):
|
||||
return []
|
||||
|
||||
def get_cov(self):
|
||||
return NotImplementedError
|
||||
|
||||
def instantiate_inv_variables(self):
|
||||
return NotImplementedError
|
||||
|
||||
def _num_towers(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_data_device(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def register_matpower(self, exp, damping_func):
|
||||
raise NotImplementedError
|
||||
|
||||
def register_cholesky(self, damping_func):
|
||||
raise NotImplementedError
|
||||
|
||||
def register_cholesky_inverse(self, damping_func):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_matpower(self, exp, damping_func):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_cholesky(self, damping_func):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_cholesky_inverse(self, damping_func):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_cov_as_linear_operator(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DenseSquareMatrixFactorTestingDummy(ff.DenseSquareMatrixFactor):
|
||||
"""Dummy class to test the non-abstract methods on ff.DenseSquareMatrixFactor.
|
||||
"""
|
||||
|
||||
def __init__(self, shape):
|
||||
self._shape = shape
|
||||
super(DenseSquareMatrixFactorTestingDummy, self).__init__()
|
||||
|
||||
@property
|
||||
def _var_scope(self):
|
||||
return 'dummy/a_b_c'
|
||||
|
||||
@property
|
||||
def _cov_shape(self):
|
||||
return self._shape
|
||||
|
||||
@property
|
||||
def _num_sources(self):
|
||||
return 1
|
||||
|
||||
@property
|
||||
def _dtype(self):
|
||||
return dtypes.float32
|
||||
|
||||
def _compute_new_cov(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def instantiate_covariance(self):
|
||||
pass
|
||||
|
||||
def _num_towers(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_data_device(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class NumericalUtilsTest(test.TestCase):
|
||||
|
||||
def testComputeCovAgainstNumpy(self):
|
||||
with tf_ops.Graph().as_default(), self.cached_session() as sess:
|
||||
npr.seed(0)
|
||||
random_seed.set_random_seed(200)
|
||||
|
||||
x = npr.randn(100, 3)
|
||||
cov = ff.compute_cov(array_ops.constant(x))
|
||||
np_cov = np.dot(x.T, x) / x.shape[0]
|
||||
|
||||
self.assertAllClose(sess.run(cov), np_cov)
|
||||
|
||||
def testComputeCovAgainstNumpyWithAlternativeNormalizer(self):
|
||||
with tf_ops.Graph().as_default(), self.cached_session() as sess:
|
||||
npr.seed(0)
|
||||
random_seed.set_random_seed(200)
|
||||
|
||||
normalizer = 10.
|
||||
x = npr.randn(100, 3)
|
||||
cov = ff.compute_cov(array_ops.constant(x), normalizer=normalizer)
|
||||
np_cov = np.dot(x.T, x) / normalizer
|
||||
|
||||
self.assertAllClose(sess.run(cov), np_cov)
|
||||
|
||||
def testAppendHomog(self):
|
||||
with tf_ops.Graph().as_default(), self.cached_session() as sess:
|
||||
npr.seed(0)
|
||||
|
||||
m, n = 3, 4
|
||||
a = npr.randn(m, n)
|
||||
a_homog = ff.append_homog(array_ops.constant(a))
|
||||
np_result = np.hstack([a, np.ones((m, 1))])
|
||||
|
||||
self.assertAllClose(sess.run(a_homog), np_result)
|
||||
|
||||
|
||||
class NameStringUtilFunctionTest(test.TestCase):
|
||||
|
||||
def _make_tensor(self):
|
||||
x = array_ops.placeholder(dtypes.float64, (3, 1))
|
||||
w = array_ops.constant(npr.RandomState(0).randn(3, 3))
|
||||
y = math_ops.matmul(w, x)
|
||||
g = gradients_impl.gradients(y, x)[0]
|
||||
return g
|
||||
|
||||
def testScopeStringFromParamsSingleTensor(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
g = self._make_tensor()
|
||||
scope_string = ff.scope_string_from_params(g)
|
||||
self.assertEqual('gradients_MatMul_grad_MatMul_1', scope_string)
|
||||
|
||||
def testScopeStringFromParamsMultipleTensors(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
x = array_ops.constant(1,)
|
||||
y = array_ops.constant(2,)
|
||||
scope_string = ff.scope_string_from_params((x, y))
|
||||
self.assertEqual('Const_Const_1', scope_string)
|
||||
|
||||
def testScopeStringFromParamsMultipleTypes(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
x = array_ops.constant(1,)
|
||||
y = array_ops.constant(2,)
|
||||
scope_string = ff.scope_string_from_params([[1, 2, 3], 'foo', True, 4,
|
||||
(x, y)])
|
||||
self.assertEqual('1-2-3_foo_True_4_Const__Const_1', scope_string)
|
||||
|
||||
def testScopeStringFromParamsUnsupportedType(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
x = array_ops.constant(1,)
|
||||
y = array_ops.constant(2,)
|
||||
unsupported = 1.2 # Floats are not supported.
|
||||
with self.assertRaises(ValueError):
|
||||
ff.scope_string_from_params([[1, 2, 3], 'foo', True, 4, (x, y),
|
||||
unsupported])
|
||||
|
||||
def testScopeStringFromName(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
g = self._make_tensor()
|
||||
scope_string = ff.scope_string_from_name(g)
|
||||
self.assertEqual('gradients_MatMul_grad_MatMul_1', scope_string)
|
||||
|
||||
def testScalarOrTensorToString(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
self.assertEqual(ff.scalar_or_tensor_to_string(5.), repr(5.))
|
||||
|
||||
g = self._make_tensor()
|
||||
scope_string = ff.scope_string_from_name(g)
|
||||
self.assertEqual(ff.scalar_or_tensor_to_string(g), scope_string)
|
||||
|
||||
|
||||
class FisherFactorTest(test.TestCase):
|
||||
|
||||
def testMakeInverseUpdateOps(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
factor = FisherFactorTestingDummy()
|
||||
|
||||
self.assertEqual(0, len(factor.make_inverse_update_ops()))
|
||||
|
||||
|
||||
class DenseSquareMatrixFactorTest(test.TestCase):
|
||||
|
||||
def testRegisterDampedInverse(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
shape = [2, 2]
|
||||
factor = DenseSquareMatrixFactorTestingDummy(shape)
|
||||
factor_var_scope = 'dummy/a_b_c'
|
||||
|
||||
damping_funcs = [make_damping_func(0.1),
|
||||
make_damping_func(0.1),
|
||||
make_damping_func(1e-5),
|
||||
make_damping_func(1e-5)]
|
||||
for damping_func in damping_funcs:
|
||||
factor.register_inverse(damping_func)
|
||||
|
||||
factor.instantiate_inv_variables()
|
||||
|
||||
inv = factor.get_inverse(damping_funcs[0]).to_dense()
|
||||
self.assertEqual(inv, factor.get_inverse(damping_funcs[1]).to_dense())
|
||||
self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]).to_dense())
|
||||
self.assertEqual(factor.get_inverse(damping_funcs[2]).to_dense(),
|
||||
factor.get_inverse(damping_funcs[3]).to_dense())
|
||||
factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
|
||||
factor_var_scope)
|
||||
factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars)
|
||||
|
||||
self.assertEqual(set([inv,
|
||||
factor.get_inverse(damping_funcs[2]).to_dense()]),
|
||||
set(factor_tensors))
|
||||
self.assertEqual(shape, inv.get_shape())
|
||||
|
||||
def testRegisterMatpower(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
shape = [3, 3]
|
||||
factor = DenseSquareMatrixFactorTestingDummy(shape)
|
||||
factor_var_scope = 'dummy/a_b_c'
|
||||
|
||||
# TODO(b/74201126): Change to using the same func for both once
|
||||
# Topohash is in place.
|
||||
damping_func_1 = make_damping_func(0.5)
|
||||
damping_func_2 = make_damping_func(0.5)
|
||||
|
||||
factor.register_matpower(-0.5, damping_func_1)
|
||||
factor.register_matpower(2, damping_func_2)
|
||||
|
||||
factor.instantiate_inv_variables()
|
||||
|
||||
factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
|
||||
factor_var_scope)
|
||||
|
||||
factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars)
|
||||
|
||||
matpower1 = factor.get_matpower(-0.5, damping_func_1).to_dense()
|
||||
matpower2 = factor.get_matpower(2, damping_func_2).to_dense()
|
||||
|
||||
self.assertEqual(set([matpower1, matpower2]), set(factor_tensors))
|
||||
|
||||
self.assertEqual(shape, matpower1.get_shape())
|
||||
self.assertEqual(shape, matpower2.get_shape())
|
||||
|
||||
def testMakeInverseUpdateOps(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
factor = FisherFactorTestingDummy()
|
||||
|
||||
self.assertEqual(0, len(factor.make_inverse_update_ops()))
|
||||
|
||||
def testMakeInverseUpdateOpsManyInversesEigenDecomp(self):
|
||||
with tf_ops.Graph().as_default(), self.cached_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
cov = np.array([[1., 2.], [3., 4.]])
|
||||
factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
|
||||
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
|
||||
|
||||
damping_funcs = []
|
||||
for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1):
|
||||
damping_funcs.append(make_damping_func(1./i))
|
||||
|
||||
for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
|
||||
factor.register_inverse(damping_funcs[i])
|
||||
|
||||
factor.instantiate_inv_variables()
|
||||
ops = factor.make_inverse_update_ops()
|
||||
self.assertEqual(1, len(ops))
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
new_invs = []
|
||||
sess.run(ops)
|
||||
for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
|
||||
# The inverse op will assign the damped inverse of cov to the inv var.
|
||||
new_invs.append(
|
||||
sess.run(factor.get_inverse(damping_funcs[i]).to_dense()))
|
||||
|
||||
# We want to see that the new invs are all different from each other.
|
||||
for i in range(len(new_invs)):
|
||||
for j in range(i + 1, len(new_invs)):
|
||||
# Just check the first element.
|
||||
self.assertNotEqual(new_invs[i][0][0], new_invs[j][0][0])
|
||||
|
||||
def testMakeInverseUpdateOpsMatPowerEigenDecomp(self):
|
||||
with tf_ops.Graph().as_default(), self.cached_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
cov = np.array([[6., 2.], [2., 4.]])
|
||||
factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
|
||||
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
|
||||
exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power
|
||||
damping = 0.5
|
||||
damping_func = make_damping_func(damping)
|
||||
|
||||
factor.register_matpower(exp, damping_func)
|
||||
factor.instantiate_inv_variables()
|
||||
ops = factor.make_inverse_update_ops()
|
||||
self.assertEqual(1, len(ops))
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
sess.run(ops[0])
|
||||
matpower = sess.run(factor.get_matpower(exp, damping_func).to_dense())
|
||||
matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp)
|
||||
self.assertAllClose(matpower, matpower_np)
|
||||
|
||||
def testMakeInverseUpdateOpsNoEigenDecomp(self):
|
||||
with tf_ops.Graph().as_default(), self.cached_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
cov = np.array([[5., 2.], [2., 4.]]) # NOTE(mattjj): must be symmetric
|
||||
factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
|
||||
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
|
||||
|
||||
damping_func = make_damping_func(0)
|
||||
|
||||
factor.register_inverse(damping_func)
|
||||
factor.instantiate_inv_variables()
|
||||
ops = factor.make_inverse_update_ops()
|
||||
self.assertEqual(1, len(ops))
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
# The inverse op will assign the damped inverse of cov to the inv var.
|
||||
old_inv = sess.run(factor.get_inverse(damping_func).to_dense())
|
||||
self.assertAllClose(
|
||||
sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv)
|
||||
|
||||
sess.run(ops)
|
||||
new_inv = sess.run(factor.get_inverse(damping_func).to_dense())
|
||||
self.assertAllClose(new_inv, np.linalg.inv(cov))
|
||||
|
||||
|
||||
class FullFactorTest(test.TestCase):
|
||||
|
||||
def testFullFactorInit(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.ones((2, 3), name='a/b/c')
|
||||
factor = ff.FullFactor((tensor,), 32)
|
||||
factor.instantiate_cov_variables()
|
||||
self.assertEqual([6, 6], factor.get_cov().get_shape().as_list())
|
||||
|
||||
def testFullFactorInitFloat64(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
dtype = dtypes.float64_ref
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
|
||||
factor = ff.FullFactor((tensor,), 32)
|
||||
factor.instantiate_cov_variables()
|
||||
cov = factor.get_cov()
|
||||
self.assertEqual(cov.dtype, dtype)
|
||||
self.assertEqual([6, 6], cov.get_shape().as_list())
|
||||
|
||||
def testMakeCovarianceUpdateOp(self):
|
||||
with tf_ops.Graph().as_default(), self.cached_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.constant([1., 2.], name='a/b/c')
|
||||
factor = ff.FullFactor((tensor,), 2)
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
new_cov = sess.run(factor.make_covariance_update_op(.5))
|
||||
self.assertAllClose([[0.75, 0.5], [0.5, 1.5]], new_cov)
|
||||
|
||||
|
||||
class NaiveDiagonalFactorTest(test.TestCase):
|
||||
|
||||
def testNaiveDiagonalFactorInit(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.ones((2, 3), name='a/b/c')
|
||||
factor = ff.NaiveDiagonalFactor((tensor,), 32)
|
||||
factor.instantiate_cov_variables()
|
||||
self.assertEqual([6, 1], factor.get_cov().get_shape().as_list())
|
||||
|
||||
def testNaiveDiagonalFactorInitFloat64(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
dtype = dtypes.float64_ref
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
|
||||
factor = ff.NaiveDiagonalFactor((tensor,), 32)
|
||||
factor.instantiate_cov_variables()
|
||||
cov = factor.get_cov()
|
||||
self.assertEqual(cov.dtype, dtype)
|
||||
self.assertEqual([6, 1], cov.get_shape().as_list())
|
||||
|
||||
def testMakeCovarianceUpdateOp(self):
|
||||
with tf_ops.Graph().as_default(), self.cached_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.constant([1., 2.], name='a/b/c')
|
||||
factor = ff.NaiveDiagonalFactor((tensor,), 2)
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
new_cov = sess.run(factor.make_covariance_update_op(.5))
|
||||
self.assertAllClose([[0.75], [1.5]], new_cov)
|
||||
|
||||
|
||||
class EmbeddingInputKroneckerFactorTest(test.TestCase):
|
||||
|
||||
def testInitialization(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
input_ids = array_ops.constant([[0], [1], [4]])
|
||||
vocab_size = 5
|
||||
factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
|
||||
factor.instantiate_cov_variables()
|
||||
cov = factor.get_cov()
|
||||
self.assertEqual(cov.shape.as_list(), [vocab_size])
|
||||
|
||||
def testCovarianceUpdateOp(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
input_ids = array_ops.constant([[0], [1], [4]])
|
||||
vocab_size = 5
|
||||
factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
|
||||
factor.instantiate_cov_variables()
|
||||
cov_update_op = factor.make_covariance_update_op(0.0)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
new_cov = sess.run(cov_update_op)
|
||||
self.assertAllClose(np.array([1., 1., 0., 0., 1.]) / 3., new_cov)
|
||||
|
||||
|
||||
class ConvDiagonalFactorTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.batch_size = 10
|
||||
self.height = self.width = 32
|
||||
self.in_channels = 3
|
||||
self.out_channels = 1
|
||||
self.kernel_height = self.kernel_width = 3
|
||||
self.strides = [1, 2, 2, 1]
|
||||
self.data_format = 'NHWC'
|
||||
self.padding = 'SAME'
|
||||
self.kernel_shape = [
|
||||
self.kernel_height, self.kernel_width, self.in_channels,
|
||||
self.out_channels
|
||||
]
|
||||
|
||||
def testInit(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
inputs = random_ops.random_uniform(
|
||||
[self.batch_size, self.height, self.width, self.in_channels])
|
||||
outputs_grads = [
|
||||
random_ops.random_uniform([
|
||||
self.batch_size, self.height // self.strides[1],
|
||||
self.width // self.strides[2], self.out_channels
|
||||
]) for _ in range(3)
|
||||
]
|
||||
|
||||
factor = ff.ConvDiagonalFactor(
|
||||
(inputs,),
|
||||
(outputs_grads,),
|
||||
self.kernel_shape,
|
||||
self.strides,
|
||||
self.padding,
|
||||
data_format=self.data_format)
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
# Ensure covariance matrix's shape makes sense.
|
||||
self.assertEqual([
|
||||
self.kernel_height * self.kernel_width * self.in_channels,
|
||||
self.out_channels
|
||||
],
|
||||
factor.get_cov().shape.as_list())
|
||||
|
||||
def testMakeCovarianceUpdateOp(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
# Construct all arguments such that convolution kernel is applied in
|
||||
# exactly one spatial location.
|
||||
inputs = np.random.randn(
|
||||
1, # batch_size
|
||||
self.kernel_height,
|
||||
self.kernel_width,
|
||||
self.in_channels) # in_channels
|
||||
outputs_grad = np.random.randn(
|
||||
1, # batch_size
|
||||
1, # output_height
|
||||
1, # output_width
|
||||
self.out_channels)
|
||||
|
||||
factor = ff.ConvDiagonalFactor(
|
||||
(constant_op.constant(inputs),),
|
||||
((constant_op.constant(outputs_grad),),),
|
||||
self.kernel_shape,
|
||||
strides=[1, 1, 1, 1],
|
||||
padding='VALID')
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
# Completely forget initial value on first update.
|
||||
cov_update_op = factor.make_covariance_update_op(0.0)
|
||||
|
||||
# Ensure new covariance value is same as outer-product of inputs/outputs
|
||||
# vectorized, squared.
|
||||
with self.cached_session() as sess:
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
cov = sess.run(cov_update_op)
|
||||
expected_cov = np.outer(inputs.flatten(), outputs_grad.flatten())**2
|
||||
self.assertAllClose(expected_cov, cov)
|
||||
|
||||
def testHasBias(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
inputs = random_ops.random_uniform(
|
||||
[self.batch_size, self.height, self.width, self.in_channels])
|
||||
outputs_grads = [
|
||||
random_ops.random_uniform([
|
||||
self.batch_size, self.height // self.strides[1],
|
||||
self.width // self.strides[2], self.out_channels
|
||||
]) for _ in range(3)
|
||||
]
|
||||
|
||||
factor = ff.ConvDiagonalFactor(
|
||||
(inputs,),
|
||||
(outputs_grads,),
|
||||
self.kernel_shape,
|
||||
self.strides,
|
||||
self.padding,
|
||||
data_format=self.data_format,
|
||||
has_bias=True)
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
# Ensure shape accounts for bias.
|
||||
self.assertEqual([
|
||||
self.kernel_height * self.kernel_width * self.in_channels + 1,
|
||||
self.out_channels
|
||||
],
|
||||
factor.get_cov().shape.as_list())
|
||||
|
||||
# Ensure update op doesn't crash.
|
||||
cov_update_op = factor.make_covariance_update_op(0.0)
|
||||
with self.cached_session() as sess:
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
sess.run(cov_update_op)
|
||||
|
||||
|
||||
class FullyConnectedKroneckerFactorTest(test.TestCase):
|
||||
|
||||
def _testFullyConnectedKroneckerFactorInit(self,
|
||||
has_bias,
|
||||
final_shape,
|
||||
dtype=dtypes.float32_ref):
|
||||
with tf_ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
|
||||
factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=has_bias)
|
||||
factor.instantiate_cov_variables()
|
||||
cov = factor.get_cov()
|
||||
self.assertEqual(cov.dtype, dtype)
|
||||
self.assertEqual(final_shape, cov.get_shape().as_list())
|
||||
|
||||
def testFullyConnectedKroneckerFactorInitNoBias(self):
|
||||
for dtype in (dtypes.float32_ref, dtypes.float64_ref):
|
||||
self._testFullyConnectedKroneckerFactorInit(False, [3, 3], dtype=dtype)
|
||||
|
||||
def testFullyConnectedKroneckerFactorInitWithBias(self):
|
||||
for dtype in (dtypes.float32_ref, dtypes.float64_ref):
|
||||
self._testFullyConnectedKroneckerFactorInit(True, [4, 4], dtype=dtype)
|
||||
|
||||
def testMakeCovarianceUpdateOpWithBias(self):
|
||||
with tf_ops.Graph().as_default(), self.cached_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
|
||||
factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=True)
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
new_cov = sess.run(factor.make_covariance_update_op(.5))
|
||||
self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov)
|
||||
|
||||
def testMakeCovarianceUpdateOpNoBias(self):
|
||||
with tf_ops.Graph().as_default(), self.cached_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
|
||||
factor = ff.FullyConnectedKroneckerFactor(((tensor,),))
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
new_cov = sess.run(factor.make_covariance_update_op(.5))
|
||||
self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov)
|
||||
|
||||
|
||||
class ConvFactorTestCase(test.TestCase):
|
||||
|
||||
def assertMatrixRank(self, rank, matrix, atol=1e-5):
|
||||
assert rank <= matrix.shape[0], 'Rank cannot be larger than matrix size.'
|
||||
eigvals = np.linalg.eigvals(matrix)
|
||||
nnz_eigvals = np.sum(eigvals > atol)
|
||||
self.assertEqual(
|
||||
rank,
|
||||
nnz_eigvals,
|
||||
msg=('Found %d of %d expected non-zero eigenvalues: %s.' %
|
||||
(nnz_eigvals, rank, eigvals)))
|
||||
|
||||
|
||||
class ConvInputKroneckerFactorTest(ConvFactorTestCase):
|
||||
|
||||
def test3DConvolution(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
batch_size = 1
|
||||
width = 3
|
||||
in_channels = 3**3
|
||||
out_channels = 4
|
||||
|
||||
factor = ff.ConvInputKroneckerFactor(
|
||||
inputs=(random_ops.random_uniform(
|
||||
(batch_size, width, width, width, in_channels), seed=0),),
|
||||
filter_shape=(width, width, width, in_channels, out_channels),
|
||||
padding='SAME',
|
||||
strides=(2, 2, 2),
|
||||
extract_patches_fn='extract_convolution_patches',
|
||||
has_bias=False)
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
# Ensure shape of covariance matches input size of filter.
|
||||
input_size = in_channels * (width**3)
|
||||
self.assertEqual([input_size, input_size],
|
||||
factor.get_cov().shape.as_list())
|
||||
|
||||
# Ensure cov_update_op doesn't crash.
|
||||
with self.cached_session() as sess:
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
sess.run(factor.make_covariance_update_op(0.0))
|
||||
cov = sess.run(factor.get_cov())
|
||||
|
||||
# Cov should be rank-8, as the filter will be applied at each corner of
|
||||
# the 4-D cube.
|
||||
self.assertMatrixRank(8, cov)
|
||||
|
||||
def testPointwiseConv2d(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
batch_size = 1
|
||||
width = 3
|
||||
in_channels = 3**2
|
||||
out_channels = 4
|
||||
|
||||
factor = ff.ConvInputKroneckerFactor(
|
||||
inputs=(random_ops.random_uniform(
|
||||
(batch_size, width, width, in_channels), seed=0),),
|
||||
filter_shape=(1, 1, in_channels, out_channels),
|
||||
padding='SAME',
|
||||
strides=(1, 1, 1, 1),
|
||||
extract_patches_fn='extract_pointwise_conv2d_patches',
|
||||
has_bias=False)
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
# Ensure shape of covariance matches input size of filter.
|
||||
self.assertEqual([in_channels, in_channels],
|
||||
factor.get_cov().shape.as_list())
|
||||
|
||||
# Ensure cov_update_op doesn't crash.
|
||||
with self.cached_session() as sess:
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
sess.run(factor.make_covariance_update_op(0.0))
|
||||
cov = sess.run(factor.get_cov())
|
||||
|
||||
# Cov should be rank-9, as the filter will be applied at each location.
|
||||
self.assertMatrixRank(9, cov)
|
||||
|
||||
def testStrides(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
batch_size = 1
|
||||
width = 3
|
||||
in_channels = 3**2
|
||||
out_channels = 4
|
||||
|
||||
factor = ff.ConvInputKroneckerFactor(
|
||||
inputs=(random_ops.random_uniform(
|
||||
(batch_size, width, width, in_channels), seed=0),),
|
||||
filter_shape=(1, 1, in_channels, out_channels),
|
||||
padding='SAME',
|
||||
strides=(1, 2, 1, 1),
|
||||
extract_patches_fn='extract_image_patches',
|
||||
has_bias=False)
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
sess.run(factor.make_covariance_update_op(0.0))
|
||||
cov = sess.run(factor.get_cov())
|
||||
|
||||
# Cov should be the sum of 3 * 2 = 6 outer products.
|
||||
self.assertMatrixRank(6, cov)
|
||||
|
||||
def testDilationRate(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
batch_size = 1
|
||||
width = 3
|
||||
in_channels = 2
|
||||
out_channels = 4
|
||||
|
||||
factor = ff.ConvInputKroneckerFactor(
|
||||
inputs=(random_ops.random_uniform(
|
||||
(batch_size, width, width, in_channels), seed=0),),
|
||||
filter_shape=(3, 3, in_channels, out_channels),
|
||||
padding='SAME',
|
||||
extract_patches_fn='extract_image_patches',
|
||||
strides=(1, 1, 1, 1),
|
||||
dilation_rate=(1, width, width, 1),
|
||||
has_bias=False)
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
sess.run(factor.make_covariance_update_op(0.0))
|
||||
cov = sess.run(factor.get_cov())
|
||||
|
||||
# Cov should be rank = in_channels, as only the center of the filter
|
||||
# receives non-zero input for each input channel.
|
||||
self.assertMatrixRank(in_channels, cov)
|
||||
|
||||
def testConvInputKroneckerFactorInitNoBias(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c')
|
||||
factor = ff.ConvInputKroneckerFactor(
|
||||
inputs=(tensor,),
|
||||
filter_shape=(1, 2, 3, 4),
|
||||
padding='SAME',
|
||||
has_bias=False)
|
||||
factor.instantiate_cov_variables()
|
||||
self.assertEqual([1 * 2 * 3, 1 * 2 * 3],
|
||||
factor.get_cov().get_shape().as_list())
|
||||
|
||||
def testConvInputKroneckerFactorInit(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c')
|
||||
factor = ff.ConvInputKroneckerFactor(
|
||||
(tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True)
|
||||
factor.instantiate_cov_variables()
|
||||
self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
|
||||
factor.get_cov().get_shape().as_list())
|
||||
|
||||
def testConvInputKroneckerFactorInitFloat64(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
dtype = dtypes.float64_ref
|
||||
tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c', dtype=dtypes.float64)
|
||||
factor = ff.ConvInputKroneckerFactor(
|
||||
(tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True)
|
||||
factor.instantiate_cov_variables()
|
||||
cov = factor.get_cov()
|
||||
self.assertEqual(cov.dtype, dtype)
|
||||
self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
|
||||
cov.get_shape().as_list())
|
||||
|
||||
def testMakeCovarianceUpdateOpWithBias(self):
|
||||
with tf_ops.Graph().as_default(), self.cached_session() as sess:
|
||||
input_shape = (2, 1, 1, 1)
|
||||
tensor = array_ops.constant(
|
||||
np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype(
|
||||
np.float32))
|
||||
factor = ff.ConvInputKroneckerFactor(
|
||||
(tensor,), filter_shape=(1, 1, 1, 1), padding='SAME', has_bias=True)
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
new_cov = sess.run(factor.make_covariance_update_op(0.))
|
||||
self.assertAllClose(
|
||||
[
|
||||
[(1. + 4.) / 2., (1. + 2.) / 2.], #
|
||||
[(1. + 2.) / 2., (1. + 1.) / 2.]
|
||||
], #
|
||||
new_cov)
|
||||
|
||||
def testMakeCovarianceUpdateOpNoBias(self):
|
||||
with tf_ops.Graph().as_default(), self.cached_session() as sess:
|
||||
input_shape = (2, 1, 1, 1)
|
||||
tensor = array_ops.constant(
|
||||
np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype(
|
||||
np.float32))
|
||||
factor = ff.ConvInputKroneckerFactor(
|
||||
(tensor,), filter_shape=(1, 1, 1, 1), padding='SAME')
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
new_cov = sess.run(factor.make_covariance_update_op(0.))
|
||||
self.assertAllClose([[(1. + 4.) / 2.]], new_cov)
|
||||
|
||||
def testSubSample(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
patches_1 = array_ops.constant(1, shape=(10, 2))
|
||||
patches_2 = array_ops.constant(1, shape=(10, 8))
|
||||
patches_3 = array_ops.constant(1, shape=(3, 3))
|
||||
patches_1_sub = ff._subsample_for_cov_computation(patches_1)
|
||||
patches_2_sub = ff._subsample_for_cov_computation(patches_2)
|
||||
patches_3_sub = ff._subsample_for_cov_computation(patches_3)
|
||||
patches_1_sub_batch_size = patches_1_sub.shape.as_list()[0]
|
||||
patches_2_sub_batch_size = patches_2_sub.shape.as_list()[0]
|
||||
patches_3_sub_batch_size = patches_3_sub.shape.as_list()[0]
|
||||
self.assertEqual(2, patches_1_sub_batch_size)
|
||||
self.assertEqual(8, patches_2_sub_batch_size)
|
||||
self.assertEqual(3, patches_3_sub_batch_size)
|
||||
|
||||
|
||||
class ConvOutputKroneckerFactorTest(ConvFactorTestCase):
|
||||
|
||||
def test3DConvolution(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
batch_size = 1
|
||||
width = 3
|
||||
out_channels = width**3
|
||||
|
||||
factor = ff.ConvOutputKroneckerFactor(outputs_grads=([
|
||||
random_ops.random_uniform(
|
||||
(batch_size, width, width, width, out_channels), seed=0)
|
||||
],))
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
sess.run(factor.make_covariance_update_op(0.0))
|
||||
cov = sess.run(factor.get_cov())
|
||||
|
||||
# Cov should be rank 3^3, as each spatial position donates a rank-1
|
||||
# update.
|
||||
self.assertMatrixRank(width**3, cov)
|
||||
|
||||
def testConvOutputKroneckerFactorInit(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.ones((2, 3, 4, 5), name='a/b/c')
|
||||
factor = ff.ConvOutputKroneckerFactor(((tensor,),))
|
||||
factor.instantiate_cov_variables()
|
||||
self.assertEqual([5, 5], factor.get_cov().get_shape().as_list())
|
||||
|
||||
def testConvOutputKroneckerFactorInitFloat64(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
dtype = dtypes.float64_ref
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.ones((2, 3, 4, 5), dtype=dtype, name='a/b/c')
|
||||
factor = ff.ConvOutputKroneckerFactor(((tensor,),))
|
||||
factor.instantiate_cov_variables()
|
||||
cov = factor.get_cov()
|
||||
self.assertEqual(cov.dtype, dtype)
|
||||
self.assertEqual([5, 5], cov.get_shape().as_list())
|
||||
|
||||
def testMakeCovarianceUpdateOp(self):
|
||||
with tf_ops.Graph().as_default(), self.cached_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = np.arange(1, 17).reshape(2, 2, 2, 2).astype(np.float32)
|
||||
factor = ff.ConvOutputKroneckerFactor(((array_ops.constant(tensor),),))
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
new_cov = sess.run(factor.make_covariance_update_op(.5))
|
||||
self.assertAllClose([[43, 46.5], [46.5, 51.5]], new_cov)
|
||||
|
||||
|
||||
class FullyConnectedMultiKFTest(test.TestCase):
|
||||
|
||||
def testFullyConnectedMultiKFInit(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.ones((2, 3), name='a/b/c')
|
||||
factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False)
|
||||
factor.instantiate_cov_variables()
|
||||
self.assertEqual([3, 3], factor.get_cov().get_shape().as_list())
|
||||
|
||||
def testFullyConnectedMultiKFInitFloat64(self):
|
||||
with tf_ops.Graph().as_default():
|
||||
dtype = dtypes.float64_ref
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
|
||||
factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False)
|
||||
factor.instantiate_cov_variables()
|
||||
cov = factor.get_cov()
|
||||
self.assertEqual(cov.dtype, dtype)
|
||||
self.assertEqual([3, 3], cov.get_shape().as_list())
|
||||
|
||||
def testMakeCovarianceUpdateOpWithBias(self):
|
||||
with tf_ops.Graph().as_default(), self.cached_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
|
||||
factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=True)
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
new_cov = sess.run(factor.make_covariance_update_op(.5))
|
||||
self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov)
|
||||
|
||||
def testMakeCovarianceUpdateOpNoBias(self):
|
||||
with tf_ops.Graph().as_default(), self.cached_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
|
||||
factor = ff.FullyConnectedMultiKF(((tensor,),))
|
||||
factor.instantiate_cov_variables()
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
new_cov = sess.run(factor.make_covariance_update_op(.5))
|
||||
self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -1,597 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tf.contrib.kfac.layer_collection."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.kfac.python.ops import fisher_blocks
|
||||
from tensorflow.contrib.kfac.python.ops import fisher_factors
|
||||
from tensorflow.contrib.kfac.python.ops import layer_collection
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class MockFisherBlock(object):
|
||||
"""A fake FisherBlock."""
|
||||
|
||||
num_registered_towers = 2
|
||||
|
||||
def __init__(self, name='MockFisherBlock'):
|
||||
self.name = name
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, MockFisherBlock) and other.name == self.name
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
|
||||
class LayerParametersDictTest(test.TestCase):
|
||||
|
||||
def testSetItem(self):
|
||||
"""Ensure insertion, contains, retrieval works for supported key types."""
|
||||
with ops.Graph().as_default():
|
||||
lp_dict = layer_collection.LayerParametersDict()
|
||||
|
||||
x = array_ops.constant(0)
|
||||
y0 = array_ops.constant(0)
|
||||
y1 = array_ops.constant(0)
|
||||
z0 = array_ops.constant(0)
|
||||
z1 = array_ops.constant(0)
|
||||
keys = [x, (y0, y1), [z0, z1]]
|
||||
for key in keys:
|
||||
lp_dict[key] = key
|
||||
|
||||
for key in keys:
|
||||
self.assertTrue(key in lp_dict)
|
||||
self.assertEqual(lp_dict[key], key)
|
||||
|
||||
def testSetItemOverlap(self):
|
||||
"""Ensure insertion fails if key overlaps with existing key."""
|
||||
with ops.Graph().as_default():
|
||||
lp_dict = layer_collection.LayerParametersDict()
|
||||
|
||||
x = array_ops.constant(0)
|
||||
y = array_ops.constant(0)
|
||||
lp_dict[x] = 'value'
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
lp_dict[(x, y)] = 'value'
|
||||
|
||||
# Ensure 'y' wasn't inserted.
|
||||
self.assertTrue(x in lp_dict)
|
||||
self.assertFalse(y in lp_dict)
|
||||
|
||||
|
||||
class LayerCollectionTest(test.TestCase):
|
||||
|
||||
def testLayerCollectionInit(self):
|
||||
lc = layer_collection.LayerCollection()
|
||||
self.assertEqual(0, len(lc.get_blocks()))
|
||||
self.assertEqual(0, len(lc.get_factors()))
|
||||
self.assertFalse(lc.losses)
|
||||
|
||||
def testRegisterBlocks(self):
|
||||
with ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.register_fully_connected(
|
||||
array_ops.constant(1), array_ops.constant(2), array_ops.constant(3))
|
||||
lc.register_fully_connected(
|
||||
array_ops.constant(1),
|
||||
array_ops.constant(2),
|
||||
array_ops.constant(3),
|
||||
approx=layer_collection.APPROX_DIAGONAL_NAME)
|
||||
lc.register_conv2d(
|
||||
params=array_ops.ones((2, 3, 4, 5)),
|
||||
strides=[1, 1, 1, 1],
|
||||
padding='SAME',
|
||||
inputs=array_ops.ones((1, 2, 3, 4)),
|
||||
outputs=array_ops.ones((1, 1, 1, 5)))
|
||||
lc.register_conv2d(
|
||||
params=array_ops.ones((2, 3, 4, 5)),
|
||||
strides=[1, 1, 1, 1],
|
||||
padding='SAME',
|
||||
inputs=array_ops.ones((1, 2, 3, 4)),
|
||||
outputs=array_ops.ones((1, 1, 1, 5)),
|
||||
approx=layer_collection.APPROX_DIAGONAL_NAME)
|
||||
lc.register_separable_conv2d(
|
||||
depthwise_params=array_ops.ones((3, 3, 1, 2)),
|
||||
pointwise_params=array_ops.ones((1, 1, 2, 4)),
|
||||
inputs=array_ops.ones((32, 5, 5, 1)),
|
||||
depthwise_outputs=array_ops.ones((32, 5, 5, 2)),
|
||||
pointwise_outputs=array_ops.ones((32, 5, 5, 4)),
|
||||
strides=[1, 1, 1, 1],
|
||||
padding='SAME')
|
||||
lc.register_convolution(
|
||||
params=array_ops.ones((3, 3, 1, 8)),
|
||||
inputs=array_ops.ones((32, 5, 5, 1)),
|
||||
outputs=array_ops.ones((32, 5, 5, 8)),
|
||||
padding='SAME')
|
||||
lc.register_generic(
|
||||
array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME)
|
||||
lc.register_generic(
|
||||
array_ops.constant(6),
|
||||
16,
|
||||
approx=layer_collection.APPROX_DIAGONAL_NAME)
|
||||
lc.register_fully_connected_multi(
|
||||
array_ops.constant(1),
|
||||
(array_ops.constant(2), array_ops.constant(3)),
|
||||
(array_ops.constant(4), array_ops.constant(5)))
|
||||
lc.register_conv2d_multi(
|
||||
params=array_ops.ones((2, 3, 4, 5)),
|
||||
strides=[1, 1, 1, 1],
|
||||
padding='SAME',
|
||||
inputs=(array_ops.ones((1, 2, 3, 4)), array_ops.ones((5, 6, 7, 8))),
|
||||
outputs=(array_ops.ones((1, 1, 1, 5)), array_ops.ones((2, 2, 2, 10))))
|
||||
lc.register_embedding_multi(
|
||||
array_ops.constant((1,)),
|
||||
(array_ops.constant(2), array_ops.constant(3)),
|
||||
(array_ops.constant(4), array_ops.constant(5)))
|
||||
|
||||
self.assertEqual(12, len(lc.get_blocks()))
|
||||
|
||||
def testRegisterBlocksMultipleRegistrations(self):
|
||||
with ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
lc = layer_collection.LayerCollection()
|
||||
key = array_ops.constant(1)
|
||||
lc.register_fully_connected(key, array_ops.constant(2),
|
||||
array_ops.constant(3))
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
lc.register_generic(key, 16)
|
||||
self.assertIn('already in LayerCollection', str(cm.exception))
|
||||
|
||||
def testRegisterSingleParamNotRegistered(self):
|
||||
x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.fisher_blocks = {
|
||||
variable_scope.get_variable('y', initializer=array_ops.constant(1,)):
|
||||
'1'
|
||||
}
|
||||
lc.register_block(x, 'foo')
|
||||
|
||||
def testShouldRegisterSingleParamRegistered(self):
|
||||
x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.fisher_blocks = {x: '1'}
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
lc.register_block(x, 'foo')
|
||||
self.assertIn('already in LayerCollection', str(cm.exception))
|
||||
|
||||
def testRegisterSingleParamRegisteredInTuple(self):
|
||||
x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
|
||||
y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.fisher_blocks = {(x, y): '1'}
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
lc.register_block(x, 'foo')
|
||||
self.assertIn('was already registered', str(cm.exception))
|
||||
|
||||
def testRegisterTupleParamNotRegistered(self):
|
||||
x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
|
||||
y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.fisher_blocks = {
|
||||
variable_scope.get_variable('z', initializer=array_ops.constant(1,)):
|
||||
'1'
|
||||
}
|
||||
|
||||
lc.register_block((x, y), 'foo')
|
||||
self.assertEqual(set(['1', 'foo']), set(lc.get_blocks()))
|
||||
|
||||
def testRegisterTupleParamRegistered(self):
|
||||
x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
|
||||
y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.fisher_blocks = {(x, y): '1'}
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
lc.register_block((x, y), 'foo')
|
||||
self.assertIn('already in LayerCollection', str(cm.exception))
|
||||
|
||||
def testRegisterTupleParamRegisteredInSuperset(self):
|
||||
x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
|
||||
y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
|
||||
z = variable_scope.get_variable('z', initializer=array_ops.constant(1,))
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.fisher_blocks = {(x, y, z): '1'}
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
lc.register_block((x, y), 'foo')
|
||||
self.assertIn('was already registered', str(cm.exception))
|
||||
|
||||
def testRegisterTupleParamSomeRegistered(self):
|
||||
x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
|
||||
y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
|
||||
z = variable_scope.get_variable('z', initializer=array_ops.constant(1,))
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.fisher_blocks = {x: MockFisherBlock('1'), z: MockFisherBlock('2')}
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
lc.register_block((x, y), MockFisherBlock('foo'))
|
||||
self.assertIn('was already registered', str(cm.exception))
|
||||
|
||||
def testRegisterTupleVarSomeRegisteredInOtherTuples(self):
|
||||
x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
|
||||
y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
|
||||
z = variable_scope.get_variable('z', initializer=array_ops.constant(1,))
|
||||
w = variable_scope.get_variable('w', initializer=array_ops.constant(1,))
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.fisher_blocks = {(x, z): '1', (z, w): '2'}
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
lc.register_block((x, y), 'foo')
|
||||
self.assertIn('was already registered', str(cm.exception))
|
||||
|
||||
def testRegisterCategoricalPredictiveDistribution(self):
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
logits = linalg_ops.eye(2)
|
||||
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.register_categorical_predictive_distribution(logits, seed=200)
|
||||
single_loss = sess.run(lc.total_sampled_loss())
|
||||
|
||||
lc2 = layer_collection.LayerCollection()
|
||||
lc2.register_categorical_predictive_distribution(logits, seed=200)
|
||||
lc2.register_categorical_predictive_distribution(logits, seed=200)
|
||||
double_loss = sess.run(lc2.total_sampled_loss())
|
||||
self.assertAlmostEqual(2 * single_loss, double_loss)
|
||||
|
||||
def testLossFunctionByName(self):
|
||||
"""Ensure loss functions can be identified by name."""
|
||||
with ops.Graph().as_default():
|
||||
logits = linalg_ops.eye(2)
|
||||
lc = layer_collection.LayerCollection()
|
||||
|
||||
# Create a new loss function by name.
|
||||
lc.register_categorical_predictive_distribution(logits, name='loss1')
|
||||
self.assertEqual(1, len(lc.towers_by_loss))
|
||||
|
||||
# Add logits to same loss function.
|
||||
lc.register_categorical_predictive_distribution(
|
||||
logits, name='loss1', reuse=True)
|
||||
self.assertEqual(1, len(lc.towers_by_loss))
|
||||
|
||||
# Add another new loss function.
|
||||
lc.register_categorical_predictive_distribution(logits, name='loss2')
|
||||
self.assertEqual(2, len(lc.towers_by_loss))
|
||||
|
||||
def testLossFunctionWithoutName(self):
|
||||
"""Ensure loss functions get unique names if 'name' not specified."""
|
||||
with ops.Graph().as_default():
|
||||
logits = linalg_ops.eye(2)
|
||||
lc = layer_collection.LayerCollection()
|
||||
|
||||
# Create a new loss function with default names.
|
||||
lc.register_categorical_predictive_distribution(logits)
|
||||
lc.register_categorical_predictive_distribution(logits)
|
||||
self.assertEqual(2, len(lc.losses))
|
||||
|
||||
def testCategoricalPredictiveDistributionMultipleMinibatches(self):
|
||||
"""Ensure multiple minibatches are registered."""
|
||||
with ops.Graph().as_default():
|
||||
batch_size = 3
|
||||
output_size = 2
|
||||
logits = array_ops.zeros([batch_size, output_size])
|
||||
targets = array_ops.ones([batch_size], dtype=dtypes.int32)
|
||||
lc = layer_collection.LayerCollection()
|
||||
|
||||
# Create a new loss function.
|
||||
lc.register_categorical_predictive_distribution(
|
||||
logits, targets=targets, name='loss1')
|
||||
|
||||
# Can add when reuse=True
|
||||
lc.register_categorical_predictive_distribution(
|
||||
logits, targets=targets, name='loss1', reuse=True)
|
||||
|
||||
# Can add when reuse=VARIABLE_SCOPE and reuse=True there.
|
||||
with variable_scope.variable_scope(
|
||||
variable_scope.get_variable_scope(), reuse=True):
|
||||
lc.register_categorical_predictive_distribution(
|
||||
logits,
|
||||
targets=targets,
|
||||
name='loss1',
|
||||
reuse=layer_collection.VARIABLE_SCOPE)
|
||||
|
||||
# Can't add when reuse=False
|
||||
with self.assertRaises(KeyError):
|
||||
lc.register_categorical_predictive_distribution(
|
||||
logits, targets=targets, name='loss1', reuse=False)
|
||||
|
||||
# Can't add when reuse=VARIABLE_SCOPE and reuse=False there.
|
||||
with self.assertRaises(KeyError):
|
||||
lc.register_categorical_predictive_distribution(
|
||||
logits,
|
||||
targets=targets,
|
||||
name='loss1',
|
||||
reuse=layer_collection.VARIABLE_SCOPE)
|
||||
|
||||
self.assertEqual(len(lc.towers_by_loss), 1)
|
||||
# Three successful registrations.
|
||||
self.assertEqual(len(lc.towers_by_loss[0]), 3)
|
||||
|
||||
def testRegisterCategoricalPredictiveDistributionBatchSize1(self):
|
||||
with ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
logits = random_ops.random_normal((1, 2))
|
||||
lc = layer_collection.LayerCollection()
|
||||
|
||||
lc.register_categorical_predictive_distribution(logits, seed=200)
|
||||
|
||||
def testRegisterCategoricalPredictiveDistributionSpecifiedTargets(self):
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
logits = array_ops.constant([[1., 2.], [3., 4.]], dtype=dtypes.float32)
|
||||
lc = layer_collection.LayerCollection()
|
||||
targets = array_ops.constant([0, 1], dtype=dtypes.int32)
|
||||
|
||||
lc.register_categorical_predictive_distribution(logits, targets=targets)
|
||||
single_loss = sess.run(lc.total_loss())
|
||||
self.assertAlmostEqual(1.6265233, single_loss)
|
||||
|
||||
def testRegisterNormalPredictiveDistribution(self):
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
predictions = array_ops.constant(
|
||||
[[1., 2.], [3., 4]], dtype=dtypes.float32)
|
||||
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.register_normal_predictive_distribution(predictions, 1., seed=200)
|
||||
single_loss = sess.run(lc.total_sampled_loss())
|
||||
|
||||
lc2 = layer_collection.LayerCollection()
|
||||
lc2.register_normal_predictive_distribution(predictions, 1., seed=200)
|
||||
lc2.register_normal_predictive_distribution(predictions, 1., seed=200)
|
||||
double_loss = sess.run(lc2.total_sampled_loss())
|
||||
|
||||
self.assertAlmostEqual(2 * single_loss, double_loss)
|
||||
|
||||
def testRegisterNormalPredictiveDistributionSpecifiedTargets(self):
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
predictions = array_ops.constant(
|
||||
[[1., 2.], [3., 4.]], dtype=dtypes.float32)
|
||||
lc = layer_collection.LayerCollection()
|
||||
targets = array_ops.constant([[3., 1.], [4., 2.]], dtype=dtypes.float32)
|
||||
|
||||
lc.register_normal_predictive_distribution(
|
||||
predictions, 2.**2, targets=targets)
|
||||
single_loss = sess.run(lc.total_loss())
|
||||
self.assertAlmostEqual(7.6983433, single_loss)
|
||||
|
||||
def ensureLayerReuseWorks(self, register_fn):
|
||||
"""Ensure the 'reuse' keyword argument function as intended.
|
||||
|
||||
Args:
|
||||
register_fn: function for registering a layer. Arguments are
|
||||
layer_collection, reuse, and approx.
|
||||
"""
|
||||
# Fails on second if reuse=False.
|
||||
lc = layer_collection.LayerCollection()
|
||||
register_fn(lc)
|
||||
with self.assertRaises(ValueError):
|
||||
register_fn(lc, reuse=False)
|
||||
|
||||
# Succeeds on second if reuse=True.
|
||||
lc = layer_collection.LayerCollection()
|
||||
register_fn(lc)
|
||||
register_fn(lc, reuse=True)
|
||||
|
||||
# Fails on second if reuse=VARIABLE_SCOPE and no variable reuse.
|
||||
lc = layer_collection.LayerCollection()
|
||||
register_fn(lc)
|
||||
with self.assertRaises(ValueError):
|
||||
register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE)
|
||||
|
||||
# Succeeds on second if reuse=VARIABLE_SCOPE and variable reuse.
|
||||
lc = layer_collection.LayerCollection()
|
||||
register_fn(lc)
|
||||
with variable_scope.variable_scope(
|
||||
variable_scope.get_variable_scope(), reuse=True):
|
||||
register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE)
|
||||
|
||||
# Fails if block type changes.
|
||||
lc = layer_collection.LayerCollection()
|
||||
register_fn(lc, approx=layer_collection.APPROX_KRONECKER_NAME)
|
||||
with self.assertRaises(ValueError):
|
||||
register_fn(lc, approx=layer_collection.APPROX_DIAGONAL_NAME, reuse=True)
|
||||
|
||||
# Fails if reuse requested but no FisherBlock exists.
|
||||
lc = layer_collection.LayerCollection()
|
||||
with self.assertRaises(KeyError):
|
||||
register_fn(lc, reuse=True)
|
||||
|
||||
def testRegisterFullyConnectedReuse(self):
|
||||
"""Ensure the 'reuse' works with register_fully_connected."""
|
||||
with ops.Graph().as_default():
|
||||
inputs = array_ops.ones([2, 10])
|
||||
outputs = array_ops.zeros([2, 5])
|
||||
params = (
|
||||
variable_scope.get_variable('w', [10, 5]), #
|
||||
variable_scope.get_variable('b', [5]))
|
||||
|
||||
def register_fn(lc, **kwargs):
|
||||
lc.register_fully_connected(
|
||||
params=params, inputs=inputs, outputs=outputs, **kwargs)
|
||||
|
||||
self.ensureLayerReuseWorks(register_fn)
|
||||
|
||||
def testRegisterConv2dReuse(self):
|
||||
"""Ensure the 'reuse' works with register_conv2d."""
|
||||
with ops.Graph().as_default():
|
||||
inputs = array_ops.ones([2, 5, 5, 10])
|
||||
outputs = array_ops.zeros([2, 5, 5, 3])
|
||||
params = (
|
||||
variable_scope.get_variable('w', [1, 1, 10, 3]), #
|
||||
variable_scope.get_variable('b', [3]))
|
||||
|
||||
def register_fn(lc, **kwargs):
|
||||
lc.register_conv2d(
|
||||
params=params,
|
||||
strides=[1, 1, 1, 1],
|
||||
padding='SAME',
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
**kwargs)
|
||||
|
||||
self.ensureLayerReuseWorks(register_fn)
|
||||
|
||||
def testReuseWithInvalidRegistration(self):
|
||||
"""Invalid registrations shouldn't overwrite existing blocks."""
|
||||
with ops.Graph().as_default():
|
||||
inputs = array_ops.ones([2, 5, 5, 10])
|
||||
outputs = array_ops.zeros([2, 5, 5, 3])
|
||||
w = variable_scope.get_variable('w', [1, 1, 10, 3])
|
||||
b = variable_scope.get_variable('b', [3])
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.register_fully_connected(w, inputs, outputs)
|
||||
self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1)
|
||||
with self.assertRaises(KeyError):
|
||||
lc.register_fully_connected((w, b), inputs, outputs, reuse=True)
|
||||
self.assertNotIn((w, b), lc.fisher_blocks)
|
||||
self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1)
|
||||
lc.register_fully_connected(w, inputs, outputs, reuse=True)
|
||||
self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 2)
|
||||
|
||||
def testMakeOrGetFactor(self):
|
||||
with ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
lc = layer_collection.LayerCollection()
|
||||
key = array_ops.constant(1)
|
||||
lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
|
||||
lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
|
||||
lc.make_or_get_factor(fisher_factors.FullFactor,
|
||||
((array_ops.constant(2),), 16))
|
||||
|
||||
self.assertEqual(2, len(lc.get_factors()))
|
||||
variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
self.assertTrue(
|
||||
all([var.name.startswith('LayerCollection') for var in variables]))
|
||||
|
||||
def testMakeOrGetFactorCustomScope(self):
|
||||
with ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
scope = 'Foo'
|
||||
lc = layer_collection.LayerCollection(name=scope)
|
||||
key = array_ops.constant(1)
|
||||
lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
|
||||
lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
|
||||
lc.make_or_get_factor(fisher_factors.FullFactor,
|
||||
((array_ops.constant(2),), 16))
|
||||
|
||||
self.assertEqual(2, len(lc.get_factors()))
|
||||
variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
self.assertTrue(all([var.name.startswith(scope) for var in variables]))
|
||||
|
||||
def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self):
|
||||
x = variable_scope.get_variable('x', shape=())
|
||||
y = variable_scope.get_variable('y', shape=())
|
||||
z = variable_scope.get_variable('z', shape=())
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.define_linked_parameters((x, y))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
lc.define_linked_parameters((x, z))
|
||||
|
||||
def testIdentifySubsetPreviouslyRegisteredTensor(self):
|
||||
x = variable_scope.get_variable('x', shape=())
|
||||
y = variable_scope.get_variable('y', shape=())
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.define_linked_parameters((x, y))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
lc.define_linked_parameters(x)
|
||||
|
||||
def testSpecifyApproximation(self):
|
||||
w_0 = variable_scope.get_variable('w_0', [10, 10])
|
||||
w_1 = variable_scope.get_variable('w_1', [10, 10])
|
||||
|
||||
b_0 = variable_scope.get_variable('b_0', [10])
|
||||
b_1 = variable_scope.get_variable('b_1', [10])
|
||||
|
||||
x_0 = array_ops.placeholder(dtypes.float32, shape=(32, 10))
|
||||
x_1 = array_ops.placeholder(dtypes.float32, shape=(32, 10))
|
||||
|
||||
pre_bias_0 = math_ops.matmul(x_0, w_0)
|
||||
pre_bias_1 = math_ops.matmul(x_1, w_1)
|
||||
|
||||
# Build the fully connected layers in the graph.
|
||||
pre_bias_0 + b_0 # pylint: disable=pointless-statement
|
||||
pre_bias_1 + b_1 # pylint: disable=pointless-statement
|
||||
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.define_linked_parameters(
|
||||
w_0, approximation=layer_collection.APPROX_DIAGONAL_NAME)
|
||||
lc.define_linked_parameters(
|
||||
w_1, approximation=layer_collection.APPROX_DIAGONAL_NAME)
|
||||
lc.define_linked_parameters(
|
||||
b_0, approximation=layer_collection.APPROX_FULL_NAME)
|
||||
lc.define_linked_parameters(
|
||||
b_1, approximation=layer_collection.APPROX_FULL_NAME)
|
||||
|
||||
lc.register_fully_connected(w_0, x_0, pre_bias_0)
|
||||
lc.register_fully_connected(
|
||||
w_1, x_1, pre_bias_1, approx=layer_collection.APPROX_KRONECKER_NAME)
|
||||
self.assertIsInstance(lc.fisher_blocks[w_0],
|
||||
fisher_blocks.FullyConnectedDiagonalFB)
|
||||
self.assertIsInstance(lc.fisher_blocks[w_1],
|
||||
fisher_blocks.FullyConnectedKFACBasicFB)
|
||||
|
||||
lc.register_generic(b_0, batch_size=1)
|
||||
lc.register_generic(
|
||||
b_1, batch_size=1, approx=layer_collection.APPROX_DIAGONAL_NAME)
|
||||
self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.FullFB)
|
||||
self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB)
|
||||
|
||||
def testDefaultLayerCollection(self):
|
||||
with ops.Graph().as_default():
|
||||
# Can't get default if there isn't one set.
|
||||
with self.assertRaises(ValueError):
|
||||
layer_collection.get_default_layer_collection()
|
||||
|
||||
# Can't set default twice.
|
||||
lc = layer_collection.LayerCollection()
|
||||
layer_collection.set_default_layer_collection(lc)
|
||||
with self.assertRaises(ValueError):
|
||||
layer_collection.set_default_layer_collection(lc)
|
||||
|
||||
# Same as one set.
|
||||
self.assertTrue(lc is layer_collection.get_default_layer_collection())
|
||||
|
||||
# Can set to None.
|
||||
layer_collection.set_default_layer_collection(None)
|
||||
with self.assertRaises(ValueError):
|
||||
layer_collection.get_default_layer_collection()
|
||||
|
||||
# as_default() is the same as setting/clearing.
|
||||
with lc.as_default():
|
||||
self.assertTrue(lc is layer_collection.get_default_layer_collection())
|
||||
with self.assertRaises(ValueError):
|
||||
layer_collection.get_default_layer_collection()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -1,190 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tf.contrib.kfac.loss_functions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.kfac.python.ops import loss_functions
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class InsertSliceInZerosTest(test.TestCase):
|
||||
|
||||
def testBadShape(self):
|
||||
bad_shaped_ones = array_ops.ones(shape=[1, 3]) # n.b. shape[1] != 1
|
||||
with self.assertRaises(ValueError):
|
||||
loss_functions.insert_slice_in_zeros(bad_shaped_ones, 1, 42, 17)
|
||||
|
||||
def test3d(self):
|
||||
input_tensor = constant_op.constant([[[1, 2]], [[3, 4]]])
|
||||
expected_output_array = [[[1, 2], [0, 0]], [[3, 4], [0, 0]]]
|
||||
op = loss_functions.insert_slice_in_zeros(input_tensor, 1, 2, 0)
|
||||
with self.cached_session() as sess:
|
||||
actual_output_array = sess.run(op)
|
||||
self.assertAllEqual(expected_output_array, actual_output_array)
|
||||
|
||||
|
||||
class CategoricalLogitsNegativeLogProbLossTest(test.TestCase):
|
||||
|
||||
def testSample(self):
|
||||
"""Ensure samples can be drawn."""
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
logits = np.asarray([
|
||||
[0., 0., 0.], #
|
||||
[1., -1., 0.]
|
||||
]).astype(np.float32)
|
||||
loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
|
||||
array_ops.constant(logits))
|
||||
sample = loss.sample(42)
|
||||
sample = sess.run(sample)
|
||||
self.assertEqual(sample.shape, (2,))
|
||||
|
||||
def testEvaluateOnTargets(self):
|
||||
"""Ensure log probability can be evaluated correctly."""
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
logits = np.asarray([
|
||||
[0., 0., 0.], #
|
||||
[1., -1., 0.]
|
||||
]).astype(np.float32)
|
||||
targets = np.asarray([2, 1]).astype(np.int32)
|
||||
loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
|
||||
array_ops.constant(logits), targets=array_ops.constant(targets))
|
||||
neg_log_prob = loss.evaluate()
|
||||
neg_log_prob = sess.run(neg_log_prob)
|
||||
|
||||
# Calculate explicit log probability of targets.
|
||||
probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
|
||||
log_probs = np.log([
|
||||
probs[0, targets[0]], #
|
||||
probs[1, targets[1]]
|
||||
])
|
||||
expected_log_prob = np.sum(log_probs)
|
||||
|
||||
self.assertAllClose(neg_log_prob, -expected_log_prob)
|
||||
|
||||
def testEvaluateOnSample(self):
|
||||
"""Ensure log probability of a sample can be drawn."""
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
logits = np.asarray([
|
||||
[0., 0., 0.], #
|
||||
[1., -1., 0.]
|
||||
]).astype(np.float32)
|
||||
loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
|
||||
array_ops.constant(logits))
|
||||
neg_log_prob = loss.evaluate_on_sample(42)
|
||||
|
||||
# Simply ensure this doesn't crash. As the output is random, it's
|
||||
# difficult to say if the output is correct or not...
|
||||
neg_log_prob = sess.run(neg_log_prob)
|
||||
|
||||
def testMultiplyFisherSingleVector(self):
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
logits = np.array([1., 2., 3.])
|
||||
loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)
|
||||
|
||||
# the LossFunction.multiply_fisher docstring only says it supports the
|
||||
# case where the vector is the same shape as the input natural parameters
|
||||
# (i.e. the logits here), but here we also test leading dimensions
|
||||
vector = np.array([1., 2., 3.])
|
||||
vectors = [vector, vector.reshape(1, -1), np.stack([vector] * 4)]
|
||||
|
||||
probs = np.exp(logits - np.logaddexp.reduce(logits))
|
||||
fisher = np.diag(probs) - np.outer(probs, probs)
|
||||
|
||||
for vector in vectors:
|
||||
result = loss.multiply_fisher(vector)
|
||||
expected_result = np.dot(vector, fisher)
|
||||
self.assertAllClose(expected_result, sess.run(result))
|
||||
|
||||
def testMultiplyFisherBatch(self):
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
logits = np.array([[1., 2., 3.], [4., 6., 8.]])
|
||||
loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)
|
||||
|
||||
vector = np.array([[1., 2., 3.], [5., 3., 1.]])
|
||||
|
||||
na = np.newaxis
|
||||
probs = np.exp(logits - np.logaddexp.reduce(logits, axis=-1,
|
||||
keepdims=True))
|
||||
fishers = probs[..., na] * np.eye(3) - probs[..., na] * probs[..., na, :]
|
||||
|
||||
result = loss.multiply_fisher(vector)
|
||||
expected_result = np.matmul(vector[..., na, :], fishers)[..., 0, :]
|
||||
self.assertEqual(sess.run(result).shape, logits.shape)
|
||||
self.assertAllClose(expected_result, sess.run(result))
|
||||
|
||||
|
||||
class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase):
|
||||
|
||||
def testSample(self):
|
||||
"""Ensure samples can be drawn."""
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
logits = np.asarray([
|
||||
[0., 0., 0.], #
|
||||
[1., -1., 0.]
|
||||
]).astype(np.float32)
|
||||
loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
|
||||
array_ops.constant(logits))
|
||||
sample = loss.sample(42)
|
||||
sample = sess.run(sample)
|
||||
self.assertEqual(sample.shape, (2, 3))
|
||||
|
||||
def testEvaluateOnTargets(self):
|
||||
"""Ensure log probability can be evaluated correctly."""
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
logits = np.asarray([
|
||||
[0., 0., 0.], #
|
||||
[1., -1., 0.]
|
||||
]).astype(np.float32)
|
||||
targets = np.asarray([2, 1]).astype(np.int32)
|
||||
loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
|
||||
array_ops.constant(logits), targets=array_ops.one_hot(targets, 3))
|
||||
neg_log_prob = loss.evaluate()
|
||||
neg_log_prob = sess.run(neg_log_prob)
|
||||
|
||||
# Calculate explicit log probability of targets.
|
||||
probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
|
||||
log_probs = np.log([
|
||||
probs[0, targets[0]], #
|
||||
probs[1, targets[1]]
|
||||
])
|
||||
expected_log_prob = np.sum(log_probs)
|
||||
|
||||
self.assertAllClose(neg_log_prob, -expected_log_prob)
|
||||
|
||||
def testEvaluateOnSample(self):
|
||||
"""Ensure log probability of a sample can be drawn."""
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
logits = np.asarray([
|
||||
[0., 0., 0.], #
|
||||
[1., -1., 0.]
|
||||
]).astype(np.float32)
|
||||
loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
|
||||
array_ops.constant(logits))
|
||||
neg_log_prob = loss.evaluate_on_sample(42)
|
||||
|
||||
# Simply ensure this doesn't crash. As the output is random, it's
|
||||
# difficult to say if the output is correct or not...
|
||||
neg_log_prob = sess.run(neg_log_prob)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -1,50 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tf.contrib.kfac.op_queue."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.kfac.python.ops import op_queue
|
||||
from tensorflow.python.framework import ops as tf_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class OpQueueTest(test.TestCase):
|
||||
|
||||
def testNextOp(self):
|
||||
"""Ensures all ops get selected eventually."""
|
||||
with tf_ops.Graph().as_default():
|
||||
ops = [
|
||||
math_ops.add(1, 2),
|
||||
math_ops.subtract(1, 2),
|
||||
math_ops.reduce_mean([1, 2]),
|
||||
]
|
||||
queue = op_queue.OpQueue(ops, seed=0)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
# Ensure every inv update op gets selected.
|
||||
selected_ops = set([queue.next_op(sess) for _ in ops])
|
||||
self.assertEqual(set(ops), set(selected_ops))
|
||||
|
||||
# Ensure additional calls don't create any new ops.
|
||||
selected_ops.add(queue.next_op(sess))
|
||||
self.assertEqual(set(ops), set(selected_ops))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -1,219 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tf.contrib.kfac.optimizer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
|
||||
from tensorflow.contrib.kfac.python.ops import layer_collection as lc
|
||||
from tensorflow.contrib.kfac.python.ops import optimizer
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
# We need to set these constants since the numerical values used in the tests
|
||||
# were chosen when these used to be the defaults.
|
||||
ff.set_global_constants(init_covariances_at_zero=False,
|
||||
zero_debias=False,
|
||||
init_inverses_at_zero=False)
|
||||
|
||||
|
||||
def dummy_layer_collection():
|
||||
lcoll = lc.LayerCollection()
|
||||
dummy = array_ops.constant([1., 2.])
|
||||
lcoll.register_categorical_predictive_distribution(logits=dummy)
|
||||
return lcoll
|
||||
|
||||
|
||||
class OptimizerTest(test.TestCase):
|
||||
|
||||
def testOptimizerInitInvalidMomentumRegistration(self):
|
||||
with self.assertRaises(ValueError):
|
||||
optimizer.KfacOptimizer(
|
||||
0.1, 0.2, 0.3, lc.LayerCollection(), momentum_type='foo')
|
||||
|
||||
def testOptimizerInit(self):
|
||||
with ops.Graph().as_default():
|
||||
layer_collection = lc.LayerCollection()
|
||||
|
||||
inputs = array_ops.ones((2, 1)) * 2
|
||||
weights_val = np.ones((1, 1), dtype=np.float32) * 3.
|
||||
weights = variable_scope.get_variable(
|
||||
'w', initializer=array_ops.constant(weights_val))
|
||||
bias = variable_scope.get_variable(
|
||||
'b', initializer=init_ops.zeros_initializer(), shape=(1, 1))
|
||||
output = math_ops.matmul(inputs, weights) + bias
|
||||
|
||||
layer_collection.register_fully_connected((weights, bias), inputs, output)
|
||||
|
||||
logits = math_ops.tanh(output)
|
||||
targets = array_ops.constant([[0.], [1.]])
|
||||
output = math_ops.reduce_mean(
|
||||
nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets))
|
||||
|
||||
layer_collection.register_categorical_predictive_distribution(logits)
|
||||
|
||||
optimizer.KfacOptimizer(
|
||||
0.1,
|
||||
0.2,
|
||||
0.3,
|
||||
layer_collection,
|
||||
momentum=0.5,
|
||||
momentum_type='regular')
|
||||
|
||||
def testSquaredFisherNorm(self):
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None),
|
||||
(array_ops.constant([[2., 3.], [4., 5.]]), None)]
|
||||
pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None),
|
||||
(array_ops.constant([[7., 8.], [9., 10.]]), None)]
|
||||
opt = optimizer.KfacOptimizer(0.1, 0.2, 0.3, dummy_layer_collection())
|
||||
sq_norm = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars)
|
||||
self.assertAlmostEqual(174., sess.run(sq_norm), places=5)
|
||||
|
||||
def testUpdateClipCoeff(self):
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None),
|
||||
(array_ops.constant([[2., 3.], [4., 5.]]), None)]
|
||||
pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None),
|
||||
(array_ops.constant([[7., 8.], [9., 10.]]), None)]
|
||||
lrate = 0.1
|
||||
|
||||
# Note: without rescaling, the squared Fisher norm of the update
|
||||
# is 1.74
|
||||
|
||||
# If the update already satisfies the norm constraint, there should
|
||||
# be no rescaling.
|
||||
opt = optimizer.KfacOptimizer(
|
||||
lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=10.)
|
||||
coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars)
|
||||
self.assertAlmostEqual(1., sess.run(coeff), places=5)
|
||||
|
||||
# If the update violates the constraint, it should be rescaled to
|
||||
# be on the constraint boundary.
|
||||
opt = optimizer.KfacOptimizer(
|
||||
lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=0.5)
|
||||
coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars)
|
||||
sq_norm_pgrad = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars)
|
||||
sq_norm_update = lrate**2 * coeff**2 * sq_norm_pgrad
|
||||
self.assertAlmostEqual(0.5, sess.run(sq_norm_update), places=5)
|
||||
|
||||
def testComputeUpdateStepsRegular(self):
|
||||
# TODO(olganw): implement this.
|
||||
pass
|
||||
|
||||
def testComputeUpdateStepsAdam(self):
|
||||
# TODO(olganw): implement this.
|
||||
pass
|
||||
|
||||
def testUpdateVelocities(self):
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
layers = lc.LayerCollection()
|
||||
layers.register_categorical_predictive_distribution(
|
||||
array_ops.constant([1.0]))
|
||||
opt = optimizer.KfacOptimizer(
|
||||
0.1, 0.2, 0.3, layers, momentum=0.5, momentum_type='regular')
|
||||
x = variable_scope.get_variable('x', initializer=array_ops.ones((2, 2)))
|
||||
y = variable_scope.get_variable(
|
||||
'y', initializer=array_ops.ones((2, 2)) * 2)
|
||||
vec1 = array_ops.ones((2, 2)) * 3
|
||||
vec2 = array_ops.ones((2, 2)) * 4
|
||||
|
||||
model_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
update_op = opt._update_velocities([(vec1, x), (vec2, y)], 0.5)
|
||||
opt_vars = [
|
||||
v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
if v not in model_vars
|
||||
]
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
old_opt_vars = sess.run(opt_vars)
|
||||
|
||||
# Optimizer vars start out at 0.
|
||||
for opt_var in old_opt_vars:
|
||||
self.assertAllEqual(sess.run(array_ops.zeros_like(opt_var)), opt_var)
|
||||
|
||||
sess.run(update_op)
|
||||
new_opt_vars = sess.run(opt_vars)
|
||||
# After one update, the velocities are equal to the vectors.
|
||||
for vec, opt_var in zip([vec1, vec2], new_opt_vars):
|
||||
self.assertAllEqual(sess.run(vec), opt_var)
|
||||
|
||||
sess.run(update_op)
|
||||
final_opt_vars = sess.run(opt_vars)
|
||||
for first, second in zip(new_opt_vars, final_opt_vars):
|
||||
self.assertFalse(np.equal(first, second).all())
|
||||
|
||||
def testApplyGradients(self):
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
layer_collection = lc.LayerCollection()
|
||||
|
||||
inputs = array_ops.ones((2, 1)) * 2
|
||||
weights_val = np.ones((1, 1), dtype=np.float32) * 3.
|
||||
weights = variable_scope.get_variable(
|
||||
'w', initializer=array_ops.constant(weights_val))
|
||||
bias = variable_scope.get_variable(
|
||||
'b', initializer=init_ops.zeros_initializer(), shape=(1, 1))
|
||||
output = math_ops.matmul(inputs, weights) + bias
|
||||
|
||||
layer_collection.register_fully_connected((weights, bias), inputs, output)
|
||||
|
||||
logits = math_ops.tanh(output)
|
||||
targets = array_ops.constant([[0.], [1.]])
|
||||
output = math_ops.reduce_mean(
|
||||
nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets))
|
||||
|
||||
layer_collection.register_categorical_predictive_distribution(logits)
|
||||
|
||||
opt = optimizer.KfacOptimizer(
|
||||
0.1,
|
||||
0.2,
|
||||
0.3,
|
||||
layer_collection,
|
||||
momentum=0.5,
|
||||
momentum_type='regular')
|
||||
(cov_update_thunks,
|
||||
inv_update_thunks) = opt.make_vars_and_create_op_thunks()
|
||||
cov_update_ops = tuple(thunk() for thunk in cov_update_thunks)
|
||||
inv_update_ops = tuple(thunk() for thunk in inv_update_thunks)
|
||||
|
||||
grads_and_vars = opt.compute_gradients(output, [weights, bias])
|
||||
all_vars = [grad_and_var[1] for grad_and_var in grads_and_vars]
|
||||
|
||||
op = opt.apply_gradients(grads_and_vars)
|
||||
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
old_vars = sess.run(all_vars)
|
||||
sess.run(cov_update_ops)
|
||||
sess.run(inv_update_ops)
|
||||
sess.run(op)
|
||||
new_vars = sess.run(all_vars)
|
||||
|
||||
for old_var, new_var in zip(old_vars, new_vars):
|
||||
self.assertNotEqual(old_var, new_var)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -1,410 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tf.contrib.kfac.utils."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import numpy.random as npr
|
||||
|
||||
from tensorflow.contrib.kfac.python.ops import utils
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class SequenceDictTest(test.TestCase):
|
||||
|
||||
def testSequenceDictInit(self):
|
||||
seq_dict = utils.SequenceDict()
|
||||
self.assertFalse(seq_dict._dict)
|
||||
|
||||
def testSequenceDictInitWithIterable(self):
|
||||
reg_dict = {'a': 'foo', 'b': 'bar'}
|
||||
itr = zip(reg_dict.keys(), reg_dict.values())
|
||||
seq_dict = utils.SequenceDict(itr)
|
||||
self.assertEqual(reg_dict, seq_dict._dict)
|
||||
|
||||
def testGetItemSingleKey(self):
|
||||
seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'})
|
||||
self.assertEqual('foo', seq_dict['a'])
|
||||
|
||||
def testGetItemMultipleKeys(self):
|
||||
seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'})
|
||||
self.assertEqual(['foo', 'bar'], seq_dict[('a', 'b')])
|
||||
|
||||
def testSetItemSingleKey(self):
|
||||
seq_dict = utils.SequenceDict()
|
||||
seq_dict['a'] = 'foo'
|
||||
self.assertEqual([('a', 'foo')], seq_dict.items())
|
||||
|
||||
def testSetItemMultipleKeys(self):
|
||||
seq_dict = utils.SequenceDict()
|
||||
keys = ('a', 'b', 'c')
|
||||
values = ('foo', 'bar', 'baz')
|
||||
seq_dict[keys] = values
|
||||
self.assertItemsEqual(list(zip(keys, values)), seq_dict.items())
|
||||
|
||||
|
||||
class SubGraphTest(test.TestCase):
|
||||
|
||||
def testBasicGraph(self):
|
||||
a = array_ops.constant([[1., 2.], [3., 4.]])
|
||||
b = array_ops.constant([[5., 6.], [7., 8.]])
|
||||
c = a + b
|
||||
d = a * b
|
||||
sub_graph = utils.SubGraph((c,))
|
||||
self.assertTrue(sub_graph.is_member(a))
|
||||
self.assertTrue(sub_graph.is_member(b))
|
||||
self.assertTrue(sub_graph.is_member(c))
|
||||
self.assertFalse(sub_graph.is_member(d))
|
||||
|
||||
def testRepeatedAdds(self):
|
||||
a = array_ops.constant([[1., 2.], [3., 4.]])
|
||||
b = array_ops.constant([[5., 6.], [7., 8.]])
|
||||
c = a + b + a # note that a appears twice in this graph
|
||||
sub_graph = utils.SubGraph((c,))
|
||||
self.assertTrue(sub_graph.is_member(a))
|
||||
self.assertTrue(sub_graph.is_member(b))
|
||||
self.assertTrue(sub_graph.is_member(c))
|
||||
|
||||
def testFilterList(self):
|
||||
a = array_ops.constant([[1., 2.], [3., 4.]])
|
||||
b = array_ops.constant([[5., 6.], [7., 8.]])
|
||||
c = a + b
|
||||
d = a * b
|
||||
sub_graph = utils.SubGraph((c,))
|
||||
input_list = [b, d]
|
||||
filtered_list = sub_graph.filter_list(input_list)
|
||||
self.assertEqual(filtered_list, [b])
|
||||
|
||||
def testVariableUses(self):
|
||||
with ops.Graph().as_default():
|
||||
var = variable_scope.get_variable('var', shape=[10, 10])
|
||||
resource_var = variable_scope.get_variable(
|
||||
'resource_var', shape=[10, 10], use_resource=True)
|
||||
x = array_ops.zeros([3, 10])
|
||||
z0 = math_ops.matmul(x, var) + math_ops.matmul(x, var)
|
||||
z1 = math_ops.matmul(x, resource_var)
|
||||
sub_graph = utils.SubGraph((z0, z1))
|
||||
self.assertEqual(2, sub_graph.variable_uses(var))
|
||||
self.assertEqual(1, sub_graph.variable_uses(resource_var))
|
||||
|
||||
|
||||
class UtilsTest(test.TestCase):
|
||||
|
||||
def _fully_connected_layer_params(self):
|
||||
weights_part = array_ops.constant([[1., 2.], [4., 3.]])
|
||||
bias_part = array_ops.constant([1., 2.])
|
||||
return (weights_part, bias_part)
|
||||
|
||||
def _conv_layer_params(self):
|
||||
weights_shape = 2, 2, 3, 4
|
||||
biases_shape = weights_shape[-1:]
|
||||
weights = array_ops.constant(npr.RandomState(0).randn(*weights_shape))
|
||||
biases = array_ops.constant(npr.RandomState(1).randn(*biases_shape))
|
||||
return (weights, biases)
|
||||
|
||||
def testFullyConnectedLayerParamsTupleToMat2d(self):
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
layer_params = self._fully_connected_layer_params()
|
||||
output = utils.layer_params_to_mat2d(layer_params)
|
||||
self.assertListEqual([3, 2], output.get_shape().as_list())
|
||||
self.assertAllClose(
|
||||
sess.run(output), np.array([[1., 2.], [4., 3.], [1., 2.]]))
|
||||
|
||||
def testFullyConnectedLayerParamsTensorToMat2d(self):
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
layer_params = self._fully_connected_layer_params()
|
||||
output = utils.layer_params_to_mat2d(layer_params[0])
|
||||
self.assertListEqual([2, 2], output.get_shape().as_list())
|
||||
self.assertAllClose(sess.run(output), np.array([[1., 2.], [4., 3.]]))
|
||||
|
||||
def testConvLayerParamsTupleToMat2d(self):
|
||||
with ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
layer_params = self._conv_layer_params()
|
||||
output = utils.layer_params_to_mat2d(layer_params)
|
||||
self.assertListEqual([2 * 2 * 3 + 1, 4], output.get_shape().as_list())
|
||||
|
||||
def testKron(self):
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
mat1 = np.array([[1., 2.], [3., 4.]])
|
||||
mat2 = np.array([[5., 6.], [7., 8.]])
|
||||
mat1_tf = array_ops.constant(mat1)
|
||||
mat2_tf = array_ops.constant(mat2)
|
||||
ans_tf = sess.run(utils.kronecker_product(mat1_tf, mat2_tf))
|
||||
ans_np = np.kron(mat1, mat2)
|
||||
self.assertAllClose(ans_tf, ans_np)
|
||||
|
||||
def testMat2dToFullyConnectedLayerParamsTuple(self):
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
vector_template = self._fully_connected_layer_params()
|
||||
mat2d = array_ops.constant([[5., 4.], [3., 2.], [1., 0.]])
|
||||
|
||||
output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d))
|
||||
|
||||
self.assertIsInstance(output, tuple)
|
||||
self.assertEqual(len(output), 2)
|
||||
a, b = output
|
||||
self.assertAllClose(a, np.array([[5., 4.], [3., 2.]]))
|
||||
self.assertAllClose(b, np.array([1., 0.]))
|
||||
|
||||
def testMat2dToFullyConnectedLayerParamsTensor(self):
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
vector_template = self._fully_connected_layer_params()[0]
|
||||
mat2d = array_ops.constant([[5., 4.], [3., 2.]])
|
||||
|
||||
output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d))
|
||||
|
||||
self.assertAllClose(output, np.array([[5., 4.], [3., 2.]]))
|
||||
|
||||
def testTensorsToColumn(self):
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
|
||||
vector = array_ops.constant(np.array([[0., 1.], [2., 3.]]))
|
||||
output = utils.tensors_to_column(vector)
|
||||
self.assertListEqual([4, 1], output.get_shape().as_list())
|
||||
self.assertAllClose(sess.run(output), np.array([0., 1., 2., 3.])[:, None])
|
||||
|
||||
vector = self._fully_connected_layer_params()
|
||||
output = utils.tensors_to_column(vector)
|
||||
self.assertListEqual([6, 1], output.get_shape().as_list())
|
||||
self.assertAllClose(
|
||||
sess.run(output), np.array([1., 2., 4., 3., 1., 2.])[:, None])
|
||||
|
||||
vector = list(vector)
|
||||
vector.append(array_ops.constant([[6.], [7.], [8.], [9.]]))
|
||||
|
||||
output = utils.tensors_to_column(vector)
|
||||
self.assertListEqual([10, 1], output.get_shape().as_list())
|
||||
self.assertAllClose(
|
||||
sess.run(output),
|
||||
np.array([1., 2., 4., 3., 1., 2., 6., 7., 8., 9.])[:, None])
|
||||
|
||||
def testColumnToTensors(self):
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
|
||||
vector_template = array_ops.constant(np.array([[0., 1.], [2., 3.]]))
|
||||
colvec = array_ops.constant(np.arange(4.)[:, None])
|
||||
output = sess.run(utils.column_to_tensors(vector_template, colvec))
|
||||
self.assertAllClose(output, np.array([[0., 1.], [2., 3.]]))
|
||||
|
||||
vector_template = self._fully_connected_layer_params()
|
||||
colvec = array_ops.constant(np.arange(6.)[:, None])
|
||||
output = sess.run(utils.column_to_tensors(vector_template, colvec))
|
||||
|
||||
self.assertIsInstance(output, tuple)
|
||||
self.assertEqual(len(output), 2)
|
||||
a, b = output
|
||||
self.assertAllClose(a, np.array([[0., 1.], [2., 3.]]))
|
||||
self.assertAllClose(b, np.array([4., 5.]))
|
||||
|
||||
vector_template = list(vector_template)
|
||||
vector_template.append(array_ops.constant([[6.], [7.], [8.], [9.]]))
|
||||
colvec = array_ops.constant(np.arange(10.)[:, None])
|
||||
output = sess.run(utils.column_to_tensors(vector_template, colvec))
|
||||
self.assertIsInstance(output, tuple)
|
||||
self.assertEqual(len(output), 3)
|
||||
a, b, c = output
|
||||
self.assertAllClose(a, np.array([[0., 1.], [2., 3.]]))
|
||||
self.assertAllClose(b, np.array([4., 5.]))
|
||||
self.assertAllClose(c, np.array([[6.], [7.], [8.], [9.]]))
|
||||
|
||||
def testPosDefInvCholesky(self):
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
npr.seed(0)
|
||||
square = lambda x: np.dot(x, x.T)
|
||||
|
||||
size = 3
|
||||
x = square(npr.randn(size, size))
|
||||
damp = 0.1
|
||||
identity = linalg_ops.eye(size, dtype=dtypes.float64)
|
||||
|
||||
tf_inv = utils.posdef_inv_cholesky(array_ops.constant(x), identity, damp)
|
||||
np_inv = np.linalg.inv(x + damp * np.eye(size))
|
||||
self.assertAllClose(sess.run(tf_inv), np_inv)
|
||||
|
||||
def testPosDefInvMatrixInverse(self):
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
npr.seed(0)
|
||||
square = lambda x: np.dot(x, x.T)
|
||||
|
||||
size = 3
|
||||
x = square(npr.randn(size, size))
|
||||
damp = 0.1
|
||||
identity = linalg_ops.eye(size, dtype=dtypes.float64)
|
||||
|
||||
tf_inv = utils.posdef_inv_matrix_inverse(
|
||||
array_ops.constant(x), identity, damp)
|
||||
np_inv = np.linalg.inv(x + damp * np.eye(size))
|
||||
self.assertAllClose(sess.run(tf_inv), np_inv)
|
||||
|
||||
def testCrossReplicaMean(self):
|
||||
"""Ensures that cross_replica_mean() executes only when num_shards > 1."""
|
||||
with ops.Graph().as_default():
|
||||
with tpu_function.tpu_shard_context(4):
|
||||
tensor = array_ops.zeros([], dtype=dtypes.float32)
|
||||
mean = utils.cross_replica_mean(tensor)
|
||||
self.assertNotEqual(mean, tensor)
|
||||
|
||||
with ops.Graph().as_default():
|
||||
with tpu_function.tpu_shard_context(1):
|
||||
tensor = array_ops.zeros([], dtype=dtypes.float32)
|
||||
mean = utils.cross_replica_mean(tensor)
|
||||
self.assertEqual(mean, tensor)
|
||||
|
||||
with ops.Graph().as_default():
|
||||
with self.assertRaises(ValueError): # Outside of TPU context.
|
||||
tensor = array_ops.zeros([], dtype=dtypes.float32)
|
||||
mean = utils.cross_replica_mean(tensor)
|
||||
|
||||
def testBatchExecute(self):
|
||||
"""Ensure batch_execute runs in a round-robin fashion."""
|
||||
|
||||
def increment_var(var):
|
||||
return lambda: var.assign_add(1)
|
||||
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
i = variable_scope.get_variable('i', initializer=0)
|
||||
accumulators = [
|
||||
variable_scope.get_variable('var%d' % j, initializer=0)
|
||||
for j in range(3)
|
||||
]
|
||||
thunks = [increment_var(var) for var in accumulators]
|
||||
increment_accumulators = utils.batch_execute(i, thunks, 2)
|
||||
increment_i = i.assign_add(1)
|
||||
|
||||
sess.run(variables.global_variables_initializer())
|
||||
|
||||
# Ensure one op per thunk.
|
||||
self.assertEqual(3, len(increment_accumulators))
|
||||
|
||||
# Ensure round-robin execution.
|
||||
values = []
|
||||
for _ in range(5):
|
||||
sess.run(increment_accumulators)
|
||||
sess.run(increment_i)
|
||||
values.append(sess.run(accumulators))
|
||||
self.assertAllClose(
|
||||
[
|
||||
[1, 1, 0], #
|
||||
[2, 1, 1], #
|
||||
[2, 2, 2], #
|
||||
[3, 3, 2], #
|
||||
[4, 3, 3]
|
||||
],
|
||||
values)
|
||||
|
||||
def testExtractConvolutionPatches(self):
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
batch_size = 10
|
||||
image_spatial_shape = [9, 10, 11]
|
||||
in_channels = out_channels = 32
|
||||
kernel_spatial_shape = [5, 3, 3]
|
||||
spatial_strides = [1, 2, 1]
|
||||
spatial_dilation = [1, 1, 1]
|
||||
padding = 'SAME'
|
||||
|
||||
images = random_ops.random_uniform(
|
||||
[batch_size] + image_spatial_shape + [in_channels], seed=0)
|
||||
kernel_shape = kernel_spatial_shape + [in_channels, out_channels]
|
||||
kernel = random_ops.random_uniform(kernel_shape, seed=1)
|
||||
|
||||
# Ensure shape matches expectation.
|
||||
patches = utils.extract_convolution_patches(
|
||||
images,
|
||||
kernel_shape,
|
||||
padding,
|
||||
strides=spatial_strides,
|
||||
dilation_rate=spatial_dilation)
|
||||
result_spatial_shape = (
|
||||
patches.shape.as_list()[1:1 + len(image_spatial_shape)])
|
||||
self.assertEqual(patches.shape.as_list(),
|
||||
[batch_size] + result_spatial_shape +
|
||||
kernel_spatial_shape + [in_channels])
|
||||
|
||||
# Ensure extract...patches() + matmul() and convolution() implementation
|
||||
# give the same answer.
|
||||
outputs = nn_ops.convolution(
|
||||
images,
|
||||
kernel,
|
||||
padding,
|
||||
strides=spatial_strides,
|
||||
dilation_rate=spatial_dilation)
|
||||
|
||||
patches_flat = array_ops.reshape(
|
||||
patches, [-1, np.prod(kernel_spatial_shape) * in_channels])
|
||||
kernel_flat = array_ops.reshape(kernel, [-1, out_channels])
|
||||
outputs_flat = math_ops.matmul(patches_flat, kernel_flat)
|
||||
|
||||
outputs_, outputs_flat_ = sess.run([outputs, outputs_flat])
|
||||
self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten())
|
||||
|
||||
def testExtractPointwiseConv2dPatches(self):
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
batch_size = 10
|
||||
image_height = image_width = 8
|
||||
in_channels = out_channels = 3
|
||||
kernel_height = kernel_width = 1
|
||||
strides = [1, 1, 1, 1]
|
||||
padding = 'VALID'
|
||||
|
||||
images = random_ops.random_uniform(
|
||||
[batch_size, image_height, image_width, in_channels], seed=0)
|
||||
kernel_shape = [kernel_height, kernel_width, in_channels, out_channels]
|
||||
kernel = random_ops.random_uniform(kernel_shape, seed=1)
|
||||
|
||||
# Ensure shape matches expectation.
|
||||
patches = utils.extract_pointwise_conv2d_patches(images, kernel_shape)
|
||||
self.assertEqual(patches.shape.as_list(), [
|
||||
batch_size, image_height, image_width, kernel_height, kernel_width,
|
||||
in_channels
|
||||
])
|
||||
|
||||
# Ensure extract...patches() + matmul() and conv2d() implementation
|
||||
# give the same answer.
|
||||
outputs = nn_ops.conv2d(images, kernel, strides, padding)
|
||||
|
||||
patches_flat = array_ops.reshape(
|
||||
patches, [-1, kernel_height * kernel_width * in_channels])
|
||||
kernel_flat = array_ops.reshape(kernel, [-1, out_channels])
|
||||
outputs_flat = math_ops.matmul(patches_flat, kernel_flat)
|
||||
|
||||
outputs_, outputs_flat_ = sess.run([outputs, outputs_flat])
|
||||
self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -1,263 +0,0 @@
|
||||
package(default_visibility = [
|
||||
"//tensorflow/contrib/kfac:__pkg__",
|
||||
"//tensorflow/contrib/kfac/python/kernel_tests:__pkg__",
|
||||
])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_library(
|
||||
name = "fisher_blocks",
|
||||
srcs = ["fisher_blocks.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":fisher_factors",
|
||||
":utils",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "fisher_blocks_lib",
|
||||
srcs = ["fisher_blocks_lib.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":fisher_blocks",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "fisher_factors",
|
||||
srcs = ["fisher_factors.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":linear_operator",
|
||||
":utils",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:linalg_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:special_math_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "fisher_factors_lib",
|
||||
srcs = ["fisher_factors_lib.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":fisher_factors",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "linear_operator",
|
||||
srcs = ["linear_operator.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "loss_functions",
|
||||
srcs = ["loss_functions.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/distributions:distributions_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python/ops/distributions",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "loss_functions_lib",
|
||||
srcs = ["loss_functions_lib.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":loss_functions",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "curvature_matrix_vector_products",
|
||||
srcs = ["curvature_matrix_vector_products.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "curvature_matrix_vector_products_lib",
|
||||
srcs = ["curvature_matrix_vector_products_lib.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":curvature_matrix_vector_products",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "layer_collection",
|
||||
srcs = ["layer_collection.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":fisher_blocks",
|
||||
":loss_functions",
|
||||
":utils",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "layer_collection_lib",
|
||||
srcs = ["layer_collection_lib.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":layer_collection",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "kfac_optimizer",
|
||||
srcs = [
|
||||
"optimizer.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":curvature_matrix_vector_products",
|
||||
":fisher_estimator",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:linalg_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "kfac_optimizer_lib",
|
||||
srcs = [
|
||||
"optimizer_lib.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":kfac_optimizer",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "fisher_estimator",
|
||||
srcs = [
|
||||
"estimator.py",
|
||||
"placement.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:util",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "fisher_estimator_lib",
|
||||
srcs = [
|
||||
"estimator_lib.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":fisher_estimator",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "utils",
|
||||
srcs = ["utils.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/tpu",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:linalg_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "utils_lib",
|
||||
srcs = ["utils_lib.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "op_queue",
|
||||
srcs = ["op_queue.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/data/python/ops:dataset_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "op_queue_lib",
|
||||
srcs = ["op_queue_lib.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":op_queue",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
@ -1,183 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Curvature matrix-vector multiplication."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.kfac.python.ops import utils
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class CurvatureMatrixVectorProductComputer(object):
|
||||
"""Class for computing matrix-vector products for Fishers, GGNs and Hessians.
|
||||
|
||||
In other words we compute M*v where M is the matrix, v is the vector, and
|
||||
* refers to standard matrix/vector multiplication (not element-wise
|
||||
multiplication).
|
||||
|
||||
The matrices are defined in terms of some differential quantity of the total
|
||||
loss function with respect to a provided list of tensors ("wrt_tensors").
|
||||
For example, the Fisher associated with a log-prob loss w.r.t. the
|
||||
parameters.
|
||||
|
||||
The 'vecs' argument to each method are lists of tensors that must be the
|
||||
size as the corresponding ones from "wrt_tensors". They represent
|
||||
the vector being multiplied.
|
||||
|
||||
"factors" of the matrix M are defined as matrices B such that B*B^T = M.
|
||||
Methods that multiply by the factor B take a 'loss_inner_vecs' argument
|
||||
instead of 'vecs', which must be a list of tensors with shapes given by the
|
||||
corresponding XXX_inner_shapes property.
|
||||
|
||||
Note that matrix-vector products are not normalized by the batch size, nor
|
||||
are any damping terms added to the results. These things can be easily
|
||||
applied externally, if desired.
|
||||
|
||||
See for example: www.cs.utoronto.ca/~jmartens/docs/HF_book_chapter.pdf
|
||||
and https://arxiv.org/abs/1412.1193 for more information about the
|
||||
generalized Gauss-Newton, Fisher, etc., and how to compute matrix-vector
|
||||
products.
|
||||
"""
|
||||
|
||||
def __init__(self, losses, wrt_tensors):
|
||||
"""Create a CurvatureMatrixVectorProductComputer object.
|
||||
|
||||
Args:
|
||||
losses: A list of LossFunction instances whose sum defines the total loss.
|
||||
wrt_tensors: A list of Tensors to compute the differential quantities
|
||||
(defining the matrices) with respect to. See class description for more
|
||||
info.
|
||||
"""
|
||||
self._losses = losses
|
||||
self._inputs_to_losses = list(loss.inputs for loss in losses)
|
||||
self._inputs_to_losses_flat = nest.flatten(self._inputs_to_losses)
|
||||
self._wrt_tensors = wrt_tensors
|
||||
|
||||
@property
|
||||
def _total_loss(self):
|
||||
return math_ops.add_n(tuple(loss.evaluate() for loss in self._losses))
|
||||
|
||||
# Jacobian multiplication functions:
|
||||
def _multiply_jacobian(self, vecs):
|
||||
"""Multiply vecs by the Jacobian of losses."""
|
||||
# We stop gradients at wrt_tensors to produce partial derivatives (which is
|
||||
# what we want for Jacobians).
|
||||
jacobian_vecs_flat = utils.fwd_gradients(
|
||||
self._inputs_to_losses_flat, self._wrt_tensors, grad_xs=vecs,
|
||||
stop_gradients=self._wrt_tensors)
|
||||
return nest.pack_sequence_as(self._inputs_to_losses, jacobian_vecs_flat)
|
||||
|
||||
def _multiply_jacobian_transpose(self, loss_vecs):
|
||||
"""Multiply vecs by the transpose Jacobian of losses."""
|
||||
loss_vecs_flat = nest.flatten(loss_vecs)
|
||||
# We stop gradients at wrt_tensors to produce partial derivatives (which is
|
||||
# what we want for Jacobians).
|
||||
return gradients_impl.gradients(
|
||||
self._inputs_to_losses_flat, self._wrt_tensors, grad_ys=loss_vecs_flat,
|
||||
stop_gradients=self._wrt_tensors)
|
||||
|
||||
# Losses Fisher/Hessian multiplication functions:
|
||||
def _multiply_loss_fisher(self, loss_vecs):
|
||||
"""Multiply loss_vecs by Fisher of total loss."""
|
||||
return tuple(
|
||||
loss.multiply_fisher(loss_vec)
|
||||
for loss, loss_vec in zip(self._losses, loss_vecs))
|
||||
|
||||
def _multiply_loss_fisher_factor(self, loss_inner_vecs):
|
||||
"""Multiply loss_inner_vecs by factor of Fisher of total loss."""
|
||||
return tuple(
|
||||
loss.multiply_fisher_factor(loss_vec)
|
||||
for loss, loss_vec in zip(self._losses, loss_inner_vecs))
|
||||
|
||||
def _multiply_loss_fisher_factor_transpose(self, loss_vecs):
|
||||
"""Multiply loss_vecs by transpose factor of Fisher of total loss."""
|
||||
return tuple(
|
||||
loss.multiply_fisher_factor_transpose(loss_vec)
|
||||
for loss, loss_vec in zip(self._losses, loss_vecs))
|
||||
|
||||
def _multiply_loss_hessian(self, loss_vecs):
|
||||
"""Multiply loss_vecs by Hessian of total loss."""
|
||||
return tuple(
|
||||
loss.multiply_hessian(loss_vec)
|
||||
for loss, loss_vec in zip(self._losses, loss_vecs))
|
||||
|
||||
def _multiply_loss_hessian_factor(self, loss_inner_vecs):
|
||||
"""Multiply loss_inner_vecs by factor of Hessian of total loss."""
|
||||
return tuple(
|
||||
loss.multiply_hessian_factor(loss_vec)
|
||||
for loss, loss_vec in zip(self._losses, loss_inner_vecs))
|
||||
|
||||
def _multiply_loss_hessian_factor_transpose(self, loss_vecs):
|
||||
"""Multiply loss_vecs by transpose factor of Hessian of total loss."""
|
||||
return tuple(
|
||||
loss.multiply_hessian_factor_transpose(loss_vec)
|
||||
for loss, loss_vec in zip(self._losses, loss_vecs))
|
||||
|
||||
# Matrix-vector product functions:
|
||||
def multiply_fisher(self, vecs):
|
||||
"""Multiply vecs by Fisher of total loss."""
|
||||
jacobian_vecs = self._multiply_jacobian(vecs)
|
||||
loss_fisher_jacobian_vecs = self._multiply_loss_fisher(jacobian_vecs)
|
||||
return self._multiply_jacobian_transpose(loss_fisher_jacobian_vecs)
|
||||
|
||||
def multiply_fisher_factor_transpose(self, vecs):
|
||||
"""Multiply vecs by transpose of factor of Fisher of total loss."""
|
||||
jacobian_vecs = self._multiply_jacobian(vecs)
|
||||
return self._multiply_loss_fisher_factor_transpose(jacobian_vecs)
|
||||
|
||||
def multiply_fisher_factor(self, loss_inner_vecs):
|
||||
"""Multiply loss_inner_vecs by factor of Fisher of total loss."""
|
||||
fisher_factor_transpose_vecs = self._multiply_loss_fisher_factor_transpose(
|
||||
loss_inner_vecs)
|
||||
return self._multiply_jacobian_transpose(fisher_factor_transpose_vecs)
|
||||
|
||||
def multiply_hessian(self, vecs):
|
||||
"""Multiply vecs by Hessian of total loss."""
|
||||
return gradients_impl.gradients(
|
||||
gradients_impl.gradients(self._total_loss, self._wrt_tensors),
|
||||
self._wrt_tensors,
|
||||
grad_ys=vecs)
|
||||
|
||||
def multiply_generalized_gauss_newton(self, vecs):
|
||||
"""Multiply vecs by generalized Gauss-Newton of total loss."""
|
||||
jacobian_vecs = self._multiply_jacobian(vecs)
|
||||
loss_hessian_jacobian_vecs = self._multiply_loss_hessian(jacobian_vecs)
|
||||
return self._multiply_jacobian_transpose(loss_hessian_jacobian_vecs)
|
||||
|
||||
def multiply_generalized_gauss_newton_factor_transpose(self, vecs):
|
||||
"""Multiply vecs by transpose of factor of GGN of total loss."""
|
||||
jacobian_vecs = self._multiply_jacobian(vecs)
|
||||
return self._multiply_loss_hessian_factor_transpose(jacobian_vecs)
|
||||
|
||||
def multiply_generalized_gauss_newton_factor(self, loss_inner_vecs):
|
||||
"""Multiply loss_inner_vecs by factor of GGN of total loss."""
|
||||
hessian_factor_transpose_vecs = (
|
||||
self._multiply_loss_hessian_factor_transpose(loss_inner_vecs))
|
||||
return self._multiply_jacobian_transpose(hessian_factor_transpose_vecs)
|
||||
|
||||
# Shape properties for multiply_XXX_factor methods:
|
||||
@property
|
||||
def fisher_factor_inner_shapes(self):
|
||||
"""Shapes required by multiply_fisher_factor."""
|
||||
return tuple(loss.fisher_factor_inner_shape for loss in self._losses)
|
||||
|
||||
@property
|
||||
def generalized_gauss_newton_factor_inner_shapes(self):
|
||||
"""Shapes required by multiply_generalized_gauss_newton_factor."""
|
||||
return tuple(loss.hessian_factor_inner_shape for loss in self._losses)
|
@ -1,30 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Curvature matrix-vector multiplication."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,line-too-long,wildcard-import
|
||||
from tensorflow.contrib.kfac.python.ops.curvature_matrix_vector_products import *
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
# pylint: enable=unused-import,line-too-long,wildcard-import
|
||||
|
||||
_allowed_symbols = [
|
||||
'CurvatureMatrixVectorProductComputer',
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
|
@ -1,516 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Defines the high-level Fisher estimator class."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.kfac.python.ops import placement
|
||||
from tensorflow.contrib.kfac.python.ops import utils
|
||||
from tensorflow.python.framework import ops as tf_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
# The linter is confused.
|
||||
# pylint: disable=abstract-class-instantiated
|
||||
def make_fisher_estimator(placement_strategy=None, **kwargs):
|
||||
"""Creates Fisher estimator instances based on the placement strategy.
|
||||
|
||||
For example if the `placement_strategy` is 'round_robin' then
|
||||
`FisherEstimatorRoundRobin` instance is returned.
|
||||
|
||||
Args:
|
||||
placement_strategy: `string`, Strategy to be used for placing covariance
|
||||
variables, covariance ops and inverse ops. Check
|
||||
`placement.FisherEstimatorRoundRobin` for a concrete example.
|
||||
**kwargs: Arguments to be passed into `FisherEstimator` class initializer.
|
||||
|
||||
Returns:
|
||||
An instance of class which inherits from `FisherEstimator` and the mixin
|
||||
which implements specific placement strategy. See,
|
||||
`FisherEstimatorRoundRobin` which inherits from `FisherEstimator` and
|
||||
`RoundRobinPlacementMixin`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the `placement_strategy` is not equal to 'round_robin'.
|
||||
"""
|
||||
if placement_strategy in [None, "round_robin"]:
|
||||
return FisherEstimatorRoundRobin(**kwargs)
|
||||
else:
|
||||
raise ValueError("Unimplemented vars and ops "
|
||||
"placement strategy : {}".format(placement_strategy))
|
||||
# pylint: enable=abstract-class-instantiated
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class FisherEstimator(object):
|
||||
"""Fisher estimator class supporting various approximations of the Fisher.
|
||||
|
||||
This is an abstract base class which does not implement a strategy for
|
||||
placing covariance variables, covariance update ops and inverse update ops.
|
||||
The placement strategies are implemented in `placement.py`. See
|
||||
`FisherEstimatorRoundRobin` for example of a concrete subclass with
|
||||
a round-robin placement strategy.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
variables,
|
||||
cov_ema_decay,
|
||||
damping,
|
||||
layer_collection,
|
||||
exps=(-1,),
|
||||
estimation_mode="gradients",
|
||||
colocate_gradients_with_ops=True,
|
||||
name="FisherEstimator",
|
||||
compute_cholesky=False,
|
||||
compute_cholesky_inverse=False):
|
||||
"""Create a FisherEstimator object.
|
||||
|
||||
Args:
|
||||
variables: A `list` of variables or `callable` which returns the variables
|
||||
for which to estimate the Fisher. This must match the variables
|
||||
registered in layer_collection (if it is not None).
|
||||
cov_ema_decay: The decay factor used when calculating the covariance
|
||||
estimate moving averages.
|
||||
damping: float. The damping factor used to stabilize training due to
|
||||
errors in the local approximation with the Fisher information matrix,
|
||||
and to regularize the update direction by making it closer to the
|
||||
gradient. (Higher damping means the update looks more like a standard
|
||||
gradient update - see Tikhonov regularization.)
|
||||
layer_collection: The layer collection object, which holds the Fisher
|
||||
blocks, Kronecker factors, and losses associated with the
|
||||
graph.
|
||||
exps: List of floats or ints. These represent the different matrix
|
||||
powers of the approximate Fisher that the FisherEstimator will be able
|
||||
to multiply vectors by. If the user asks for a matrix power other
|
||||
one of these (or 1, which is always supported), there will be a
|
||||
failure. (Default: (-1,))
|
||||
estimation_mode: The type of estimator to use for the Fishers. Can be
|
||||
'gradients', 'empirical', 'curvature_prop', or 'exact'.
|
||||
(Default: 'gradients'). 'gradients' is the basic estimation approach
|
||||
from the original K-FAC paper. 'empirical' computes the 'empirical'
|
||||
Fisher information matrix (which uses the data's distribution for the
|
||||
targets, as opposed to the true Fisher which uses the model's
|
||||
distribution) and requires that each registered loss have specified
|
||||
targets. 'curvature_propagation' is a method which estimates the
|
||||
Fisher using self-products of random 1/-1 vectors times "half-factors"
|
||||
of the Fisher, as described here: https://arxiv.org/abs/1206.6464 .
|
||||
Finally, 'exact' is the obvious generalization of Curvature
|
||||
Propagation to compute the exact Fisher (modulo any additional
|
||||
diagonal or Kronecker approximations) by looping over one-hot vectors
|
||||
for each coordinate of the output instead of using 1/-1 vectors. It
|
||||
is more expensive to compute than the other three options by a factor
|
||||
equal to the output dimension, roughly speaking.
|
||||
colocate_gradients_with_ops: Whether we should request gradients be
|
||||
colocated with their respective ops. (Default: True)
|
||||
name: A string. A name given to this estimator, which is added to the
|
||||
variable scope when constructing variables and ops.
|
||||
(Default: "FisherEstimator")
|
||||
compute_cholesky: Bool. Whether or not the FisherEstimator will be
|
||||
able to multiply vectors by the Cholesky factor.
|
||||
(Default: False)
|
||||
compute_cholesky_inverse: Bool. Whether or not the FisherEstimator
|
||||
will be able to multiply vectors by the Cholesky factor inverse.
|
||||
(Default: False)
|
||||
Raises:
|
||||
ValueError: If no losses have been registered with layer_collection.
|
||||
"""
|
||||
self._variables = variables
|
||||
self._cov_ema_decay = cov_ema_decay
|
||||
self._damping = damping
|
||||
self._estimation_mode = estimation_mode
|
||||
self._layers = layer_collection
|
||||
self._gradient_fns = {
|
||||
"gradients": self._get_grads_lists_gradients,
|
||||
"empirical": self._get_grads_lists_empirical,
|
||||
"curvature_prop": self._get_grads_lists_curvature_prop,
|
||||
"exact": self._get_grads_lists_exact
|
||||
}
|
||||
self._colocate_gradients_with_ops = colocate_gradients_with_ops
|
||||
|
||||
self._made_vars = False
|
||||
self._exps = exps
|
||||
self._compute_cholesky = compute_cholesky
|
||||
self._compute_cholesky_inverse = compute_cholesky_inverse
|
||||
|
||||
self._name = name
|
||||
|
||||
@property
|
||||
def variables(self):
|
||||
if callable(self._variables):
|
||||
return self._variables()
|
||||
else:
|
||||
return self._variables
|
||||
|
||||
@property
|
||||
def damping(self):
|
||||
return self._damping
|
||||
|
||||
@property
|
||||
def blocks(self):
|
||||
"""All registered FisherBlocks."""
|
||||
return self._layers.get_blocks()
|
||||
|
||||
@property
|
||||
def factors(self):
|
||||
"""All registered FisherFactors."""
|
||||
return self._layers.get_factors()
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@abc.abstractmethod
|
||||
def make_vars_and_create_op_thunks(self, scope=None):
|
||||
"""Make vars and create op thunks with a specific placement strategy.
|
||||
|
||||
For each factor, all of that factor's cov variables and their associated
|
||||
update ops will be placed on a particular device. A new device is chosen
|
||||
for each factor by cycling through list of devices in the cov_devices
|
||||
argument. If cov_devices is None then no explicit device placement occurs.
|
||||
|
||||
An analogous strategy is followed for inverse update ops, with the list of
|
||||
devices being given by the inv_devices argument.
|
||||
|
||||
Inverse variables on the other hand are not placed on any specific device
|
||||
(they will just use the current the device placement context, whatever
|
||||
that happens to be). The idea is that the inverse variable belong where
|
||||
they will be accessed most often, which is the device that actually applies
|
||||
the preconditioner to the gradient. The user will be responsible for setting
|
||||
the device context for this.
|
||||
|
||||
Args:
|
||||
scope: A string or None. If None it will be set to the name of this
|
||||
estimator (given by the name property). All variables will be created,
|
||||
and all thunks will execute, inside of a variable scope of the given
|
||||
name. (Default: None)
|
||||
|
||||
Returns:
|
||||
cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
|
||||
the list of factors given by the "factors" property.
|
||||
inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
|
||||
the list of factors given by the "factors" property.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _apply_transformation(self, vecs_and_vars, transform):
|
||||
"""Applies an block-wise transformation to the corresponding vectors.
|
||||
|
||||
Args:
|
||||
vecs_and_vars: List of (vector, variable) pairs.
|
||||
transform: A function of the form f(fb, vec), where vec is the vector
|
||||
to transform and fb is its corresponding block in the matrix, that
|
||||
returns the transformed vector.
|
||||
|
||||
Returns:
|
||||
A list of (transformed vector, var) pairs in the same order as
|
||||
vecs_and_vars.
|
||||
"""
|
||||
|
||||
vecs = utils.SequenceDict((var, vec) for vec, var in vecs_and_vars)
|
||||
|
||||
trans_vecs = utils.SequenceDict()
|
||||
|
||||
for params, fb in self._layers.fisher_blocks.items():
|
||||
trans_vecs[params] = transform(fb, vecs[params])
|
||||
|
||||
return [(trans_vecs[var], var) for _, var in vecs_and_vars]
|
||||
|
||||
def multiply_inverse(self, vecs_and_vars):
|
||||
"""Multiplies the vecs by the corresponding (damped) inverses of the blocks.
|
||||
|
||||
Args:
|
||||
vecs_and_vars: List of (vector, variable) pairs.
|
||||
|
||||
Returns:
|
||||
A list of (transformed vector, var) pairs in the same order as
|
||||
vecs_and_vars.
|
||||
"""
|
||||
return self.multiply_matpower(-1, vecs_and_vars)
|
||||
|
||||
def multiply(self, vecs_and_vars):
|
||||
"""Multiplies the vectors by the corresponding (damped) blocks.
|
||||
|
||||
Args:
|
||||
vecs_and_vars: List of (vector, variable) pairs.
|
||||
|
||||
Returns:
|
||||
A list of (transformed vector, var) pairs in the same order as
|
||||
vecs_and_vars.
|
||||
"""
|
||||
return self.multiply_matpower(1, vecs_and_vars)
|
||||
|
||||
def multiply_matpower(self, exp, vecs_and_vars):
|
||||
"""Multiplies the vecs by the corresponding matrix powers of the blocks.
|
||||
|
||||
Args:
|
||||
exp: A float representing the power to raise the blocks by before
|
||||
multiplying it by the vector.
|
||||
vecs_and_vars: List of (vector, variable) pairs.
|
||||
|
||||
Returns:
|
||||
A list of (transformed vector, var) pairs in the same order as
|
||||
vecs_and_vars.
|
||||
"""
|
||||
assert exp in self._exps
|
||||
|
||||
fcn = lambda fb, vec: fb.multiply_matpower(vec, exp)
|
||||
return self._apply_transformation(vecs_and_vars, fcn)
|
||||
|
||||
def multiply_cholesky(self, vecs_and_vars, transpose=False):
|
||||
"""Multiplies the vecs by the corresponding Cholesky factors.
|
||||
|
||||
Args:
|
||||
vecs_and_vars: List of (vector, variable) pairs.
|
||||
transpose: Bool. If true the Cholesky factors are transposed before
|
||||
multiplying the vecs. (Default: False)
|
||||
|
||||
Returns:
|
||||
A list of (transformed vector, var) pairs in the same order as
|
||||
vecs_and_vars.
|
||||
"""
|
||||
assert self._compute_cholesky
|
||||
|
||||
fcn = lambda fb, vec: fb.multiply_cholesky(vec, transpose=transpose)
|
||||
return self._apply_transformation(vecs_and_vars, fcn)
|
||||
|
||||
def multiply_cholesky_inverse(self, vecs_and_vars, transpose=False):
|
||||
"""Mults the vecs by the inverses of the corresponding Cholesky factors.
|
||||
|
||||
Note: if you are using Cholesky inverse multiplication to sample from
|
||||
a matrix-variate Gaussian you will want to multiply by the transpose.
|
||||
Let L be the Cholesky factor of F and observe that
|
||||
|
||||
L^-T * L^-1 = (L * L^T)^-1 = F^-1 .
|
||||
|
||||
Thus we want to multiply by L^-T in order to sample from Gaussian with
|
||||
covariance F^-1.
|
||||
|
||||
Args:
|
||||
vecs_and_vars: List of (vector, variable) pairs.
|
||||
transpose: Bool. If true the Cholesky factor inverses are transposed
|
||||
before multiplying the vecs. (Default: False)
|
||||
|
||||
Returns:
|
||||
A list of (transformed vector, var) pairs in the same order as
|
||||
vecs_and_vars.
|
||||
"""
|
||||
assert self._compute_cholesky_inverse
|
||||
|
||||
fcn = lambda fb, vec: fb.multiply_cholesky_inverse(vec, transpose=transpose)
|
||||
return self._apply_transformation(vecs_and_vars, fcn)
|
||||
|
||||
def _instantiate_factors(self):
|
||||
"""Instantiates FisherFactors' variables.
|
||||
|
||||
Raises:
|
||||
ValueError: If estimation_mode was improperly specified at construction.
|
||||
"""
|
||||
blocks = self.blocks
|
||||
tensors_to_compute_grads = [
|
||||
block.tensors_to_compute_grads() for block in blocks
|
||||
]
|
||||
|
||||
try:
|
||||
grads_lists = self._gradient_fns[self._estimation_mode](
|
||||
tensors_to_compute_grads)
|
||||
except KeyError:
|
||||
raise ValueError("Unrecognized value {} for estimation_mode.".format(
|
||||
self._estimation_mode))
|
||||
|
||||
for grads_list, block in zip(grads_lists, blocks):
|
||||
block.instantiate_factors(grads_list, self.damping)
|
||||
|
||||
def _check_vars_unmade_and_set_made_flag(self):
|
||||
if self._made_vars:
|
||||
raise Exception("Already made variables.")
|
||||
self._made_vars = True
|
||||
|
||||
def made_vars(self):
|
||||
return self._made_vars
|
||||
|
||||
def _register_matrix_functions(self):
|
||||
for block in self.blocks:
|
||||
for exp in self._exps:
|
||||
block.register_matpower(exp)
|
||||
if self._compute_cholesky:
|
||||
block.register_cholesky()
|
||||
if self._compute_cholesky_inverse:
|
||||
block.register_cholesky_inverse()
|
||||
|
||||
def _finalize_layer_collection(self):
|
||||
self._layers.create_subgraph()
|
||||
self._layers.check_registration(self.variables)
|
||||
self._instantiate_factors()
|
||||
self._register_matrix_functions()
|
||||
|
||||
def create_ops_and_vars_thunks(self, scope=None):
|
||||
"""Create thunks that make the ops and vars on demand.
|
||||
|
||||
This function returns 4 lists of thunks: cov_variable_thunks,
|
||||
cov_update_thunks, inv_variable_thunks, and inv_update_thunks.
|
||||
|
||||
The length of each list is the number of factors and the i-th element of
|
||||
each list corresponds to the i-th factor (given by the "factors" property).
|
||||
|
||||
Note that the execution of these thunks must happen in a certain
|
||||
partial order. The i-th element of cov_variable_thunks must execute
|
||||
before the i-th element of cov_update_thunks (and also the i-th element
|
||||
of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks
|
||||
must execute before the i-th element of inv_update_thunks.
|
||||
|
||||
TL;DR (oversimplified): Execute the thunks according to the order that
|
||||
they are returned.
|
||||
|
||||
Args:
|
||||
scope: A string or None. If None it will be set to the name of this
|
||||
estimator (given by the name property). All thunks will execute inside
|
||||
of a variable scope of the given name. (Default: None)
|
||||
Returns:
|
||||
cov_variable_thunks: A list of thunks that make the cov variables.
|
||||
cov_update_thunks: A list of thunks that make the cov update ops.
|
||||
inv_variable_thunks: A list of thunks that make the inv variables.
|
||||
inv_update_thunks: A list of thunks that make the inv update ops.
|
||||
"""
|
||||
self._check_vars_unmade_and_set_made_flag()
|
||||
|
||||
self._finalize_layer_collection()
|
||||
|
||||
scope = self.name if scope is None else scope
|
||||
|
||||
cov_variable_thunks = [
|
||||
self._create_cov_variable_thunk(factor, scope)
|
||||
for factor in self.factors
|
||||
]
|
||||
cov_update_thunks = [
|
||||
self._create_cov_update_thunk(factor, scope) for factor in self.factors
|
||||
]
|
||||
inv_variable_thunks = [
|
||||
self._create_inv_variable_thunk(factor, scope)
|
||||
for factor in self.factors
|
||||
]
|
||||
inv_update_thunks = [
|
||||
self._create_inv_update_thunk(factor, scope) for factor in self.factors
|
||||
]
|
||||
|
||||
return (cov_variable_thunks, cov_update_thunks,
|
||||
inv_variable_thunks, inv_update_thunks)
|
||||
|
||||
def _create_cov_variable_thunk(self, factor, scope):
|
||||
"""Constructs a covariance variable thunk for a single FisherFactor."""
|
||||
|
||||
def thunk():
|
||||
with variable_scope.variable_scope(scope):
|
||||
return factor.instantiate_cov_variables()
|
||||
|
||||
return thunk
|
||||
|
||||
def _create_cov_update_thunk(self, factor, scope):
|
||||
"""Constructs a covariance update thunk for a single FisherFactor."""
|
||||
|
||||
def thunk():
|
||||
with variable_scope.variable_scope(scope):
|
||||
return factor.make_covariance_update_op(self._cov_ema_decay)
|
||||
|
||||
return thunk
|
||||
|
||||
def _create_inv_variable_thunk(self, factor, scope):
|
||||
"""Constructs a inverse variable thunk for a single FisherFactor."""
|
||||
|
||||
def thunk():
|
||||
with variable_scope.variable_scope(scope):
|
||||
return factor.instantiate_inv_variables()
|
||||
|
||||
return thunk
|
||||
|
||||
def _create_inv_update_thunk(self, factor, scope):
|
||||
"""Constructs an inverse update thunk for a single FisherFactor."""
|
||||
|
||||
def thunk():
|
||||
with variable_scope.variable_scope(scope):
|
||||
return control_flow_ops.group(factor.make_inverse_update_ops())
|
||||
|
||||
return thunk
|
||||
|
||||
def _get_grads_lists_gradients(self, tensors):
|
||||
# Passing in a list of loss values is better than passing in the sum as
|
||||
# the latter creates unnessesary ops on the default device
|
||||
grads_flat = gradients_impl.gradients(
|
||||
self._layers.eval_losses_on_samples(),
|
||||
nest.flatten(tensors),
|
||||
colocate_gradients_with_ops=self._colocate_gradients_with_ops)
|
||||
grads_all = nest.pack_sequence_as(tensors, grads_flat)
|
||||
return tuple((grad,) for grad in grads_all)
|
||||
|
||||
def _get_grads_lists_empirical(self, tensors):
|
||||
# Passing in a list of loss values is better than passing in the sum as
|
||||
# the latter creates unnecessary ops on the default device
|
||||
grads_flat = gradients_impl.gradients(
|
||||
self._layers.eval_losses(),
|
||||
nest.flatten(tensors),
|
||||
colocate_gradients_with_ops=self._colocate_gradients_with_ops)
|
||||
grads_all = nest.pack_sequence_as(tensors, grads_flat)
|
||||
return tuple((grad,) for grad in grads_all)
|
||||
|
||||
def _get_transformed_random_signs(self):
|
||||
transformed_random_signs = []
|
||||
for loss in self._layers.losses:
|
||||
with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]):
|
||||
transformed_random_signs.append(
|
||||
loss.multiply_fisher_factor(
|
||||
utils.generate_random_signs(loss.fisher_factor_inner_shape)))
|
||||
return transformed_random_signs
|
||||
|
||||
def _get_grads_lists_curvature_prop(self, tensors):
|
||||
loss_inputs = list(loss.inputs for loss in self._layers.losses)
|
||||
transformed_random_signs = self._get_transformed_random_signs()
|
||||
grads_flat = gradients_impl.gradients(
|
||||
nest.flatten(loss_inputs),
|
||||
nest.flatten(tensors),
|
||||
grad_ys=nest.flatten(transformed_random_signs),
|
||||
colocate_gradients_with_ops=self._colocate_gradients_with_ops)
|
||||
grads_all = nest.pack_sequence_as(tensors, grads_flat)
|
||||
return tuple((grad,) for grad in grads_all)
|
||||
|
||||
def _get_grads_lists_exact(self, tensors):
|
||||
"""No docstring required."""
|
||||
# Loop over all coordinates of all losses.
|
||||
grads_all = []
|
||||
for loss in self._layers.losses:
|
||||
with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]):
|
||||
for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]):
|
||||
transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot(
|
||||
index)
|
||||
grads_flat = gradients_impl.gradients(
|
||||
loss.inputs,
|
||||
nest.flatten(tensors),
|
||||
grad_ys=transformed_one_hot,
|
||||
colocate_gradients_with_ops=self._colocate_gradients_with_ops)
|
||||
grads_all.append(nest.pack_sequence_as(tensors, grads_flat))
|
||||
return zip(*grads_all)
|
||||
|
||||
|
||||
class FisherEstimatorRoundRobin(placement.RoundRobinPlacementMixin,
|
||||
FisherEstimator):
|
||||
"""Fisher estimator which provides round robin device placement strategy."""
|
||||
pass
|
@ -1,31 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Defines the high-level Fisher estimator class."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,line-too-long,wildcard-import
|
||||
from tensorflow.contrib.kfac.python.ops.estimator import *
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
# pylint: enable=unused-import,line-too-long,wildcard-import
|
||||
|
||||
_allowed_symbols = [
|
||||
'FisherEstimator',
|
||||
'make_fisher_estimator',
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
|
File diff suppressed because it is too large
Load Diff
@ -1,45 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""FisherBlock definitions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,line-too-long,wildcard-import
|
||||
from tensorflow.contrib.kfac.python.ops.fisher_blocks import *
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
# pylint: enable=unused-import,line-too-long,wildcard-import
|
||||
|
||||
_allowed_symbols = [
|
||||
'FisherBlock',
|
||||
'FullFB',
|
||||
'NaiveDiagonalFB',
|
||||
'FullyConnectedDiagonalFB',
|
||||
'KroneckerProductFB',
|
||||
'EmbeddingKFACFB',
|
||||
'FullyConnectedKFACBasicFB',
|
||||
'ConvKFCBasicFB',
|
||||
'ConvDiagonalFB',
|
||||
'set_global_constants',
|
||||
'compute_pi_tracenorm',
|
||||
'compute_pi_adjusted_damping',
|
||||
'num_conv_locations',
|
||||
'normalize_damping',
|
||||
'LEFT_MULTIPLY',
|
||||
'RIGHT_MULTIPLY',
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
|
File diff suppressed because it is too large
Load Diff
@ -1,38 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""FisherFactor definitions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,line-too-long,wildcard-import
|
||||
from tensorflow.contrib.kfac.python.ops.fisher_factors import *
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
# pylint: enable=unused-import,line-too-long,wildcard-import
|
||||
|
||||
_allowed_symbols = [
|
||||
"inverse_initializer", "covariance_initializer",
|
||||
"diagonal_covariance_initializer", "scope_string_from_params",
|
||||
"scope_string_from_name", "scalar_or_tensor_to_string", "FisherFactor",
|
||||
"InverseProvidingFactor", "FullFactor", "DiagonalFactor",
|
||||
"NaiveDiagonalFactor", "EmbeddingInputKroneckerFactor",
|
||||
"FullyConnectedDiagonalFactor", "FullyConnectedKroneckerFactor",
|
||||
"ConvInputKroneckerFactor", "ConvOutputKroneckerFactor",
|
||||
"ConvDiagonalFactor", "set_global_constants", "maybe_colocate_with",
|
||||
"compute_cov", "append_homog"
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
|
File diff suppressed because it is too large
Load Diff
@ -1,46 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Registry for layers and their parameters/variables.
|
||||
|
||||
This represents the collection of all layers in the approximate Fisher
|
||||
information matrix to which a particular FisherBlock may belong. That is, we
|
||||
might have several layer collections for one TF graph (if we have multiple K-FAC
|
||||
optimizers being used, for example.)
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,line-too-long,wildcard-import
|
||||
from tensorflow.contrib.kfac.python.ops.layer_collection import *
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
# pylint: enable=unused-import,line-too-long,wildcard-import
|
||||
|
||||
_allowed_symbols = [
|
||||
"get_default_layer_collection",
|
||||
"set_default_layer_collection",
|
||||
"LayerParametersDict",
|
||||
"LayerCollection",
|
||||
"APPROX_KRONECKER_NAME",
|
||||
"APPROX_DIAGONAL_NAME",
|
||||
"APPROX_FULL_NAME",
|
||||
"VARIABLE_SCOPE",
|
||||
"APPROX_KRONECKER_INDEP_NAME",
|
||||
"APPROX_KRONECKER_SERIES_1_NAME",
|
||||
"APPROX_KRONECKER_SERIES_2_NAME"
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
|
@ -1,95 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""SmartMatrices definitions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.kfac.python.ops import utils
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.linalg import linalg
|
||||
from tensorflow.python.ops.linalg import linalg_impl
|
||||
from tensorflow.python.ops.linalg import linear_operator_util as lou
|
||||
|
||||
|
||||
class LinearOperatorExtras(object): # pylint: disable=missing-docstring
|
||||
|
||||
def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
|
||||
|
||||
with self._name_scope(name, values=[x]):
|
||||
if isinstance(x, ops.IndexedSlices):
|
||||
return self._matmul_sparse(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
self._check_input_dtype(x)
|
||||
|
||||
self_dim = -2 if adjoint else -1
|
||||
arg_dim = -1 if adjoint_arg else -2
|
||||
self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim])
|
||||
|
||||
return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
|
||||
def matmul_right(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
|
||||
|
||||
with self._name_scope(name, values=[x]):
|
||||
|
||||
if isinstance(x, ops.IndexedSlices):
|
||||
return self._matmul_right_sparse(
|
||||
x, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
self._check_input_dtype(x)
|
||||
|
||||
self_dim = -1 if adjoint else -2
|
||||
arg_dim = -2 if adjoint_arg else -1
|
||||
self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim])
|
||||
|
||||
return self._matmul_right(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
|
||||
|
||||
class LinearOperatorFullMatrix(LinearOperatorExtras,
|
||||
linalg.LinearOperatorFullMatrix):
|
||||
|
||||
# TODO(b/78117889) Remove this definition once core LinearOperator
|
||||
# has _matmul_right.
|
||||
def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
|
||||
return lou.matmul_with_broadcast(
|
||||
x, self._matrix, adjoint_a=adjoint_arg, adjoint_b=adjoint)
|
||||
|
||||
def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
|
||||
raise NotImplementedError
|
||||
|
||||
def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):
|
||||
assert not adjoint and not adjoint_arg
|
||||
return utils.matmul_sparse_dense(x, self._matrix)
|
||||
|
||||
|
||||
class LinearOperatorDiag(LinearOperatorExtras, # pylint: disable=missing-docstring
|
||||
linalg.LinearOperatorDiag):
|
||||
|
||||
def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
|
||||
diag_mat = math_ops.conj(self._diag) if adjoint else self._diag
|
||||
x = linalg_impl.adjoint(x) if adjoint_arg else x
|
||||
return diag_mat * x
|
||||
|
||||
def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
|
||||
diag_mat = math_ops.conj(self._diag) if adjoint else self._diag
|
||||
assert not adjoint_arg
|
||||
return utils.matmul_diag_sparse(diag_mat, x)
|
||||
|
||||
def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):
|
||||
raise NotImplementedError
|
@ -1,754 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Loss functions to be used by LayerCollection."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import onehot_categorical
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.distributions import bernoulli
|
||||
from tensorflow.python.ops.distributions import categorical
|
||||
from tensorflow.python.ops.distributions import normal
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class LossFunction(object):
|
||||
"""Abstract base class for loss functions.
|
||||
|
||||
Note that unlike typical loss functions used in neural networks these are
|
||||
summed and not averaged across cases in the batch, since this is what the
|
||||
users of this class (FisherEstimator and MatrixVectorProductComputer) will
|
||||
be expecting. The implication of this is that you will may want to
|
||||
normalize things like Fisher-vector products by the batch size when you
|
||||
use this class. It depends on the use case.
|
||||
"""
|
||||
|
||||
@abc.abstractproperty
|
||||
def targets(self):
|
||||
"""The targets being predicted by the model.
|
||||
|
||||
Returns:
|
||||
None or Tensor of appropriate shape for calling self._evaluate() on.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
def inputs(self):
|
||||
"""The inputs to the loss function (excluding the targets)."""
|
||||
pass
|
||||
|
||||
def evaluate(self):
|
||||
"""Evaluate the loss function on the targets."""
|
||||
if self.targets is not None:
|
||||
# We treat the targets as "constant". It's only the inputs that get
|
||||
# "back-propped" through.
|
||||
return self._evaluate(array_ops.stop_gradient(self.targets))
|
||||
else:
|
||||
raise Exception("Cannot evaluate losses with unspecified targets.")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _evaluate(self, targets):
|
||||
"""Evaluates the negative log probability of the targets.
|
||||
|
||||
Args:
|
||||
targets: Tensor that distribution can calculate log_prob() of.
|
||||
|
||||
Returns:
|
||||
negative log probability of each target, summed across all targets.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def multiply_hessian(self, vector):
|
||||
"""Right-multiply a vector by the Hessian.
|
||||
|
||||
Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
|
||||
of the loss function with respect to its inputs.
|
||||
|
||||
Args:
|
||||
vector: The vector to multiply. Must be the same shape(s) as the
|
||||
'inputs' property.
|
||||
|
||||
Returns:
|
||||
The vector right-multiplied by the Hessian. Will be of the same shape(s)
|
||||
as the 'inputs' property.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def multiply_hessian_factor(self, vector):
|
||||
"""Right-multiply a vector by a factor B of the Hessian.
|
||||
|
||||
Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
|
||||
of the loss function with respect to its inputs. Typically this will be
|
||||
block-diagonal across different cases in the batch, since the loss function
|
||||
is typically summed across cases.
|
||||
|
||||
Note that B can be any matrix satisfying B * B^T = H where H is the Hessian,
|
||||
but will agree with the one used in the other methods of this class.
|
||||
|
||||
Args:
|
||||
vector: The vector to multiply. Must be of the shape given by the
|
||||
'hessian_factor_inner_shape' property.
|
||||
|
||||
Returns:
|
||||
The vector right-multiplied by B. Will be of the same shape(s) as the
|
||||
'inputs' property.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def multiply_hessian_factor_transpose(self, vector):
|
||||
"""Right-multiply a vector by the transpose of a factor B of the Hessian.
|
||||
|
||||
Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
|
||||
of the loss function with respect to its inputs. Typically this will be
|
||||
block-diagonal across different cases in the batch, since the loss function
|
||||
is typically summed across cases.
|
||||
|
||||
Note that B can be any matrix satisfying B * B^T = H where H is the Hessian,
|
||||
but will agree with the one used in the other methods of this class.
|
||||
|
||||
Args:
|
||||
vector: The vector to multiply. Must be the same shape(s) as the
|
||||
'inputs' property.
|
||||
|
||||
Returns:
|
||||
The vector right-multiplied by B^T. Will be of the shape given by the
|
||||
'hessian_factor_inner_shape' property.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def multiply_hessian_factor_replicated_one_hot(self, index):
|
||||
"""Right-multiply a replicated-one-hot vector by a factor B of the Hessian.
|
||||
|
||||
Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
|
||||
of the loss function with respect to its inputs. Typically this will be
|
||||
block-diagonal across different cases in the batch, since the loss function
|
||||
is typically summed across cases.
|
||||
|
||||
A 'replicated-one-hot' vector means a tensor which, for each slice along the
|
||||
batch dimension (assumed to be dimension 0), is 1.0 in the entry
|
||||
corresponding to the given index and 0 elsewhere.
|
||||
|
||||
Note that B can be any matrix satisfying B * B^T = H where H is the Hessian,
|
||||
but will agree with the one used in the other methods of this class.
|
||||
|
||||
Args:
|
||||
index: A tuple representing in the index of the entry in each slice that
|
||||
is 1.0. Note that len(index) must be equal to the number of elements
|
||||
of the 'hessian_factor_inner_shape' tensor minus one.
|
||||
|
||||
Returns:
|
||||
The vector right-multiplied by B^T. Will be of the same shape(s) as the
|
||||
'inputs' property.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
def hessian_factor_inner_shape(self):
|
||||
"""The shape of the tensor returned by multiply_hessian_factor."""
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
def hessian_factor_inner_static_shape(self):
|
||||
"""Static version of hessian_factor_inner_shape."""
|
||||
pass
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class NegativeLogProbLoss(LossFunction):
|
||||
"""Abstract base class for loss functions that are negative log probs."""
|
||||
|
||||
def __init__(self, seed=None):
|
||||
self._default_seed = seed
|
||||
super(NegativeLogProbLoss, self).__init__()
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return self.params
|
||||
|
||||
@abc.abstractproperty
|
||||
def params(self):
|
||||
"""Parameters to the underlying distribution."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def multiply_fisher(self, vector):
|
||||
"""Right-multiply a vector by the Fisher.
|
||||
|
||||
Args:
|
||||
vector: The vector to multiply. Must be the same shape(s) as the
|
||||
'inputs' property.
|
||||
|
||||
Returns:
|
||||
The vector right-multiplied by the Fisher. Will be of the same shape(s)
|
||||
as the 'inputs' property.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def multiply_fisher_factor(self, vector):
|
||||
"""Right-multiply a vector by a factor B of the Fisher.
|
||||
|
||||
Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
|
||||
product of gradients) with respect to the parameters of the underlying
|
||||
probability distribution (whose log-prob defines the loss). Typically this
|
||||
will be block-diagonal across different cases in the batch, since the
|
||||
distribution is usually (but not always) conditionally iid across different
|
||||
cases.
|
||||
|
||||
Note that B can be any matrix satisfying B * B^T = F where F is the Fisher,
|
||||
but will agree with the one used in the other methods of this class.
|
||||
|
||||
Args:
|
||||
vector: The vector to multiply. Must be of the shape given by the
|
||||
'fisher_factor_inner_shape' property.
|
||||
|
||||
Returns:
|
||||
The vector right-multiplied by B. Will be of the same shape(s) as the
|
||||
'inputs' property.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def multiply_fisher_factor_transpose(self, vector):
|
||||
"""Right-multiply a vector by the transpose of a factor B of the Fisher.
|
||||
|
||||
Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
|
||||
product of gradients) with respect to the parameters of the underlying
|
||||
probability distribution (whose log-prob defines the loss). Typically this
|
||||
will be block-diagonal across different cases in the batch, since the
|
||||
distribution is usually (but not always) conditionally iid across different
|
||||
cases.
|
||||
|
||||
Note that B can be any matrix satisfying B * B^T = F where F is the Fisher,
|
||||
but will agree with the one used in the other methods of this class.
|
||||
|
||||
Args:
|
||||
vector: The vector to multiply. Must be the same shape(s) as the
|
||||
'inputs' property.
|
||||
|
||||
Returns:
|
||||
The vector right-multiplied by B^T. Will be of the shape given by the
|
||||
'fisher_factor_inner_shape' property.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def multiply_fisher_factor_replicated_one_hot(self, index):
|
||||
"""Right-multiply a replicated-one-hot vector by a factor B of the Fisher.
|
||||
|
||||
Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
|
||||
product of gradients) with respect to the parameters of the underlying
|
||||
probability distribution (whose log-prob defines the loss). Typically this
|
||||
will be block-diagonal across different cases in the batch, since the
|
||||
distribution is usually (but not always) conditionally iid across different
|
||||
cases.
|
||||
|
||||
A 'replicated-one-hot' vector means a tensor which, for each slice along the
|
||||
batch dimension (assumed to be dimension 0), is 1.0 in the entry
|
||||
corresponding to the given index and 0 elsewhere.
|
||||
|
||||
Note that B can be any matrix satisfying B * B^T = H where H is the Fisher,
|
||||
but will agree with the one used in the other methods of this class.
|
||||
|
||||
Args:
|
||||
index: A tuple representing in the index of the entry in each slice that
|
||||
is 1.0. Note that len(index) must be equal to the number of elements
|
||||
of the 'fisher_factor_inner_shape' tensor minus one.
|
||||
|
||||
Returns:
|
||||
The vector right-multiplied by B. Will be of the same shape(s) as the
|
||||
'inputs' property.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
def fisher_factor_inner_shape(self):
|
||||
"""The shape of the tensor returned by multiply_fisher_factor."""
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
def fisher_factor_inner_static_shape(self):
|
||||
"""Static version of fisher_factor_inner_shape."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def sample(self, seed):
|
||||
"""Sample 'targets' from the underlying distribution."""
|
||||
pass
|
||||
|
||||
def evaluate_on_sample(self, seed=None):
|
||||
"""Evaluates the log probability on a random sample.
|
||||
|
||||
Args:
|
||||
seed: int or None. Random seed for this draw from the distribution.
|
||||
|
||||
Returns:
|
||||
Log probability of sampled targets, summed across examples.
|
||||
"""
|
||||
if seed is None:
|
||||
seed = self._default_seed
|
||||
# We treat the targets as "constant". It's only the inputs that get
|
||||
# "back-propped" through.
|
||||
return self._evaluate(array_ops.stop_gradient(self.sample(seed)))
|
||||
|
||||
|
||||
# TODO(jamesmartens): should this just inherit from object to avoid "diamond"
|
||||
# inheritance, or is there a better way?
|
||||
class NaturalParamsNegativeLogProbLoss(NegativeLogProbLoss):
|
||||
"""Base class for neg log prob losses whose inputs are 'natural' parameters.
|
||||
|
||||
Note that the Hessian and Fisher for natural parameters of exponential-
|
||||
family models are the same, hence the purpose of this class.
|
||||
See here: https://arxiv.org/abs/1412.1193
|
||||
|
||||
'Natural parameters' are defined for exponential-family models. See for
|
||||
example: https://en.wikipedia.org/wiki/Exponential_family
|
||||
"""
|
||||
|
||||
def multiply_hessian(self, vector):
|
||||
return self.multiply_fisher(vector)
|
||||
|
||||
def multiply_hessian_factor(self, vector):
|
||||
return self.multiply_fisher_factor(vector)
|
||||
|
||||
def multiply_hessian_factor_transpose(self, vector):
|
||||
return self.multiply_fisher_factor_transpose(vector)
|
||||
|
||||
def multiply_hessian_factor_replicated_one_hot(self, index):
|
||||
return self.multiply_fisher_factor_replicated_one_hot(index)
|
||||
|
||||
@property
|
||||
def hessian_factor_inner_shape(self):
|
||||
return self.fisher_factor_inner_shape
|
||||
|
||||
@property
|
||||
def hessian_factor_inner_static_shape(self):
|
||||
return self.fisher_factor_inner_shape
|
||||
|
||||
|
||||
class DistributionNegativeLogProbLoss(NegativeLogProbLoss):
|
||||
"""Base class for neg log prob losses that use the TF Distribution classes."""
|
||||
|
||||
def __init__(self, seed=None):
|
||||
super(DistributionNegativeLogProbLoss, self).__init__(seed=seed)
|
||||
|
||||
@abc.abstractproperty
|
||||
def dist(self):
|
||||
"""The underlying tf.distributions.Distribution."""
|
||||
pass
|
||||
|
||||
def _evaluate(self, targets):
|
||||
return -math_ops.reduce_sum(self.dist.log_prob(targets))
|
||||
|
||||
def sample(self, seed):
|
||||
return self.dist.sample(seed=seed)
|
||||
|
||||
|
||||
class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss,
|
||||
NaturalParamsNegativeLogProbLoss):
|
||||
"""Neg log prob loss for a normal distribution parameterized by a mean vector.
|
||||
|
||||
|
||||
Note that the covariance is treated as a constant 'var' times the identity.
|
||||
Also note that the Fisher for such a normal distribution with respect the mean
|
||||
parameter is given by:
|
||||
|
||||
F = (1/var) * I
|
||||
|
||||
See for example https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf.
|
||||
"""
|
||||
|
||||
def __init__(self, mean, var=0.5, targets=None, seed=None):
|
||||
self._mean = mean
|
||||
self._var = var
|
||||
self._targets = targets
|
||||
super(NormalMeanNegativeLogProbLoss, self).__init__(seed=seed)
|
||||
|
||||
@property
|
||||
def targets(self):
|
||||
return self._targets
|
||||
|
||||
@property
|
||||
def dist(self):
|
||||
return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._var))
|
||||
|
||||
@property
|
||||
def params(self):
|
||||
return self._mean
|
||||
|
||||
def multiply_fisher(self, vector):
|
||||
return (1. / self._var) * vector
|
||||
|
||||
def multiply_fisher_factor(self, vector):
|
||||
return self._var**-0.5 * vector
|
||||
|
||||
def multiply_fisher_factor_transpose(self, vector):
|
||||
return self.multiply_fisher_factor(vector) # it's symmetric in this case
|
||||
|
||||
def multiply_fisher_factor_replicated_one_hot(self, index):
|
||||
assert len(index) == 1, "Length of index was {}".format(len(index))
|
||||
ones_slice = array_ops.expand_dims(
|
||||
array_ops.ones(array_ops.shape(self._mean)[:1], dtype=self._mean.dtype),
|
||||
axis=-1)
|
||||
output_slice = self._var**-0.5 * ones_slice
|
||||
return insert_slice_in_zeros(output_slice, 1, int(self._mean.shape[1]),
|
||||
index[0])
|
||||
|
||||
@property
|
||||
def fisher_factor_inner_shape(self):
|
||||
return array_ops.shape(self._mean)
|
||||
|
||||
@property
|
||||
def fisher_factor_inner_static_shape(self):
|
||||
return self._mean.shape
|
||||
|
||||
|
||||
class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss):
|
||||
"""Negative log prob loss for a normal distribution with mean and variance.
|
||||
|
||||
This class parameterizes a multivariate normal distribution with n independent
|
||||
dimensions. Unlike `NormalMeanNegativeLogProbLoss`, this class does not
|
||||
assume the variance is held constant. The Fisher Information for n = 1
|
||||
is given by,
|
||||
|
||||
F = [[1 / variance, 0],
|
||||
[ 0, 0.5 / variance^2]]
|
||||
|
||||
where the parameters of the distribution are concatenated into a single
|
||||
vector as [mean, variance]. For n > 1, the mean parameter vector is
|
||||
concatenated with the variance parameter vector.
|
||||
|
||||
See https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf for derivation.
|
||||
"""
|
||||
|
||||
def __init__(self, mean, variance, targets=None, seed=None):
|
||||
assert len(mean.shape) == 2, "Expect 2D mean tensor."
|
||||
assert len(variance.shape) == 2, "Expect 2D variance tensor."
|
||||
self._mean = mean
|
||||
self._variance = variance
|
||||
self._targets = targets
|
||||
super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(seed=seed)
|
||||
|
||||
@property
|
||||
def targets(self):
|
||||
return self._targets
|
||||
|
||||
@property
|
||||
def dist(self):
|
||||
return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._variance))
|
||||
|
||||
@property
|
||||
def params(self):
|
||||
return self._mean, self._variance
|
||||
|
||||
def _concat(self, mean, variance):
|
||||
return array_ops.concat([mean, variance], axis=-1)
|
||||
|
||||
def _split(self, params):
|
||||
return array_ops.split(params, 2, axis=-1)
|
||||
|
||||
@property
|
||||
def _fisher_mean(self):
|
||||
return 1. / self._variance
|
||||
|
||||
@property
|
||||
def _fisher_mean_factor(self):
|
||||
return 1. / math_ops.sqrt(self._variance)
|
||||
|
||||
@property
|
||||
def _fisher_var(self):
|
||||
return 1. / (2 * math_ops.square(self._variance))
|
||||
|
||||
@property
|
||||
def _fisher_var_factor(self):
|
||||
return 1. / (math_ops.sqrt(2.) * self._variance)
|
||||
|
||||
def multiply_fisher(self, vecs):
|
||||
mean_vec, var_vec = vecs
|
||||
return (self._fisher_mean * mean_vec, self._fisher_var * var_vec)
|
||||
|
||||
def multiply_fisher_factor(self, vecs):
|
||||
mean_vec, var_vec = self._split(vecs)
|
||||
return (self._fisher_mean_factor * mean_vec,
|
||||
self._fisher_var_factor * var_vec)
|
||||
|
||||
def multiply_fisher_factor_transpose(self, vecs):
|
||||
mean_vec, var_vec = vecs
|
||||
return self._concat(self._fisher_mean_factor * mean_vec,
|
||||
self._fisher_var_factor * var_vec)
|
||||
|
||||
def multiply_fisher_factor_replicated_one_hot(self, index):
|
||||
assert len(index) == 1, "Length of index was {}".format(len(index))
|
||||
index = index[0]
|
||||
|
||||
if index < int(self._mean.shape[-1]):
|
||||
# Index corresponds to mean parameter.
|
||||
mean_slice = self._fisher_mean_factor[:, index]
|
||||
mean_slice = array_ops.expand_dims(mean_slice, axis=-1)
|
||||
mean_output = insert_slice_in_zeros(mean_slice, 1, int(
|
||||
self._mean.shape[1]), index)
|
||||
var_output = array_ops.zeros_like(mean_output)
|
||||
else:
|
||||
index -= int(self._mean.shape[-1])
|
||||
# Index corresponds to variance parameter.
|
||||
var_slice = self._fisher_var_factor[:, index]
|
||||
var_slice = array_ops.expand_dims(var_slice, axis=-1)
|
||||
var_output = insert_slice_in_zeros(var_slice, 1,
|
||||
int(self._variance.shape[1]), index)
|
||||
mean_output = array_ops.zeros_like(var_output)
|
||||
|
||||
return mean_output, var_output
|
||||
|
||||
@property
|
||||
def fisher_factor_inner_shape(self):
|
||||
return array_ops.concat(
|
||||
[
|
||||
array_ops.shape(self._mean)[:-1],
|
||||
2 * array_ops.shape(self._mean)[-1:]
|
||||
],
|
||||
axis=0)
|
||||
|
||||
@property
|
||||
def fisher_factor_inner_static_shape(self):
|
||||
shape = self._mean.shape.as_list()
|
||||
return tensor_shape.TensorShape(shape[-1:] + [2 * shape[-1]])
|
||||
|
||||
def multiply_hessian(self, vector):
|
||||
raise NotImplementedError()
|
||||
|
||||
def multiply_hessian_factor(self, vector):
|
||||
raise NotImplementedError()
|
||||
|
||||
def multiply_hessian_factor_transpose(self, vector):
|
||||
raise NotImplementedError()
|
||||
|
||||
def multiply_hessian_factor_replicated_one_hot(self, index):
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def hessian_factor_inner_shape(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def hessian_factor_inner_static_shape(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss,
|
||||
NaturalParamsNegativeLogProbLoss):
|
||||
"""Neg log prob loss for a categorical distribution parameterized by logits.
|
||||
|
||||
|
||||
Note that the Fisher (for a single case) of a categorical distribution, with
|
||||
respect to the natural parameters (i.e. the logits), is given by:
|
||||
|
||||
F = diag(p) - p*p^T
|
||||
|
||||
where p = softmax(logits). F can be factorized as F = B * B^T where
|
||||
|
||||
B = diag(q) - p*q^T
|
||||
|
||||
where q is the entry-wise square root of p. This is easy to verify using the
|
||||
fact that q^T*q = 1.
|
||||
"""
|
||||
|
||||
def __init__(self, logits, targets=None, seed=None):
|
||||
"""Instantiates a CategoricalLogitsNegativeLogProbLoss.
|
||||
|
||||
Args:
|
||||
logits: Tensor of shape [batch_size, output_size]. Parameters for
|
||||
underlying distribution.
|
||||
targets: None or Tensor of shape [output_size]. Each elements contains an
|
||||
index in [0, output_size).
|
||||
seed: int or None. Default random seed when sampling.
|
||||
"""
|
||||
self._logits = logits
|
||||
self._targets = targets
|
||||
super(CategoricalLogitsNegativeLogProbLoss, self).__init__(seed=seed)
|
||||
|
||||
@property
|
||||
def targets(self):
|
||||
return self._targets
|
||||
|
||||
@property
|
||||
def dist(self):
|
||||
return categorical.Categorical(logits=self._logits)
|
||||
|
||||
@property
|
||||
def _probs(self):
|
||||
return self.dist.probs
|
||||
|
||||
@property
|
||||
def _sqrt_probs(self):
|
||||
return math_ops.sqrt(self._probs)
|
||||
|
||||
@property
|
||||
def params(self):
|
||||
return self._logits
|
||||
|
||||
def multiply_fisher(self, vector):
|
||||
probs = self._probs
|
||||
return vector * probs - probs * math_ops.reduce_sum(
|
||||
vector * probs, axis=-1, keepdims=True)
|
||||
|
||||
def multiply_fisher_factor(self, vector):
|
||||
probs = self._probs
|
||||
sqrt_probs = self._sqrt_probs
|
||||
return sqrt_probs * vector - probs * math_ops.reduce_sum(
|
||||
sqrt_probs * vector, axis=-1, keepdims=True)
|
||||
|
||||
def multiply_fisher_factor_transpose(self, vector):
|
||||
probs = self._probs
|
||||
sqrt_probs = self._sqrt_probs
|
||||
return sqrt_probs * vector - sqrt_probs * math_ops.reduce_sum(
|
||||
probs * vector, axis=-1, keepdims=True)
|
||||
|
||||
def multiply_fisher_factor_replicated_one_hot(self, index):
|
||||
assert len(index) == 1, "Length of index was {}".format(len(index))
|
||||
probs = self._probs
|
||||
sqrt_probs = self._sqrt_probs
|
||||
sqrt_probs_slice = array_ops.expand_dims(sqrt_probs[:, index[0]], -1)
|
||||
padded_slice = insert_slice_in_zeros(sqrt_probs_slice, 1,
|
||||
int(sqrt_probs.shape[1]), index[0])
|
||||
return padded_slice - probs * sqrt_probs_slice
|
||||
|
||||
@property
|
||||
def fisher_factor_inner_shape(self):
|
||||
return array_ops.shape(self._logits)
|
||||
|
||||
@property
|
||||
def fisher_factor_inner_static_shape(self):
|
||||
return self._logits.shape
|
||||
|
||||
|
||||
class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss,
|
||||
NaturalParamsNegativeLogProbLoss):
|
||||
"""Neg log prob loss for multiple Bernoulli distributions param'd by logits.
|
||||
|
||||
Represents N independent Bernoulli distributions where N = len(logits). Its
|
||||
Fisher Information matrix is given by,
|
||||
|
||||
F = diag(p * (1-p))
|
||||
p = sigmoid(logits)
|
||||
|
||||
As F is diagonal with positive entries, its factor B is,
|
||||
|
||||
B = diag(sqrt(p * (1-p)))
|
||||
"""
|
||||
|
||||
def __init__(self, logits, targets=None, seed=None):
|
||||
self._logits = logits
|
||||
self._targets = targets
|
||||
super(MultiBernoulliNegativeLogProbLoss, self).__init__(seed=seed)
|
||||
|
||||
@property
|
||||
def targets(self):
|
||||
return self._targets
|
||||
|
||||
@property
|
||||
def dist(self):
|
||||
return bernoulli.Bernoulli(logits=self._logits)
|
||||
|
||||
@property
|
||||
def _probs(self):
|
||||
return self.dist.probs
|
||||
|
||||
@property
|
||||
def params(self):
|
||||
return self._logits
|
||||
|
||||
def multiply_fisher(self, vector):
|
||||
return self._probs * (1 - self._probs) * vector
|
||||
|
||||
def multiply_fisher_factor(self, vector):
|
||||
return math_ops.sqrt(self._probs * (1 - self._probs)) * vector
|
||||
|
||||
def multiply_fisher_factor_transpose(self, vector):
|
||||
return self.multiply_fisher_factor(vector) # it's symmetric in this case
|
||||
|
||||
def multiply_fisher_factor_replicated_one_hot(self, index):
|
||||
assert len(index) == 1, "Length of index was {}".format(len(index))
|
||||
probs_slice = array_ops.expand_dims(self._probs[:, index[0]], -1)
|
||||
output_slice = math_ops.sqrt(probs_slice * (1 - probs_slice))
|
||||
return insert_slice_in_zeros(output_slice, 1, int(self._logits.shape[1]),
|
||||
index[0])
|
||||
|
||||
@property
|
||||
def fisher_factor_inner_shape(self):
|
||||
return array_ops.shape(self._logits)
|
||||
|
||||
@property
|
||||
def fisher_factor_inner_static_shape(self):
|
||||
return self._logits.shape
|
||||
|
||||
|
||||
def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position):
|
||||
"""Inserts slice into a larger tensor of zeros.
|
||||
|
||||
Forms a new tensor which is the same shape as slice_to_insert, except that
|
||||
the dimension given by 'dim' is expanded to the size given by 'dim_size'.
|
||||
'position' determines the position (index) at which to insert the slice within
|
||||
that dimension.
|
||||
|
||||
Assumes slice_to_insert.shape[dim] = 1.
|
||||
|
||||
Args:
|
||||
slice_to_insert: The slice to insert.
|
||||
dim: The dimension which to expand with zeros.
|
||||
dim_size: The new size of the 'dim' dimension.
|
||||
position: The position of 'slice_to_insert' in the new tensor.
|
||||
|
||||
Returns:
|
||||
The new tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: If the slice's shape at the given dim is not 1.
|
||||
"""
|
||||
slice_shape = slice_to_insert.shape
|
||||
if slice_shape[dim] != 1:
|
||||
raise ValueError("Expected slice_to_insert.shape to have {} dim of 1, but "
|
||||
"was {}".format(dim, slice_to_insert.shape[dim]))
|
||||
|
||||
before = [0] * int(len(slice_shape))
|
||||
after = before[:]
|
||||
before[dim] = position
|
||||
after[dim] = dim_size - position - 1
|
||||
|
||||
return array_ops.pad(slice_to_insert, list(zip(before, after)))
|
||||
|
||||
|
||||
class OnehotCategoricalLogitsNegativeLogProbLoss(
|
||||
CategoricalLogitsNegativeLogProbLoss):
|
||||
"""Neg log prob loss for a categorical distribution with onehot targets.
|
||||
|
||||
Identical to CategoricalLogitsNegativeLogProbLoss except that the underlying
|
||||
distribution is OneHotCategorical as opposed to Categorical.
|
||||
"""
|
||||
|
||||
@property
|
||||
def dist(self):
|
||||
return onehot_categorical.OneHotCategorical(logits=self._logits)
|
@ -1,39 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Loss functions to be used by LayerCollection."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,line-too-long,wildcard-import
|
||||
from tensorflow.contrib.kfac.python.ops.loss_functions import *
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
# pylint: enable=unused-import,line-too-long,wildcard-import
|
||||
|
||||
_allowed_symbols = [
|
||||
"LossFunction",
|
||||
"NegativeLogProbLoss",
|
||||
"NaturalParamsNegativeLogProbLoss",
|
||||
"DistributionNegativeLogProbLoss",
|
||||
"NormalMeanNegativeLogProbLoss",
|
||||
"NormalMeanVarianceNegativeLogProbLoss",
|
||||
"CategoricalLogitsNegativeLogProbLoss",
|
||||
"OnehotCategoricalLogitsNegativeLogProbLoss",
|
||||
"MultiBernoulliNegativeLogProbLoss",
|
||||
"insert_slice_in_zeros",
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
|
@ -1,69 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Helper for choosing which op to run next in a distributed setting."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import ops as tf_ops
|
||||
|
||||
|
||||
class OpQueue(object):
|
||||
"""Class for choosing which Op to run next.
|
||||
|
||||
Constructs an infinitely repeating sequence of Ops in shuffled order.
|
||||
|
||||
In K-FAC, this can be used to distribute inverse update operations among
|
||||
workers.
|
||||
"""
|
||||
|
||||
def __init__(self, ops, seed=None):
|
||||
"""Initializes an OpQueue.
|
||||
|
||||
Args:
|
||||
ops: list of TensorFlow Ops. Ops to be selected from. All workers must
|
||||
initialize with the same set of ops.
|
||||
seed: int or None. Random seed used when shuffling order of ops.
|
||||
"""
|
||||
self._ops_by_name = {op.name: op for op in ops}
|
||||
|
||||
# Construct a (shuffled) Dataset with Op names.
|
||||
op_names = tf_ops.convert_to_tensor(list(sorted(op.name for op in ops)))
|
||||
op_names_dataset = (dataset_ops.Dataset.from_tensor_slices(op_names)
|
||||
.shuffle(len(ops), seed=seed).repeat())
|
||||
self._next_op_name = op_names_dataset.make_one_shot_iterator().get_next()
|
||||
|
||||
@property
|
||||
def ops(self):
|
||||
"""Ops this OpQueue can return in next_op()."""
|
||||
return self._ops_by_name.values()
|
||||
|
||||
def next_op(self, sess):
|
||||
"""Chooses which op to run next.
|
||||
|
||||
Note: This call will make a call to sess.run().
|
||||
|
||||
Args:
|
||||
sess: tf.Session.
|
||||
|
||||
Returns:
|
||||
Next Op chosen from 'ops'.
|
||||
"""
|
||||
# In Python 3, type(next_op_name) == bytes. Calling bytes.decode('ascii')
|
||||
# returns a str.
|
||||
next_op_name = sess.run(self._next_op_name).decode('ascii')
|
||||
return self._ops_by_name[next_op_name]
|
@ -1,30 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Helper for choosing which op to run next in a distributed setting."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,line-too-long,wildcard-import
|
||||
from tensorflow.contrib.kfac.python.ops.op_queue import *
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
# pylint: enable=unused-import,line-too-long,wildcard-import
|
||||
|
||||
_allowed_symbols = [
|
||||
'OpQueue',
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
|
@ -1,727 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The KFAC optimizer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import warnings
|
||||
|
||||
# pylint disable=long-line
|
||||
from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp
|
||||
from tensorflow.contrib.kfac.python.ops import estimator as est
|
||||
# pylint enable=long-line
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.training import gradient_descent
|
||||
|
||||
|
||||
class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
|
||||
"""The KFAC Optimizer (https://arxiv.org/abs/1503.05671)."""
|
||||
|
||||
def __init__(self,
|
||||
learning_rate,
|
||||
cov_ema_decay,
|
||||
damping,
|
||||
layer_collection,
|
||||
var_list=None,
|
||||
momentum=0.9,
|
||||
momentum_type="regular",
|
||||
norm_constraint=None,
|
||||
name="KFAC",
|
||||
estimation_mode="gradients",
|
||||
colocate_gradients_with_ops=True,
|
||||
batch_size=None,
|
||||
placement_strategy=None,
|
||||
**kwargs):
|
||||
"""Initializes the KFAC optimizer with the given settings.
|
||||
|
||||
Args:
|
||||
learning_rate: The base learning rate for the optimizer. Should probably
|
||||
be set to 1.0 when using momentum_type = 'qmodel', but can still be
|
||||
set lowered if desired (effectively lowering the trust in the
|
||||
quadratic model.)
|
||||
cov_ema_decay: The decay factor used when calculating the covariance
|
||||
estimate moving averages.
|
||||
damping: The damping factor used to stabilize training due to errors in
|
||||
the local approximation with the Fisher information matrix, and to
|
||||
regularize the update direction by making it closer to the gradient.
|
||||
If damping is adapted during training then this value is used for
|
||||
initializing damping variable.
|
||||
(Higher damping means the update looks more like a standard gradient
|
||||
update - see Tikhonov regularization.)
|
||||
layer_collection: The layer collection object, which holds the fisher
|
||||
blocks, Kronecker factors, and losses associated with the
|
||||
graph. The layer_collection cannot be modified after KfacOptimizer's
|
||||
initialization.
|
||||
var_list: Optional list or tuple of variables to train. Defaults to the
|
||||
list of variables collected in the graph under the key
|
||||
`GraphKeys.TRAINABLE_VARIABLES`.
|
||||
momentum: The momentum decay constant to use. Only applies when
|
||||
momentum_type is 'regular' or 'adam'. (Default: 0.9)
|
||||
momentum_type: The type of momentum to use in this optimizer, one of
|
||||
'regular', 'adam', or 'qmodel'. (Default: 'regular')
|
||||
norm_constraint: float or Tensor. If specified, the update is scaled down
|
||||
so that its approximate squared Fisher norm v^T F v is at most the
|
||||
specified value. May only be used with momentum type 'regular'.
|
||||
(Default: None)
|
||||
name: The name for this optimizer. (Default: 'KFAC')
|
||||
estimation_mode: The type of estimator to use for the Fishers. Can be
|
||||
'gradients', 'empirical', 'curvature_propagation', or 'exact'.
|
||||
(Default: 'gradients'). See the doc-string for FisherEstimator for
|
||||
more a more detailed description of these options.
|
||||
colocate_gradients_with_ops: Whether we should request gradients we
|
||||
compute in the estimator be colocated with their respective ops.
|
||||
(Default: True)
|
||||
batch_size: The size of the mini-batch. Only needed when momentum_type
|
||||
== 'qmodel' or when automatic adjustment is used. (Default: None)
|
||||
placement_strategy: string, Device placement strategy used when creating
|
||||
covariance variables, covariance ops, and inverse ops.
|
||||
(Default: `None`)
|
||||
**kwargs: Arguments to be passed to specific placement
|
||||
strategy mixin. Check `placement.RoundRobinPlacementMixin` for example.
|
||||
|
||||
Raises:
|
||||
ValueError: If the momentum type is unsupported.
|
||||
ValueError: If clipping is used with momentum type other than 'regular'.
|
||||
ValueError: If no losses have been registered with layer_collection.
|
||||
ValueError: If momentum is non-zero and momentum_type is not 'regular'
|
||||
or 'adam'.
|
||||
"""
|
||||
warnings.warn(
|
||||
"third_party.tensorflow.contrib.kfac is deprecated."
|
||||
"This will be removed on 15-07-2018. Check README for further details.",
|
||||
DeprecationWarning)
|
||||
# Parameters to be passed to the Fisher estimator:
|
||||
self._variables = var_list or tf_variables.trainable_variables
|
||||
self._cov_ema_decay = cov_ema_decay
|
||||
self._layers = layer_collection
|
||||
self._estimation_mode = estimation_mode
|
||||
self._colocate_gradients_with_ops = colocate_gradients_with_ops
|
||||
|
||||
# The below parameters are required only if damping needs to be adapted.
|
||||
# These parameters can be set by calling
|
||||
# set_damping_adaptation_params() explicitly.
|
||||
self._damping_adaptation_decay = 0.95
|
||||
self._damping_adaptation_interval = 5
|
||||
# Check section 6.5 KFAC paper. omega(1) = pow(damping decay, interval)
|
||||
self._omega = (
|
||||
self._damping_adaptation_decay**self._damping_adaptation_interval)
|
||||
self._adapt_damping = False
|
||||
self._min_damping = 1e-5
|
||||
self._prev_train_batch = None
|
||||
self._is_chief = False
|
||||
self._loss_fn = None
|
||||
self._damping_constant = damping
|
||||
self._damping = None
|
||||
self._rho = None
|
||||
self._prev_loss = None
|
||||
self._q_model_change = None
|
||||
self._update_damping_op = None
|
||||
|
||||
momentum_type = momentum_type.lower()
|
||||
legal_momentum_types = ["regular", "adam", "qmodel"]
|
||||
|
||||
if momentum_type not in legal_momentum_types:
|
||||
raise ValueError("Unsupported momentum type {}. Must be one of {}."
|
||||
.format(momentum_type, legal_momentum_types))
|
||||
if momentum_type != "regular" and norm_constraint is not None:
|
||||
raise ValueError("Update clipping is only supported with momentum "
|
||||
"type 'regular'.")
|
||||
if momentum_type not in ["regular", "adam"] and momentum != 0:
|
||||
raise ValueError("Momentum must be unspecified if using a momentum_type "
|
||||
"other than 'regular' or 'adam'.")
|
||||
|
||||
# Extra parameters of the optimizer
|
||||
self._momentum = momentum
|
||||
self._momentum_type = momentum_type
|
||||
self._norm_constraint = norm_constraint
|
||||
self._batch_size = batch_size
|
||||
self._placement_strategy = placement_strategy
|
||||
|
||||
with variable_scope.variable_scope(name):
|
||||
self._fisher_est = est.make_fisher_estimator(
|
||||
placement_strategy=placement_strategy,
|
||||
variables=self._variables,
|
||||
cov_ema_decay=self._cov_ema_decay,
|
||||
damping=self.damping,
|
||||
layer_collection=self._layers,
|
||||
exps=(-1,),
|
||||
estimation_mode=self._estimation_mode,
|
||||
colocate_gradients_with_ops=self._colocate_gradients_with_ops,
|
||||
**kwargs)
|
||||
|
||||
super(KfacOptimizer, self).__init__(learning_rate, name=name)
|
||||
|
||||
def set_damping_adaptation_params(self,
|
||||
is_chief,
|
||||
prev_train_batch,
|
||||
loss_fn,
|
||||
min_damping=1e-5,
|
||||
damping_adaptation_decay=0.99,
|
||||
damping_adaptation_interval=5):
|
||||
"""Sets parameters required to adapt damping during training.
|
||||
|
||||
When called, enables damping adaptation according to the Levenberg-Marquardt
|
||||
style rule described in Section 6.5 of "Optimizing Neural Networks with
|
||||
Kronecker-factored Approximate Curvature".
|
||||
|
||||
Note that this function creates Tensorflow variables which store a few
|
||||
scalars and are accessed by the ops which update the damping (as part
|
||||
of the training op returned by the minimize() method).
|
||||
|
||||
Args:
|
||||
is_chief: `Boolean`, `True` if the worker is chief.
|
||||
prev_train_batch: Training data used to minimize loss in the previous
|
||||
step. This will be used to evaluate loss by calling
|
||||
`loss_fn(prev_train_batch)`.
|
||||
loss_fn: `function` that takes as input training data tensor and returns
|
||||
a scalar loss.
|
||||
min_damping: `float`(Optional), Minimum value the damping parameter
|
||||
can take. Default value 1e-5.
|
||||
damping_adaptation_decay: `float`(Optional), The `damping` parameter is
|
||||
multiplied by the `damping_adaptation_decay` every
|
||||
`damping_adaptation_interval` number of iterations. Default value 0.99.
|
||||
damping_adaptation_interval: `int`(Optional), Number of steps in between
|
||||
updating the `damping` parameter. Default value 5.
|
||||
|
||||
Raises:
|
||||
ValueError: If `set_damping_adaptation_params` is already called and the
|
||||
the `adapt_damping` is `True`.
|
||||
"""
|
||||
if self._adapt_damping:
|
||||
raise ValueError("Damping adaptation parameters already set.")
|
||||
|
||||
with variable_scope.variable_scope(self.get_name()):
|
||||
self._adapt_damping = True
|
||||
self._is_chief = is_chief
|
||||
self._prev_train_batch = prev_train_batch
|
||||
self._loss_fn = loss_fn
|
||||
self._damping_adaptation_decay = damping_adaptation_decay
|
||||
self._damping_adaptation_interval = damping_adaptation_interval
|
||||
self._omega = (
|
||||
self._damping_adaptation_decay**self._damping_adaptation_interval)
|
||||
self._min_damping = min_damping
|
||||
|
||||
self._rho = variable_scope.get_variable(
|
||||
"rho", shape=(), dtype=dtypes.float32, trainable=False) # LM ratio.
|
||||
self._prev_loss = variable_scope.get_variable(
|
||||
"prev_loss", shape=(), dtype=dtypes.float32, trainable=False)
|
||||
self._q_model_change = variable_scope.get_variable(
|
||||
"q_model_change", shape=(), dtype=dtypes.float32, trainable=False)
|
||||
self._damping = variable_scope.get_variable(
|
||||
"damping", initializer=self._damping_constant, trainable=False)
|
||||
|
||||
@property
|
||||
def variables(self):
|
||||
return self._fisher_est.variables
|
||||
|
||||
@property
|
||||
def damping(self):
|
||||
if self._damping:
|
||||
return self._damping
|
||||
else:
|
||||
return self._damping_constant
|
||||
|
||||
@property
|
||||
def damping_adaptation_interval(self):
|
||||
return self._damping_adaptation_interval
|
||||
|
||||
def make_vars_and_create_op_thunks(self):
|
||||
"""Make vars and create op thunks.
|
||||
|
||||
Returns:
|
||||
cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
|
||||
the list of factors given by the "factors" property.
|
||||
inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
|
||||
the list of factors given by the "factors" property.
|
||||
"""
|
||||
scope = self.get_name() + "/" + self._fisher_est.name
|
||||
return self._fisher_est.make_vars_and_create_op_thunks(scope=scope)
|
||||
|
||||
def create_ops_and_vars_thunks(self):
|
||||
"""Create thunks that make the ops and vars on demand.
|
||||
|
||||
This function returns 4 lists of thunks: cov_variable_thunks,
|
||||
cov_update_thunks, inv_variable_thunks, and inv_update_thunks.
|
||||
|
||||
The length of each list is the number of factors and the i-th element of
|
||||
each list corresponds to the i-th factor (given by the "factors" property).
|
||||
|
||||
Note that the execution of these thunks must happen in a certain
|
||||
partial order. The i-th element of cov_variable_thunks must execute
|
||||
before the i-th element of cov_update_thunks (and also the i-th element
|
||||
of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks
|
||||
must execute before the i-th element of inv_update_thunks.
|
||||
|
||||
TL;DR (oversimplified): Execute the thunks according to the order that
|
||||
they are returned.
|
||||
|
||||
Returns:
|
||||
cov_variable_thunks: A list of thunks that make the cov variables.
|
||||
cov_update_thunks: A list of thunks that make the cov update ops.
|
||||
inv_variable_thunks: A list of thunks that make the inv variables.
|
||||
inv_update_thunks: A list of thunks that make the inv update ops.
|
||||
"""
|
||||
scope = self.get_name() + "/" + self._fisher_est.name
|
||||
return self._fisher_est.create_ops_and_vars_thunks(scope=scope)
|
||||
|
||||
def minimize(self, *args, **kwargs):
|
||||
# Should this variable scope encompass everything below? Or will the super-
|
||||
# class make another copy of the same name scope?
|
||||
with variable_scope.variable_scope(self.get_name()):
|
||||
kwargs["var_list"] = kwargs.get("var_list") or self.variables
|
||||
if set(kwargs["var_list"]) != set(self.variables):
|
||||
raise ValueError("var_list doesn't match with set of Fisher-estimating "
|
||||
"variables.")
|
||||
if self._adapt_damping and self._is_chief:
|
||||
global_step = kwargs.get("global_step", None)
|
||||
if not global_step:
|
||||
raise KeyError("global_step needs to be passed to optimizer.minimize "
|
||||
"if damping parameter is adapted.")
|
||||
update_damping_op = self._update_damping(self._prev_train_batch,
|
||||
global_step)
|
||||
with ops.control_dependencies([update_damping_op]):
|
||||
loss = args[0]
|
||||
loss_assign_op = state_ops.assign(self._prev_loss, loss)
|
||||
train_op = super(KfacOptimizer, self).minimize(*args, **kwargs)
|
||||
return control_flow_ops.group(loss_assign_op, train_op)
|
||||
else:
|
||||
return super(KfacOptimizer, self).minimize(*args, **kwargs)
|
||||
|
||||
def compute_gradients(self, *args, **kwargs):
|
||||
# args[1] could be our var_list
|
||||
if len(args) > 1:
|
||||
var_list = args[1]
|
||||
else:
|
||||
kwargs["var_list"] = kwargs.get("var_list") or self.variables
|
||||
var_list = kwargs["var_list"]
|
||||
|
||||
if set(var_list) != set(self.variables):
|
||||
raise ValueError("var_list doesn't match with set of Fisher-estimating "
|
||||
"variables.")
|
||||
return super(KfacOptimizer, self).compute_gradients(*args, **kwargs)
|
||||
|
||||
def apply_gradients(self, grads_and_vars, *args, **kwargs):
|
||||
"""Applies gradients to variables.
|
||||
|
||||
Args:
|
||||
grads_and_vars: List of (gradient, variable) pairs.
|
||||
*args: Additional arguments for super.apply_gradients.
|
||||
**kwargs: Additional keyword arguments for super.apply_gradients.
|
||||
|
||||
Returns:
|
||||
An `Operation` that applies the specified gradients.
|
||||
"""
|
||||
# In Python 3, grads_and_vars can be a zip() object which can only be
|
||||
# iterated over once. By converting it to a list, we ensure that it can be
|
||||
# iterated over more than once.
|
||||
grads_and_vars = list(grads_and_vars)
|
||||
|
||||
# Compute step.
|
||||
steps_and_vars = self._compute_update_steps(grads_and_vars)
|
||||
|
||||
# Update trainable variables with this step.
|
||||
return super(KfacOptimizer, self).apply_gradients(steps_and_vars, *args,
|
||||
**kwargs)
|
||||
|
||||
def _squared_fisher_norm(self, grads_and_vars, precon_grads_and_vars):
|
||||
"""Computes the squared (approximate) Fisher norm of the updates.
|
||||
|
||||
This is defined as v^T F v, where F is the approximate Fisher matrix
|
||||
as computed by the estimator, and v = F^{-1} g, where g is the gradient.
|
||||
This is computed efficiently as v^T g.
|
||||
|
||||
Args:
|
||||
grads_and_vars: List of (gradient, variable) pairs.
|
||||
precon_grads_and_vars: List of (preconditioned gradient, variable) pairs.
|
||||
Must be the result of calling `self._fisher_est.multiply_inverse`
|
||||
on `grads_and_vars`.
|
||||
|
||||
Returns:
|
||||
Scalar representing the squared norm.
|
||||
|
||||
Raises:
|
||||
ValueError: if the two list arguments do not contain the same variables,
|
||||
in the same order.
|
||||
"""
|
||||
for (_, gvar), (_, pgvar) in zip(grads_and_vars, precon_grads_and_vars):
|
||||
if gvar is not pgvar:
|
||||
raise ValueError("The variables referenced by the two arguments "
|
||||
"must match.")
|
||||
terms = [
|
||||
math_ops.reduce_sum(grad * pgrad)
|
||||
for (grad, _), (pgrad, _) in zip(grads_and_vars, precon_grads_and_vars)
|
||||
]
|
||||
return math_ops.reduce_sum(terms)
|
||||
|
||||
def _update_clip_coeff(self, grads_and_vars, precon_grads_and_vars):
|
||||
"""Computes the scale factor for the update to satisfy the norm constraint.
|
||||
|
||||
Defined as min(1, sqrt(c / r^T F r)), where c is the norm constraint,
|
||||
F is the approximate Fisher matrix, and r is the update vector, i.e.
|
||||
-alpha * v, where alpha is the learning rate, and v is the preconditioned
|
||||
gradient.
|
||||
|
||||
This is based on Section 5 of Ba et al., Distributed Second-Order
|
||||
Optimization using Kronecker-Factored Approximations. Note that they
|
||||
absorb the learning rate alpha (which they denote eta_max) into the formula
|
||||
for the coefficient, while in our implementation, the rescaling is done
|
||||
before multiplying by alpha. Hence, our formula differs from theirs by a
|
||||
factor of alpha.
|
||||
|
||||
Args:
|
||||
grads_and_vars: List of (gradient, variable) pairs.
|
||||
precon_grads_and_vars: List of (preconditioned gradient, variable) pairs.
|
||||
Must be the result of calling `self._fisher_est.multiply_inverse`
|
||||
on `grads_and_vars`.
|
||||
|
||||
Returns:
|
||||
Scalar representing the coefficient which should be applied to the
|
||||
preconditioned gradients to satisfy the norm constraint.
|
||||
"""
|
||||
sq_norm_grad = self._squared_fisher_norm(grads_and_vars,
|
||||
precon_grads_and_vars)
|
||||
sq_norm_up = sq_norm_grad * self._learning_rate**2
|
||||
return math_ops.minimum(1.,
|
||||
math_ops.sqrt(self._norm_constraint / sq_norm_up))
|
||||
|
||||
def _clip_updates(self, grads_and_vars, precon_grads_and_vars):
|
||||
"""Rescales the preconditioned gradients to satisfy the norm constraint.
|
||||
|
||||
Rescales the preconditioned gradients such that the resulting update r
|
||||
(after multiplying by the learning rate) will satisfy the norm constraint.
|
||||
This constraint is that r^T F r <= C, where F is the approximate Fisher
|
||||
matrix, and C is the norm_constraint attribute. See Section 5 of
|
||||
Ba et al., Distributed Second-Order Optimization using Kronecker-Factored
|
||||
Approximations.
|
||||
|
||||
Args:
|
||||
grads_and_vars: List of (gradient, variable) pairs.
|
||||
precon_grads_and_vars: List of (preconditioned gradient, variable) pairs.
|
||||
Must be the result of calling `self._fisher_est.multiply_inverse`
|
||||
on `grads_and_vars`.
|
||||
|
||||
Returns:
|
||||
List of (rescaled preconditioned gradient, variable) pairs.
|
||||
"""
|
||||
coeff = self._update_clip_coeff(grads_and_vars, precon_grads_and_vars)
|
||||
return [(pgrad * coeff, var) for pgrad, var in precon_grads_and_vars]
|
||||
|
||||
def _compute_prev_updates(self, variables):
|
||||
"""Computes previous updates as negative velocities scaled by learning rate.
|
||||
|
||||
Args:
|
||||
variables: List of variables in the graph that the update will be
|
||||
applied to.
|
||||
|
||||
Returns:
|
||||
List of previous updates applied to the `variables`.
|
||||
"""
|
||||
return list(
|
||||
-1 * self._learning_rate * self._zeros_slot(var, "velocity", self._name)
|
||||
for var in variables)
|
||||
|
||||
def _compute_qmodel_hyperparams(self, precon_grads, prev_updates, grads,
|
||||
variables):
|
||||
"""Compute optimal update hyperparameters from the quadratic model.
|
||||
|
||||
More specifically, if L is the loss we minimize a quadratic approximation
|
||||
of L(theta + d) which we denote by qmodel(d) with
|
||||
d = alpha*precon_grad + mu*prev_update with respect to alpha and mu, where
|
||||
|
||||
qmodel(d) = (1/2) * d^T * B * d + grad^T*d + L(theta) .
|
||||
|
||||
Unlike in the KL clipping approach we use the non-approximated quadratic
|
||||
model where the curvature matrix C is the true Fisher on the current
|
||||
mini-batch (computed without any approximations beyond mini-batch sampling),
|
||||
with the usual Tikhonov damping/regularization applied,
|
||||
|
||||
C = F + damping * I
|
||||
|
||||
See Section 7 of https://arxiv.org/abs/1503.05671 for a derivation of
|
||||
the formula. See Appendix C for a discussion of the trick of using
|
||||
a factorized Fisher matrix to more efficiently compute the required
|
||||
vector-matrix-vector products.
|
||||
|
||||
Note that the elements of all 4 lists passed to this function must
|
||||
be in correspondence with each other.
|
||||
|
||||
Args:
|
||||
precon_grads: List of preconditioned gradients.
|
||||
prev_updates: List of updates computed at the previous iteration.
|
||||
grads: List of gradients.
|
||||
variables: List of variables in the graph that the update will be
|
||||
applied to. (Note that this function doesn't actually apply the
|
||||
update.)
|
||||
|
||||
Returns:
|
||||
(alpha, mu, qmodel_change), where alpha and mu are chosen to optimize the
|
||||
quadratic model, and
|
||||
qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0)
|
||||
= qmodel(alpha*precon_grad + mu*prev_update) - L(theta).
|
||||
"""
|
||||
|
||||
cmvpc = cmvp.CurvatureMatrixVectorProductComputer(self._layers.losses,
|
||||
variables)
|
||||
|
||||
# compute the matrix-vector products with the transposed Fisher factor
|
||||
fft_precon_grads = cmvpc.multiply_fisher_factor_transpose(precon_grads)
|
||||
fft_prev_updates = cmvpc.multiply_fisher_factor_transpose(prev_updates)
|
||||
batch_size = math_ops.cast(
|
||||
self._batch_size, dtype=fft_precon_grads[0].dtype)
|
||||
|
||||
# compute the entries of the 2x2 matrix
|
||||
m_11 = (
|
||||
_inner_product_list(fft_precon_grads, fft_precon_grads) / batch_size +
|
||||
self.damping * _inner_product_list(precon_grads, precon_grads))
|
||||
|
||||
m_21 = (
|
||||
_inner_product_list(fft_prev_updates, fft_precon_grads) / batch_size +
|
||||
self.damping * _inner_product_list(prev_updates, precon_grads))
|
||||
|
||||
m_22 = (
|
||||
_inner_product_list(fft_prev_updates, fft_prev_updates) / batch_size +
|
||||
self.damping * _inner_product_list(prev_updates, prev_updates))
|
||||
|
||||
def non_zero_prevupd_case():
|
||||
r"""Computes optimal (alpha, mu) given non-zero previous update.
|
||||
|
||||
We solve the full 2x2 linear system. See Martens & Grosse (2015),
|
||||
Section 7, definition of $\alpha^*$ and $\mu^*$.
|
||||
|
||||
Returns:
|
||||
(alpha, mu, qmodel_change), where alpha and mu are chosen to optimize
|
||||
the quadratic model, and
|
||||
qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0).
|
||||
"""
|
||||
m = ops.convert_to_tensor([[m_11, m_21], [m_21, m_22]])
|
||||
|
||||
c = ops.convert_to_tensor([[_inner_product_list(grads, precon_grads)],
|
||||
[_inner_product_list(grads, prev_updates)]])
|
||||
|
||||
sol = -1. * _two_by_two_solve(m, c)
|
||||
alpha = sol[0]
|
||||
mu = sol[1]
|
||||
qmodel_change = 0.5 * math_ops.reduce_sum(sol * c)
|
||||
|
||||
return alpha, mu, qmodel_change
|
||||
|
||||
def zero_prevupd_case():
|
||||
r"""Computes optimal (alpha, mu) given all-zero previous update.
|
||||
|
||||
The linear system reduces to 1x1. See Martens & Grosse (2015),
|
||||
Section 6.4, definition of $\alpha^*$.
|
||||
|
||||
Returns:
|
||||
(alpha, 0.0, qmodel_change), where alpha is chosen to optimize the
|
||||
quadratic model, and
|
||||
qmodel_change = qmodel(alpha*precon_grad) - qmodel(0)
|
||||
"""
|
||||
m = m_11
|
||||
c = _inner_product_list(grads, precon_grads)
|
||||
|
||||
alpha = -c / m
|
||||
mu = 0.0
|
||||
qmodel_change = 0.5 * alpha * c
|
||||
|
||||
return alpha, mu, qmodel_change
|
||||
|
||||
return control_flow_ops.cond(
|
||||
math_ops.equal(m_22, 0.0), zero_prevupd_case, non_zero_prevupd_case)
|
||||
|
||||
def _assign_q_model_change(self, q_model_change):
|
||||
"""Assigns `q_model_change` to `self._q_model_change` if damping is adapted.
|
||||
|
||||
Note only the chief worker does the assignment.
|
||||
|
||||
Args:
|
||||
q_model_change: Scalar tensor of type `float32`.
|
||||
|
||||
Returns:
|
||||
If `adapt_damping` is `True` then returns an assign op, Otherwise returns
|
||||
a no_op().
|
||||
"""
|
||||
if self._adapt_damping and self._is_chief:
|
||||
q_model_assign_op = state_ops.assign(self._q_model_change, q_model_change)
|
||||
else:
|
||||
q_model_assign_op = control_flow_ops.no_op()
|
||||
return q_model_assign_op
|
||||
|
||||
def _compute_qmodel_hyperparams_wrapper(self, grads_and_vars,
|
||||
precon_grads_and_vars):
|
||||
"""Wrapper function for `self._compute_qmodel_hyperparams`.
|
||||
|
||||
Constructs a list of preconditioned gradients and variables. Also creates a
|
||||
op to assign the computed q model change to `self._q_model_change`.
|
||||
|
||||
Args:
|
||||
grads_and_vars: List of (gradient, variable) pairs.
|
||||
precon_grads_and_vars: List of (preconditioned gradients, variable)
|
||||
pairs.
|
||||
|
||||
Returns:
|
||||
(alpha, mu, q_model_assign_op), where alpha and mu are chosen to optimize
|
||||
the quadratic model, `q_model_assign_op` assigns the computed q model
|
||||
change to `self._q_model_change`.
|
||||
"""
|
||||
precon_grads = list(
|
||||
precon_grad for (precon_grad, _) in precon_grads_and_vars)
|
||||
grads = list(grad for (grad, _) in grads_and_vars)
|
||||
variables = list(var for (_, var) in grads_and_vars)
|
||||
prev_updates = self._compute_prev_updates(variables)
|
||||
# Compute optimal velocity update parameters according to quadratic model
|
||||
alpha, mu, q_model_change = self._compute_qmodel_hyperparams(
|
||||
precon_grads, prev_updates, grads, variables)
|
||||
|
||||
return alpha, mu, self._assign_q_model_change(q_model_change)
|
||||
|
||||
def _compute_update_steps(self, grads_and_vars):
|
||||
"""Computes the update steps for the variables given the gradients.
|
||||
|
||||
Args:
|
||||
grads_and_vars: List of (gradient, variable) pairs.
|
||||
|
||||
Returns:
|
||||
A list of tuple (assign_op ,var) where `assign_op` assigns the update
|
||||
steps to `var`.
|
||||
"""
|
||||
|
||||
if self._momentum_type == "regular":
|
||||
# Compute "preconditioned" gradient.
|
||||
precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars)
|
||||
|
||||
# Apply "KL clipping" if asked for.
|
||||
if self._norm_constraint is not None:
|
||||
precon_grads_and_vars = self._clip_updates(grads_and_vars,
|
||||
precon_grads_and_vars)
|
||||
|
||||
# Update the velocity with this and return it as the step.
|
||||
if self._adapt_damping and self._is_chief:
|
||||
_, _, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper(
|
||||
grads_and_vars, precon_grads_and_vars)
|
||||
with ops.control_dependencies([q_model_assign_op]):
|
||||
return self._update_velocities(precon_grads_and_vars, self._momentum)
|
||||
else:
|
||||
return self._update_velocities(precon_grads_and_vars, self._momentum)
|
||||
elif self._momentum_type == "adam":
|
||||
# Update velocity.
|
||||
velocities_and_vars = self._update_velocities(grads_and_vars,
|
||||
self._momentum)
|
||||
# Return "preconditioned" velocity vector as the step.
|
||||
return self._fisher_est.multiply_inverse(velocities_and_vars)
|
||||
|
||||
elif self._momentum_type == "qmodel":
|
||||
# Compute "preconditioned" gradient.
|
||||
precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars)
|
||||
|
||||
# Compute optimal velocity update parameters according to quadratic model
|
||||
alpha, mu, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper(
|
||||
grads_and_vars, precon_grads_and_vars)
|
||||
|
||||
with ops.control_dependencies([q_model_assign_op]):
|
||||
return self._update_velocities(
|
||||
precon_grads_and_vars, mu, vec_coeff=-alpha)
|
||||
|
||||
def _update_velocities(self, vecs_and_vars, decay, vec_coeff=1.0):
|
||||
"""Updates the velocities of the variables with the given vectors.
|
||||
|
||||
Args:
|
||||
vecs_and_vars: List of (vector, variable) pairs.
|
||||
decay: How much to decay the old velocity by. This is often referred to
|
||||
as the 'momentum constant'.
|
||||
vec_coeff: Coefficient to apply to the vectors before adding them to the
|
||||
velocity.
|
||||
|
||||
Returns:
|
||||
A list of (velocity, var) indicating the new velocity for each var.
|
||||
"""
|
||||
|
||||
def _update_velocity(vec, var):
|
||||
velocity = self._zeros_slot(var, "velocity", self._name)
|
||||
with ops.colocate_with(velocity):
|
||||
# NOTE(mattjj): read/modify/write race condition not suitable for async.
|
||||
|
||||
# Compute the new velocity for this variable.
|
||||
new_velocity = decay * velocity + vec_coeff * vec
|
||||
|
||||
# Save the updated velocity.
|
||||
return (array_ops.identity(velocity.assign(new_velocity)), var)
|
||||
|
||||
# Go through variable and update its associated part of the velocity vector.
|
||||
return [_update_velocity(vec, var) for vec, var in vecs_and_vars]
|
||||
|
||||
def _update_damping(self, prev_batch, global_step):
|
||||
"""Adapts damping parameter. Check KFAC (Section 6.5) for the details.
|
||||
|
||||
The damping parameter is updated according to the Levenberg-Marquardt rule
|
||||
every `self._damping_adaptation_interval` iterations.
|
||||
|
||||
Args:
|
||||
prev_batch: Tensor or tuple of tensors which can be passed to
|
||||
`self._loss_fn` to evaluate loss.
|
||||
global_step: `Variable` which keeps track of number of times the training
|
||||
variables have been updated.
|
||||
Returns:
|
||||
A `tf.cond` op which updates the damping parameter.
|
||||
"""
|
||||
def compute_damping():
|
||||
""""Adapts damping parameter based on "reduction ratio".
|
||||
|
||||
Reduction ratio captures how closely the quadratic approximation to the
|
||||
loss function approximates the actual loss within a trust region. The
|
||||
damping update tries to make the damping as small as possible while
|
||||
maintaining the property that the quadratic model remains a good local
|
||||
approximation to the loss function.
|
||||
|
||||
Returns:
|
||||
An Op to assign newly computed damping value to `self._damping`.
|
||||
"""
|
||||
prev_batch_loss = self._loss_fn(prev_batch)
|
||||
with ops.control_dependencies([prev_batch_loss]):
|
||||
rho_assign = self._rho.assign(
|
||||
(prev_batch_loss - self._prev_loss) / self._q_model_change)
|
||||
with ops.control_dependencies([rho_assign]):
|
||||
new_damping = control_flow_ops.case(
|
||||
[(self._rho < 0.25, lambda: self.damping / self._omega),
|
||||
(self._rho > 0.75, lambda: self.damping * self._omega)],
|
||||
lambda: self.damping)
|
||||
with ops.control_dependencies([new_damping]):
|
||||
new_damping_min = math_ops.maximum(new_damping, self._min_damping)
|
||||
return control_flow_ops.group(self._damping.assign(new_damping_min))
|
||||
|
||||
return control_flow_ops.cond(
|
||||
math_ops.equal(
|
||||
math_ops.mod(global_step + 1, self._damping_adaptation_interval),
|
||||
0), compute_damping, control_flow_ops.no_op)
|
||||
|
||||
|
||||
def _inner_product_list(list1, list2):
|
||||
return math_ops.add_n(
|
||||
[math_ops.reduce_sum(elt1 * elt2) for elt1, elt2 in zip(list1, list2)])
|
||||
|
||||
|
||||
def _two_by_two_solve(m, c):
|
||||
# it might be better just to crank out the exact formula for 2x2 inverses
|
||||
return math_ops.matmul(linalg_ops.matrix_inverse(m), c)
|
@ -1,30 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The KFAC optimizer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,line-too-long,wildcard-import
|
||||
from tensorflow.contrib.kfac.python.ops.optimizer import *
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
# pylint: enable=unused-import,line-too-long,wildcard-import
|
||||
|
||||
_allowed_symbols = [
|
||||
"KfacOptimizer",
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
|
@ -1,114 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Implements placement strategies for cov and inv ops, cov variables."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
|
||||
from tensorflow.python.framework import ops as tf_ops
|
||||
|
||||
|
||||
def _make_thunk_on_device(func, device):
|
||||
def thunk():
|
||||
with tf_ops.device(device):
|
||||
return func()
|
||||
return thunk
|
||||
|
||||
|
||||
class RoundRobinPlacementMixin(object):
|
||||
"""Implements round robin placement strategy for ops and variables."""
|
||||
|
||||
def __init__(self, cov_devices=None, inv_devices=None, **kwargs):
|
||||
"""Initializes the RoundRobinPlacementMixin class.
|
||||
|
||||
Args:
|
||||
cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
|
||||
computations will be placed on these devices in a round-robin fashion.
|
||||
Can be None, which means that no devices are specified.
|
||||
inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
|
||||
computations will be placed on these devices in a round-robin fashion.
|
||||
Can be None, which means that no devices are specified.
|
||||
**kwargs: Need something here?
|
||||
|
||||
"""
|
||||
super(RoundRobinPlacementMixin, self).__init__(**kwargs)
|
||||
self._cov_devices = cov_devices
|
||||
self._inv_devices = inv_devices
|
||||
|
||||
def make_vars_and_create_op_thunks(self, scope=None):
|
||||
"""Make vars and create op thunks w/ a round-robin device placement start.
|
||||
|
||||
For each factor, all of that factor's cov variables and their associated
|
||||
update ops will be placed on a particular device. A new device is chosen
|
||||
for each factor by cycling through list of devices in the
|
||||
`self._cov_devices` attribute. If `self._cov_devices` is `Non`e then no
|
||||
explicit device placement occurs.
|
||||
|
||||
An analogous strategy is followed for inverse update ops, with the list of
|
||||
devices being given by the `self._inv_devices` attribute.
|
||||
|
||||
Inverse variables on the other hand are not placed on any specific device
|
||||
(they will just use the current the device placement context, whatever
|
||||
that happens to be). The idea is that the inverse variable belong where
|
||||
they will be accessed most often, which is the device that actually applies
|
||||
the preconditioner to the gradient. The user will be responsible for setting
|
||||
the device context for this.
|
||||
|
||||
Args:
|
||||
scope: A string or None. If None it will be set to the name of this
|
||||
estimator (given by the name property). All variables will be created,
|
||||
and all thunks will execute, inside of a variable scope of the given
|
||||
name. (Default: None)
|
||||
|
||||
Returns:
|
||||
cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
|
||||
the list of factors given by the "factors" property.
|
||||
inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
|
||||
the list of factors given by the "factors" property.
|
||||
"""
|
||||
# Note: `create_ops_and_vars_thunks` is implemented in `FisherEstimator`.
|
||||
(cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw,
|
||||
inv_update_thunks_raw) = self.create_ops_and_vars_thunks(scope=scope)
|
||||
|
||||
if self._cov_devices:
|
||||
cov_update_thunks = []
|
||||
for cov_variable_thunk, cov_update_thunk, device in zip(
|
||||
cov_variable_thunks_raw, cov_update_thunks_raw,
|
||||
itertools.cycle(self._cov_devices)):
|
||||
with tf_ops.device(device):
|
||||
cov_variable_thunk()
|
||||
cov_update_thunks.append(_make_thunk_on_device(cov_update_thunk,
|
||||
device))
|
||||
else:
|
||||
for cov_variable_thunk in cov_variable_thunks_raw:
|
||||
cov_variable_thunk()
|
||||
cov_update_thunks = cov_update_thunks_raw
|
||||
|
||||
for inv_variable_thunk in inv_variable_thunks_raw:
|
||||
inv_variable_thunk()
|
||||
|
||||
if self._inv_devices:
|
||||
inv_update_thunks = []
|
||||
for inv_update_thunk, device in zip(inv_update_thunks_raw,
|
||||
itertools.cycle(self._inv_devices)):
|
||||
inv_update_thunks.append(_make_thunk_on_device(inv_update_thunk,
|
||||
device))
|
||||
else:
|
||||
inv_update_thunks = inv_update_thunks_raw
|
||||
|
||||
return cov_update_thunks, inv_update_thunks
|
@ -1,709 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utility functions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.tpu.python.ops import tpu_ops
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
# Method used for inverting matrices.
|
||||
POSDEF_INV_METHOD = "cholesky"
|
||||
POSDEF_EIG_METHOD = "self_adjoint"
|
||||
|
||||
|
||||
def set_global_constants(posdef_inv_method=None):
|
||||
"""Sets various global constants used by the classes in this module."""
|
||||
global POSDEF_INV_METHOD
|
||||
|
||||
if posdef_inv_method is not None:
|
||||
POSDEF_INV_METHOD = posdef_inv_method
|
||||
|
||||
|
||||
class SequenceDict(object):
|
||||
"""A dict convenience wrapper that allows getting/setting with sequences."""
|
||||
|
||||
def __init__(self, iterable=None):
|
||||
self._dict = dict(iterable or [])
|
||||
|
||||
def __getitem__(self, key_or_keys):
|
||||
if isinstance(key_or_keys, (tuple, list)):
|
||||
return list(map(self.__getitem__, key_or_keys))
|
||||
else:
|
||||
return self._dict[key_or_keys]
|
||||
|
||||
def __setitem__(self, key_or_keys, val_or_vals):
|
||||
if isinstance(key_or_keys, (tuple, list)):
|
||||
for key, value in zip(key_or_keys, val_or_vals):
|
||||
self[key] = value
|
||||
else:
|
||||
self._dict[key_or_keys] = val_or_vals
|
||||
|
||||
def items(self):
|
||||
return list(self._dict.items())
|
||||
|
||||
|
||||
def tensors_to_column(tensors):
|
||||
"""Converts a tensor or list of tensors to a column vector.
|
||||
|
||||
Args:
|
||||
tensors: A tensor or list of tensors.
|
||||
|
||||
Returns:
|
||||
The tensors reshaped into vectors and stacked on top of each other.
|
||||
"""
|
||||
if isinstance(tensors, (tuple, list)):
|
||||
return array_ops.concat(
|
||||
tuple(array_ops.reshape(tensor, [-1, 1]) for tensor in tensors), axis=0)
|
||||
else:
|
||||
return array_ops.reshape(tensors, [-1, 1])
|
||||
|
||||
|
||||
def column_to_tensors(tensors_template, colvec):
|
||||
"""Converts a column vector back to the shape of the given template.
|
||||
|
||||
Args:
|
||||
tensors_template: A tensor or list of tensors.
|
||||
colvec: A 2d column vector with the same shape as the value of
|
||||
tensors_to_column(tensors_template).
|
||||
|
||||
Returns:
|
||||
X, where X is tensor or list of tensors with the properties:
|
||||
1) tensors_to_column(X) = colvec
|
||||
2) X (or its elements) have the same shape as tensors_template (or its
|
||||
elements)
|
||||
"""
|
||||
if isinstance(tensors_template, (tuple, list)):
|
||||
offset = 0
|
||||
tensors = []
|
||||
for tensor_template in tensors_template:
|
||||
sz = np.prod(tensor_template.shape.as_list(), dtype=np.int32)
|
||||
tensor = array_ops.reshape(colvec[offset:(offset + sz)],
|
||||
tensor_template.shape)
|
||||
tensors.append(tensor)
|
||||
offset += sz
|
||||
|
||||
tensors = tuple(tensors)
|
||||
else:
|
||||
tensors = array_ops.reshape(colvec, tensors_template.shape)
|
||||
|
||||
return tensors
|
||||
|
||||
|
||||
def kronecker_product(mat1, mat2):
|
||||
"""Computes the Kronecker product two matrices."""
|
||||
m1, n1 = mat1.get_shape().as_list()
|
||||
mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1])
|
||||
m2, n2 = mat2.get_shape().as_list()
|
||||
mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2])
|
||||
return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2])
|
||||
|
||||
|
||||
def layer_params_to_mat2d(vector):
|
||||
"""Converts a vector shaped like layer parameters to a 2D matrix.
|
||||
|
||||
In particular, we reshape the weights/filter component of the vector to be
|
||||
2D, flattening all leading (input) dimensions. If there is a bias component,
|
||||
we concatenate it to the reshaped weights/filter component.
|
||||
|
||||
Args:
|
||||
vector: A Tensor or pair of Tensors shaped like layer parameters.
|
||||
|
||||
Returns:
|
||||
A 2D Tensor with the same coefficients and the same output dimension.
|
||||
"""
|
||||
if isinstance(vector, (tuple, list)):
|
||||
w_part, b_part = vector
|
||||
w_part_reshaped = array_ops.reshape(w_part,
|
||||
[-1, w_part.shape.as_list()[-1]])
|
||||
return array_ops.concat(
|
||||
(w_part_reshaped, array_ops.reshape(b_part, [1, -1])), axis=0)
|
||||
elif isinstance(vector, ops.IndexedSlices):
|
||||
return vector
|
||||
else: # Tensor or Tensor-like.
|
||||
return array_ops.reshape(vector, [-1, vector.shape.as_list()[-1]])
|
||||
|
||||
|
||||
def mat2d_to_layer_params(vector_template, mat2d):
|
||||
"""Converts a canonical 2D matrix representation back to a vector.
|
||||
|
||||
Args:
|
||||
vector_template: A Tensor or pair of Tensors shaped like layer parameters.
|
||||
mat2d: A 2D Tensor with the same shape as the value of
|
||||
layer_params_to_mat2d(vector_template).
|
||||
|
||||
Returns:
|
||||
A Tensor or pair of Tensors with the same coefficients as mat2d and the same
|
||||
shape as vector_template.
|
||||
"""
|
||||
if isinstance(vector_template, (tuple, list)):
|
||||
w_part, b_part = mat2d[:-1], mat2d[-1]
|
||||
return array_ops.reshape(w_part, vector_template[0].shape), b_part
|
||||
elif isinstance(vector_template, ops.IndexedSlices):
|
||||
if not isinstance(mat2d, ops.IndexedSlices):
|
||||
raise TypeError(
|
||||
"If vector_template is an IndexedSlices, so should mat2d.")
|
||||
return mat2d
|
||||
else:
|
||||
return array_ops.reshape(mat2d, vector_template.shape)
|
||||
|
||||
|
||||
def posdef_inv(tensor, damping):
|
||||
"""Computes the inverse of tensor + damping * identity."""
|
||||
identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype)
|
||||
damping = math_ops.cast(damping, dtype=tensor.dtype)
|
||||
return posdef_inv_functions[POSDEF_INV_METHOD](tensor, identity, damping)
|
||||
|
||||
|
||||
def posdef_inv_matrix_inverse(tensor, identity, damping):
|
||||
"""Computes inverse(tensor + damping * identity) directly."""
|
||||
return linalg_ops.matrix_inverse(tensor + damping * identity)
|
||||
|
||||
|
||||
def posdef_inv_cholesky(tensor, identity, damping):
|
||||
"""Computes inverse(tensor + damping * identity) with Cholesky."""
|
||||
chol = linalg_ops.cholesky(tensor + damping * identity)
|
||||
return linalg_ops.cholesky_solve(chol, identity)
|
||||
|
||||
|
||||
def posdef_inv_eig(tensor, identity, damping):
|
||||
"""Computes inverse(tensor + damping * identity) with eigendecomposition."""
|
||||
eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(
|
||||
tensor + damping * identity)
|
||||
return math_ops.matmul(
|
||||
eigenvectors / eigenvalues, eigenvectors, transpose_b=True)
|
||||
|
||||
|
||||
posdef_inv_functions = {
|
||||
"matrix_inverse": posdef_inv_matrix_inverse,
|
||||
"cholesky": posdef_inv_cholesky,
|
||||
"eig": posdef_inv_eig,
|
||||
}
|
||||
|
||||
|
||||
def posdef_eig(mat):
|
||||
"""Computes the eigendecomposition of a positive semidefinite matrix."""
|
||||
return posdef_eig_functions[POSDEF_EIG_METHOD](mat)
|
||||
|
||||
|
||||
def posdef_eig_svd(mat):
|
||||
"""Computes the singular values and left singular vectors of a matrix."""
|
||||
evals, evecs, _ = linalg_ops.svd(mat)
|
||||
|
||||
return evals, evecs
|
||||
|
||||
|
||||
def posdef_eig_self_adjoint(mat):
|
||||
"""Computes eigendecomposition using self_adjoint_eig."""
|
||||
evals, evecs = linalg_ops.self_adjoint_eig(mat)
|
||||
evals = math_ops.abs(evals) # Should be equivalent to svd approach.
|
||||
|
||||
return evals, evecs
|
||||
|
||||
|
||||
posdef_eig_functions = {
|
||||
"self_adjoint": posdef_eig_self_adjoint,
|
||||
"svd": posdef_eig_svd,
|
||||
}
|
||||
|
||||
|
||||
def cholesky(tensor, damping):
|
||||
"""Computes the inverse of tensor + damping * identity."""
|
||||
identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype)
|
||||
damping = math_ops.cast(damping, dtype=tensor.dtype)
|
||||
return linalg_ops.cholesky(tensor + damping * identity)
|
||||
|
||||
|
||||
class SubGraph(object):
|
||||
"""Defines a subgraph given by all the dependencies of a given set of outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, outputs):
|
||||
# Set of all ancestor Tensors, Ops to 'outputs'.
|
||||
self._members = set()
|
||||
|
||||
self._iter_add(outputs)
|
||||
|
||||
def _iter_add(self, root):
|
||||
"""Iteratively adds all of nodes' ancestors using depth first search."""
|
||||
stack = [root]
|
||||
while stack:
|
||||
nodes = stack.pop()
|
||||
for node in nodes:
|
||||
if node in self._members:
|
||||
continue
|
||||
self._members.add(node)
|
||||
|
||||
if isinstance(node, ops.Tensor):
|
||||
stack.append((node.op,))
|
||||
elif isinstance(node, ops.Operation):
|
||||
stack.append(node.inputs)
|
||||
|
||||
def is_member(self, node):
|
||||
"""Check if 'node' is in this subgraph."""
|
||||
return node in self._members
|
||||
|
||||
def variable_uses(self, var):
|
||||
"""Computes number of times a variable is used.
|
||||
|
||||
Args:
|
||||
var: Variable or ResourceVariable instance.
|
||||
|
||||
Returns:
|
||||
Number of times a variable is used within this subgraph.
|
||||
|
||||
Raises:
|
||||
ValueError: If 'var' is not a variable type.
|
||||
"""
|
||||
if isinstance(var, resource_variable_ops.ResourceVariable):
|
||||
var = var.handle
|
||||
elif isinstance(var, variables.Variable):
|
||||
var = var.value()
|
||||
else:
|
||||
raise ValueError("%s does not appear to be a variable." % str(var))
|
||||
|
||||
return len(self._members.intersection(set(var.consumers())))
|
||||
|
||||
def filter_list(self, node_list):
|
||||
"""Filters 'node_list' to nodes in this subgraph."""
|
||||
filtered_list = []
|
||||
for node in node_list:
|
||||
if self.is_member(node):
|
||||
filtered_list.append(node)
|
||||
return filtered_list
|
||||
|
||||
|
||||
def generate_random_signs(shape, dtype=dtypes.float32):
|
||||
"""Generate a random tensor with {-1, +1} entries."""
|
||||
ints = random_ops.random_uniform(shape, maxval=2, dtype=dtypes.int32)
|
||||
return 2 * math_ops.cast(ints, dtype=dtype) - 1
|
||||
|
||||
|
||||
def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None):
|
||||
"""Compute forward-mode gradients."""
|
||||
# See b/37888268.
|
||||
|
||||
# This version of forward-mode autodiff is based on code by Tim Cooijmans
|
||||
# and handles list arguments and certain special cases such as when the
|
||||
# ys doesn't depend on one or more of the xs, and when ops.IndexedSlices are
|
||||
# generated by the first gradients_impl.gradients call.
|
||||
|
||||
us = [array_ops.zeros_like(y) + float("nan") for y in ys]
|
||||
dydxs = gradients_impl.gradients(
|
||||
ys, xs, grad_ys=us, stop_gradients=stop_gradients)
|
||||
|
||||
# Deal with strange types that gradients_impl.gradients returns but can't
|
||||
# deal with.
|
||||
dydxs = [
|
||||
ops.convert_to_tensor(dydx)
|
||||
if isinstance(dydx, ops.IndexedSlices) else dydx for dydx in dydxs
|
||||
]
|
||||
dydxs = [
|
||||
array_ops.zeros_like(x) if dydx is None else dydx
|
||||
for x, dydx in zip(xs, dydxs)
|
||||
]
|
||||
|
||||
dysdx = gradients_impl.gradients(dydxs, us, grad_ys=grad_xs)
|
||||
|
||||
return dysdx
|
||||
|
||||
|
||||
def on_tpu():
|
||||
"""Returns True when building a TPU computation."""
|
||||
return tpu_function.get_tpu_context().number_of_shards is not None
|
||||
|
||||
|
||||
def cross_replica_mean(tensor, name=None):
|
||||
"""Takes mean value of a Tensor across all TPU cores.
|
||||
|
||||
Args:
|
||||
tensor: Tensor to be synchronized.
|
||||
name: None or string. Name of Op.
|
||||
|
||||
Returns:
|
||||
Average of Tensor across all TPU cores.
|
||||
|
||||
Raises:
|
||||
ValueError: If called outside of TPU context.
|
||||
"""
|
||||
with ops.name_scope(name, "cross_replica_mean", [tensor]):
|
||||
num_shards = tpu_function.get_tpu_context().number_of_shards
|
||||
if num_shards is None:
|
||||
raise ValueError(
|
||||
"Cannot take cross_replica_mean() outside of TPU Context.")
|
||||
if num_shards == 1:
|
||||
return tensor
|
||||
return tpu_ops.cross_replica_sum(tensor / num_shards)
|
||||
|
||||
|
||||
def ensure_sequence(obj):
|
||||
"""If `obj` isn't a tuple or list, return a tuple containing `obj`."""
|
||||
if isinstance(obj, (tuple, list)):
|
||||
return obj
|
||||
else:
|
||||
return (obj,)
|
||||
|
||||
|
||||
def batch_execute(global_step, thunks, batch_size, name=None):
|
||||
"""Executes a subset of ops per global step.
|
||||
|
||||
Given a list of thunks, each of which produces a single stateful op,
|
||||
ensures that exactly 'batch_size' ops are run per global step. Ops are
|
||||
scheduled in a round-robin fashion. For example, with 3 ops
|
||||
|
||||
global_step | op0 | op1 | op2
|
||||
------------+-----+-----+-----
|
||||
0 | x | x |
|
||||
------------+-----+-----+-----
|
||||
1 | x | | x
|
||||
------------+-----+-----+-----
|
||||
2 | | x | x
|
||||
------------+-----+-----+-----
|
||||
3 | x | x |
|
||||
------------+-----+-----+-----
|
||||
4 | x | | x
|
||||
|
||||
Does not guarantee order of op execution within a single global step.
|
||||
|
||||
Args:
|
||||
global_step: Tensor indicating time. Determines which ops run.
|
||||
thunks: List of thunks. Each thunk encapsulates one op. Return values are
|
||||
ignored.
|
||||
batch_size: int. Number of ops to execute per global_step.
|
||||
name: string or None. Name scope for newly added ops.
|
||||
|
||||
Returns:
|
||||
List of ops. Exactly 'batch_size' ops are guaranteed to have an effect
|
||||
every global step.
|
||||
"""
|
||||
|
||||
def true_fn(thunk):
|
||||
"""Ensures thunk is executed and returns an Op (not a Tensor)."""
|
||||
|
||||
def result():
|
||||
with ops.control_dependencies([thunk()]):
|
||||
return control_flow_ops.no_op()
|
||||
|
||||
return result
|
||||
|
||||
def false_fn(_):
|
||||
"""Executes a no-op."""
|
||||
|
||||
def result():
|
||||
return control_flow_ops.no_op()
|
||||
|
||||
return result
|
||||
|
||||
with ops.name_scope(name, "batch_execute"):
|
||||
true_fns = [true_fn(thunk) for thunk in thunks]
|
||||
false_fns = [false_fn(thunk) for thunk in thunks]
|
||||
num_thunks = len(thunks)
|
||||
conditions = [
|
||||
math_ops.less(
|
||||
math_ops.mod(batch_size - 1 + global_step * batch_size - j,
|
||||
num_thunks), batch_size) for j in range(num_thunks)
|
||||
]
|
||||
result = [
|
||||
control_flow_ops.cond(condition, true_fn, false_fn)
|
||||
for (condition, true_fn,
|
||||
false_fn) in zip(conditions, true_fns, false_fns)
|
||||
]
|
||||
return result
|
||||
|
||||
|
||||
def extract_convolution_patches(inputs,
|
||||
filter_shape,
|
||||
padding,
|
||||
strides=None,
|
||||
dilation_rate=None,
|
||||
name=None,
|
||||
data_format=None):
|
||||
"""Extracts inputs to each output coordinate in tf.nn.convolution.
|
||||
|
||||
This is a generalization of tf.extract_image_patches() to tf.nn.convolution(),
|
||||
where the number of spatial dimensions may be something other than 2.
|
||||
|
||||
Assumes,
|
||||
- First dimension of inputs is batch_size
|
||||
- Convolution filter is applied to all input channels.
|
||||
|
||||
Args:
|
||||
inputs: Tensor of shape [batch_size, ..spatial_image_shape..,
|
||||
..spatial_filter_shape.., in_channels]. Inputs to tf.nn.convolution().
|
||||
filter_shape: List of ints. Shape of filter passed to tf.nn.convolution().
|
||||
padding: string. Padding method. One of "VALID", "SAME".
|
||||
strides: None or list of ints. Strides along spatial dimensions.
|
||||
dilation_rate: None or list of ints. Dilation along spatial dimensions.
|
||||
name: None or str. Name of Op.
|
||||
data_format: None or str. Format of data.
|
||||
|
||||
Returns:
|
||||
Tensor of shape [batch_size, ..spatial_image_shape..,
|
||||
..spatial_filter_shape.., in_channels]
|
||||
|
||||
Raises:
|
||||
ValueError: If data_format does not put channel last.
|
||||
ValueError: If inputs and filter disagree on in_channels.
|
||||
"""
|
||||
if not is_data_format_channel_last(data_format):
|
||||
raise ValueError("Channel must be last dimension.")
|
||||
with ops.name_scope(name, "extract_convolution_patches",
|
||||
[inputs, filter_shape, padding, strides, dilation_rate]):
|
||||
batch_size = inputs.shape.as_list()[0]
|
||||
in_channels = inputs.shape.as_list()[-1]
|
||||
|
||||
# filter_shape = spatial_filter_shape + [in_channels, out_channels]
|
||||
spatial_filter_shape = filter_shape[:-2]
|
||||
if in_channels != filter_shape[-2]:
|
||||
raise ValueError("inputs and filter_shape must agree on in_channels.")
|
||||
|
||||
# Map each input feature to a location in the output.
|
||||
out_channels = np.prod(spatial_filter_shape) * in_channels
|
||||
filters = linalg_ops.eye(out_channels)
|
||||
filters = array_ops.reshape(
|
||||
filters,
|
||||
list(spatial_filter_shape) + [in_channels, out_channels])
|
||||
|
||||
result = nn_ops.convolution(
|
||||
inputs,
|
||||
filters,
|
||||
padding=padding,
|
||||
strides=strides,
|
||||
dilation_rate=dilation_rate)
|
||||
spatial_output_shape = result.shape.as_list()[1:-1]
|
||||
result = array_ops.reshape(result,
|
||||
[batch_size or -1] + spatial_output_shape +
|
||||
list(spatial_filter_shape) + [in_channels])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def extract_pointwise_conv2d_patches(inputs,
|
||||
filter_shape,
|
||||
name=None,
|
||||
data_format=None):
|
||||
"""Extract patches for a 1x1 conv2d.
|
||||
|
||||
Args:
|
||||
inputs: 4-D Tensor of shape [batch_size, height, width, in_channels].
|
||||
filter_shape: List of 4 ints. Shape of filter to apply with conv2d()
|
||||
name: None or str. Name for Op.
|
||||
data_format: None or str. Format for data. See 'data_format' in
|
||||
tf.nn.conv2d() for details.
|
||||
|
||||
Returns:
|
||||
Tensor of shape [batch_size, ..spatial_input_shape..,
|
||||
..spatial_filter_shape.., in_channels]
|
||||
|
||||
Raises:
|
||||
ValueError: if inputs is not 4-D.
|
||||
ValueError: if filter_shape is not [1, 1, ?, ?]
|
||||
ValueError: if data_format is not channels-last.
|
||||
"""
|
||||
if inputs.shape.ndims != 4:
|
||||
raise ValueError("inputs must have 4 dims.")
|
||||
if len(filter_shape) != 4:
|
||||
raise ValueError("filter_shape must have 4 dims.")
|
||||
if filter_shape[0] != 1 or filter_shape[1] != 1:
|
||||
raise ValueError("filter_shape must have shape 1 along spatial dimensions.")
|
||||
if not is_data_format_channel_last(data_format):
|
||||
raise ValueError("data_format must be channels last.")
|
||||
with ops.name_scope(name, "extract_pointwise_conv2d_patches",
|
||||
[inputs, filter_shape]):
|
||||
ksizes = [1, 1, 1, 1] # Spatial shape is 1x1.
|
||||
strides = [1, 1, 1, 1] # Operate on all pixels.
|
||||
rates = [1, 1, 1, 1] # Dilation has no meaning with spatial shape = 1.
|
||||
padding = "VALID" # Doesn't matter.
|
||||
result = array_ops.extract_image_patches(inputs, ksizes, strides, rates,
|
||||
padding)
|
||||
|
||||
batch_size, input_height, input_width, in_channels = inputs.shape.as_list()
|
||||
filter_height, filter_width, in_channels, _ = filter_shape
|
||||
return array_ops.reshape(result, [
|
||||
batch_size, input_height, input_width, filter_height, filter_width,
|
||||
in_channels
|
||||
])
|
||||
|
||||
|
||||
def is_data_format_channel_last(data_format):
|
||||
"""True if data_format puts channel last."""
|
||||
if data_format is None:
|
||||
return True
|
||||
return data_format.endswith("C")
|
||||
|
||||
|
||||
def matmul_sparse_dense(A, B, name=None, transpose_a=False, transpose_b=False): # pylint: disable=invalid-name
|
||||
"""Computes matmul(A, B) where A is sparse, B is dense.
|
||||
|
||||
Args:
|
||||
A: tf.IndexedSlices with dense shape [m, n].
|
||||
B: tf.Tensor with shape [n, k].
|
||||
name: str. Name of op.
|
||||
transpose_a: Bool. If true we transpose A before multiplying it by B.
|
||||
(Default: False)
|
||||
transpose_b: Bool. If true we transpose B before multiplying it by A.
|
||||
(Default: False)
|
||||
|
||||
Returns:
|
||||
tf.IndexedSlices resulting from matmul(A, B).
|
||||
|
||||
Raises:
|
||||
ValueError: If A doesn't represent a matrix.
|
||||
ValueError: If B is not rank-2.
|
||||
"""
|
||||
with ops.name_scope(name, "matmul_sparse_dense", [A, B]):
|
||||
if A.indices.shape.ndims != 1 or A.values.shape.ndims != 2:
|
||||
raise ValueError("A must represent a matrix. Found: %s." % A)
|
||||
if B.shape.ndims != 2:
|
||||
raise ValueError("B must be a matrix.")
|
||||
new_values = math_ops.matmul(
|
||||
A.values, B, transpose_a=transpose_a, transpose_b=transpose_b)
|
||||
return ops.IndexedSlices(
|
||||
new_values,
|
||||
A.indices,
|
||||
dense_shape=array_ops.stack([A.dense_shape[0], new_values.shape[1]]))
|
||||
|
||||
|
||||
def matmul_diag_sparse(A_diag, B, name=None): # pylint: disable=invalid-name
|
||||
"""Computes matmul(A, B) where A is a diagonal matrix, B is sparse.
|
||||
|
||||
Args:
|
||||
A_diag: diagonal entries of matrix A of shape [m, m].
|
||||
B: tf.IndexedSlices. Represents matrix of shape [m, n].
|
||||
name: str. Name of op.
|
||||
|
||||
Returns:
|
||||
tf.IndexedSlices resulting from matmul(A, B).
|
||||
|
||||
Raises:
|
||||
ValueError: If A_diag is not rank-1.
|
||||
ValueError: If B doesn't represent a matrix.
|
||||
"""
|
||||
with ops.name_scope(name, "matmul_diag_sparse", [A_diag, B]):
|
||||
A_diag = ops.convert_to_tensor(A_diag)
|
||||
if A_diag.shape.ndims != 1:
|
||||
raise ValueError("A_diag must be a rank-1 Tensor.")
|
||||
if B.indices.shape.ndims != 1 or B.values.shape.ndims != 2:
|
||||
raise ValueError("B must represent a matrix. Found: %s." % B)
|
||||
a = array_ops.gather(A_diag, B.indices)
|
||||
a = array_ops.reshape(a, list(a.shape) + [1] * (B.values.shape.ndims - 1))
|
||||
return ops.IndexedSlices(a * B.values, B.indices, dense_shape=B.dense_shape)
|
||||
|
||||
|
||||
class PartitionedTensor(object):
|
||||
"""A Tensor partitioned across its 0-th dimension."""
|
||||
|
||||
def __init__(self, tensors):
|
||||
"""Initializes PartitionedTensor.
|
||||
|
||||
Args:
|
||||
tensors: List of Tensors. All Tensors must agree on shape (excepting
|
||||
batch dimension) and dtype.
|
||||
|
||||
Raises:
|
||||
ValueError: If 'tensors' has length zero.
|
||||
ValueError: if contents of 'tensors' don't agree on shape or dtype.
|
||||
"""
|
||||
if not tensors:
|
||||
raise ValueError("tensors must be a list of 1+ Tensors.")
|
||||
|
||||
dtype = tensors[0].dtype
|
||||
if not all(tensor.dtype == dtype for tensor in tensors):
|
||||
raise ValueError("all tensors must have dtype = %s." % dtype)
|
||||
|
||||
shape = tensors[0].shape[1:]
|
||||
if not all(tensor.shape[1:] == shape for tensor in tensors):
|
||||
raise ValueError("All tensors must have shape = %s (excluding batch "
|
||||
"dimension)." % shape)
|
||||
|
||||
self.tensors = tensors
|
||||
self._concats = {} # {device: Tensor}
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
feature_shape = self.tensors[0].shape[1:]
|
||||
batch_size = sum([tensor.shape[0] for tensor in self.tensors],
|
||||
tensor_shape.Dimension(0))
|
||||
return tensor_shape.TensorShape([batch_size]).concatenate(feature_shape)
|
||||
|
||||
def get_shape(self):
|
||||
return self.shape
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.tensors[0].dtype
|
||||
|
||||
def __str__(self):
|
||||
return "PartitionedTensor([%s, ...], dtype=%s, shape=%s)" % (
|
||||
self.tensors[0].name, self.dtype.name, tuple(self.shape.as_list()))
|
||||
|
||||
def __hash__(self):
|
||||
return hash(tuple(self.tensors))
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, PartitionedTensor):
|
||||
return False
|
||||
return self.tensors == other.tensors
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self == other # pylint: disable=g-comparison-negation
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.as_tensor()[key]
|
||||
|
||||
def as_tensor(self, dtype=None, name=None, as_ref=False):
|
||||
with ops.name_scope(name, "PartitionedTensor.as_tensor", self.tensors):
|
||||
assert not as_ref
|
||||
assert dtype in [None, self.dtype]
|
||||
result = array_ops.concat(self.tensors, axis=0)
|
||||
|
||||
# Cache 'result' if we haven't already cached a value for this device.
|
||||
if result.device not in self._concats:
|
||||
self._concats[result.device] = result
|
||||
return self._concats[result.device]
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# PartitionedTensors in general do not live on a single device. If the
|
||||
# device cannot be determined unambiguously this property will return None.
|
||||
device = self.tensors[0].device
|
||||
if all(tensor.device == device for tensor in self.tensors):
|
||||
return device
|
||||
return None
|
||||
|
||||
|
||||
ops.register_tensor_conversion_function(
|
||||
PartitionedTensor,
|
||||
lambda val, dtype, name, as_ref: val.as_tensor(dtype, name, as_ref))
|
||||
|
||||
|
||||
# TODO(b/69623235): Add a function for finding tensors that share gradients
|
||||
# to eliminate redundant fisher factor computations.
|
@ -1,50 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utility functions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,line-too-long,wildcard-import
|
||||
from tensorflow.contrib.kfac.python.ops.utils import *
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
# pylint: enable=unused-import,line-too-long,wildcard-import
|
||||
|
||||
_allowed_symbols = [
|
||||
"set_global_constants",
|
||||
"SequenceDict",
|
||||
"tensors_to_column",
|
||||
"column_to_tensors",
|
||||
"kronecker_product",
|
||||
"layer_params_to_mat2d",
|
||||
"mat2d_to_layer_params",
|
||||
"posdef_inv",
|
||||
"posdef_inv_matrix_inverse",
|
||||
"posdef_inv_cholesky",
|
||||
"posdef_inv_funcs",
|
||||
"SubGraph",
|
||||
"generate_random_signs",
|
||||
"fwd_gradients",
|
||||
"ensure_sequence",
|
||||
"batch_execute",
|
||||
"extract_convolution_patches",
|
||||
"extract_pointwise_conv2d_patches",
|
||||
"is_data_format_channel_last",
|
||||
"matmul_sparse_dense",
|
||||
"matmul_diag_sparse",
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
|
Loading…
Reference in New Issue
Block a user