Fix copyrights and a few other lint errors.

Change: 123250570
This commit is contained in:
A. Unique TensorFlower 2016-05-25 13:01:01 -08:00 committed by TensorFlower Gardener
parent fd5ebfa768
commit 5c145f0e3c
17 changed files with 294 additions and 256 deletions

View File

@ -1,5 +1,4 @@
"""Main Scikit Flow module."""
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""High level API for learning with TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

View File

@ -1,5 +1,4 @@
"""Base utilities for loading datasets."""
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base utilities for loading datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

View File

@ -1,5 +1,4 @@
"""Scikit Flow Estimators."""
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,12 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Estimators."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.learn.python.learn.estimators.autoencoder import TensorFlowDNNAutoencoder
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowEstimator, TensorFlowBaseTransformer
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowBaseTransformer
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowEstimator
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNClassifier
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNRegressor
from tensorflow.contrib.learn.python.learn.estimators.dnn import TensorFlowDNNClassifier

View File

@ -1,5 +1,4 @@
"""sklearn cross-support."""
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""sklearn cross-support."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@ -20,6 +22,8 @@ import collections
import os
import numpy as np
import six
def _pprint(d):
return ', '.join(['%s=%s' % (key, str(value)) for key, value in d.items()])
@ -102,6 +106,7 @@ class _BaseEstimator(object):
_pprint(self.get_params(deep=False)),)
# pylint: disable=old-style-class
class _ClassifierMixin():
"""Mixin class for all classifiers."""
pass
@ -111,8 +116,10 @@ class _RegressorMixin():
"""Mixin class for all regression estimators."""
pass
class _TransformerMixin():
"""Mixin class for all transformer estimators."""
"""Mixin class for all transformer estimators."""
class _NotFittedError(ValueError, AttributeError):
"""Exception class to raise if estimator is used before fitting.
@ -134,6 +141,8 @@ class _NotFittedError(ValueError, AttributeError):
https://github.com/scikit-learn/scikit-learn/master/sklearn/exceptions.py
"""
# pylint: enable=old-style-class
def _accuracy_score(y_true, y_pred):
score = y_true == y_pred
@ -149,8 +158,7 @@ def _mean_squared_error(y_true, y_pred):
def _train_test_split(*args, **options):
n_array = len(args)
# pylint: disable=missing-docstring
test_size = options.pop('test_size', None)
train_size = options.pop('train_size', None)
random_state = options.pop('random_state', None)
@ -159,7 +167,7 @@ def _train_test_split(*args, **options):
train_size = 0.75
elif train_size is None:
train_size = 1 - test_size
train_size = train_size * args[0].shape[0]
train_size *= args[0].shape[0]
np.random.seed(random_state)
indices = np.random.permutation(args[0].shape[0])
@ -173,6 +181,7 @@ def _train_test_split(*args, **options):
# If "TENSORFLOW_SKLEARN" flag is defined then try to import from sklearn.
TRY_IMPORT_SKLEARN = os.environ.get('TENSORFLOW_SKLEARN', False)
if TRY_IMPORT_SKLEARN:
# pylint: disable=g-import-not-at-top,g-multiple-import,unused-import
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, TransformerMixin
from sklearn.metrics import accuracy_score, log_loss, mean_squared_error
from sklearn.cross_validation import train_test_split

View File

@ -1,5 +1,4 @@
"""Deep Autoencoder estimators."""
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,105 +11,115 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deep Autoencoder estimators."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.ops import nn
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowBaseTransformer
import numpy as np
from tensorflow.contrib.learn.python.learn import models
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowBaseTransformer
from tensorflow.python.ops import nn
class TensorFlowDNNAutoencoder(TensorFlowBaseTransformer):
"""TensorFlow Autoencoder Regressor model.
"""TensorFlow Autoencoder Regressor model.
Parameters:
hidden_units: List of hidden units per layer.
batch_size: Mini batch size.
activation: activation function used to map inner latent layer onto
reconstruction layer.
add_noise: a function that adds noise to tensor_in,
e.g. def add_noise(x):
return(x + np.random.normal(0, 0.1, (len(x), len(x[0]))))
steps: Number of steps to run over data.
optimizer: Optimizer name (or class), for example "SGD", "Adam",
"Adagrad".
learning_rate: If this is constant float value, no decay function is used.
Instead, a customized decay function can be passed that accepts
global_step as parameter and returns a Tensor.
e.g. exponential decay function:
def exp_decay(global_step):
return tf.train.exponential_decay(
learning_rate=0.1, global_step,
decay_steps=2, decay_rate=0.001)
continue_training: when continue_training is True, once initialized
model will be continuely trained on every call of fit.
config: RunConfig object that controls the configurations of the session,
e.g. num_cores, gpu_memory_fraction, etc.
verbose: Controls the verbosity, possible values:
0: the algorithm and debug information is muted.
1: trainer prints the progress.
2: log device placement is printed.
dropout: When not None, the probability we will drop out a given
coordinate.
"""
def __init__(self, hidden_units, n_classes=0, batch_size=32,
steps=200, optimizer="Adagrad", learning_rate=0.1,
clip_gradients=5.0, activation=nn.relu, add_noise=None,
continue_training=False, config=None,
verbose=1, dropout=None):
self.hidden_units = hidden_units
self.dropout = dropout
self.activation = activation
self.add_noise = add_noise
super(TensorFlowDNNAutoencoder, self).__init__(
model_fn=self._model_fn,
n_classes=n_classes,
batch_size=batch_size, steps=steps, optimizer=optimizer,
learning_rate=learning_rate, clip_gradients=clip_gradients,
continue_training=continue_training,
config=config, verbose=verbose)
Parameters:
hidden_units: List of hidden units per layer.
batch_size: Mini batch size.
activation: activation function used to map inner latent layer onto
reconstruction layer.
add_noise: a function that adds noise to tensor_in,
e.g. def add_noise(x):
return(x + np.random.normal(0, 0.1, (len(x), len(x[0]))))
steps: Number of steps to run over data.
optimizer: Optimizer name (or class), for example "SGD", "Adam",
"Adagrad".
learning_rate: If this is constant float value, no decay function is used.
Instead, a customized decay function can be passed that accepts
global_step as parameter and returns a Tensor.
e.g. exponential decay function:
def exp_decay(global_step):
return tf.train.exponential_decay(
learning_rate=0.1, global_step,
decay_steps=2, decay_rate=0.001)
continue_training: when continue_training is True, once initialized
model will be continuely trained on every call of fit.
config: RunConfig object that controls the configurations of the session,
e.g. num_cores, gpu_memory_fraction, etc.
verbose: Controls the verbosity, possible values:
0: the algorithm and debug information is muted.
1: trainer prints the progress.
2: log device placement is printed.
dropout: When not None, the probability we will drop out a given
coordinate.
"""
def _model_fn(self, X, y):
encoder, decoder, autoencoder_estimator = models.get_autoencoder_model(
self.hidden_units,
models.linear_regression,
activation=self.activation,
add_noise=self.add_noise,
dropout=self.dropout)(X)
self.encoder = encoder
self.decoder = decoder
return autoencoder_estimator
def __init__(self, hidden_units, n_classes=0, batch_size=32,
steps=200, optimizer="Adagrad", learning_rate=0.1,
clip_gradients=5.0, activation=nn.relu, add_noise=None,
continue_training=False, config=None,
verbose=1, dropout=None):
self.hidden_units = hidden_units
self.dropout = dropout
self.activation = activation
self.add_noise = add_noise
super(TensorFlowDNNAutoencoder, self).__init__(
model_fn=self._model_fn,
n_classes=n_classes,
batch_size=batch_size, steps=steps, optimizer=optimizer,
learning_rate=learning_rate, clip_gradients=clip_gradients,
continue_training=continue_training,
config=config, verbose=verbose)
def generate(self, hidden=None):
"""Generate new data using trained construction layer"""
if hidden is None:
last_layer = len(self.hidden_units) - 1
bias = self.get_tensor_value('encoder/dnn/layer%d/Linear/Bias:0' % last_layer)
import numpy as np
hidden = np.random.normal(size=bias.shape)
hidden = np.reshape(hidden, (1, len(hidden)))
return self._session.run(self.decoder, feed_dict={self.encoder: hidden})
def _model_fn(self, X, y):
encoder, decoder, autoencoder_estimator = models.get_autoencoder_model(
self.hidden_units,
models.linear_regression,
activation=self.activation,
add_noise=self.add_noise,
dropout=self.dropout)(X)
self.encoder = encoder
self.decoder = decoder
return autoencoder_estimator
@property
def weights_(self):
"""Returns weights of the autoencoder's weight layers."""
weights = []
for layer in range(len(self.hidden_units)):
weights.append(self.get_tensor_value('encoder/dnn/layer%d/Linear/Matrix:0' % layer))
for layer in range(len(self.hidden_units)):
weights.append(self.get_tensor_value('decoder/dnn/layer%d/Linear/Matrix:0' % layer))
weights.append(self.get_tensor_value('linear_regression/weights:0'))
return weights
def generate(self, hidden=None):
"""Generate new data using trained construction layer."""
if hidden is None:
last_layer = len(self.hidden_units) - 1
bias = self.get_tensor_value(
"encoder/dnn/layer%d/Linear/Bias:0" % last_layer)
hidden = np.random.normal(size=bias.shape)
hidden = np.reshape(hidden, (1, len(hidden)))
return self._session.run(self.decoder, feed_dict={self.encoder: hidden})
@property
def bias_(self):
"""Returns bias of the autoencoder's bias layers."""
biases = []
for layer in range(len(self.hidden_units)):
biases.append(self.get_tensor_value('encoder/dnn/layer%d/Linear/Bias:0' % layer))
for layer in range(len(self.hidden_units)):
biases.append(self.get_tensor_value('decoder/dnn/layer%d/Linear/Bias:0' % layer))
biases.append(self.get_tensor_value('linear_regression/bias:0'))
return biases
@property
def weights_(self):
"""Returns weights of the autoencoder's weight layers."""
weights = []
for layer in range(len(self.hidden_units)):
weights.append(self.get_tensor_value(
"encoder/dnn/layer%d/Linear/Matrix:0" % layer))
for layer in range(len(self.hidden_units)):
weights.append(self.get_tensor_value(
"decoder/dnn/layer%d/Linear/Matrix:0" % layer))
weights.append(self.get_tensor_value("linear_regression/weights:0"))
return weights
@property
def bias_(self):
"""Returns bias of the autoencoder's bias layers."""
biases = []
for layer in range(len(self.hidden_units)):
biases.append(self.get_tensor_value(
"encoder/dnn/layer%d/Linear/Bias:0" % layer))
for layer in range(len(self.hidden_units)):
biases.append(self.get_tensor_value(
"decoder/dnn/layer%d/Linear/Bias:0" % layer))
biases.append(self.get_tensor_value("linear_regression/bias:0"))
return biases

View File

@ -1,5 +1,4 @@
"""Base estimator class."""
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,18 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base estimator class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import datetime
import json
import os
import shutil
from six import string_types
import numpy as np
from google.protobuf import text_format
from tensorflow.python.platform import gfile

View File

@ -1,5 +1,4 @@
"""Deep Neural Network estimators."""
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deep Neural Network estimators."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

View File

@ -1,5 +1,4 @@
"""Linear Estimators."""
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Linear Estimators."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

View File

@ -1,5 +1,4 @@
"""Recurrent Neural Network estimators."""
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Recurrent Neural Network estimators."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

View File

@ -1,6 +1,4 @@
"""Implementations of different data feeders to provide data for TF trainer."""
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementations of different data feeders to provide data for TF trainer."""
# TODO(ipolosukhin): Replace this module with feed-dict queue runners & queues.
from __future__ import absolute_import

View File

@ -1,5 +1,4 @@
"""Various high level TF models."""
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,13 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Various high level TF models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.learn.python.learn.ops import autoencoder_ops
from tensorflow.contrib.learn.python.learn.ops import dnn_ops
from tensorflow.contrib.learn.python.learn.ops import losses_ops
from tensorflow.contrib.learn.python.learn.ops import autoencoder_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops as array_ops_
@ -29,8 +31,7 @@ from tensorflow.python.ops import variable_scope as vs
def linear_regression_zero_init(X, y):
"""Creates a linear regression TensorFlow subgraph, in which weights and
bias terms are initialized to exactly zero.
"""Linear regression subgraph with zero-value initial weights and bias.
Args:
X: tensor or placeholder for input features.
@ -43,8 +44,7 @@ def linear_regression_zero_init(X, y):
def logistic_regression_zero_init(X, y):
"""Creates a logistic regression TensorFlow subgraph, in which weights and
bias terms are initialized to exactly zero.
"""Logistic regression subgraph with zero-value initial weights and bias.
Args:
X: tensor or placeholder for input features.
@ -85,7 +85,7 @@ def linear_regression(X, y, init_mean=None, init_stddev=1.0):
else:
output_shape = y_shape[1]
# Set up the requested initialization.
if (init_mean is None):
if init_mean is None:
weights = vs.get_variable('weights', [X.get_shape()[1], output_shape])
bias = vs.get_variable('bias', [output_shape])
else:
@ -134,7 +134,7 @@ def logistic_regression(X,
logging_ops.histogram_summary('logistic_regression.X', X)
logging_ops.histogram_summary('logistic_regression.y', y)
# Set up the requested initialization.
if (init_mean is None):
if init_mean is None:
weights = vs.get_variable('weights',
[X.get_shape()[1], y.get_shape()[-1]])
bias = vs.get_variable('bias', [y.get_shape()[-1]])
@ -188,35 +188,37 @@ def get_dnn_model(hidden_units, target_predictor_fn, dropout=None):
return dnn_estimator
def get_autoencoder_model(hidden_units, target_predictor_fn,
activation, add_noise=None, dropout=None):
"""Returns a function that creates a Autoencoder TensorFlow subgraph with given
params.
"""Returns a function that creates a Autoencoder TensorFlow subgraph.
Args:
hidden_units: List of values of hidden units for layers.
target_predictor_fn: Function that will predict target from input
features. This can be logistic regression,
linear regression or any other model,
that takes X, y and returns predictions and loss tensors.
activation: activation function used to map inner latent layer onto
reconstruction layer.
add_noise: a function that adds noise to tensor_in,
e.g. def add_noise(x):
return(x + np.random.normal(0, 0.1, (len(x), len(x[0]))))
dropout: When not none, causes dropout regularization to be used,
with the specified probability of removing a given coordinate.
Args:
hidden_units: List of values of hidden units for layers.
target_predictor_fn: Function that will predict target from input
features. This can be logistic regression,
linear regression or any other model,
that takes X, y and returns predictions and loss
tensors.
activation: activation function used to map inner latent layer onto
reconstruction layer.
add_noise: a function that adds noise to tensor_in,
e.g. def add_noise(x):
return(x + np.random.normal(0, 0.1, (len(x), len(x[0]))))
dropout: When not none, causes dropout regularization to be used,
with the specified probability of removing a given coordinate.
Returns:
A function that creates the subgraph.
"""
def dnn_autoencoder_estimator(X):
"""Autoencoder estimator with target predictor function on top."""
encoder, decoder = autoencoder_ops.dnn_autoencoder(
X, hidden_units, activation,
add_noise=add_noise, dropout=dropout)
return encoder, decoder, target_predictor_fn(X, decoder)
return dnn_autoencoder_estimator
Returns:
A function that creates the subgraph.
"""
def dnn_autoencoder_estimator(X):
"""Autoencoder estimator with target predictor function on top."""
encoder, decoder = autoencoder_ops.dnn_autoencoder(
X, hidden_units, activation,
add_noise=add_noise, dropout=dropout)
return encoder, decoder, target_predictor_fn(X, decoder)
return dnn_autoencoder_estimator
## This will be in Tensorflow 0.7.
## TODO(ilblackdragon): Clean this up when it's released

View File

@ -1,5 +1,4 @@
"""Monitors to track model training, report on progress and request early stopping"""
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Monitors to track training, report progress and request early stopping."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

View File

@ -1,4 +1,4 @@
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,147 +12,155 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This example builds deep residual network for mnist data.
"""This example builds deep residual network for mnist data.
Reference Paper: http://arxiv.org/pdf/1512.03385.pdf
Note that this is still a work-in-progress. Feel free to submit a PR
to make this better.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from collections import namedtuple
from math import sqrt
import os
from sklearn import metrics
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib import learn
from tensorflow.examples.tutorials.mnist import input_data
def res_net(x, y, activation=tf.nn.relu):
"""Builds a residual network. Note that if the input tensor is 2D, it must be
square in order to be converted to a 4D tensor.
"""Builds a residual network.
Borrowed structure from here: https://github.com/pkmital/tensorflow_tutorials/blob/master/10_residual_network.py
Note that if the input tensor is 2D, it must be square in order to be
converted to a 4D tensor.
Args:
x: Input of the network
y: Output of the network
activation: Activation function to apply after each convolution
"""
Borrowed structure from:
github.com/pkmital/tensorflow_tutorials/blob/master/10_residual_network.py
# Configurations for each bottleneck block
BottleneckBlock = namedtuple(
'BottleneckBlock', ['num_layers', 'num_filters', 'bottleneck_size'])
blocks = [BottleneckBlock(3, 128, 32),
BottleneckBlock(3, 256, 64),
BottleneckBlock(3, 512, 128),
BottleneckBlock(3, 1024, 256)]
Args:
x: Input of the network
y: Output of the network
activation: Activation function to apply after each convolution
input_shape = x.get_shape().as_list()
Returns:
Predictions and loss tensors.
"""
# Reshape the input into the right shape if it's 2D tensor
if len(input_shape) == 2:
ndim = int(sqrt(input_shape[1]))
x = tf.reshape(x, [-1, ndim, ndim, 1])
# Configurations for each bottleneck block.
BottleneckBlock = namedtuple(
'BottleneckBlock', ['num_layers', 'num_filters', 'bottleneck_size'])
blocks = [BottleneckBlock(3, 128, 32),
BottleneckBlock(3, 256, 64),
BottleneckBlock(3, 512, 128),
BottleneckBlock(3, 1024, 256)]
# First convolution expands to 64 channels
with tf.variable_scope('conv_layer1'):
net = learn.ops.conv2d(x, 64, [7, 7], batch_norm=True,
activation=activation, bias=False)
input_shape = x.get_shape().as_list()
# Max pool
net = tf.nn.max_pool(
net, [1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME')
# Reshape the input into the right shape if it's 2D tensor
if len(input_shape) == 2:
ndim = int(sqrt(input_shape[1]))
x = tf.reshape(x, [-1, ndim, ndim, 1])
# First chain of resnets
with tf.variable_scope('conv_layer2'):
net = learn.ops.conv2d(net, blocks[0].num_filters,
[1, 1], [1, 1, 1, 1],
padding='VALID', bias=True)
# First convolution expands to 64 channels
with tf.variable_scope('conv_layer1'):
net = learn.ops.conv2d(x, 64, [7, 7], batch_norm=True,
activation=activation, bias=False)
# Create each bottleneck building block for each layer
for block_i, block in enumerate(blocks):
for layer_i in range(block.num_layers):
# Max pool
net = tf.nn.max_pool(
net, [1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME')
name = 'block_%d/layer_%d' % (block_i, layer_i)
# First chain of resnets
with tf.variable_scope('conv_layer2'):
net = learn.ops.conv2d(net, blocks[0].num_filters,
[1, 1], [1, 1, 1, 1],
padding='VALID', bias=True)
# 1x1 convolution responsible for reducing dimension
with tf.variable_scope(name + '/conv_in'):
conv = learn.ops.conv2d(net, block.bottleneck_size,
[1, 1], [1, 1, 1, 1],
padding='VALID',
activation=activation,
batch_norm=True,
bias=False)
# Create each bottleneck building block for each layer
for block_i, block in enumerate(blocks):
for layer_i in range(block.num_layers):
with tf.variable_scope(name + '/conv_bottleneck'):
conv = learn.ops.conv2d(conv, block.bottleneck_size,
[3, 3], [1, 1, 1, 1],
padding='SAME',
activation=activation,
batch_norm=True,
bias=False)
name = 'block_%d/layer_%d' % (block_i, layer_i)
# 1x1 convolution responsible for restoring dimension
with tf.variable_scope(name + '/conv_out'):
conv = learn.ops.conv2d(conv, block.num_filters,
[1, 1], [1, 1, 1, 1],
padding='VALID',
activation=activation,
batch_norm=True,
bias=False)
# 1x1 convolution responsible for reducing dimension
with tf.variable_scope(name + '/conv_in'):
conv = learn.ops.conv2d(net, block.bottleneck_size,
[1, 1], [1, 1, 1, 1],
padding='VALID',
activation=activation,
batch_norm=True,
bias=False)
# shortcut connections that turn the network into its counterpart
# residual function (identity shortcut)
net = conv + net
with tf.variable_scope(name + '/conv_bottleneck'):
conv = learn.ops.conv2d(conv, block.bottleneck_size,
[3, 3], [1, 1, 1, 1],
padding='SAME',
activation=activation,
batch_norm=True,
bias=False)
try:
# upscale to the next block size
next_block = blocks[block_i + 1]
with tf.variable_scope('block_%d/conv_upscale' % block_i):
net = learn.ops.conv2d(net, next_block.num_filters,
[1, 1], [1, 1, 1, 1],
bias=False,
padding='SAME')
except IndexError:
pass
# 1x1 convolution responsible for restoring dimension
with tf.variable_scope(name + '/conv_out'):
conv = learn.ops.conv2d(conv, block.num_filters,
[1, 1], [1, 1, 1, 1],
padding='VALID',
activation=activation,
batch_norm=True,
bias=False)
net_shape = net.get_shape().as_list()
net = tf.nn.avg_pool(net,
ksize=[1, net_shape[1], net_shape[2], 1],
strides=[1, 1, 1, 1], padding='VALID')
# shortcut connections that turn the network into its counterpart
# residual function (identity shortcut)
net = conv + net
net_shape = net.get_shape().as_list()
net = tf.reshape(net, [-1, net_shape[1] * net_shape[2] * net_shape[3]])
try:
# upscale to the next block size
next_block = blocks[block_i + 1]
with tf.variable_scope('block_%d/conv_upscale' % block_i):
net = learn.ops.conv2d(net, next_block.num_filters,
[1, 1], [1, 1, 1, 1],
bias=False,
padding='SAME')
except IndexError:
pass
return learn.models.logistic_regression(net, y)
net_shape = net.get_shape().as_list()
net = tf.nn.avg_pool(net,
ksize=[1, net_shape[1], net_shape[2], 1],
strides=[1, 1, 1, 1], padding='VALID')
net_shape = net.get_shape().as_list()
net = tf.reshape(net, [-1, net_shape[1] * net_shape[2] * net_shape[3]])
return learn.models.logistic_regression(net, y)
# Download and load MNIST data.
mnist = input_data.read_data_sets('MNIST_data')
# Restore model if graph is saved into a folder.
if os.path.exists("models/resnet/graph.pbtxt"):
classifier = learn.TensorFlowEstimator.restore("models/resnet/")
if os.path.exists('models/resnet/graph.pbtxt'):
classifier = learn.TensorFlowEstimator.restore('models/resnet/')
else:
# Create a new resnet classifier.
classifier = learn.TensorFlowEstimator(
model_fn=res_net, n_classes=10, batch_size=100, steps=100,
learning_rate=0.001, continue_training=True)
# Create a new resnet classifier.
classifier = learn.TensorFlowEstimator(
model_fn=res_net, n_classes=10, batch_size=100, steps=100,
learning_rate=0.001, continue_training=True)
while True:
# Train model and save summaries into logdir.
classifier.fit(mnist.train.images, mnist.train.labels, logdir="models/resnet/")
# Train model and save summaries into logdir.
classifier.fit(
mnist.train.images, mnist.train.labels, logdir='models/resnet/')
# Calculate accuracy.
score = metrics.accuracy_score(
mnist.test.labels, classifier.predict(mnist.test.images, batch_size=64))
print('Accuracy: {0:f}'.format(score))
# Calculate accuracy.
score = metrics.accuracy_score(
mnist.test.labels, classifier.predict(mnist.test.images, batch_size=64))
print('Accuracy: {0:f}'.format(score))
# Save model graph and checkpoints.
classifier.save("models/resnet/")
# Save model graph and checkpoints.
classifier.save('models/resnet/')