diff --git a/tensorflow/contrib/learn/python/learn/__init__.py b/tensorflow/contrib/learn/python/learn/__init__.py
index 8de7797e6b7..1d72243f992 100644
--- a/tensorflow/contrib/learn/python/learn/__init__.py
+++ b/tensorflow/contrib/learn/python/learn/__init__.py
@@ -1,5 +1,4 @@
-"""Main Scikit Flow module."""
-#  Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+#  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 #
 #  Licensed under the Apache License, Version 2.0 (the "License");
 #  you may not use this file except in compliance with the License.
@@ -13,6 +12,8 @@
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
 
+"""High level API for learning with TensorFlow."""
+
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
diff --git a/tensorflow/contrib/learn/python/learn/datasets/base.py b/tensorflow/contrib/learn/python/learn/datasets/base.py
index 7f78b2dced9..9c29b9eeb11 100644
--- a/tensorflow/contrib/learn/python/learn/datasets/base.py
+++ b/tensorflow/contrib/learn/python/learn/datasets/base.py
@@ -1,5 +1,4 @@
-"""Base utilities for loading datasets."""
-#  Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+#  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 #
 #  Licensed under the Apache License, Version 2.0 (the "License");
 #  you may not use this file except in compliance with the License.
@@ -13,6 +12,8 @@
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
 
+"""Base utilities for loading datasets."""
+
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
diff --git a/tensorflow/contrib/learn/python/learn/estimators/__init__.py b/tensorflow/contrib/learn/python/learn/estimators/__init__.py
index e714c15f2e0..1b0d0aef6f5 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/__init__.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/__init__.py
@@ -1,5 +1,4 @@
-"""Scikit Flow Estimators."""
-#  Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+#  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 #
 #  Licensed under the Apache License, Version 2.0 (the "License");
 #  you may not use this file except in compliance with the License.
@@ -12,12 +11,16 @@
 #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
+
+"""Estimators."""
+
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
 from tensorflow.contrib.learn.python.learn.estimators.autoencoder import TensorFlowDNNAutoencoder
-from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowEstimator, TensorFlowBaseTransformer
+from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowBaseTransformer
+from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowEstimator
 from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNClassifier
 from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNRegressor
 from tensorflow.contrib.learn.python.learn.estimators.dnn import TensorFlowDNNClassifier
diff --git a/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py b/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py
index dcd1d81056b..5032ea966d4 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py
@@ -1,5 +1,4 @@
-"""sklearn cross-support."""
-#  Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+#  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 #
 #  Licensed under the Apache License, Version 2.0 (the "License");
 #  you may not use this file except in compliance with the License.
@@ -12,6 +11,9 @@
 #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
+
+"""sklearn cross-support."""
+
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
@@ -20,6 +22,8 @@ import collections
 import os
 
 import numpy as np
+import six
+
 
 def _pprint(d):
   return ', '.join(['%s=%s' % (key, str(value)) for key, value in d.items()])
@@ -102,6 +106,7 @@ class _BaseEstimator(object):
                        _pprint(self.get_params(deep=False)),)
 
 
+# pylint: disable=old-style-class
 class _ClassifierMixin():
   """Mixin class for all classifiers."""
   pass
@@ -111,8 +116,10 @@ class _RegressorMixin():
   """Mixin class for all regression estimators."""
   pass
 
+
 class _TransformerMixin():
-    """Mixin class for all transformer estimators."""
+  """Mixin class for all transformer estimators."""
+
 
 class _NotFittedError(ValueError, AttributeError):
   """Exception class to raise if estimator is used before fitting.
@@ -134,6 +141,8 @@ class _NotFittedError(ValueError, AttributeError):
   https://github.com/scikit-learn/scikit-learn/master/sklearn/exceptions.py
   """
 
+# pylint: enable=old-style-class
+
 
 def _accuracy_score(y_true, y_pred):
   score = y_true == y_pred
@@ -149,8 +158,7 @@ def _mean_squared_error(y_true, y_pred):
 
 
 def _train_test_split(*args, **options):
-  n_array = len(args)
-
+  # pylint: disable=missing-docstring
   test_size = options.pop('test_size', None)
   train_size = options.pop('train_size', None)
   random_state = options.pop('random_state', None)
@@ -159,7 +167,7 @@ def _train_test_split(*args, **options):
     train_size = 0.75
   elif train_size is None:
     train_size = 1 - test_size
-  train_size = train_size * args[0].shape[0]
+  train_size *= args[0].shape[0]
 
   np.random.seed(random_state)
   indices = np.random.permutation(args[0].shape[0])
@@ -173,6 +181,7 @@ def _train_test_split(*args, **options):
 # If "TENSORFLOW_SKLEARN" flag is defined then try to import from sklearn.
 TRY_IMPORT_SKLEARN = os.environ.get('TENSORFLOW_SKLEARN', False)
 if TRY_IMPORT_SKLEARN:
+  # pylint: disable=g-import-not-at-top,g-multiple-import,unused-import
   from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, TransformerMixin
   from sklearn.metrics import accuracy_score, log_loss, mean_squared_error
   from sklearn.cross_validation import train_test_split
diff --git a/tensorflow/contrib/learn/python/learn/estimators/autoencoder.py b/tensorflow/contrib/learn/python/learn/estimators/autoencoder.py
index 690bac8f196..a3f41697680 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/autoencoder.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/autoencoder.py
@@ -1,5 +1,4 @@
-"""Deep Autoencoder estimators."""
-#  Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+#  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 #
 #  Licensed under the Apache License, Version 2.0 (the "License");
 #  you may not use this file except in compliance with the License.
@@ -12,105 +11,115 @@
 #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
+
+"""Deep Autoencoder estimators."""
+
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python.ops import nn
-from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowBaseTransformer
+import numpy as np
+
 from tensorflow.contrib.learn.python.learn import models
+from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowBaseTransformer
+from tensorflow.python.ops import nn
 
 
 class TensorFlowDNNAutoencoder(TensorFlowBaseTransformer):
-    """TensorFlow Autoencoder Regressor model.
+  """TensorFlow Autoencoder Regressor model.
 
-    Parameters:
-        hidden_units: List of hidden units per layer.
-        batch_size: Mini batch size.
-        activation: activation function used to map inner latent layer onto
-                    reconstruction layer.
-        add_noise: a function that adds noise to tensor_in, 
-               e.g. def add_noise(x):
-                        return(x + np.random.normal(0, 0.1, (len(x), len(x[0]))))
-        steps: Number of steps to run over data.
-        optimizer: Optimizer name (or class), for example "SGD", "Adam",
-                   "Adagrad".
-        learning_rate: If this is constant float value, no decay function is used.
-            Instead, a customized decay function can be passed that accepts
-            global_step as parameter and returns a Tensor.
-            e.g. exponential decay function:
-            def exp_decay(global_step):
-                return tf.train.exponential_decay(
-                    learning_rate=0.1, global_step,
-                    decay_steps=2, decay_rate=0.001)
-        continue_training: when continue_training is True, once initialized
-            model will be continuely trained on every call of fit.
-        config: RunConfig object that controls the configurations of the session,
-            e.g. num_cores, gpu_memory_fraction, etc.
-        verbose: Controls the verbosity, possible values:
-                 0: the algorithm and debug information is muted.
-                 1: trainer prints the progress.
-                 2: log device placement is printed.
-        dropout: When not None, the probability we will drop out a given
-                 coordinate.
-    """
-    def __init__(self, hidden_units, n_classes=0, batch_size=32,
-                 steps=200, optimizer="Adagrad", learning_rate=0.1,
-                 clip_gradients=5.0, activation=nn.relu, add_noise=None,
-                 continue_training=False, config=None,
-                 verbose=1, dropout=None):
-        self.hidden_units = hidden_units
-        self.dropout = dropout
-        self.activation = activation
-        self.add_noise = add_noise
-        super(TensorFlowDNNAutoencoder, self).__init__(
-            model_fn=self._model_fn,
-            n_classes=n_classes,
-            batch_size=batch_size, steps=steps, optimizer=optimizer,
-            learning_rate=learning_rate, clip_gradients=clip_gradients,
-            continue_training=continue_training,
-            config=config, verbose=verbose)
+  Parameters:
+      hidden_units: List of hidden units per layer.
+      batch_size: Mini batch size.
+      activation: activation function used to map inner latent layer onto
+                  reconstruction layer.
+      add_noise: a function that adds noise to tensor_in,
+             e.g. def add_noise(x):
+                      return(x + np.random.normal(0, 0.1, (len(x), len(x[0]))))
+      steps: Number of steps to run over data.
+      optimizer: Optimizer name (or class), for example "SGD", "Adam",
+                 "Adagrad".
+      learning_rate: If this is constant float value, no decay function is used.
+          Instead, a customized decay function can be passed that accepts
+          global_step as parameter and returns a Tensor.
+          e.g. exponential decay function:
+          def exp_decay(global_step):
+              return tf.train.exponential_decay(
+                  learning_rate=0.1, global_step,
+                  decay_steps=2, decay_rate=0.001)
+      continue_training: when continue_training is True, once initialized
+          model will be continuely trained on every call of fit.
+      config: RunConfig object that controls the configurations of the session,
+          e.g. num_cores, gpu_memory_fraction, etc.
+      verbose: Controls the verbosity, possible values:
+               0: the algorithm and debug information is muted.
+               1: trainer prints the progress.
+               2: log device placement is printed.
+      dropout: When not None, the probability we will drop out a given
+               coordinate.
+  """
 
-    def _model_fn(self, X, y):
-        encoder, decoder, autoencoder_estimator = models.get_autoencoder_model(
-            self.hidden_units,
-            models.linear_regression,
-            activation=self.activation,
-            add_noise=self.add_noise,
-            dropout=self.dropout)(X)
-        self.encoder = encoder
-        self.decoder = decoder
-        return autoencoder_estimator
+  def __init__(self, hidden_units, n_classes=0, batch_size=32,
+               steps=200, optimizer="Adagrad", learning_rate=0.1,
+               clip_gradients=5.0, activation=nn.relu, add_noise=None,
+               continue_training=False, config=None,
+               verbose=1, dropout=None):
+    self.hidden_units = hidden_units
+    self.dropout = dropout
+    self.activation = activation
+    self.add_noise = add_noise
+    super(TensorFlowDNNAutoencoder, self).__init__(
+        model_fn=self._model_fn,
+        n_classes=n_classes,
+        batch_size=batch_size, steps=steps, optimizer=optimizer,
+        learning_rate=learning_rate, clip_gradients=clip_gradients,
+        continue_training=continue_training,
+        config=config, verbose=verbose)
 
-    def generate(self, hidden=None):
-        """Generate new data using trained construction layer"""
-        if hidden is None:
-            last_layer = len(self.hidden_units) - 1
-            bias = self.get_tensor_value('encoder/dnn/layer%d/Linear/Bias:0' % last_layer)
-            import numpy as np
-            hidden = np.random.normal(size=bias.shape)
-            hidden = np.reshape(hidden, (1, len(hidden)))
-        return self._session.run(self.decoder, feed_dict={self.encoder: hidden})
+  def _model_fn(self, X, y):
+    encoder, decoder, autoencoder_estimator = models.get_autoencoder_model(
+        self.hidden_units,
+        models.linear_regression,
+        activation=self.activation,
+        add_noise=self.add_noise,
+        dropout=self.dropout)(X)
+    self.encoder = encoder
+    self.decoder = decoder
+    return autoencoder_estimator
 
-    @property
-    def weights_(self):
-        """Returns weights of the autoencoder's weight layers."""
-        weights = []
-        for layer in range(len(self.hidden_units)):
-            weights.append(self.get_tensor_value('encoder/dnn/layer%d/Linear/Matrix:0' % layer))
-        for layer in range(len(self.hidden_units)):
-            weights.append(self.get_tensor_value('decoder/dnn/layer%d/Linear/Matrix:0' % layer))
-        weights.append(self.get_tensor_value('linear_regression/weights:0'))
-        return weights
+  def generate(self, hidden=None):
+    """Generate new data using trained construction layer."""
+    if hidden is None:
+      last_layer = len(self.hidden_units) - 1
+      bias = self.get_tensor_value(
+          "encoder/dnn/layer%d/Linear/Bias:0" % last_layer)
+      hidden = np.random.normal(size=bias.shape)
+      hidden = np.reshape(hidden, (1, len(hidden)))
+    return self._session.run(self.decoder, feed_dict={self.encoder: hidden})
 
-    @property
-    def bias_(self):
-        """Returns bias of the autoencoder's bias layers."""
-        biases = []
-        for layer in range(len(self.hidden_units)):
-            biases.append(self.get_tensor_value('encoder/dnn/layer%d/Linear/Bias:0' % layer))
-        for layer in range(len(self.hidden_units)):
-            biases.append(self.get_tensor_value('decoder/dnn/layer%d/Linear/Bias:0' % layer))
-        biases.append(self.get_tensor_value('linear_regression/bias:0'))
-        return biases
+  @property
+  def weights_(self):
+    """Returns weights of the autoencoder's weight layers."""
+    weights = []
+    for layer in range(len(self.hidden_units)):
+      weights.append(self.get_tensor_value(
+          "encoder/dnn/layer%d/Linear/Matrix:0" % layer))
+    for layer in range(len(self.hidden_units)):
+      weights.append(self.get_tensor_value(
+          "decoder/dnn/layer%d/Linear/Matrix:0" % layer))
+    weights.append(self.get_tensor_value("linear_regression/weights:0"))
+    return weights
+
+  @property
+  def bias_(self):
+    """Returns bias of the autoencoder's bias layers."""
+    biases = []
+    for layer in range(len(self.hidden_units)):
+      biases.append(self.get_tensor_value(
+          "encoder/dnn/layer%d/Linear/Bias:0" % layer))
+    for layer in range(len(self.hidden_units)):
+      biases.append(self.get_tensor_value(
+          "decoder/dnn/layer%d/Linear/Bias:0" % layer))
+    biases.append(self.get_tensor_value("linear_regression/bias:0"))
+    return biases
 
diff --git a/tensorflow/contrib/learn/python/learn/estimators/base.py b/tensorflow/contrib/learn/python/learn/estimators/base.py
index 39131f059b0..ab00ae76f78 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/base.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/base.py
@@ -1,5 +1,4 @@
-"""Base estimator class."""
-#  Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+#  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 #
 #  Licensed under the Apache License, Version 2.0 (the "License");
 #  you may not use this file except in compliance with the License.
@@ -12,18 +11,17 @@
 #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
+
+"""Base estimator class."""
+
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import datetime
 import json
 import os
-import shutil
 from six import string_types
 
-import numpy as np
-
 from google.protobuf import text_format
 from tensorflow.python.platform import gfile
 
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py
index 017667699bc..5447d9ec052 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py
@@ -1,5 +1,4 @@
-"""Deep Neural Network estimators."""
-#  Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+#  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 #
 #  Licensed under the Apache License, Version 2.0 (the "License");
 #  you may not use this file except in compliance with the License.
@@ -12,6 +11,9 @@
 #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
+
+"""Deep Neural Network estimators."""
+
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py
index d58ab35f5ee..ef73c44013a 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py
@@ -1,5 +1,4 @@
-"""Linear Estimators."""
-#  Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+#  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 #
 #  Licensed under the Apache License, Version 2.0 (the "License");
 #  you may not use this file except in compliance with the License.
@@ -12,6 +11,9 @@
 #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
+
+"""Linear Estimators."""
+
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
diff --git a/tensorflow/contrib/learn/python/learn/estimators/rnn.py b/tensorflow/contrib/learn/python/learn/estimators/rnn.py
index b703f607657..719a19a5bc8 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/rnn.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/rnn.py
@@ -1,5 +1,4 @@
-"""Recurrent Neural Network estimators."""
-#  Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+#  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 #
 #  Licensed under the Apache License, Version 2.0 (the "License");
 #  you may not use this file except in compliance with the License.
@@ -12,6 +11,9 @@
 #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
+
+"""Recurrent Neural Network estimators."""
+
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
diff --git a/tensorflow/contrib/learn/python/learn/io/data_feeder.py b/tensorflow/contrib/learn/python/learn/io/data_feeder.py
index 04bbd997482..b3ed3bc7d92 100644
--- a/tensorflow/contrib/learn/python/learn/io/data_feeder.py
+++ b/tensorflow/contrib/learn/python/learn/io/data_feeder.py
@@ -1,6 +1,4 @@
-"""Implementations of different data feeders to provide data for TF trainer."""
-
-#  Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+#  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 #
 #  Licensed under the Apache License, Version 2.0 (the "License");
 #  you may not use this file except in compliance with the License.
@@ -14,6 +12,8 @@
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
 
+"""Implementations of different data feeders to provide data for TF trainer."""
+
 # TODO(ipolosukhin): Replace this module with feed-dict queue runners & queues.
 
 from __future__ import absolute_import
diff --git a/tensorflow/contrib/learn/python/learn/models.py b/tensorflow/contrib/learn/python/learn/models.py
index 8cabd390fc7..dddd152f368 100644
--- a/tensorflow/contrib/learn/python/learn/models.py
+++ b/tensorflow/contrib/learn/python/learn/models.py
@@ -1,5 +1,4 @@
-"""Various high level TF models."""
-#  Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+#  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 #
 #  Licensed under the Apache License, Version 2.0 (the "License");
 #  you may not use this file except in compliance with the License.
@@ -12,13 +11,16 @@
 #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
+
+"""Various high level TF models."""
+
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.contrib.learn.python.learn.ops import autoencoder_ops
 from tensorflow.contrib.learn.python.learn.ops import dnn_ops
 from tensorflow.contrib.learn.python.learn.ops import losses_ops
-from tensorflow.contrib.learn.python.learn.ops import autoencoder_ops
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops as array_ops_
@@ -29,8 +31,7 @@ from tensorflow.python.ops import variable_scope as vs
 
 
 def linear_regression_zero_init(X, y):
-  """Creates a linear regression TensorFlow subgraph, in which weights and
-    bias terms are initialized to exactly zero.
+  """Linear regression subgraph with zero-value initial weights and bias.
 
   Args:
     X: tensor or placeholder for input features.
@@ -43,8 +44,7 @@ def linear_regression_zero_init(X, y):
 
 
 def logistic_regression_zero_init(X, y):
-  """Creates a logistic regression TensorFlow subgraph, in which weights and
-       bias terms are initialized to exactly zero.
+  """Logistic regression subgraph with zero-value initial weights and bias.
 
   Args:
     X: tensor or placeholder for input features.
@@ -85,7 +85,7 @@ def linear_regression(X, y, init_mean=None, init_stddev=1.0):
     else:
       output_shape = y_shape[1]
     # Set up the requested initialization.
-    if (init_mean is None):
+    if init_mean is None:
       weights = vs.get_variable('weights', [X.get_shape()[1], output_shape])
       bias = vs.get_variable('bias', [output_shape])
     else:
@@ -134,7 +134,7 @@ def logistic_regression(X,
     logging_ops.histogram_summary('logistic_regression.X', X)
     logging_ops.histogram_summary('logistic_regression.y', y)
     # Set up the requested initialization.
-    if (init_mean is None):
+    if init_mean is None:
       weights = vs.get_variable('weights',
                                 [X.get_shape()[1], y.get_shape()[-1]])
       bias = vs.get_variable('bias', [y.get_shape()[-1]])
@@ -188,35 +188,37 @@ def get_dnn_model(hidden_units, target_predictor_fn, dropout=None):
 
   return dnn_estimator
 
+
 def get_autoencoder_model(hidden_units, target_predictor_fn,
                           activation, add_noise=None, dropout=None):
-    """Returns a function that creates a Autoencoder TensorFlow subgraph with given
-    params.
+  """Returns a function that creates a Autoencoder TensorFlow subgraph.
 
-    Args:
-        hidden_units: List of values of hidden units for layers.
-        target_predictor_fn: Function that will predict target from input
-                             features. This can be logistic regression,
-                             linear regression or any other model,
-                             that takes X, y and returns predictions and loss tensors.
-        activation: activation function used to map inner latent layer onto
-                    reconstruction layer.
-        add_noise: a function that adds noise to tensor_in, 
-               e.g. def add_noise(x):
-                        return(x + np.random.normal(0, 0.1, (len(x), len(x[0]))))
-        dropout: When not none, causes dropout regularization to be used,
-                 with the specified probability of removing a given coordinate.
+  Args:
+    hidden_units: List of values of hidden units for layers.
+    target_predictor_fn: Function that will predict target from input
+                         features. This can be logistic regression,
+                         linear regression or any other model,
+                         that takes X, y and returns predictions and loss
+                         tensors.
+    activation: activation function used to map inner latent layer onto
+                reconstruction layer.
+    add_noise: a function that adds noise to tensor_in,
+           e.g. def add_noise(x):
+                    return(x + np.random.normal(0, 0.1, (len(x), len(x[0]))))
+    dropout: When not none, causes dropout regularization to be used,
+             with the specified probability of removing a given coordinate.
+
+  Returns:
+      A function that creates the subgraph.
+  """
+  def dnn_autoencoder_estimator(X):
+    """Autoencoder estimator with target predictor function on top."""
+    encoder, decoder = autoencoder_ops.dnn_autoencoder(
+        X, hidden_units, activation,
+        add_noise=add_noise, dropout=dropout)
+    return encoder, decoder, target_predictor_fn(X, decoder)
+  return dnn_autoencoder_estimator
 
-    Returns:
-        A function that creates the subgraph.
-    """
-    def dnn_autoencoder_estimator(X):
-        """Autoencoder estimator with target predictor function on top."""
-        encoder, decoder = autoencoder_ops.dnn_autoencoder(
-          X, hidden_units, activation,
-          add_noise=add_noise, dropout=dropout)
-        return encoder, decoder, target_predictor_fn(X, decoder)
-    return dnn_autoencoder_estimator
 
 ## This will be in Tensorflow 0.7.
 ## TODO(ilblackdragon): Clean this up when it's released
diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py
index 861db1758f5..79c629d9491 100644
--- a/tensorflow/contrib/learn/python/learn/monitors.py
+++ b/tensorflow/contrib/learn/python/learn/monitors.py
@@ -1,5 +1,4 @@
-"""Monitors to track model training, report on progress and request early stopping"""
-#  Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+#  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 #
 #  Licensed under the Apache License, Version 2.0 (the "License");
 #  you may not use this file except in compliance with the License.
@@ -13,6 +12,8 @@
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
 
+"""Monitors to track training, report progress and request early stopping."""
+
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
diff --git a/tensorflow/examples/skflow/boston.py b/tensorflow/examples/skflow/boston.py
index bf2066770c7..9d895bd8e38 100644
--- a/tensorflow/examples/skflow/boston.py
+++ b/tensorflow/examples/skflow/boston.py
@@ -1,4 +1,4 @@
-#  Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+#  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 #
 #  Licensed under the Apache License, Version 2.0 (the "License");
 #  you may not use this file except in compliance with the License.
diff --git a/tensorflow/examples/skflow/iris.py b/tensorflow/examples/skflow/iris.py
index c6c566b10fd..ea44428d541 100644
--- a/tensorflow/examples/skflow/iris.py
+++ b/tensorflow/examples/skflow/iris.py
@@ -1,4 +1,4 @@
-#  Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+#  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 #
 #  Licensed under the Apache License, Version 2.0 (the "License");
 #  you may not use this file except in compliance with the License.
diff --git a/tensorflow/examples/skflow/iris_custom_decay_dnn.py b/tensorflow/examples/skflow/iris_custom_decay_dnn.py
index f9c172725d9..b8b1a1dd140 100644
--- a/tensorflow/examples/skflow/iris_custom_decay_dnn.py
+++ b/tensorflow/examples/skflow/iris_custom_decay_dnn.py
@@ -1,4 +1,4 @@
-#  Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+#  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 #
 #  Licensed under the Apache License, Version 2.0 (the "License");
 #  you may not use this file except in compliance with the License.
diff --git a/tensorflow/examples/skflow/mnist.py b/tensorflow/examples/skflow/mnist.py
index 082ecb2f839..d1288a31e98 100644
--- a/tensorflow/examples/skflow/mnist.py
+++ b/tensorflow/examples/skflow/mnist.py
@@ -1,4 +1,4 @@
-#  Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+#  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 #
 #  Licensed under the Apache License, Version 2.0 (the "License");
 #  you may not use this file except in compliance with the License.
diff --git a/tensorflow/examples/skflow/resnet.py b/tensorflow/examples/skflow/resnet.py
index f1f39568d46..03a5d5e5191 100644
--- a/tensorflow/examples/skflow/resnet.py
+++ b/tensorflow/examples/skflow/resnet.py
@@ -1,4 +1,4 @@
-#  Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+#  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 #
 #  Licensed under the Apache License, Version 2.0 (the "License");
 #  you may not use this file except in compliance with the License.
@@ -12,147 +12,155 @@
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
 
-"""
-This example builds deep residual network for mnist data.
+"""This example builds deep residual network for mnist data.
+
 Reference Paper: http://arxiv.org/pdf/1512.03385.pdf
 
 Note that this is still a work-in-progress. Feel free to submit a PR
 to make this better.
 """
+
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import os
 from collections import namedtuple
 from math import sqrt
+import os
 
 from sklearn import metrics
 import tensorflow as tf
-from tensorflow.examples.tutorials.mnist import input_data
 from tensorflow.contrib import learn
+from tensorflow.examples.tutorials.mnist import input_data
 
 
 def res_net(x, y, activation=tf.nn.relu):
-    """Builds a residual network. Note that if the input tensor is 2D, it must be
-    square in order to be converted to a 4D tensor. 
+  """Builds a residual network.
 
-    Borrowed structure from here: https://github.com/pkmital/tensorflow_tutorials/blob/master/10_residual_network.py
+  Note that if the input tensor is 2D, it must be square in order to be
+  converted to a 4D tensor.
 
-    Args:
-        x: Input of the network
-        y: Output of the network
-        activation: Activation function to apply after each convolution
-    """
+  Borrowed structure from:
+  github.com/pkmital/tensorflow_tutorials/blob/master/10_residual_network.py
 
-    # Configurations for each bottleneck block
-    BottleneckBlock = namedtuple(
-        'BottleneckBlock', ['num_layers', 'num_filters', 'bottleneck_size'])
-    blocks = [BottleneckBlock(3, 128, 32),
-              BottleneckBlock(3, 256, 64),
-              BottleneckBlock(3, 512, 128),
-              BottleneckBlock(3, 1024, 256)]
+  Args:
+    x: Input of the network
+    y: Output of the network
+    activation: Activation function to apply after each convolution
 
-    input_shape = x.get_shape().as_list()
+  Returns:
+    Predictions and loss tensors.
+  """
 
-    # Reshape the input into the right shape if it's 2D tensor
-    if len(input_shape) == 2:
-        ndim = int(sqrt(input_shape[1]))
-        x = tf.reshape(x, [-1, ndim, ndim, 1])
+  # Configurations for each bottleneck block.
+  BottleneckBlock = namedtuple(
+      'BottleneckBlock', ['num_layers', 'num_filters', 'bottleneck_size'])
+  blocks = [BottleneckBlock(3, 128, 32),
+            BottleneckBlock(3, 256, 64),
+            BottleneckBlock(3, 512, 128),
+            BottleneckBlock(3, 1024, 256)]
 
-    # First convolution expands to 64 channels
-    with tf.variable_scope('conv_layer1'):
-        net = learn.ops.conv2d(x, 64, [7, 7], batch_norm=True,
-                                activation=activation, bias=False)
+  input_shape = x.get_shape().as_list()
 
-    # Max pool
-    net = tf.nn.max_pool(
-        net, [1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME')
+  # Reshape the input into the right shape if it's 2D tensor
+  if len(input_shape) == 2:
+    ndim = int(sqrt(input_shape[1]))
+    x = tf.reshape(x, [-1, ndim, ndim, 1])
 
-    # First chain of resnets
-    with tf.variable_scope('conv_layer2'):
-        net = learn.ops.conv2d(net, blocks[0].num_filters,
-                               [1, 1], [1, 1, 1, 1],
-                               padding='VALID', bias=True)
+  # First convolution expands to 64 channels
+  with tf.variable_scope('conv_layer1'):
+    net = learn.ops.conv2d(x, 64, [7, 7], batch_norm=True,
+                           activation=activation, bias=False)
 
-    # Create each bottleneck building block for each layer
-    for block_i, block in enumerate(blocks):
-        for layer_i in range(block.num_layers):
+  # Max pool
+  net = tf.nn.max_pool(
+      net, [1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME')
 
-            name = 'block_%d/layer_%d' % (block_i, layer_i)
+  # First chain of resnets
+  with tf.variable_scope('conv_layer2'):
+    net = learn.ops.conv2d(net, blocks[0].num_filters,
+                           [1, 1], [1, 1, 1, 1],
+                           padding='VALID', bias=True)
 
-            # 1x1 convolution responsible for reducing dimension
-            with tf.variable_scope(name + '/conv_in'):
-                conv = learn.ops.conv2d(net, block.bottleneck_size,
-                                         [1, 1], [1, 1, 1, 1],
-                                         padding='VALID',
-                                         activation=activation,
-                                         batch_norm=True,
-                                         bias=False)
+  # Create each bottleneck building block for each layer
+  for block_i, block in enumerate(blocks):
+    for layer_i in range(block.num_layers):
 
-            with tf.variable_scope(name + '/conv_bottleneck'):
-                conv = learn.ops.conv2d(conv, block.bottleneck_size,
-                                         [3, 3], [1, 1, 1, 1],
-                                         padding='SAME',
-                                         activation=activation,
-                                         batch_norm=True,
-                                         bias=False)
+      name = 'block_%d/layer_%d' % (block_i, layer_i)
 
-            # 1x1 convolution responsible for restoring dimension
-            with tf.variable_scope(name + '/conv_out'):
-                conv = learn.ops.conv2d(conv, block.num_filters,
-                                         [1, 1], [1, 1, 1, 1],
-                                         padding='VALID',
-                                         activation=activation,
-                                         batch_norm=True,
-                                         bias=False)
+      # 1x1 convolution responsible for reducing dimension
+      with tf.variable_scope(name + '/conv_in'):
+        conv = learn.ops.conv2d(net, block.bottleneck_size,
+                                [1, 1], [1, 1, 1, 1],
+                                padding='VALID',
+                                activation=activation,
+                                batch_norm=True,
+                                bias=False)
 
-            # shortcut connections that turn the network into its counterpart
-            # residual function (identity shortcut)
-            net = conv + net
+      with tf.variable_scope(name + '/conv_bottleneck'):
+        conv = learn.ops.conv2d(conv, block.bottleneck_size,
+                                [3, 3], [1, 1, 1, 1],
+                                padding='SAME',
+                                activation=activation,
+                                batch_norm=True,
+                                bias=False)
 
-        try:
-            # upscale to the next block size
-            next_block = blocks[block_i + 1]
-            with tf.variable_scope('block_%d/conv_upscale' % block_i):
-                net = learn.ops.conv2d(net, next_block.num_filters,
-                                        [1, 1], [1, 1, 1, 1],
-                                        bias=False,
-                                        padding='SAME')
-        except IndexError:
-            pass
+      # 1x1 convolution responsible for restoring dimension
+      with tf.variable_scope(name + '/conv_out'):
+        conv = learn.ops.conv2d(conv, block.num_filters,
+                                [1, 1], [1, 1, 1, 1],
+                                padding='VALID',
+                                activation=activation,
+                                batch_norm=True,
+                                bias=False)
 
-    net_shape = net.get_shape().as_list()
-    net = tf.nn.avg_pool(net,
-                         ksize=[1, net_shape[1], net_shape[2], 1],
-                         strides=[1, 1, 1, 1], padding='VALID')
+      # shortcut connections that turn the network into its counterpart
+      # residual function (identity shortcut)
+      net = conv + net
 
-    net_shape = net.get_shape().as_list()
-    net = tf.reshape(net, [-1, net_shape[1] * net_shape[2] * net_shape[3]])
+      try:
+        # upscale to the next block size
+        next_block = blocks[block_i + 1]
+        with tf.variable_scope('block_%d/conv_upscale' % block_i):
+          net = learn.ops.conv2d(net, next_block.num_filters,
+                                 [1, 1], [1, 1, 1, 1],
+                                 bias=False,
+                                 padding='SAME')
+      except IndexError:
+        pass
 
-    return learn.models.logistic_regression(net, y)
+  net_shape = net.get_shape().as_list()
+  net = tf.nn.avg_pool(net,
+                       ksize=[1, net_shape[1], net_shape[2], 1],
+                       strides=[1, 1, 1, 1], padding='VALID')
+
+  net_shape = net.get_shape().as_list()
+  net = tf.reshape(net, [-1, net_shape[1] * net_shape[2] * net_shape[3]])
+
+  return learn.models.logistic_regression(net, y)
 
 
 # Download and load MNIST data.
 mnist = input_data.read_data_sets('MNIST_data')
 
 # Restore model if graph is saved into a folder.
-if os.path.exists("models/resnet/graph.pbtxt"):
-    classifier = learn.TensorFlowEstimator.restore("models/resnet/")
+if os.path.exists('models/resnet/graph.pbtxt'):
+  classifier = learn.TensorFlowEstimator.restore('models/resnet/')
 else:
-    # Create a new resnet classifier.
-    classifier = learn.TensorFlowEstimator(
-        model_fn=res_net, n_classes=10, batch_size=100, steps=100,
-        learning_rate=0.001, continue_training=True)
+  # Create a new resnet classifier.
+  classifier = learn.TensorFlowEstimator(
+      model_fn=res_net, n_classes=10, batch_size=100, steps=100,
+      learning_rate=0.001, continue_training=True)
 
 while True:
-    # Train model and save summaries into logdir.
-    classifier.fit(mnist.train.images, mnist.train.labels, logdir="models/resnet/")
+  # Train model and save summaries into logdir.
+  classifier.fit(
+      mnist.train.images, mnist.train.labels, logdir='models/resnet/')
 
-    # Calculate accuracy.
-    score = metrics.accuracy_score(
-        mnist.test.labels, classifier.predict(mnist.test.images, batch_size=64))
-    print('Accuracy: {0:f}'.format(score))
+  # Calculate accuracy.
+  score = metrics.accuracy_score(
+      mnist.test.labels, classifier.predict(mnist.test.images, batch_size=64))
+  print('Accuracy: {0:f}'.format(score))
 
-    # Save model graph and checkpoints.
-    classifier.save("models/resnet/")
+  # Save model graph and checkpoints.
+  classifier.save('models/resnet/')