Fix copyrights and a few other lint errors.
Change: 123250570
This commit is contained in:
parent
fd5ebfa768
commit
5c145f0e3c
@ -1,5 +1,4 @@
|
||||
"""Main Scikit Flow module."""
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -13,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""High level API for learning with TensorFlow."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
@ -1,5 +1,4 @@
|
||||
"""Base utilities for loading datasets."""
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -13,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Base utilities for loading datasets."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
@ -1,5 +1,4 @@
|
||||
"""Scikit Flow Estimators."""
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,12 +11,16 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Estimators."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.estimators.autoencoder import TensorFlowDNNAutoencoder
|
||||
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowEstimator, TensorFlowBaseTransformer
|
||||
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowBaseTransformer
|
||||
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowEstimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNClassifier
|
||||
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNRegressor
|
||||
from tensorflow.contrib.learn.python.learn.estimators.dnn import TensorFlowDNNClassifier
|
||||
|
@ -1,5 +1,4 @@
|
||||
"""sklearn cross-support."""
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,6 +11,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""sklearn cross-support."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@ -20,6 +22,8 @@ import collections
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
|
||||
def _pprint(d):
|
||||
return ', '.join(['%s=%s' % (key, str(value)) for key, value in d.items()])
|
||||
@ -102,6 +106,7 @@ class _BaseEstimator(object):
|
||||
_pprint(self.get_params(deep=False)),)
|
||||
|
||||
|
||||
# pylint: disable=old-style-class
|
||||
class _ClassifierMixin():
|
||||
"""Mixin class for all classifiers."""
|
||||
pass
|
||||
@ -111,9 +116,11 @@ class _RegressorMixin():
|
||||
"""Mixin class for all regression estimators."""
|
||||
pass
|
||||
|
||||
|
||||
class _TransformerMixin():
|
||||
"""Mixin class for all transformer estimators."""
|
||||
|
||||
|
||||
class _NotFittedError(ValueError, AttributeError):
|
||||
"""Exception class to raise if estimator is used before fitting.
|
||||
|
||||
@ -134,6 +141,8 @@ class _NotFittedError(ValueError, AttributeError):
|
||||
https://github.com/scikit-learn/scikit-learn/master/sklearn/exceptions.py
|
||||
"""
|
||||
|
||||
# pylint: enable=old-style-class
|
||||
|
||||
|
||||
def _accuracy_score(y_true, y_pred):
|
||||
score = y_true == y_pred
|
||||
@ -149,8 +158,7 @@ def _mean_squared_error(y_true, y_pred):
|
||||
|
||||
|
||||
def _train_test_split(*args, **options):
|
||||
n_array = len(args)
|
||||
|
||||
# pylint: disable=missing-docstring
|
||||
test_size = options.pop('test_size', None)
|
||||
train_size = options.pop('train_size', None)
|
||||
random_state = options.pop('random_state', None)
|
||||
@ -159,7 +167,7 @@ def _train_test_split(*args, **options):
|
||||
train_size = 0.75
|
||||
elif train_size is None:
|
||||
train_size = 1 - test_size
|
||||
train_size = train_size * args[0].shape[0]
|
||||
train_size *= args[0].shape[0]
|
||||
|
||||
np.random.seed(random_state)
|
||||
indices = np.random.permutation(args[0].shape[0])
|
||||
@ -173,6 +181,7 @@ def _train_test_split(*args, **options):
|
||||
# If "TENSORFLOW_SKLEARN" flag is defined then try to import from sklearn.
|
||||
TRY_IMPORT_SKLEARN = os.environ.get('TENSORFLOW_SKLEARN', False)
|
||||
if TRY_IMPORT_SKLEARN:
|
||||
# pylint: disable=g-import-not-at-top,g-multiple-import,unused-import
|
||||
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, TransformerMixin
|
||||
from sklearn.metrics import accuracy_score, log_loss, mean_squared_error
|
||||
from sklearn.cross_validation import train_test_split
|
||||
|
@ -1,5 +1,4 @@
|
||||
"""Deep Autoencoder estimators."""
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,13 +11,18 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Deep Autoencoder estimators."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowBaseTransformer
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.learn.python.learn import models
|
||||
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowBaseTransformer
|
||||
from tensorflow.python.ops import nn
|
||||
|
||||
|
||||
class TensorFlowDNNAutoencoder(TensorFlowBaseTransformer):
|
||||
@ -54,6 +58,7 @@ class TensorFlowDNNAutoencoder(TensorFlowBaseTransformer):
|
||||
dropout: When not None, the probability we will drop out a given
|
||||
coordinate.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_units, n_classes=0, batch_size=32,
|
||||
steps=200, optimizer="Adagrad", learning_rate=0.1,
|
||||
clip_gradients=5.0, activation=nn.relu, add_noise=None,
|
||||
@ -83,11 +88,11 @@ class TensorFlowDNNAutoencoder(TensorFlowBaseTransformer):
|
||||
return autoencoder_estimator
|
||||
|
||||
def generate(self, hidden=None):
|
||||
"""Generate new data using trained construction layer"""
|
||||
"""Generate new data using trained construction layer."""
|
||||
if hidden is None:
|
||||
last_layer = len(self.hidden_units) - 1
|
||||
bias = self.get_tensor_value('encoder/dnn/layer%d/Linear/Bias:0' % last_layer)
|
||||
import numpy as np
|
||||
bias = self.get_tensor_value(
|
||||
"encoder/dnn/layer%d/Linear/Bias:0" % last_layer)
|
||||
hidden = np.random.normal(size=bias.shape)
|
||||
hidden = np.reshape(hidden, (1, len(hidden)))
|
||||
return self._session.run(self.decoder, feed_dict={self.encoder: hidden})
|
||||
@ -97,10 +102,12 @@ class TensorFlowDNNAutoencoder(TensorFlowBaseTransformer):
|
||||
"""Returns weights of the autoencoder's weight layers."""
|
||||
weights = []
|
||||
for layer in range(len(self.hidden_units)):
|
||||
weights.append(self.get_tensor_value('encoder/dnn/layer%d/Linear/Matrix:0' % layer))
|
||||
weights.append(self.get_tensor_value(
|
||||
"encoder/dnn/layer%d/Linear/Matrix:0" % layer))
|
||||
for layer in range(len(self.hidden_units)):
|
||||
weights.append(self.get_tensor_value('decoder/dnn/layer%d/Linear/Matrix:0' % layer))
|
||||
weights.append(self.get_tensor_value('linear_regression/weights:0'))
|
||||
weights.append(self.get_tensor_value(
|
||||
"decoder/dnn/layer%d/Linear/Matrix:0" % layer))
|
||||
weights.append(self.get_tensor_value("linear_regression/weights:0"))
|
||||
return weights
|
||||
|
||||
@property
|
||||
@ -108,9 +115,11 @@ class TensorFlowDNNAutoencoder(TensorFlowBaseTransformer):
|
||||
"""Returns bias of the autoencoder's bias layers."""
|
||||
biases = []
|
||||
for layer in range(len(self.hidden_units)):
|
||||
biases.append(self.get_tensor_value('encoder/dnn/layer%d/Linear/Bias:0' % layer))
|
||||
biases.append(self.get_tensor_value(
|
||||
"encoder/dnn/layer%d/Linear/Bias:0" % layer))
|
||||
for layer in range(len(self.hidden_units)):
|
||||
biases.append(self.get_tensor_value('decoder/dnn/layer%d/Linear/Bias:0' % layer))
|
||||
biases.append(self.get_tensor_value('linear_regression/bias:0'))
|
||||
biases.append(self.get_tensor_value(
|
||||
"decoder/dnn/layer%d/Linear/Bias:0" % layer))
|
||||
biases.append(self.get_tensor_value("linear_regression/bias:0"))
|
||||
return biases
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
"""Base estimator class."""
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,18 +11,17 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Base estimator class."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from six import string_types
|
||||
|
||||
import numpy as np
|
||||
|
||||
from google.protobuf import text_format
|
||||
from tensorflow.python.platform import gfile
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
"""Deep Neural Network estimators."""
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,6 +11,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Deep Neural Network estimators."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
@ -1,5 +1,4 @@
|
||||
"""Linear Estimators."""
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,6 +11,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Linear Estimators."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
@ -1,5 +1,4 @@
|
||||
"""Recurrent Neural Network estimators."""
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,6 +11,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Recurrent Neural Network estimators."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
@ -1,6 +1,4 @@
|
||||
"""Implementations of different data feeders to provide data for TF trainer."""
|
||||
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -14,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Implementations of different data feeders to provide data for TF trainer."""
|
||||
|
||||
# TODO(ipolosukhin): Replace this module with feed-dict queue runners & queues.
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
@ -1,5 +1,4 @@
|
||||
"""Various high level TF models."""
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,13 +11,16 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Various high level TF models."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.ops import autoencoder_ops
|
||||
from tensorflow.contrib.learn.python.learn.ops import dnn_ops
|
||||
from tensorflow.contrib.learn.python.learn.ops import losses_ops
|
||||
from tensorflow.contrib.learn.python.learn.ops import autoencoder_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops as array_ops_
|
||||
@ -29,8 +31,7 @@ from tensorflow.python.ops import variable_scope as vs
|
||||
|
||||
|
||||
def linear_regression_zero_init(X, y):
|
||||
"""Creates a linear regression TensorFlow subgraph, in which weights and
|
||||
bias terms are initialized to exactly zero.
|
||||
"""Linear regression subgraph with zero-value initial weights and bias.
|
||||
|
||||
Args:
|
||||
X: tensor or placeholder for input features.
|
||||
@ -43,8 +44,7 @@ def linear_regression_zero_init(X, y):
|
||||
|
||||
|
||||
def logistic_regression_zero_init(X, y):
|
||||
"""Creates a logistic regression TensorFlow subgraph, in which weights and
|
||||
bias terms are initialized to exactly zero.
|
||||
"""Logistic regression subgraph with zero-value initial weights and bias.
|
||||
|
||||
Args:
|
||||
X: tensor or placeholder for input features.
|
||||
@ -85,7 +85,7 @@ def linear_regression(X, y, init_mean=None, init_stddev=1.0):
|
||||
else:
|
||||
output_shape = y_shape[1]
|
||||
# Set up the requested initialization.
|
||||
if (init_mean is None):
|
||||
if init_mean is None:
|
||||
weights = vs.get_variable('weights', [X.get_shape()[1], output_shape])
|
||||
bias = vs.get_variable('bias', [output_shape])
|
||||
else:
|
||||
@ -134,7 +134,7 @@ def logistic_regression(X,
|
||||
logging_ops.histogram_summary('logistic_regression.X', X)
|
||||
logging_ops.histogram_summary('logistic_regression.y', y)
|
||||
# Set up the requested initialization.
|
||||
if (init_mean is None):
|
||||
if init_mean is None:
|
||||
weights = vs.get_variable('weights',
|
||||
[X.get_shape()[1], y.get_shape()[-1]])
|
||||
bias = vs.get_variable('bias', [y.get_shape()[-1]])
|
||||
@ -188,17 +188,18 @@ def get_dnn_model(hidden_units, target_predictor_fn, dropout=None):
|
||||
|
||||
return dnn_estimator
|
||||
|
||||
|
||||
def get_autoencoder_model(hidden_units, target_predictor_fn,
|
||||
activation, add_noise=None, dropout=None):
|
||||
"""Returns a function that creates a Autoencoder TensorFlow subgraph with given
|
||||
params.
|
||||
"""Returns a function that creates a Autoencoder TensorFlow subgraph.
|
||||
|
||||
Args:
|
||||
hidden_units: List of values of hidden units for layers.
|
||||
target_predictor_fn: Function that will predict target from input
|
||||
features. This can be logistic regression,
|
||||
linear regression or any other model,
|
||||
that takes X, y and returns predictions and loss tensors.
|
||||
that takes X, y and returns predictions and loss
|
||||
tensors.
|
||||
activation: activation function used to map inner latent layer onto
|
||||
reconstruction layer.
|
||||
add_noise: a function that adds noise to tensor_in,
|
||||
@ -218,6 +219,7 @@ def get_autoencoder_model(hidden_units, target_predictor_fn,
|
||||
return encoder, decoder, target_predictor_fn(X, decoder)
|
||||
return dnn_autoencoder_estimator
|
||||
|
||||
|
||||
## This will be in Tensorflow 0.7.
|
||||
## TODO(ilblackdragon): Clean this up when it's released
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
"""Monitors to track model training, report on progress and request early stopping"""
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -13,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Monitors to track training, report progress and request early stopping."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,40 +12,47 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This example builds deep residual network for mnist data.
|
||||
"""This example builds deep residual network for mnist data.
|
||||
|
||||
Reference Paper: http://arxiv.org/pdf/1512.03385.pdf
|
||||
|
||||
Note that this is still a work-in-progress. Feel free to submit a PR
|
||||
to make this better.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from math import sqrt
|
||||
import os
|
||||
|
||||
from sklearn import metrics
|
||||
import tensorflow as tf
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
from tensorflow.contrib import learn
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
|
||||
|
||||
def res_net(x, y, activation=tf.nn.relu):
|
||||
"""Builds a residual network. Note that if the input tensor is 2D, it must be
|
||||
square in order to be converted to a 4D tensor.
|
||||
"""Builds a residual network.
|
||||
|
||||
Borrowed structure from here: https://github.com/pkmital/tensorflow_tutorials/blob/master/10_residual_network.py
|
||||
Note that if the input tensor is 2D, it must be square in order to be
|
||||
converted to a 4D tensor.
|
||||
|
||||
Borrowed structure from:
|
||||
github.com/pkmital/tensorflow_tutorials/blob/master/10_residual_network.py
|
||||
|
||||
Args:
|
||||
x: Input of the network
|
||||
y: Output of the network
|
||||
activation: Activation function to apply after each convolution
|
||||
|
||||
Returns:
|
||||
Predictions and loss tensors.
|
||||
"""
|
||||
|
||||
# Configurations for each bottleneck block
|
||||
# Configurations for each bottleneck block.
|
||||
BottleneckBlock = namedtuple(
|
||||
'BottleneckBlock', ['num_layers', 'num_filters', 'bottleneck_size'])
|
||||
blocks = [BottleneckBlock(3, 128, 32),
|
||||
@ -137,8 +144,8 @@ def res_net(x, y, activation=tf.nn.relu):
|
||||
mnist = input_data.read_data_sets('MNIST_data')
|
||||
|
||||
# Restore model if graph is saved into a folder.
|
||||
if os.path.exists("models/resnet/graph.pbtxt"):
|
||||
classifier = learn.TensorFlowEstimator.restore("models/resnet/")
|
||||
if os.path.exists('models/resnet/graph.pbtxt'):
|
||||
classifier = learn.TensorFlowEstimator.restore('models/resnet/')
|
||||
else:
|
||||
# Create a new resnet classifier.
|
||||
classifier = learn.TensorFlowEstimator(
|
||||
@ -147,7 +154,8 @@ else:
|
||||
|
||||
while True:
|
||||
# Train model and save summaries into logdir.
|
||||
classifier.fit(mnist.train.images, mnist.train.labels, logdir="models/resnet/")
|
||||
classifier.fit(
|
||||
mnist.train.images, mnist.train.labels, logdir='models/resnet/')
|
||||
|
||||
# Calculate accuracy.
|
||||
score = metrics.accuracy_score(
|
||||
@ -155,4 +163,4 @@ while True:
|
||||
print('Accuracy: {0:f}'.format(score))
|
||||
|
||||
# Save model graph and checkpoints.
|
||||
classifier.save("models/resnet/")
|
||||
classifier.save('models/resnet/')
|
||||
|
Loading…
Reference in New Issue
Block a user