Fix copyrights and a few other lint errors.

Change: 123250570
This commit is contained in:
A. Unique TensorFlower 2016-05-25 13:01:01 -08:00 committed by TensorFlower Gardener
parent fd5ebfa768
commit 5c145f0e3c
17 changed files with 294 additions and 256 deletions

View File

@ -1,5 +1,4 @@
"""Main Scikit Flow module.""" # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""High level API for learning with TensorFlow."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function

View File

@ -1,5 +1,4 @@
"""Base utilities for loading datasets.""" # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Base utilities for loading datasets."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function

View File

@ -1,5 +1,4 @@
"""Scikit Flow Estimators.""" # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,12 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Estimators."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib.learn.python.learn.estimators.autoencoder import TensorFlowDNNAutoencoder from tensorflow.contrib.learn.python.learn.estimators.autoencoder import TensorFlowDNNAutoencoder
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowEstimator, TensorFlowBaseTransformer from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowBaseTransformer
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowEstimator
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNClassifier from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNClassifier
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNRegressor from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNRegressor
from tensorflow.contrib.learn.python.learn.estimators.dnn import TensorFlowDNNClassifier from tensorflow.contrib.learn.python.learn.estimators.dnn import TensorFlowDNNClassifier

View File

@ -1,5 +1,4 @@
"""sklearn cross-support.""" # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""sklearn cross-support."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
@ -20,6 +22,8 @@ import collections
import os import os
import numpy as np import numpy as np
import six
def _pprint(d): def _pprint(d):
return ', '.join(['%s=%s' % (key, str(value)) for key, value in d.items()]) return ', '.join(['%s=%s' % (key, str(value)) for key, value in d.items()])
@ -102,6 +106,7 @@ class _BaseEstimator(object):
_pprint(self.get_params(deep=False)),) _pprint(self.get_params(deep=False)),)
# pylint: disable=old-style-class
class _ClassifierMixin(): class _ClassifierMixin():
"""Mixin class for all classifiers.""" """Mixin class for all classifiers."""
pass pass
@ -111,8 +116,10 @@ class _RegressorMixin():
"""Mixin class for all regression estimators.""" """Mixin class for all regression estimators."""
pass pass
class _TransformerMixin(): class _TransformerMixin():
"""Mixin class for all transformer estimators.""" """Mixin class for all transformer estimators."""
class _NotFittedError(ValueError, AttributeError): class _NotFittedError(ValueError, AttributeError):
"""Exception class to raise if estimator is used before fitting. """Exception class to raise if estimator is used before fitting.
@ -134,6 +141,8 @@ class _NotFittedError(ValueError, AttributeError):
https://github.com/scikit-learn/scikit-learn/master/sklearn/exceptions.py https://github.com/scikit-learn/scikit-learn/master/sklearn/exceptions.py
""" """
# pylint: enable=old-style-class
def _accuracy_score(y_true, y_pred): def _accuracy_score(y_true, y_pred):
score = y_true == y_pred score = y_true == y_pred
@ -149,8 +158,7 @@ def _mean_squared_error(y_true, y_pred):
def _train_test_split(*args, **options): def _train_test_split(*args, **options):
n_array = len(args) # pylint: disable=missing-docstring
test_size = options.pop('test_size', None) test_size = options.pop('test_size', None)
train_size = options.pop('train_size', None) train_size = options.pop('train_size', None)
random_state = options.pop('random_state', None) random_state = options.pop('random_state', None)
@ -159,7 +167,7 @@ def _train_test_split(*args, **options):
train_size = 0.75 train_size = 0.75
elif train_size is None: elif train_size is None:
train_size = 1 - test_size train_size = 1 - test_size
train_size = train_size * args[0].shape[0] train_size *= args[0].shape[0]
np.random.seed(random_state) np.random.seed(random_state)
indices = np.random.permutation(args[0].shape[0]) indices = np.random.permutation(args[0].shape[0])
@ -173,6 +181,7 @@ def _train_test_split(*args, **options):
# If "TENSORFLOW_SKLEARN" flag is defined then try to import from sklearn. # If "TENSORFLOW_SKLEARN" flag is defined then try to import from sklearn.
TRY_IMPORT_SKLEARN = os.environ.get('TENSORFLOW_SKLEARN', False) TRY_IMPORT_SKLEARN = os.environ.get('TENSORFLOW_SKLEARN', False)
if TRY_IMPORT_SKLEARN: if TRY_IMPORT_SKLEARN:
# pylint: disable=g-import-not-at-top,g-multiple-import,unused-import
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, TransformerMixin from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, TransformerMixin
from sklearn.metrics import accuracy_score, log_loss, mean_squared_error from sklearn.metrics import accuracy_score, log_loss, mean_squared_error
from sklearn.cross_validation import train_test_split from sklearn.cross_validation import train_test_split

View File

@ -1,5 +1,4 @@
"""Deep Autoencoder estimators.""" # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,105 +11,115 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Deep Autoencoder estimators."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.ops import nn import numpy as np
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowBaseTransformer
from tensorflow.contrib.learn.python.learn import models from tensorflow.contrib.learn.python.learn import models
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowBaseTransformer
from tensorflow.python.ops import nn
class TensorFlowDNNAutoencoder(TensorFlowBaseTransformer): class TensorFlowDNNAutoencoder(TensorFlowBaseTransformer):
"""TensorFlow Autoencoder Regressor model. """TensorFlow Autoencoder Regressor model.
Parameters: Parameters:
hidden_units: List of hidden units per layer. hidden_units: List of hidden units per layer.
batch_size: Mini batch size. batch_size: Mini batch size.
activation: activation function used to map inner latent layer onto activation: activation function used to map inner latent layer onto
reconstruction layer. reconstruction layer.
add_noise: a function that adds noise to tensor_in, add_noise: a function that adds noise to tensor_in,
e.g. def add_noise(x): e.g. def add_noise(x):
return(x + np.random.normal(0, 0.1, (len(x), len(x[0])))) return(x + np.random.normal(0, 0.1, (len(x), len(x[0]))))
steps: Number of steps to run over data. steps: Number of steps to run over data.
optimizer: Optimizer name (or class), for example "SGD", "Adam", optimizer: Optimizer name (or class), for example "SGD", "Adam",
"Adagrad". "Adagrad".
learning_rate: If this is constant float value, no decay function is used. learning_rate: If this is constant float value, no decay function is used.
Instead, a customized decay function can be passed that accepts Instead, a customized decay function can be passed that accepts
global_step as parameter and returns a Tensor. global_step as parameter and returns a Tensor.
e.g. exponential decay function: e.g. exponential decay function:
def exp_decay(global_step): def exp_decay(global_step):
return tf.train.exponential_decay( return tf.train.exponential_decay(
learning_rate=0.1, global_step, learning_rate=0.1, global_step,
decay_steps=2, decay_rate=0.001) decay_steps=2, decay_rate=0.001)
continue_training: when continue_training is True, once initialized continue_training: when continue_training is True, once initialized
model will be continuely trained on every call of fit. model will be continuely trained on every call of fit.
config: RunConfig object that controls the configurations of the session, config: RunConfig object that controls the configurations of the session,
e.g. num_cores, gpu_memory_fraction, etc. e.g. num_cores, gpu_memory_fraction, etc.
verbose: Controls the verbosity, possible values: verbose: Controls the verbosity, possible values:
0: the algorithm and debug information is muted. 0: the algorithm and debug information is muted.
1: trainer prints the progress. 1: trainer prints the progress.
2: log device placement is printed. 2: log device placement is printed.
dropout: When not None, the probability we will drop out a given dropout: When not None, the probability we will drop out a given
coordinate. coordinate.
""" """
def __init__(self, hidden_units, n_classes=0, batch_size=32,
steps=200, optimizer="Adagrad", learning_rate=0.1,
clip_gradients=5.0, activation=nn.relu, add_noise=None,
continue_training=False, config=None,
verbose=1, dropout=None):
self.hidden_units = hidden_units
self.dropout = dropout
self.activation = activation
self.add_noise = add_noise
super(TensorFlowDNNAutoencoder, self).__init__(
model_fn=self._model_fn,
n_classes=n_classes,
batch_size=batch_size, steps=steps, optimizer=optimizer,
learning_rate=learning_rate, clip_gradients=clip_gradients,
continue_training=continue_training,
config=config, verbose=verbose)
def _model_fn(self, X, y): def __init__(self, hidden_units, n_classes=0, batch_size=32,
encoder, decoder, autoencoder_estimator = models.get_autoencoder_model( steps=200, optimizer="Adagrad", learning_rate=0.1,
self.hidden_units, clip_gradients=5.0, activation=nn.relu, add_noise=None,
models.linear_regression, continue_training=False, config=None,
activation=self.activation, verbose=1, dropout=None):
add_noise=self.add_noise, self.hidden_units = hidden_units
dropout=self.dropout)(X) self.dropout = dropout
self.encoder = encoder self.activation = activation
self.decoder = decoder self.add_noise = add_noise
return autoencoder_estimator super(TensorFlowDNNAutoencoder, self).__init__(
model_fn=self._model_fn,
n_classes=n_classes,
batch_size=batch_size, steps=steps, optimizer=optimizer,
learning_rate=learning_rate, clip_gradients=clip_gradients,
continue_training=continue_training,
config=config, verbose=verbose)
def generate(self, hidden=None): def _model_fn(self, X, y):
"""Generate new data using trained construction layer""" encoder, decoder, autoencoder_estimator = models.get_autoencoder_model(
if hidden is None: self.hidden_units,
last_layer = len(self.hidden_units) - 1 models.linear_regression,
bias = self.get_tensor_value('encoder/dnn/layer%d/Linear/Bias:0' % last_layer) activation=self.activation,
import numpy as np add_noise=self.add_noise,
hidden = np.random.normal(size=bias.shape) dropout=self.dropout)(X)
hidden = np.reshape(hidden, (1, len(hidden))) self.encoder = encoder
return self._session.run(self.decoder, feed_dict={self.encoder: hidden}) self.decoder = decoder
return autoencoder_estimator
@property def generate(self, hidden=None):
def weights_(self): """Generate new data using trained construction layer."""
"""Returns weights of the autoencoder's weight layers.""" if hidden is None:
weights = [] last_layer = len(self.hidden_units) - 1
for layer in range(len(self.hidden_units)): bias = self.get_tensor_value(
weights.append(self.get_tensor_value('encoder/dnn/layer%d/Linear/Matrix:0' % layer)) "encoder/dnn/layer%d/Linear/Bias:0" % last_layer)
for layer in range(len(self.hidden_units)): hidden = np.random.normal(size=bias.shape)
weights.append(self.get_tensor_value('decoder/dnn/layer%d/Linear/Matrix:0' % layer)) hidden = np.reshape(hidden, (1, len(hidden)))
weights.append(self.get_tensor_value('linear_regression/weights:0')) return self._session.run(self.decoder, feed_dict={self.encoder: hidden})
return weights
@property @property
def bias_(self): def weights_(self):
"""Returns bias of the autoencoder's bias layers.""" """Returns weights of the autoencoder's weight layers."""
biases = [] weights = []
for layer in range(len(self.hidden_units)): for layer in range(len(self.hidden_units)):
biases.append(self.get_tensor_value('encoder/dnn/layer%d/Linear/Bias:0' % layer)) weights.append(self.get_tensor_value(
for layer in range(len(self.hidden_units)): "encoder/dnn/layer%d/Linear/Matrix:0" % layer))
biases.append(self.get_tensor_value('decoder/dnn/layer%d/Linear/Bias:0' % layer)) for layer in range(len(self.hidden_units)):
biases.append(self.get_tensor_value('linear_regression/bias:0')) weights.append(self.get_tensor_value(
return biases "decoder/dnn/layer%d/Linear/Matrix:0" % layer))
weights.append(self.get_tensor_value("linear_regression/weights:0"))
return weights
@property
def bias_(self):
"""Returns bias of the autoencoder's bias layers."""
biases = []
for layer in range(len(self.hidden_units)):
biases.append(self.get_tensor_value(
"encoder/dnn/layer%d/Linear/Bias:0" % layer))
for layer in range(len(self.hidden_units)):
biases.append(self.get_tensor_value(
"decoder/dnn/layer%d/Linear/Bias:0" % layer))
biases.append(self.get_tensor_value("linear_regression/bias:0"))
return biases

View File

@ -1,5 +1,4 @@
"""Base estimator class.""" # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,18 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Base estimator class."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import datetime
import json import json
import os import os
import shutil
from six import string_types from six import string_types
import numpy as np
from google.protobuf import text_format from google.protobuf import text_format
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile

View File

@ -1,5 +1,4 @@
"""Deep Neural Network estimators.""" # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Deep Neural Network estimators."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function

View File

@ -1,5 +1,4 @@
"""Linear Estimators.""" # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Linear Estimators."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function

View File

@ -1,5 +1,4 @@
"""Recurrent Neural Network estimators.""" # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Recurrent Neural Network estimators."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function

View File

@ -1,6 +1,4 @@
"""Implementations of different data feeders to provide data for TF trainer.""" # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Implementations of different data feeders to provide data for TF trainer."""
# TODO(ipolosukhin): Replace this module with feed-dict queue runners & queues. # TODO(ipolosukhin): Replace this module with feed-dict queue runners & queues.
from __future__ import absolute_import from __future__ import absolute_import

View File

@ -1,5 +1,4 @@
"""Various high level TF models.""" # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,13 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Various high level TF models."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib.learn.python.learn.ops import autoencoder_ops
from tensorflow.contrib.learn.python.learn.ops import dnn_ops from tensorflow.contrib.learn.python.learn.ops import dnn_ops
from tensorflow.contrib.learn.python.learn.ops import losses_ops from tensorflow.contrib.learn.python.learn.ops import losses_ops
from tensorflow.contrib.learn.python.learn.ops import autoencoder_ops
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops as array_ops_ from tensorflow.python.ops import array_ops as array_ops_
@ -29,8 +31,7 @@ from tensorflow.python.ops import variable_scope as vs
def linear_regression_zero_init(X, y): def linear_regression_zero_init(X, y):
"""Creates a linear regression TensorFlow subgraph, in which weights and """Linear regression subgraph with zero-value initial weights and bias.
bias terms are initialized to exactly zero.
Args: Args:
X: tensor or placeholder for input features. X: tensor or placeholder for input features.
@ -43,8 +44,7 @@ def linear_regression_zero_init(X, y):
def logistic_regression_zero_init(X, y): def logistic_regression_zero_init(X, y):
"""Creates a logistic regression TensorFlow subgraph, in which weights and """Logistic regression subgraph with zero-value initial weights and bias.
bias terms are initialized to exactly zero.
Args: Args:
X: tensor or placeholder for input features. X: tensor or placeholder for input features.
@ -85,7 +85,7 @@ def linear_regression(X, y, init_mean=None, init_stddev=1.0):
else: else:
output_shape = y_shape[1] output_shape = y_shape[1]
# Set up the requested initialization. # Set up the requested initialization.
if (init_mean is None): if init_mean is None:
weights = vs.get_variable('weights', [X.get_shape()[1], output_shape]) weights = vs.get_variable('weights', [X.get_shape()[1], output_shape])
bias = vs.get_variable('bias', [output_shape]) bias = vs.get_variable('bias', [output_shape])
else: else:
@ -134,7 +134,7 @@ def logistic_regression(X,
logging_ops.histogram_summary('logistic_regression.X', X) logging_ops.histogram_summary('logistic_regression.X', X)
logging_ops.histogram_summary('logistic_regression.y', y) logging_ops.histogram_summary('logistic_regression.y', y)
# Set up the requested initialization. # Set up the requested initialization.
if (init_mean is None): if init_mean is None:
weights = vs.get_variable('weights', weights = vs.get_variable('weights',
[X.get_shape()[1], y.get_shape()[-1]]) [X.get_shape()[1], y.get_shape()[-1]])
bias = vs.get_variable('bias', [y.get_shape()[-1]]) bias = vs.get_variable('bias', [y.get_shape()[-1]])
@ -188,35 +188,37 @@ def get_dnn_model(hidden_units, target_predictor_fn, dropout=None):
return dnn_estimator return dnn_estimator
def get_autoencoder_model(hidden_units, target_predictor_fn, def get_autoencoder_model(hidden_units, target_predictor_fn,
activation, add_noise=None, dropout=None): activation, add_noise=None, dropout=None):
"""Returns a function that creates a Autoencoder TensorFlow subgraph with given """Returns a function that creates a Autoencoder TensorFlow subgraph.
params.
Args: Args:
hidden_units: List of values of hidden units for layers. hidden_units: List of values of hidden units for layers.
target_predictor_fn: Function that will predict target from input target_predictor_fn: Function that will predict target from input
features. This can be logistic regression, features. This can be logistic regression,
linear regression or any other model, linear regression or any other model,
that takes X, y and returns predictions and loss tensors. that takes X, y and returns predictions and loss
activation: activation function used to map inner latent layer onto tensors.
reconstruction layer. activation: activation function used to map inner latent layer onto
add_noise: a function that adds noise to tensor_in, reconstruction layer.
e.g. def add_noise(x): add_noise: a function that adds noise to tensor_in,
return(x + np.random.normal(0, 0.1, (len(x), len(x[0])))) e.g. def add_noise(x):
dropout: When not none, causes dropout regularization to be used, return(x + np.random.normal(0, 0.1, (len(x), len(x[0]))))
with the specified probability of removing a given coordinate. dropout: When not none, causes dropout regularization to be used,
with the specified probability of removing a given coordinate.
Returns:
A function that creates the subgraph.
"""
def dnn_autoencoder_estimator(X):
"""Autoencoder estimator with target predictor function on top."""
encoder, decoder = autoencoder_ops.dnn_autoencoder(
X, hidden_units, activation,
add_noise=add_noise, dropout=dropout)
return encoder, decoder, target_predictor_fn(X, decoder)
return dnn_autoencoder_estimator
Returns:
A function that creates the subgraph.
"""
def dnn_autoencoder_estimator(X):
"""Autoencoder estimator with target predictor function on top."""
encoder, decoder = autoencoder_ops.dnn_autoencoder(
X, hidden_units, activation,
add_noise=add_noise, dropout=dropout)
return encoder, decoder, target_predictor_fn(X, decoder)
return dnn_autoencoder_estimator
## This will be in Tensorflow 0.7. ## This will be in Tensorflow 0.7.
## TODO(ilblackdragon): Clean this up when it's released ## TODO(ilblackdragon): Clean this up when it's released

View File

@ -1,5 +1,4 @@
"""Monitors to track model training, report on progress and request early stopping""" # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Monitors to track training, report progress and request early stopping."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function

View File

@ -1,4 +1,4 @@
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,147 +12,155 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """This example builds deep residual network for mnist data.
This example builds deep residual network for mnist data.
Reference Paper: http://arxiv.org/pdf/1512.03385.pdf Reference Paper: http://arxiv.org/pdf/1512.03385.pdf
Note that this is still a work-in-progress. Feel free to submit a PR Note that this is still a work-in-progress. Feel free to submit a PR
to make this better. to make this better.
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
from collections import namedtuple from collections import namedtuple
from math import sqrt from math import sqrt
import os
from sklearn import metrics from sklearn import metrics
import tensorflow as tf import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib import learn from tensorflow.contrib import learn
from tensorflow.examples.tutorials.mnist import input_data
def res_net(x, y, activation=tf.nn.relu): def res_net(x, y, activation=tf.nn.relu):
"""Builds a residual network. Note that if the input tensor is 2D, it must be """Builds a residual network.
square in order to be converted to a 4D tensor.
Borrowed structure from here: https://github.com/pkmital/tensorflow_tutorials/blob/master/10_residual_network.py Note that if the input tensor is 2D, it must be square in order to be
converted to a 4D tensor.
Args: Borrowed structure from:
x: Input of the network github.com/pkmital/tensorflow_tutorials/blob/master/10_residual_network.py
y: Output of the network
activation: Activation function to apply after each convolution
"""
# Configurations for each bottleneck block Args:
BottleneckBlock = namedtuple( x: Input of the network
'BottleneckBlock', ['num_layers', 'num_filters', 'bottleneck_size']) y: Output of the network
blocks = [BottleneckBlock(3, 128, 32), activation: Activation function to apply after each convolution
BottleneckBlock(3, 256, 64),
BottleneckBlock(3, 512, 128),
BottleneckBlock(3, 1024, 256)]
input_shape = x.get_shape().as_list() Returns:
Predictions and loss tensors.
"""
# Reshape the input into the right shape if it's 2D tensor # Configurations for each bottleneck block.
if len(input_shape) == 2: BottleneckBlock = namedtuple(
ndim = int(sqrt(input_shape[1])) 'BottleneckBlock', ['num_layers', 'num_filters', 'bottleneck_size'])
x = tf.reshape(x, [-1, ndim, ndim, 1]) blocks = [BottleneckBlock(3, 128, 32),
BottleneckBlock(3, 256, 64),
BottleneckBlock(3, 512, 128),
BottleneckBlock(3, 1024, 256)]
# First convolution expands to 64 channels input_shape = x.get_shape().as_list()
with tf.variable_scope('conv_layer1'):
net = learn.ops.conv2d(x, 64, [7, 7], batch_norm=True,
activation=activation, bias=False)
# Max pool # Reshape the input into the right shape if it's 2D tensor
net = tf.nn.max_pool( if len(input_shape) == 2:
net, [1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME') ndim = int(sqrt(input_shape[1]))
x = tf.reshape(x, [-1, ndim, ndim, 1])
# First chain of resnets # First convolution expands to 64 channels
with tf.variable_scope('conv_layer2'): with tf.variable_scope('conv_layer1'):
net = learn.ops.conv2d(net, blocks[0].num_filters, net = learn.ops.conv2d(x, 64, [7, 7], batch_norm=True,
[1, 1], [1, 1, 1, 1], activation=activation, bias=False)
padding='VALID', bias=True)
# Create each bottleneck building block for each layer # Max pool
for block_i, block in enumerate(blocks): net = tf.nn.max_pool(
for layer_i in range(block.num_layers): net, [1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME')
name = 'block_%d/layer_%d' % (block_i, layer_i) # First chain of resnets
with tf.variable_scope('conv_layer2'):
net = learn.ops.conv2d(net, blocks[0].num_filters,
[1, 1], [1, 1, 1, 1],
padding='VALID', bias=True)
# 1x1 convolution responsible for reducing dimension # Create each bottleneck building block for each layer
with tf.variable_scope(name + '/conv_in'): for block_i, block in enumerate(blocks):
conv = learn.ops.conv2d(net, block.bottleneck_size, for layer_i in range(block.num_layers):
[1, 1], [1, 1, 1, 1],
padding='VALID',
activation=activation,
batch_norm=True,
bias=False)
with tf.variable_scope(name + '/conv_bottleneck'): name = 'block_%d/layer_%d' % (block_i, layer_i)
conv = learn.ops.conv2d(conv, block.bottleneck_size,
[3, 3], [1, 1, 1, 1],
padding='SAME',
activation=activation,
batch_norm=True,
bias=False)
# 1x1 convolution responsible for restoring dimension # 1x1 convolution responsible for reducing dimension
with tf.variable_scope(name + '/conv_out'): with tf.variable_scope(name + '/conv_in'):
conv = learn.ops.conv2d(conv, block.num_filters, conv = learn.ops.conv2d(net, block.bottleneck_size,
[1, 1], [1, 1, 1, 1], [1, 1], [1, 1, 1, 1],
padding='VALID', padding='VALID',
activation=activation, activation=activation,
batch_norm=True, batch_norm=True,
bias=False) bias=False)
# shortcut connections that turn the network into its counterpart with tf.variable_scope(name + '/conv_bottleneck'):
# residual function (identity shortcut) conv = learn.ops.conv2d(conv, block.bottleneck_size,
net = conv + net [3, 3], [1, 1, 1, 1],
padding='SAME',
activation=activation,
batch_norm=True,
bias=False)
try: # 1x1 convolution responsible for restoring dimension
# upscale to the next block size with tf.variable_scope(name + '/conv_out'):
next_block = blocks[block_i + 1] conv = learn.ops.conv2d(conv, block.num_filters,
with tf.variable_scope('block_%d/conv_upscale' % block_i): [1, 1], [1, 1, 1, 1],
net = learn.ops.conv2d(net, next_block.num_filters, padding='VALID',
[1, 1], [1, 1, 1, 1], activation=activation,
bias=False, batch_norm=True,
padding='SAME') bias=False)
except IndexError:
pass
net_shape = net.get_shape().as_list() # shortcut connections that turn the network into its counterpart
net = tf.nn.avg_pool(net, # residual function (identity shortcut)
ksize=[1, net_shape[1], net_shape[2], 1], net = conv + net
strides=[1, 1, 1, 1], padding='VALID')
net_shape = net.get_shape().as_list() try:
net = tf.reshape(net, [-1, net_shape[1] * net_shape[2] * net_shape[3]]) # upscale to the next block size
next_block = blocks[block_i + 1]
with tf.variable_scope('block_%d/conv_upscale' % block_i):
net = learn.ops.conv2d(net, next_block.num_filters,
[1, 1], [1, 1, 1, 1],
bias=False,
padding='SAME')
except IndexError:
pass
return learn.models.logistic_regression(net, y) net_shape = net.get_shape().as_list()
net = tf.nn.avg_pool(net,
ksize=[1, net_shape[1], net_shape[2], 1],
strides=[1, 1, 1, 1], padding='VALID')
net_shape = net.get_shape().as_list()
net = tf.reshape(net, [-1, net_shape[1] * net_shape[2] * net_shape[3]])
return learn.models.logistic_regression(net, y)
# Download and load MNIST data. # Download and load MNIST data.
mnist = input_data.read_data_sets('MNIST_data') mnist = input_data.read_data_sets('MNIST_data')
# Restore model if graph is saved into a folder. # Restore model if graph is saved into a folder.
if os.path.exists("models/resnet/graph.pbtxt"): if os.path.exists('models/resnet/graph.pbtxt'):
classifier = learn.TensorFlowEstimator.restore("models/resnet/") classifier = learn.TensorFlowEstimator.restore('models/resnet/')
else: else:
# Create a new resnet classifier. # Create a new resnet classifier.
classifier = learn.TensorFlowEstimator( classifier = learn.TensorFlowEstimator(
model_fn=res_net, n_classes=10, batch_size=100, steps=100, model_fn=res_net, n_classes=10, batch_size=100, steps=100,
learning_rate=0.001, continue_training=True) learning_rate=0.001, continue_training=True)
while True: while True:
# Train model and save summaries into logdir. # Train model and save summaries into logdir.
classifier.fit(mnist.train.images, mnist.train.labels, logdir="models/resnet/") classifier.fit(
mnist.train.images, mnist.train.labels, logdir='models/resnet/')
# Calculate accuracy. # Calculate accuracy.
score = metrics.accuracy_score( score = metrics.accuracy_score(
mnist.test.labels, classifier.predict(mnist.test.images, batch_size=64)) mnist.test.labels, classifier.predict(mnist.test.images, batch_size=64))
print('Accuracy: {0:f}'.format(score)) print('Accuracy: {0:f}'.format(score))
# Save model graph and checkpoints. # Save model graph and checkpoints.
classifier.save("models/resnet/") classifier.save('models/resnet/')