diff --git a/RELEASE.md b/RELEASE.md
index c4d730e991e..0a7609c0bcc 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -10,6 +10,9 @@
 ## Bug Fixes and Other Changes
 * TensorBoard now displays graphs with only one data point
 * TensorBoard now visually displays NaN values
+* `tf.nn.moments()` now accepts a `shift` argument. Shifting by a good estimate
+  of the mean improves numerical stability. Also changes the behavior of the
+  `shift` argument to `tf.nn.sufficient_statistics()`.
 
 # Release 0.8.0
 
diff --git a/WORKSPACE b/WORKSPACE
index d1709b715d3..ffebcde5541 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -119,7 +119,7 @@ new_git_repository(
   name = "iron_fit_behavior",
   build_file = "bower.BUILD",
   remote = "https://github.com/polymerelements/iron-fit-behavior.git",
-  tag = "v1.2.1",
+  tag = "v1.2.2",
 )
 
 new_git_repository(
@@ -196,7 +196,7 @@ new_git_repository(
   name = "iron_range_behavior",
   build_file = "bower.BUILD",
   remote = "https://github.com/polymerelements/iron-range-behavior.git",
-  tag = "v1.0.4",
+  tag = "v1.0.5",
 )
 
 new_git_repository(
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 6ca45e84980..261103a746f 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -201,7 +201,7 @@ def batch_norm(inputs,
         collections=moving_variance_collections)
     if is_training:
       # Calculate the moments based on the individual batch.
-      mean, variance = nn.moments(inputs, axis)
+      mean, variance = nn.moments(inputs, axis, shift=moving_mean)
       # Update the moving_mean and moving_variance moments.
       update_moving_mean = moving_averages.assign_moving_average(
           moving_mean, mean, decay)
@@ -710,7 +710,7 @@ legacy_relu6 = functools.partial(legacy_fully_connected, activation_fn=nn.relu6)
 # Simple alias for fully_connected which removes the activation_fn parameter.
 legacy_linear = functools.partial(legacy_fully_connected, activation_fn=None)
 
-linear = functools.partial(fully_connected, activation_fn=nn.relu)
-relu = functools.partial(fully_connected, activation_fn=nn.relu6)
-relu6 = functools.partial(fully_connected, activation_fn=None)
+linear = legacy_linear
+relu = legacy_relu
+relu6 = legacy_relu6
 
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 8b89ce4c8d3..052d2d81b35 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -75,6 +75,42 @@ py_test(
     ],
 )
 
+py_test(
+    name = "test_dataframe",
+    size = "small",
+    srcs = ["python/learn/tests/dataframe/test_dataframe.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":learn",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/python:framework_test_lib",
+    ],
+)
+
+py_test(
+    name = "test_column",
+    size = "small",
+    srcs = ["python/learn/tests/dataframe/test_column.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":learn",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/python:framework_test_lib",
+    ],
+)
+
+py_test(
+    name = "test_transform",
+    size = "small",
+    srcs = ["python/learn/tests/dataframe/test_transform.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":learn",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/python:framework_test_lib",
+    ],
+)
+
 py_test(
     name = "test_early_stopping",
     size = "medium",
@@ -124,7 +160,7 @@ py_test(
 
 py_test(
     name = "dnn_linear_combined_test",
-    size = "small",
+    size = "medium",
     srcs = ["python/learn/estimators/dnn_linear_combined_test.py"],
     srcs_version = "PY2AND3",
     deps = [
diff --git a/tensorflow/contrib/learn/python/learn/__init__.py b/tensorflow/contrib/learn/python/learn/__init__.py
index 79a0badb7fa..8de7797e6b7 100644
--- a/tensorflow/contrib/learn/python/learn/__init__.py
+++ b/tensorflow/contrib/learn/python/learn/__init__.py
@@ -29,6 +29,7 @@ from tensorflow.contrib.learn.python.learn import ops
 from tensorflow.contrib.learn.python.learn import preprocessing
 from tensorflow.contrib.learn.python.learn import utils
 # pylint: disable=wildcard-import
+from tensorflow.contrib.learn.python.learn.dataframe import *
 from tensorflow.contrib.learn.python.learn.estimators import *
 from tensorflow.contrib.learn.python.learn.graph_actions import evaluate
 from tensorflow.contrib.learn.python.learn.graph_actions import infer
diff --git a/tensorflow/contrib/learn/python/learn/dataframe/__init__.py b/tensorflow/contrib/learn/python/learn/dataframe/__init__.py
new file mode 100644
index 00000000000..c83bc658521
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/dataframe/__init__.py
@@ -0,0 +1,26 @@
+"""DataFrames for ingesting and preprocessing data."""
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.learn.python.learn.dataframe.column import Column
+from tensorflow.contrib.learn.python.learn.dataframe.column import TransformedColumn
+from tensorflow.contrib.learn.python.learn.dataframe.dataframe import DataFrame
+from tensorflow.contrib.learn.python.learn.dataframe.transform import parameter
+from tensorflow.contrib.learn.python.learn.dataframe.transform import Transform
+
+__all__ = ['Column', 'TransformedColumn', 'DataFrame', 'parameter', 'Transform']
diff --git a/tensorflow/contrib/learn/python/learn/dataframe/column.py b/tensorflow/contrib/learn/python/learn/dataframe/column.py
new file mode 100644
index 00000000000..077013f969c
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/dataframe/column.py
@@ -0,0 +1,93 @@
+"""A Column represents a deferred Tensor computation in a DataFrame."""
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from abc import ABCMeta
+
+
+class Column(object):
+  """A single output column.
+
+  Represents the deferred construction of a graph that computes the column
+  values.
+
+  Note every `Column` should be a `TransformedColumn`, except when mocked.
+  """
+
+  __metaclass__ = ABCMeta
+
+  def build(self, cache):
+    """Returns a Tensor."""
+    raise NotImplementedError()
+
+
+class TransformedColumn(Column):
+  """A `Column` that results from applying a `Transform` to a list of inputs."""
+
+  def __init__(self, input_columns, transform, output_name):
+    super(TransformedColumn, self).__init__()
+    self._input_columns = input_columns
+    self._transform = transform
+    self._output_name = output_name
+
+    if output_name is None:
+      raise ValueError("output_name must be provided")
+
+    if len(input_columns) != transform.input_valency:
+      raise ValueError("Expected %s input Columns but received %s." %
+                       (transform.input_valency, len(input_columns)))
+
+    self._repr = TransformedColumn.make_repr(
+        self._input_columns, self._transform, self._output_name)
+
+  def build(self, cache=None):
+    if cache is None:
+      cache = {}
+    all_outputs = self._transform.apply_transform(self._input_columns, cache)
+    return getattr(all_outputs, self._output_name)
+
+  def __repr__(self):
+    return self._repr
+
+  # Note we need to generate column reprs from Transform, without needing the
+  # columns themselves.  So we just make this public.  Alternatively we could
+  # create throwaway columns just in order to call repr() on them.
+  @staticmethod
+  def make_repr(input_columns, transform, output_name):
+    """Generate a key for caching Tensors produced for a TransformedColumn.
+
+    Generally we a need a deterministic unique key representing which transform
+    was applied to which inputs, and which output was selected.
+
+    Args:
+      input_columns: the input `Columns` for the `Transform`
+      transform: the `Transform` being applied
+      output_name: the name of the specific output from the `Transform` that is
+        to be cached
+
+    Returns:
+      A string suitable for use as a cache key for Tensors produced via a
+        TransformedColumn
+    """
+    input_column_keys = [repr(column) for column in input_columns]
+    input_column_keys_joined = ", ".join(input_column_keys)
+
+    return "%s(%s)[%s]" % (
+        repr(transform), input_column_keys_joined, output_name)
+
+
diff --git a/tensorflow/contrib/learn/python/learn/dataframe/dataframe.py b/tensorflow/contrib/learn/python/learn/dataframe/dataframe.py
new file mode 100644
index 00000000000..b27ce6d9f49
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/dataframe/dataframe.py
@@ -0,0 +1,172 @@
+"""A DataFrame is a container for ingesting and preprocessing data."""
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from abc import ABCMeta
+import collections
+
+from .column import Column
+from .transform import Transform
+
+
+class DataFrame(object):
+  """A DataFrame is a container for ingesting and preprocessing data."""
+  __metaclass__ = ABCMeta
+
+  def __init__(self):
+    self._columns = {}
+
+  def columns(self):
+    """Set of the column names."""
+    return frozenset(self._columns.keys())
+
+  def __len__(self):
+    """The number of columns in the DataFrame."""
+    return len(self._columns)
+
+  def assign(self, **kwargs):
+    """Adds columns to DataFrame.
+
+    Args:
+      **kwargs: assignments of the form key=value where key is a string
+      and value is an `inflow.Series`, a `pandas.Series` or a numpy array.
+
+    Raises:
+      TypeError: keys are not strings.
+      TypeError: values are not `inflow.Series`, `pandas.Series` or
+      `numpy.ndarray`.
+
+    TODO(jamieas): pandas assign method returns a new DataFrame. Consider
+    switching to this behavior, changing the name or adding in_place as an
+    argument.
+    """
+    for k, v in kwargs.items():
+      if not isinstance(k, str):
+        raise TypeError("The only supported type for keys is string; got %s" %
+                        type(k))
+      if isinstance(v, Column):
+        s = v
+      elif isinstance(v, Transform) and v.input_valency() == 0:
+        s = v()
+      # TODO(jamieas): hook up these special cases again
+      # TODO(soergel): can these special cases be generalized?
+      # elif isinstance(v, pd.Series):
+      #   s = series.NumpySeries(v.values)
+      # elif isinstance(v, np.ndarray):
+      #   s = series.NumpySeries(v)
+      else:
+        raise TypeError(
+            "Column in assignment must be an inflow.Column, pandas.Series or a"
+            " numpy array; got type '%s'." % type(v).__name__)
+      self._columns[k] = s
+
+  def select(self, keys):
+    """Returns a new DataFrame with a subset of columns.
+
+    Args:
+      keys: A list of strings. Each should be the name of a column in the
+        DataFrame.
+    Returns:
+      A new DataFrame containing only the specified columns.
+    """
+    result = type(self)()
+    for key in keys:
+      result[key] = self._columns[key]
+    return result
+
+  def __getitem__(self, key):
+    """Indexing functionality for DataFrames.
+
+    Args:
+      key: a string or an iterable of strings.
+
+    Returns:
+      A Series or list of Series corresponding to the given keys.
+    """
+    if isinstance(key, str):
+      return self._columns[key]
+    elif isinstance(key, collections.Iterable):
+      for i in key:
+        if not isinstance(i, str):
+          raise TypeError("Expected a String; entry %s has type %s." %
+                          (i, type(i).__name__))
+      return [self.__getitem__(i) for i in key]
+    raise TypeError(
+        "Invalid index: %s of type %s. Only strings or lists of strings are "
+        "supported." % (key, type(key)))
+
+  def __setitem__(self, key, value):
+    if isinstance(key, str):
+      key = [key]
+    if isinstance(value, Column):
+      value = [value]
+    self.assign(**dict(zip(key, value)))
+
+  def build(self):
+    # We do not allow passing a cache here, because that would encourage
+    # working around the rule that DataFrames cannot be expected to be
+    # synced with each other (e.g., they shuffle independently).
+    cache = {}
+    tensors = {name: c.build(cache) for name, c in self._columns.items()}
+    return tensors
+
+  def to_input_fn(self, feature_keys=None, target_keys=None):
+    """Build an input_fn suitable for use with Estimator.
+
+    Args:
+      feature_keys: the names of columns to be used as features.  If None, all
+        columns except those in target_keys are used.
+      target_keys: the names of columns to be used as targets.  None is
+        acceptable for unsupervised learning.
+
+    Returns:
+      A function that returns a pair of dicts (features, targets), each mapping
+        string names to Tensors.
+
+    Raises:
+      ValueError: when the feature and target key sets are non-disjoint
+    """
+    if target_keys is None:
+      target_keys = []
+
+    if feature_keys is None:
+      if target_keys:
+        feature_keys = self.columns() - set(target_keys)
+      else:
+        feature_keys = self.columns()
+    else:
+      in_both = set(feature_keys) & set(target_keys)
+      if in_both:
+        raise ValueError(
+            "Columns cannot be used for both features and targets: %s" %
+            ", ".join(in_both))
+
+    def input_fn():
+      # It's important to build all the tensors together in one DataFrame.
+      # If we did df.select() for both key sets and then build those, the two
+      # resulting DataFrames would be shuffled independently.
+      tensors = self.build()
+
+      # Note that (for now at least) we provide our columns to Estimator keyed
+      # by strings, so they are base features as far as Estimator is concerned.
+      # TODO(soergel): reconcile with FeatureColumn keys, Transformer etc.
+      features = {key: tensors[key] for key in feature_keys}
+      targets = {key: tensors[key] for key in target_keys}
+      return features, targets
+
+    return input_fn
diff --git a/tensorflow/contrib/learn/python/learn/dataframe/transform.py b/tensorflow/contrib/learn/python/learn/dataframe/transform.py
new file mode 100644
index 00000000000..a3a3a83a19c
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/dataframe/transform.py
@@ -0,0 +1,287 @@
+"""A Transform takes a list of `Column` and returns a namedtuple of `Column`."""
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from abc import ABCMeta
+from abc import abstractmethod
+from abc import abstractproperty
+
+import collections
+import inspect
+
+from .column import Column
+from .column import TransformedColumn
+
+
+def _make_list_of_column(x):
+  """Converts `x` into a list of `Column` if possible.
+
+  Args:
+    x: a `Column`, a list of `Column` or `None`.
+
+  Returns:
+    `x` if it is a list of Column, `[x]` if `x` is a `Column`, `[]` if x is
+    `None`.
+
+  Raises:
+    TypeError: `x` is not a `Column` a list of `Column` or `None`.
+  """
+  if x is None:
+    return []
+  elif isinstance(x, Column):
+    return [x]
+  elif isinstance(x, (list, tuple)):
+    for i, y in enumerate(x):
+      if not isinstance(y, Column):
+        raise TypeError(
+            "Expected a tuple or list of Columns; entry %s has type %s." %
+            (i, type(y).__name__))
+    return list(x)
+  raise TypeError("Expected a Column or list of Column; got %s" %
+                  type(x).__name__)
+
+
+def _make_tuple_of_string(x):
+  """Converts `x` into a list of `str` if possible.
+
+  Args:
+    x: a `str`, a list of `str`, a tuple of `str`, or `None`.
+
+  Returns:
+    `x` if it is a tuple of str, `tuple(x)` if it is a list of str,
+    `(x)` if `x` is a `str`, `()` if x is `None`.
+
+  Raises:
+    TypeError: `x` is not a `str`, a list or tuple of `str`, or `None`.
+  """
+  if x is None:
+    return ()
+  elif isinstance(x, str):
+    return (x,)
+  elif isinstance(x, (list, tuple)):
+    for i, y in enumerate(x):
+      if not isinstance(y, str):
+        raise TypeError(
+            "Expected a tuple or list of strings; entry %s has type %s." %
+            (i, type(y).__name__))
+    return x
+  raise TypeError("Expected a string or list of strings or tuple of strings; " +
+                  "got %s" % type(x).__name__)
+
+
+def parameter(func):
+  """Tag functions annotated with `@parameter` for later retrieval.
+
+  Note that all `@parameter`s are automatically `@property`s as well.
+
+  Args:
+    func: the getter function to tag and wrap
+
+  Returns:
+    A `@property` whose getter function is marked with is_parameter = True
+  """
+  func.is_parameter = True
+  return property(func)
+
+
+class Transform(object):
+  """A function from a list of `Column` to a namedtuple of `Column`.
+
+  Transforms map zero or more columns of a DataFrame to new columns.
+  """
+
+  __metaclass__ = ABCMeta
+
+  def __init__(self):
+    self._return_type = None
+
+  @abstractproperty
+  def name(self):
+    """Name of the transform."""
+    raise NotImplementedError()
+
+  def parameters(self):
+    """A dict of names to values of properties marked with `@parameter`."""
+    property_param_names = [name
+                            for name, func in inspect.getmembers(type(self))
+                            if (hasattr(func, "fget") and hasattr(
+                                getattr(func, "fget"), "is_parameter"))]
+    return {name: getattr(self, name) for name in property_param_names}
+
+  @abstractproperty
+  def input_valency(self):
+    """The number of `Column`s that the `Transform` should expect as input.
+
+    `None` indicates that the transform can take a variable number of inputs.
+
+    This function should depend only on `@parameter`s of this `Transform`.
+
+    Returns:
+      The number of expected inputs.
+    """
+    raise NotImplementedError()
+
+  @property
+  def output_names(self):
+    """The names of `Column`s output by the `Transform`.
+
+    This function should depend only on `@parameter`s of this `Transform`.
+
+    Returns:
+      A tuple of names of outputs provided by this Transform.
+    """
+    return _make_tuple_of_string(self._output_names)
+
+  @abstractproperty
+  def _output_names(self):
+    """The names of `Column`s output by the `Transform`.
+
+    This function should depend only on `@parameter`s of this `Transform`.
+
+    Returns:
+      Names of outputs provided by this Transform, as a string, tuple, or list.
+    """
+    raise NotImplementedError()
+
+  @property
+  def return_type(self):
+    """Provides a namedtuple type which will be used for output.
+
+    A Transform generates one or many outputs, named according to
+    _output_names.  This method creates (and caches) a namedtuple type using
+    those names as the keys.  The Transform output is then generated by
+    instantiating an object of this type with corresponding values.
+
+    Note this output type is used both for `__call__`, in which case the
+    values are `TransformedColumn`s, and for `apply_transform`, in which case
+    the values are `Tensor`s.
+
+    Returns:
+      A namedtuple type fixing the order and names of the outputs of this
+        transform.
+    """
+    if self._return_type is None:
+      # TODO(soergel): pylint 3 chokes on this, but it is legit and preferred.
+      # return_type_name = "%sReturnType" % type(self).__name__
+      return_type_name = "ReturnType"
+      self._return_type = collections.namedtuple(return_type_name,
+                                                 self.output_names)
+    return self._return_type
+
+  def _check_output_tensors(self, output_tensors):
+    """Helper for `build(...)`; verifies the output of `_build_transform`.
+
+    Args:
+      output_tensors: value returned by a call to `_build_transform`.
+
+    Raises:
+      TypeError: `transform_output` is not a list.
+      ValueError: `transform_output` does not match `output_names`.
+    """
+    if not isinstance(output_tensors, self.return_type):
+      raise TypeError(
+          "Expected a NamedTuple of Tensors with elements %s; got %s." %
+          (self.output_names, type(output_tensors).__name__))
+
+  def __call__(self, input_columns=None):
+    """Apply this `Transform` to the provided `Column`s, producing 'Column's.
+
+    Args:
+      input_columns: None, a `Column`, or a list of input `Column`s, acting as
+         positional arguments.
+
+    Returns:
+      A namedtuple of the output Columns.
+
+    Raises:
+      ValueError: `input_columns` does not have expected length
+    """
+    input_columns = _make_list_of_column(input_columns)
+    if len(input_columns) != self.input_valency:
+      raise ValueError("Expected %s input Columns but received %s." %
+                       (self.input_valency, len(input_columns)))
+    output_columns = [TransformedColumn(input_columns, self, output_name)
+                      for output_name in self.output_names]
+
+    # pylint: disable=not-callable
+    return self.return_type(*output_columns)
+
+  def apply_transform(self, input_columns, cache=None):
+    """Apply this `Transform` to the provided `Column`s, producing 'Tensor's.
+
+    Args:
+      input_columns: None, a `Column`, or a list of input `Column`s, acting as
+         positional arguments.
+      cache: a dict from Column reprs to Tensors.
+
+    Returns:
+      A namedtuple of the output Tensors.
+
+    Raises:
+      ValueError: `input_columns` does not have expected length
+    """
+    # pylint: disable=not-callable
+    if cache is None:
+      cache = {}
+
+    if len(input_columns) != self.input_valency:
+      raise ValueError("Expected %s input Columns but received %s." %
+                       (self.input_valency, len(input_columns)))
+    input_tensors = [input_column.build(cache)
+                     for input_column in input_columns]
+
+    # Note we cache each output individually, not just the entire output
+    # tuple.  This allows using the graph as the cache, since it can sensibly
+    # cache only individual Tensors.
+    output_reprs = [TransformedColumn.make_repr(input_columns, self,
+                                                output_name)
+                    for output_name in self.output_names]
+    output_tensors = [cache.get(output_repr) for output_repr in output_reprs]
+
+    if None in output_tensors:
+      result = self._apply_transform(input_tensors)
+      for output_name, output_repr in zip(self.output_names, output_reprs):
+        cache[output_repr] = getattr(result, output_name)
+    else:
+      result = self.return_type(*output_tensors)
+
+    self._check_output_tensors(result)
+    return result
+
+  @abstractmethod
+  def _apply_transform(self, input_tensors):
+    """Applies the transformation to the `transform_input`.
+
+    Args:
+        input_tensors: a list of Tensors representing the input to
+        the Transform.
+
+    Returns:
+        A namedtuple of Tensors representing the transformed output.
+    """
+    raise NotImplementedError()
+
+  def __str__(self):
+    return self.name
+
+  def __repr__(self):
+    parameters_sorted = ["%s: %s" % (repr(k), repr(v))
+                         for k, v in sorted(self.parameters().items())]
+    parameters_joined = ", ".join(parameters_sorted)
+
+    return "%s({%s})" % (self.name, parameters_joined)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/base.py b/tensorflow/contrib/learn/python/learn/estimators/base.py
index a94ee9e7170..39131f059b0 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/base.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/base.py
@@ -240,9 +240,9 @@ class TensorFlowEstimator(estimator.Estimator):
         input_fn=predict_data_feeder.input_builder,
         feed_fn=predict_data_feeder.get_feed_dict_fn())
     if self.n_classes > 1 and axis != -1:
-      preds = preds['predictions'].argmax(axis=axis)
+      preds = preds.argmax(axis=axis)
     else:
-      preds = preds['predictions']
+      preds = preds
     return preds
 
   def predict(self, x, axis=1, batch_size=None):
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
index 73e29299d60..b1f34fa7400 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
@@ -17,8 +17,10 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import inspect
 import math
 
+import numpy as np
 import six
 
 from tensorflow.contrib import layers
@@ -70,6 +72,7 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
       deep part of the model. If `None`, will use an Adagrad optimizer.
     dnn_activation_fn: Activation function applied to each layer. If `None`,
       will use `tf.nn.relu`.
+    config: RunConfig object to configure the runtime settings.
 
     Raises:
       ValueError: If both linear_feature_columns and dnn_features_columns are
@@ -85,8 +88,10 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
                dnn_feature_columns=None,
                dnn_optimizer=None,
                dnn_hidden_units=None,
-               dnn_activation_fn=nn.relu):
-    super(_DNNLinearCombinedBaseEstimator, self).__init__(model_dir=model_dir)
+               dnn_activation_fn=nn.relu,
+               config=None):
+    super(_DNNLinearCombinedBaseEstimator, self).__init__(model_dir=model_dir,
+                                                          config=config)
     self._n_classes = n_classes
     self._weight_column_name = weight_column_name
     self._linear_feature_columns = linear_feature_columns
@@ -100,6 +105,37 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
     self._dnn_weight_collection = "DNNLinearCombined_dnn"
     self._linear_weight_collection = "DNNLinearCombined_linear"
 
+  def predict(self, x=None, input_fn=None, batch_size=None):
+    """Returns predictions for given features.
+
+    Args:
+      x: features.
+      input_fn: Input function. If set, x must be None.
+      batch_size: Override default batch size.
+
+    Returns:
+      Numpy array of predicted classes or regression values.
+    """
+    predictions = self._infer_model(x=x,
+                                    input_fn=input_fn,
+                                    batch_size=batch_size)
+    if self._n_classes > 1:
+      predictions = np.argmax(predictions, axis=1)
+    return predictions
+
+  def predict_proba(self, x=None, input_fn=None, batch_size=None):
+    """Returns prediction probabilities for given features (classification).
+
+    Args:
+      x: features.
+      input_fn: Input function. If set, x and y must be None.
+      batch_size: Override default batch size.
+
+    Returns:
+      Numpy array of predicted probabilities.
+    """
+    return self._infer_model(x=x, input_fn=input_fn, batch_size=batch_size)
+
   def _get_train_ops(self, features, targets):
     """See base class."""
     global_step = variables.get_global_step()
@@ -123,39 +159,55 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
       with ops.get_default_graph().colocate_with(global_step):
         return state_ops.assign_add(global_step, 1).op, loss
 
-  def _get_eval_ops(self, features, targets, metrics):
-    """See base class."""
-    predictions = self._get_predict_ops(features)
+  def _run_metrics(self, predictions, targets, metrics, weights):
     result = {}
+    targets = math_ops.cast(targets, predictions.dtype)
     for name, metric in six.iteritems(metrics):
-      result[name] = metric(predictions, targets,
-                            self._get_weight_tensor(features))
+      if "weights" in inspect.getargspec(metric)[0]:
+        result[name] = metric(predictions, targets, weights=weights)
+      else:
+        result[name] = metric(predictions, targets)
+
     return result
 
-  def _get_default_metric_functions(self):
+  def _get_eval_ops(self, features, targets, metrics=None):
     """See base class."""
-    def _compute_loss(logits, targets, weights=None):
-      return metrics_lib.streaming_mean(self._loss(
-          logits, targets, weight_tensor=weights))
+    logits = self._logits(features)
+    result = {"loss": metrics_lib.streaming_mean(self._loss(
+        logits, targets,
+        weight_tensor=self._get_weight_tensor(features)))}
 
-    def _compute_accuracy(logits, targets, weights=None):
-      if self._n_classes > 2:
-        _, predictions = nn.top_k(logits, 1)
-      else:
-        predictions = array_ops.reshape(logits, [-1])
-        predictions = math_ops.greater(predictions,
-                                       array_ops.zeros_like(predictions))
-        targets = array_ops.reshape(targets, [-1])
-      return metrics_lib.streaming_accuracy(
-          math_ops.to_int32(predictions), math_ops.to_int32(targets), weights)
+    # Adding default metrics
+    if metrics is None and self._n_classes > 1:
+      metrics = {"accuracy": metrics_lib.streaming_accuracy}
+
+    if self._n_classes == 2:
+      predictions = math_ops.sigmoid(logits)
+      result["eval_auc"] = metrics_lib.streaming_auc(predictions, targets)
+
+    if metrics:
+      predictions = self._logits_to_predictions(logits, proba=False)
+      result.update(self._run_metrics(predictions, targets, metrics,
+                                      self._get_weight_tensor(features)))
 
-    result = {"loss": _compute_loss}
-    if self._n_classes > 1:
-      result["accuracy"] = _compute_accuracy
     return result
 
   def _get_predict_ops(self, features):
-    return self._logits(features)
+    """See base class."""
+    logits = self._logits(features)
+    return self._logits_to_predictions(logits, proba=True)
+
+  def _logits_to_predictions(self, logits, proba=False):
+    if self._n_classes < 2:
+      return array_ops.reshape(logits, [-1])
+
+    if self._n_classes == 2:
+      logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
+
+    if proba:
+      return nn.softmax(logits)
+    else:
+      return math_ops.argmax(logits, 1)
 
   def _get_feature_ops_from_example(self, examples_batch):
     column_types = layers.create_dict_for_parse_example(
@@ -367,6 +419,7 @@ class DNNLinearCombinedClassifier(_DNNLinearCombinedBaseEstimator):
       deep part of the model. If `None`, will use an Adagrad optimizer.
     dnn_activation_fn: Activation function applied to each layer. If `None`,
       will use `tf.nn.relu`.
+    config: RunConfig object to configure the runtime settings.
 
     Raises:
       ValueError: If both linear_feature_columns and dnn_features_columns are
@@ -383,7 +436,8 @@ class DNNLinearCombinedClassifier(_DNNLinearCombinedBaseEstimator):
                dnn_feature_columns=None,
                dnn_optimizer=None,
                dnn_hidden_units=None,
-               dnn_activation_fn=nn.relu):
+               dnn_activation_fn=nn.relu,
+               config=None):
     if n_classes < 2:
       raise ValueError("n_classes should be greater than 1. Given: {}".format(
           n_classes))
@@ -397,7 +451,8 @@ class DNNLinearCombinedClassifier(_DNNLinearCombinedBaseEstimator):
         dnn_feature_columns=dnn_feature_columns,
         dnn_optimizer=dnn_optimizer,
         dnn_hidden_units=dnn_hidden_units,
-        dnn_activation_fn=dnn_activation_fn)
+        dnn_activation_fn=dnn_activation_fn,
+        config=config)
 
 
 class DNNLinearCombinedRegressor(_DNNLinearCombinedBaseEstimator):
@@ -466,6 +521,7 @@ class DNNLinearCombinedRegressor(_DNNLinearCombinedBaseEstimator):
       deep part of the model. If `None`, will use an Adagrad optimizer.
     dnn_activation_fn: Activation function applied to each layer. If None, will
       use `tf.nn.relu`.
+    config: RunConfig object to configure the runtime settings.
 
     Raises:
       ValueError: If both linear_feature_columns and dnn_features_columns are
@@ -480,7 +536,8 @@ class DNNLinearCombinedRegressor(_DNNLinearCombinedBaseEstimator):
                dnn_feature_columns=None,
                dnn_optimizer=None,
                dnn_hidden_units=None,
-               dnn_activation_fn=nn.relu):
+               dnn_activation_fn=nn.relu,
+               config=None):
     super(DNNLinearCombinedRegressor, self).__init__(
         model_dir=model_dir,
         n_classes=0,
@@ -490,4 +547,5 @@ class DNNLinearCombinedRegressor(_DNNLinearCombinedBaseEstimator):
         dnn_feature_columns=dnn_feature_columns,
         dnn_optimizer=dnn_optimizer,
         dnn_hidden_units=dnn_hidden_units,
-        dnn_activation_fn=dnn_activation_fn)
+        dnn_activation_fn=dnn_activation_fn,
+        config=config)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
index 16d86bf73d7..d4188a44ac9 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
 
 import numpy as np
 import tensorflow as tf
+from tensorflow.contrib.learn.python.learn.estimators import _sklearn
 
 
 def _get_quantile_based_buckets(feature_values, num_buckets):
@@ -229,6 +230,58 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
     scores = classifier.evaluate(input_fn=_iris_input_fn, steps=100)
     self.assertGreater(scores['accuracy'], 0.9)
 
+  def testPredict(self):
+    """Tests weight column in evaluation."""
+
+    def _input_fn_train():
+      # Create 4 rows, one of them (y = x), three of them (y=Not(x))
+      target = tf.constant([[1], [0], [0], [0]])
+      features = {'x': tf.ones(shape=[4, 1], dtype=tf.float32),}
+      return features, target
+
+    def _input_fn_predict():
+      features = {'x': tf.ones(shape=[4, 1], dtype=tf.float32),}
+      return features
+
+    classifier = tf.contrib.learn.DNNLinearCombinedClassifier(
+        linear_feature_columns=[tf.contrib.layers.real_valued_column('x')],
+        dnn_feature_columns=[tf.contrib.layers.real_valued_column('x')],
+        dnn_hidden_units=[3, 3])
+
+    classifier.train(input_fn=_input_fn_train, steps=100)
+    probs = classifier.predict_proba(input_fn=_input_fn_predict)
+    self.assertAllClose([[0.75, 0.25]] * 4, probs, 0.01)
+    classes = classifier.predict(input_fn=_input_fn_predict)
+    self.assertListEqual([0] * 4, list(classes))
+
+  def testCustomMetrics(self):
+    """Tests weight column in evaluation."""
+
+    def _input_fn_train():
+      # Create 4 rows, one of them (y = x), three of them (y=Not(x))
+      target = tf.constant([[1], [0], [0], [0]])
+      features = {'x': tf.ones(shape=[4, 1], dtype=tf.float32),}
+      return features, target
+
+    classifier = tf.contrib.learn.DNNLinearCombinedClassifier(
+        linear_feature_columns=[tf.contrib.layers.real_valued_column('x')],
+        dnn_feature_columns=[tf.contrib.layers.real_valued_column('x')],
+        dnn_hidden_units=[3, 3])
+
+    classifier.train(input_fn=_input_fn_train, steps=100)
+    scores = classifier.evaluate(
+        input_fn=_input_fn_train,
+        steps=100,
+        metrics={
+            'my_accuracy': tf.contrib.metrics.streaming_accuracy,
+            'my_precision': tf.contrib.metrics.streaming_precision
+        })
+    self.assertTrue(set(['loss', 'my_accuracy', 'my_precision']).issubset(set(
+        scores.keys())))
+    predictions = classifier.predict(input_fn=_input_fn_train)
+    self.assertEqual(_sklearn.accuracy_score([1, 0, 0, 0], predictions),
+                     scores['my_accuracy'])
+
 
 class DNNLinearCombinedRegressorTest(tf.test.TestCase):
 
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index e2340167131..0fce7d140f1 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -29,6 +29,7 @@ import six
 from tensorflow.contrib import framework as contrib_framework
 from tensorflow.contrib import layers
 from tensorflow.contrib import losses
+from tensorflow.contrib import metrics as metrics_lib
 from tensorflow.contrib.learn.python.learn import graph_actions
 from tensorflow.contrib.learn.python.learn import monitors as monitors_lib
 from tensorflow.contrib.learn.python.learn.estimators import _sklearn as sklearn
@@ -51,7 +52,7 @@ from tensorflow.python.training import saver
 # Default metrics for evaluation.
 _EVAL_METRICS = {
     'regression': {
-        'mean_squared_error': losses.sum_of_squares,
+        'mean_squared_error': metrics_lib.streaming_mean_squared_error,
     },
     'classification': {
         'logistic': losses.sigmoid_cross_entropy,
@@ -74,28 +75,15 @@ class ModeKeys(object):
 
 
 def _get_input_fn(x, y, batch_size):
-  # TODO(ipoloshukin): Remove this when refactor of data_feeder is done
-  if hasattr(x, 'create_graph') and hasattr(y, 'create_graph'):
-    def input_fn():
-      return x.create_graph(), y.create_graph()
-    return input_fn, None
-
-  df = data_feeder.setup_train_data_feeder(x, y,
-                                           n_classes=None,
-                                           batch_size=batch_size)
+  df = data_feeder.setup_train_data_feeder(
+      x, y, n_classes=None, batch_size=batch_size)
   return df.input_builder, df.get_feed_dict_fn()
 
 
-def _get_predict_input_fn(x, batch_size):
-  # TODO(ipoloshukin): Remove this when refactor of data_feeder is done
-  if hasattr(x, 'create_graph'):
-    def input_fn():
-      return x.create_graph()
-    return input_fn, None
-
-  df = data_feeder.setup_train_data_feeder(x, None,
-                                           n_classes=None,
-                                           batch_size=batch_size, epochs=1)
+def _get_predict_input_fn(x, y, batch_size):
+  df = data_feeder.setup_train_data_feeder(
+      x, y, n_classes=None, batch_size=batch_size,
+      shuffle=False, epochs=1)
   return df.input_builder, df.get_feed_dict_fn()
 
 
@@ -147,78 +135,6 @@ class BaseEstimator(sklearn.BaseEstimator):
 
     self._graph = None
 
-  @property
-  def model_dir(self):
-    return self._model_dir
-
-  @abc.abstractproperty
-  def _get_train_ops(self, features, targets):
-    """Method that builds model graph and returns trainer ops.
-
-    Expected to be overriden by sub-classes that require custom support.
-
-    Args:
-      features: `Tensor` or `dict` of `Tensor` objects.
-      targets: `Tensor` or `dict` of `Tensor` objects.
-
-    Returns:
-      Tuple of train `Operation` and loss `Tensor`.
-    """
-    pass
-
-  @abc.abstractproperty
-  def _get_predict_ops(self, features):
-    """Method that builds model graph and returns prediction ops.
-
-    Args:
-      features: `Tensor` or `dict` of `Tensor` objects.
-
-    Returns:
-      predictions: `Tensor` or `dict` of `Tensor` objects.
-    """
-    pass
-
-  def _get_eval_ops(self, features, targets, metrics):
-    """Method that builds model graph and returns evaluation ops.
-
-    Args:
-      features: `Tensor` or `dict` of `Tensor` objects.
-      targets: `Tensor` or `dict` of `Tensor` objects.
-      metrics: `dict` of functions that take predictions and targets.
-
-    Returns:
-      metrics: `dict` of `Tensor` objects.
-    """
-    predictions = self._get_predict_ops(features)
-    result = {}
-    for name, metric in six.iteritems(metrics):
-      result[name] = metric(predictions, targets)
-    return result
-
-  def _get_feature_ops_from_example(self, examples_batch):
-    """Method that returns features given the batch of examples.
-
-    This method will be used to export model into a server.
-
-    Args:
-      examples_batch: batch of tf.Example
-
-    Returns:
-      features: `Tensor` or `dict` of `Tensor` objects.
-    """
-    raise NotImplementedError('_get_feature_ops_from_example not implemented '
-                              'in BaseEstimator')
-
-  def _get_default_metric_functions(self):
-    """Method that provides default metric operations.
-
-    This functions is intented to be overridden by sub-classes.
-    Returns:
-      `dict` of functions that take predictions and targets `Tensor` objects and
-      return `Tensor`.
-    """
-    return {}
-
   def fit(self, x, y, steps, batch_size=32, monitors=None):
     """Trains a model given training data X and y.
 
@@ -296,7 +212,7 @@ class BaseEstimator(sklearn.BaseEstimator):
                input_fn=None,
                feed_fn=None,
                batch_size=32,
-               steps=100,
+               steps=None,
                metrics=None,
                name=None):
     """Evaluates given model with provided evaluation data.
@@ -325,37 +241,85 @@ class BaseEstimator(sklearn.BaseEstimator):
       raise ValueError('Either x and y or input_fn must be None.')
     if input_fn is None:
       assert x is not None
-      input_fn, feed_fn = _get_input_fn(x, y, batch_size)
+      input_fn, feed_fn = _get_predict_input_fn(x, y, batch_size)
     return self._evaluate_model(input_fn=input_fn,
                                 feed_fn=feed_fn,
                                 steps=steps,
                                 metrics=metrics,
                                 name=name)
 
-  def predict(self, x, axis=None, batch_size=None):
+  def predict(self, x=None, input_fn=None, batch_size=None):
     """Returns predictions for given features.
 
     Args:
       x: features.
-      axis: Axis on which to argmax. (for classification).
+      input_fn: Input function. If set, x must be None.
       batch_size: Override default batch size.
 
     Returns:
       Numpy array of predicted classes or regression values.
     """
-    return self._infer_model(x=x, batch_size=batch_size, axis=axis)
+    return self._infer_model(x=x, input_fn=input_fn, batch_size=batch_size)
 
-  def predict_proba(self, x, batch_size=None):
-    """Returns prediction probabilities for given features (classification).
+  @property
+  def model_dir(self):
+    return self._model_dir
+
+  @abc.abstractproperty
+  def _get_train_ops(self, features, targets):
+    """Method that builds model graph and returns trainer ops.
+
+    Expected to be overriden by sub-classes that require custom support.
 
     Args:
-      x: features.
-      batch_size: Override default batch size.
+      features: `Tensor` or `dict` of `Tensor` objects.
+      targets: `Tensor` or `dict` of `Tensor` objects.
 
     Returns:
-      Numpy array of predicted probabilities.
+      Tuple of train `Operation` and loss `Tensor`.
     """
-    return self._infer_model(x=x, batch_size=batch_size, proba=True)
+    pass
+
+  @abc.abstractproperty
+  def _get_predict_ops(self, features):
+    """Method that builds model graph and returns prediction ops.
+
+    Args:
+      features: `Tensor` or `dict` of `Tensor` objects.
+
+    Returns:
+      predictions: `Tensor` or `dict` of `Tensor` objects.
+    """
+    pass
+
+  def _get_eval_ops(self, features, targets, metrics):
+    """Method that builds model graph and returns evaluation ops.
+
+    Expected to be overriden by sub-classes that require custom support.
+
+    Args:
+      features: `Tensor` or `dict` of `Tensor` objects.
+      targets: `Tensor` or `dict` of `Tensor` objects.
+      metrics: `dict` of functions that take predictions and targets.
+
+    Returns:
+      metrics: `dict` of `Tensor` objects.
+    """
+    raise NotImplementedError('_get_eval_ops not implemented in BaseEstimator')
+
+  def _get_feature_ops_from_example(self, examples_batch):
+    """Method that returns features given the batch of examples.
+
+    This method will be used to export model into a server.
+
+    Args:
+      examples_batch: batch of tf.Example
+
+    Returns:
+      features: `Tensor` or `dict` of `Tensor` objects.
+    """
+    raise NotImplementedError('_get_feature_ops_from_example not implemented '
+                              'in BaseEstimator')
 
   def _check_inputs(self, features, targets):
     if self._features_info is not None:
@@ -416,6 +380,11 @@ class BaseEstimator(sklearn.BaseEstimator):
           summary_op=logging_ops.get_summary_op(),
           save_summary_steps=100)
 
+      is_chief = self._config.task == 0
+      if not is_chief:
+        # Run monitors only on chief.
+        monitors = []
+
       # Setup monitors.
       for monitor in monitors:
         monitor.set_estimator(self)
@@ -430,7 +399,7 @@ class BaseEstimator(sklearn.BaseEstimator):
           init_feed_dict=init_feed_fn() if init_feed_fn is not None else None,
           init_fn=init_fn,
           log_every_steps=log_every_steps,
-          supervisor_is_chief=(self._config.task == 0),
+          supervisor_is_chief=is_chief,
           supervisor_master=self._config.master,
           feed_fn=feed_fn,
           max_steps=steps,
@@ -450,6 +419,7 @@ class BaseEstimator(sklearn.BaseEstimator):
           logging.warning(
               'Ignoring metric {}. It returned a list|tuple with len {}, '
               'expected 2'.format(name, len(metric_ops)))
+          value_ops[name] = metric_ops
       else:
         value_ops[name] = metric_ops
 
@@ -469,7 +439,7 @@ class BaseEstimator(sklearn.BaseEstimator):
     if self._config.execution_mode not in ('all', 'evaluate', 'eval_evalset'):
       return
 
-    checkpoint_path = saver.latest_checkpoint(self._model_dir)
+    checkpoint_path = self._model_dir
     eval_dir = os.path.join(self._model_dir, 'eval' if not name else
                             'eval_' + name)
     with ops.Graph().as_default() as g:
@@ -477,9 +447,7 @@ class BaseEstimator(sklearn.BaseEstimator):
       global_step = contrib_framework.create_global_step(g)
       features, targets = input_fn()
       self._check_inputs(features, targets)
-      eval_dict = self._get_eval_ops(features, targets,
-                                     metrics if metrics is not None else
-                                     self._get_default_metric_functions())
+      eval_dict = self._get_eval_ops(features, targets, metrics)
       update_op, eval_dict = self._extract_metric_update_ops(eval_dict)
       eval_results, _ = evaluate(graph=g,
                                  output_dir=eval_dir,
@@ -492,41 +460,48 @@ class BaseEstimator(sklearn.BaseEstimator):
                                  max_steps=steps)
       return eval_results
 
-  def _infer_model(self,
-                   x=None, input_fn=None, feed_fn=None,
-                   batch_size=None, axis=None, proba=False):
+  def _get_features_from_input_fn(self, input_fn):
+    result = input_fn()
+    if isinstance(result, (list, tuple)):
+      return result[0]
+    return result
+
+  def _infer_model(self, x=None, input_fn=None, feed_fn=None, batch_size=None):
     # Converts inputs into tf.DataFrame / tf.Series.
     batch_size = -1 if batch_size is None else batch_size
     if x is not None:
-      input_fn, feed_fn = _get_predict_input_fn(x, batch_size)
+      input_fn, feed_fn = _get_predict_input_fn(x, None, batch_size)
 
     checkpoint_path = saver.latest_checkpoint(self._model_dir)
     with ops.Graph().as_default() as g:
       random_seed.set_random_seed(self._config.tf_random_seed)
       contrib_framework.create_global_step(g)
-      features, _ = input_fn()
+      features = self._get_features_from_input_fn(input_fn)
       predictions = self._get_predict_ops(features)
+      return_dict = True
       if not isinstance(predictions, dict):
-        predictions = {'predictions': predictions}
-      # TODO(ipolosukhin): Support batching
+        predictions, return_dict = {'predictions': predictions}, False
       if feed_fn is None:
-        return infer(checkpoint_path, predictions)
-      preds = {}
-      while True:
-        try:
-          feed_dict = feed_fn()
-        except StopIteration:
-          break
-        if feed_dict is None:
-          break
-        outputs = infer(checkpoint_path, predictions, feed_dict=feed_dict)
-        for key in outputs:
-          if key not in preds:
-            preds[key] = []
-          preds[key].append(outputs[key])
-      for key in preds:
-        preds[key] = np.concatenate(preds[key], axis=0)
-      return preds
+        preds = infer(checkpoint_path, predictions)
+      else:
+        preds = {}
+        while True:
+          try:
+            feed_dict = feed_fn()
+          except StopIteration:
+            break
+          if feed_dict is None:
+            break
+          outputs = infer(checkpoint_path, predictions, feed_dict=feed_dict)
+          for key in outputs:
+            if key not in preds:
+              preds[key] = []
+            preds[key].append(outputs[key])
+        for key in preds:
+          preds[key] = np.concatenate(preds[key], axis=0)
+      if return_dict:
+        return preds
+      return preds['predictions']
 
 
 class Estimator(BaseEstimator):
@@ -571,6 +546,41 @@ class Estimator(BaseEstimator):
     self.learning_rate = learning_rate
     self.clip_gradients = clip_gradients
 
+  def predict(self, x=None, input_fn=None, axis=None, batch_size=None):
+    """Returns predictions for given features.
+
+    Args:
+      x: features.
+      input_fn: Input function. If set, x must be None.
+      axis: Axis on which to argmax (for classification).
+            Last axis is used by default.
+      batch_size: Override default batch size.
+
+    Returns:
+      Numpy array of predicted classes or regression values.
+    """
+    predictions = self._infer_model(x=x,
+                                    input_fn=input_fn,
+                                    batch_size=batch_size)
+    if self._classification:
+      for key in predictions:
+        cur_axis = (len(predictions[key].shape) - 1) if axis is None else axis
+        predictions[key] = np.argmax(predictions[key], axis=cur_axis)
+    return predictions
+
+  def predict_proba(self, x=None, input_fn=None, batch_size=None):
+    """Returns prediction probabilities for given features (classification).
+
+    Args:
+      x: features.
+      input_fn: Input function. If set, x and y must be None.
+      batch_size: Override default batch size.
+
+    Returns:
+      Numpy array of predicted probabilities.
+    """
+    return self._infer_model(x=x, input_fn=input_fn, batch_size=batch_size)
+
   def _get_train_ops(self, features, targets):
     """Method that builds model graph and returns trainer ops.
 
@@ -624,6 +634,9 @@ class Estimator(BaseEstimator):
     """
     predictions, loss = self._model_fn(features, targets, ModeKeys.EVAL)
     result = {'loss': loss}
+    if metrics is None:
+      metrics = _EVAL_METRICS[
+          'classification' if self._classification else 'regression']
     if isinstance(targets, dict) and len(targets) == 1:
       # Unpack single target into just tensor.
       targets = targets[targets.keys()[0]]
@@ -650,15 +663,6 @@ class Estimator(BaseEstimator):
     predictions, _ = self._model_fn(features, targets, ModeKeys.INFER)
     return predictions
 
-  def _get_default_metric_functions(self):
-    """Method that provides default metric operations.
-
-    Returns:
-      a dictionary of metric operations.
-    """
-    return _EVAL_METRICS[
-        'classification' if self._classification else 'regression']
-
   def _get_feature_ops_from_example(self, examples_batch):
     """Unimplemented.
 
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
index feb68782fa8..40a455c6bf1 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
@@ -19,8 +19,11 @@ from __future__ import print_function
 
 import tempfile
 
+import numpy as np
 import tensorflow as tf
 
+from tensorflow.contrib.learn.python.learn.estimators._sklearn import mean_squared_error
+
 
 def boston_input_fn():
   boston = tf.contrib.learn.datasets.load_boston()
@@ -33,6 +36,18 @@ def boston_input_fn():
   return features, target
 
 
+def boston_eval_fn():
+  boston = tf.contrib.learn.datasets.load_boston()
+  n_examples = len(boston.target)
+  features = tf.cast(
+      tf.reshape(
+          tf.constant(boston.data), [n_examples, 13]), tf.float32)
+  target = tf.cast(
+      tf.reshape(
+          tf.constant(boston.target), [n_examples, 1]), tf.float32)
+  return tf.concat(0, [features, features]), tf.concat(0, [target, target])
+
+
 def linear_model_fn(features, target, unused_mode):
   return tf.contrib.learn.models.linear_regression_zero_init(features, target)
 
@@ -57,12 +72,23 @@ class CheckCallsMonitor(tf.contrib.learn.monitors.BaseMonitor):
 
 class EstimatorTest(tf.test.TestCase):
 
-  def testTrain(self):
-    output_dir = tempfile.mkdtemp()
+  def testBostonAll(self):
+    boston = tf.contrib.learn.datasets.load_boston()
     est = tf.contrib.learn.Estimator(model_fn=linear_model_fn,
-                                     classification=False, model_dir=output_dir)
+                                     classification=False)
+    est.fit(x=boston.data, y=boston.target.astype(np.float32), steps=100)
+    scores = est.evaluate(
+        x=boston.data,
+        y=boston.target.astype(np.float32))
+    predictions = est.predict(x=boston.data)
+    other_score = mean_squared_error(predictions, boston.target)
+    self.assertAllClose(other_score, scores['mean_squared_error'])
+
+  def testTrainInputFn(self):
+    est = tf.contrib.learn.Estimator(model_fn=linear_model_fn,
+                                     classification=False)
     est.train(input_fn=boston_input_fn, steps=1)
-    _ = est.evaluate(input_fn=boston_input_fn, steps=1)
+    _ = est.evaluate(input_fn=boston_eval_fn, steps=1)
 
   def testPredict(self):
     est = tf.contrib.learn.Estimator(model_fn=linear_model_fn,
@@ -70,7 +96,15 @@ class EstimatorTest(tf.test.TestCase):
     boston = tf.contrib.learn.datasets.load_boston()
     est.train(input_fn=boston_input_fn, steps=1)
     output = est.predict(boston.data)
-    self.assertEqual(output['predictions'].shape[0], boston.target.shape[0])
+    self.assertEqual(output.shape[0], boston.target.shape[0])
+
+  def testPredictFn(self):
+    est = tf.contrib.learn.Estimator(model_fn=linear_model_fn,
+                                     classification=False)
+    boston = tf.contrib.learn.datasets.load_boston()
+    est.train(input_fn=boston_input_fn, steps=1)
+    output = est.predict(input_fn=boston_input_fn)
+    self.assertEqual(output.shape[0], boston.target.shape[0])
 
   def testWrongInput(self):
     def other_input_fn():
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
index c9aa8d32b0c..f3d158433ae 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
@@ -97,11 +97,6 @@ class LinearClassifierTest(tf.test.TestCase):
     classifier = tf.contrib.learn.LinearClassifier(
         feature_columns=[age, language])
 
-    # Evaluate on untrained model
-    classifier.evaluate(input_fn=input_fn, steps=2)
-    # TODO(ispir): Enable accuracy check after resolving the randomness issue.
-    # self.assertAlmostEqual(.5, evaluated_values['accuracy/mean'])
-
     # Evaluate on trained mdoel
     classifier.train(input_fn, steps=100)
     classifier.evaluate(input_fn=input_fn, steps=2)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/tensor_signature.py b/tensorflow/contrib/learn/python/learn/estimators/tensor_signature.py
index 49b99ba7fa0..9a5c17ef342 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/tensor_signature.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/tensor_signature.py
@@ -20,6 +20,7 @@ from __future__ import print_function
 import collections
 
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import array_ops
 
 
@@ -43,10 +44,23 @@ class TensorSignature(collections.namedtuple(
 
   def is_compatible_with(self, other):
     """Returns True if signatures are compatible."""
+
+    def _shape_is_compatible_0dim(this, other):
+      other = tensor_shape.as_shape(other)
+      if this.ndims != other.ndims:
+        return False
+      for dim, (x_dim, y_dim) in enumerate(zip(this.dims, other.dims)):
+        if dim == 0:
+          continue
+        if not x_dim.is_compatible_with(y_dim):
+          return False
+      return True
+
     if other.is_sparse:
       return self.is_sparse and self.dtype.is_compatible_with(other.dtype)
     return (self.dtype.is_compatible_with(other.dtype) and
-            self.shape.is_compatible_with(other.shape) and not self.is_sparse)
+            _shape_is_compatible_0dim(self.shape, other.shape) and
+            not self.is_sparse)
 
   def get_placeholder(self):
     if self.is_sparse:
diff --git a/tensorflow/contrib/learn/python/learn/estimators/tensor_signature_test.py b/tensorflow/contrib/learn/python/learn/estimators/tensor_signature_test.py
index 6bd79a9110f..bd1e18bd8d9 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/tensor_signature_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/tensor_signature_test.py
@@ -35,6 +35,9 @@ class TensorSignatureTest(tf.test.TestCase):
     placeholder_c = tf.placeholder(name='mismatch',
                                    shape=[256, 100],
                                    dtype=tf.float32)
+    placeholder_d = tf.placeholder(name='mismatch',
+                                   shape=[128, 100],
+                                   dtype=tf.int32)
     signatures = tensor_signature.create_signatures(placeholder_a)
     self.assertTrue(tensor_signature.tensors_compatible(placeholder_a,
                                                         signatures))
@@ -42,6 +45,8 @@ class TensorSignatureTest(tf.test.TestCase):
                                                         signatures))
     self.assertFalse(tensor_signature.tensors_compatible(placeholder_c,
                                                          signatures))
+    self.assertTrue(tensor_signature.tensors_compatible(placeholder_d,
+                                                        signatures))
 
     inputs = {'a': placeholder_a}
     signatures = tensor_signature.create_signatures(inputs)
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py
index d40809ec1ff..898590e52ff 100644
--- a/tensorflow/contrib/learn/python/learn/graph_actions.py
+++ b/tensorflow/contrib/learn/python/learn/graph_actions.py
@@ -25,15 +25,16 @@ import time
 import numpy as np
 
 from six import reraise
+
 from tensorflow.contrib.framework.python.ops import ops as contrib_ops
 from tensorflow.contrib.framework.python.ops import variables as contrib_variables
 from tensorflow.contrib.layers.python.layers import summaries
 from tensorflow.contrib.learn.python.learn import monitors as monitors_lib
-from tensorflow.core.util.event_pb2 import SessionLog
 from tensorflow.python.client import session as tf_session
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import data_flow_ops
 from tensorflow.python.ops import logging_ops
 from tensorflow.python.ops import variables
@@ -42,10 +43,10 @@ from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training import coordinator
 from tensorflow.python.training import queue_runner
 from tensorflow.python.training import saver as tf_saver
+from tensorflow.python.training import session_manager as session_manager_lib
 from tensorflow.python.training import summary_io
 from tensorflow.python.training import supervisor as tf_supervisor
 
-
 # pylint: disable=invalid-name
 Supervisor = tf_supervisor.Supervisor
 Coordinator = coordinator.Coordinator
@@ -83,6 +84,7 @@ def _run_dict(session, run_dict, feed_dict=None):
     session: The session to evaluate.
     run_dict: A dict of tensors to be run in the session.
     feed_dict: Feed dict to be used in running the session.
+
   Returns:
     A dict containing the result of evaluating the tensors.
   Raises:
@@ -91,40 +93,11 @@ def _run_dict(session, run_dict, feed_dict=None):
   if run_dict is None:
     raise ValueError('Invalid run_dict %s.', run_dict)
   keys = run_dict.keys()
-  values = session.run([run_dict[key] for key in keys], feed_dict=feed_dict)
+  tensors = [run_dict[key] for key in keys]
+  values = session.run(tensors, feed_dict=feed_dict)
   return dict(zip(keys, values))
 
 
-def _prepare_session(graph,
-                     output_dir,
-                     start_services,
-                     global_step_tensor,
-                     init_op=None,
-                     init_feed_dict=None,
-                     init_fn=None,
-                     supervisor_is_chief=True,
-                     supervisor_master='',
-                     supervisor_save_model_secs=600):
-  """Starts a session using the supervisor."""
-  if global_step_tensor is None:
-    global_step_tensor = Supervisor.USE_DEFAULT
-  supervisor = Supervisor(
-      graph,
-      init_op=init_op or Supervisor.USE_DEFAULT,
-      init_feed_dict=init_feed_dict,
-      is_chief=supervisor_is_chief,
-      logdir=output_dir,
-      saver=_make_saver(graph),
-      global_step=global_step_tensor,
-      summary_op=None,
-      save_model_secs=supervisor_save_model_secs,
-      init_fn=init_fn)
-  session = supervisor.PrepareSession(master=supervisor_master,
-                                      start_standard_services=start_services)
-  supervisor.StartQueueRunners(session)
-  return supervisor, session
-
-
 def _run_with_monitors(session, step, tensors, feed_dict, monitors):
   """Runs session for given tensors with monitor callbacks."""
   for monitor in monitors:
@@ -233,17 +206,20 @@ def train(graph,
   for monitor in monitors:
     monitor.begin(max_steps=max_steps)
 
-  supervisor, session = _prepare_session(
-      graph=graph,
-      output_dir=output_dir,
-      start_services=True,
-      global_step_tensor=global_step_tensor,
-      init_op=init_op,
+  supervisor = Supervisor(
+      graph,
+      init_op=init_op or Supervisor.USE_DEFAULT,
       init_feed_dict=init_feed_dict,
-      init_fn=init_fn,
-      supervisor_is_chief=supervisor_is_chief,
-      supervisor_master=supervisor_master,
-      supervisor_save_model_secs=supervisor_save_model_secs)
+      is_chief=supervisor_is_chief,
+      logdir=output_dir,
+      saver=_make_saver(graph),
+      global_step=global_step_tensor,
+      summary_op=None,
+      save_model_secs=supervisor_save_model_secs,
+      init_fn=init_fn)
+  session = supervisor.PrepareSession(master=supervisor_master,
+                                      start_standard_services=True)
+  supervisor.StartQueueRunners(session)
 
   with session:
     get_current_step = lambda: session.run(global_step_tensor)
@@ -338,6 +314,56 @@ def train(graph,
     return loss_value
 
 
+def _get_first_op_from_collection(collection_name):
+  elements = ops.get_collection(collection_name)
+  if elements is not None:
+    if elements:
+      return elements[0]
+  return None
+
+
+def _get_saver():
+  saver = _get_first_op_from_collection(ops.GraphKeys.SAVERS)
+  if saver is not None:
+    if saver:
+      saver = saver[0]
+    else:
+      saver = None
+  if saver is None and variables.all_variables():
+    saver = tf_saver.Saver()
+    ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
+  return saver
+
+
+def _get_ready_op():
+  ready_op = _get_first_op_from_collection(ops.GraphKeys.READY_OP)
+  if ready_op is None:
+    ready_op = variables.report_uninitialized_variables()
+    ops.add_to_collection(ops.GraphKeys.READY_OP, ready_op)
+  return ready_op
+
+
+def _get_local_init_op():
+  local_init_op = _get_first_op_from_collection(
+      ops.GraphKeys.LOCAL_INIT_OP)
+  if local_init_op is None:
+    op_list = [variables.initialize_local_variables(),
+               data_flow_ops.initialize_all_tables()]
+    if op_list:
+      local_init_op = control_flow_ops.group(*op_list)
+      ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op)
+  return local_init_op
+
+
+def _start_queue_runners(session, coord):
+  queue_runners = ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS)
+  threads = []
+  for qr in queue_runners:
+    threads.extend(qr.create_threads(session, coord=coord, daemon=True,
+                                     start=True))
+  return threads
+
+
 # TODO(ptucker): Add unit test.
 def evaluate(graph,
              output_dir,
@@ -345,7 +371,6 @@ def evaluate(graph,
              eval_dict,
              update_op=None,
              global_step_tensor=None,
-             init_op=None,
              supervisor_master='',
              log_every_steps=10,
              feed_fn=None,
@@ -367,14 +392,13 @@ def evaluate(graph,
     output_dir: A string containing the directory to write a summary to.
     checkpoint_path: A string containing the path to a checkpoint to restore.
       Can be `None` if the graph doesn't require loading any variables.
-    eval_dict: A `dict` mapping string names to tensors to evaluate for in every
-      eval step.
-    update_op: A 'Tensor' which is run before evaluating 'eval_dict'.
+    eval_dict: A `dict` mapping string names to tensors to evaluate. It is
+      evaluated in every logging step. The result of the final evaluation is
+      returned. If update_op is None, then it's evaluated in every step.
+    update_op: A `Tensor` which is run in every step.
     global_step_tensor: A `Variable` containing the global step. If `None`,
       one is extracted from the graph using the same logic as in `Supervisor`.
       Used to place eval summaries on training curves.
-    init_op: An op that initializes the graph. If `None`, use `Supervisor`'s
-      default.
     supervisor_master: The master string to use when preparing the session.
     log_every_steps: Integer. Output logs every `log_every_steps` evaluation
       steps. The logs contain the `eval_dict` and timing information.
@@ -385,7 +409,7 @@ def evaluate(graph,
   Returns:
     A tuple `(eval_results, global_step)`:
     eval_results: A `dict` mapping `string` to numeric values (`int`, `float`)
-      that are the eval results from the last step of the eval.  None if no
+      that are the result of running eval_dict in the last step. `None` if no
       eval steps were run.
     global_step: The global step this evaluation corresponds to.
   """
@@ -402,50 +426,67 @@ def evaluate(graph,
     if isinstance(value, ops.Tensor):
       summaries.summarize_tensor(value, tag=key)
 
-  # Create or get summary op.
+  # Create or get summary op, global_step and saver.
   summary_op = logging_ops.get_summary_op()
+  saver = _get_saver()
+  local_init_op = _get_local_init_op()
+  ready_op = _get_ready_op()
 
-  # TODO(wicke): Don't use supervisor here, or switch to output_dir=eval_dir.
-  supervisor, session = _prepare_session(
-      graph=graph,
-      output_dir=None,  # Must be None to avoid writing an event file
-      start_services=False,
-      global_step_tensor=global_step_tensor,
-      init_op=init_op,
-      supervisor_is_chief=True,
-      supervisor_master=supervisor_master,
-      supervisor_save_model_secs=None)
-  global_step_tensor = supervisor.global_step
+  session_manager = session_manager_lib.SessionManager(
+      local_init_op=local_init_op,
+      ready_op=ready_op)
+  session, initialized = session_manager.recover_session(
+      master=supervisor_master,
+      saver=saver,
+      checkpoint_dir=checkpoint_path)
+
+  # Start queue runners.
+  coord = coordinator.Coordinator()
+  threads = _start_queue_runners(session, coord)
 
   with session:
-    if checkpoint_path:
-      _restore_from_checkpoint(
-          session, graph, checkpoint_path, supervisor.saver)
+    if not initialized:
+      logging.warning('Failed to initialize from %s.', checkpoint_path)
+      # TODO(ipolosukhin): This should be failing, but old code relies on that.
+      session.run(variables.initialize_all_variables())
+      if checkpoint_path:
+        _restore_from_checkpoint(session, graph, checkpoint_path, saver)
 
     current_global_step = session.run(global_step_tensor)
     eval_results = None
     # TODO(amodei): Fix this to run through the eval set exactly once.
     step = 0
-    logging.info('Eval steps [%d,%s)', step, 'inf' if max_steps is None
-                 else str(max_steps))
+    logging.info('Eval steps [%d,%s) for training step %d.', step,
+                 'inf' if max_steps is None
+                 else str(max_steps), current_global_step)
     try:
       try:
-        while not supervisor.ShouldStop() and (
-            (max_steps is None) or (step < max_steps)):
+        while (max_steps is None) or (step < max_steps):
           start_time = time.time()
           feed_dict = feed_fn() if feed_fn is not None else None
-          if update_op:
-            session.run(update_op)
-          eval_results = _run_dict(session, eval_dict, feed_dict=feed_dict)
+          eval_results = None
+          if update_op is not None:
+            session.run(update_op, feed_dict=feed_dict)
+          else:
+            eval_results = _run_dict(session, eval_dict, feed_dict=feed_dict)
+
           # TODO(wicke): We should assert that the global step hasn't changed.
           step += 1
           if step % log_every_steps == 0:
+            if eval_results is None:
+              eval_results = _run_dict(session, eval_dict, feed_dict=feed_dict)
             duration = time.time() - start_time
             logging.info('Results after %d steps (%.3f sec/batch): %s.',
                          step, float(duration),
                          ', '.join('%s = %s' % (k, v)
                                    for k, v in eval_results.items()))
       finally:
+        if eval_results is None:
+          eval_results = _run_dict(session, eval_dict, feed_dict=feed_dict)
+        # Stop queue runners.
+        coord.request_stop()
+        coord.join(threads, stop_grace_period_secs=120)
+
         # Make our own summary writer and write a summary to the eval dir.
         # Only is feed_fn is not provided.
         # TODO(ipolosukhin): Convert evaluation to use streaming_metrics,
@@ -462,17 +503,19 @@ def evaluate(graph,
           finally:
             if summary_writer:
               summary_writer.close()
-
-        # Call supervisor.Stop() from within a try block because it re-raises
-        # exceptions thrown by the supervised threads.
-        supervisor.Stop()
     # catch OutOfRangeError which is thrown when queue is out of data (and for
     # other reasons as well).
     except errors.OutOfRangeError as e:
-      logging.warn('Input queue is exhausted: %s.', e)
+      if max_steps is None:
+        logging.info('Input queue is exhausted.')
+      else:
+        logging.warn('Input queue is exhausted: %s.', e)
     # catch StopIteration which is thrown is DataReader is out of data.
     except StopIteration as e:
-      logging.info('Input iterator is exhausted: %s.', e)
+      if max_steps is None:
+        logging.info('Input iterator is exhausted.')
+      else:
+        logging.warn('Input iterator is exhausted: %s.', e)
 
   return eval_results, current_global_step
 
diff --git a/tensorflow/contrib/learn/python/learn/io/data_feeder.py b/tensorflow/contrib/learn/python/learn/io/data_feeder.py
index babb216d9dc..04bbd997482 100644
--- a/tensorflow/contrib/learn/python/learn/io/data_feeder.py
+++ b/tensorflow/contrib/learn/python/learn/io/data_feeder.py
@@ -46,7 +46,7 @@ def _get_in_out_shape(x_shape, y_shape, n_classes, batch_size):
   # Skip first dimension if it is 1.
   if y_shape and y_shape[0] == 1:
     y_shape = y_shape[1:]
-  if n_classes > 1:
+  if n_classes is not None and n_classes > 1:
     output_shape = [batch_size] + y_shape + [n_classes]
   else:
     output_shape = [batch_size] + y_shape
@@ -441,7 +441,7 @@ class StreamingDataFeeder(DataFeeder):
 
         if self.y is not None:
           y = six.next(self.y)
-          if self.n_classes > 1:
+          if self.n_classes is not None and self.n_classes > 1:
             if len(self.output_shape) == 2:
               out.itemset((i, y), 1.0)
             else:
diff --git a/tensorflow/contrib/learn/python/learn/ops/losses_ops.py b/tensorflow/contrib/learn/python/learn/ops/losses_ops.py
index a3a5386e7ea..f78c242d6f6 100644
--- a/tensorflow/contrib/learn/python/learn/ops/losses_ops.py
+++ b/tensorflow/contrib/learn/python/learn/ops/losses_ops.py
@@ -27,8 +27,8 @@ def mean_squared_error_regressor(tensor_in, labels, weights, biases, name=None):
   """Returns prediction and loss for mean squared error regression."""
   with ops.op_scope([tensor_in, labels], name, "mean_squared_error_regressor"):
     predictions = nn.xw_plus_b(tensor_in, weights, biases)
-    if len(labels.get_shape()) == 1:
-      labels = array_ops_.reshape(labels, [-1, 1])
+    if len(labels.get_shape()) == 1 and len(predictions.get_shape()) == 2:
+      predictions = array_ops_.squeeze(predictions, squeeze_dims=[1])
     return predictions, loss_ops.sum_of_squares(predictions, labels)
 
 
diff --git a/tensorflow/contrib/learn/python/learn/tests/dataframe/__init__.py b/tensorflow/contrib/learn/python/learn/tests/dataframe/__init__.py
new file mode 100644
index 00000000000..3645584a43a
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/tests/dataframe/__init__.py
@@ -0,0 +1,18 @@
+"""Tests for DataFrames."""
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/learn/python/learn/tests/dataframe/mocks.py b/tensorflow/contrib/learn/python/learn/tests/dataframe/mocks.py
new file mode 100644
index 00000000000..62006b39f05
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/tests/dataframe/mocks.py
@@ -0,0 +1,122 @@
+"""Mock DataFrame constituents for testing."""
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from abc import ABCMeta
+
+from tensorflow.contrib.learn.python import learn
+
+
+class MockColumn(learn.Column):
+  """A mock column for use in testing."""
+
+  def __init__(self, cachekey, mock_tensors):
+    super(MockColumn, self).__init__()
+    self._cachekey = cachekey
+    self._mock_tensors = mock_tensors
+
+  def build(self, cache):
+    return self._mock_tensors
+
+  def __repr__(self):
+    return self._cachekey
+
+
+class MockTransform(learn.Transform):
+  """A mock transform for use in testing."""
+
+  __metaclass__ = ABCMeta
+
+  def __init__(self, param_one, param_two):
+    super(MockTransform, self).__init__()
+    self._param_one = param_one
+    self._param_two = param_two
+
+  @property
+  def name(self):
+    return "MockTransform"
+
+  @learn.parameter
+  def param_one(self):
+    return self._param_one
+
+  @learn.parameter
+  def param_two(self):
+    return self._param_two
+
+  @property
+  def input_valency(self):
+    return 1
+
+
+class MockZeroOutputTransform(MockTransform):
+  """A mock transform for use in testing."""
+
+  _mock_output_names = []
+
+  def __init__(self, param_one, param_two):
+    super(MockZeroOutputTransform, self).__init__(param_one, param_two)
+
+  @property
+  def _output_names(self):
+    return MockZeroOutputTransform._mock_output_names
+
+  def _apply_transform(self, input_tensors):
+    # pylint: disable=not-callable
+    return self.return_type()
+
+
+class MockOneOutputTransform(MockTransform):
+  """A mock transform for use in testing."""
+
+  _mock_output_names = ["out1"]
+
+  def __init__(self, param_one, param_two):
+    super(MockOneOutputTransform, self).__init__(param_one, param_two)
+
+  @property
+  def _output_names(self):
+    return MockOneOutputTransform._mock_output_names
+
+  def _apply_transform(self, input_tensors):
+    # pylint: disable=not-callable
+    return self.return_type("Fake Tensor 1")
+
+
+class MockTwoOutputTransform(MockTransform):
+  """A mock transform for use in testing."""
+
+  _mock_output_names = ["out1", "out2"]
+
+  @learn.parameter
+  def param_three(self):
+    return self._param_three
+
+  def __init__(self, param_one, param_two, param_three):
+    super(MockTwoOutputTransform, self).__init__(param_one, param_two)
+    self._param_three = param_three
+
+  @property
+  def _output_names(self):
+    return MockTwoOutputTransform._mock_output_names
+
+  def _apply_transform(self, input_tensors):
+    # pylint: disable=not-callable
+    return self.return_type("Fake Tensor 1", "Fake Tensor 2")
+
+
diff --git a/tensorflow/contrib/learn/python/learn/tests/dataframe/test_column.py b/tensorflow/contrib/learn/python/learn/tests/dataframe/test_column.py
new file mode 100644
index 00000000000..ae4f36cceb9
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/tests/dataframe/test_column.py
@@ -0,0 +1,68 @@
+"""Tests of the Column class."""
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.learn.python import learn
+from tensorflow.contrib.learn.python.learn.tests.dataframe import mocks
+
+
+class TransformedColumnTest(tf.test.TestCase):
+  """Test of `TransformedColumn`."""
+
+  def test_repr(self):
+    col = learn.TransformedColumn(
+        [mocks.MockColumn("foobar", [])],
+        mocks.MockTwoOutputTransform("thb", "nth", "snt"), "qux")
+
+    # note params are sorted by name
+    expected = ("MockTransform({'param_one': 'thb', 'param_three': 'snt', "
+                "'param_two': 'nth'})"
+                "(foobar)[qux]")
+    self.assertEqual(expected, repr(col))
+
+  def test_build_no_output(self):
+    def create_no_output_column():
+      return learn.TransformedColumn(
+          [mocks.MockColumn("foobar", [])],
+          mocks.MockZeroOutputTransform("thb", "nth"), None)
+
+    self.assertRaises(ValueError, create_no_output_column)
+
+  def test_build_single_output(self):
+    col = learn.TransformedColumn(
+        [mocks.MockColumn("foobar", [])],
+        mocks.MockOneOutputTransform("thb", "nth"), "out1")
+
+    result = col.build()
+    expected = "Fake Tensor 1"
+    self.assertEqual(expected, result)
+
+  def test_build_multiple_output(self):
+    col = learn.TransformedColumn(
+        [mocks.MockColumn("foobar", [])],
+        mocks.MockTwoOutputTransform("thb", "nth", "snt"), "out2")
+
+    result = col.build()
+    expected = "Fake Tensor 2"
+    self.assertEqual(expected, result)
+
+
+if __name__ == "__main__":
+  tf.test.main()
diff --git a/tensorflow/contrib/learn/python/learn/tests/dataframe/test_dataframe.py b/tensorflow/contrib/learn/python/learn/tests/dataframe/test_dataframe.py
new file mode 100644
index 00000000000..2385e4a4c3d
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/tests/dataframe/test_dataframe.py
@@ -0,0 +1,147 @@
+"""Tests of the DataFrame class."""
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.learn.python import learn
+from tensorflow.contrib.learn.python.learn.tests.dataframe import mocks
+
+
+def setup_test_df():
+  """Create a dataframe populated with some test columns."""
+  df = learn.DataFrame()
+  df["a"] = learn.TransformedColumn(
+      [mocks.MockColumn("foobar", [])],
+      mocks.MockTwoOutputTransform("iue", "eui", "snt"), "out1")
+  df["b"] = learn.TransformedColumn(
+      [mocks.MockColumn("foobar", [])],
+      mocks.MockTwoOutputTransform("iue", "eui", "snt"), "out2")
+  df["c"] = learn.TransformedColumn(
+      [mocks.MockColumn("foobar", [])],
+      mocks.MockTwoOutputTransform("iue", "eui", "snt"), "out1")
+  return df
+
+
+class DataFrameTest(tf.test.TestCase):
+  """Test of `DataFrame`."""
+
+  def test_create(self):
+    df = setup_test_df()
+    self.assertEqual(df.columns(), frozenset(["a", "b", "c"]))
+
+  def test_select(self):
+    df = setup_test_df()
+    df2 = df.select(["a", "c"])
+    self.assertEqual(df2.columns(), frozenset(["a", "c"]))
+
+  def test_get_item(self):
+    df = setup_test_df()
+    c1 = df["b"]
+    self.assertEqual("Fake Tensor 2", c1.build())
+
+  def test_set_item_column(self):
+    df = setup_test_df()
+    self.assertEqual(3, len(df))
+    col1 = mocks.MockColumn("QuackColumn", [])
+    df["quack"] = col1
+    self.assertEqual(4, len(df))
+    col2 = df["quack"]
+    self.assertEqual(col1, col2)
+
+  def test_set_item_column_multi(self):
+    df = setup_test_df()
+    self.assertEqual(3, len(df))
+    col1 = mocks.MockColumn("QuackColumn", [])
+    col2 = mocks.MockColumn("MooColumn", [])
+    df["quack", "moo"] = [col1, col2]
+    self.assertEqual(5, len(df))
+    col3 = df["quack"]
+    self.assertEqual(col1, col3)
+    col4 = df["moo"]
+    self.assertEqual(col2, col4)
+
+  def test_set_item_pandas(self):
+    # TODO(jamieas)
+    pass
+
+  def test_set_item_numpy(self):
+    # TODO(jamieas)
+    pass
+
+  def test_build(self):
+    df = setup_test_df()
+    result = df.build()
+    expected = {"a": "Fake Tensor 1",
+                "b": "Fake Tensor 2",
+                "c": "Fake Tensor 1"}
+    self.assertEqual(expected, result)
+
+  def test_to_input_fn_all_features(self):
+    df = setup_test_df()
+    input_fn = df.to_input_fn()
+    f, t = input_fn()
+    expected_f = {"a": "Fake Tensor 1",
+                  "b": "Fake Tensor 2",
+                  "c": "Fake Tensor 1"}
+    self.assertEqual(expected_f, f)
+
+    expected_t = {}
+    self.assertEqual(expected_t, t)
+
+  def test_to_input_fn_features_only(self):
+    df = setup_test_df()
+    input_fn = df.to_input_fn(["b", "c"])
+    f, t = input_fn()
+    expected_f = {"b": "Fake Tensor 2", "c": "Fake Tensor 1"}
+    self.assertEqual(expected_f, f)
+
+    expected_t = {}
+    self.assertEqual(expected_t, t)
+
+  def test_to_input_fn_targets_only(self):
+    df = setup_test_df()
+    input_fn = df.to_input_fn(target_keys=["b", "c"])
+    f, t = input_fn()
+    expected_f = {"a": "Fake Tensor 1"}
+    self.assertEqual(expected_f, f)
+
+    expected_t = {"b": "Fake Tensor 2", "c": "Fake Tensor 1"}
+    self.assertEqual(expected_t, t)
+
+  def test_to_input_fn_both(self):
+    df = setup_test_df()
+    input_fn = df.to_input_fn(feature_keys=["a"], target_keys=["b"])
+    f, t = input_fn()
+    expected_f = {"a": "Fake Tensor 1"}
+    self.assertEqual(expected_f, f)
+
+    expected_t = {"b": "Fake Tensor 2"}
+    self.assertEqual(expected_t, t)
+
+  def test_to_input_fn_not_disjoint(self):
+    df = setup_test_df()
+
+    def get_not_disjoint():
+      df.to_input_fn(feature_keys=["a", "b"], target_keys=["b"])
+
+    self.assertRaises(ValueError, get_not_disjoint)
+
+
+if __name__ == "__main__":
+  tf.test.main()
diff --git a/tensorflow/contrib/learn/python/learn/tests/dataframe/test_transform.py b/tensorflow/contrib/learn/python/learn/tests/dataframe/test_transform.py
new file mode 100644
index 00000000000..5e69c465d5c
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/tests/dataframe/test_transform.py
@@ -0,0 +1,91 @@
+"""Tests of the Transform class."""
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.learn.python import learn
+from tensorflow.contrib.learn.python.learn.dataframe.transform import _make_list_of_column
+from tensorflow.contrib.learn.python.learn.tests.dataframe import mocks
+
+
+class TransformTest(tf.test.TestCase):
+  """Tests of the Transform class."""
+
+  def test_make_list_of_column(self):
+    col1 = mocks.MockColumn("foo", [])
+    col2 = mocks.MockColumn("bar", [])
+
+    self.assertEqual([], _make_list_of_column(None))
+    self.assertEqual([col1], _make_list_of_column(col1))
+    self.assertEqual([col1], _make_list_of_column([col1]))
+    self.assertEqual([col1, col2], _make_list_of_column([col1, col2]))
+    self.assertEqual([col1, col2], _make_list_of_column((col1, col2)))
+
+  def test_cache(self):
+    z = mocks.MockColumn("foobar", [])
+    t = mocks.MockTwoOutputTransform("thb", "nth", "snt")
+    cache = {}
+    t.apply_transform([z], cache)
+    self.assertEqual(2, len(cache))
+
+    expected_keys = [
+        "MockTransform("
+        "{'param_one': 'thb', 'param_three': 'snt', 'param_two': 'nth'})"
+        "(foobar)[out1]",
+        "MockTransform("
+        "{'param_one': 'thb', 'param_three': 'snt', 'param_two': 'nth'})"
+        "(foobar)[out2]"]
+
+    self.assertEqual(expected_keys, sorted(cache.keys()))
+
+  def test_parameters(self):
+    t = mocks.MockTwoOutputTransform("a", "b", "c")
+    self.assertEqual({"param_one": "a", "param_three": "c", "param_two": "b"},
+                     t.parameters())
+
+  def test_parameters_inherited_combined(self):
+    t = mocks.MockTwoOutputTransform("thb", "nth", "snt")
+
+    expected = {"param_one": "thb", "param_two": "nth", "param_three": "snt"}
+    self.assertEqual(expected, t.parameters())
+
+  def test_return_type(self):
+    t = mocks.MockTwoOutputTransform("a", "b", "c")
+
+    rt = t.return_type
+    self.assertEqual("ReturnType", rt.__name__)
+    self.assertEqual(("out1", "out2"), rt._fields)
+
+  def test_call(self):
+    t = mocks.MockTwoOutputTransform("a", "b", "c")
+    # MockTwoOutputTransform has input valency 1
+    input1 = mocks.MockColumn("foobar", [])
+    out1, out2 = t([input1])  # pylint: disable=not-callable
+
+    self.assertEqual(learn.TransformedColumn, type(out1))
+    # self.assertEqual(out1.transform, t)
+    # self.assertEqual(out1.output_name, "output1")
+
+    self.assertEqual(learn.TransformedColumn, type(out2))
+    # self.assertEqual(out2.transform, t)
+    # self.assertEqual(out2.output_name, "output2")
+
+
+if __name__ == "__main__":
+  tf.test.main()
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py
index e8c7599dc37..c9947a4ec2b 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops.py
@@ -14,6 +14,9 @@
 # ==============================================================================
 """## Loss operations for use in neural networks.
 
+Note: By default all the losses are collected into the `GraphKeys.LOSSES`
+collection.
+
 All of the loss functions take a pair of predictions and ground truth labels,
 from which the loss is computed. It is assumed that the shape of both these
 tensors is of the form [batch_size, d1, ... dN] where `batch_size` is the number
@@ -32,6 +35,9 @@ implement this as:
   # Uses default weight of 1.0
   tf.contrib.losses.sum_of_squares(predictions, targets)
 
+  # All the losses are collected into the `GraphKeys.LOSSES` collection.
+  losses = tf.get_collection(tf.GraphKeys.LOSSES)
+
 While specifying a scalar loss rescales the loss over the entire batch,
 we sometimes want to rescale the loss per batch sample. For example, if we have
 certain examples that matter more to us to get correctly, we might want to have
@@ -75,7 +81,7 @@ these predictions.
   predictions = MyModelPredictions(images)
 
   weight = tf.cast(tf.greater(depths, 0), tf.float32)
-  tf.contrib.losses.sum_of_squares(predictions, depths, weight)
+  loss  = tf.contrib.losses.sum_of_squares(predictions, depths, weight)
 
 Note that when using weights for the losses, the final average is computed
 by rescaling the losses by the weights and then dividing by the total number of
@@ -96,7 +102,7 @@ weighted average over the individual prediction errors:
 
   weight = MyComplicatedWeightingFunction(labels)
   weight = tf.div(weight, tf.size(weight))
-  tf.contrib.losses.sum_of_squares(predictions, depths, weight)
+  loss = tf.contrib.losses.sum_of_squares(predictions, depths, weight)
 
 
 @@absolute_difference
@@ -189,7 +195,9 @@ def _compute_weighted_loss(losses, weight):
 
   total_loss = _scale_losses(losses, weight)
   num_present = _num_present(losses, weight)
-  return _safe_mean(total_loss, num_present)
+  mean_loss = _safe_mean(total_loss, num_present)
+  ops.add_to_collection(ops.GraphKeys.LOSSES, mean_loss)
+  return mean_loss
 
 
 def _num_present(losses, weight, per_batch=False):
@@ -516,10 +524,12 @@ def sum_of_pairwise_squares(predictions, targets, weight=1.0, scope=None):
 
     loss = _scale_losses(term1 - term2, weight)
 
-    return math_ops.select(math_ops.reduce_sum(num_present_per_batch) > 0,
-                           loss,
-                           array_ops.zeros_like(loss),
-                           name="value")
+    mean_loss = math_ops.select(math_ops.reduce_sum(num_present_per_batch) > 0,
+                                loss,
+                                array_ops.zeros_like(loss),
+                                name="value")
+    ops.add_to_collection(ops.GraphKeys.LOSSES, mean_loss)
+    return mean_loss
 
 
 def cosine_distance(predictions, targets, dim, weight=1.0, scope=None):
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index c142efd7855..354d1239447 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -211,6 +211,7 @@ cc_library(
         "platform/mutex.h",
         "platform/protobuf.h",  # TODO(josh11b): make internal
         "platform/regexp.h",
+        "platform/strong_hash.h",
         "platform/thread_annotations.h",
         "platform/types.h",
     ],
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
index 3f06bfb5a55..80f07d19c5c 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
@@ -82,8 +82,11 @@ void ReEncodeConsts(GraphDef* gdef) {
 
 Status GrpcSession::CreateImpl(CallOptions* call_options,
                                const GraphDef& graph) {
-  if (!handle_.empty()) {
-    return errors::InvalidArgument("A session is alive.");
+  {
+    mutex_lock l(mu_);
+    if (!handle_.empty()) {
+      return errors::InvalidArgument("A session is alive.");
+    }
   }
   CreateSessionRequest req;
   *req.mutable_config() = options_.config;
@@ -114,7 +117,12 @@ Status GrpcSession::Create(const RunOptions& run_options,
 
 Status GrpcSession::ExtendImpl(CallOptions* call_options,
                                const GraphDef& graph) {
-  if (handle_.empty()) {
+  bool handle_is_empty;
+  {
+    mutex_lock l(mu_);
+    handle_is_empty = handle_.empty();
+  }
+  if (handle_is_empty) {
     // Session was unitialized, so simply initialize the session with 'graph'.
     return Create(graph);
   }
@@ -213,11 +221,14 @@ Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
 
 Status GrpcSession::RunProto(CallOptions* call_options, RunStepRequest* req,
                              RunStepResponse* resp) {
-  if (handle_.empty()) {
-    return errors::InvalidArgument("A session is not created yet....");
-  }
+  {
+    mutex_lock l(mu_);
+    if (handle_.empty()) {
+      return errors::InvalidArgument("A session is not created yet....");
+    }
 
-  req->set_session_handle(handle_);
+    req->set_session_handle(handle_);
+  }
   return master_->RunStep(call_options, req, resp);
 }
 
@@ -236,12 +247,15 @@ Status GrpcSession::PRun(const string& handle,
 }
 
 Status GrpcSession::Close() {
-  if (handle_.empty()) {
-    return errors::InvalidArgument("A session is not created yet....");
-  }
   CloseSessionRequest req;
-  req.set_session_handle(handle_);
-  handle_.clear();
+  {
+    mutex_lock l(mu_);
+    if (handle_.empty()) {
+      return errors::InvalidArgument("A session is not created yet....");
+    }
+    req.set_session_handle(handle_);
+    handle_.clear();
+  }
   CloseSessionResponse resp;
   CallOptions call_options;
   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.h b/tensorflow/core/distributed_runtime/rpc/grpc_session.h
index 54acf21a423..d6f680bf9d5 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.h
@@ -102,7 +102,7 @@ class GrpcSession : public Session {
   mutex mu_;
 
   // handle_ returned by the master to identify this session.
-  string handle_;
+  string handle_ GUARDED_BY(mu_);
 
   // The current version of the graph.
   int64 current_graph_version_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 470c6c025b8..9cd596436fd 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1437,6 +1437,7 @@ tf_kernel_libraries(
         "sparse_reduce_sum_op",
         "sparse_dense_binary_op_shared",
         "sparse_reorder_op",
+        "sparse_softmax",
         "sparse_split_op",
         "sparse_tensor_dense_add_op",
         "sparse_tensor_dense_matmul_op",
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc
index d8ca83f8f7c..f10885509d5 100644
--- a/tensorflow/core/kernels/cast_op.cc
+++ b/tensorflow/core/kernels/cast_op.cc
@@ -150,19 +150,11 @@ class CpuCastOp : public CastOpBase {
       work_ = [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) {
         int64 N = out->NumElements();
         auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
-        int num_threads = static_cast<int>(std::min(
-            static_cast<int64>(std::min(4, worker_threads->num_threads)),
-            N / 4096));
-        if (num_threads < 1) {
-          BFloat16ToFloat(inp.flat<bfloat16>().data(),
-                          out->flat<float>().data(), N);
-        } else {
-          auto work = [&inp, &out](int64 start, int64 end) {
-            BFloat16ToFloat(inp.flat<bfloat16>().data() + start,
-                            out->flat<float>().data() + start, end - start);
-          };
-          Shard(num_threads, worker_threads->workers, N, 100, work);
-        }
+        auto work = [&inp, &out](int64 start, int64 end) {
+          BFloat16ToFloat(inp.flat<bfloat16>().data() + start,
+                          out->flat<float>().data() + start, end - start);
+        };
+        Shard(worker_threads->num_threads, worker_threads->workers, N, 2, work);
       };
       return Status::OK();
     }
@@ -170,19 +162,11 @@ class CpuCastOp : public CastOpBase {
       work_ = [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) {
         int64 N = out->NumElements();
         auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
-        int num_threads = static_cast<int>(std::min(
-            static_cast<int64>(std::min(4, worker_threads->num_threads)),
-            N / 4096));
-        if (num_threads < 1) {
-          FloatToBFloat16(inp.flat<float>().data(),
-                          out->flat<bfloat16>().data(), N);
-        } else {
-          auto work = [&inp, &out](int64 start, int64 end) {
-            FloatToBFloat16(inp.flat<float>().data() + start,
-                            out->flat<bfloat16>().data() + start, end - start);
-          };
-          Shard(num_threads, worker_threads->workers, N, 100, work);
-        }
+        auto work = [&inp, &out](int64 start, int64 end) {
+          FloatToBFloat16(inp.flat<float>().data() + start,
+                          out->flat<bfloat16>().data() + start, end - start);
+        };
+        Shard(worker_threads->num_threads, worker_threads->workers, N, 2, work);
       };
       return Status::OK();
     }
diff --git a/tensorflow/core/kernels/concat_lib_cpu.cc b/tensorflow/core/kernels/concat_lib_cpu.cc
index fa3c9f5fbf3..330cea8222d 100644
--- a/tensorflow/core/kernels/concat_lib_cpu.cc
+++ b/tensorflow/core/kernels/concat_lib_cpu.cc
@@ -123,7 +123,8 @@ void ConcatCPU(DeviceBase* d,
       }
     }
   };
-  Shard(num_threads, worker_threads->workers, output->size(), sizeof(T), work);
+  Shard(worker_threads->num_threads, worker_threads->workers, output->size(),
+        sizeof(T), work);
 }
 
 #define REGISTER(T)                                                            \
diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h
index 419ba4dfc6d..14e7d033eb9 100644
--- a/tensorflow/core/kernels/conv_ops_gpu.h
+++ b/tensorflow/core/kernels/conv_ops_gpu.h
@@ -18,6 +18,7 @@ limitations under the License.
 
 #if GOOGLE_CUDA
 
+#include <tuple>
 #include "tensorflow/core/platform/stream_executor.h"
 
 namespace tensorflow {
@@ -95,8 +96,18 @@ struct ConvParameters {
   int64 padding_cols;
   int device_id;
 
+  typedef std::tuple<int64, int64, int64, int64, int64, int64, int64, int64,
+                     int64, int64, int64, int>
+      DataType;
+
+  DataType get_data_as_tuple() const {
+    return std::make_tuple(batch, in_depths, in_rows, in_cols, out_depths,
+                           filter_rows, filter_cols, stride_rows, stride_cols,
+                           padding_rows, padding_cols, device_id);
+  }
+
   bool operator==(const ConvParameters& other) const {
-    return memcmp(this, &other, sizeof(ConvParameters)) == 0;
+    return this->get_data_as_tuple() == other.get_data_as_tuple();
   }
 
   bool operator!=(const ConvParameters& other) const {
@@ -104,7 +115,7 @@ struct ConvParameters {
   }
 
   bool operator<(const ConvParameters& other) const {
-    return memcmp(this, &other, sizeof(ConvParameters)) < 0;
+    return this->get_data_as_tuple() < other.get_data_as_tuple();
   }
 };
 
diff --git a/tensorflow/core/kernels/cwise_op_mod.cc b/tensorflow/core/kernels/cwise_op_mod.cc
index 4f2430f9964..c98630b624f 100644
--- a/tensorflow/core/kernels/cwise_op_mod.cc
+++ b/tensorflow/core/kernels/cwise_op_mod.cc
@@ -18,4 +18,17 @@ limitations under the License.
 namespace tensorflow {
 REGISTER2(BinaryOp, CPU, "Mod", functor::safe_mod, int32, int64);
 REGISTER2(BinaryOp, CPU, "Mod", functor::fmod, float, double);
+
+#if GOOGLE_CUDA
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Mod")
+                            .Device(DEVICE_GPU)
+                            .HostMemory("x")
+                            .HostMemory("y")
+                            .HostMemory("z")
+                            .TypeConstraint<int32>("T"),
+                        BinaryOp<CPUDevice, functor::safe_mod<int32>>);
+#endif
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/decode_png_op.cc b/tensorflow/core/kernels/decode_png_op.cc
index 827d8d23f71..6cd4d7e66fe 100644
--- a/tensorflow/core/kernels/decode_png_op.cc
+++ b/tensorflow/core/kernels/decode_png_op.cc
@@ -70,8 +70,8 @@ class DecodePngOp : public OpKernel {
     //   verify single dimension is not too large.
     // - verify when width and height are multiplied together, there are a few
     //   bits to spare as well.
-    const int width = decode.width;
-    const int height = decode.height;
+    const int width = static_cast<int>(decode.width);
+    const int height = static_cast<int>(decode.height);
     const int64 total_size =
         static_cast<int64>(width) * static_cast<int64>(height);
     if (width != static_cast<int64>(decode.width) || width <= 0 ||
diff --git a/tensorflow/core/kernels/draw_bounding_box_op.cc b/tensorflow/core/kernels/draw_bounding_box_op.cc
index 8691ab0449d..36667df161d 100644
--- a/tensorflow/core/kernels/draw_bounding_box_op.cc
+++ b/tensorflow/core/kernels/draw_bounding_box_op.cc
@@ -25,12 +25,11 @@ limitations under the License.
 
 namespace tensorflow {
 
+template <class T>
 class DrawBoundingBoxesOp : public OpKernel {
  public:
   explicit DrawBoundingBoxesOp(OpKernelConstruction* context)
       : OpKernel(context) {
-    OP_REQUIRES_OK(context,
-                   context->MatchSignature({DT_FLOAT, DT_FLOAT}, {DT_FLOAT}));
   }
 
   void Compute(OpKernelContext* context) override {
@@ -57,8 +56,8 @@ class DrawBoundingBoxesOp : public OpKernel {
         context->allocate_output(
             0, TensorShape({batch_size, height, width, depth}), &output));
 
-    output->tensor<float, 4>() = images.tensor<float, 4>();
-    auto canvas = output->tensor<float, 4>();
+    output->tensor<T, 4>() = images.tensor<T, 4>();
+    auto canvas = output->tensor<T, 4>();
 
     for (int64 b = 0; b < batch_size; ++b) {
       const int64 num_boxes = boxes.dim_size(1);
@@ -122,29 +121,35 @@ class DrawBoundingBoxesOp : public OpKernel {
         // Draw top line.
         if (min_box_row >= 0) {
           for (int64 j = min_box_col_clamp; j <= max_box_col_clamp; ++j)
-            canvas(b, min_box_row, j, 0) = nanf("");
+            canvas(b, min_box_row, j, 0) = T(nanf(""));
         }
         // Draw bottom line.
         if (max_box_row < height) {
           for (int64 j = min_box_col_clamp; j <= max_box_col_clamp; ++j)
-            canvas(b, max_box_row, j, 0) = nanf("");
+            canvas(b, max_box_row, j, 0) = T(nanf(""));
         }
         // Draw left line.
         if (min_box_col >= 0) {
           for (int64 i = min_box_row_clamp; i <= max_box_row_clamp; ++i)
-            canvas(b, i, min_box_col, 0) = nanf("");
+            canvas(b, i, min_box_col, 0) = T(nanf(""));
         }
         // Draw right line.
         if (max_box_col < width) {
           for (int64 i = min_box_row_clamp; i <= max_box_row_clamp; ++i)
-            canvas(b, i, max_box_col, 0) = nanf("");
+            canvas(b, i, max_box_col, 0) = T(nanf(""));
         }
       }
     }
   }
 };
 
-REGISTER_KERNEL_BUILDER(Name("DrawBoundingBoxes").Device(DEVICE_CPU),
-                        DrawBoundingBoxesOp);
+REGISTER_KERNEL_BUILDER(
+    Name("DrawBoundingBoxes").Device(DEVICE_CPU).TypeConstraint<float>("T"),
+    DrawBoundingBoxesOp<float>);
+
+REGISTER_KERNEL_BUILDER(Name("DrawBoundingBoxes")
+                            .Device(DEVICE_CPU)
+                            .TypeConstraint<Eigen::half>("T"),
+                        DrawBoundingBoxesOp<Eigen::half>);
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/edit_distance_op.cc b/tensorflow/core/kernels/edit_distance_op.cc
index b4d14e8c628..7f0b73e6a2c 100644
--- a/tensorflow/core/kernels/edit_distance_op.cc
+++ b/tensorflow/core/kernels/edit_distance_op.cc
@@ -144,7 +144,7 @@ class EditDistanceOp : public OpKernel {
     std::iota(group_dims.begin(), group_dims.end(), 0);
 
     TensorShape output_shape;
-    for (size_t d = 0; d < group_dims.size(); ++d) {
+    for (int d = 0; d < static_cast<int>(group_dims.size()); ++d) {
       output_shape.AddDim(std::max(hypothesis_st_shape.dim_size(d),
                                    truth_st_shape.dim_size(d)));
     }
diff --git a/tensorflow/core/kernels/example_parsing_ops.cc b/tensorflow/core/kernels/example_parsing_ops.cc
index b7e0109b75d..2dcf079dbfa 100644
--- a/tensorflow/core/kernels/example_parsing_ops.cc
+++ b/tensorflow/core/kernels/example_parsing_ops.cc
@@ -48,6 +48,8 @@ class ExampleParserOp : public OpKernel {
                 errors::InvalidArgument("len(dense_keys) != len(dense_types"));
     OP_REQUIRES(ctx, static_cast<size_t>(num_dense_) == dense_shapes_.size(),
                 errors::InvalidArgument("len(dense_keys) != len(dense_shapes"));
+    OP_REQUIRES(ctx, num_dense_ <= std::numeric_limits<int32>::max(),
+                errors::InvalidArgument("num_dense_ too large"));
     for (const DataType& type : dense_types_) {
       OP_REQUIRES_OK(ctx, CheckValidType(type));
     }
@@ -108,7 +110,7 @@ class ExampleParserOp : public OpKernel {
                     "Expected len(dense_defaults) == len(dense_keys) but got: ",
                     dense_defaults.size(), " vs. ", num_dense_));
 
-    for (int d = 0; d < num_dense_; ++d) {
+    for (int d = 0; d < static_cast<int>(num_dense_); ++d) {
       const Tensor& def_value = dense_defaults[d];
       if (def_value.NumElements() > 0) {
         OP_REQUIRES(ctx, def_value.shape() == dense_shapes_[d],
@@ -126,7 +128,7 @@ class ExampleParserOp : public OpKernel {
 
     auto serialized_t = serialized->vec<string>();
 
-    const int batch_size = serialized_t.size();
+    const int64 batch_size = serialized_t.size();
 
     OpOutputList sparse_indices;
     OpOutputList sparse_values;
@@ -146,7 +148,8 @@ class ExampleParserOp : public OpKernel {
       // Preallocate dense_values, since we know their sizes
       TensorShape out_shape;
       out_shape.AddDim(batch_size);
-      for (const int dim : dense_shapes_[d].dim_sizes()) out_shape.AddDim(dim);
+      for (const int64 dim : dense_shapes_[d].dim_sizes())
+        out_shape.AddDim(dim);
       Tensor* out = nullptr;
       dense_values.allocate(d, out_shape, &out);
 
diff --git a/tensorflow/core/kernels/listdiff_op.cc b/tensorflow/core/kernels/listdiff_op.cc
index 891f7888aba..9e221efac97 100644
--- a/tensorflow/core/kernels/listdiff_op.cc
+++ b/tensorflow/core/kernels/listdiff_op.cc
@@ -42,20 +42,24 @@ class ListDiffOp : public OpKernel {
     OP_REQUIRES(context, TensorShapeUtils::IsVector(y.shape()),
                 errors::InvalidArgument("y should be a 1D vector."));
 
-    std::unordered_set<T> y_set;
+    const auto Tx = x.vec<T>();
+    const size_t x_size = Tx.size();
     const auto Ty = y.vec<T>();
-    const int y_size = Ty.size();
+    const size_t y_size = Ty.size();
+
+    OP_REQUIRES(context, x_size < std::numeric_limits<int32>::max(),
+                errors::InvalidArgument("x too large for int32 indexing"));
+
+    std::unordered_set<T> y_set;
     y_set.reserve(y_size);
-    for (int i = 0; i < y_size; ++i) {
+    for (size_t i = 0; i < y_size; ++i) {
       y_set.insert(Ty(i));
     }
 
     // Compute the size of the output.
-    const auto Tx = x.vec<T>();
-    const int x_size = Tx.size();
 
-    int out_size = 0;
-    for (int i = 0; i < x_size; ++i) {
+    int64 out_size = 0;
+    for (size_t i = 0; i < x_size; ++i) {
       if (y_set.count(Tx(i)) == 0) {
         ++out_size;
       }
@@ -70,7 +74,7 @@ class ListDiffOp : public OpKernel {
     OP_REQUIRES_OK(context, context->allocate_output(1, {out_size}, &indices));
     auto Tindices = indices->vec<int32>();
 
-    for (int i = 0, p = 0; i < x_size; ++i) {
+    for (int i = 0, p = 0; i < static_cast<int32>(x_size); ++i) {
       if (y_set.count(Tx(i)) == 0) {
         OP_REQUIRES(context, p < out_size,
                     errors::InvalidArgument(
diff --git a/tensorflow/core/kernels/lrn_op.cc b/tensorflow/core/kernels/lrn_op.cc
index 5a4930ec63e..8a1b7b3de00 100644
--- a/tensorflow/core/kernels/lrn_op.cc
+++ b/tensorflow/core/kernels/lrn_op.cc
@@ -33,10 +33,10 @@ namespace tensorflow {
 
 namespace {
 
-// When the depth is large and beta_ is 0.5 or 1.0, MognetLRN is faster than the
-// main band matrix approach used below. Benchmarks suggest switching to
-// MognetLRN when depth > 384.
-const int kMognetLRNDepthCutoff = 384;
+// When the depth is large and beta_ is 0.5 or 1.0, Single-threaded
+// LRN is faster than the main band matrix approach used
+// below. Benchmarks suggest switching to SingleThreadedLRN when depth > 384.
+const int kSingleThreadedLRNDepthCutoff = 384;
 
 // Create a depth-by-depth band matrix with 1s along a swath of size (2 *
 // depth_radius + 1) around the diagonal.
@@ -88,10 +88,11 @@ class LRNOp : public OpKernel {
                        0, TensorShape({batch, rows, cols, depth}), &output));
 
 #if defined(__ANDROID__)
-    MognetLRN(in, batch, rows, cols, depth, output);
+    SingleThreadedLRN(in, batch, rows, cols, depth, output);
 #else
-    if (depth > kMognetLRNDepthCutoff && (beta_ == 0.5f || beta_ == 1.0f)) {
-      MognetLRN(in, batch, rows, cols, depth, output);
+    if (depth > kSingleThreadedLRNDepthCutoff &&
+        (beta_ == 0.5f || beta_ == 1.0f)) {
+      SingleThreadedLRN(in, batch, rows, cols, depth, output);
       return;
     }
 
@@ -124,8 +125,8 @@ class LRNOp : public OpKernel {
  private:
   typedef Eigen::Tensor<float, 1, Eigen::RowMajor>::DimensionPair DimPair;
 
-  void MognetLRN(const Tensor& in, const int batch, const int rows,
-                 const int cols, const int depth, Tensor* out) {
+  void SingleThreadedLRN(const Tensor& in, const int batch, const int rows,
+                         const int cols, const int depth, Tensor* out) {
     Eigen::Map<const Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic>>
     data_in(in.flat<float>().data(), depth, batch * rows * cols);
 
diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc
index bd70663eb3a..670a041e188 100644
--- a/tensorflow/core/kernels/random_op.cc
+++ b/tensorflow/core/kernels/random_op.cc
@@ -161,13 +161,11 @@ struct FillPhiloxRandom<CPUDevice, Distribution> {
 
     int64 total_group_count = (size + kGroupSize - 1) / kGroupSize;
 
-    // Limit to maximum six threads for now. The performance scaling is very
-    // sub-linear. Too many threads causes a much worse overall performance.
-    int num_workers = 6;
     const int kGroupCost =
         random::PhiloxRandom::kResultElementCount *
         (random::PhiloxRandom::kElementCost + Distribution::kElementCost);
-    Shard(num_workers, worker_threads.workers, total_group_count, kGroupCost,
+    Shard(worker_threads.num_threads, worker_threads.workers, total_group_count,
+          kGroupCost,
           [&gen, data, size, dist](int64 start_group, int64 limit_group) {
             FillPhiloxRandomTask<
                 Distribution,
@@ -399,8 +397,10 @@ class MultinomialOp : public OpKernel {
                     sizeof(int64) * num_samples);
       }
     };
-    Shard(std::min(batch_size, worker_threads.num_threads),
-          worker_threads.workers, batch_size, num_samples * num_classes * 2,
+    // Rough estimate, log2() takes from 58-680 cycles on Haswell.
+    // The functor here calls log twice for each element.
+    const int64 cost = 500 * num_samples * num_classes;
+    Shard(worker_threads.num_threads, worker_threads.workers, batch_size, cost,
           DoWork);
   }
 
diff --git a/tensorflow/core/kernels/softmax_op_functor.h b/tensorflow/core/kernels/softmax_op_functor.h
index 47bb9de411a..c3b0881b0c4 100644
--- a/tensorflow/core/kernels/softmax_op_functor.h
+++ b/tensorflow/core/kernels/softmax_op_functor.h
@@ -63,31 +63,34 @@ struct SoftmaxEigenImpl {
     Eigen::IndexList<Eigen::type2index<1>, int> one_by_class;
     one_by_class.set(1, num_classes);
 #endif
-    //shifted_logits = logits - max(logits along classes);
-    auto shifted_logits = (logits - logits.maximum(along_class)
-                                      .eval()
-                                      .reshape(batch_by_one)
-                                      .broadcast(one_by_class));
+    // shifted_logits = logits - max(logits along classes);
+    auto shifted_logits = (logits -
+                           logits.maximum(along_class)
+                               .eval()
+                               .reshape(batch_by_one)
+                               .broadcast(one_by_class));
     if (log) {
       // Calculate the log of the softmax
       // softmax = logits - max(logits along classes);
       softmax.device(d) = shifted_logits;
       // softmax = softmax - log(sum(exp(softmax along classes)));
       softmax.device(d) = (softmax -
-                           softmax.exp().sum(along_class)
-                              .eval()
-                              .reshape(batch_by_one)
-                              .broadcast(one_by_class)
-                              .log());
+                           softmax.exp()
+                               .sum(along_class)
+                               .eval()
+                               .reshape(batch_by_one)
+                               .broadcast(one_by_class)
+                               .log());
     } else {
       // NOTE(touts): If you modify this implementation please run
       // the BM_ImageNetSoftmaxFwd benchmark in nn_ops_test.cc.
       //
       // softmax = exp(logits - max(logits along classes));
       softmax.device(d) = shifted_logits.exp();
-      // softmax = softmax / sum(softmax along classes);
-      softmax.device(d) = (softmax /
+      // softmax = softmax * (1 / sum(softmax along classes));
+      softmax.device(d) = (softmax *
                            softmax.sum(along_class)
+                               .inverse()
                                .eval()
                                .reshape(batch_by_one)
                                .broadcast(one_by_class));
diff --git a/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc b/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc
index ce4c0eee65f..560771b6a76 100644
--- a/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc
+++ b/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc
@@ -171,7 +171,10 @@ class SparseDenseBinaryOpShared : public OpKernel {
                                                                              \
   REGISTER_KERNEL_BUILDER(                                                   \
       Name("SparseDenseCwiseDiv").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
-      SparseDenseBinaryOpShared<CPUDevice, T, functor::div<T>>)
+      SparseDenseBinaryOpShared<CPUDevice, T, functor::div<T>>)              \
+  REGISTER_KERNEL_BUILDER(                                                   \
+      Name("SparseDenseCwiseAdd").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+      SparseDenseBinaryOpShared<CPUDevice, T, functor::add<T>>)
 
 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
 #undef REGISTER_KERNELS
diff --git a/tensorflow/core/kernels/sparse_softmax_op.cc b/tensorflow/core/kernels/sparse_softmax_op.cc
new file mode 100644
index 00000000000..05a52b4e736
--- /dev/null
+++ b/tensorflow/core/kernels/sparse_softmax_op.cc
@@ -0,0 +1,128 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/sparse_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/util/sparse/sparse_tensor.h"
+
+using tensorflow::gtl::ArraySlice;
+using tensorflow::sparse::SparseTensor;
+
+namespace tensorflow {
+
+using CPUDevice = Eigen::ThreadPoolDevice;
+
+template <typename Device, typename T>
+class SparseSoftmaxOp : public OpKernel {
+ public:
+  explicit SparseSoftmaxOp(OpKernelConstruction *context) : OpKernel(context) {}
+
+  void Compute(OpKernelContext *context) override {
+    const Tensor *indices_t, *values_t, *shape_t;
+    OP_REQUIRES_OK(context, context->input("sp_indices", &indices_t));
+    OP_REQUIRES_OK(context, context->input("sp_values", &values_t));
+    OP_REQUIRES_OK(context, context->input("sp_shape", &shape_t));
+
+    // Validations.
+    OP_REQUIRES(context, TensorShapeUtils::IsMatrix(indices_t->shape()),
+                errors::InvalidArgument(
+                    "Input sp_indices should be a matrix but received shape: ",
+                    indices_t->shape().DebugString()));
+    OP_REQUIRES(context, TensorShapeUtils::IsVector(values_t->shape()) &&
+                             TensorShapeUtils::IsVector(shape_t->shape()),
+                errors::InvalidArgument(
+                    "Inputs sp_values and sp_shape should be vectors "
+                    "but received shapes: ",
+                    values_t->shape().DebugString(), " and ",
+                    shape_t->shape().DebugString()));
+    OP_REQUIRES(context, shape_t->NumElements() >= 2,
+                errors::InvalidArgument(
+                    "Input should have rank >= 2, but received shape: ",
+                    shape_t->SummarizeValue(3)));
+    OP_REQUIRES(context,
+                indices_t->dim_size(0) < std::numeric_limits<int>::max(),
+                errors::InvalidArgument(
+                    "Number of non-zero elements exceeds int32 range"));
+
+    const int nnz = static_cast<int>(indices_t->dim_size(0));
+    const int rank = static_cast<int>(indices_t->dim_size(1));
+    SparseTensor st(tensor::DeepCopy(*indices_t), tensor::DeepCopy(*values_t),
+                    TensorShape(shape_t->flat<int64>()));
+
+    Tensor *output_values = nullptr;
+    OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({nnz}),
+                                                     &output_values));
+    typename TTypes<T>::Flat output_flat = output_values->flat<T>();
+
+    Tensor tmp_t;
+    OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
+                                                   TensorShape({}), &tmp_t));
+    typename TTypes<T>::Scalar tmp_scalar = tmp_t.scalar<T>();
+
+    gtl::InlinedVector<int64, 4> dims(rank);
+    std::iota(dims.begin(), dims.end(), 0);
+    // { 0, ..., rank-1 }.
+    const ArraySlice<int64> kReorderDims(dims);
+    // All but the last dim -- the class dimension to be max-reduced along.
+    const ArraySlice<int64> kGroupByDims(kReorderDims, 0, rank - 1);
+    st.Reorder<T>(kReorderDims);
+    int count = 0;
+
+    // The SparseTensor has logical shape [..., b, c], where the
+    // innermost size-"c" dimension is the class dimension to be max-reduced.
+    // Therefore we group by the first (rank - 1) dimensions.
+    const Device &device = context->eigen_device<Device>();
+    for (const auto &g : st.group(kGroupByDims)) {
+      const auto group_vals = g.values<T>();
+      const int group_size = group_vals.size();
+
+      // Shifts by max, exponentiates, then renormalizes.
+      tmp_scalar.device(context->eigen_device<Device>()) = group_vals.maximum();
+      const T group_max = tmp_scalar();
+
+      Eigen::Tensor<T, 1, Eigen::RowMajor> tmp(group_size);
+      tmp.device(device) = (group_vals - tmp.constant(group_max)).exp();
+
+      tmp_scalar.device(device) = tmp.sum().inverse();
+      tmp.device(device) = tmp * tmp.constant(tmp_scalar());
+
+      // Assigns back to output[count, count + group_size).
+      Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>> output_part(
+          output_flat.data() + count, group_size);
+      output_part.device(device) = tmp;
+
+      count += group_size;
+    }
+  }
+};
+
+#define REGISTER_KERNEL(T)                                             \
+  REGISTER_KERNEL_BUILDER(                                             \
+      Name("SparseSoftmax").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+      SparseSoftmaxOp<CPUDevice, T>)
+
+REGISTER_KERNEL(float);
+REGISTER_KERNEL(double);
+#undef REGISTER_KERNEL
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/string_to_hash_bucket_op.cc b/tensorflow/core/kernels/string_to_hash_bucket_op.cc
index 3a2429d4cd0..e00cd25f455 100644
--- a/tensorflow/core/kernels/string_to_hash_bucket_op.cc
+++ b/tensorflow/core/kernels/string_to_hash_bucket_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
 
 #include "tensorflow/core/lib/hash/hash.h"
 #include "tensorflow/core/platform/fingerprint.h"
+#include "tensorflow/core/platform/strong_hash.h"
 
 namespace tensorflow {
 
@@ -57,11 +58,14 @@ class LegacyStringToHashBuckeOp : public OpKernel {
   TF_DISALLOW_COPY_AND_ASSIGN(LegacyStringToHashBuckeOp);
 };
 
-// StringToHashBucket is deprecated in favor of StringToHashBucketStable.
+// StringToHashBucket is deprecated in favor of StringToHashBucketFast/Strong.
 REGISTER_KERNEL_BUILDER(Name("StringToHashBucket").Device(DEVICE_CPU),
                         LegacyStringToHashBuckeOp);
 
 REGISTER_KERNEL_BUILDER(Name("StringToHashBucketFast").Device(DEVICE_CPU),
                         StringToHashBucketOp<Fingerprint64>);
 
+REGISTER_KERNEL_BUILDER(Name("StringToHashBucketStrong").Device(DEVICE_CPU),
+                        StringToKeyedHashBucketOp<StrongKeyedHash>);
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/string_to_hash_bucket_op.h b/tensorflow/core/kernels/string_to_hash_bucket_op.h
index 9c6c0a89e42..0c3acbebbfb 100644
--- a/tensorflow/core/kernels/string_to_hash_bucket_op.h
+++ b/tensorflow/core/kernels/string_to_hash_bucket_op.h
@@ -61,6 +61,49 @@ class StringToHashBucketOp : public OpKernel {
   TF_DISALLOW_COPY_AND_ASSIGN(StringToHashBucketOp);
 };
 
+template <uint64 hash(const uint64 (&)[2], const string&)>
+class StringToKeyedHashBucketOp : public OpKernel {
+ public:
+  explicit StringToKeyedHashBucketOp(OpKernelConstruction* ctx)
+      : OpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("num_buckets", &num_buckets_));
+
+    std::vector<int64> key;
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("key", &key));
+    OP_REQUIRES(ctx, key.size() == 2,
+                errors::InvalidArgument("Key must have 2 elements"));
+    std::memcpy(key_, key.data(), sizeof(key_));
+  }
+
+  void Compute(OpKernelContext* context) override {
+    const Tensor* input_tensor;
+    OP_REQUIRES_OK(context, context->input("input", &input_tensor));
+    const auto& input_flat = input_tensor->flat<string>();
+
+    Tensor* output_tensor = nullptr;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output("output", input_tensor->shape(),
+                                            &output_tensor));
+    auto output_flat = output_tensor->flat<int64>();
+
+    typedef decltype(input_flat.size()) Index;
+    for (Index i = 0; i < input_flat.size(); ++i) {
+      const uint64 input_hash = hash(key_, input_flat(i));
+      const uint64 bucket_id = input_hash % num_buckets_;
+      // The number of buckets is always in the positive range of int64 so is
+      // the resulting bucket_id. Casting the bucket_id from uint64 to int64 is
+      // safe.
+      output_flat(i) = static_cast<int64>(bucket_id);
+    }
+  }
+
+ private:
+  int64 num_buckets_;
+  uint64 key_[2];
+
+  TF_DISALLOW_COPY_AND_ASSIGN(StringToKeyedHashBucketOp);
+};
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_KERNELS_STRING_TO_HASH_BUCKET_OP_H_
diff --git a/tensorflow/core/kernels/summary_image_op.cc b/tensorflow/core/kernels/summary_image_op.cc
index 74465c49056..c5000a8596c 100644
--- a/tensorflow/core/kernels/summary_image_op.cc
+++ b/tensorflow/core/kernels/summary_image_op.cc
@@ -85,27 +85,12 @@ class SummaryImageOp : public OpKernel {
             &values(i, 0, 0), Eigen::DSizes<Eigen::DenseIndex, 2>(hw, depth));
       };
       AddImages(base_tag, batch_size, w, h, depth, ith_image, &s);
+    } else if (tensor.dtype() == DT_HALF) {
+      NormalizeAndAddImages<Eigen::half>(c, tensor, h, w, hw, depth, batch_size,
+                                         base_tag, &s);
     } else {  // tensor.dtype() == DT_FLOAT
-      // For float images, nans and infs are replaced with bad_color.
-      OP_REQUIRES(c, bad_color_.dim_size(0) >= depth,
-                  errors::InvalidArgument(
-                      "expected depth <= bad_color.size, got depth = ", depth,
-                      ", bad_color.size = ", bad_color_.dim_size(0)));
-      auto bad_color_full = bad_color_.vec<uint8>();
-      typename TTypes<uint8>::ConstVec bad_color(bad_color_full.data(), depth);
-
-      // Float images must be scaled and translated.
-      Uint8Image image(hw, depth);
-      auto ith_image = [&tensor, &image, bad_color, batch_size, hw,
-                        depth](int i) {
-        auto tensor_eigen = tensor.shaped<float, 3>({batch_size, hw, depth});
-        typename TTypes<float>::ConstMatrix values(
-            &tensor_eigen(i, 0, 0),
-            Eigen::DSizes<Eigen::DenseIndex, 2>(hw, depth));
-        NormalizeFloatImage(hw, depth, values, bad_color, &image);
-        return image;
-      };
-      AddImages(base_tag, batch_size, w, h, depth, ith_image, &s);
+      NormalizeAndAddImages<float>(c, tensor, h, w, hw, depth, batch_size,
+                                   base_tag, &s);
     }
 
     Tensor* summary_tensor = nullptr;
@@ -113,6 +98,32 @@ class SummaryImageOp : public OpKernel {
     CHECK(s.SerializeToString(&summary_tensor->scalar<string>()()));
   }
 
+  template <class T>
+  void NormalizeAndAddImages(OpKernelContext* c, const Tensor& tensor, int h,
+                             int w, int hw, int depth, int batch_size,
+                             const string& base_tag, Summary* s) {
+    // For float and half images, nans and infs are replaced with bad_color.
+    OP_REQUIRES(c, bad_color_.dim_size(0) >= depth,
+                errors::InvalidArgument(
+                    "expected depth <= bad_color.size, got depth = ", depth,
+                    ", bad_color.size = ", bad_color_.dim_size(0)));
+    auto bad_color_full = bad_color_.vec<uint8>();
+    typename TTypes<uint8>::ConstVec bad_color(bad_color_full.data(), depth);
+
+    // Float images must be scaled and translated.
+    Uint8Image image(hw, depth);
+    auto ith_image = [&tensor, &image, bad_color, batch_size, hw,
+                      depth](int i) {
+      auto tensor_eigen = tensor.template shaped<T, 3>({batch_size, hw, depth});
+      typename TTypes<T>::ConstMatrix values(
+          &tensor_eigen(i, 0, 0),
+          Eigen::DSizes<Eigen::DenseIndex, 2>(hw, depth));
+      NormalizeFloatImage<T>(hw, depth, values, bad_color, &image);
+      return image;
+    };
+    AddImages(base_tag, batch_size, w, h, depth, ith_image, s);
+  }
+
   // Add the sequence of images specified by ith_image to the summary.
   //
   // Factoring this loop out into a helper function lets ith_image behave
@@ -153,15 +164,16 @@ class SummaryImageOp : public OpKernel {
     return Status::OK();
   }
 
+  template <class T>
   static void NormalizeFloatImage(int hw, int depth,
-                                  typename TTypes<float>::ConstMatrix values,
+                                  typename TTypes<T>::ConstMatrix values,
                                   typename TTypes<uint8>::ConstVec bad_color,
                                   Uint8Image* image) {
     if (!image->size()) return;  // Nothing to do for empty images
 
     // Rescale the image to uint8 range.
     //
-    // We are trying to generate an RGB image from a float tensor.  We do
+    // We are trying to generate an RGB image from a float/half tensor.  We do
     // not have any info about the expected range of values in the tensor
     // but the generated image needs to have all RGB values within [0, 255].
     //
@@ -179,14 +191,14 @@ class SummaryImageOp : public OpKernel {
     for (int i = 0; i < hw; i++) {
       bool finite = true;
       for (int j = 0; j < depth; j++) {
-        if (!std::isfinite(values(i, j))) {
+        if (!Eigen::numext::isfinite(values(i, j))) {
           finite = false;
           break;
         }
       }
       if (finite) {
         for (int j = 0; j < depth; j++) {
-          float value = values(i, j);
+          float value(values(i, j));
           image_min = std::min(image_min, value);
           image_max = std::max(image_max, value);
         }
@@ -195,27 +207,28 @@ class SummaryImageOp : public OpKernel {
 
     // Pick an affine transform into uint8
     const float kZeroThreshold = 1e-6;
-    float scale, offset;
+    T scale, offset;
     if (image_min < 0) {
       float max_val = std::max(std::abs(image_min), std::abs(image_max));
-      scale = max_val < kZeroThreshold ? 0.0f : 127.0f / max_val;
-      offset = 128.0f;
+      scale = T(max_val < kZeroThreshold ? 0.0f : 127.0f / max_val);
+      offset = T(128.0f);
     } else {
-      scale = image_max < kZeroThreshold ? 0.0f : 255.0f / image_max;
-      offset = 0.0f;
+      scale = T(image_max < kZeroThreshold ? 0.0f : 255.0f / image_max);
+      offset = T(0.0f);
     }
 
     // Transform image, turning nonfinite values to bad_color
     for (int i = 0; i < hw; i++) {
       bool finite = true;
       for (int j = 0; j < depth; j++) {
-        if (!std::isfinite(values(i, j))) {
+        if (!Eigen::numext::isfinite(values(i, j))) {
           finite = false;
           break;
         }
       }
       if (finite) {
-        image->chip<0>(i) = (values.chip<0>(i) * scale + offset).cast<uint8>();
+        image->chip<0>(i) = (values.template chip<0>(i) * scale + offset)
+                                .template cast<uint8>();
       } else {
         image->chip<0>(i) = bad_color;
       }
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index 6afb9d25397..5f68637f273 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -173,7 +173,7 @@ std::vector<mutex_lock> MaybeLockMutexesInOrder(
             [&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
 
   for (auto input : acquire_order) {
-    locks.push_back(mutex_lock(*ctx->input_ref_mutex(input)));
+    locks.emplace_back(*ctx->input_ref_mutex(input));
   }
   return locks;
 }
diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc
index 2ccee283b4c..af5d44565cb 100644
--- a/tensorflow/core/lib/core/threadpool.cc
+++ b/tensorflow/core/lib/core/threadpool.cc
@@ -87,14 +87,11 @@ struct ThreadPool::Impl : Eigen::ThreadPoolTempl<EigenEnvironment> {
         num_threads_(num_threads) {}
 
   void ParallelFor(int64 total, int64 cost_per_unit,
-                   std::function<void(int64, int64)> fn,
-                   int32 max_parallelism = kint32max) {
+                   std::function<void(int64, int64)> fn) {
 #ifdef EIGEN_USE_NONBLOCKING_THREAD_POOL
-    CHECK_GT(max_parallelism, 0);
     CHECK_GE(total, 0);
     CHECK_EQ(total, (int64)(Eigen::Index)total);
-    Eigen::ThreadPoolDevice device(this,
-                                   std::min(num_threads_, max_parallelism));
+    Eigen::ThreadPoolDevice device(this, num_threads_);
     device.parallelFor(
         total, Eigen::TensorOpCost(0, 0, cost_per_unit),
         [&fn](Eigen::Index first, Eigen::Index last) { fn(first, last); });
@@ -103,6 +100,8 @@ struct ThreadPool::Impl : Eigen::ThreadPoolTempl<EigenEnvironment> {
 #endif
   }
 
+  int NumThreads() const { return num_threads_; };
+
   const int num_threads_;
 };
 
@@ -114,11 +113,12 @@ struct ThreadPool::Impl {
   ~Impl();
   void Schedule(std::function<void()> fn);
   void ParallelFor(int64 total, int64 cost_per_unit,
-                   std::function<void(int64, int64)> fn,
-                   int32 max_parallelism = kint32max) {
+                   std::function<void(int64, int64)> fn) {
     CHECK(0);  // should not be used with the old thread pool
   }
 
+  int NumThreads() const { return threads_.size(); };
+
  private:
   struct Waiter {
     condition_variable cv;
@@ -242,10 +242,11 @@ void ThreadPool::Schedule(std::function<void()> fn) {
 }
 
 void ThreadPool::ParallelFor(int64 total, int64 cost_per_unit,
-                             std::function<void(int64, int64)> fn,
-                             int32 max_parallelism) {
-  impl_->ParallelFor(total, cost_per_unit, std::move(fn), max_parallelism);
+                             std::function<void(int64, int64)> fn) {
+  impl_->ParallelFor(total, cost_per_unit, std::move(fn));
 }
 
+int ThreadPool::NumThreads() const { return impl_->NumThreads(); }
+
 }  // namespace thread
 }  // namespace tensorflow
diff --git a/tensorflow/core/lib/core/threadpool.h b/tensorflow/core/lib/core/threadpool.h
index 30049fb2520..fe7f2d0d86b 100644
--- a/tensorflow/core/lib/core/threadpool.h
+++ b/tensorflow/core/lib/core/threadpool.h
@@ -51,12 +51,11 @@ class ThreadPool {
   // having roughly "cost_per_unit" cost, in cycles. Each unit of work is
   // indexed 0, 1, ..., total - 1. Each shard contains 1 or more units of work
   // and the total cost of each shard is roughly the same.
-  // Max_parallelism optionally caps the number of threads used.
-  //
-  // REQUIRES: max_parallelism > 0.
   void ParallelFor(int64 total, int64 cost_per_unit,
-                   std::function<void(int64, int64)> fn,
-                   int32 max_parallelism = kint32max);
+                   std::function<void(int64, int64)> fn);
+
+  // Returns the number of threads in the pool.
+  int NumThreads() const;
 
   struct Impl;
 
diff --git a/tensorflow/core/lib/core/threadpool_test.cc b/tensorflow/core/lib/core/threadpool_test.cc
index 524af800d87..5043e54459a 100644
--- a/tensorflow/core/lib/core/threadpool_test.cc
+++ b/tensorflow/core/lib/core/threadpool_test.cc
@@ -66,22 +66,17 @@ TEST(ThreadPool, ParallelFor) {
     const int kWorkItems = 15;
     bool work[kWorkItems];
     ThreadPool pool(Env::Default(), "test", num_threads);
-    for (int max_parallelism = 1; max_parallelism <= kNumThreads + 1;
-         max_parallelism++) {
-      for (int i = 0; i < kWorkItems; i++) {
-        work[i] = false;
-      }
-      pool.ParallelFor(kWorkItems, kHugeCost,
-                       [&work](int64 begin, int64 end) {
-                         for (int64 i = begin; i < end; ++i) {
-                           ASSERT_FALSE(work[i]);
-                           work[i] = true;
-                         }
-                       },
-                       max_parallelism);
-      for (int i = 0; i < kWorkItems; i++) {
-        ASSERT_TRUE(work[i]);
+    for (int i = 0; i < kWorkItems; i++) {
+      work[i] = false;
+    }
+    pool.ParallelFor(kWorkItems, kHugeCost, [&work](int64 begin, int64 end) {
+      for (int64 i = begin; i < end; ++i) {
+        ASSERT_FALSE(work[i]);
+        work[i] = true;
       }
+    });
+    for (int i = 0; i < kWorkItems; i++) {
+      ASSERT_TRUE(work[i]);
     }
   }
 }
diff --git a/tensorflow/core/ops/compat/op_compatibility_lib.cc b/tensorflow/core/ops/compat/op_compatibility_lib.cc
index b8df96d5e8f..36b6620f437 100644
--- a/tensorflow/core/ops/compat/op_compatibility_lib.cc
+++ b/tensorflow/core/ops/compat/op_compatibility_lib.cc
@@ -98,7 +98,6 @@ Status OpCompatibilityLib::ValidateCompatible(Env* env, int* changed_ops,
     const string& op_name = op_list_.op(cur).name();
     if (stable_ops_ != nullptr && stable_ops_->count(op_name) == 0) {
       // Ignore unstable op.
-      ++cur;
       for (++cur; cur < op_list_.op_size(); ++cur) {
         if (op_list_.op(cur).name() != op_name) break;
       }
diff --git a/tensorflow/core/ops/compat/ops_history.v0.pbtxt b/tensorflow/core/ops/compat/ops_history.v0.pbtxt
index d95b0d0b59e..1f0b4de6311 100644
--- a/tensorflow/core/ops/compat/ops_history.v0.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v0.pbtxt
@@ -7149,6 +7149,34 @@ op {
     type: DT_FLOAT
   }
 }
+op {
+  name: "DrawBoundingBoxes"
+  input_arg {
+    name: "images"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "boxes"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_HALF
+      }
+    }
+  }
+}
 op {
   name: "DynamicPartition"
   input_arg {
@@ -9239,6 +9267,62 @@ op {
     }
   }
 }
+op {
+  name: "ImageSummary"
+  input_arg {
+    name: "tag"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "tensor"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "summary"
+    type: DT_STRING
+  }
+  attr {
+    name: "max_images"
+    type: "int"
+    default_value {
+      i: 3
+    }
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_UINT8
+        type: DT_FLOAT
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "bad_color"
+    type: "tensor"
+    default_value {
+      tensor {
+        dtype: DT_UINT8
+        tensor_shape {
+          dim {
+            size: 4
+          }
+        }
+        int_val: 255
+        int_val: 0
+        int_val: 0
+        int_val: 255
+      }
+    }
+  }
+}
 op {
   name: "ImmutableConst"
   output_arg {
@@ -12709,6 +12793,53 @@ op {
     type: "type"
   }
 }
+op {
+  name: "OneHot"
+  input_arg {
+    name: "indices"
+    type_attr: "TI"
+  }
+  input_arg {
+    name: "depth"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "on_value"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "off_value"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "axis"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "TI"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
 op {
   name: "Pack"
   input_arg {
@@ -19482,6 +19613,51 @@ op {
     type: "type"
   }
 }
+op {
+  name: "SparseDenseCwiseAdd"
+  input_arg {
+    name: "sp_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sp_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sp_shape"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "dense"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+}
 op {
   name: "SparseDenseCwiseDiv"
   input_arg {
@@ -19992,6 +20168,35 @@ op {
     }
   }
 }
+op {
+  name: "SparseSoftmax"
+  input_arg {
+    name: "sp_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sp_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sp_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+}
 op {
   name: "SparseSoftmaxCrossEntropyWithLogits"
   input_arg {
@@ -20771,6 +20976,27 @@ op {
     minimum: 1
   }
 }
+op {
+  name: "StringToHashBucketStrong"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type: DT_INT64
+  }
+  attr {
+    name: "num_buckets"
+    type: "int"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "key"
+    type: "list(int)"
+  }
+}
 op {
   name: "StringToNumber"
   input_arg {
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index ed0a5faa51e..3a25fde488d 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -387,9 +387,10 @@ output: `images` converted to RGB.
 
 // --------------------------------------------------------------------------
 REGISTER_OP("DrawBoundingBoxes")
-    .Input("images: float")
+    .Input("images: T")
     .Input("boxes: float")
-    .Output("output: float")
+    .Output("output: T")
+    .Attr("T: {float, half} = DT_FLOAT")
     .Doc(R"doc(
 Draw bounding boxes on a batch of images.
 
diff --git a/tensorflow/core/ops/logging_ops.cc b/tensorflow/core/ops/logging_ops.cc
index 69855df41df..cd66c0894fa 100644
--- a/tensorflow/core/ops/logging_ops.cc
+++ b/tensorflow/core/ops/logging_ops.cc
@@ -101,7 +101,7 @@ REGISTER_OP("ImageSummary")
     .Input("tensor: T")
     .Output("summary: string")
     .Attr("max_images: int >= 1 = 3")
-    .Attr("T: {uint8, float} = DT_FLOAT")
+    .Attr("T: {uint8, float, half} = DT_FLOAT")
     .Attr(
         "bad_color: tensor = { dtype: DT_UINT8 "
         "tensor_shape: { dim { size: 4 } } "
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 9cad7911392..0b5997fbaad 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -3764,7 +3764,7 @@ op {
   input_arg {
     name: "images"
     description: "4-D with shape `[batch, height, width, depth]`. A batch of images."
-    type: DT_FLOAT
+    type_attr: "T"
   }
   input_arg {
     name: "boxes"
@@ -3774,7 +3774,20 @@ op {
   output_arg {
     name: "output"
     description: "4-D with the same shape as `images`. The batch of input images with\nbounding boxes drawn on the images."
-    type: DT_FLOAT
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_HALF
+      }
+    }
   }
   summary: "Draw bounding boxes on a batch of images."
   description: "Outputs a copy of `images` but draws on top of the pixels zero or more bounding\nboxes specified by the locations in `boxes`. The coordinates of the each\nbounding box in `boxes are encoded as `[y_min, x_min, y_max, x_max]`. The\nbounding box coordinates are floats in `[0.0, 1.0]` relative to the width and\nheight of the underlying image.\n\nFor example, if an image is 100 x 200 pixels and the bounding box is\n`[0.1, 0.5, 0.2, 0.9]`, the bottom-left and upper-right coordinates of the\nbounding box will be `(10, 40)` to `(50, 180)`.\n\nParts of the bounding box may fall outside the image."
@@ -5119,6 +5132,7 @@ op {
       list {
         type: DT_UINT8
         type: DT_FLOAT
+        type: DT_HALF
       }
     }
   }
@@ -7150,7 +7164,7 @@ op {
   input_arg {
     name: "indices"
     description: "A tensor of indices."
-    type: DT_INT64
+    type_attr: "TI"
   }
   input_arg {
     name: "depth"
@@ -7184,6 +7198,19 @@ op {
     name: "T"
     type: "type"
   }
+  attr {
+    name: "TI"
+    type: "type"
+    default_value {
+      type: DT_INT64
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
   summary: "Returns a one-hot tensor."
   description: "The locations represented by indices in `indices` take value `on_value`,\nwhile all other locations take value `off_value`.\n\nIf the input `indices` is rank `N`, the output will have rank `N+1`,\nThe new axis is created at dimension `axis` (default: the new axis is\nappended at the end).\n\nIf `indices` is a scalar the output shape will be a vector of length `depth`.\n\nIf `indices` is a vector of length `features`, the output shape will be:\n```\n  features x depth if axis == -1\n  depth x features if axis == 0\n```\n\nIf `indices` is a matrix (batch) with shape `[batch, features]`,\nthe output shape will be:\n```\n  batch x features x depth if axis == -1\n  batch x depth x features if axis == 1\n  depth x batch x features if axis == 0\n```\n\n\nExamples\n=========\n\nSuppose that\n\n```\n  indices = [0, 2, -1, 1]\n  depth = 3\n  on_value = 5.0\n  off_value = 0.0\n  axis = -1\n```\n\nThen output is `[4 x 3]`:\n\n    ```output =\n      [5.0 0.0 0.0]  // one_hot(0)\n      [0.0 0.0 5.0]  // one_hot(2)\n      [0.0 0.0 0.0]  // one_hot(-1)\n      [0.0 5.0 0.0]  // one_hot(1)\n    ```\n\nSuppose that\n\n```\n  indices = [0, 2, -1, 1]\n  depth = 3\n  on_value = 0.0\n  off_value = 3.0\n  axis = 0\n```\n\nThen output is `[3 x 4]`:\n\n    ```output =\n      [0.0 3.0 3.0 3.0]\n      [3.0 3.0 3.0 0.0]\n      [3.0 3.0 3.0 3.0]\n      [3.0 0.0 3.0 3.0]\n    //  ^                one_hot(0)\n    //      ^            one_hot(2)\n    //          ^        one_hot(-1)\n    //              ^    one_hot(1)\n    ```\nSuppose that\n\n```\n  indices = [[0, 2], [1, -1]]\n  depth = 3\n  on_value = 1.0\n  off_value = 0.0\n  axis = -1\n```\n\nThen output is `[2 x 2 x 3]`:\n\n    ```output =\n      [\n        [1.0, 0.0, 0.0]  // one_hot(0)\n        [0.0, 0.0, 1.0]  // one_hot(2)\n      ][\n        [0.0, 1.0, 0.0]  // one_hot(1)\n        [0.0, 0.0, 0.0]  // one_hot(-1)\n      ]```"
 }
@@ -11164,6 +11191,58 @@ op {
   summary: "Concatenates a list of `SparseTensor` along the specified dimension."
   description: "Concatenation is with respect to the dense versions of these sparse tensors.\nIt is assumed that each input is a `SparseTensor` whose elements are ordered\nalong increasing dimension number.\n\nAll inputs\' shapes must match, except for the concat dimension.  The\n`indices`, `values`, and `shapes` lists must have the same length.\n\nThe output shape is identical to the inputs\', except along the concat\ndimension, where it is the sum of the inputs\' sizes along that dimension.\n\nThe output elements will be resorted to preserve the sort order along\nincreasing dimension number.\n\nThis op runs in `O(M log M)` time, where `M` is the total number of non-empty\nvalues across all inputs. This is due to the need for an internal sort in\norder to concatenate efficiently across an arbitrary dimension.\n\nFor example, if `concat_dim = 1` and the inputs are\n\n    sp_inputs[0]: shape = [2, 3]\n    [0, 2]: \"a\"\n    [1, 0]: \"b\"\n    [1, 1]: \"c\"\n\n    sp_inputs[1]: shape = [2, 4]\n    [0, 1]: \"d\"\n    [0, 2]: \"e\"\n\nthen the output will be\n\n    shape = [2, 7]\n    [0, 2]: \"a\"\n    [0, 4]: \"d\"\n    [0, 5]: \"e\"\n    [1, 0]: \"b\"\n    [1, 1]: \"c\"\n\nGraphically this is equivalent to doing\n\n    [    a] concat [  d e  ] = [    a   d e  ]\n    [b c  ]        [       ]   [b c          ]"
 }
+op {
+  name: "SparseDenseCwiseAdd"
+  input_arg {
+    name: "sp_indices"
+    description: "2-D.  `N x R` matrix with the indices of non-empty values in a\nSparseTensor, possibly not in canonical ordering."
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sp_values"
+    description: "1-D.  `N` non-empty values corresponding to `sp_indices`."
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sp_shape"
+    description: "1-D.  Shape of the input SparseTensor."
+    type: DT_INT64
+  }
+  input_arg {
+    name: "dense"
+    description: "`R`-D.  The dense Tensor operand."
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    description: "1-D.  The `N` values that are operated on."
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  summary: "Adds up a SparseTensor and a dense Tensor, using these special rules:"
+  description: "(1) Broadcasts the dense side to have the same shape as the sparse side, if\n    eligible;\n(2) Then, only the dense values pointed to by the indices of the SparseTensor\n    participate in the cwise addition.\n\nBy these rules, the result is a logical SparseTensor with exactly the same\nindices and shape, but possibly with different non-zero values.  The output of\nthis Op is the resultant non-zero values."
+}
 op {
   name: "SparseDenseCwiseDiv"
   input_arg {
@@ -11266,7 +11345,7 @@ op {
     }
   }
   summary: "Component-wise multiplies a SparseTensor by a dense Tensor."
-  description: "*Limitation*: this Op only broadcasts the dense side to the sparse side, but not\nthe other direction."
+  description: "The output locations corresponding to the implicitly zero elements in the sparse\ntensor will be zero (i.e., will not take up storage space), regardless of the\ncontents of the dense tensor (even if it\'s +/-INF and that INF*0 == NaN).\n\n*Limitation*: this Op only broadcasts the dense side to the sparse side, but not\nthe other direction."
 }
 op {
   name: "SparseMatMul"
@@ -11620,6 +11699,37 @@ op {
   summary: "Computes the sum along sparse segments of a tensor."
   description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nLike `SegmentSum`, but `segment_ids` can have rank less than `data`\'s first\ndimension, selecting a subset of dimension 0, specified by `indices`.\n\nFor example:\n\n```prettyprint\nc = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])\n\n# Select two rows, one segment.\ntf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0]))\n  ==> [[0 0 0 0]]\n\n# Select two rows, two segment.\ntf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1]))\n  ==> [[ 1  2  3  4]\n       [-1 -2 -3 -4]]\n\n# Select all rows, two segments.\ntf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1]))\n  ==> [[0 0 0 0]\n       [5 6 7 8]]\n\n# Which is equivalent to:\ntf.segment_sum(c, tf.constant([0, 0, 1]))\n```"
 }
+op {
+  name: "SparseSoftmax"
+  input_arg {
+    name: "sp_indices"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sp_values"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "sp_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+      }
+    }
+  }
+  summary: "Applies softmax to a batched N-D `SparseTensor`."
+  description: "The inputs represent an N-D SparseTensor  with logical shape `[..., B, C]`\n(where `N >= 2`), and with indices sorted in the canonical lexicographic order.\n\nThis op is equivalent to applying the normal `tf.nn.softmax()` to each innermost\nlogical submatrix with shape `[B, C]`, but with the catch that *the implicitly\nzero elements do not participate*.  Specifically, the algorithm is equivalent"
+}
 op {
   name: "SparseSoftmaxCrossEntropyWithLogits"
   input_arg {
@@ -12157,14 +12267,14 @@ op {
   description: "The hash function is deterministic on the content of the string within the\nprocess.\n\nNote that the hash function may change from time to time."
   deprecation {
     version: 10
-    explanation: "Use tf.string_to_hash_bucket_fast()"
+    explanation: "Use `tf.string_to_hash_bucket_fast()` or `tf.string_to_hash_bucket_strong()`"
   }
 }
 op {
   name: "StringToHashBucketFast"
   input_arg {
     name: "input"
-    description: "The strings to assing a hash bucket."
+    description: "The strings to assign a hash bucket."
     type: DT_STRING
   }
   output_arg {
@@ -12180,7 +12290,34 @@ op {
     minimum: 1
   }
   summary: "Converts each string in the input Tensor to its hash mod by a number of buckets."
-  description: "The hash function is deterministic on the content of the string within the\nprocess and will never change. However, it is not suitable for cryptography."
+  description: "The hash function is deterministic on the content of the string within the\nprocess and will never change. However, it is not suitable for cryptography.\nThis function may be used when CPU time is scarce and inputs are trusted or\nunimportant. There is a risk of adversaries constructing inputs that all hash\nto the same bucket. To prevent this problem, use a strong hash function with\n`tf.string_to_hash_bucket_strong`."
+}
+op {
+  name: "StringToHashBucketStrong"
+  input_arg {
+    name: "input"
+    description: "The strings to assign a hash bucket."
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    description: "A Tensor of the same shape as the input `string_tensor`."
+    type: DT_INT64
+  }
+  attr {
+    name: "num_buckets"
+    type: "int"
+    description: "The number of buckets."
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "key"
+    type: "list(int)"
+    description: "The key for the keyed hash function passed as a list of two uint64\nelements."
+  }
+  summary: "Converts each string in the input Tensor to its hash mod by a number of buckets."
+  description: "The hash function is deterministic on the content of the string within the\nprocess. The hash function is a keyed hash function, where attribute `key`\ndefines the key of the hash function. `key` is an array of 2 elements.\n\nA strong hash is important when inputs may be malicious, e.g. URLs with\nadditional components. Adversaries could try to make their inputs hash to the\nsame bucket for a denial-of-service attack or to skew the results. A strong\nhash prevents this by making it dificult, if not infeasible, to compute inputs\nthat hash to the same bucket. This comes at a cost of roughly 4x higher compute\ntime than tf.string_to_hash_bucket_fast."
 }
 op {
   name: "StringToNumber"
diff --git a/tensorflow/core/ops/sparse_ops.cc b/tensorflow/core/ops/sparse_ops.cc
index 716444ae251..c8f4f8d25b0 100644
--- a/tensorflow/core/ops/sparse_ops.cc
+++ b/tensorflow/core/ops/sparse_ops.cc
@@ -441,16 +441,21 @@ keep_dims: If true, retain reduced dimensions with length 1.
 output: `R-K`-D.  The reduced Tensor.
 )doc");
 
-REGISTER_OP("SparseDenseCwiseMul")
-    .Input("sp_indices: int64")
-    .Input("sp_values: T")
-    .Input("sp_shape: int64")
-    .Input("dense: T")
-    .Output("output: T")
-    .Attr("T: numbertype")
-    .Doc(R"doc(
+#define SPARSE_DENSE_CWISE_SIGNATURE() \
+  Input("sp_indices: int64")           \
+      .Input("sp_values: T")           \
+      .Input("sp_shape: int64")        \
+      .Input("dense: T")               \
+      .Output("output: T")             \
+      .Attr("T: numbertype")
+
+REGISTER_OP("SparseDenseCwiseMul").SPARSE_DENSE_CWISE_SIGNATURE().Doc(R"doc(
 Component-wise multiplies a SparseTensor by a dense Tensor.
 
+The output locations corresponding to the implicitly zero elements in the sparse
+tensor will be zero (i.e., will not take up storage space), regardless of the
+contents of the dense tensor (even if it's +/-INF and that INF*0 == NaN).
+
 *Limitation*: this Op only broadcasts the dense side to the sparse side, but not
 the other direction.
 
@@ -462,14 +467,7 @@ dense: `R`-D.  The dense Tensor operand.
 output: 1-D.  The `N` values that are operated on.
 )doc");
 
-REGISTER_OP("SparseDenseCwiseDiv")
-    .Input("sp_indices: int64")
-    .Input("sp_values: T")
-    .Input("sp_shape: int64")
-    .Input("dense: T")
-    .Output("output: T")
-    .Attr("T: numbertype")
-    .Doc(R"doc(
+REGISTER_OP("SparseDenseCwiseDiv").SPARSE_DENSE_CWISE_SIGNATURE().Doc(R"doc(
 Component-wise divides a SparseTensor by a dense Tensor.
 
 *Limitation*: this Op only broadcasts the dense side to the sparse side, but not
@@ -483,4 +481,56 @@ dense: `R`-D.  The dense Tensor operand.
 output: 1-D.  The `N` values that are operated on.
 )doc");
 
+REGISTER_OP("SparseDenseCwiseAdd").SPARSE_DENSE_CWISE_SIGNATURE().Doc(R"doc(
+Adds up a SparseTensor and a dense Tensor, using these special rules:
+
+(1) Broadcasts the dense side to have the same shape as the sparse side, if
+    eligible;
+(2) Then, only the dense values pointed to by the indices of the SparseTensor
+    participate in the cwise addition.
+
+By these rules, the result is a logical SparseTensor with exactly the same
+indices and shape, but possibly with different non-zero values.  The output of
+this Op is the resultant non-zero values.
+
+sp_indices: 2-D.  `N x R` matrix with the indices of non-empty values in a
+  SparseTensor, possibly not in canonical ordering.
+sp_values: 1-D.  `N` non-empty values corresponding to `sp_indices`.
+sp_shape: 1-D.  Shape of the input SparseTensor.
+dense: `R`-D.  The dense Tensor operand.
+output: 1-D.  The `N` values that are operated on.
+)doc");
+
+REGISTER_OP("SparseSoftmax")
+    .Input("sp_indices: int64")
+    .Input("sp_values: T")
+    .Input("sp_shape: int64")
+    .Output("output: T")
+    .Attr("T: {float, double}")
+    .Doc(R"doc(
+Applies softmax to a batched N-D `SparseTensor`.
+
+The inputs represent an N-D SparseTensor  with logical shape `[..., B, C]`
+(where `N >= 2`), and with indices sorted in the canonical lexicographic order.
+
+This op is equivalent to applying the normal `tf.nn.softmax()` to each innermost
+logical submatrix with shape `[B, C]`, but with the catch that *the implicitly
+zero elements do not participate*.  Specifically, the algorithm is equivalent
+to:
+
+  (1) Applies `tf.nn.softmax()` to a densified view of each innermost submatrix
+      with shape `[B, C]`, along the size-C dimension;
+  (2) Masks out the original implicitly-zero locations;
+  (3) Renormalizes the remaining elements.
+
+Hence, the `SparseTensor` result has exactly the same non-zero indices and
+shape.
+
+sp_indices: 2-D.  `NNZ x R` matrix with the indices of non-empty values in a
+  SparseTensor, in canonical ordering.
+sp_values: 1-D.  `NNZ` non-empty values corresponding to `sp_indices`.
+sp_shape: 1-D.  Shape of the input SparseTensor.
+output: 1-D.  The `NNZ` values for the result `SparseTensor`.
+)doc");
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index 1a274f1e68a..526fc35eb25 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -26,17 +26,49 @@ Converts each string in the input Tensor to its hash mod by a number of buckets.
 
 The hash function is deterministic on the content of the string within the
 process and will never change. However, it is not suitable for cryptography.
+This function may be used when CPU time is scarce and inputs are trusted or
+unimportant. There is a risk of adversaries constructing inputs that all hash
+to the same bucket. To prevent this problem, use a strong hash function with
+`tf.string_to_hash_bucket_strong`.
 
-input: The strings to assing a hash bucket.
+input: The strings to assign a hash bucket.
 num_buckets: The number of buckets.
 output: A Tensor of the same shape as the input `string_tensor`.
 )doc");
 
+REGISTER_OP("StringToHashBucketStrong")
+    .Input("input: string")
+    .Output("output: int64")
+    .Attr("num_buckets: int >= 1")
+    .Attr("key: list(int)")
+    .Doc(R"doc(
+Converts each string in the input Tensor to its hash mod by a number of buckets.
+
+The hash function is deterministic on the content of the string within the
+process. The hash function is a keyed hash function, where attribute `key`
+defines the key of the hash function. `key` is an array of 2 elements.
+
+A strong hash is important when inputs may be malicious, e.g. URLs with
+additional components. Adversaries could try to make their inputs hash to the
+same bucket for a denial-of-service attack or to skew the results. A strong
+hash prevents this by making it dificult, if not infeasible, to compute inputs
+that hash to the same bucket. This comes at a cost of roughly 4x higher compute
+time than tf.string_to_hash_bucket_fast.
+
+input: The strings to assign a hash bucket.
+num_buckets: The number of buckets.
+key: The key for the keyed hash function passed as a list of two uint64
+  elements.
+output: A Tensor of the same shape as the input `string_tensor`.
+)doc");
+
 REGISTER_OP("StringToHashBucket")
     .Input("string_tensor: string")
     .Output("output: int64")
     .Attr("num_buckets: int >= 1")
-    .Deprecated(10, "Use tf.string_to_hash_bucket_fast()")
+    .Deprecated(10,
+                "Use `tf.string_to_hash_bucket_fast()` or "
+                "`tf.string_to_hash_bucket_strong()`")
     .Doc(R"doc(
 Converts each string in the input Tensor to its hash mod by a number of buckets.
 
diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD
index 109fd18e6b5..66e2c75934f 100644
--- a/tensorflow/core/platform/default/build_config/BUILD
+++ b/tensorflow/core/platform/default/build_config/BUILD
@@ -50,6 +50,7 @@ cc_library(
         "@farmhash_archive//:farmhash",
         "@jpeg_archive//:jpeg",
         "@png_archive//:png",
+        "@highwayhash//:sip_hash",
         "@re2//:re2",
         "//tensorflow/core:protos_cc",
     ],
diff --git a/tensorflow/core/platform/default/mutex.h b/tensorflow/core/platform/default/mutex.h
index 904e8a689fa..c3eb11f37f0 100644
--- a/tensorflow/core/platform/default/mutex.h
+++ b/tensorflow/core/platform/default/mutex.h
@@ -25,6 +25,8 @@ limitations under the License.
 #include "tensorflow/core/platform/thread_annotations.h"
 namespace tensorflow {
 
+#undef mutex_lock
+
 enum LinkerInitialized { LINKER_INITIALIZED };
 
 // A class that wraps around the std::mutex implementation, only adding an
@@ -53,6 +55,9 @@ class SCOPED_LOCKABLE mutex_lock : public std::unique_lock<std::mutex> {
   ~mutex_lock() RELEASE() {}
 };
 
+// Catch bug where variable name is omitted, e.g. mutex_lock (mu);
+#define mutex_lock(x) static_assert(0, "mutex_lock_decl_missing_var_name");
+
 using std::condition_variable;
 
 inline ConditionResult WaitForMilliseconds(mutex_lock* mu,
diff --git a/tensorflow/core/platform/default/strong_hash.h b/tensorflow/core/platform/default/strong_hash.h
new file mode 100644
index 00000000000..53d1dae98dd
--- /dev/null
+++ b/tensorflow/core/platform/default/strong_hash.h
@@ -0,0 +1,30 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_STRONG_HASH_H_
+#define TENSORFLOW_CORE_PLATFORM_DEFAULT_STRONG_HASH_H_
+
+#include "highwayhash/sip_hash.h"
+#include "highwayhash/state_helpers.h"
+
+namespace tensorflow {
+
+inline uint64 StrongKeyedHash(const uint64 (&key)[2], const string& s) {
+  return highwayhash::StringHasher<highwayhash::SipHashState>()(key, s);
+}
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_PLATFORM_DEFAULT_STRONG_HASH_H_
diff --git a/tensorflow/core/platform/strong_hash.h b/tensorflow/core/platform/strong_hash.h
new file mode 100644
index 00000000000..7bd3eed6106
--- /dev/null
+++ b/tensorflow/core/platform/strong_hash.h
@@ -0,0 +1,45 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_STRONG_HASH_H_
+#define TENSORFLOW_CORE_PLATFORM_STRONG_HASH_H_
+
+#include "tensorflow/core/platform/platform.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+// This is a strong keyed hash function interface for strings.
+// The hash function is deterministic on the content of the string within the
+// process. The key of the hash is an array of 2 uint64 elements.
+// A strong hash make it dificult, if not infeasible, to compute inputs that
+// hash to the same bucket.
+//
+// Usage:
+//   uint64 key[2] = {123, 456};
+//   string input = "input string";
+//   uint64 hash_value = StrongKeyedHash(key, input);
+//
+uint64 StrongKeyedHash(const uint64 (&)[2], const string&);
+
+}  // namespace tensorflow
+
+#if defined(PLATFORM_GOOGLE)
+#include "tensorflow/core/platform/google/strong_hash.h"
+#else
+#include "tensorflow/core/platform/default/strong_hash.h"
+#endif
+
+#endif  // TENSORFLOW_CORE_PLATFORM_STRONG_HASH_H_
diff --git a/tensorflow/core/util/use_cudnn.cc b/tensorflow/core/util/use_cudnn.cc
index b82100c1320..5890348c86d 100644
--- a/tensorflow/core/util/use_cudnn.cc
+++ b/tensorflow/core/util/use_cudnn.cc
@@ -37,7 +37,7 @@ static bool ReadBoolFromEnvVar(const char* env_var_name, bool default_val) {
 bool CanUseCudnn() { return ReadBoolFromEnvVar("TF_USE_CUDNN", true); }
 
 bool CudnnUseAutotune() {
-  return ReadBoolFromEnvVar("TF_CUDNN_USE_AUTOTUNE", false);
+  return ReadBoolFromEnvVar("TF_CUDNN_USE_AUTOTUNE", true);
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/util/work_sharder.cc b/tensorflow/core/util/work_sharder.cc
index 192b1dcd7c4..1c454d08fa6 100644
--- a/tensorflow/core/util/work_sharder.cc
+++ b/tensorflow/core/util/work_sharder.cc
@@ -32,8 +32,11 @@ void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total,
     return;
   }
 #ifdef EIGEN_USE_NONBLOCKING_THREAD_POOL
-  workers->ParallelFor(total, cost_per_unit, work, max_parallelism);
-#else
+  if (max_parallelism >= workers->NumThreads()) {
+    workers->ParallelFor(total, cost_per_unit, work);
+    return;
+  }
+#endif
   cost_per_unit = std::max(1LL, cost_per_unit);
   // We shard [0, total) into "num_shards" shards.
   //   1 <= num_shards <= num worker threads
@@ -71,7 +74,6 @@ void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total,
   // Inline execute the 1st shard.
   work(0, std::min(block_size, total));
   counter.Wait();
-#endif
 }
 
 }  // end namespace tensorflow
diff --git a/tensorflow/examples/android/jni/tensorflow_jni.cc b/tensorflow/examples/android/jni/tensorflow_jni.cc
index f61eb0655c7..75d834b735b 100644
--- a/tensorflow/examples/android/jni/tensorflow_jni.cc
+++ b/tensorflow/examples/android/jni/tensorflow_jni.cc
@@ -48,7 +48,7 @@ static std::vector<std::string> g_label_strings;
 static bool g_compute_graph_initialized = false;
 //static mutex g_compute_graph_mutex(base::LINKER_INITIALIZED);
 
-static int g_tensorflow_input_size;  // The image size for the mognet input.
+static int g_tensorflow_input_size;  // The image size for the model input.
 static int g_image_mean;  // The image mean.
 static std::unique_ptr<StatSummarizer> g_stats;
 
@@ -82,11 +82,9 @@ inline static int64 CurrentThreadTimeUs() {
   return tv.tv_sec * 1000000 + tv.tv_usec;
 }
 
-JNIEXPORT jint JNICALL
-TENSORFLOW_METHOD(initializeTensorflow)(
-    JNIEnv* env, jobject thiz, jobject java_asset_manager,
-    jstring model, jstring labels,
-    jint num_classes, jint mognet_input_size, jint image_mean) {
+JNIEXPORT jint JNICALL TENSORFLOW_METHOD(initializeTensorflow)(
+    JNIEnv* env, jobject thiz, jobject java_asset_manager, jstring model,
+    jstring labels, jint num_classes, jint model_input_size, jint image_mean) {
   g_num_runs = 0;
   g_timing_total_us = 0;
   g_frequency_start.Reset();
@@ -103,7 +101,7 @@ TENSORFLOW_METHOD(initializeTensorflow)(
   const char* const model_cstr = env->GetStringUTFChars(model, NULL);
   const char* const labels_cstr = env->GetStringUTFChars(labels, NULL);
 
-  g_tensorflow_input_size = mognet_input_size;
+  g_tensorflow_input_size = model_input_size;
   g_image_mean = image_mean;
 
   LOG(INFO) << "Loading Tensorflow.";
diff --git a/tensorflow/examples/android/jni/tensorflow_jni.h b/tensorflow/examples/android/jni/tensorflow_jni.h
index 8c94e76a75a..7c714b986a3 100644
--- a/tensorflow/examples/android/jni/tensorflow_jni.h
+++ b/tensorflow/examples/android/jni/tensorflow_jni.h
@@ -30,11 +30,9 @@ extern "C" {
 #define TENSORFLOW_METHOD(METHOD_NAME) \
   Java_org_tensorflow_demo_TensorflowClassifier_##METHOD_NAME  // NOLINT
 
-JNIEXPORT jint JNICALL
-TENSORFLOW_METHOD(initializeTensorflow)(
-    JNIEnv* env, jobject thiz, jobject java_asset_manager,
-    jstring model, jstring labels,
-    jint num_classes, jint mognet_input_size, jint image_mean);
+JNIEXPORT jint JNICALL TENSORFLOW_METHOD(initializeTensorflow)(
+    JNIEnv* env, jobject thiz, jobject java_asset_manager, jstring model,
+    jstring labels, jint num_classes, jint model_input_size, jint image_mean);
 
 JNIEXPORT jstring JNICALL
 TENSORFLOW_METHOD(classifyImageBmp)(
diff --git a/tensorflow/g3doc/api_docs/python/array_ops.md b/tensorflow/g3doc/api_docs/python/array_ops.md
index 0c6cd7582ec..3662e7eaa3a 100644
--- a/tensorflow/g3doc/api_docs/python/array_ops.md
+++ b/tensorflow/g3doc/api_docs/python/array_ops.md
@@ -1439,14 +1439,21 @@ boolean_mask(tensor, mask) ==> [[1, 2], [5, 6]]
 
 - - -
 
-### `tf.one_hot(indices, depth, on_value=1, off_value=0, axis=None, dtype=tf.float32, name=None)` {#one_hot}
+### `tf.one_hot(indices, depth, on_value=None, off_value=None, axis=None, dtype=None, name=None)` {#one_hot}
 
 Returns a one-hot tensor.
 
 The locations represented by indices in `indices` take value `on_value`,
-while all other locations take value `off_value`. By default, `on_value` is 1,
-and `off_value` is 0. The type of the output tensor is specified by `dtype`,
-which defaults to `tf.float32`.
+while all other locations take value `off_value`.
+
+`on_value` and `off_value` must have matching data types. If `dtype` is also
+provided, they must be the same data type as specified by `dtype`.
+
+If `on_value` is not provided, it will default to the value `1` with type
+`dtype`
+
+If `off_value` is not provided, it will default to the value `0` with type
+`dtype`
 
 If the input `indices` is rank `N`, the output will have rank `N+1`. The
 new axis is created at dimension `axis` (default: the new axis is appended
@@ -1468,6 +1475,13 @@ shape will be:
   depth x batch x features if axis == 0
 ```
 
+If `dtype` is not provided, it will attempt to assume the data type of
+`on_value` or `off_value`, if one or both are passed in. If none of
+`on_value`, `off_value`, or `dtype` are provided, `dtype` will default to the
+value `tf.float32`
+
+Note: If a non-numeric data type output is desired (tf.string, tf.bool, etc.),
+both `on_value` and `off_value` _must_ be provided to `one_hot`
 
 Examples
 =========
@@ -1515,6 +1529,22 @@ Then output is `[2 x 2 x 3]`:
   ]
 ```
 
+Using default values for `on_value` and `off_value`:
+
+```
+  indices = [0, 1, 2]
+  depth = 3
+```
+
+The output will be
+
+```
+  output =
+  [[1., 0., 0.],
+   [0., 1., 0.],
+   [0., 0., 1.]]
+```
+
 ##### Args:
 
 
@@ -1535,7 +1565,8 @@ Then output is `[2 x 2 x 3]`:
 ##### Raises:
 
 
-*  <b>`TypeError`</b>: If dtype is `tf.string`
+*  <b>`TypeError`</b>: If dtype of either `on_value` or `off_value` don't match `dtype`
+*  <b>`TypeError`</b>: If dtype of `on_value` and `off_value` don't match one another
 
 
 
diff --git a/tensorflow/g3doc/api_docs/python/check_ops.md b/tensorflow/g3doc/api_docs/python/check_ops.md
index 0e7e3128beb..88463cf092f 100644
--- a/tensorflow/g3doc/api_docs/python/check_ops.md
+++ b/tensorflow/g3doc/api_docs/python/check_ops.md
@@ -77,6 +77,27 @@ If `x` is empty this is trivially satisfied.
   Op raising `InvalidArgumentError` unless `x` is all positive.
 
 
+- - -
+
+### `tf.assert_proper_iterable(values)` {#assert_proper_iterable}
+
+Static assert that values is a "proper" iterable.
+
+`Ops` that expect iterables of `Tensor` can call this to validate input.
+Useful since `Tensor`, `ndarray`, byte/text type are all iterables themselves.
+
+##### Args:
+
+
+*  <b>`values`</b>: Object to be checked.
+
+##### Raises:
+
+
+*  <b>`TypeError`</b>: If `values` is not iterable or is one of
+    `Tensor`, `SparseTensor`, `np.array`, `tf.compat.bytes_or_text_types`.
+
+
 - - -
 
 ### `tf.assert_non_negative(x, data=None, summarize=None, name=None)` {#assert_non_negative}
@@ -189,6 +210,39 @@ If both `x` and `y` are empty, this is trivially satisfied.
   Op that raises `InvalidArgumentError` if `x == y` is False.
 
 
+- - -
+
+### `tf.assert_integer(x, data=None, summarize=None, name=None)` {#assert_integer}
+
+Assert that `x` is of integer dtype.
+
+Example of adding a dependency to an operation:
+
+```python
+with tf.control_dependencies([tf.assert_integer(x)]):
+  output = tf.reduce_sum(x)
+```
+
+Example of adding dependency to the tensor being checked:
+
+```python
+x = tf.with_dependencies([tf.assert_integer(x)], x)
+```
+
+##### Args:
+
+
+*  <b>`x`</b>: `Tensor` whose basetype is integer and is not quantized.
+*  <b>`data`</b>: The tensors to print out if the condition is False.  Defaults to
+    error message and first few entries of `x`.
+*  <b>`summarize`</b>: Print this many entries of each tensor.
+*  <b>`name`</b>: A name for this operation (optional).  Defaults to "assert_integer".
+
+##### Returns:
+
+  Op that raises `InvalidArgumentError` if `x == y` is False.
+
+
 - - -
 
 ### `tf.assert_less(x, y, data=None, summarize=None, name=None)` {#assert_less}
@@ -362,39 +416,6 @@ Asserts that the given `Tensor` is of the specified type.
 *  <b>`ValueError`</b>: If the tensors data type doesn't match tf_type.
 
 
-- - -
-
-### `tf.assert_integer(x, data=None, summarize=None, name=None)` {#assert_integer}
-
-Assert that `x` is of integer dtype.
-
-Example of adding a dependency to an operation:
-
-```python
-with tf.control_dependencies([tf.assert_integer(x)]):
-  output = tf.reduce_sum(x)
-```
-
-Example of adding dependency to the tensor being checked:
-
-```python
-x = tf.with_dependencies([tf.assert_integer(x)], x)
-```
-
-##### Args:
-
-
-*  <b>`x`</b>: `Tensor` whose basetype is integer and is not quantized.
-*  <b>`data`</b>: The tensors to print out if the condition is False.  Defaults to
-    error message and first few entries of `x`.
-*  <b>`summarize`</b>: Print this many entries of each tensor.
-*  <b>`name`</b>: A name for this operation (optional).  Defaults to "assert_integer".
-
-##### Returns:
-
-  Op that raises `InvalidArgumentError` if `x == y` is False.
-
-
 - - -
 
 ### `tf.is_non_decreasing(x, name=None)` {#is_non_decreasing}
diff --git a/tensorflow/g3doc/api_docs/python/contrib.learn.md b/tensorflow/g3doc/api_docs/python/contrib.learn.md
index b1c8fe225d6..70aff96a846 100644
--- a/tensorflow/g3doc/api_docs/python/contrib.learn.md
+++ b/tensorflow/g3doc/api_docs/python/contrib.learn.md
@@ -34,7 +34,7 @@ Parameters:
 
 - - -
 
-#### `tf.contrib.learn.BaseEstimator.evaluate(x=None, y=None, input_fn=None, feed_fn=None, batch_size=32, steps=100, metrics=None, name=None)` {#BaseEstimator.evaluate}
+#### `tf.contrib.learn.BaseEstimator.evaluate(x=None, y=None, input_fn=None, feed_fn=None, batch_size=32, steps=None, metrics=None, name=None)` {#BaseEstimator.evaluate}
 
 Evaluates given model with provided evaluation data.
 
@@ -150,7 +150,7 @@ to converge, and you want to split up training into subparts.
 
 - - -
 
-#### `tf.contrib.learn.BaseEstimator.predict(x, axis=None, batch_size=None)` {#BaseEstimator.predict}
+#### `tf.contrib.learn.BaseEstimator.predict(x=None, input_fn=None, batch_size=None)` {#BaseEstimator.predict}
 
 Returns predictions for given features.
 
@@ -158,7 +158,7 @@ Returns predictions for given features.
 
 
 *  <b>`x`</b>: features.
-*  <b>`axis`</b>: Axis on which to argmax. (for classification).
+*  <b>`input_fn`</b>: Input function. If set, x must be None.
 *  <b>`batch_size`</b>: Override default batch size.
 
 ##### Returns:
@@ -166,23 +166,6 @@ Returns predictions for given features.
   Numpy array of predicted classes or regression values.
 
 
-- - -
-
-#### `tf.contrib.learn.BaseEstimator.predict_proba(x, batch_size=None)` {#BaseEstimator.predict_proba}
-
-Returns prediction probabilities for given features (classification).
-
-##### Args:
-
-
-*  <b>`x`</b>: features.
-*  <b>`batch_size`</b>: Override default batch size.
-
-##### Returns:
-
-  Numpy array of predicted probabilities.
-
-
 - - -
 
 #### `tf.contrib.learn.BaseEstimator.set_params(**params)` {#BaseEstimator.set_params}
@@ -251,7 +234,7 @@ Parameters:
 
 - - -
 
-#### `tf.contrib.learn.Estimator.evaluate(x=None, y=None, input_fn=None, feed_fn=None, batch_size=32, steps=100, metrics=None, name=None)` {#Estimator.evaluate}
+#### `tf.contrib.learn.Estimator.evaluate(x=None, y=None, input_fn=None, feed_fn=None, batch_size=32, steps=None, metrics=None, name=None)` {#Estimator.evaluate}
 
 Evaluates given model with provided evaluation data.
 
@@ -367,7 +350,7 @@ to converge, and you want to split up training into subparts.
 
 - - -
 
-#### `tf.contrib.learn.Estimator.predict(x, axis=None, batch_size=None)` {#Estimator.predict}
+#### `tf.contrib.learn.Estimator.predict(x=None, input_fn=None, axis=None, batch_size=None)` {#Estimator.predict}
 
 Returns predictions for given features.
 
@@ -375,7 +358,9 @@ Returns predictions for given features.
 
 
 *  <b>`x`</b>: features.
-*  <b>`axis`</b>: Axis on which to argmax. (for classification).
+*  <b>`input_fn`</b>: Input function. If set, x must be None.
+*  <b>`axis`</b>: Axis on which to argmax (for classification).
+        Last axis is used by default.
 *  <b>`batch_size`</b>: Override default batch size.
 
 ##### Returns:
@@ -385,7 +370,7 @@ Returns predictions for given features.
 
 - - -
 
-#### `tf.contrib.learn.Estimator.predict_proba(x, batch_size=None)` {#Estimator.predict_proba}
+#### `tf.contrib.learn.Estimator.predict_proba(x=None, input_fn=None, batch_size=None)` {#Estimator.predict_proba}
 
 Returns prediction probabilities for given features (classification).
 
@@ -393,6 +378,7 @@ Returns prediction probabilities for given features (classification).
 
 
 *  <b>`x`</b>: features.
+*  <b>`input_fn`</b>: Input function. If set, x and y must be None.
 *  <b>`batch_size`</b>: Override default batch size.
 
 ##### Returns:
@@ -3198,7 +3184,7 @@ Attributes:
 
 - - -
 
-### `tf.contrib.learn.evaluate(graph, output_dir, checkpoint_path, eval_dict, update_op=None, global_step_tensor=None, init_op=None, supervisor_master='', log_every_steps=10, feed_fn=None, max_steps=None)` {#evaluate}
+### `tf.contrib.learn.evaluate(graph, output_dir, checkpoint_path, eval_dict, update_op=None, global_step_tensor=None, supervisor_master='', log_every_steps=10, feed_fn=None, max_steps=None)` {#evaluate}
 
 Evaluate a model loaded from a checkpoint.
 
@@ -3219,14 +3205,13 @@ and written to `output_dir`.
 *  <b>`output_dir`</b>: A string containing the directory to write a summary to.
 *  <b>`checkpoint_path`</b>: A string containing the path to a checkpoint to restore.
     Can be `None` if the graph doesn't require loading any variables.
-*  <b>`eval_dict`</b>: A `dict` mapping string names to tensors to evaluate for in every
-    eval step.
-*  <b>`update_op`</b>: A 'Tensor' which is run before evaluating 'eval_dict'.
+*  <b>`eval_dict`</b>: A `dict` mapping string names to tensors to evaluate. It is
+    evaluated in every logging step. The result of the final evaluation is
+    returned. If update_op is None, then it's evaluated in every step.
+*  <b>`update_op`</b>: A `Tensor` which is run in every step.
 *  <b>`global_step_tensor`</b>: A `Variable` containing the global step. If `None`,
     one is extracted from the graph using the same logic as in `Supervisor`.
     Used to place eval summaries on training curves.
-*  <b>`init_op`</b>: An op that initializes the graph. If `None`, use `Supervisor`'s
-    default.
 *  <b>`supervisor_master`</b>: The master string to use when preparing the session.
 *  <b>`log_every_steps`</b>: Integer. Output logs every `log_every_steps` evaluation
     steps. The logs contain the `eval_dict` and timing information.
@@ -3239,7 +3224,7 @@ and written to `output_dir`.
   A tuple `(eval_results, global_step)`:
 
 *  <b>`eval_results`</b>: A `dict` mapping `string` to numeric values (`int`, `float`)
-    that are the eval results from the last step of the eval.  None if no
+    that are the result of running eval_dict in the last step. `None` if no
     eval steps were run.
 *  <b>`global_step`</b>: The global step this evaluation corresponds to.
 
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_proper_iterable.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_proper_iterable.md
new file mode 100644
index 00000000000..ba010737653
--- /dev/null
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_proper_iterable.md
@@ -0,0 +1,18 @@
+### `tf.assert_proper_iterable(values)` {#assert_proper_iterable}
+
+Static assert that values is a "proper" iterable.
+
+`Ops` that expect iterables of `Tensor` can call this to validate input.
+Useful since `Tensor`, `ndarray`, byte/text type are all iterables themselves.
+
+##### Args:
+
+
+*  <b>`values`</b>: Object to be checked.
+
+##### Raises:
+
+
+*  <b>`TypeError`</b>: If `values` is not iterable or is one of
+    `Tensor`, `SparseTensor`, `np.array`, `tf.compat.bytes_or_text_types`.
+
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.BaseEstimator.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.BaseEstimator.md
index 00bb76637ec..034af231a1a 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.BaseEstimator.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.BaseEstimator.md
@@ -19,7 +19,7 @@ Parameters:
 
 - - -
 
-#### `tf.contrib.learn.BaseEstimator.evaluate(x=None, y=None, input_fn=None, feed_fn=None, batch_size=32, steps=100, metrics=None, name=None)` {#BaseEstimator.evaluate}
+#### `tf.contrib.learn.BaseEstimator.evaluate(x=None, y=None, input_fn=None, feed_fn=None, batch_size=32, steps=None, metrics=None, name=None)` {#BaseEstimator.evaluate}
 
 Evaluates given model with provided evaluation data.
 
@@ -135,7 +135,7 @@ to converge, and you want to split up training into subparts.
 
 - - -
 
-#### `tf.contrib.learn.BaseEstimator.predict(x, axis=None, batch_size=None)` {#BaseEstimator.predict}
+#### `tf.contrib.learn.BaseEstimator.predict(x=None, input_fn=None, batch_size=None)` {#BaseEstimator.predict}
 
 Returns predictions for given features.
 
@@ -143,7 +143,7 @@ Returns predictions for given features.
 
 
 *  <b>`x`</b>: features.
-*  <b>`axis`</b>: Axis on which to argmax. (for classification).
+*  <b>`input_fn`</b>: Input function. If set, x must be None.
 *  <b>`batch_size`</b>: Override default batch size.
 
 ##### Returns:
@@ -151,23 +151,6 @@ Returns predictions for given features.
   Numpy array of predicted classes or regression values.
 
 
-- - -
-
-#### `tf.contrib.learn.BaseEstimator.predict_proba(x, batch_size=None)` {#BaseEstimator.predict_proba}
-
-Returns prediction probabilities for given features (classification).
-
-##### Args:
-
-
-*  <b>`x`</b>: features.
-*  <b>`batch_size`</b>: Override default batch size.
-
-##### Returns:
-
-  Numpy array of predicted probabilities.
-
-
 - - -
 
 #### `tf.contrib.learn.BaseEstimator.set_params(**params)` {#BaseEstimator.set_params}
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.Estimator.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.Estimator.md
index dbeba21d0cd..00f12fa0a1f 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.Estimator.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.Estimator.md
@@ -25,7 +25,7 @@ Parameters:
 
 - - -
 
-#### `tf.contrib.learn.Estimator.evaluate(x=None, y=None, input_fn=None, feed_fn=None, batch_size=32, steps=100, metrics=None, name=None)` {#Estimator.evaluate}
+#### `tf.contrib.learn.Estimator.evaluate(x=None, y=None, input_fn=None, feed_fn=None, batch_size=32, steps=None, metrics=None, name=None)` {#Estimator.evaluate}
 
 Evaluates given model with provided evaluation data.
 
@@ -141,7 +141,7 @@ to converge, and you want to split up training into subparts.
 
 - - -
 
-#### `tf.contrib.learn.Estimator.predict(x, axis=None, batch_size=None)` {#Estimator.predict}
+#### `tf.contrib.learn.Estimator.predict(x=None, input_fn=None, axis=None, batch_size=None)` {#Estimator.predict}
 
 Returns predictions for given features.
 
@@ -149,7 +149,9 @@ Returns predictions for given features.
 
 
 *  <b>`x`</b>: features.
-*  <b>`axis`</b>: Axis on which to argmax. (for classification).
+*  <b>`input_fn`</b>: Input function. If set, x must be None.
+*  <b>`axis`</b>: Axis on which to argmax (for classification).
+        Last axis is used by default.
 *  <b>`batch_size`</b>: Override default batch size.
 
 ##### Returns:
@@ -159,7 +161,7 @@ Returns predictions for given features.
 
 - - -
 
-#### `tf.contrib.learn.Estimator.predict_proba(x, batch_size=None)` {#Estimator.predict_proba}
+#### `tf.contrib.learn.Estimator.predict_proba(x=None, input_fn=None, batch_size=None)` {#Estimator.predict_proba}
 
 Returns prediction probabilities for given features (classification).
 
@@ -167,6 +169,7 @@ Returns prediction probabilities for given features (classification).
 
 
 *  <b>`x`</b>: features.
+*  <b>`input_fn`</b>: Input function. If set, x and y must be None.
 *  <b>`batch_size`</b>: Override default batch size.
 
 ##### Returns:
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.evaluate.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.evaluate.md
index e0b0891eb88..022662c3f6d 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.evaluate.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.evaluate.md
@@ -1,4 +1,4 @@
-### `tf.contrib.learn.evaluate(graph, output_dir, checkpoint_path, eval_dict, update_op=None, global_step_tensor=None, init_op=None, supervisor_master='', log_every_steps=10, feed_fn=None, max_steps=None)` {#evaluate}
+### `tf.contrib.learn.evaluate(graph, output_dir, checkpoint_path, eval_dict, update_op=None, global_step_tensor=None, supervisor_master='', log_every_steps=10, feed_fn=None, max_steps=None)` {#evaluate}
 
 Evaluate a model loaded from a checkpoint.
 
@@ -19,14 +19,13 @@ and written to `output_dir`.
 *  <b>`output_dir`</b>: A string containing the directory to write a summary to.
 *  <b>`checkpoint_path`</b>: A string containing the path to a checkpoint to restore.
     Can be `None` if the graph doesn't require loading any variables.
-*  <b>`eval_dict`</b>: A `dict` mapping string names to tensors to evaluate for in every
-    eval step.
-*  <b>`update_op`</b>: A 'Tensor' which is run before evaluating 'eval_dict'.
+*  <b>`eval_dict`</b>: A `dict` mapping string names to tensors to evaluate. It is
+    evaluated in every logging step. The result of the final evaluation is
+    returned. If update_op is None, then it's evaluated in every step.
+*  <b>`update_op`</b>: A `Tensor` which is run in every step.
 *  <b>`global_step_tensor`</b>: A `Variable` containing the global step. If `None`,
     one is extracted from the graph using the same logic as in `Supervisor`.
     Used to place eval summaries on training curves.
-*  <b>`init_op`</b>: An op that initializes the graph. If `None`, use `Supervisor`'s
-    default.
 *  <b>`supervisor_master`</b>: The master string to use when preparing the session.
 *  <b>`log_every_steps`</b>: Integer. Output logs every `log_every_steps` evaluation
     steps. The logs contain the `eval_dict` and timing information.
@@ -39,7 +38,7 @@ and written to `output_dir`.
   A tuple `(eval_results, global_step)`:
 
 *  <b>`eval_results`</b>: A `dict` mapping `string` to numeric values (`int`, `float`)
-    that are the eval results from the last step of the eval.  None if no
+    that are the result of running eval_dict in the last step. `None` if no
     eval steps were run.
 *  <b>`global_step`</b>: The global step this evaluation corresponds to.
 
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.draw_bounding_boxes.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.draw_bounding_boxes.md
index b072207ffed..0e1c6115c7c 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.draw_bounding_boxes.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.draw_bounding_boxes.md
@@ -17,7 +17,7 @@ Parts of the bounding box may fall outside the image.
 ##### Args:
 
 
-*  <b>`images`</b>: A `Tensor` of type `float32`.
+*  <b>`images`</b>: A `Tensor`. Must be one of the following types: `float32`, `half`.
     4-D with shape `[batch, height, width, depth]`. A batch of images.
 *  <b>`boxes`</b>: A `Tensor` of type `float32`.
     3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding
@@ -26,7 +26,7 @@ Parts of the bounding box may fall outside the image.
 
 ##### Returns:
 
-  A `Tensor` of type `float32`.
+  A `Tensor`. Has the same type as `images`.
   4-D with the same shape as `images`. The batch of input images with
   bounding boxes drawn on the images.
 
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.moments.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.moments.md
index f44aa71f12b..704bb5ba49c 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.moments.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.moments.md
@@ -1,4 +1,4 @@
-### `tf.nn.moments(x, axes, name=None, keep_dims=False)` {#moments}
+### `tf.nn.moments(x, axes, shift=None, name=None, keep_dims=False)` {#moments}
 
 Calculate the mean and variance of `x`.
 
@@ -18,6 +18,9 @@ When using these moments for batch normalization (see
 *  <b>`x`</b>: A `Tensor`.
 *  <b>`axes`</b>: array of ints.  Axes along which to compute mean and
     variance.
+*  <b>`shift`</b>: A `Tensor` containing the value by which to shift the data for
+    numerical stability, or `None` if no shift is to be performed. A shift
+    close to the true mean provides the most numerically stable results.
 *  <b>`keep_dims`</b>: produce moments with the same dimensionality as the input.
 *  <b>`name`</b>: Name used to scope the operations that compute the moments.
 
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.sufficient_statistics.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.sufficient_statistics.md
index fc9c64a9852..92cb5596e6b 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.sufficient_statistics.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.sufficient_statistics.md
@@ -1,21 +1,19 @@
-### `tf.nn.sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None)` {#sufficient_statistics}
+### `tf.nn.sufficient_statistics(x, axes, shift=None, keep_dims=False, name=None)` {#sufficient_statistics}
 
 Calculate the sufficient statistics for the mean and variance of `x`.
 
 These sufficient statistics are computed using the one pass algorithm on
-an input that's optionally shifted using the value of the 1st element in `x`.
-See:
+an input that's optionally shifted. See:
 https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data
-Unfortunately, in some cases using a random individual sample as the shift
-value leads experimentally to very poor numerical stability, so it is disabled
-by default. The one-pass approach might have to be revised accordingly.
 
 ##### Args:
 
 
 *  <b>`x`</b>: A `Tensor`.
 *  <b>`axes`</b>: Array of ints. Axes along which to compute mean and variance.
-*  <b>`shift`</b>: If true, shift the data to provide more numerically stable results.
+*  <b>`shift`</b>: A `Tensor` containing the value by which to shift the data for
+    numerical stability, or `None` if no shift is to be performed. A shift
+    close to the true mean provides the most numerically stable results.
 *  <b>`keep_dims`</b>: produce statistics with the same dimensionality as the input.
 *  <b>`name`</b>: Name used to scope the operations that compute the sufficient stats.
 
@@ -25,5 +23,5 @@ by default. The one-pass approach might have to be revised accordingly.
   * the count (number of elements to average over).
   * the (possibly shifted) sum of the elements in the array.
   * the (possibly shifted) sum of squares of the elements in the array.
-  * the shift by which the mean must be corrected or None if `shift` is False.
+  * the shift by which the mean must be corrected or None if `shift` is None.
 
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.one_hot.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.one_hot.md
index 790ab80c8a0..eebb6ab643c 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.one_hot.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.one_hot.md
@@ -1,11 +1,18 @@
-### `tf.one_hot(indices, depth, on_value=1, off_value=0, axis=None, dtype=tf.float32, name=None)` {#one_hot}
+### `tf.one_hot(indices, depth, on_value=None, off_value=None, axis=None, dtype=None, name=None)` {#one_hot}
 
 Returns a one-hot tensor.
 
 The locations represented by indices in `indices` take value `on_value`,
-while all other locations take value `off_value`. By default, `on_value` is 1,
-and `off_value` is 0. The type of the output tensor is specified by `dtype`,
-which defaults to `tf.float32`.
+while all other locations take value `off_value`.
+
+`on_value` and `off_value` must have matching data types. If `dtype` is also
+provided, they must be the same data type as specified by `dtype`.
+
+If `on_value` is not provided, it will default to the value `1` with type
+`dtype`
+
+If `off_value` is not provided, it will default to the value `0` with type
+`dtype`
 
 If the input `indices` is rank `N`, the output will have rank `N+1`. The
 new axis is created at dimension `axis` (default: the new axis is appended
@@ -27,6 +34,13 @@ shape will be:
   depth x batch x features if axis == 0
 ```
 
+If `dtype` is not provided, it will attempt to assume the data type of
+`on_value` or `off_value`, if one or both are passed in. If none of
+`on_value`, `off_value`, or `dtype` are provided, `dtype` will default to the
+value `tf.float32`
+
+Note: If a non-numeric data type output is desired (tf.string, tf.bool, etc.),
+both `on_value` and `off_value` _must_ be provided to `one_hot`
 
 Examples
 =========
@@ -74,6 +88,22 @@ Then output is `[2 x 2 x 3]`:
   ]
 ```
 
+Using default values for `on_value` and `off_value`:
+
+```
+  indices = [0, 1, 2]
+  depth = 3
+```
+
+The output will be
+
+```
+  output =
+  [[1., 0., 0.],
+   [0., 1., 0.],
+   [0., 0., 1.]]
+```
+
 ##### Args:
 
 
@@ -94,5 +124,6 @@ Then output is `[2 x 2 x 3]`:
 ##### Raises:
 
 
-*  <b>`TypeError`</b>: If dtype is `tf.string`
+*  <b>`TypeError`</b>: If dtype of either `on_value` or `off_value` don't match `dtype`
+*  <b>`TypeError`</b>: If dtype of `on_value` and `off_value` don't match one another
 
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_concat.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_concat.md
index 69266b42145..8d05472e340 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_concat.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_concat.md
@@ -1,4 +1,4 @@
-### `tf.sparse_concat(concat_dim, sp_inputs, name=None)` {#sparse_concat}
+### `tf.sparse_concat(concat_dim, sp_inputs, name=None, expand_nonconcat_dim=False)` {#sparse_concat}
 
 Concatenates a list of `SparseTensor` along the specified dimension.
 
@@ -6,11 +6,19 @@ Concatenation is with respect to the dense versions of each sparse input.
 It is assumed that each inputs is a `SparseTensor` whose elements are ordered
 along increasing dimension number.
 
-All inputs' shapes must match, except for the concat dimension.  The
-`indices`, `values`, and `shapes` lists must have the same length.
+If expand_nonconcat_dim is False, all inputs' shapes must match, except for
+the concat dimension. If expand_nonconcat_dim is True, then inputs' shapes are
+allowd to vary among all inputs.
 
-The output shape is identical to the inputs', except along the concat
-dimension, where it is the sum of the inputs' sizes along that dimension.
+The `indices`, `values`, and `shapes` lists must have the same length.
+
+If expand_nonconcat_dim is False, then the output shape is identical to the
+inputs', except along the concat dimension, where it is the sum of the inputs'
+sizes along that dimension.
+
+If expand_nonconcat_dim is True, then the output shape along the non-concat
+dimensions will be expand to be the largest among all inputs, and it is the
+sum of the inputs sizes along the concat dimension.
 
 The output elements will be resorted to preserve the sort order along
 increasing dimension number.
@@ -44,12 +52,42 @@ Graphically this is equivalent to doing
     [    a] concat [  d e  ] = [    a   d e  ]
     [b c  ]        [       ]   [b c          ]
 
+Another example, if 'concat_dim = 1' and the inputs are
+
+    sp_inputs[0]: shape = [3, 3]
+    [0, 2]: "a"
+    [1, 0]: "b"
+    [2, 1]: "c"
+
+    sp_inputs[1]: shape = [2, 4]
+    [0, 1]: "d"
+    [0, 2]: "e"
+
+if expand_nonconcat_dim = False, this will result in an error. But if
+expand_nonconcat_dim = True, this will result in:
+
+    shape = [3, 7]
+    [0, 2]: "a"
+    [0, 4]: "d"
+    [0, 5]: "e"
+    [1, 0]: "b"
+    [2, 1]: "c"
+
+Graphically this is equivalent to doing
+
+    [    a] concat [  d e  ] = [    a   d e  ]
+    [b    ]        [       ]   [b            ]
+    [  c  ]                    [  c          ]
+
+
 ##### Args:
 
 
 *  <b>`concat_dim`</b>: Dimension to concatenate along.
 *  <b>`sp_inputs`</b>: List of `SparseTensor` to concatenate.
 *  <b>`name`</b>: A name prefix for the returned tensors (optional).
+*  <b>`expand_nonconcat_dim`</b>: Whether to allow the expansion in the non-concat
+    dimensions. Defaulted to False.
 
 ##### Returns:
 
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_reset_shape.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_reset_shape.md
new file mode 100644
index 00000000000..d0606cdc5d5
--- /dev/null
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_reset_shape.md
@@ -0,0 +1,60 @@
+### `tf.sparse_reset_shape(sp_input, new_shape=None)` {#sparse_reset_shape}
+
+Resets the shape of a `SparseTensor` with indices and values unchanged.
+
+If `new_shape` is None, returns a copy of `sp_input` with its shape reset
+to the tight bounding box of `sp_input`.
+
+If `new_shape` is provided, then it must be larger or equal in all dimensions
+compared to the shape of `sp_input`. When this condition is met, the returned
+SparseTensor will have its shape reset to `new_shape` and its indices and
+values unchanged from that of `sp_input.`
+
+For example:
+
+  Consider a `sp_input` with shape [2, 3, 5]:
+
+    [0, 0, 1]: a
+    [0, 1, 0]: b
+    [0, 2, 2]: c
+    [1, 0, 3]: d
+
+  - It is an error to set `new_shape` as [3, 7] since this represents a
+    rank-2 tensor while `sp_input` is rank-3. This is either a ValueError
+    during graph construction (if both shapes are known) or an OpError during
+    run time.
+
+  - Setting `new_shape` as [2, 3, 6] will be fine as this shape is larger or
+    eqaul in every dimension compared to the original shape [2, 3, 5].
+
+  - On the other hand, setting new_shape as [2, 3, 4] is also an error: The
+    third dimension is smaller than the original shape [2, 3, 5] (and an
+    `InvalidArgumentError` will be raised).
+
+  - If `new_shape` is None, the returned SparseTensor will have a shape
+    [2, 3, 4], which is the tight bounding box of `sp_input`.
+
+##### Args:
+
+
+*  <b>`sp_input`</b>: The input `SparseTensor`.
+*  <b>`new_shape`</b>: None or a vector representing the new shape for the returned
+    `SpraseTensor`.
+
+##### Returns:
+
+  A `SparseTensor` indices and values unchanged from `input_sp`. Its shape is
+    `new_shape` if that is set. Otherwise it is  the tight bounding box of
+     `input_sp`
+
+##### Raises:
+
+
+*  <b>`TypeError`</b>: If `sp_input` is not a `SparseTensor`.
+*  <b>`ValueError`</b>: If `new_shape` represents a tensor with a different rank from
+    that of `sp_input` (if shapes are known when graph is constructed).
+*  <b>`OpError`</b>: 
+    - If `new_shape` has dimension sizes that are too small.
+    - If shapes are not known during graph construction time, and during run
+      time it is found out that the ranks do not match.
+
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_softmax.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_softmax.md
new file mode 100644
index 00000000000..cb54fd94525
--- /dev/null
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_softmax.md
@@ -0,0 +1,51 @@
+### `tf.sparse_softmax(sp_input, name=None)` {#sparse_softmax}
+
+Applies softmax to a batched N-D `SparseTensor`.
+
+The inputs represent an N-D SparseTensor  with logical shape `[..., B, C]`
+(where `N >= 2`), and with indices sorted in the canonical lexicographic
+order.
+
+This op is equivalent to applying the normal `tf.nn.softmax()` to each
+innermost logical submatrix with shape `[B, C]`, but with the catch that *the
+implicitly zero elements do not participate*.  Specifically, the algorithm is
+equivalent to:
+
+  (1) Applies `tf.nn.softmax()` to a densified view of each innermost
+      submatrix with shape `[B, C]`, along the size-C dimension;
+  (2) Masks out the original implicitly-zero locations;
+  (3) Renormalizes the remaining elements.
+
+Hence, the `SparseTensor` result has exactly the same non-zero indices and
+shape.
+
+Example:
+```python
+# First batch:
+# [?   e.]
+# [1.  ? ]
+# Second batch:
+# [e   ? ]
+# [e   e ]
+shape = [2, 2, 2]  # 3-D SparseTensor
+values = np.asarray([[[0., np.e], [1., 0.]], [[np.e, 0.], [np.e, np.e]]])
+indices = np.vstack(np.where(values)).astype(np.int64).T
+
+result = tf.sparse_softmax(tf.SparseTensor(indices, values, shape))
+# ...returning a 3-D SparseTensor, equivalent to:
+# [?   1.]     [1    ?]
+# [1.  ? ] and [.5  .5]
+# where ? means implicitly zero.
+```
+
+##### Args:
+
+
+*  <b>`sp_input`</b>: N-D `SparseTensor`, where `N >= 2`.
+*  <b>`name`</b>: optional name of the operation.
+
+##### Returns:
+
+
+*  <b>`output`</b>: N-D `SparseTensor` representing the results.
+
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.string_to_hash_bucket_fast.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.string_to_hash_bucket_fast.md
index 79cc778eb94..e684058326f 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.string_to_hash_bucket_fast.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.string_to_hash_bucket_fast.md
@@ -4,11 +4,15 @@ Converts each string in the input Tensor to its hash mod by a number of buckets.
 
 The hash function is deterministic on the content of the string within the
 process and will never change. However, it is not suitable for cryptography.
+This function may be used when CPU time is scarce and inputs are trusted or
+unimportant. There is a risk of adversaries constructing inputs that all hash
+to the same bucket. To prevent this problem, use a strong hash function with
+`tf.string_to_hash_bucket_strong`.
 
 ##### Args:
 
 
-*  <b>`input`</b>: A `Tensor` of type `string`. The strings to assing a hash bucket.
+*  <b>`input`</b>: A `Tensor` of type `string`. The strings to assign a hash bucket.
 *  <b>`num_buckets`</b>: An `int` that is `>= 1`. The number of buckets.
 *  <b>`name`</b>: A name for the operation (optional).
 
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.string_to_hash_bucket_strong.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.string_to_hash_bucket_strong.md
new file mode 100644
index 00000000000..67cf3b6fd98
--- /dev/null
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.string_to_hash_bucket_strong.md
@@ -0,0 +1,30 @@
+### `tf.string_to_hash_bucket_strong(input, num_buckets, key, name=None)` {#string_to_hash_bucket_strong}
+
+Converts each string in the input Tensor to its hash mod by a number of buckets.
+
+The hash function is deterministic on the content of the string within the
+process. The hash function is a keyed hash function, where attribute `key`
+defines the key of the hash function. `key` is an array of 2 elements.
+
+A strong hash is important when inputs may be malicious, e.g. URLs with
+additional components. Adversaries could try to make their inputs hash to the
+same bucket for a denial-of-service attack or to skew the results. A strong
+hash prevents this by making it dificult, if not infeasible, to compute inputs
+that hash to the same bucket. This comes at a cost of roughly 4x higher compute
+time than tf.string_to_hash_bucket_fast.
+
+##### Args:
+
+
+*  <b>`input`</b>: A `Tensor` of type `string`. The strings to assign a hash bucket.
+*  <b>`num_buckets`</b>: An `int` that is `>= 1`. The number of buckets.
+*  <b>`key`</b>: A list of `ints`.
+    The key for the keyed hash function passed as a list of two uint64
+    elements.
+*  <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+  A `Tensor` of type `int64`.
+  A Tensor of the same shape as the input `string_tensor`.
+
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.variable_axis_size_partitioner.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.variable_axis_size_partitioner.md
index 63ff8111266..5d8822e83ca 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.variable_axis_size_partitioner.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.variable_axis_size_partitioner.md
@@ -1,4 +1,4 @@
-### `tf.variable_axis_size_partitioner(max_shard_bytes, axis=0, bytes_per_string_element=16)` {#variable_axis_size_partitioner}
+### `tf.variable_axis_size_partitioner(max_shard_bytes, axis=0, bytes_per_string_element=16, max_shards=None)` {#variable_axis_size_partitioner}
 
 Get a partitioner for VariableScope to keep shards below `max_shard_bytes`.
 
@@ -8,6 +8,10 @@ always possible when sharding along only one axis.  When this happens,
 this axis is sharded as much as possible (i.e., every dimension becomes
 a separate shard).
 
+If the partitioner hits the `max_shards` limit, then each shard may end up
+larger than `max_shard_bytes`. By default `max_shards` equals `None` and no
+limit on the number of shards is enforced.
+
 One reasonable value for `max_shard_bytes` is `(64 << 20) - 1`, or almost
 `64MB`, to keep below the protobuf byte limit.
 
@@ -18,6 +22,8 @@ One reasonable value for `max_shard_bytes` is `(64 << 20) - 1`, or almost
 *  <b>`axis`</b>: The axis to partition along.  Default: outermost axis.
 *  <b>`bytes_per_string_element`</b>: If the `Variable` is of type string, this provides
     an estimate of how large each scalar in the `Variable` is.
+*  <b>`max_shards`</b>: The maximum number of shards in int created taking precedence
+    over `max_shard_bytes`.
 
 ##### Returns:
 
diff --git a/tensorflow/g3doc/api_docs/python/image.md b/tensorflow/g3doc/api_docs/python/image.md
index b297b31405f..3c93de2a02b 100644
--- a/tensorflow/g3doc/api_docs/python/image.md
+++ b/tensorflow/g3doc/api_docs/python/image.md
@@ -1134,7 +1134,7 @@ Parts of the bounding box may fall outside the image.
 ##### Args:
 
 
-*  <b>`images`</b>: A `Tensor` of type `float32`.
+*  <b>`images`</b>: A `Tensor`. Must be one of the following types: `float32`, `half`.
     4-D with shape `[batch, height, width, depth]`. A batch of images.
 *  <b>`boxes`</b>: A `Tensor` of type `float32`.
     3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding
@@ -1143,7 +1143,7 @@ Parts of the bounding box may fall outside the image.
 
 ##### Returns:
 
-  A `Tensor` of type `float32`.
+  A `Tensor`. Has the same type as `images`.
   4-D with the same shape as `images`. The batch of input images with
   bounding boxes drawn on the images.
 
diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md
index 49846bdd3fb..1617ebffa9e 100644
--- a/tensorflow/g3doc/api_docs/python/index.md
+++ b/tensorflow/g3doc/api_docs/python/index.md
@@ -42,6 +42,7 @@
   * [`assert_non_negative`](../../api_docs/python/check_ops.md#assert_non_negative)
   * [`assert_non_positive`](../../api_docs/python/check_ops.md#assert_non_positive)
   * [`assert_positive`](../../api_docs/python/check_ops.md#assert_positive)
+  * [`assert_proper_iterable`](../../api_docs/python/check_ops.md#assert_proper_iterable)
   * [`assert_rank`](../../api_docs/python/check_ops.md#assert_rank)
   * [`assert_rank_at_least`](../../api_docs/python/check_ops.md#assert_rank_at_least)
   * [`assert_type`](../../api_docs/python/check_ops.md#assert_type)
@@ -259,6 +260,7 @@
   * [`reduce_join`](../../api_docs/python/string_ops.md#reduce_join)
   * [`string_to_hash_bucket`](../../api_docs/python/string_ops.md#string_to_hash_bucket)
   * [`string_to_hash_bucket_fast`](../../api_docs/python/string_ops.md#string_to_hash_bucket_fast)
+  * [`string_to_hash_bucket_strong`](../../api_docs/python/string_ops.md#string_to_hash_bucket_strong)
 
 * **[Histograms](../../api_docs/python/histogram_ops.md)**:
   * [`histogram_fixed_width`](../../api_docs/python/histogram_ops.md#histogram_fixed_width)
@@ -348,7 +350,9 @@
   * [`sparse_fill_empty_rows`](../../api_docs/python/sparse_ops.md#sparse_fill_empty_rows)
   * [`sparse_merge`](../../api_docs/python/sparse_ops.md#sparse_merge)
   * [`sparse_reorder`](../../api_docs/python/sparse_ops.md#sparse_reorder)
+  * [`sparse_reset_shape`](../../api_docs/python/sparse_ops.md#sparse_reset_shape)
   * [`sparse_retain`](../../api_docs/python/sparse_ops.md#sparse_retain)
+  * [`sparse_softmax`](../../api_docs/python/sparse_ops.md#sparse_softmax)
   * [`sparse_split`](../../api_docs/python/sparse_ops.md#sparse_split)
   * [`sparse_tensor_dense_matmul`](../../api_docs/python/sparse_ops.md#sparse_tensor_dense_matmul)
   * [`sparse_tensor_to_dense`](../../api_docs/python/sparse_ops.md#sparse_tensor_to_dense)
diff --git a/tensorflow/g3doc/api_docs/python/nn.md b/tensorflow/g3doc/api_docs/python/nn.md
index 5b594b33e63..35a94808c1e 100644
--- a/tensorflow/g3doc/api_docs/python/nn.md
+++ b/tensorflow/g3doc/api_docs/python/nn.md
@@ -824,24 +824,22 @@ convolutional neural networks (NIPS 2012)]
 
 - - -
 
-### `tf.nn.sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None)` {#sufficient_statistics}
+### `tf.nn.sufficient_statistics(x, axes, shift=None, keep_dims=False, name=None)` {#sufficient_statistics}
 
 Calculate the sufficient statistics for the mean and variance of `x`.
 
 These sufficient statistics are computed using the one pass algorithm on
-an input that's optionally shifted using the value of the 1st element in `x`.
-See:
+an input that's optionally shifted. See:
 https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data
-Unfortunately, in some cases using a random individual sample as the shift
-value leads experimentally to very poor numerical stability, so it is disabled
-by default. The one-pass approach might have to be revised accordingly.
 
 ##### Args:
 
 
 *  <b>`x`</b>: A `Tensor`.
 *  <b>`axes`</b>: Array of ints. Axes along which to compute mean and variance.
-*  <b>`shift`</b>: If true, shift the data to provide more numerically stable results.
+*  <b>`shift`</b>: A `Tensor` containing the value by which to shift the data for
+    numerical stability, or `None` if no shift is to be performed. A shift
+    close to the true mean provides the most numerically stable results.
 *  <b>`keep_dims`</b>: produce statistics with the same dimensionality as the input.
 *  <b>`name`</b>: Name used to scope the operations that compute the sufficient stats.
 
@@ -851,7 +849,7 @@ by default. The one-pass approach might have to be revised accordingly.
   * the count (number of elements to average over).
   * the (possibly shifted) sum of the elements in the array.
   * the (possibly shifted) sum of squares of the elements in the array.
-  * the shift by which the mean must be corrected or None if `shift` is False.
+  * the shift by which the mean must be corrected or None if `shift` is None.
 
 
 - - -
@@ -879,7 +877,7 @@ Calculate the mean and variance of based on the sufficient statistics.
 
 - - -
 
-### `tf.nn.moments(x, axes, name=None, keep_dims=False)` {#moments}
+### `tf.nn.moments(x, axes, shift=None, name=None, keep_dims=False)` {#moments}
 
 Calculate the mean and variance of `x`.
 
@@ -899,6 +897,9 @@ When using these moments for batch normalization (see
 *  <b>`x`</b>: A `Tensor`.
 *  <b>`axes`</b>: array of ints.  Axes along which to compute mean and
     variance.
+*  <b>`shift`</b>: A `Tensor` containing the value by which to shift the data for
+    numerical stability, or `None` if no shift is to be performed. A shift
+    close to the true mean provides the most numerically stable results.
 *  <b>`keep_dims`</b>: produce moments with the same dimensionality as the input.
 *  <b>`name`</b>: Name used to scope the operations that compute the moments.
 
diff --git a/tensorflow/g3doc/api_docs/python/sparse_ops.md b/tensorflow/g3doc/api_docs/python/sparse_ops.md
index af42487423e..ba34af0eb39 100644
--- a/tensorflow/g3doc/api_docs/python/sparse_ops.md
+++ b/tensorflow/g3doc/api_docs/python/sparse_ops.md
@@ -422,7 +422,7 @@ equal to:
 
 - - -
 
-### `tf.sparse_concat(concat_dim, sp_inputs, name=None)` {#sparse_concat}
+### `tf.sparse_concat(concat_dim, sp_inputs, name=None, expand_nonconcat_dim=False)` {#sparse_concat}
 
 Concatenates a list of `SparseTensor` along the specified dimension.
 
@@ -430,11 +430,19 @@ Concatenation is with respect to the dense versions of each sparse input.
 It is assumed that each inputs is a `SparseTensor` whose elements are ordered
 along increasing dimension number.
 
-All inputs' shapes must match, except for the concat dimension.  The
-`indices`, `values`, and `shapes` lists must have the same length.
+If expand_nonconcat_dim is False, all inputs' shapes must match, except for
+the concat dimension. If expand_nonconcat_dim is True, then inputs' shapes are
+allowd to vary among all inputs.
 
-The output shape is identical to the inputs', except along the concat
-dimension, where it is the sum of the inputs' sizes along that dimension.
+The `indices`, `values`, and `shapes` lists must have the same length.
+
+If expand_nonconcat_dim is False, then the output shape is identical to the
+inputs', except along the concat dimension, where it is the sum of the inputs'
+sizes along that dimension.
+
+If expand_nonconcat_dim is True, then the output shape along the non-concat
+dimensions will be expand to be the largest among all inputs, and it is the
+sum of the inputs sizes along the concat dimension.
 
 The output elements will be resorted to preserve the sort order along
 increasing dimension number.
@@ -468,12 +476,42 @@ Graphically this is equivalent to doing
     [    a] concat [  d e  ] = [    a   d e  ]
     [b c  ]        [       ]   [b c          ]
 
+Another example, if 'concat_dim = 1' and the inputs are
+
+    sp_inputs[0]: shape = [3, 3]
+    [0, 2]: "a"
+    [1, 0]: "b"
+    [2, 1]: "c"
+
+    sp_inputs[1]: shape = [2, 4]
+    [0, 1]: "d"
+    [0, 2]: "e"
+
+if expand_nonconcat_dim = False, this will result in an error. But if
+expand_nonconcat_dim = True, this will result in:
+
+    shape = [3, 7]
+    [0, 2]: "a"
+    [0, 4]: "d"
+    [0, 5]: "e"
+    [1, 0]: "b"
+    [2, 1]: "c"
+
+Graphically this is equivalent to doing
+
+    [    a] concat [  d e  ] = [    a   d e  ]
+    [b    ]        [       ]   [b            ]
+    [  c  ]                    [  c          ]
+
+
 ##### Args:
 
 
 *  <b>`concat_dim`</b>: Dimension to concatenate along.
 *  <b>`sp_inputs`</b>: List of `SparseTensor` to concatenate.
 *  <b>`name`</b>: A name prefix for the returned tensors (optional).
+*  <b>`expand_nonconcat_dim`</b>: Whether to allow the expansion in the non-concat
+    dimensions. Defaulted to False.
 
 ##### Returns:
 
@@ -608,6 +646,69 @@ be a `SparseTensor` of shape `[4, 5]` with 2 non-empty values:
 *  <b>`TypeError`</b>: If `sp_input` is not a `SparseTensor`.
 
 
+- - -
+
+### `tf.sparse_reset_shape(sp_input, new_shape=None)` {#sparse_reset_shape}
+
+Resets the shape of a `SparseTensor` with indices and values unchanged.
+
+If `new_shape` is None, returns a copy of `sp_input` with its shape reset
+to the tight bounding box of `sp_input`.
+
+If `new_shape` is provided, then it must be larger or equal in all dimensions
+compared to the shape of `sp_input`. When this condition is met, the returned
+SparseTensor will have its shape reset to `new_shape` and its indices and
+values unchanged from that of `sp_input.`
+
+For example:
+
+  Consider a `sp_input` with shape [2, 3, 5]:
+
+    [0, 0, 1]: a
+    [0, 1, 0]: b
+    [0, 2, 2]: c
+    [1, 0, 3]: d
+
+  - It is an error to set `new_shape` as [3, 7] since this represents a
+    rank-2 tensor while `sp_input` is rank-3. This is either a ValueError
+    during graph construction (if both shapes are known) or an OpError during
+    run time.
+
+  - Setting `new_shape` as [2, 3, 6] will be fine as this shape is larger or
+    eqaul in every dimension compared to the original shape [2, 3, 5].
+
+  - On the other hand, setting new_shape as [2, 3, 4] is also an error: The
+    third dimension is smaller than the original shape [2, 3, 5] (and an
+    `InvalidArgumentError` will be raised).
+
+  - If `new_shape` is None, the returned SparseTensor will have a shape
+    [2, 3, 4], which is the tight bounding box of `sp_input`.
+
+##### Args:
+
+
+*  <b>`sp_input`</b>: The input `SparseTensor`.
+*  <b>`new_shape`</b>: None or a vector representing the new shape for the returned
+    `SpraseTensor`.
+
+##### Returns:
+
+  A `SparseTensor` indices and values unchanged from `input_sp`. Its shape is
+    `new_shape` if that is set. Otherwise it is  the tight bounding box of
+     `input_sp`
+
+##### Raises:
+
+
+*  <b>`TypeError`</b>: If `sp_input` is not a `SparseTensor`.
+*  <b>`ValueError`</b>: If `new_shape` represents a tensor with a different rank from
+    that of `sp_input` (if shapes are known when graph is constructed).
+*  <b>`OpError`</b>: 
+    - If `new_shape` has dimension sizes that are too small.
+    - If shapes are not known during graph construction time, and during run
+      time it is found out that the ranks do not match.
+
+
 - - -
 
 ### `tf.sparse_fill_empty_rows(sp_input, default_value, name=None)` {#sparse_fill_empty_rows}
@@ -725,6 +826,60 @@ Then,
 *  <b>`TypeError`</b>: If both `a` and `b` are `Tensor`s.  Use `tf.add()` instead.
 
 
+- - -
+
+### `tf.sparse_softmax(sp_input, name=None)` {#sparse_softmax}
+
+Applies softmax to a batched N-D `SparseTensor`.
+
+The inputs represent an N-D SparseTensor  with logical shape `[..., B, C]`
+(where `N >= 2`), and with indices sorted in the canonical lexicographic
+order.
+
+This op is equivalent to applying the normal `tf.nn.softmax()` to each
+innermost logical submatrix with shape `[B, C]`, but with the catch that *the
+implicitly zero elements do not participate*.  Specifically, the algorithm is
+equivalent to:
+
+  (1) Applies `tf.nn.softmax()` to a densified view of each innermost
+      submatrix with shape `[B, C]`, along the size-C dimension;
+  (2) Masks out the original implicitly-zero locations;
+  (3) Renormalizes the remaining elements.
+
+Hence, the `SparseTensor` result has exactly the same non-zero indices and
+shape.
+
+Example:
+```python
+# First batch:
+# [?   e.]
+# [1.  ? ]
+# Second batch:
+# [e   ? ]
+# [e   e ]
+shape = [2, 2, 2]  # 3-D SparseTensor
+values = np.asarray([[[0., np.e], [1., 0.]], [[np.e, 0.], [np.e, np.e]]])
+indices = np.vstack(np.where(values)).astype(np.int64).T
+
+result = tf.sparse_softmax(tf.SparseTensor(indices, values, shape))
+# ...returning a 3-D SparseTensor, equivalent to:
+# [?   1.]     [1    ?]
+# [1.  ? ] and [.5  .5]
+# where ? means implicitly zero.
+```
+
+##### Args:
+
+
+*  <b>`sp_input`</b>: N-D `SparseTensor`, where `N >= 2`.
+*  <b>`name`</b>: optional name of the operation.
+
+##### Returns:
+
+
+*  <b>`output`</b>: N-D `SparseTensor` representing the results.
+
+
 - - -
 
 ### `tf.sparse_tensor_dense_matmul(sp_a, b, adjoint_a=False, adjoint_b=False, name=None)` {#sparse_tensor_dense_matmul}
diff --git a/tensorflow/g3doc/api_docs/python/state_ops.md b/tensorflow/g3doc/api_docs/python/state_ops.md
index 2565142a4b7..cff67f21c15 100644
--- a/tensorflow/g3doc/api_docs/python/state_ops.md
+++ b/tensorflow/g3doc/api_docs/python/state_ops.md
@@ -1692,7 +1692,7 @@ An adaptor for ones() to match the Initializer spec.
 
 - - -
 
-### `tf.variable_axis_size_partitioner(max_shard_bytes, axis=0, bytes_per_string_element=16)` {#variable_axis_size_partitioner}
+### `tf.variable_axis_size_partitioner(max_shard_bytes, axis=0, bytes_per_string_element=16, max_shards=None)` {#variable_axis_size_partitioner}
 
 Get a partitioner for VariableScope to keep shards below `max_shard_bytes`.
 
@@ -1702,6 +1702,10 @@ always possible when sharding along only one axis.  When this happens,
 this axis is sharded as much as possible (i.e., every dimension becomes
 a separate shard).
 
+If the partitioner hits the `max_shards` limit, then each shard may end up
+larger than `max_shard_bytes`. By default `max_shards` equals `None` and no
+limit on the number of shards is enforced.
+
 One reasonable value for `max_shard_bytes` is `(64 << 20) - 1`, or almost
 `64MB`, to keep below the protobuf byte limit.
 
@@ -1712,6 +1716,8 @@ One reasonable value for `max_shard_bytes` is `(64 << 20) - 1`, or almost
 *  <b>`axis`</b>: The axis to partition along.  Default: outermost axis.
 *  <b>`bytes_per_string_element`</b>: If the `Variable` is of type string, this provides
     an estimate of how large each scalar in the `Variable` is.
+*  <b>`max_shards`</b>: The maximum number of shards in int created taking precedence
+    over `max_shard_bytes`.
 
 ##### Returns:
 
diff --git a/tensorflow/g3doc/api_docs/python/string_ops.md b/tensorflow/g3doc/api_docs/python/string_ops.md
index 302d9df8099..a516d851cf2 100644
--- a/tensorflow/g3doc/api_docs/python/string_ops.md
+++ b/tensorflow/g3doc/api_docs/python/string_ops.md
@@ -20,11 +20,15 @@ Converts each string in the input Tensor to its hash mod by a number of buckets.
 
 The hash function is deterministic on the content of the string within the
 process and will never change. However, it is not suitable for cryptography.
+This function may be used when CPU time is scarce and inputs are trusted or
+unimportant. There is a risk of adversaries constructing inputs that all hash
+to the same bucket. To prevent this problem, use a strong hash function with
+`tf.string_to_hash_bucket_strong`.
 
 ##### Args:
 
 
-*  <b>`input`</b>: A `Tensor` of type `string`. The strings to assing a hash bucket.
+*  <b>`input`</b>: A `Tensor` of type `string`. The strings to assign a hash bucket.
 *  <b>`num_buckets`</b>: An `int` that is `>= 1`. The number of buckets.
 *  <b>`name`</b>: A name for the operation (optional).
 
@@ -34,6 +38,39 @@ process and will never change. However, it is not suitable for cryptography.
   A Tensor of the same shape as the input `string_tensor`.
 
 
+- - -
+
+### `tf.string_to_hash_bucket_strong(input, num_buckets, key, name=None)` {#string_to_hash_bucket_strong}
+
+Converts each string in the input Tensor to its hash mod by a number of buckets.
+
+The hash function is deterministic on the content of the string within the
+process. The hash function is a keyed hash function, where attribute `key`
+defines the key of the hash function. `key` is an array of 2 elements.
+
+A strong hash is important when inputs may be malicious, e.g. URLs with
+additional components. Adversaries could try to make their inputs hash to the
+same bucket for a denial-of-service attack or to skew the results. A strong
+hash prevents this by making it dificult, if not infeasible, to compute inputs
+that hash to the same bucket. This comes at a cost of roughly 4x higher compute
+time than tf.string_to_hash_bucket_fast.
+
+##### Args:
+
+
+*  <b>`input`</b>: A `Tensor` of type `string`. The strings to assign a hash bucket.
+*  <b>`num_buckets`</b>: An `int` that is `>= 1`. The number of buckets.
+*  <b>`key`</b>: A list of `ints`.
+    The key for the keyed hash function passed as a list of two uint64
+    elements.
+*  <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+  A `Tensor` of type `int64`.
+  A Tensor of the same shape as the input `string_tensor`.
+
+
 - - -
 
 ### `tf.string_to_hash_bucket(string_tensor, num_buckets, name=None)` {#string_to_hash_bucket}
diff --git a/tensorflow/g3doc/how_tos/style_guide.md b/tensorflow/g3doc/how_tos/style_guide.md
index 715853b93ef..8b498e92721 100644
--- a/tensorflow/g3doc/how_tos/style_guide.md
+++ b/tensorflow/g3doc/how_tos/style_guide.md
@@ -92,8 +92,8 @@ creates a part of the graph and returns output tensors.
  If operation needs to save some `Tensor`s to Graph collections,
  put the arguments with names of the collections right before `name` argument.
 
-* Tensor arguments should be either a single tensor or an iterable of tensors,
-  not both.  E.g. a "Tensor or list of Tensors" is too broad.
+* Tensor arguments should be either a single tensor or an iterable of tensors.
+  E.g. a "Tensor or list of Tensors" is too broad. See `assert_proper_iterable`.
 
 * Operations that take tensors as arguments should call `convert_to_tensor`
  to convert non-tensor inputs into tensors if they are using C++ operations.
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index e327fec3304..32f0e3f6ceb 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -900,10 +900,12 @@ class InteractiveSession(BaseSession):
 
     super(InteractiveSession, self).__init__(target, graph, config)
     self._default_session = self.as_default()
+    self._default_session.enforce_nesting = False
     self._default_session.__enter__()
     self._explicit_graph = graph
     if self._explicit_graph is not None:
       self._default_graph = graph.as_default()
+      self._default_graph.enforce_nesting = False
       self._default_graph.__enter__()
 
   def close(self):
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index c2e69ea1131..c3cce65a9bf 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -1138,11 +1138,33 @@ class SessionTest(test_util.TensorFlowTestCase):
         d = math_ops.mul(c, c)
       for step in xrange(120):
         run_metadata = config_pb2.RunMetadata()
-        sess.run(d, feed_dict={a: 1.0}, options=run_options, run_metadata=run_metadata)
+        sess.run(d, feed_dict={a: 1.0},
+                 options=run_options, run_metadata=run_metadata)
         if step == 99:
           self.assertTrue(run_metadata.HasField('cost_graph'))
         else:
           self.assertFalse(run_metadata.HasField('cost_graph'))
 
+  def testNonInteractiveSessionNesting(self):
+    sess1 = session.Session()
+    sess1_controller = sess1.as_default()
+    sess1_controller.__enter__()
+
+    sess2 = session.Session()
+    sess2_controller = sess2.as_default()
+    sess2_controller.__enter__()
+
+    with self.assertRaisesRegexp(AssertionError, 'Nesting violated'):
+      sess1_controller.__exit__(None, None, None)
+
+    ops._default_session_stack.reset()
+
+  def testInteractiveSessionNesting(self):
+    sess1 = session.InteractiveSession()
+    sess2 = session.InteractiveSession()
+    del sess1
+    del sess2
+
+
 if __name__ == '__main__':
   googletest.main()
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 2667b544cce..37178746933 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -3306,6 +3306,7 @@ class _DefaultStack(threading.local):
 
   def __init__(self):
     super(_DefaultStack, self).__init__()
+    self._enforce_nesting = True
     self.stack = []
 
   def get_default(self):
@@ -3314,6 +3315,14 @@ class _DefaultStack(threading.local):
   def reset(self):
     self.stack = []
 
+  @property
+  def enforce_nesting(self):
+    return self._enforce_nesting
+
+  @enforce_nesting.setter
+  def enforce_nesting(self, value):
+    self._enforce_nesting = value
+
   @contextlib.contextmanager
   def get_controller(self, default):
     """A context manager for manipulating a default stack."""
@@ -3321,9 +3330,14 @@ class _DefaultStack(threading.local):
       self.stack.append(default)
       yield default
     finally:
-      assert self.stack[-1] is default
-      self.stack.pop()
-
+      if self._enforce_nesting:
+        if self.stack[-1] is not default:
+          raise AssertionError(
+              "Nesting violated for default stack of %s objects"
+              % type(default))
+        self.stack.pop()
+      else:
+        self.stack.remove(default)
 
 _default_session_stack = _DefaultStack()
 
@@ -3686,6 +3700,8 @@ class GraphKeys(object):
   ACTIVATIONS = "activations"
   # Key to collect update_ops
   UPDATE_OPS = "update_ops"
+  # Key to collect losses
+  LOSSES = "losses"
 
   # Key to indicate various ops.
   INIT_OP = "init_op"
diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py
index 39bcc1cc41c..140fafa64c7 100644
--- a/tensorflow/python/kernel_tests/check_ops_test.py
+++ b/tensorflow/python/kernel_tests/check_ops_test.py
@@ -21,6 +21,42 @@ import numpy as np
 import tensorflow as tf
 
 
+class AssertProperIterableTest(tf.test.TestCase):
+
+  def test_single_tensor_raises(self):
+    tensor = tf.constant(1)
+    with self.assertRaisesRegexp(TypeError, "proper"):
+      tf.assert_proper_iterable(tensor)
+
+  def test_single_sparse_tensor_raises(self):
+    ten = tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], shape=[3, 4])
+    with self.assertRaisesRegexp(TypeError, "proper"):
+      tf.assert_proper_iterable(ten)
+
+  def test_single_ndarray_raises(self):
+    array = np.array([1, 2, 3])
+    with self.assertRaisesRegexp(TypeError, "proper"):
+      tf.assert_proper_iterable(array)
+
+  def test_single_string_raises(self):
+    mystr = "hello"
+    with self.assertRaisesRegexp(TypeError, "proper"):
+      tf.assert_proper_iterable(mystr)
+
+  def test_non_iterable_object_raises(self):
+    non_iterable = 1234
+    with self.assertRaisesRegexp(TypeError, "to be iterable"):
+      tf.assert_proper_iterable(non_iterable)
+
+  def test_list_does_not_raise(self):
+    list_of_stuff = [tf.constant([11, 22]), tf.constant([1, 2])]
+    tf.assert_proper_iterable(list_of_stuff)
+
+  def test_generator_does_not_raise(self):
+    generator_of_stuff = (tf.constant([11, 22]), tf.constant([1, 2]))
+    tf.assert_proper_iterable(generator_of_stuff)
+
+
 class AssertEqualTest(tf.test.TestCase):
 
   def test_doesnt_raise_when_equal(self):
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index 596390bf428..556e0f65dda 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -515,6 +515,9 @@ class BinaryOpTest(tf.test.TestCase):
     self._compareBoth(x, y, np.true_divide, _TRUEDIV)
     self._compareBoth(x, y, np.floor_divide, _FLOORDIV)
     self._compareBoth(x, y, np.mod, _MOD)
+    # _compareBoth tests on GPU only for floating point types, so test
+    # _MOD for int32 on GPU by calling _compareGpu
+    self._compareGpu(x, y, np.mod, _MOD)
 
   def testInt64Basic(self):
     x = np.arange(1 << 40, 13 << 40, 2 << 40).reshape(1, 3, 2).astype(np.int64)
diff --git a/tensorflow/python/kernel_tests/partitioned_variables_test.py b/tensorflow/python/kernel_tests/partitioned_variables_test.py
index df6d1a994aa..5a9ad223497 100644
--- a/tensorflow/python/kernel_tests/partitioned_variables_test.py
+++ b/tensorflow/python/kernel_tests/partitioned_variables_test.py
@@ -31,6 +31,19 @@ get_partitioned_variable_list = variable_scope._get_partitioned_variable_list
 
 class PartitionerCreatorsTest(tf.test.TestCase):
 
+  def _testVariableAxisSizePartitioner(self, name, axis, max_shard_bytes,
+                                       expected_axis_shards,
+                                       expected_partitions,
+                                       max_shards=None):
+    partitioner = tf.variable_axis_size_partitioner(
+        axis=axis, max_shard_bytes=max_shard_bytes, max_shards=max_shards)
+
+    with tf.variable_scope("root", partitioner=partitioner):
+      v0_list, v0_part = get_partitioned_variable_list(
+          name, dtype=tf.float32, shape=(4, 8, 16, 32))
+      self.assertEqual(len(v0_list), expected_axis_shards)
+      self.assertAllEqual(v0_part, expected_partitions)
+
   def testVariableAxisSizePartitioner(self):
     with self.test_session():
       # Create a partitioned variable of shape (4, 8, 16, 32) type float32
@@ -43,69 +56,62 @@ class PartitionerCreatorsTest(tf.test.TestCase):
 
       # Now partition it in different ways...
 
-      partitioner_axis0 = tf.variable_axis_size_partitioner(
-          axis=0, max_shard_bytes=131072, bytes_per_string_element=8)
+      # No need to slice: bytes_per_slice * dim0 = 65536 < max_shard_bytes
+      self._testVariableAxisSizePartitioner("v0", axis=0,
+                                            max_shard_bytes=131072,
+                                            expected_axis_shards=1,
+                                            expected_partitions=(1, 1, 1, 1))
 
-      with tf.variable_scope("root", partitioner=partitioner_axis0):
-        v0_list, v0_part = get_partitioned_variable_list(
-            "v0", dtype=tf.float32, shape=(4, 8, 16, 32))
-        # No need to slice: bytes_per_slice * dim0 = 65536 < max_shard_bytes
-        self.assertEqual(len(v0_list), 1)
-        self.assertAllEqual(v0_part, (1, 1, 1, 1))
+      # Slice exactly once: bytes_per_slice * dim1 = 65536 = max_shard_bytes
+      self._testVariableAxisSizePartitioner("v1", axis=1,
+                                            max_shard_bytes=65536,
+                                            expected_axis_shards=1,
+                                            expected_partitions=(1, 1, 1, 1))
 
-      partitioner_axis1 = tf.variable_axis_size_partitioner(
-          axis=1, max_shard_bytes=65536, bytes_per_string_element=8)
+      # Slice into 2 parts:
+      # bytes_per_slice = 4096
+      # slices_per_shard = 32768 / 4096 = 8
+      # axis_shards = 16 / 8 = 2
+      self._testVariableAxisSizePartitioner("v2", axis=2,
+                                            max_shard_bytes=32768,
+                                            expected_axis_shards=2,
+                                            expected_partitions=(1, 1, 2, 1))
 
-      with tf.variable_scope("root", partitioner=partitioner_axis1):
-        v1_list, v1_part = get_partitioned_variable_list(
-            "v1", dtype=tf.float32, shape=(4, 8, 16, 32))
-        # Slice exactly once: bytes_per_slice * dim1 = 65536 = max_shard_bytes
-        self.assertEqual(len(v1_list), 1)
-        self.assertAllEqual(v1_part, (1, 1, 1, 1))
-
-      partitioner_axis2 = tf.variable_axis_size_partitioner(
-          axis=2, max_shard_bytes=32768, bytes_per_string_element=8)
-
-      with tf.variable_scope("root", partitioner=partitioner_axis2):
-        v2_list, v2_part = get_partitioned_variable_list(
-            "v2", dtype=tf.float32, shape=(4, 8, 16, 32))
-        # Slice into 2 parts:
-        # bytes_per_slice = 4096
-        # slices_per_shard = 32768 / 4096 = 8
-        # axis_shards = 16 / 8 = 2
-        self.assertEqual(len(v2_list), 2)
-        self.assertAllEqual(v2_part, (1, 1, 2, 1))
-
-      # This partitioner makes sure we maximize the number of shards
-      # along axis 3
-      partitioner_axis3_a = tf.variable_axis_size_partitioner(
-          axis=3, max_shard_bytes=2048, bytes_per_string_element=8)
-
-      with tf.variable_scope("root", partitioner=partitioner_axis3_a):
-        v3a_list, v3a_part = get_partitioned_variable_list(
-            "v3a", dtype=tf.float32, shape=(4, 8, 16, 32))
-        # Slice into 32 parts:
-        # bytes_per_slice = 2048
-        # slices_per_shard = 2048 / 2048 = 1
-        # axis_shards = 32 / 1 = 32
-        self.assertEqual(len(v3a_list), 32)
-        self.assertAllEqual(v3a_part, (1, 1, 1, 32))
+      # This partitioner makes sure we maximize the number of shards along
+      # axis 3. Slice it into 32 parts:
+      # bytes_per_slice = 2048
+      # slices_per_shard = 2048 / 2048 = 1
+      # axis_shards = 32 / 1 = 32
+      self._testVariableAxisSizePartitioner("v3a", axis=3,
+                                            max_shard_bytes=2048,
+                                            expected_axis_shards=32,
+                                            expected_partitions=(1, 1, 1, 32))
 
       # This partitioner makes sure we do not go past the bound of allowable
-      # number of shards along axis 3
-      partitioner_axis3_b = tf.variable_axis_size_partitioner(
-          axis=3, max_shard_bytes=1024, bytes_per_string_element=8)
+      # number of shards along axis 3.
+      # Slice into 32 parts:
+      # bytes_per_slice = 2048
+      # slices_per_shard = max(1, 1024 / 2048) = 1
+      # axis_shards = 32 / 1 = 32
+      # Slice into max of 32 parts because: max_shard_bytes < bytes_per_slice
+      self._testVariableAxisSizePartitioner("v3b", axis=3,
+                                            max_shard_bytes=1024,
+                                            expected_axis_shards=32,
+                                            expected_partitions=(1, 1, 1, 32))
 
-      with tf.variable_scope("root", partitioner=partitioner_axis3_b):
-        v3b_list, v3b_part = get_partitioned_variable_list(
-            "v3b", dtype=tf.float32, shape=(4, 8, 16, 32))
-        # Slice into 32 parts:
-        # bytes_per_slice = 2048
-        # slices_per_shard = max(1, 1024 / 2048) = 1
-        # axis_shards = 32 / 1 = 32
-        # Slice into max of 32 parts because: max_shard_bytes < bytes_per_slice
-        self.assertEqual(len(v3b_list), 32)
-        self.assertAllEqual(v3b_part, (1, 1, 1, 32))
+      # Specify max_shards so that it won't affect sharding.
+      self._testVariableAxisSizePartitioner("v3c", axis=3,
+                                            max_shard_bytes=1024,
+                                            expected_axis_shards=32,
+                                            expected_partitions=(1, 1, 1, 32),
+                                            max_shards=33)
+
+      # Specify max_shards so that it will affect sharding.
+      self._testVariableAxisSizePartitioner("v3d", axis=3,
+                                            max_shard_bytes=1024,
+                                            expected_axis_shards=2,
+                                            expected_partitions=(1, 1, 1, 2),
+                                            max_shards=2)
 
       # Use the partitioner with strings
       partitioner_axis3_str = tf.variable_axis_size_partitioner(
diff --git a/tensorflow/python/kernel_tests/rnn_cell_test.py b/tensorflow/python/kernel_tests/rnn_cell_test.py
index e3756e03d25..10a1a2e2a39 100644
--- a/tensorflow/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/python/kernel_tests/rnn_cell_test.py
@@ -76,7 +76,7 @@ class RNNCellTest(tf.test.TestCase):
       with tf.variable_scope("other", initializer=tf.constant_initializer(0.5)):
         x = tf.zeros([1, 3])  # Test GRUCell with input_size != num_units.
         m = tf.zeros([1, 2])
-        g, _ = tf.nn.rnn_cell.GRUCell(2, input_size=3)(x, m)
+        g, _ = tf.nn.rnn_cell.GRUCell(2)(x, m)
         sess.run([tf.initialize_all_variables()])
         res = sess.run([g], {x.name: np.array([[1., 1., 1.]]),
                              m.name: np.array([[0.1, 0.1]])})
@@ -104,7 +104,7 @@ class RNNCellTest(tf.test.TestCase):
       with tf.variable_scope("other", initializer=tf.constant_initializer(0.5)):
         x = tf.zeros([1, 3])  # Test BasicLSTMCell with input_size != num_units.
         m = tf.zeros([1, 4])
-        g, out_m = tf.nn.rnn_cell.BasicLSTMCell(2, input_size=3)(x, m)
+        g, out_m = tf.nn.rnn_cell.BasicLSTMCell(2)(x, m)
         sess.run([tf.initialize_all_variables()])
         res = sess.run([g, out_m], {x.name: np.array([[1., 1., 1.]]),
                                     m.name: 0.1 * np.ones([1, 4])})
@@ -147,8 +147,7 @@ class RNNCellTest(tf.test.TestCase):
         x = tf.zeros([batch_size, input_size])
         m = tf.zeros([batch_size, state_size])
         output, state = tf.nn.rnn_cell.LSTMCell(
-            num_units=num_units, input_size=input_size,
-            num_proj=num_proj, forget_bias=1.0)(x, m)
+            num_units=num_units, num_proj=num_proj, forget_bias=1.0)(x, m)
         sess.run([tf.initialize_all_variables()])
         res = sess.run([output, state],
                        {x.name: np.array([[1., 1.], [2., 2.], [3., 3.]]),
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index 646c981791d..469635ae4f8 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -26,9 +26,15 @@ import numpy as np
 from six.moves import xrange  # pylint: disable=redefined-builtin
 import tensorflow as tf
 
+from tensorflow.python.ops import rnn_cell
 
-def _flatten(list_of_lists):
-  return [x for y in list_of_lists for x in y]
+# pylint: disable=protected-access
+_is_sequence = rnn_cell._is_sequence
+_unpacked_state = rnn_cell._unpacked_state
+_packed_state = rnn_cell._packed_state
+# pylint: enable=protected-access
+
+_flatten = _unpacked_state
 
 
 class Plus1RNNCell(tf.nn.rnn_cell.RNNCell):
@@ -48,24 +54,32 @@ class Plus1RNNCell(tf.nn.rnn_cell.RNNCell):
 
 class TestStateSaver(object):
 
-  def __init__(self, batch_size, state_size, state_is_tuple=False):
+  def __init__(self, batch_size, state_size):
     self._batch_size = batch_size
     self._state_size = state_size
-    self._state_is_tuple = state_is_tuple
     self.saved_state = {}
 
-  def state(self, _):
-    if self._state_is_tuple:
-      return tuple(
-          tf.zeros(tf.pack([self._batch_size, s])) for s in self._state_size)
+  def state(self, name):
+    if isinstance(self._state_size, dict):
+      return tf.zeros([self._batch_size, self._state_size[name]])
     else:
-      return tf.zeros(tf.pack([self._batch_size, self._state_size]))
+      return tf.zeros([self._batch_size, self._state_size])
 
   def save_state(self, name, state):
     self.saved_state[name] = state
     return tf.identity(state)
 
 
+class PackStateTest(tf.test.TestCase):
+
+  def testPackUnpackState(self):
+    structure = ((3, 4), 5, (6, 7, (9, 10), 8))
+    flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
+    self.assertEqual(_unpacked_state(structure), (3, 4, 5, 6, 7, 9, 10, 8))
+    self.assertEqual(_packed_state(structure, flat),
+                     (("a", "b"), "c", ("d", "e", ("f", "g"), "h")))
+
+
 class RNNTest(tf.test.TestCase):
 
   def setUp(self):
@@ -197,7 +211,7 @@ class GRUTest(tf.test.TestCase):
       concat_inputs = tf.placeholder(
           tf.float32, shape=(time_steps, batch_size, input_size))
 
-      cell = tf.nn.rnn_cell.GRUCell(num_units=num_units, input_size=input_size)
+      cell = tf.nn.rnn_cell.GRUCell(num_units=num_units)
 
       with tf.variable_scope("dynamic_scope"):
         outputs_dynamic, state_dynamic = tf.nn.dynamic_rnn(
@@ -229,8 +243,7 @@ class LSTMTest(tf.test.TestCase):
     max_length = 8
     with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
       initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
-      cell = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, initializer=initializer)
+      cell = tf.nn.rnn_cell.LSTMCell(num_units, initializer=initializer)
       inputs = max_length * [
           tf.placeholder(tf.float32, shape=(batch_size, input_size))]
       outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
@@ -250,8 +263,7 @@ class LSTMTest(tf.test.TestCase):
     with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
       initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
       cell = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=True,
-          cell_clip=0.0, initializer=initializer)
+          num_units, use_peepholes=True, cell_clip=0.0, initializer=initializer)
       inputs = max_length * [
           tf.placeholder(tf.float32, shape=(batch_size, input_size))]
       outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
@@ -276,7 +288,7 @@ class LSTMTest(tf.test.TestCase):
       initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
       state_saver = TestStateSaver(batch_size, 2 * num_units)
       cell = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=False, initializer=initializer)
+          num_units, use_peepholes=False, initializer=initializer)
       inputs = max_length * [
           tf.placeholder(tf.float32, shape=(batch_size, input_size))]
       with tf.variable_scope("share_scope"):
@@ -293,16 +305,16 @@ class LSTMTest(tf.test.TestCase):
           feed_dict={inputs[0]: input_value})
       self.assertAllEqual(last_state_value, saved_state_value)
 
-  def _testNoProjNoShardingTupleStateSaver(self, use_gpu):
+  def testNoProjNoShardingTupleStateSaver(self):
     num_units = 3
     input_size = 5
     batch_size = 2
     max_length = 8
-    with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
+    with self.test_session(graph=tf.Graph()) as sess:
       initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
-      state_saver = TestStateSaver(batch_size, (num_units, num_units))
+      state_saver = TestStateSaver(batch_size, num_units)
       cell = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=False, initializer=initializer,
+          num_units, use_peepholes=False, initializer=initializer,
           state_is_tuple=True)
       inputs = max_length * [
           tf.placeholder(tf.float32, shape=(batch_size, input_size))]
@@ -316,10 +328,70 @@ class LSTMTest(tf.test.TestCase):
       tf.initialize_all_variables().run()
       input_value = np.random.randn(batch_size, input_size)
       last_and_saved_states = sess.run(
-          state + state_saver.saved_state.values(),
+          state + (state_saver.saved_state["c"], state_saver.saved_state["m"]),
           feed_dict={inputs[0]: input_value})
       self.assertEqual(4, len(last_and_saved_states))
-      self.assertEqual(last_and_saved_states[:2], last_and_saved_states[2:])
+      self.assertAllEqual(last_and_saved_states[:2], last_and_saved_states[2:])
+
+  def testNoProjNoShardingNestedTupleStateSaver(self):
+    num_units = 3
+    input_size = 5
+    batch_size = 2
+    max_length = 8
+    with self.test_session(graph=tf.Graph()) as sess:
+      initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
+      state_saver = TestStateSaver(batch_size, {"c0": num_units,
+                                                "m0": num_units,
+                                                "c1": num_units + 1,
+                                                "m1": num_units + 1,
+                                                "c2": num_units + 2,
+                                                "m2": num_units + 2,
+                                                "c3": num_units + 3,
+                                                "m3": num_units + 3})
+      def _cell(i):
+        return tf.nn.rnn_cell.LSTMCell(
+            num_units + i, use_peepholes=False, initializer=initializer,
+            state_is_tuple=True)
+
+      # This creates a state tuple which has 4 sub-tuples of length 2 each.
+      cell = tf.nn.rnn_cell.MultiRNNCell(
+          [_cell(i) for i in range(4)], state_is_tuple=True)
+
+      self.assertEqual(len(cell.state_size), 4)
+      for i in range(4):
+        self.assertEqual(len(cell.state_size[i]), 2)
+
+      inputs = max_length * [
+          tf.placeholder(tf.float32, shape=(batch_size, input_size))]
+
+      state_names = (("c0", "m0"), ("c1", "m1"),
+                     ("c2", "m2"), ("c3", "m3"))
+      with tf.variable_scope("share_scope"):
+        outputs, state = tf.nn.state_saving_rnn(
+            cell, inputs, state_saver=state_saver, state_name=state_names)
+      self.assertEqual(len(outputs), len(inputs))
+
+      # Final output comes from _cell(3) which has state size num_units + 3
+      for out in outputs:
+        self.assertEqual(out.get_shape().as_list(), [batch_size, num_units + 3])
+
+      tf.initialize_all_variables().run()
+      input_value = np.random.randn(batch_size, input_size)
+      last_states = sess.run(
+          list(_unpacked_state(state)), feed_dict={inputs[0]: input_value})
+      saved_states = sess.run(
+          list(state_saver.saved_state.values()),
+          feed_dict={inputs[0]: input_value})
+      self.assertEqual(8, len(last_states))
+      self.assertEqual(8, len(saved_states))
+      flat_state_names = _unpacked_state(state_names)
+      named_saved_states = dict(
+          zip(state_saver.saved_state.keys(), saved_states))
+
+      for i in range(8):
+        self.assertAllEqual(
+            last_states[i],
+            named_saved_states[flat_state_names[i]])
 
   def _testProjNoSharding(self, use_gpu):
     num_units = 3
@@ -332,7 +404,7 @@ class LSTMTest(tf.test.TestCase):
       inputs = max_length * [
           tf.placeholder(tf.float32, shape=(None, input_size))]
       cell = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=True,
+          num_units, use_peepholes=True,
           num_proj=num_proj, initializer=initializer)
       outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
       self.assertEqual(len(outputs), len(inputs))
@@ -353,21 +425,21 @@ class LSTMTest(tf.test.TestCase):
       inputs = max_length * [
           tf.placeholder(tf.float32, shape=(None, input_size))]
       cell_notuple = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=True,
+          num_units, use_peepholes=True,
           num_proj=num_proj, initializer=initializer)
       cell_tuple = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=True,
+          num_units, use_peepholes=True,
           num_proj=num_proj, initializer=initializer, state_is_tuple=True)
       outputs_notuple, state_notuple = tf.nn.rnn(
           cell_notuple, inputs, dtype=tf.float32,
           sequence_length=sequence_length)
       tf.get_variable_scope().reuse_variables()
-      outputs_tuple, state_is_tuple = tf.nn.rnn(
+      outputs_tuple, state_tuple = tf.nn.rnn(
           cell_tuple, inputs, dtype=tf.float32,
           sequence_length=sequence_length)
       self.assertEqual(len(outputs_notuple), len(inputs))
       self.assertEqual(len(outputs_tuple), len(inputs))
-      self.assertTrue(isinstance(state_is_tuple, tuple))
+      self.assertTrue(isinstance(state_tuple, tuple))
       self.assertTrue(isinstance(state_notuple, tf.Tensor))
 
       tf.initialize_all_variables().run()
@@ -380,9 +452,9 @@ class LSTMTest(tf.test.TestCase):
 
       (state_notuple_v,) = sess.run(
           (state_notuple,), feed_dict={inputs[0]: input_value})
-      state_is_tuple_v = sess.run(
-          state_is_tuple, feed_dict={inputs[0]: input_value})
-      self.assertAllEqual(state_notuple_v, np.hstack(state_is_tuple_v))
+      state_tuple_v = sess.run(
+          state_tuple, feed_dict={inputs[0]: input_value})
+      self.assertAllEqual(state_notuple_v, np.hstack(state_tuple_v))
 
   def _testProjSharding(self, use_gpu):
     num_units = 3
@@ -400,7 +472,6 @@ class LSTMTest(tf.test.TestCase):
 
       cell = tf.nn.rnn_cell.LSTMCell(
           num_units,
-          input_size=input_size,
           use_peepholes=True,
           num_proj=num_proj,
           num_unit_shards=num_unit_shards,
@@ -430,7 +501,6 @@ class LSTMTest(tf.test.TestCase):
 
       cell = tf.nn.rnn_cell.LSTMCell(
           num_units,
-          input_size=input_size,
           use_peepholes=True,
           num_proj=num_proj,
           num_unit_shards=num_unit_shards,
@@ -455,7 +525,6 @@ class LSTMTest(tf.test.TestCase):
 
       cell = tf.nn.rnn_cell.LSTMCell(
           num_units,
-          input_size=input_size,
           use_peepholes=True,
           num_proj=num_proj,
           num_unit_shards=num_unit_shards,
@@ -487,7 +556,7 @@ class LSTMTest(tf.test.TestCase):
       initializer = tf.constant_initializer(0.001)
 
       cell_noshard = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size,
+          num_units,
           num_proj=num_proj,
           use_peepholes=True,
           initializer=initializer,
@@ -495,7 +564,7 @@ class LSTMTest(tf.test.TestCase):
           num_proj_shards=num_proj_shards)
 
       cell_shard = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=True,
+          num_units, use_peepholes=True,
           initializer=initializer, num_proj=num_proj)
 
       with tf.variable_scope("noshard_scope"):
@@ -541,7 +610,6 @@ class LSTMTest(tf.test.TestCase):
 
       cell = tf.nn.rnn_cell.LSTMCell(
           num_units,
-          input_size=input_size,
           use_peepholes=True,
           num_proj=num_proj,
           num_unit_shards=num_unit_shards,
@@ -577,10 +645,10 @@ class LSTMTest(tf.test.TestCase):
       inputs = max_length * [
           tf.placeholder(tf.float32, shape=(None, input_size))]
       cell = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=True,
+          num_units, use_peepholes=True,
           num_proj=num_proj, initializer=initializer)
       cell_d = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=True,
+          num_units, use_peepholes=True,
           num_proj=num_proj, initializer=initializer_d)
 
       with tf.variable_scope("share_scope"):
@@ -616,7 +684,7 @@ class LSTMTest(tf.test.TestCase):
       inputs = max_length * [
           tf.placeholder(tf.float32, shape=(None, input_size))]
       cell = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=True,
+          num_units, use_peepholes=True,
           num_proj=num_proj, initializer=initializer)
 
       with tf.name_scope("scope0"):
@@ -649,7 +717,7 @@ class LSTMTest(tf.test.TestCase):
           tf.placeholder(tf.float32, shape=(None, input_size))]
       inputs_c = tf.pack(inputs)
       cell = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=True,
+          num_units, use_peepholes=True,
           num_proj=num_proj, initializer=initializer, state_is_tuple=True)
       outputs_static, state_static = tf.nn.rnn(
           cell, inputs, dtype=tf.float32,
@@ -675,6 +743,61 @@ class LSTMTest(tf.test.TestCase):
       self.assertAllEqual(
           np.hstack(state_static_v), np.hstack(state_dynamic_v))
 
+  def testDynamicRNNWithNestedTupleStates(self):
+    num_units = 3
+    input_size = 5
+    batch_size = 2
+    num_proj = 4
+    max_length = 8
+    sequence_length = [4, 6]
+    with self.test_session(graph=tf.Graph()) as sess:
+      initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
+      inputs = max_length * [
+          tf.placeholder(tf.float32, shape=(None, input_size))]
+      inputs_c = tf.pack(inputs)
+      def _cell(i):
+        return tf.nn.rnn_cell.LSTMCell(
+            num_units + i, use_peepholes=True,
+            num_proj=num_proj + i, initializer=initializer, state_is_tuple=True)
+
+      # This creates a state tuple which has 4 sub-tuples of length 2 each.
+      cell = tf.nn.rnn_cell.MultiRNNCell(
+          [_cell(i) for i in range(4)], state_is_tuple=True)
+
+      self.assertEqual(len(cell.state_size), 4)
+      for i in range(4):
+        self.assertEqual(len(cell.state_size[i]), 2)
+
+      test_zero = cell.zero_state(1, tf.float32)
+      self.assertEqual(len(test_zero), 4)
+      for i in range(4):
+        self.assertEqual(test_zero[i][0].get_shape()[1], cell.state_size[i][0])
+        self.assertEqual(test_zero[i][1].get_shape()[1], cell.state_size[i][1])
+
+      outputs_static, state_static = tf.nn.rnn(
+          cell, inputs, dtype=tf.float32,
+          sequence_length=sequence_length)
+      tf.get_variable_scope().reuse_variables()
+      outputs_dynamic, state_dynamic = tf.nn.dynamic_rnn(
+          cell, inputs_c, dtype=tf.float32, time_major=True,
+          sequence_length=sequence_length)
+
+      tf.initialize_all_variables().run()
+
+      input_value = np.random.randn(batch_size, input_size)
+      outputs_static_v = sess.run(
+          outputs_static, feed_dict={inputs[0]: input_value})
+      outputs_dynamic_v = sess.run(
+          outputs_dynamic, feed_dict={inputs[0]: input_value})
+      self.assertAllEqual(outputs_static_v, outputs_dynamic_v)
+
+      state_static_v = sess.run(
+          _unpacked_state(state_static), feed_dict={inputs[0]: input_value})
+      state_dynamic_v = sess.run(
+          _unpacked_state(state_dynamic), feed_dict={inputs[0]: input_value})
+      self.assertAllEqual(
+          np.hstack(state_static_v), np.hstack(state_dynamic_v))
+
   def _testDynamicEquivalentToStaticRNN(self, use_gpu, use_sequence_length):
     time_steps = 8
     num_units = 3
@@ -697,7 +820,7 @@ class LSTMTest(tf.test.TestCase):
       initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
 
       cell = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=True,
+          num_units, use_peepholes=True,
           initializer=initializer, num_proj=num_proj)
 
       with tf.variable_scope("dynamic_scope"):
@@ -752,7 +875,7 @@ class LSTMTest(tf.test.TestCase):
       initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
 
       cell = tf.nn.rnn_cell.LSTMCell(
-          num_units, input_size, use_peepholes=True,
+          num_units, use_peepholes=True,
           initializer=initializer, num_proj=num_proj)
 
       with tf.variable_scope("dynamic_scope"):
@@ -1010,8 +1133,7 @@ def _static_vs_dynamic_rnn_benchmark_static(inputs_list_t, sequence_length):
   (_, input_size) = inputs_list_t[0].get_shape().as_list()
   initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127)
   cell = tf.nn.rnn_cell.LSTMCell(
-      num_units=input_size, input_size=input_size, use_peepholes=True,
-      initializer=initializer)
+      num_units=input_size, use_peepholes=True, initializer=initializer)
   outputs, final_state = tf.nn.rnn(
       cell, inputs_list_t, sequence_length=sequence_length, dtype=tf.float32)
 
@@ -1025,8 +1147,7 @@ def _static_vs_dynamic_rnn_benchmark_dynamic(inputs_t, sequence_length):
   (unused_0, unused_1, input_size) = inputs_t.get_shape().as_list()
   initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127)
   cell = tf.nn.rnn_cell.LSTMCell(
-      num_units=input_size, input_size=input_size, use_peepholes=True,
-      initializer=initializer)
+      num_units=input_size, use_peepholes=True, initializer=initializer)
   outputs, final_state = tf.nn.dynamic_rnn(
       cell, inputs_t, sequence_length=sequence_length, dtype=tf.float32)
 
@@ -1129,8 +1250,7 @@ def _half_seq_len_vs_unroll_half_rnn_benchmark(inputs_list_t, sequence_length):
   (_, input_size) = inputs_list_t[0].get_shape().as_list()
   initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127)
   cell = tf.nn.rnn_cell.LSTMCell(
-      num_units=input_size, input_size=input_size, use_peepholes=True,
-      initializer=initializer)
+      num_units=input_size, use_peepholes=True, initializer=initializer)
   outputs, final_state = tf.nn.rnn(
       cell, inputs_list_t, sequence_length=sequence_length, dtype=tf.float32)
 
@@ -1183,7 +1303,7 @@ def _concat_state_vs_tuple_state_rnn_benchmark(
   (_, input_size) = inputs_list_t[0].get_shape().as_list()
   initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127)
   cell = tf.nn.rnn_cell.LSTMCell(
-      num_units=input_size, input_size=input_size, use_peepholes=True,
+      num_units=input_size, use_peepholes=True,
       initializer=initializer, state_is_tuple=state_is_tuple)
   outputs, final_state = tf.nn.rnn(
       cell, inputs_list_t, sequence_length=sequence_length, dtype=tf.float32)
@@ -1239,8 +1359,7 @@ def _dynamic_rnn_swap_memory_benchmark(inputs_t, sequence_length,
   (unused_0, unused_1, input_size) = inputs_t.get_shape().as_list()
   initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=127)
   cell = tf.nn.rnn_cell.LSTMCell(
-      num_units=input_size, input_size=input_size, use_peepholes=True,
-      initializer=initializer)
+      num_units=input_size, use_peepholes=True, initializer=initializer)
   outputs, final_state = tf.nn.dynamic_rnn(
       cell, inputs_t, sequence_length=sequence_length,
       swap_memory=swap_memory, dtype=tf.float32)
diff --git a/tensorflow/python/kernel_tests/sparse_concat_op_test.py b/tensorflow/python/kernel_tests/sparse_concat_op_test.py
index 9c365c0bc7b..5c71820b090 100644
--- a/tensorflow/python/kernel_tests/sparse_concat_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_concat_op_test.py
@@ -246,6 +246,16 @@ class SparseConcatTest(tf.test.TestCase):
       with self.assertRaises(ValueError):
         tf.sparse_concat(1, [sp_a, sp_e])
 
+  def testMismatchedRankExpandNonconcatDim(self):
+    with self.test_session(use_gpu=False):
+      sp_a = self._SparseTensor_3x3()
+      sp_e = self._SparseTensor_2x3x4()
+
+      # Rank mismatches should be caught at shape-inference time, even for
+      # expand_nonconcat_dim=True.
+      with self.assertRaises(ValueError):
+        tf.sparse_concat(1, [sp_a, sp_e], expand_nonconcat_dim=True)
+
   def testMismatchedShapes(self):
     with self.test_session(use_gpu=False) as sess:
       sp_a = self._SparseTensor_3x3()
@@ -258,6 +268,42 @@ class SparseConcatTest(tf.test.TestCase):
       with self.assertRaisesOpError("Input shapes must match"):
         sess.run(sp_concat)
 
+  def testMismatchedShapesExpandNonconcatDim(self):
+    with self.test_session(use_gpu=False) as sess:
+      sp_a = self._SparseTensor_3x3()
+      sp_b = self._SparseTensor_3x5()
+      sp_c = self._SparseTensor_3x2()
+      sp_d = self._SparseTensor_2x3()
+      sp_concat_dim0 = tf.sparse_concat(0, [sp_a, sp_b, sp_c, sp_d],
+                                        expand_nonconcat_dim=True)
+      sp_concat_dim1 = tf.sparse_concat(1, [sp_a, sp_b, sp_c, sp_d],
+                                        expand_nonconcat_dim=True)
+
+      sp_concat_dim0_out = sess.run(sp_concat_dim0)
+      sp_concat_dim1_out = sess.run(sp_concat_dim1)
+
+      self.assertAllEqual(
+          sp_concat_dim0_out.indices,
+          [[0, 2], [1, 0], [2, 0], [2, 2], [4, 1], [5, 0], [5, 3], [5, 4],
+           [7, 0], [8, 0], [9, 1], [10, 0], [10, 2]])
+      self.assertAllEqual(
+          sp_concat_dim0_out.values,
+          [1, 2, 3, 4, 1, 2, 1, 0, 1, 2, 1, 1, 2])
+      self.assertAllEqual(
+          sp_concat_dim0_out.shape,
+          [11, 5])
+
+      self.assertAllEqual(
+          sp_concat_dim1_out.indices,
+          [[0, 2], [0, 11], [1, 0], [1, 4], [1, 8], [1, 10], [1, 12], [2, 0],
+           [2, 2], [2, 3], [2, 6], [2, 7], [2, 8]])
+      self.assertAllEqual(
+          sp_concat_dim1_out.values,
+          [1, 1, 2, 1, 1, 1, 2, 3, 4, 2, 1, 0, 2])
+      self.assertAllEqual(
+          sp_concat_dim1_out.shape,
+          [3, 13])
+
   def testShapeInferenceUnknownShapes(self):
     with self.test_session(use_gpu=False):
       sp_inputs = [
diff --git a/tensorflow/python/kernel_tests/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py
index a0113e7c20c..6b046883d4d 100644
--- a/tensorflow/python/kernel_tests/sparse_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_ops_test.py
@@ -25,6 +25,7 @@ import tensorflow as tf
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import constant_op
 from tensorflow.python.ops import sparse_ops
 from tensorflow.python.platform import googletest
@@ -234,6 +235,101 @@ class SparseRetainTest(test_util.TensorFlowTestCase):
         sparse_ops.sparse_retain(sp_input, to_retain)
 
 
+class SparseResetShapeTest(test_util.TensorFlowTestCase):
+
+  _IND_2_5_6 = np.array([[0, 0, 0], [0, 1, 0], [0, 1, 3], [1, 1, 4],
+                         [1, 3, 2], [1, 3, 3]], dtype=np.int64)
+  _VAL_2_5_6 = np.array([0, 10, 13, 14, 32, 33], dtype=np.int32)
+  _SHP_2_5_6 = np.array([2, 5, 6], dtype=np.int64)
+
+  def _SparseTensor_2x5x6(self):
+    return ops.SparseTensor(
+        constant_op.constant(self._IND_2_5_6, dtypes.int64),
+        constant_op.constant(self._VAL_2_5_6, dtypes.int32),
+        constant_op.constant(self._SHP_2_5_6, dtypes.int64))
+
+  def _SparseTensorValue_2x5x6(self):
+    return ops.SparseTensorValue(self._IND_2_5_6, self._VAL_2_5_6,
+                                 self._SHP_2_5_6)
+
+  def testBasic(self):
+    with self.test_session(use_gpu=False) as sess:
+      sp_input = self._SparseTensor_2x5x6()
+      new_shape = np.array([3, 6, 7], dtype=np.int64)
+      sp_output = sparse_ops.sparse_reset_shape(sp_input, new_shape)
+
+      output = sess.run(sp_output)
+
+      self.assertAllEqual(output.indices, [[0, 0, 0], [0, 1, 0],
+                                           [0, 1, 3], [1, 1, 4],
+                                           [1, 3, 2], [1, 3, 3]])
+      self.assertAllEqual(output.values, [0, 10, 13, 14, 32, 33])
+      self.assertAllEqual(output.shape, [3, 6, 7])
+
+  def testInputUnavaibleInGraphConstructionOk(self):
+    with self.test_session(use_gpu=False) as sess:
+      sp_input = array_ops.sparse_placeholder(dtype=dtypes.int32)
+      new_shape = np.array([3, 6, 7], dtype=np.int64)
+      sp_output = sparse_ops.sparse_reset_shape(sp_input, new_shape)
+
+      output = sess.run(sp_output,
+                        feed_dict={sp_input: self._SparseTensorValue_2x5x6()})
+
+      self.assertAllEqual(output.indices, [[0, 0, 0], [0, 1, 0],
+                                           [0, 1, 3], [1, 1, 4],
+                                           [1, 3, 2], [1, 3, 3]])
+      self.assertAllEqual(output.values, [0, 10, 13, 14, 32, 33])
+      self.assertAllEqual(output.shape, [3, 6, 7])
+
+  def testTightBoundingBox(self):
+    with self.test_session(use_gpu=False) as sess:
+      sp_input = self._SparseTensor_2x5x6()
+      sp_output = sparse_ops.sparse_reset_shape(sp_input)
+
+      output = sess.run(sp_output)
+
+      self.assertAllEqual(output.indices, [[0, 0, 0], [0, 1, 0],
+                                           [0, 1, 3], [1, 1, 4],
+                                           [1, 3, 2], [1, 3, 3]])
+      self.assertAllEqual(output.values, [0, 10, 13, 14, 32, 33])
+      self.assertAllEqual(output.shape, [2, 4, 5])
+
+  def testInvalidRank(self):
+    with self.test_session(use_gpu=False):
+      sp_input = self._SparseTensor_2x5x6()
+      new_shape = np.array([3, 7], dtype=np.int64)
+
+      with self.assertRaises(ValueError):
+        sparse_ops.sparse_reset_shape(sp_input, new_shape)
+
+  def testInvalidRankNewShapeUnavaibleInGraphConstruction(self):
+    with self.test_session(use_gpu=False) as sess:
+      new_shape = array_ops.placeholder(dtype=dtypes.int64)
+      sp_input = self._SparseTensor_2x5x6()
+      out = sparse_ops.sparse_reset_shape(sp_input, new_shape)
+
+      with self.assertRaisesOpError("x == y did not hold element-wise"):
+        sess.run(out, feed_dict={new_shape: np.array([3, 7], dtype=np.int64)})
+
+  def testInvalidDimensionSize(self):
+    with self.test_session(use_gpu=False) as sess:
+      sp_input = self._SparseTensor_2x5x6()
+      new_shape = np.array([3, 7, 5], dtype=np.int64)
+      out = sparse_ops.sparse_reset_shape(sp_input, new_shape)
+
+      with self.assertRaisesOpError("x <= y did not hold element-wise"):
+        sess.run(out)
+
+  def testInvalidDimensionSizeInputUnavailableInGraphConstruction(self):
+    sp_input = array_ops.sparse_placeholder(dtype=dtypes.int32)
+    with self.test_session(use_gpu=False) as sess:
+      new_shape = np.array([3, 7, 5], dtype=np.int64)
+      out = sparse_ops.sparse_reset_shape(sp_input, new_shape)
+
+      with self.assertRaisesOpError("x <= y did not hold element-wise"):
+        sess.run(out, feed_dict={sp_input: self._SparseTensorValue_2x5x6()})
+
+
 class SparseFillEmptyRowsTest(test_util.TensorFlowTestCase):
 
   def _SparseTensor_5x6(self):
@@ -391,13 +487,15 @@ class SparseReduceSumTest(test_util.TensorFlowTestCase):
 class SparseMathOpsTest(test_util.TensorFlowTestCase):
 
   def _check(self, result_tensor, result_np, input_sp_t):
+    self.assertTrue(isinstance(result_tensor, ops.SparseTensor))
+    self.assertTrue(isinstance(input_sp_t, ops.SparseTensor))
     self.assertAllEqual(input_sp_t.indices.eval(), result_tensor.indices.eval())
     self.assertAllEqual(input_sp_t.shape.eval(), result_tensor.shape.eval())
 
     res_densified = sparse_ops.sparse_to_dense(result_tensor.indices,
                                                result_tensor.shape,
                                                result_tensor.values).eval()
-    self.assertAllEqual(res_densified, result_np)
+    self.assertAllEqual(result_np, res_densified)
 
   def testCwiseDivAndMul(self):
     np.random.seed(1618)
@@ -422,6 +520,23 @@ class SparseMathOpsTest(test_util.TensorFlowTestCase):
             res = sp_t / dense_t  # should invoke "__truediv__"
             self.assertEqual(res.values.eval().dtype, np.float64)
 
+  def testCwiseAdd(self):
+    with self.test_session(use_gpu=False):
+      # Identity(2) + AllOnes(2,2).  Should be equal to 2 * Identity(2).
+      indices = [[0, 0], [1, 1]]
+      vals = [1, 1]
+      shape = (2, 2)
+
+      sp_t = tf.SparseTensor(indices, vals, shape)
+      dense_t = tf.ones(shape, dtype=dtypes.int32)
+      self._check(sparse_ops.sparse_dense_cwise_add(sp_t, dense_t),
+                  np.identity(2) * 2, sp_t)
+
+      # Variant of above, but broadcasts the dense side.
+      dense_t = tf.ones([1], dtype=dtypes.int32)
+      self._check(sparse_ops.sparse_dense_cwise_add(sp_t, dense_t),
+                  np.identity(2) * 2, sp_t)
+
   def testGradients(self):
     np.random.seed(1618)
     sp_shapes = [(10, 10, 10), (5, 5), (1618,), (3, 3, 7)]
@@ -451,5 +566,56 @@ class SparseMathOpsTest(test_util.TensorFlowTestCase):
           self.assertLess(err, 2e-4)
 
 
+class SparseSoftmaxTest(test_util.TensorFlowTestCase):
+
+  def testEquivalentToDensified(self):
+    np.random.seed(1618)
+    n, m = np.random.choice(20, size=2)
+
+    for dtype in [np.float32, np.float64]:
+      sp_vals_np = np.random.rand(n, m).astype(dtype)
+
+      batched_sp_t, unused_nnz1 = _sparsify(
+          sp_vals_np.reshape((1, n, m)), thresh=0.)  # No masking.
+
+      with self.test_session(use_gpu=False):
+        densified = tf.constant(sp_vals_np)
+
+        sp_result = sparse_ops.sparse_softmax(
+            batched_sp_t).eval().values.reshape((n, m))
+        dense_result = tf.nn.softmax(densified)
+
+        self.assertAllClose(dense_result.eval(), sp_result)
+
+  def testHigherRanks(self):
+    # For the first shape:
+    # First batch:
+    # [?   e.]
+    # [1.  ? ]
+    # Second batch:
+    # [e   ? ]
+    # [e   e ]
+    #
+    # The softmax results should be:
+    # [?   1.]     [1    ?]
+    # [1.  ? ] and [.5  .5]
+    # where ? means implicitly zero.
+    #
+    # The second shape: same input data, but with a higher-rank shape.
+    shapes = [[2, 2, 2], [2, 1, 2, 2]]
+    for shape in shapes:
+      values = np.asarray(
+          [0., np.e, 1., 0., np.e, 0., np.e, np.e]).reshape(shape)
+      sp_t, unused_nnz = _sparsify(values, thresh=1e-2)
+      expected_values = [1., 1., 1., .5, .5]
+
+      with self.test_session(use_gpu=False):
+        result = sparse_ops.sparse_softmax(sp_t).eval()
+
+        self.assertAllEqual(expected_values, result.values)
+        self.assertAllEqual(sp_t.indices.eval(), result.indices)
+        self.assertAllEqual(shape, result.shape)
+
+
 if __name__ == "__main__":
   googletest.main()
diff --git a/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py b/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py
index 379edbfbb04..8a018573d1f 100644
--- a/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py
+++ b/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import numpy as np
 import tensorflow as tf
 
 
@@ -66,6 +67,30 @@ class StringToHashBucketOpTest(tf.test.TestCase):
       # Hash64('c') -> 14899841994519054197 -> mod 10 -> 7
       self.assertAllEqual([8, 0, 7], result)
 
+  def testStringToOneHashBucketStrongOneHashBucket(self):
+    with self.test_session():
+      input_string = tf.constant(['a', 'b', 'c'])
+      output = tf.string_to_hash_bucket_strong(input_string, 1, key=[123, 345])
+      self.assertAllEqual([0, 0, 0], output.eval())
+
+  def testStringToHashBucketsStrong(self):
+    with self.test_session():
+      input_string = tf.constant(['a', 'b', 'c'])
+      output = tf.string_to_hash_bucket_strong(input_string,
+                                               10,
+                                               key=[98765, 132])
+      # key = [98765, 132]
+      # StrongKeyedHash(key, 'a') -> 7157389809176466784 -> mod 10 -> 4
+      # StrongKeyedHash(key, 'b') -> 15805638358933211562 -> mod 10 -> 2
+      # StrongKeyedHash(key, 'c') -> 18100027895074076528 -> mod 10 -> 8
+      self.assertAllEqual([4, 2, 8], output.eval())
+
+  def testStringToHashBucketsStrongInvalidKey(self):
+    with self.test_session():
+      input_string = tf.constant(['a', 'b', 'c'])
+      with self.assertRaisesOpError('Key must have 2 elements'):
+        tf.string_to_hash_bucket_strong(input_string, 10, key=[98765]).eval()
+
 
 if __name__ == '__main__':
   tf.test.main()
diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py
index eb617cf1edb..9a1840e4766 100644
--- a/tensorflow/python/ops/check_ops.py
+++ b/tensorflow/python/ops/check_ops.py
@@ -17,15 +17,16 @@
 
 @@assert_negative
 @@assert_positive
+@@assert_proper_iterable
 @@assert_non_negative
 @@assert_non_positive
 @@assert_equal
+@@assert_integer
 @@assert_less
 @@assert_less_equal
 @@assert_rank
 @@assert_rank_at_least
 @@assert_type
-@@assert_integer
 @@is_non_decreasing
 @@is_numeric_tensor
 @@is_strictly_increasing
@@ -35,6 +36,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import numpy as np
+
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_util
@@ -42,6 +45,7 @@ from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import logging_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.util import compat
 
 NUMERIC_TYPES = frozenset(
     [dtypes.float32, dtypes.float64, dtypes.int8, dtypes.int16, dtypes.int32,
@@ -51,14 +55,15 @@ NUMERIC_TYPES = frozenset(
 __all__ = [
     'assert_negative',
     'assert_positive',
+    'assert_proper_iterable',
     'assert_non_negative',
     'assert_non_positive',
     'assert_equal',
+    'assert_integer',
     'assert_less',
     'assert_less_equal',
     'assert_rank',
     'assert_rank_at_least',
-    'assert_integer',
     'assert_type',
     'is_non_decreasing',
     'is_numeric_tensor',
@@ -66,6 +71,33 @@ __all__ = [
 ]
 
 
+def assert_proper_iterable(values):
+  """Static assert that values is a "proper" iterable.
+
+  `Ops` that expect iterables of `Tensor` can call this to validate input.
+  Useful since `Tensor`, `ndarray`, byte/text type are all iterables themselves.
+
+  Args:
+    values:  Object to be checked.
+
+  Raises:
+    TypeError:  If `values` is not iterable or is one of
+      `Tensor`, `SparseTensor`, `np.array`, `tf.compat.bytes_or_text_types`.
+  """
+  unintentional_iterables = (
+      (ops.Tensor, ops.SparseTensor, np.ndarray)
+      + compat.bytes_or_text_types
+  )
+  if isinstance(values, unintentional_iterables):
+    raise TypeError(
+        'Expected argument "values" to be a "proper" iterable.  Found: %s' %
+        type(values))
+
+  if not hasattr(values, '__iter__'):
+    raise TypeError(
+        'Expected argument "values" to be iterable.  Found: %s' % type(values))
+
+
 def assert_negative(x, data=None, summarize=None, name=None):
   """Assert the condition `x < 0` holds element-wise.
 
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 3075c8b9ef8..9312a84f3eb 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -1533,6 +1533,7 @@ def _BroadcastShape(op):
 
 @ops.RegisterShape("SparseDenseCwiseMul")
 @ops.RegisterShape("SparseDenseCwiseDiv")
+@ops.RegisterShape("SparseDenseCwiseAdd")
 def _SparseDenseBinaryOpShape(op):  # pylint: disable=invalid-name
   """Common shape for 'sparse <binary cwise op> dense -> sparse' operators."""
   nnz = op.inputs[1].get_shape()[0]
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index 92e4ed8c4d4..d47b03db5b6 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -587,21 +587,19 @@ def separable_conv2d(input, depthwise_filter, pointwise_filter, strides,
                          padding="VALID", name=name)
 
 
-def sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None):
+def sufficient_statistics(x, axes, shift=None, keep_dims=False, name=None):
   """Calculate the sufficient statistics for the mean and variance of `x`.
 
   These sufficient statistics are computed using the one pass algorithm on
-  an input that's optionally shifted using the value of the 1st element in `x`.
-  See:
+  an input that's optionally shifted. See:
   https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data
-  Unfortunately, in some cases using a random individual sample as the shift
-  value leads experimentally to very poor numerical stability, so it is disabled
-  by default. The one-pass approach might have to be revised accordingly.
 
   Args:
     x: A `Tensor`.
     axes: Array of ints. Axes along which to compute mean and variance.
-    shift: If true, shift the data to provide more numerically stable results.
+    shift: A `Tensor` containing the value by which to shift the data for
+      numerical stability, or `None` if no shift is to be performed. A shift
+      close to the true mean provides the most numerically stable results.
     keep_dims: produce statistics with the same dimensionality as the input.
     name: Name used to scope the operations that compute the sufficient stats.
 
@@ -610,9 +608,9 @@ def sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None):
     * the count (number of elements to average over).
     * the (possibly shifted) sum of the elements in the array.
     * the (possibly shifted) sum of squares of the elements in the array.
-    * the shift by which the mean must be corrected or None if `shift` is False.
+    * the shift by which the mean must be corrected or None if `shift` is None.
   """
-  with ops.op_scope([x, axes], name, "sufficient_statistics"):
+  with ops.op_scope([x, axes, shift], name, "sufficient_statistics"):
     x = ops.convert_to_tensor(x, name="x")
     x_shape = x.get_shape()
     if x_shape.is_fully_defined():
@@ -635,23 +633,16 @@ def sufficient_statistics(x, axes, shift=False, keep_dims=False, name=None):
           math_ops.reduce_prod(x_shape / m_shape),
           x.dtype,
           name="count")
-    if shift:
-      shift_value = array_ops.slice(x, array_ops.zeros_like(m_shape), m_shape)
-      m_ss = math_ops.sub(x, shift_value)
-      v_ss = math_ops.squared_difference(x, shift_value)
-      if keep_dims:
-        shift_value = array_ops.identity(shift_value, name="shift")
-      else:
-        shift_value = array_ops.squeeze(shift_value,
-                                        squeeze_dims=axes,
-                                        name="shift")
-    else:  # not shift.
+    if shift is not None:
+      shift = ops.convert_to_tensor(shift, name="shift")
+      m_ss = math_ops.sub(x, shift)
+      v_ss = math_ops.squared_difference(x, shift)
+    else:  # no shift.
       m_ss = x
       v_ss = math_ops.square(x)
-      shift_value = None
     m_ss = math_ops.reduce_sum(m_ss, axes, keep_dims=keep_dims, name="mean_ss")
     v_ss = math_ops.reduce_sum(v_ss, axes, keep_dims=keep_dims, name="var_ss")
-  return counts, m_ss, v_ss, shift_value
+  return counts, m_ss, v_ss, shift
 
 
 def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
@@ -685,7 +676,7 @@ def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
   return (mean, variance)
 
 
-def moments(x, axes, name=None, keep_dims=False):
+def moments(x, axes, shift=None, name=None, keep_dims=False):
   """Calculate the mean and variance of `x`.
 
   The mean and variance are calculated by aggregating the contents of `x`
@@ -702,15 +693,19 @@ def moments(x, axes, name=None, keep_dims=False):
     x: A `Tensor`.
     axes: array of ints.  Axes along which to compute mean and
       variance.
+    shift: A `Tensor` containing the value by which to shift the data for
+      numerical stability, or `None` if no shift is to be performed. A shift
+      close to the true mean provides the most numerically stable results.
     keep_dims: produce moments with the same dimensionality as the input.
     name: Name used to scope the operations that compute the moments.
 
   Returns:
     Two `Tensor` objects: `mean` and `variance`.
   """
-  with ops.op_scope([x, axes], name, "moments"):
+  with ops.op_scope([x, axes, shift], name, "moments"):
     counts, m_ss, v_ss, shift = sufficient_statistics(x,
                                                       axes,
+                                                      shift=shift,
                                                       keep_dims=keep_dims,
                                                       name=name)
     return normalize_moments(counts, m_ss, v_ss, shift, name=name)
diff --git a/tensorflow/python/ops/nn_batchnorm_test.py b/tensorflow/python/ops/nn_batchnorm_test.py
index 5b9f33a73bb..c6a27a803c4 100644
--- a/tensorflow/python/ops/nn_batchnorm_test.py
+++ b/tensorflow/python/ops/nn_batchnorm_test.py
@@ -317,16 +317,10 @@ class SufficientStatisticsTest(tf.test.TestCase):
 
   def _npSuffStats(self, x, axes, shift, keep_dims):
     axis = tuple(axes)
-    if shift:
-      shift_value = x[[slice(None) if i not in set(axis) else slice(0, 1)
-                       for i in xrange(x.ndim)]]
-      m_ss = np.sum(x - shift_value, axis=axis, keepdims=keep_dims)
-      v_ss = np.sum(
-          (x - shift_value) * (x - shift_value),
-          axis=axis,
-          keepdims=keep_dims)
+    if shift is not None:
+      m_ss = np.sum(x - shift, axis=axis, keepdims=keep_dims)
+      v_ss = np.sum((x - shift) * (x - shift), axis=axis, keepdims=keep_dims)
     else:
-      shift_value = None
       m_ss = np.sum(x, axis=axis, keepdims=keep_dims)
       v_ss = np.sum(x * x, axis=axis, keepdims=keep_dims)
     count = 1.0
@@ -334,8 +328,8 @@ class SufficientStatisticsTest(tf.test.TestCase):
       if d in set(axes):
         count *= x.shape[d]
     if not keep_dims:
-      shift_value = np.squeeze(shift_value, axis=axis)
-    return count, m_ss, v_ss, shift_value
+      shift = np.squeeze(shift, axis=axis)
+    return count, m_ss, v_ss, shift
 
   def _opSuffStats(self, x, axes, shift, keep_dims):
     return tf.nn.sufficient_statistics(x, axes, shift, keep_dims)
@@ -375,7 +369,7 @@ class SufficientStatisticsTest(tf.test.TestCase):
   def testSuffStats(self):
     for has_shape in [True, False]:
       for keep_dims in [True, False]:
-        for shift in [True, False]:
+        for shift in [None, 1.0]:
           self._testSuffStats([2, 3], [1], shift, keep_dims, has_shape)
           self._testSuffStats([2, 3], [0], shift, keep_dims, has_shape)
           self._testSuffStats([1, 2, 3], [0, 2], shift, keep_dims, has_shape)
@@ -419,7 +413,7 @@ class NormalizeMomentsTest(tf.test.TestCase):
         self.assertAllClose(npv, tfv, atol=0.000001)
 
   def testNormalizeMoments(self):
-    for shift in [True, False]:
+    for shift in [None, 4.0]:
       self._testNormalizeMoments([3], shift)
       self._testNormalizeMoments([2, 3], shift)
 
diff --git a/tensorflow/python/ops/partitioned_variables.py b/tensorflow/python/ops/partitioned_variables.py
index 3c8e4d8885e..2971cb9ab43 100644
--- a/tensorflow/python/ops/partitioned_variables.py
+++ b/tensorflow/python/ops/partitioned_variables.py
@@ -65,7 +65,7 @@ __all__ = ["create_partitioned_variables", "variable_axis_size_partitioner"]
 
 
 def variable_axis_size_partitioner(
-    max_shard_bytes, axis=0, bytes_per_string_element=16):
+    max_shard_bytes, axis=0, bytes_per_string_element=16, max_shards=None):
   """Get a partitioner for VariableScope to keep shards below `max_shard_bytes`.
 
   This partitioner will shard a Variable along one axis, attempting to keep
@@ -74,6 +74,10 @@ def variable_axis_size_partitioner(
   this axis is sharded as much as possible (i.e., every dimension becomes
   a separate shard).
 
+  If the partitioner hits the `max_shards` limit, then each shard may end up
+  larger than `max_shard_bytes`. By default `max_shards` equals `None` and no
+  limit on the number of shards is enforced.
+
   One reasonable value for `max_shard_bytes` is `(64 << 20) - 1`, or almost
   `64MB`, to keep below the protobuf byte limit.
 
@@ -82,6 +86,8 @@ def variable_axis_size_partitioner(
     axis: The axis to partition along.  Default: outermost axis.
     bytes_per_string_element: If the `Variable` is of type string, this provides
       an estimate of how large each scalar in the `Variable` is.
+    max_shards: The maximum number of shards in int created taking precedence
+      over `max_shard_bytes`.
 
   Returns:
     A partition function usable as the `partitioner` argument to
@@ -93,6 +99,9 @@ def variable_axis_size_partitioner(
   if max_shard_bytes < 1 or bytes_per_string_element < 1:
     raise ValueError(
         "Both max_shard_bytes and bytes_per_string_element must be positive.")
+  if max_shards and max_shards < 1:
+    raise ValueError(
+        "max_shards must be positive.")
 
   def _partitioner(shape, dtype):
     """Partitioner that partitions shards to have max_shard_bytes total size.
@@ -129,8 +138,11 @@ def variable_axis_size_partitioner(
     # How many shards do we need for axis given that each shard fits
     # slices_per_shard slices from a total of shape[axis].value slices?
     axis_shards = int(math.ceil(1.0 * shape[axis].value / slices_per_shard))
+    if max_shards:
+      axis_shards = min(max_shards, axis_shards)
 
     partitions[axis] = axis_shards
+
     return partitions
 
   return _partitioner
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index 259811c2a4f..6d9a0d4e3f2 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -32,6 +32,13 @@ from tensorflow.python.ops import tensor_array_ops
 from tensorflow.python.ops import variable_scope as vs
 
 
+# pylint: disable=protected-access
+_is_sequence = rnn_cell._is_sequence
+_unpacked_state = rnn_cell._unpacked_state
+_packed_state = rnn_cell._packed_state
+# pylint: enable=protected-access
+
+
 def rnn(cell, inputs, initial_state=None, dtype=None,
         sequence_length=None, scope=None):
   """Creates a recurrent neural network specified by RNNCell `cell`.
@@ -177,20 +184,26 @@ def state_saving_rnn(cell, inputs, state_saver, state_name,
      type of `state_name` does not match that of `cell.state_size`.
   """
   state_size = cell.state_size
-  state_is_tuple = isinstance(state_size, (list, tuple))
-  state_name_tuple = isinstance(state_name, (list, tuple))
+  state_is_tuple = _is_sequence(state_size)
+  state_name_tuple = _is_sequence(state_name)
 
   if state_is_tuple != state_name_tuple:
     raise ValueError(
-        "state_name should be a tuple iff cell.state_size is.  state_name: %s, "
-        "cell.state_size: %s" % (str(state_name), str(state_size)))
+        "state_name should be the same type as cell.state_size.  "
+        "state_name: %s, cell.state_size: %s"
+        % (str(state_name), str(state_size)))
 
   if state_is_tuple:
-    if len(state_name) != len(state_size):
-      raise ValueError("len(state_name) != len(state_size): %d vs. %d"
-                       % (len(state_name), len(state_size)))
+    state_name_flat = _unpacked_state(state_name)
+    state_size_flat = _unpacked_state(state_size)
 
-    initial_state = tuple(state_saver.state(n) for n in state_name)
+    if len(state_name_flat) != len(state_size_flat):
+      raise ValueError("#elems(state_name) != #elems(state_size): %d vs. %d"
+                       % (len(state_name_flat), len(state_size_flat)))
+
+    initial_state = _packed_state(
+        structure=state_name,
+        state=[state_saver.state(n) for n in state_name_flat])
   else:
     initial_state = state_saver.state(state_name)
 
@@ -198,8 +211,10 @@ def state_saving_rnn(cell, inputs, state_saver, state_name,
                          sequence_length=sequence_length, scope=scope)
 
   if state_is_tuple:
+    state_flat = _unpacked_state(state)
     save_state = [
-        state_saver.save_state(n, s) for (n, s) in zip(state_name, state)]
+        state_saver.save_state(n, s)
+        for (n, s) in zip(state_name_flat, state_flat)]
   else:
     save_state = [state_saver.save_state(state_name, state)]
 
@@ -262,9 +277,10 @@ def _rnn_step(
       that returned by `state_size`.
   """
 
-  state_is_tuple = isinstance(state, (list, tuple))
+  state_is_tuple = _is_sequence(state)
+  orig_state = state
   # Convert state to a list for ease of use
-  state = list(state) if state_is_tuple else [state]
+  state = list(_unpacked_state(state)) if state_is_tuple else [state]
   state_shape = [s.get_shape() for s in state]
 
   def _copy_some_through(new_output, new_state):
@@ -279,7 +295,8 @@ def _rnn_step(
   def _maybe_copy_some_through():
     """Run RNN step.  Pass through either no or some past state."""
     new_output, new_state = call_cell()
-    new_state = list(new_state) if state_is_tuple else [new_state]
+    new_state = (
+        list(_unpacked_state(new_state)) if state_is_tuple else [new_state])
 
     if len(state) != len(new_state):
       raise ValueError(
@@ -300,7 +317,8 @@ def _rnn_step(
     # steps.  This is faster when max_seq_len is equal to the number of unrolls
     # (which is typical for dynamic_rnn).
     new_output, new_state = call_cell()
-    new_state = list(new_state) if state_is_tuple else [new_state]
+    new_state = (
+        list(_unpacked_state(new_state)) if state_is_tuple else [new_state])
 
     if len(state) != len(new_state):
       raise ValueError(
@@ -325,7 +343,9 @@ def _rnn_step(
     final_state_i.set_shape(state_shape_i)
 
   if state_is_tuple:
-    return (final_output, tuple(final_state))
+    return (
+        final_output,
+        _packed_state(structure=orig_state, state=final_state))
   else:
     return (final_output, final_state[0])
 
@@ -613,9 +633,9 @@ def _dynamic_rnn_loop(
   time = array_ops.constant(0, dtype=dtypes.int32, name="time")
 
   state_size = cell.state_size
-  state_is_tuple = isinstance(state_size, (list, tuple))
+  state_is_tuple = _is_sequence(state_size)
 
-  state = tuple(state) if state_is_tuple else (state,)
+  state = _unpacked_state(state) if state_is_tuple else (state,)
 
   with ops.op_scope([], "dynamic_rnn") as scope:
     base_name = scope
@@ -646,8 +666,9 @@ def _dynamic_rnn_loop(
     # Restore some shape information
     input_t.set_shape([const_batch_size, const_depth])
 
-    # Unpack state if not using state tuples
-    state = tuple(state) if state_is_tuple else state[0]
+    # Pack state back up for use by cell
+    state = (_packed_state(structure=state_size, state=state)
+             if state_is_tuple else state[0])
 
     call_cell = lambda: cell(input_t, state)
 
@@ -665,7 +686,7 @@ def _dynamic_rnn_loop(
       (output, new_state) = call_cell()
 
     # Pack state if using state tuples
-    new_state = tuple(new_state) if state_is_tuple else (new_state,)
+    new_state = _unpacked_state(new_state) if state_is_tuple else (new_state,)
 
     output_ta_t = output_ta_t.write(time, output)
 
@@ -686,6 +707,7 @@ def _dynamic_rnn_loop(
       const_time_steps, const_batch_size, cell.output_size])
 
   # Unpack final state if not using state tuples.
-  final_state = tuple(final_state) if state_is_tuple else final_state[0]
+  final_state = (
+      _unpacked_state(final_state) if state_is_tuple else final_state[0])
 
   return (final_outputs, final_state)
diff --git a/tensorflow/python/ops/rnn_cell.py b/tensorflow/python/ops/rnn_cell.py
index bfd0758883b..69ff7775d52 100644
--- a/tensorflow/python/ops/rnn_cell.py
+++ b/tensorflow/python/ops/rnn_cell.py
@@ -18,11 +18,10 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import collections
 import math
 
-# pylint: disable=redefined-builtin,unused-import
-from six.moves import xrange
-# pylint: enable=redefined-builtin,unused-import
+import six
 
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
@@ -39,6 +38,88 @@ from tensorflow.python.ops.math_ops import tanh
 from tensorflow.python.platform import tf_logging as logging
 
 
+def _is_sequence(seq):
+  return (isinstance(seq, collections.Sequence)
+          and not isinstance(seq, six.string_types))
+
+
+def _packed_state_with_indices(structure, flat, index):
+  """Helper function for _packed_state.
+
+  Args:
+    structure: Substructure (tuple of elements and/or tuples) to mimic
+    flat: Flattened values to output substructure for.
+    index: Index at which to start reading from flat.
+
+  Returns:
+    The tuple (new_index, child), where:
+      * new_index - the updated index into `flat` having processed `structure`.
+      * packed - the subset of `flat` corresponding to `structure`,
+                 having started at `index`, and packed into the same nested
+                 format.
+
+  Raises:
+    ValueError: if `structure` contains more elements than `flat`
+      (assuming indexing starts from `index`).
+  """
+  packed = []
+  for s in structure:
+    if _is_sequence(s):
+      new_index, child = _packed_state_with_indices(s, flat, index)
+      packed.append(type(s)(child))
+      index = new_index
+    else:
+      packed.append(flat[index])
+      index += 1
+  return (index, packed)
+
+
+def _yield_unpacked_state(state):
+  for s in state:
+    if _is_sequence(s):
+      for si in _yield_unpacked_state(s):
+        yield si
+    else:
+      yield s
+
+
+def _unpacked_state(state):
+  if not _is_sequence(state):
+    raise TypeError("state must be a sequence")
+  return type(state)(_yield_unpacked_state(state))
+
+
+def _packed_state(structure, state):
+  """Returns the flat state packed into a recursive tuple like structure.
+
+  Args:
+    structure: tuple or list constructed of scalars and/or other tuples/lists.
+    state: flattened state.
+
+  Returns:
+    packed: `state` converted to have the same recursive structure as
+      `structure`.
+
+  Raises:
+    TypeError: If structure or state is not a tuple or list.
+    ValueError: If state and structure have different element counts.
+  """
+  if not _is_sequence(structure):
+    raise TypeError("structure must be a sequence")
+  if not _is_sequence(state):
+    raise TypeError("state must be a sequence")
+
+  flat_structure = _unpacked_state(structure)
+  if len(flat_structure) != len(state):
+    raise ValueError(
+        "Internal error: Could not pack state.  Structure had %d elements, but "
+        "state had %d elements.  Structure: %s, state: %s."
+        % (len(flat_structure), len(state), structure, state))
+
+  (_, packed) = _packed_state_with_indices(structure, state, 0)
+  return type(structure)(packed)
+
+
 class RNNCell(object):
   """Abstract object representing an RNN cell.
 
@@ -98,17 +179,19 @@ class RNNCell(object):
       If `state_size` is an int, then the return value is a `2-D` tensor of
       shape `[batch_size x state_size]` filled with zeros.
 
-      If `state_size` is a list or tuple of ints, then the return value is
-      a tuple of `2-D` tensors with shape
-      `[batch_size x s] for s in state_size`.
+      If `state_size` is a nested list or tuple, then the return value is
+      a nested list or tuple (of the same structure) of `2-D` tensors with
+    the shapes `[batch_size x s]` for each s in `state_size`.
     """
     state_size = self.state_size
-    if isinstance(state_size, (list, tuple)):
-      zeros = tuple(
+    if _is_sequence(state_size):
+      state_size_flat = _unpacked_state(state_size)
+      zeros_flat = [
           array_ops.zeros(array_ops.pack([batch_size, s]), dtype=dtype)
-          for s in state_size)
-      for s, z in zip(state_size, zeros):
+          for s in state_size_flat]
+      for s, z in zip(state_size_flat, zeros_flat):
         z.set_shape([None, s])
+      zeros = _packed_state(structure=state_size, state=zeros_flat)
     else:
       zeros = array_ops.zeros(
           array_ops.pack([batch_size, state_size]), dtype=dtype)
@@ -675,7 +758,7 @@ class MultiRNNCell(RNNCell):
     self._cells = cells
     self._state_is_tuple = state_is_tuple
     if not state_is_tuple:
-      if any(isinstance(c.state_size, (list, tuple)) for c in self._cells):
+      if any(_is_sequence(c.state_size) for c in self._cells):
         raise ValueError("Some cells return tuples of states, but the flag "
                          "state_is_tuple is not set.  State sizes are: %s"
                          % str([c.state_size for c in self._cells]))
@@ -700,7 +783,7 @@ class MultiRNNCell(RNNCell):
       for i, cell in enumerate(self._cells):
         with vs.variable_scope("Cell%d" % i):
           if self._state_is_tuple:
-            if not isinstance(state, (list, tuple)):
+            if not _is_sequence(state):
               raise ValueError(
                   "Expected state to be a tuple of length %d, but received: %s"
                   % (len(self.state_size), state))
@@ -778,9 +861,9 @@ def _linear(args, output_size, bias, bias_start=0.0, scope=None):
   Raises:
     ValueError: if some of the arguments has unspecified or wrong shape.
   """
-  if args is None or (isinstance(args, (list, tuple)) and not args):
+  if args is None or (_is_sequence(args) and not args):
     raise ValueError("`args` must be specified")
-  if not isinstance(args, (list, tuple)):
+  if not _is_sequence(args):
     args = [args]
 
   # Calculate the total size of arguments on dimension 1.
diff --git a/tensorflow/python/ops/sparse_grad.py b/tensorflow/python/ops/sparse_grad.py
index 0d43c856342..d2e1ae20967 100644
--- a/tensorflow/python/ops/sparse_grad.py
+++ b/tensorflow/python/ops/sparse_grad.py
@@ -173,6 +173,12 @@ def _SparseTensorDenseMatMulGrad(op, grad):
   return (None, a_values_grad, None, b_grad)
 
 
+@ops.RegisterGradient("SparseDenseCwiseAdd")
+def _SparseDenseCwiseAddGrad(unused_op, unused_grad):
+  raise NotImplementedError("Gradient for SparseDenseCwiseAdd is currently not"
+                            " implemented yet.")
+
+
 def _SparseDenseCwiseMulOrDivGrad(op, grad, is_mul):
   """Common code for SparseDenseCwise{Mul,Div} gradients."""
   x_indices = op.inputs[0]
@@ -218,3 +224,9 @@ def _SparseDenseCwiseMulGrad(op, grad):
 def _SparseDenseCwiseDivGrad(op, grad):
   """Gradients for SparseDenseCwiseDiv."""
   return _SparseDenseCwiseMulOrDivGrad(op, grad, False)
+
+
+@ops.RegisterGradient("SparseSoftmax")
+def _SparseSoftmaxGrad(unused_op, unused_grad):
+  raise NotImplementedError("SparseSoftmax op doesn't have its gradient"
+                            "implemented yet")
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index f19474e8192..4df0e9c5d8e 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -37,10 +37,12 @@ dimension, and dense along all other dimensions.
 @@sparse_reorder
 @@sparse_split
 @@sparse_retain
+@@sparse_reset_shape
 @@sparse_fill_empty_rows
 
 ## Math Operations
 @@sparse_add
+@@sparse_softmax
 @@sparse_tensor_dense_matmul
 """
 from __future__ import absolute_import
@@ -55,6 +57,8 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import gen_sparse_ops
 from tensorflow.python.ops import math_ops
 # go/tf-wildcard-import
@@ -64,18 +68,26 @@ from tensorflow.python.ops.gen_sparse_ops import *
 
 
 # pylint: disable=protected-access
-def sparse_concat(concat_dim, sp_inputs, name=None):
+def sparse_concat(concat_dim, sp_inputs, name=None, expand_nonconcat_dim=False):
   """Concatenates a list of `SparseTensor` along the specified dimension.
 
   Concatenation is with respect to the dense versions of each sparse input.
   It is assumed that each inputs is a `SparseTensor` whose elements are ordered
   along increasing dimension number.
 
-  All inputs' shapes must match, except for the concat dimension.  The
-  `indices`, `values`, and `shapes` lists must have the same length.
+  If expand_nonconcat_dim is False, all inputs' shapes must match, except for
+  the concat dimension. If expand_nonconcat_dim is True, then inputs' shapes are
+  allowd to vary among all inputs.
 
-  The output shape is identical to the inputs', except along the concat
-  dimension, where it is the sum of the inputs' sizes along that dimension.
+  The `indices`, `values`, and `shapes` lists must have the same length.
+
+  If expand_nonconcat_dim is False, then the output shape is identical to the
+  inputs', except along the concat dimension, where it is the sum of the inputs'
+  sizes along that dimension.
+
+  If expand_nonconcat_dim is True, then the output shape along the non-concat
+  dimensions will be expand to be the largest among all inputs, and it is the
+  sum of the inputs sizes along the concat dimension.
 
   The output elements will be resorted to preserve the sort order along
   increasing dimension number.
@@ -109,10 +121,40 @@ def sparse_concat(concat_dim, sp_inputs, name=None):
       [    a] concat [  d e  ] = [    a   d e  ]
       [b c  ]        [       ]   [b c          ]
 
+  Another example, if 'concat_dim = 1' and the inputs are
+
+      sp_inputs[0]: shape = [3, 3]
+      [0, 2]: "a"
+      [1, 0]: "b"
+      [2, 1]: "c"
+
+      sp_inputs[1]: shape = [2, 4]
+      [0, 1]: "d"
+      [0, 2]: "e"
+
+  if expand_nonconcat_dim = False, this will result in an error. But if
+  expand_nonconcat_dim = True, this will result in:
+
+      shape = [3, 7]
+      [0, 2]: "a"
+      [0, 4]: "d"
+      [0, 5]: "e"
+      [1, 0]: "b"
+      [2, 1]: "c"
+
+  Graphically this is equivalent to doing
+
+      [    a] concat [  d e  ] = [    a   d e  ]
+      [b    ]        [       ]   [b            ]
+      [  c  ]                    [  c          ]
+
+
   Args:
     concat_dim: Dimension to concatenate along.
     sp_inputs: List of `SparseTensor` to concatenate.
     name: A name prefix for the returned tensors (optional).
+    expand_nonconcat_dim: Whether to allow the expansion in the non-concat
+      dimensions. Defaulted to False.
 
   Returns:
     A `SparseTensor` with the concatenated output.
@@ -132,6 +174,14 @@ def sparse_concat(concat_dim, sp_inputs, name=None):
   vals = [sp_input.values for sp_input in sp_inputs]
   shapes = [sp_input.shape for sp_input in sp_inputs]
 
+  if expand_nonconcat_dim:
+    max_shape = math_ops.reduce_max(array_ops.concat(0, [array_ops.reshape(
+        shape, [1, -1]) for shape in shapes]), 0)
+    shapes = [array_ops.concat(0, [max_shape[:concat_dim],
+                                   shape[concat_dim:concat_dim + 1],
+                                   max_shape[concat_dim + 1:]])
+              for shape in shapes]
+
   output_ind, output_val, output_shape = (
       gen_sparse_ops._sparse_concat(inds,
                                     vals,
@@ -226,6 +276,31 @@ def _SparseAddShape(op):  # pylint: disable=invalid-name
   ]
 
 
+def sparse_dense_cwise_add(sp_t, dense_t):
+  """Adds up a SparseTensor and a dense Tensor, using these special rules:
+
+  (1) Broadcasts the dense side to have the same shape as the sparse side, if
+      eligible;
+  (2) Then, only the dense values pointed to by the indices of the SparseTensor
+      participate in the cwise addition.
+
+  By the rules, the result is a logical SparseTensor with exactly the same
+  indices and shape, but possibly with different non-zero values.  The output of
+  this Op is the resultant non-zero values.
+
+  Args:
+    sp_t: the SparseTensor operand.
+    dense_t: the dense Tensor operand; must have the same dtype and a
+      broadcast-compatible shape as `sp_t`.
+
+  Returns:
+    output: the SparseTensor output.
+  """
+  result = gen_sparse_ops.sparse_dense_cwise_add(sp_t.indices, sp_t.values,
+                                                 sp_t.shape, dense_t)
+  return ops.SparseTensor(sp_t.indices, result, sp_t.shape)
+
+
 @ops.RegisterShape("SparseTensorDenseAdd")
 def _SparseTensorDenseAddShape(op):  # pylint: disable=invalid-name
   return [op.inputs[3].get_shape()]
@@ -765,6 +840,91 @@ def sparse_retain(sp_input, to_retain):
                           array_ops.identity(sp_input.shape))
 
 
+def sparse_reset_shape(sp_input, new_shape=None):
+  """Resets the shape of a `SparseTensor` with indices and values unchanged.
+
+  If `new_shape` is None, returns a copy of `sp_input` with its shape reset
+  to the tight bounding box of `sp_input`.
+
+  If `new_shape` is provided, then it must be larger or equal in all dimensions
+  compared to the shape of `sp_input`. When this condition is met, the returned
+  SparseTensor will have its shape reset to `new_shape` and its indices and
+  values unchanged from that of `sp_input.`
+
+  For example:
+
+    Consider a `sp_input` with shape [2, 3, 5]:
+
+      [0, 0, 1]: a
+      [0, 1, 0]: b
+      [0, 2, 2]: c
+      [1, 0, 3]: d
+
+    - It is an error to set `new_shape` as [3, 7] since this represents a
+      rank-2 tensor while `sp_input` is rank-3. This is either a ValueError
+      during graph construction (if both shapes are known) or an OpError during
+      run time.
+
+    - Setting `new_shape` as [2, 3, 6] will be fine as this shape is larger or
+      eqaul in every dimension compared to the original shape [2, 3, 5].
+
+    - On the other hand, setting new_shape as [2, 3, 4] is also an error: The
+      third dimension is smaller than the original shape [2, 3, 5] (and an
+      `InvalidArgumentError` will be raised).
+
+    - If `new_shape` is None, the returned SparseTensor will have a shape
+      [2, 3, 4], which is the tight bounding box of `sp_input`.
+
+  Args:
+    sp_input: The input `SparseTensor`.
+    new_shape: None or a vector representing the new shape for the returned
+      `SpraseTensor`.
+
+  Returns:
+    A `SparseTensor` indices and values unchanged from `input_sp`. Its shape is
+      `new_shape` if that is set. Otherwise it is  the tight bounding box of
+       `input_sp`
+
+  Raises:
+    TypeError: If `sp_input` is not a `SparseTensor`.
+    ValueError: If `new_shape` represents a tensor with a different rank from
+      that of `sp_input` (if shapes are known when graph is constructed).
+    OpError:
+      - If `new_shape` has dimension sizes that are too small.
+      - If shapes are not known during graph construction time, and during run
+        time it is found out that the ranks do not match.
+  """
+  if not isinstance(sp_input, ops.SparseTensor):
+    raise TypeError("Input must be a SparseTensor")
+
+  in_indices = array_ops.identity(sp_input.indices)
+  in_values = array_ops.identity(sp_input.values)
+  in_shape = array_ops.identity(sp_input.shape)
+
+  if new_shape is None:
+    dim_low_bound = math_ops.reduce_max(in_indices, 0)
+    output_shape_tensor = math_ops.add(dim_low_bound,
+                                       array_ops.ones_like(in_shape))
+  else:
+    output_shape_tensor = ops.convert_to_tensor(new_shape)
+    output_shape_tensor.get_shape().assert_has_rank(1)
+    output_shape_tensor = math_ops.cast(output_shape_tensor, dtypes.int64)
+    # For cases when shape is known during graph construction, this catches the
+    # error before the ops.SparseTensor catches it.
+    output_shape_tensor.get_shape()[0].merge_with(in_shape.get_shape()[0])
+
+    # For cases where shape is not known during graph construction.
+    output_shape_tensor = control_flow_ops.with_dependencies(
+        [check_ops.assert_equal(array_ops.shape(in_shape),
+                                array_ops.shape(output_shape_tensor))],
+        output_shape_tensor)
+    output_shape_tensor = control_flow_ops.with_dependencies(
+        [check_ops.assert_less_equal(in_shape, output_shape_tensor)],
+        output_shape_tensor)
+
+  return ops.SparseTensor(in_indices, in_values, output_shape_tensor)
+
+
 def sparse_fill_empty_rows(sp_input, default_value, name=None):
   """Fills empty rows in the input 2-D `SparseTensor` with a default value.
 
@@ -1189,3 +1349,66 @@ def _SparseTensorDenseMatMulShape(op):  # pylint: disable=invalid-name
   b_shape = op.inputs[3].get_shape().with_rank(2)
   output_shape_right = b_shape[0] if adjoint_b else b_shape[1]
   return [tensor_shape.matrix(None, output_shape_right)]
+
+
+def sparse_softmax(sp_input, name=None):
+  """Applies softmax to a batched N-D `SparseTensor`.
+
+  The inputs represent an N-D SparseTensor  with logical shape `[..., B, C]`
+  (where `N >= 2`), and with indices sorted in the canonical lexicographic
+  order.
+
+  This op is equivalent to applying the normal `tf.nn.softmax()` to each
+  innermost logical submatrix with shape `[B, C]`, but with the catch that *the
+  implicitly zero elements do not participate*.  Specifically, the algorithm is
+  equivalent to:
+
+    (1) Applies `tf.nn.softmax()` to a densified view of each innermost
+        submatrix with shape `[B, C]`, along the size-C dimension;
+    (2) Masks out the original implicitly-zero locations;
+    (3) Renormalizes the remaining elements.
+
+  Hence, the `SparseTensor` result has exactly the same non-zero indices and
+  shape.
+
+  Example:
+  ```python
+  # First batch:
+  # [?   e.]
+  # [1.  ? ]
+  # Second batch:
+  # [e   ? ]
+  # [e   e ]
+  shape = [2, 2, 2]  # 3-D SparseTensor
+  values = np.asarray([[[0., np.e], [1., 0.]], [[np.e, 0.], [np.e, np.e]]])
+  indices = np.vstack(np.where(values)).astype(np.int64).T
+
+  result = tf.sparse_softmax(tf.SparseTensor(indices, values, shape))
+  # ...returning a 3-D SparseTensor, equivalent to:
+  # [?   1.]     [1    ?]
+  # [1.  ? ] and [.5  .5]
+  # where ? means implicitly zero.
+  ```
+
+  Args:
+    sp_input: N-D `SparseTensor`, where `N >= 2`.
+    name: optional name of the operation.
+  Returns:
+    output: N-D `SparseTensor` representing the results.
+  """
+  with ops.op_scope([sp_input.indices, sp_input.values], name,
+                    "SparseSoftmax") as name:
+    out_vals = gen_sparse_ops.sparse_softmax(sp_input.indices,
+                                             sp_input.values,
+                                             sp_input.shape)
+    return ops.SparseTensor(sp_input.indices, out_vals, sp_input.shape)
+
+
+@ops.RegisterShape("SparseSoftmax")
+def _SparseSoftmaxShape(op):  # pylint: disable=invalid-name
+  """Shape function for SparseSoftmax op."""
+  unused_indices_shape = op.inputs[0].get_shape().with_rank(2)
+  values_shape = op.inputs[1].get_shape().with_rank(1)
+  unused_shape_shape = op.inputs[2].get_shape().with_rank(1)
+  nnz = values_shape[0]
+  return [tensor_shape.vector(nnz)]
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index 1cd38af3b8f..e057ba64079 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -19,6 +19,7 @@ String hashing ops take a string input tensor and map each element to an
 integer.
 
 @@string_to_hash_bucket_fast
+@@string_to_hash_bucket_strong
 @@string_to_hash_bucket
 
 ## Joining
@@ -49,10 +50,12 @@ from tensorflow.python.ops.gen_string_ops import *
 
 ops.NoGradient("StringToHashBucket")
 ops.NoGradient("StringToHashBucketFast")
+ops.NoGradient("StringToHashBucketStrong")
 ops.NoGradient("ReduceJoin")
 
 ops.RegisterShape("StringToHashBucket")(common_shapes.unchanged_shape)
 ops.RegisterShape("StringToHashBucketFast")(common_shapes.unchanged_shape)
+ops.RegisterShape("StringToHashBucketStrong")(common_shapes.unchanged_shape)
 
 
 @ops.RegisterShape("ReduceJoin")
diff --git a/tensorflow/stream_executor/platform/default/mutex.h b/tensorflow/stream_executor/platform/default/mutex.h
index 0ce1eeadbb8..b834895da5c 100644
--- a/tensorflow/stream_executor/platform/default/mutex.h
+++ b/tensorflow/stream_executor/platform/default/mutex.h
@@ -35,6 +35,9 @@ limitations under the License.
 namespace perftools {
 namespace gputools {
 
+#undef mutex_lock
+#undef shared_lock
+
 enum ConditionResult { kCond_Timeout, kCond_MaybeNotified };
 
 #ifdef STREAM_EXECUTOR_USE_SHARED_MUTEX
@@ -62,6 +65,9 @@ class SCOPED_LOCKABLE mutex_lock : public std::unique_lock<BaseMutex> {
   ~mutex_lock() RELEASE() {}
 };
 
+// Catch bug where variable name is omitted, e.g. mutex_lock (mu);
+#define mutex_lock(x) static_assert(0, "mutex_lock_decl_missing_var_name");
+
 #ifdef STREAM_EXECUTOR_USE_SHARED_MUTEX
 // TODO(vrv): Annotate these with ACQUIRE_SHARED after implementing
 // as classes.
@@ -70,6 +76,9 @@ typedef std::shared_lock<BaseMutex> shared_lock;
 typedef mutex_lock shared_lock;
 #endif
 
+// Catch bug where variable name is omitted, e.g. shared_lock (mu);
+#define shared_lock(x) static_assert(0, "shared_lock_decl_missing_var_name");
+
 using std::condition_variable;
 
 inline ConditionResult WaitForMilliseconds(mutex_lock* mu,
diff --git a/tensorflow/tensorboard/TAG b/tensorflow/tensorboard/TAG
index 209e3ef4b62..aabe6ec3909 100644
--- a/tensorflow/tensorboard/TAG
+++ b/tensorflow/tensorboard/TAG
@@ -1 +1 @@
-20
+21
diff --git a/tensorflow/tensorboard/bower.json b/tensorflow/tensorboard/bower.json
index 0522cb8dff9..1a082fbb816 100644
--- a/tensorflow/tensorboard/bower.json
+++ b/tensorflow/tensorboard/bower.json
@@ -55,7 +55,7 @@
     "iron-list": "PolymerElements/iron-list#1.1.7",
     "iron-menu-behavior": "PolymerElements/iron-menu-behavior#1.1.5",
     "iron-meta": "PolymerElements/iron-meta#1.1.1",
-    "iron-overlay-behavior": "PolymerElements/iron-overlay-behavior#1.7.2",
+    "iron-overlay-behavior": "PolymerElements/iron-overlay-behavior#1.7.6",
     "iron-range-behavior": "PolymerElements/iron-range-behavior#1.0.4",
     "iron-resizable-behavior": "PolymerElements/iron-resizable-behavior#1.0.3",
     "iron-selector": "PolymerElements/iron-selector#1.2.4",
@@ -129,7 +129,7 @@
     "iron-list": "1.1.7",
     "iron-menu-behavior": "1.1.5",
     "iron-meta": "1.1.1",
-    "iron-overlay-behavior": "1.7.2",
+    "iron-overlay-behavior": "1.7.6",
     "iron-range-behavior": "1.0.4",
     "iron-resizable-behavior": "1.0.3",
     "iron-selector": "1.2.4",
diff --git a/tensorflow/tensorboard/components/tf-event-dashboard/tf-chart.ts b/tensorflow/tensorboard/components/tf-event-dashboard/tf-chart.ts
index 852369abbbe..21daf50800f 100644
--- a/tensorflow/tensorboard/components/tf-event-dashboard/tf-chart.ts
+++ b/tensorflow/tensorboard/components/tf-event-dashboard/tf-chart.ts
@@ -293,7 +293,8 @@ module TF {
         let points = plot.datasets().map(
             (dataset) => this.findClosestPoint(target, dataset));
         let pointsToCircle = points.filter(
-            (p) => Plottable.Utils.DOM.intersectsBBox(p.x, p.y, centerBBox));
+            (p) => p != null &&
+                Plottable.Utils.DOM.intersectsBBox(p.x, p.y, centerBBox));
         let pts: any = pointsComponent.content().selectAll('.point').data(
             pointsToCircle, (p: Point) => p.dataset.metadata().run);
         if (points.length !== 0) {
diff --git a/tensorflow/tensorboard/components/tf-multi-checkbox/tf-multi-checkbox.html b/tensorflow/tensorboard/components/tf-multi-checkbox/tf-multi-checkbox.html
index a5db63a5e2a..6ae2a573624 100644
--- a/tensorflow/tensorboard/components/tf-multi-checkbox/tf-multi-checkbox.html
+++ b/tensorflow/tensorboard/components/tf-multi-checkbox/tf-multi-checkbox.html
@@ -200,7 +200,8 @@ handle these situations gracefully.
       var name = e.srcElement.name;
       var checked = e.srcElement.checked;
       this.runToIsCheckedMapping[name] = checked;
-      this.notifyPath("runToIsCheckedMapping." + name, checked);
+      // n.b. notifyPath won't work because run names may have periods.
+      this.runToIsCheckedMapping = _.clone(this.runToIsCheckedMapping);
     },
     _isChecked: function(item, outSelectedChange) {
       return this.runToIsCheckedMapping[item];
diff --git a/tensorflow/tensorboard/dist/tf-tensorboard.html b/tensorflow/tensorboard/dist/tf-tensorboard.html
index be22f17b4f4..b54882c373b 100644
--- a/tensorflow/tensorboard/dist/tf-tensorboard.html
+++ b/tensorflow/tensorboard/dist/tf-tensorboard.html
@@ -358,7 +358,8 @@ var TF;
       var name = e.srcElement.name;
       var checked = e.srcElement.checked;
       this.runToIsCheckedMapping[name] = checked;
-      this.notifyPath("runToIsCheckedMapping." + name, checked);
+      // n.b. notifyPath won't work because run names may have periods.
+      this.runToIsCheckedMapping = _.clone(this.runToIsCheckedMapping);
     },
     _isChecked: function(item, outSelectedChange) {
       return this.runToIsCheckedMapping[item];
@@ -453,6 +454,7 @@ var TF;
     --tb-grey-lighter: #f3f3f3;
     --tb-ui-dark-accent: #757575;
     --tb-ui-light-accent: #e0e0e0;
+    --tb-graph-faded: #e0d4b3;
   }
 
 </style>
@@ -1541,7 +1543,8 @@ var TF;
                 };
                 var centerBBox = _this.gridlines.content().node().getBBox();
                 var points = plot.datasets().map(function (dataset) { return _this.findClosestPoint(target, dataset); });
-                var pointsToCircle = points.filter(function (p) { return Plottable.Utils.DOM.intersectsBBox(p.x, p.y, centerBBox); });
+                var pointsToCircle = points.filter(function (p) { return p != null &&
+                    Plottable.Utils.DOM.intersectsBBox(p.x, p.y, centerBBox); });
                 var pts = pointsComponent.content().selectAll('.point').data(pointsToCircle, function (p) { return p.dataset.metadata().run; });
                 if (points.length !== 0) {
                     pts.enter().append('circle').classed('point', true);
@@ -3697,16 +3700,16 @@ Polymer({
   },
   observers: [
     '_selectedDatasetChanged(selectedDataset, datasets)',
-    '_readAndParseMetadata(selectedDataset, selectedMetadataTag, datasets)'
+    '_readAndParseMetadata(selectedMetadataTag)'
   ],
-  _readAndParseMetadata: function(datasetIndex, metadataIndex, datasets) {
-    if (metadataIndex == -1 || datasets[datasetIndex] == null ||
-        datasets[datasetIndex].runMetadata == null ||
-        datasets[datasetIndex].runMetadata[metadataIndex] == null) {
+  _readAndParseMetadata: function(metadataIndex) {
+    if (metadataIndex == -1 || this.datasets[this.selectedDataset] == null ||
+        this.datasets[this.selectedDataset].runMetadata == null ||
+        this.datasets[this.selectedDataset].runMetadata[metadataIndex] == null) {
       this._setOutStats(null);
       return;
     }
-    var path = datasets[datasetIndex].runMetadata[metadataIndex].path;
+    var path = this.datasets[this.selectedDataset].runMetadata[metadataIndex].path;
     // Reset the progress bar to 0.
     this.set('progress', {
       value: 0,
@@ -6577,8 +6580,9 @@ var tf;
              * for each node in the graph.
              */
             var RenderGraphInfo = (function () {
-                function RenderGraphInfo(hierarchy) {
+                function RenderGraphInfo(hierarchy, displayingStats) {
                     this.hierarchy = hierarchy;
+                    this.displayingStats = displayingStats;
                     this.index = {};
                     this.computeScales();
                     // Maps node name to whether the rendering hierarchy was already
@@ -6663,6 +6667,9 @@ var tf;
                         renderInfo.computeTimeColor =
                             this.computeTimeScale(node.stats.totalMicros);
                     }
+                    // We only fade nodes when we're displaying stats.
+                    renderInfo.isFadedOut = this.displayingStats &&
+                        !tf.graph.util.hasDisplayableNodeStats(node.stats);
                     if (node.isGroupNode) {
                         // Make a list of tuples (device, proportion), where proportion
                         // is the fraction of op nodes that have that device.
@@ -6768,6 +6775,8 @@ var tf;
                     _.each(metagraph.edges(), function (edgeObj) {
                         var metaedge = metagraph.edge(edgeObj);
                         var renderMetaedgeInfo = new RenderMetaedgeInfo(metaedge);
+                        renderMetaedgeInfo.isFadedOut =
+                            _this.index[edgeObj.v].isFadedOut || _this.index[edgeObj.w].isFadedOut;
                         coreGraph.setEdge(edgeObj.v, edgeObj.w, renderMetaedgeInfo);
                     });
                     if (PARAMS.enableExtraction &&
@@ -7220,6 +7229,8 @@ var tf;
                     this.isInExtract = false;
                     this.isOutExtract = false;
                     this.coreBox = { width: 0, height: 0 };
+                    // By default, we don't fade nodes out. Default to false for safety.
+                    this.isFadedOut = false;
                 }
                 RenderNodeInfo.prototype.isInCore = function () {
                     return !this.isInExtract && !this.isOutExtract;
@@ -7237,6 +7248,7 @@ var tf;
                     this.adjoiningMetaedge = null;
                     this.structural = false;
                     this.weight = 1;
+                    this.isFadedOut = false;
                 }
                 return RenderMetaedgeInfo;
             }());
@@ -8197,6 +8209,7 @@ var tf;
                  * d's label property will be a RenderMetaedgeInfo object.
                  */
                 function stylize(edgeGroup, d, stylize) {
+                    edgeGroup.classed('faded', d.label.isFadedOut);
                     var metaedge = d.label.metaedge;
                     edgeGroup.select('path.' + scene.Class.Edge.LINE)
                         .classed('control-dep', metaedge && !metaedge.numRegularEdges);
@@ -8600,7 +8613,10 @@ var tf;
                                 stampType =
                                     groupNodeInfo.node.hasNonControlEdges ? 'vertical' : 'horizontal';
                             }
-                            scene.selectOrCreateChild(shapeGroup, 'use', scene.Class.Node.COLOR_TARGET)
+                            scene
+                                .selectOrCreateChild(shapeGroup, 'use', scene.Class.Node.COLOR_TARGET + ' ' + groupNodeInfo.isFadedOut ?
+                                'faded-ellipse' :
+                                '')
                                 .attr('xlink:href', '#op-series-' + stampType + '-stamp');
                             scene.selectOrCreateChild(shapeGroup, 'rect', scene.Class.Node.COLOR_TARGET)
                                 .attr({ rx: d.radius, ry: d.radius });
@@ -8782,10 +8798,12 @@ var tf;
                     var isSelected = sceneElement.isNodeSelected(renderInfo.node.name);
                     var isExtract = renderInfo.isInExtract || renderInfo.isOutExtract;
                     var isExpanded = renderInfo.expanded;
+                    var isFadedOut = renderInfo.isFadedOut;
                     nodeGroup.classed('highlighted', isHighlighted);
                     nodeGroup.classed('selected', isSelected);
                     nodeGroup.classed('extract', isExtract);
                     nodeGroup.classed('expanded', isExpanded);
+                    nodeGroup.classed('faded', isFadedOut);
                     // Main node always exists here and it will be reached before subscene,
                     // so d3 selection is fine here.
                     var node = nodeGroup.select('.' + nodeClass + ' .' + scene.Class.Node.COLOR_TARGET);
@@ -9677,6 +9695,14 @@ var tf;
                 return (value.toPrecision(3) - 0) + ' ' + units[unitIndex].symbol;
             }
             util.convertUnitsToHumanReadable = convertUnitsToHumanReadable;
+            function hasDisplayableNodeStats(stats) {
+                if (stats &&
+                    (stats.totalBytes > 0 || stats.totalMicros > 0 || stats.outputSize)) {
+                    return true;
+                }
+                return false;
+            }
+            util.hasDisplayableNodeStats = hasDisplayableNodeStats;
         })(util = graph.util || (graph.util = {}));
     })(graph = tf.graph || (tf.graph = {}));
 })(tf || (tf = {}));
@@ -10083,6 +10109,34 @@ Polymer({
   stroke-width: 4;
 }
 
+::content .faded,
+::content .faded rect,
+::content .faded ellipse,
+::content .faded path,
+::content #rectHatch line,
+::content #ellipseHatch line {
+  color: var(--tb-graph-faded) !important;
+  fill: white;
+  stroke: var(--tb-graph-faded) !important;
+}
+
+
+::content .faded path {
+  stroke-width: 1px !important;
+}
+
+::content .faded rect {
+  fill: url("#rectHatch") !important;
+}
+
+::content .faded ellipse {
+  fill: url("#ellipseHatch") !important;
+}
+
+::content .faded text {
+  opacity: 0;
+}
+
 
 /* --- Op Node --- */
 
@@ -10232,10 +10286,18 @@ Polymer({
   marker-end: url("#annotation-arrowhead");
 }
 
+::content .faded .annotation > .annotation-edge {
+  marker-end: url("#annotation-arrowhead-faded");
+}
+
 ::content .annotation > .annotation-edge.refline {
   marker-start: url("#ref-annotation-arrowhead");
 }
 
+::content .faded .annotation > .annotation-edge.refline {
+  marker-start: url("#ref-annotation-arrowhead-faded");
+}
+
 ::content .annotation > .annotation-control-edge {
   stroke-dasharray: 1, 1;
 }
@@ -10244,10 +10306,18 @@ Polymer({
   fill: #aaa;
 }
 
+::content #annotation-arrowhead-faded {
+  fill: var(--tb-graph-faded);
+}
+
 ::content #ref-annotation-arrowhead {
   fill: #aaa;
 }
 
+::content #ref-annotation-arrowhead-faded {
+  fill: var(--tb-graph-faded);
+}
+
 ::content .annotation > .annotation-label {
   font-size: 5px;
   cursor: pointer;
@@ -10398,9 +10468,15 @@ Polymer({
     <marker id="annotation-arrowhead" markerWidth="5" markerHeight="5" refX="5" refY="2.5" orient="auto">
       <path d="M 0,0 L 5,2.5 L 0,5 L 0,0"></path>
     </marker>
+    <marker id="annotation-arrowhead-faded" markerWidth="5" markerHeight="5" refX="5" refY="2.5" orient="auto">
+      <path d="M 0,0 L 5,2.5 L 0,5 L 0,0"></path>
+    </marker>
     <marker id="ref-annotation-arrowhead" markerWidth="5" markerHeight="5" refX="0" refY="2.5" orient="auto">
       <path d="M 5,0 L 0,2.5 L 5,5 L 5,0"></path>
     </marker>
+    <marker id="ref-annotation-arrowhead-faded" markerWidth="5" markerHeight="5" refX="0" refY="2.5" orient="auto">
+      <path d="M 5,0 L 0,2.5 L 5,5 L 5,0"></path>
+    </marker>
     
     <ellipse id="op-node-stamp" rx="7.5" ry="3" stroke="inherit" fill="inherit"></ellipse>
     
@@ -10428,6 +10504,14 @@ Polymer({
     </svg>
     
     <g id="linearGradients"></g>
+
+    
+    <pattern id="rectHatch" patternTransform="rotate(45 0 0)" width="5" height="5" patternUnits="userSpaceOnUse">
+      <line x1="0" y1="0" x2="0" y2="5" style="stroke-width: 1"></line>
+    </pattern>
+    <pattern id="ellipseHatch" patternTransform="rotate(45 0 0)" width="2" height="2" patternUnits="userSpaceOnUse">
+      <line x1="0" y1="0" x2="0" y2="2" style="stroke-width: 1"></line>
+    </pattern>
   </defs>
   
   <rect fill="white" width="10000" height="10000"></rect>
@@ -10908,9 +10992,12 @@ Polymer({
     '_buildRenderHierarchy(graphHierarchy)'
   ],
   _statsChanged: function(stats) {
-    if (stats != null) {
-      tf.graph.joinStatsInfoWithGraph(this.basicGraph, stats);
-      tf.graph.hierarchy.joinAndAggregateStats(this.graphHierarchy, stats);
+    if (this.graphHierarchy) {
+      if (stats != null) {
+        tf.graph.joinStatsInfoWithGraph(this.basicGraph, stats);
+        tf.graph.hierarchy.joinAndAggregateStats(this.graphHierarchy, stats);
+      }
+
       // Recompute the rendering information.
       this._buildRenderHierarchy(this.graphHierarchy);
     }
@@ -10923,7 +11010,8 @@ Polymer({
         // and thus mistakenly pass non-metanode to this module.
         return;
       }
-      var renderGraph = new tf.graph.render.RenderGraphInfo(graphHierarchy);
+      var renderGraph = new tf.graph.render.RenderGraphInfo(
+          graphHierarchy, !!this.stats /** displayingStats */);
       // Producing the 'color by' parameters to be consumed
       // by the tf-graph-controls panel. It contains information about the
       // min and max values and their respective colors, as well as list
@@ -11083,6 +11171,19 @@ Polymer({
 });
 </script>
 <dom-module id="tf-graph-icon" assetpath="../tf-graph/">
+  <style>
+    .faded-rect {
+      fill: url("#rectHatch");
+    }
+
+    .faded-ellipse {
+      fill: url("#ellipseHatch");
+    }
+
+    .faded-rect, .faded-ellipse, .faded-series {
+      stroke:   var(--tb-graph-faded) !important;
+    }
+  </style>
   <template>
     <template is="dom-if" if="[[_isType(node, type, 'OP')]]">
       <template is="dom-if" if="[[_isConst(node, const)]]">
@@ -11097,24 +11198,24 @@ Polymer({
       </template>
       <template is="dom-if" if="[[_isRegularOp(node, const, summary)]]">
         <svg height$="[[height]]" preserveAspectRatio="xMinYMid meet" viewBox="0 0 16 8">
-          <use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#op-node-stamp" fill$="[[_getFill(_computedFill, 'OP')]]" stroke$="[[_getStroke(_computedFill, 'OP')]]" x="8" y="4"></use>
+          <use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#op-node-stamp" fill$="[[_getFill(_computedFill, 'OP')]]" stroke$="[[_getStroke(_computedFill, 'OP')]]" class$="{{_fadedClass(renderInfo, 'ellipse')}}" x="8" y="4"></use>
         </svg>
       </template>
     </template>
     <template is="dom-if" if="[[_isType(node, type, 'META')]]">
       <svg height$="[[height]]" preserveAspectRatio="xMinYMid meet" viewBox="0 0 37 16">
-        <rect x="1" y="1" fill$="[[_getFill(_computedFill, 'META')]]" stroke$="[[_getStroke(_computedFill, 'META')]]" stroke-width="2px" height="14" width="35" rx="5" ry="5"></rect>
+        <rect x="1" y="1" fill$="[[_getFill(_computedFill, 'META')]]" stroke$="[[_getStroke(_computedFill, 'META')]]" class$="{{_fadedClass(renderInfo, 'rect')}}" stroke-width="2px" height="14" width="35" rx="5" ry="5"></rect>
       </svg>
     </template>
     <template is="dom-if" if="[[_isType(node, type, 'SERIES')]]">
       <template is="dom-if" if="[[_isVertical(node, vertical)]]">
         <svg height$="[[height]]" preserveAspectRatio="xMinYMid meet" viewBox="0 0 16 15">
-          <use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#op-series-vertical-stamp" fill$="[[_getFill(_computedFill, 'SERIES')]]" stroke$="[[_getStroke(_computedFill, 'SERIES')]]" x="0" y="2"></use>
+          <use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#op-series-vertical-stamp" fill$="[[_getFill(_computedFill, 'SERIES')]]" stroke$="[[_getStroke(_computedFill, 'SERIES')]]" class$="{{_fadedClass(renderInfo, 'series')}}" x="0" y="2"></use>
         </svg>
       </template>
       <template is="dom-if" if="[[!_isVertical(node, vertical)]]">
         <svg height$="[[height]]" preserveAspectRatio="xMinYMid meet" viewBox="0 0 24 10">
-          <use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#op-series-horizontal-stamp" fill$="[[_getFill(_computedFill, 'SERIES')]]" stroke$="[[_getStroke(_computedFill, 'SERIES')]]" x="0" y="1"></use>
+          <use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#op-series-horizontal-stamp" fill$="[[_getFill(_computedFill, 'SERIES')]]" stroke$="[[_getStroke(_computedFill, 'SERIES')]]" class$="{{_fadedClass(renderInfo, 'series')}}" x="0" y="1"></use>
         </svg>
       </template>
     </template>
@@ -11308,6 +11409,10 @@ Polymer({
         _isRegularOp: function(inputNode, inputConst, inputSummary) {
           return !this._isConst(inputNode, inputConst) &&
               !this._isSummary(inputNode, inputSummary);
+        },
+
+        _fadedClass: function(itemRenderInfo, shape) {
+          return itemRenderInfo && itemRenderInfo.isFadedOut ? 'faded-' + shape : '';
         }
       });
     })();
@@ -11352,12 +11457,18 @@ Polymer({
     top: 1px;
     left: 2px;
   }
+
+  .faded span {
+    color: var(--tb-graph-faded);
+  }
   </style>
   <template>
     <div id="list-item" on-mouseover="_nodeListener" on-mouseout="_nodeListener" on-click="_nodeListener">
-      <tf-graph-icon class="node-icon" height="12" color-by="[[colorBy]]" color-by-params="[[colorByParams]]" node="[[itemNode]]" render-info="[[itemRenderInfo]]" template-index="[[templateIndex]]"></tf-graph-icon>
-      <span title$="[[name]]">[[name]]</span>
-      <span class="edge-label">[[edgeLabel]]</span>
+      <div class$="{{_fadedClass(itemRenderInfo)}}">
+        <tf-graph-icon class="node-icon" height="12" color-by="[[colorBy]]" color-by-params="[[colorByParams]]" node="[[itemNode]]" render-info="[[itemRenderInfo]]" template-index="[[templateIndex]]"></tf-graph-icon>
+        <span title$="[[name]]">[[name]]</span>
+        <span class="edge-label">[[edgeLabel]]</span>
+      </div>
     </div>
   </template>
 
@@ -11409,6 +11520,10 @@ Polymer({
             nodeName: this.name,
             type: this.itemType
           });
+        },
+
+        _fadedClass: function(itemRenderInfo) {
+          return itemRenderInfo && itemRenderInfo.isFadedOut ? 'faded' : '';
         }
       });
     })();
@@ -11453,6 +11568,10 @@ Polymer({
     display: table-row;
   }
 
+  .sub-list-table-row .sub-list-table-cell:last-child {
+    text-align: right;
+  }
+
   .sub-list-table-cell {
     color: #565656;
     display: table-cell;
@@ -11795,13 +11914,7 @@ Polymer({
           return null;
         },
         _getHasDisplayableNodeStats: function(stats) {
-          if (stats &&
-              (stats.totalBytes > 0 ||
-                  stats.totalBytes > 0 ||
-                  stats.outputSize)) {
-            return true;
-          }
-          return false;
+          return tf.graph.util.hasDisplayableNodeStats(stats);
         },
         _getNodeStatsFormattedBytes(stats) {
           if (!stats || !stats.totalBytes) {
@@ -12228,7 +12341,6 @@ Polymer({
   }
 });
 </script>
-
 <dom-module id="tf-graph-controls" assetpath="../tf-graph/">
 <template>
 <style>
@@ -12315,6 +12427,7 @@ svg.icon {
   fill: #D9D9D9;
 }
 .domainValues {
+  margin-bottom: 10px;
   width: 165px;
 }
 .domainStart {
@@ -12353,8 +12466,31 @@ svg.icon {
   padding: 8px 0;
 }
 
-.color-text {
-  padding: 0 0 0 49px;
+.color-legend-row {
+  clear: both;
+  height: 20px;
+  margin-top: 5px;
+  position: relative;
+}
+
+.color-legend-row svg {
+  position: absolute;
+  top: -1px;
+  width: 40px;
+}
+
+.color-legend-row span.color-legend-value {
+  margin-left: 60px;
+}
+
+#grey-rect {
+  fill: #eee;
+  stroke: #a6a6a6;
+}
+
+#faded-rect {
+  fill: url("#rectHatch");
+  stroke: var(--tb-graph-faded);
 }
 
 .button-text {
@@ -12397,6 +12533,19 @@ span.counter {
   color: gray;
 }
 </style>
+<svg width="0" height="0">
+  <defs>
+    <g id="legend-rect">
+      <rect x="1" y="1" stroke-width="2px" height="14" width="35" rx="5" ry="5"></rect>
+    </g>
+    <g id="grey-rect">
+       <use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#legend-rect"></use>
+     </g>
+     <g id="faded-rect">
+       <use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#legend-rect"></use>
+     </g>
+  </defs>
+</svg>
 <div class="allcontrols">
   <div class="control-holder">
     <paper-icon-button icon="aspect-ratio" class="iconbutton" on-click="fit" alt="Fit to screen">
@@ -12466,11 +12615,22 @@ span.counter {
         <div class="domainStart">[[_currentGradientParams.minValue]]</div>
         <div class="domainEnd">[[_currentGradientParams.maxValue]]</div>
       </div>
+      <br style="clear: both">
     </template>
     <template is="dom-if" if="[[_equals(colorBy, 'structure')]]">
       <div class="color-text">
-        color: same substructure<br>
-        gray: unique substructure
+        <div class="color-legend-row">
+          <div style="position: absolute;">
+            colors
+          </div>
+          <span class="color-legend-value">same substructure</span>
+        </div>
+        <div class="color-legend-row">
+          <svg>
+            <use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#grey-rect" x="0" y="0"></use>
+          </svg>
+          <span class="color-legend-value">unique substructure</span>
+        </div>
       </div>
     </template>
     <template is="dom-if" if="[[_equals(colorBy, 'device')]]">
@@ -12490,7 +12650,20 @@ span.counter {
           </table>
         </div>
         <br>
-        gray: unknown device
+        <div class="color-legend-row">
+          <svg>
+            <use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#grey-rect" x="0" y="0"></use>
+          </svg>
+          <span class="color-legend-value">unknown device</span>
+        </div>
+      </div>
+    </template>
+    <template is="dom-if" if="[[_statsNotNull(stats)]]">
+      <div class="color-legend-row">
+        <svg>
+          <use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#faded-rect" x="0" y="0"></use>
+        </svg>
+        <span class="color-legend-value">unused substructure</span>
       </div>
     </template>
   </div>
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index f6118e2cc07..03054238a24 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -86,10 +86,22 @@ def tf_proto_text_protos_relative():
   return [p for p in tf_android_core_proto_sources_relative()
           if p not in ("util/test_log.proto")]
 
-def if_android_arm(a, b=[]):
+def if_android_arm(a):
   return select({
       "//tensorflow:android_arm": a,
-      "//conditions:default": b,
+      "//conditions:default": [],
+  })
+
+def if_not_android(a):
+  return select({
+      "//tensorflow:android": [],
+      "//conditions:default": a,
+  })
+
+def if_android(a):
+  return select({
+      "//tensorflow:android": a,
+      "//conditions:default": [],
   })
 
 def tf_copts():
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index d99cb5b5e3d..7c68fb763fa 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -52,6 +52,13 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
     actual = "@farmhash//:farmhash",
   )
 
+  native.git_repository(
+    name = "highwayhash",
+    remote = "https://github.com/google/highwayhash.git",
+    commit = "be5edafc2e1a455768e260ccd68ae7317b6690ee",
+    init_submodules = True,
+  )
+
   native.new_http_archive(
     name = "jpeg_archive",
     url = "http://www.ijg.org/files/jpegsrc.v9a.tar.gz",