diff --git a/WORKSPACE b/WORKSPACE
index 36d382095b5..a0aaefa6f58 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -70,7 +70,7 @@ new_git_repository(
   name = "iron_a11y_keys_behavior",
   build_file = "bower.BUILD",
   remote = "https://github.com/polymerelements/iron-a11y-keys-behavior.git",
-  tag = "v1.1.5",
+  tag = "v1.1.6",
 )
 
 new_git_repository(
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 8f1e5b860a4..7f831e67c41 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -171,6 +171,16 @@ cuda_py_tests(
     ],
 )
 
+cuda_py_tests(
+    name = "transformed_distribution_test",
+    size = "small",
+    srcs = ["python/kernel_tests/transformed_distribution_test.py"],
+    additional_deps = [
+        ":distributions_py",
+        "//tensorflow/python:platform_test",
+    ],
+)
+
 filegroup(
     name = "all_files",
     srcs = glob(
diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py
index d04693ce983..0957c681016 100644
--- a/tensorflow/contrib/distributions/__init__.py
+++ b/tensorflow/contrib/distributions/__init__.py
@@ -47,6 +47,10 @@ initialized with parameters that define the distributions.
 
 @@DirichletMultinomial
 
+### Transformed distributions
+
+@@ContinuousTransformedDistribution
+
 ## Operators allowing for matrix-free methods
 
 ### Positive definite operators
@@ -95,4 +99,5 @@ from tensorflow.contrib.distributions.python.ops.operator_pd import *
 from tensorflow.contrib.distributions.python.ops.operator_pd_cholesky import *
 from tensorflow.contrib.distributions.python.ops.operator_pd_full import *
 from tensorflow.contrib.distributions.python.ops.student_t import *
+from tensorflow.contrib.distributions.python.ops.transformed_distribution import *
 from tensorflow.contrib.distributions.python.ops.uniform import *
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
new file mode 100644
index 00000000000..d78f4a92161
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
@@ -0,0 +1,79 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for ContinuousTransformedDistribution."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from scipy import stats
+import tensorflow as tf
+
+
+class ContinuousTransformedDistributionTest(tf.test.TestCase):
+
+  def testContinuousTransformedDistribution(self):
+    with self.test_session():
+      mu = 3.0
+      sigma = 0.02
+      log_normal = tf.contrib.distributions.ContinuousTransformedDistribution(
+          base_dist_cls=tf.contrib.distributions.Normal,
+          mu=mu,
+          sigma=sigma,
+          transform=lambda x: tf.exp(x),
+          inverse=lambda y: tf.log(y),
+          log_det_jacobian=(lambda x: tf.reduce_sum(x)))
+
+      # sample
+      self.assertAllClose([stats.lognorm.mean(s=sigma, scale=np.exp(mu))],
+                          [np.mean(log_normal.sample(100000, seed=235).eval())],
+                          atol=1e-2)
+
+      # pdf, log_pdf
+      test_vals = np.linspace(0.00001, 10.).astype(np.float32)
+      for test_val in test_vals:
+        expected = stats.lognorm.logpdf(test_val, s=sigma, scale=np.exp(mu))
+        self.assertAllClose([expected], [log_normal.log_pdf(test_val).eval()])
+        self.assertAllClose([np.exp(expected)],
+                            [log_normal.pdf(test_val).eval()])
+
+  def testCachedSamplesWithoutInverse(self):
+    with self.test_session() as sess:
+      mu = 3.0
+      sigma = 0.02
+      log_normal = tf.contrib.distributions.ContinuousTransformedDistribution(
+          base_dist_cls=tf.contrib.distributions.Normal,
+          mu=mu,
+          sigma=sigma,
+          transform=lambda x: tf.exp(x),
+          inverse=None,
+          log_det_jacobian=(lambda x: tf.reduce_sum(x)))
+
+      sample = log_normal.sample(1)
+      sample_val, log_pdf_val = sess.run([sample, log_normal.log_pdf(sample)])
+      self.assertAllClose(
+          stats.lognorm.logpdf(sample_val, s=sigma,
+                               scale=np.exp(mu)),
+          log_pdf_val,
+          atol=1e-2)
+
+      with self.assertRaisesRegexp(ValueError,
+                                   "was not returned from `sample`"):
+        log_normal.log_pdf(tf.constant(3.0))
+
+
+if __name__ == "__main__":
+  tf.test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/transformed_distribution.py
new file mode 100644
index 00000000000..8899d7a59f7
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/transformed_distribution.py
@@ -0,0 +1,252 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A Transformed Distribution class."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distributions.python.ops import distribution  # pylint: disable=line-too-long
+from tensorflow.python.framework import ops
+
+
+class ContinuousTransformedDistribution(distribution.ContinuousDistribution):
+  """A Transformed Distribution.
+
+  A Transformed Distribution models `p(y)` given a base distribution `p(x)`,
+  an invertible transform, `y = f(x)`, and the determinant of the Jacobian of
+  `f(x)`.
+
+  Shapes, type, and reparameterization are taken from the base distribution.
+
+  #### Mathematical details
+
+  * `p(x)` - probability distribution for random variable X
+  * `p(y)` - probability distribution for random variable Y
+  * `f` - transform
+  * `g` - inverse transform, `f(g(x)) = x`
+  * `J(x)` - Jacobian of f(x)
+
+  A Transformed Distribution exposes `sample` and `pdf`:
+
+    * `sample`: `y = f(x)`, after drawing a sample of X.
+    * `pdf`: `p(y) = p(x) / det|J(x)| = p(g(y)) / det|J(g(y))|`
+
+  A simple example constructing a Log-Normal distribution from a Normal
+  distribution:
+
+  ```
+  logit_normal = ContinuousTransformedDistribution(
+    base_dist=Normal(mu, sigma),
+    transform=lambda x: tf.sigmoid(x),
+    inverse=lambda y: tf.log(y) - tf.log(1. - y),
+    log_det_jacobian=(lambda x:
+        tf.reduce_sum(tf.log(tf.sigmoid(x)) + tf.log(1. - tf.sigmoid(x)),
+                      reduction_indices=[-1])))
+    name="LogitNormalTransformedDistribution"
+  )
+  ```
+  """
+
+  def __init__(self,
+               base_dist_cls,
+               transform,
+               inverse,
+               log_det_jacobian,
+               name="ContinuousTransformedDistribution",
+               **base_dist_args):
+    """Construct a Transformed Distribution.
+
+    Args:
+      base_dist_cls: the base distribution class to transform. Must be a
+          subclass of `ContinuousDistribution`.
+      transform: a callable that takes a `Tensor` sample from `base_dist` and
+          returns a `Tensor` of the same shape and type. `x => y`.
+      inverse: a callable that computes the inverse of transform. `y => x`. If
+          None, users can only call `log_pdf` on values returned by `sample`.
+      log_det_jacobian: a callable that takes a `Tensor` sample from `base_dist`
+          and returns the log of the determinant of the Jacobian of `transform`.
+      name: The name for the distribution.
+      **base_dist_args: kwargs to pass on to dist_cls on construction.
+
+    Raises:
+      TypeError: if `base_dist_cls` is not a subclass of
+          `ContinuousDistribution`.
+    """
+    if not issubclass(base_dist_cls, distribution.ContinuousDistribution):
+      raise TypeError("base_dist_cls must be a subclass of"
+                      "ContinuousDistribution.")
+    with ops.op_scope(base_dist_args.values(), name) as scope:
+      self._name = scope
+      self._base_dist = base_dist_cls(**base_dist_args)
+    self._transform = transform
+    self._inverse = inverse
+    self._log_det_jacobian = log_det_jacobian
+    self._inverse_cache = {}
+
+  @property
+  def name(self):
+    return self._name
+
+  @property
+  def dtype(self):
+    return self._base_dist.dtype
+
+  def batch_shape(self, name="batch_shape"):
+    """Batch dimensions of this instance as a 1-D int32 `Tensor`.
+
+    The product of the dimensions of the `batch_shape` is the number of
+    independent distributions of this kind the instance represents.
+
+    Args:
+      name: name to give to the op.
+
+    Returns:
+      `Tensor` `batch_shape`
+    """
+    with ops.name_scope(self.name):
+      return self._base_dist.batch_shape(name)
+
+  def get_batch_shape(self):
+    """`TensorShape` available at graph construction time.
+
+    Same meaning as `batch_shape`. May be only partially defined.
+
+    Returns:
+      batch shape
+    """
+    return self._base_dist.get_batch_shape()
+
+  def event_shape(self, name="event_shape"):
+    """Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
+
+    Args:
+      name: name to give to the op.
+
+    Returns:
+      `Tensor` `event_shape`
+    """
+    with ops.name_scope(self.name):
+      return self._base_dist.event_shape(name)
+
+  def get_event_shape(self):
+    """`TensorShape` available at graph construction time.
+
+    Same meaning as `event_shape`. May be only partially defined.
+
+    Returns:
+      event shape
+    """
+    return self._base_dist.get_event_shape()
+
+  @property
+  def base_distribution(self):
+    """Base distribution, p(x)."""
+    return self._base_dist
+
+  @property
+  def transform(self):
+    """Function transforming x => y."""
+    return self._transform
+
+  @property
+  def inverse(self):
+    """Inverse function of transform, y => x."""
+    return self._inverse
+
+  @property
+  def log_det_jacobian(self):
+    """Function computing the log determinant of the Jacobian of transform."""
+    return self._log_det_jacobian
+
+  def log_pdf(self, y, name="log_pdf"):
+    """Log pdf of observations in `y`.
+
+    `log ( p(g(y)) / det|J(g(y))| )`, where `g` is the inverse of `transform`.
+
+    Args:
+      y: tensor of dtype `dtype`.
+      name: The name to give this op.
+
+    Returns:
+      log_pdf: tensor of dtype `dtype`, the log-PDFs of `y`.
+
+    Raises:
+      ValueError: if `inverse` was not provided to the distribution and `y` was
+          not returned from `sample`.
+    """
+    with ops.name_scope(self.name):
+      with ops.op_scope([y], name):
+        y = ops.convert_to_tensor(y)
+        if y.dtype != self.dtype:
+          raise TypeError("Input x dtype does not match dtype: %s vs. %s" %
+                          (y.dtype, self.dtype))
+        with ops.name_scope("inverse"):
+          if y in self._inverse_cache:
+            x = self._inverse_cache[y]
+          elif self._inverse:
+            x = self._inverse(y)
+          else:
+            raise ValueError("No inverse function exists and input `y` was not "
+                             "returned from `sample`.")
+        with ops.name_scope("log_det_jacobian"):
+          log_det_jacobian = self._log_det_jacobian(x)
+        return self._base_dist.log_likelihood(x) - log_det_jacobian
+
+  def pdf(self, y, name="pdf"):
+    """The PDF of observations in `y`.
+
+    `p(g(y)) / det|J(g(y))|`, where `g` is the inverse of `transform`.
+
+    Args:
+      y: `Tensor` of dtype `dtype`.
+      name: The name to give this op.
+
+    Returns:
+      pdf: `Tensor` of dtype `dtype`, the pdf values of `y`.
+    """
+    return super(ContinuousTransformedDistribution, self).pdf(y, name=name)
+
+  def sample(self, n, seed=None, name="sample"):
+    """Sample `n` observations.
+
+    Samples from the base distribution and then passes through the transform.
+
+    Args:
+      n: scalar, type int32, the number of observations to sample.
+      seed: Python integer, the random seed.
+      name: The name to give this op.
+
+    Returns:
+      samples: `[n, ...]`, a `Tensor` of `n` samples.
+    """
+    with ops.name_scope(self.name):
+      with ops.name_scope(name):
+        samples = self._base_dist.sample(n=n, seed=seed)
+        with ops.name_scope("transform"):
+          transformed = self._transform(samples)
+          self._inverse_cache[transformed] = samples
+          return transformed
+
+  @property
+  def is_reparameterized(self):
+    return self._base_dist.is_reparameterized
+
+  @property
+  def strict_statistics(self):
+    return self._base_dist.strict_statistics
+
+  @property
+  def strict(self):
+    return self._base_dist.strict
diff --git a/tensorflow/contrib/framework/python/ops/embedding_ops.py b/tensorflow/contrib/framework/python/ops/embedding_ops.py
index af51042944d..76f4143c09a 100644
--- a/tensorflow/contrib/framework/python/ops/embedding_ops.py
+++ b/tensorflow/contrib/framework/python/ops/embedding_ops.py
@@ -17,18 +17,14 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import embedding_ops as tf_embedding_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import sparse_ops
+from tensorflow.contrib.framework.python.framework.deprecation import deprecated
+from tensorflow.contrib.layers import embedding_ops as embedding_ops
 
 __all__ = ["safe_embedding_lookup_sparse",]
 
 
+@deprecated("2016-09-01",
+            "Please use tf.contrib.layers.safe_embedding_lookup_sparse.")
 def safe_embedding_lookup_sparse(embedding_weights,
                                  sparse_ids,
                                  sparse_weights=None,
@@ -74,82 +70,11 @@ def safe_embedding_lookup_sparse(embedding_weights,
   Raises:
     ValueError: if `embedding_weights` is empty.
   """
-  if embedding_weights is None or len(embedding_weights) < 1:
-    raise ValueError("Missing embedding_weights %s." % embedding_weights)
-
-  dtype = sparse_weights.dtype if sparse_weights is not None else None
-  embedding_weights = [
-      ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
-  ]
-
-  contrib_tensor_util.assert_same_float_dtype(embedding_weights +
-                                              [sparse_weights])
-
-  with ops.op_scope(embedding_weights + [sparse_ids, sparse_weights], name,
-                    "embedding_lookup") as scope:
-    # Reshape higher-rank sparse ids and weights to linear segment ids.
-    original_shape = sparse_ids.shape
-    original_rank_dim = sparse_ids.shape.get_shape()[0]
-    original_rank = (
-        array_ops.size(original_shape)
-        if original_rank_dim.value is None
-        else original_rank_dim.value)
-    sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [
-        math_ops.reduce_prod(
-            array_ops.slice(original_shape, [0], [original_rank - 1])),
-        array_ops.gather(original_shape, original_rank - 1)])
-    if sparse_weights is not None:
-      sparse_weights = ops.SparseTensor(sparse_ids.indices,
-                                        sparse_weights.values, sparse_ids.shape)
-
-    # Prune invalid ids and weights.
-    sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights)
-
-    # Fill in dummy values for empty features, if necessary.
-    sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids,
-                                                                 default_id or
-                                                                 0)
-    if sparse_weights is not None:
-      sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0)
-
-    result = tf_embedding_ops.embedding_lookup_sparse(
-        embedding_weights,
-        sparse_ids,
-        sparse_weights,
-        combiner=combiner,
-        partition_strategy=partition_strategy,
-        name=None if default_id is None else scope)
-
-    if default_id is None:
-      # Broadcast is_row_empty to the same shape as embedding_lookup_result,
-      # for use in Select.
-      is_row_empty = array_ops.tile(
-          array_ops.reshape(is_row_empty, [-1, 1]),
-          array_ops.pack([1, array_ops.shape(result)[1]]))
-
-      result = math_ops.select(is_row_empty,
-                               array_ops.zeros_like(result),
-                               result,
-                               name=scope)
-
-    # Reshape back from linear ids back into higher-dimensional dense result.
-    final_result = array_ops.reshape(result, array_ops.concat(0, [
-        array_ops.slice(
-            math_ops.cast(original_shape, dtypes.int32),
-            [0], [original_rank - 1]),
-        array_ops.slice(array_ops.shape(result), [1], [-1])]))
-    final_result.set_shape(tensor_shape.unknown_shape(
-        (original_rank_dim - 1).value).concatenate(result.get_shape()[1:]))
-    return final_result
-
-
-def _prune_invalid_ids(sparse_ids, sparse_weights):
-  """Prune invalid IDs (< 0) from the input ids and weights."""
-  is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
-  if sparse_weights is not None:
-    is_id_valid = math_ops.logical_and(
-        is_id_valid, math_ops.greater(sparse_weights.values, 0))
-  sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
-  if sparse_weights is not None:
-    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
-  return sparse_ids, sparse_weights
+  return embedding_ops.safe_embedding_lookup_sparse(
+      embedding_weights,
+      sparse_ids,
+      sparse_weights=sparse_weights,
+      combiner=combiner,
+      default_id=default_id,
+      name=name,
+      partition_strategy=partition_strategy)
diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py
index 4904c16a9cd..b40b622b8f1 100644
--- a/tensorflow/contrib/layers/python/layers/embedding_ops.py
+++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py
@@ -18,11 +18,12 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.framework.python.ops import embedding_ops as contrib_embedding_ops
+from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
 from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
-from tensorflow.python.framework import dtypes
 
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import embedding_ops
 from tensorflow.python.ops import math_ops
@@ -32,8 +33,130 @@ __all__ = ["safe_embedding_lookup_sparse", "hashed_embedding_lookup",
            "hashed_embedding_lookup_sparse"]
 
 
-# TODO(chapelle): move the safe_embedding_lookup_sparse code here (b/29826543)
-safe_embedding_lookup_sparse = contrib_embedding_ops.safe_embedding_lookup_sparse  # pylint: disable=line-too-long
+def safe_embedding_lookup_sparse(embedding_weights,
+                                 sparse_ids,
+                                 sparse_weights=None,
+                                 combiner="mean",
+                                 default_id=None,
+                                 name=None,
+                                 partition_strategy="div"):
+  """Lookup embedding results, accounting for invalid IDs and empty features.
+
+  The partitioned embedding in `embedding_weights` must all be the same shape
+  except for the first dimension. The first dimension is allowed to vary as the
+  vocabulary size is not necessarily a multiple of `P`.
+
+  Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
+  with non-positive weight. For an entry with no features, the embedding vector
+  for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
+
+  The ids and weights may be multi-dimensional. Embeddings are always aggregated
+  along the last dimension.
+
+  Args:
+    embedding_weights:  A list of `P` float tensors or values representing
+        partitioned embedding tensors.  The total unpartitioned shape should be
+        `[e_0, e_1, ..., e_m]`, where `e_0` represents the vocab size and
+        `e_1, ..., e_m` are the embedding dimensions.
+    sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
+        ids. `d_0` is typically batch size.
+    sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
+        float weights corresponding to `sparse_ids`, or `None` if all weights
+        are be assumed to be 1.0.
+    combiner: A string specifying how to combine embedding results for each
+        entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
+        the default.
+    default_id: The id to use for an entry with no features.
+    name: A name for this operation (optional).
+    partition_strategy: A string specifying the partitioning strategy.
+        Currently `"div"` and `"mod"` are supported. Default is `"div"`.
+
+
+  Returns:
+    Dense tensor of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`.
+
+  Raises:
+    ValueError: if `embedding_weights` is empty.
+  """
+  if embedding_weights is None or len(embedding_weights) < 1:
+    raise ValueError("Missing embedding_weights %s." % embedding_weights)
+
+  dtype = sparse_weights.dtype if sparse_weights is not None else None
+  embedding_weights = [
+      ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
+  ]
+
+  contrib_tensor_util.assert_same_float_dtype(embedding_weights +
+                                              [sparse_weights])
+
+  with ops.op_scope(embedding_weights + [sparse_ids, sparse_weights], name,
+                    "embedding_lookup") as scope:
+    # Reshape higher-rank sparse ids and weights to linear segment ids.
+    original_shape = sparse_ids.shape
+    original_rank_dim = sparse_ids.shape.get_shape()[0]
+    original_rank = (
+        array_ops.size(original_shape)
+        if original_rank_dim.value is None
+        else original_rank_dim.value)
+    sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [
+        math_ops.reduce_prod(
+            array_ops.slice(original_shape, [0], [original_rank - 1])),
+        array_ops.gather(original_shape, original_rank - 1)])
+    if sparse_weights is not None:
+      sparse_weights = ops.SparseTensor(sparse_ids.indices,
+                                        sparse_weights.values, sparse_ids.shape)
+
+    # Prune invalid ids and weights.
+    sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights)
+
+    # Fill in dummy values for empty features, if necessary.
+    sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids,
+                                                                 default_id or
+                                                                 0)
+    if sparse_weights is not None:
+      sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0)
+
+    result = embedding_ops.embedding_lookup_sparse(
+        embedding_weights,
+        sparse_ids,
+        sparse_weights,
+        combiner=combiner,
+        partition_strategy=partition_strategy,
+        name=None if default_id is None else scope)
+
+    if default_id is None:
+      # Broadcast is_row_empty to the same shape as embedding_lookup_result,
+      # for use in Select.
+      is_row_empty = array_ops.tile(
+          array_ops.reshape(is_row_empty, [-1, 1]),
+          array_ops.pack([1, array_ops.shape(result)[1]]))
+
+      result = math_ops.select(is_row_empty,
+                               array_ops.zeros_like(result),
+                               result,
+                               name=scope)
+
+    # Reshape back from linear ids back into higher-dimensional dense result.
+    final_result = array_ops.reshape(result, array_ops.concat(0, [
+        array_ops.slice(
+            math_ops.cast(original_shape, dtypes.int32),
+            [0], [original_rank - 1]),
+        array_ops.slice(array_ops.shape(result), [1], [-1])]))
+    final_result.set_shape(tensor_shape.unknown_shape(
+        (original_rank_dim - 1).value).concatenate(result.get_shape()[1:]))
+    return final_result
+
+
+def _prune_invalid_ids(sparse_ids, sparse_weights):
+  """Prune invalid IDs (< 0) from the input ids and weights."""
+  is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
+  if sparse_weights is not None:
+    is_id_valid = math_ops.logical_and(
+        is_id_valid, math_ops.greater(sparse_weights.values, 0))
+  sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
+  if sparse_weights is not None:
+    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
+  return sparse_ids, sparse_weights
 
 
 def hashed_embedding_lookup(params, values, dimension, name=None):
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops.py b/tensorflow/contrib/layers/python/layers/feature_column_ops.py
index 823a38f30d0..f343b68f7c0 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_ops.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_ops.py
@@ -98,9 +98,13 @@ def input_from_feature_columns(columns_to_tensors,
                                     [ops.GraphKeys.VARIABLES]))
 
     for column in sorted(set(feature_columns), key=lambda x: x.key):
-      transformed_tensor = transformer.transform(column)
-      output_tensors.append(column.to_dnn_input_layer(
-          transformed_tensor, weight_collections, trainable))
+      try:
+        transformed_tensor = transformer.transform(column)
+        output_tensors.append(column.to_dnn_input_layer(
+            transformed_tensor, weight_collections, trainable))
+      except ValueError as e:
+        raise ValueError('Error creating input layer for column: {}.\n'
+                         '{}'.format(column.name, e))
     return array_ops.concat(1, output_tensors)
 
 
@@ -174,11 +178,15 @@ def weighted_sum_from_feature_columns(columns_to_tensors,
     column_to_variable = dict()
     transformer = _Transformer(columns_to_tensors)
     for column in sorted(set(feature_columns), key=lambda x: x.key):
-      transformed_tensor = transformer.transform(column)
-      predictions, variable = column.to_weighted_sum(transformed_tensor,
-                                                     num_outputs,
-                                                     weight_collections,
-                                                     trainable)
+      try:
+        transformed_tensor = transformer.transform(column)
+        predictions, variable = column.to_weighted_sum(transformed_tensor,
+                                                       num_outputs,
+                                                       weight_collections,
+                                                       trainable)
+      except ValueError as e:
+        raise ValueError('Error creating weighted sum for column: {}.\n'
+                         '{}'.format(column.name, e))
       output_tensors.append(predictions)
       column_to_variable[column] = variable
       _log_variable(variable)
@@ -305,7 +313,10 @@ def check_feature_columns(feature_columns):
   for f in feature_columns:
     key = f.key
     if key in seen_keys:
-      raise ValueError('Duplicate feature column key found: %s' % key)
+      raise ValueError('Duplicate feature column key found for column: {}. '
+                       'This usually means that the column is almost identical '
+                       'to another column, and one must be discarded.'.format(
+                           f.name))
     seen_keys.add(key)
 
 
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
index 1d0f45357ed..33aa3c8b091 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
@@ -341,9 +341,12 @@ class InputLayerTest(tf.test.TestCase):
 
     # Makes sure that trying to use different initializers with the same
     # embedding column explicitly fails.
-    with self.assertRaises(ValueError):
-      tf.contrib.layers.input_from_feature_columns(
-          features, [embedded_sparse, embedded_sparse_alternate])
+    with self.test_session():
+      with self.assertRaisesRegexp(
+          ValueError,
+          "Duplicate feature column key found for column: wire_embedding"):
+        tf.contrib.layers.input_from_feature_columns(
+            features, [embedded_sparse, embedded_sparse_alternate])
 
   def testSparseColumn(self):
     hashed_sparse = tf.contrib.layers.sparse_column_with_hash_bucket("wire", 10)
@@ -351,9 +354,11 @@ class InputLayerTest(tf.test.TestCase):
                                   indices=[[0, 0], [1, 0], [1, 1]],
                                   shape=[2, 2])
     features = {"wire": wire_tensor}
-    with self.assertRaises(ValueError):
-      tf.initialize_all_variables().run()
-      tf.contrib.layers.input_layer(features, [hashed_sparse])
+    with self.test_session():
+      with self.assertRaisesRegexp(
+          ValueError, "Error creating input layer for column: wire"):
+        tf.initialize_all_variables().run()
+        tf.contrib.layers.input_from_feature_columns(features, [hashed_sparse])
 
   def testCrossedColumn(self):
     a = tf.contrib.layers.sparse_column_with_hash_bucket("aaa",
@@ -366,9 +371,11 @@ class InputLayerTest(tf.test.TestCase):
                                   indices=[[0, 0], [1, 0], [1, 1]],
                                   shape=[2, 2])
     features = {"aaa": wire_tensor, "bbb": wire_tensor}
-    with self.assertRaises(ValueError):
-      tf.initialize_all_variables().run()
-      tf.contrib.layers.input_layer(features, [crossed])
+    with self.test_session():
+      with self.assertRaisesRegexp(
+          ValueError, "Error creating input layer for column: aaa_X_bbb"):
+        tf.initialize_all_variables().run()
+        tf.contrib.layers.input_from_feature_columns(features, [crossed])
 
   def testAllColumns(self):
     real_valued = tf.contrib.layers.real_valued_column("income", 3)
@@ -477,10 +484,13 @@ class WeightedSumTest(tf.test.TestCase):
                                   shape=[2, 2])
     features = {"wire": wire_tensor}
     embeded_sparse = tf.contrib.layers.embedding_column(hashed_sparse, 10)
-    with self.assertRaises(ValueError):
-      tf.initialize_all_variables().run()
-      tf.contrib.layers.weighted_sum_from_feature_columns(features,
-                                                          [embeded_sparse])
+    with self.test_session():
+      with self.assertRaisesRegexp(
+          ValueError, "Error creating weighted sum for column: wire_embedding"):
+        tf.initialize_all_variables().run()
+        tf.contrib.layers.weighted_sum_from_feature_columns(features,
+                                                            [embeded_sparse],
+                                                            num_outputs=5)
 
   def testRealValuedColumnWithMultiDimensions(self):
     real_valued = tf.contrib.layers.real_valued_column("price", 2)
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index ff834723f44..5db69a8b5d0 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -19,6 +19,7 @@ py_library(
     deps = [
         "//tensorflow/contrib/learn/python/learn/datasets",
         "//tensorflow/contrib/session_bundle:exporter",
+        "//tensorflow/contrib/tensor_forest:client_lib",
         "//tensorflow/python:framework",
     ],
 )
@@ -390,6 +391,18 @@ py_test(
     ],
 )
 
+py_test(
+    name = "random_forest_test",
+    size = "medium",
+    srcs = ["python/learn/estimators/random_forest_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":learn",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/python:framework_test_lib",
+    ],
+)
+
 py_test(
     name = "rnn_test",
     size = "medium",
diff --git a/tensorflow/contrib/learn/python/learn/estimators/__init__.py b/tensorflow/contrib/learn/python/learn/estimators/__init__.py
index 12a8b7cfe51..b6e6e57ebe4 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/__init__.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/__init__.py
@@ -40,6 +40,7 @@ from tensorflow.contrib.learn.python.learn.estimators.linear import TensorFlowLi
 from tensorflow.contrib.learn.python.learn.estimators.linear import TensorFlowLinearRegressor
 from tensorflow.contrib.learn.python.learn.estimators.linear import TensorFlowRegressor
 from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor
+from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestEstimator
 from tensorflow.contrib.learn.python.learn.estimators.rnn import TensorFlowRNNClassifier
 from tensorflow.contrib.learn.python.learn.estimators.rnn import TensorFlowRNNRegressor
 from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
index dcd4719c5cf..5ab5eec26a5 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
@@ -110,8 +110,7 @@ class _ComposableModel(object):
     grads = gradients.gradients(loss, my_vars)
     if self._gradient_clip_norm:
       grads, _ = clip_ops.clip_by_global_norm(grads, self._gradient_clip_norm)
-    self._optimizer = self._get_optimizer()
-    return [self._optimizer.apply_gradients(zip(grads, my_vars))]
+    return [self._get_optimizer().apply_gradients(zip(grads, my_vars))]
 
   def _get_feature_columns(self):
     if not self._feature_columns:
@@ -130,6 +129,16 @@ class _ComposableModel(object):
     return []
 
   def _get_optimizer(self):
+    if (self._optimizer is None or isinstance(self._optimizer,
+                                              six.string_types)):
+      optimizer = self._get_default_optimizer(self._optimizer)
+    elif callable(self._optimizer):
+      optimizer = self._optimizer()
+    else:
+      optimizer = self._optimizer
+    return optimizer
+
+  def _get_default_optimizer(self, optimizer_name=None):
     raise NotImplementedError
 
 
@@ -173,14 +182,12 @@ class _LinearComposableModel(_ComposableModel):
         name="linear")
     return logits
 
-  def _get_optimizer(self):
-    if self._optimizer is None:
-      self._optimizer = "Ftrl"
-    if isinstance(self._optimizer, six.string_types):
-      default_learning_rate = 1. / math.sqrt(len(self._get_feature_columns()))
-      self._optimizer = layers.OPTIMIZER_CLS_NAMES[self._optimizer](
-          learning_rate=default_learning_rate)
-    return self._optimizer
+  def _get_default_optimizer(self, optimizer_name=None):
+    if optimizer_name is None:
+      optimizer_name = "Ftrl"
+    default_learning_rate = 1. / math.sqrt(len(self._get_feature_columns()))
+    return layers.OPTIMIZER_CLS_NAMES[optimizer_name](
+        learning_rate=default_learning_rate)
 
 
 class _DNNComposableModel(_ComposableModel):
@@ -269,13 +276,10 @@ class _DNNComposableModel(_ComposableModel):
     self._add_hidden_layer_summary(logits, "dnn_logits")
     return logits
 
-  def _get_optimizer(self):
-    if self._optimizer is None:
-      self._optimizer = "Adagrad"
-    if isinstance(self._optimizer, six.string_types):
-      self._optimizer = layers.OPTIMIZER_CLS_NAMES[self._optimizer](
-          learning_rate=0.05)
-    return self._optimizer
+  def _get_default_optimizer(self, optimizer_name=None):
+    if optimizer_name is None:
+      optimizer_name = "Adagrad"
+    return layers.OPTIMIZER_CLS_NAMES[optimizer_name](learning_rate=0.05)
 
 
 # TODO(ispir): Increase test coverage
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
index 345065c4832..30750227681 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
@@ -262,6 +262,36 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
     scores = classifier.evaluate(input_fn=_iris_input_logistic_fn, steps=100)
     self.assertGreater(scores['accuracy'], 0.9)
 
+  def testCustomOptimizerByFunction(self):
+    """Tests binary classification using matrix data as input."""
+    iris = _prepare_iris_data_for_logistic_regression()
+    cont_features = [
+        tf.contrib.layers.real_valued_column('feature', dimension=4)
+    ]
+    bucketized_features = [
+        tf.contrib.layers.bucketized_column(
+            cont_features[0], _get_quantile_based_buckets(iris.data, 10))
+    ]
+
+    def _optimizer_exp_decay():
+      global_step = tf.contrib.framework.get_global_step()
+      learning_rate = tf.train.exponential_decay(learning_rate=0.1,
+                                                 global_step=global_step,
+                                                 decay_steps=100,
+                                                 decay_rate=0.001)
+      return tf.train.AdagradOptimizer(learning_rate=learning_rate)
+
+    classifier = tf.contrib.learn.DNNLinearCombinedClassifier(
+        linear_feature_columns=bucketized_features,
+        linear_optimizer=_optimizer_exp_decay,
+        dnn_feature_columns=cont_features,
+        dnn_hidden_units=[3, 3],
+        dnn_optimizer=_optimizer_exp_decay)
+
+    classifier.fit(input_fn=_iris_input_logistic_fn, steps=100)
+    scores = classifier.evaluate(input_fn=_iris_input_logistic_fn, steps=100)
+    self.assertGreater(scores['accuracy'], 0.9)
+
   def testPredict(self):
     """Tests weight column in evaluation."""
     def _input_fn_train():
@@ -441,6 +471,25 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
     self.assertNotIn('linear/feature_BUCKETIZED_weights',
                      classifier.get_variable_names())
 
+  def testDNNWeightsBiasesNames(self):
+    """Tests the names of DNN weights and biases in the checkpoints."""
+    def _input_fn_train():
+      # Create 4 rows, three (y = x), one (y=Not(x))
+      target = tf.constant([[1], [1], [1], [0]])
+      features = {'x': tf.ones(shape=[4, 1], dtype=tf.float32),}
+      return features, target
+    classifier = tf.contrib.learn.DNNLinearCombinedClassifier(
+        linear_feature_columns=[tf.contrib.layers.real_valued_column('x')],
+        dnn_feature_columns=[tf.contrib.layers.real_valued_column('x')],
+        dnn_hidden_units=[3, 3])
+
+    classifier.fit(input_fn=_input_fn_train, steps=5)
+    # hiddenlayer_0/weights,hiddenlayer_1/weights and dnn_logits/weights.
+    self.assertEquals(3, len(classifier.dnn_weights_))
+    # hiddenlayer_0/biases, hiddenlayer_1/biases, dnn_logits/biases,
+    # centered_bias_weight.
+    self.assertEquals(4, len(classifier.dnn_bias_))
+
 
 class DNNLinearCombinedRegressorTest(tf.test.TestCase):
 
diff --git a/tensorflow/contrib/learn/python/learn/estimators/random_forest.py b/tensorflow/contrib/learn/python/learn/estimators/random_forest.py
new file mode 100644
index 00000000000..b9f50701ee3
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/estimators/random_forest.py
@@ -0,0 +1,191 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A tf.learn implementation of tensor_forest (extremely random forests)."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+import six
+
+from tensorflow.contrib import framework as contrib_framework
+from tensorflow.contrib.learn.python.learn import monitors as mon
+
+from tensorflow.contrib.learn.python.learn.estimators import estimator
+from tensorflow.contrib.learn.python.learn.estimators import run_config
+
+from tensorflow.contrib.tensor_forest.client import eval_metrics
+from tensorflow.contrib.tensor_forest.data import data_ops
+from tensorflow.contrib.tensor_forest.python import tensor_forest
+
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+
+
+class LossMonitor(mon.EveryN):
+  """Terminates training when training loss stops decreasing."""
+
+  def __init__(self,
+               early_stopping_rounds,
+               every_n_steps):
+    super(LossMonitor, self).__init__(every_n_steps=every_n_steps)
+    self.early_stopping_rounds = early_stopping_rounds
+    self.min_loss = None
+    self.min_loss_step = 0
+
+  def set_estimator(self, est):
+    """This function gets called in the same graph as _get_train_ops."""
+    super(LossMonitor, self).set_estimator(est)
+    self._loss_op_name = est.training_loss.name
+
+  def every_n_step_end(self, step, outputs):
+    super(LossMonitor, self).every_n_step_end(step, outputs)
+    current_loss = outputs[self._loss_op_name]
+    if self.min_loss is None or current_loss < self.min_loss:
+      self.min_loss = current_loss
+      self.min_loss_step = step
+    return step - self.min_loss_step >= self.early_stopping_rounds
+
+
+class TensorForestEstimator(estimator.BaseEstimator):
+  """An estimator that can train and evaluate a random forest."""
+
+  def __init__(self, params, device_assigner=None, model_dir=None,
+               graph_builder_class=tensor_forest.RandomForestGraphs,
+               master='', accuracy_metric=None,
+               tf_random_seed=None, continue_training=False, verbose=1,
+               max_to_keep=5, save_checkpoint_secs=300):
+    self.params = params.fill()
+    self.accuracy_metric = (accuracy_metric or
+                            ('r2' if self.params.regression else 'accuracy'))
+    self.data_feeder = None
+    self.device_assigner = (
+        device_assigner or tensor_forest.RandomForestDeviceAssigner())
+    self.graph_builder_class = graph_builder_class
+    self.training_args = {}
+    self.construction_args = {}
+
+    config = run_config.RunConfig(
+        master=master,
+        tf_random_seed=(tf_random_seed or int((time.time() * 1000) % 1000)),
+        save_checkpoints_secs=save_checkpoint_secs,
+        keep_checkpoint_max=max_to_keep)
+
+    super(TensorForestEstimator, self).__init__(model_dir=model_dir,
+                                                config=config)
+
+  def predict_proba(self, x=None, input_fn=None, batch_size=None):
+    """Returns prediction probabilities for given features (classification).
+
+    Args:
+      x: features.
+      input_fn: Input function. If set, x and y must be None.
+      batch_size: Override default batch size.
+
+    Returns:
+      Numpy array of predicted probabilities.
+
+    Raises:
+      ValueError: If both or neither of x and input_fn were given.
+    """
+    return super(TensorForestEstimator, self).predict(
+        x=x, input_fn=input_fn, batch_size=batch_size)
+
+  def predict(self, x=None, input_fn=None, axis=None, batch_size=None):
+    """Returns predictions for given features.
+
+    Args:
+      x: features.
+      input_fn: Input function. If set, x must be None.
+      axis: Axis on which to argmax (for classification).
+            Last axis is used by default.
+      batch_size: Override default batch size.
+
+    Returns:
+      Numpy array of predicted classes or regression values.
+    """
+    probabilities = self.predict_proba(x, input_fn, batch_size)
+    if self.params.regression:
+      return probabilities
+    else:
+      return np.argmax(probabilities, axis=1)
+
+  def _get_train_ops(self, features, targets):
+    """Method that builds model graph and returns trainer ops.
+
+    Args:
+      features: `Tensor` or `dict` of `Tensor` objects.
+      targets: `Tensor` or `dict` of `Tensor` objects.
+
+    Returns:
+      Tuple of train `Operation` and loss `Tensor`.
+    """
+    features, spec = data_ops.ParseDataTensorOrDict(features)
+    labels = data_ops.ParseLabelTensorOrDict(targets)
+
+    graph_builder = self.graph_builder_class(
+        self.params, device_assigner=self.device_assigner,
+        **self.construction_args)
+
+    epoch = None
+    if self.data_feeder:
+      epoch = self.data_feeder.make_epoch_variable()
+
+    train = control_flow_ops.group(
+        graph_builder.training_graph(
+            features, labels, data_spec=spec, epoch=epoch,
+            **self.training_args),
+        state_ops.assign_add(contrib_framework.get_global_step(), 1))
+
+    self.training_loss = graph_builder.training_loss()
+
+    return train, self.training_loss
+
+  def _get_predict_ops(self, features):
+    graph_builder = self.graph_builder_class(
+        self.params, device_assigner=self.device_assigner, training=False,
+        **self.construction_args)
+    features, spec = data_ops.ParseDataTensorOrDict(features)
+    return graph_builder.inference_graph(features, data_spec=spec)
+
+  def _get_eval_ops(self, features, targets, metrics):
+    features, spec = data_ops.ParseDataTensorOrDict(features)
+    labels = data_ops.ParseLabelTensorOrDict(targets)
+
+    graph_builder = self.graph_builder_class(
+        self.params, device_assigner=self.device_assigner, training=False,
+        **self.construction_args)
+
+    probabilities = graph_builder.inference_graph(features, data_spec=spec)
+
+    # One-hot the labels.
+    if not self.params.regression:
+      labels = math_ops.to_int64(array_ops.one_hot(math_ops.to_int64(
+          array_ops.squeeze(labels)), self.params.num_classes, 1, 0))
+
+    if metrics is None:
+      metrics = {self.accuracy_metric:
+                 eval_metrics.get_metric(self.accuracy_metric)}
+
+    result = {}
+    for name, metric in six.iteritems(metrics):
+      result[name] = metric(probabilities, labels)
+
+    return result
diff --git a/tensorflow/contrib/learn/python/learn/estimators/random_forest_test.py b/tensorflow/contrib/learn/python/learn/estimators/random_forest_test.py
new file mode 100644
index 00000000000..15db20906d1
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/estimators/random_forest_test.py
@@ -0,0 +1,55 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for TensorForestTrainer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+class TensorForestTrainerTests(tf.test.TestCase):
+
+  def testClassification(self):
+    """Tests multi-class classification using matrix data as input."""
+    hparams = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
+        num_trees=3, max_nodes=1000, num_classes=3, num_features=4)
+    classifier = tf.contrib.learn.TensorForestEstimator(hparams)
+
+    iris = tf.contrib.learn.datasets.load_iris()
+
+    classifier.fit(x=iris.data, y=iris.target, steps=100)
+    classifier.evaluate(x=iris.data, y=iris.target, steps=10)
+
+  def testRegression(self):
+    """Tests multi-class classification using matrix data as input."""
+
+    hparams = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
+        num_trees=3, max_nodes=1000, num_classes=1, num_features=13,
+        regression=True)
+
+    regressor = tf.contrib.learn.TensorForestEstimator(hparams)
+
+    boston = tf.contrib.learn.datasets.load_boston()
+
+    regressor.fit(x=boston.data, y=boston.target, steps=100)
+    regressor.evaluate(x=boston.data, y=boston.target, steps=10)
+
+
+if __name__ == '__main__':
+  tf.test.main()
diff --git a/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc b/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc
index 921dad614c2..15deca87e67 100644
--- a/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc
+++ b/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc
@@ -600,13 +600,13 @@ Status RunTrainStepsForMiniBatch(
     const DeviceBase::CpuWorkerThreads& worker_threads,
     const Regularizations& regularizations, const DualLossUpdater& loss_updater,
     FeaturesAndWeights* const features_and_weights,
-    TTypes<float>::Matrix example_state_data) {
+    TTypes<float>::Matrix* const example_state_data) {
   // Process examples in parallel, in a partitioned fashion.
   mutex mu;
   Status train_step_status GUARDED_BY(mu);
   auto train_step = [&](const int64 begin, const int64 end) {
     for (int64 example_index = begin; example_index < end; ++example_index) {
-      const float dual = example_state_data(example_index, 0);
+      const float dual = (*example_state_data)(example_index, 0);
       const float example_weight = example_weights(example_index);
       float example_label = example_labels(example_index);
       const Status conversion_status =
@@ -641,10 +641,10 @@ Status RunTrainStepsForMiniBatch(
           example_index, bounded_dual_delta, regularizations.symmetric_l2());
 
       // Update example data.
-      example_state_data(example_index, 0) = new_dual;
-      example_state_data(example_index, 1) = primal_loss;
-      example_state_data(example_index, 2) = dual_loss;
-      example_state_data(example_index, 3) = example_weight;
+      (*example_state_data)(example_index, 0) = new_dual;
+      (*example_state_data)(example_index, 1) = primal_loss;
+      (*example_state_data)(example_index, 2) = dual_loss;
+      (*example_state_data)(example_index, 3) = example_weight;
     }
   };
   // TODO(sibyl-Aix6ihai): Current multiplier 100000 works well empirically
@@ -737,11 +737,11 @@ class SdcaSolver : public OpKernel {
                        num_examples, example_labels, example_weights,
                        *context->device()->tensorflow_cpu_worker_threads(),
                        regularizations_, *loss_updater_, &features_and_weights,
-                       example_state_data));
+                       &example_state_data));
     }
     features_and_weights.AddDeltaWeights();
 
-    context->set_output(0, mutable_example_state_data_t);
+    context->set_output("example_data_data_out", mutable_example_state_data_t);
   }
 
  private:
@@ -775,6 +775,12 @@ class SdcaShrinkL1 : public OpKernel {
 };
 REGISTER_KERNEL_BUILDER(Name("SdcaShrinkL1").Device(DEVICE_CPU), SdcaShrinkL1);
 
+// Computes platform independent, compact and unique (with very high
+// probability) representation of an example id. It shouldn't be put in
+// persistent storage, as its implementation may change in the future.
+//
+// The current probability of at least one collision for 1B example_ids is
+// approximately 10^-21 (ie 2^60 / 2^129).
 class SdcaFprint : public OpKernel {
  public:
   explicit SdcaFprint(OpKernelConstruction* context) : OpKernel(context) {}
@@ -788,8 +794,8 @@ class SdcaFprint : public OpKernel {
     auto out_values = out->flat<string>();
 
     for (int64 i = 0; i < in_values.size(); ++i) {
-      const string& s = in_values(i);
-      Fprint128 fprint = Fingerprint128(s);
+      const Fprint128 fprint = Fingerprint128(in_values(i));
+      // Hex encode the fprint as a string (33 characters).
       out_values(i) = strings::StrCat(strings::FpToString(fprint.high64), "-",
                                       strings::FpToString(fprint.low64));
     }
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index 8ecf4bfe89a..2b9a95a1d65 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -781,7 +781,14 @@ class SdcaWithHingeLossTest(SdcaOptimizerTest):
 
 
 class SdcaFprintTest(TensorFlowTestCase):
-  """Tests for the SdcaFprint op."""
+  """Tests for the SdcaFprint op.
+
+  This is one way of enforcing the platform-agnostic nature of SdcaFprint.
+  Basically we are checking against exact values and this test could be running
+  across different platforms. Note that it is fine for expected values to change
+  in the future, if the implementation of SdcaFprint changes (ie this is *not* a
+  frozen test).
+  """
 
   def testFprint(self):
     with self.test_session():
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
index f669407c176..b4e9e5b23d9 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
@@ -95,8 +95,7 @@ class SdcaModel(object):
     ```
 
     In the training program you will just have to run the returned Op from
-    minimize(). You should also eventually cleanup the temporary state used by
-    the model, by resetting its (possibly shared) container.
+    minimize().
 
     ```python
     # Execute opt_op and train for num_steps.
diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD
index 792243790ce..8c9dc742228 100644
--- a/tensorflow/contrib/tensor_forest/BUILD
+++ b/tensorflow/contrib/tensor_forest/BUILD
@@ -18,6 +18,54 @@ filegroup(
     ),
 )
 
+py_library(
+    name = "constants",
+    srcs = [
+        "python/constants.py",
+    ],
+    srcs_version = "PY2AND3",
+)
+
+tf_custom_op_library(
+    name = "data/_data_ops.so",
+    srcs = [
+        "data/string_to_float_op.cc",
+    ],
+    deps = [
+        ":tree_utils",
+    ],
+)
+
+py_library(
+    name = "data_ops_lib",
+    srcs = [
+        "data/data_ops.py",
+    ],
+    data = [
+        "data/_data_ops.so",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":constants",
+    ],
+)
+
+py_library(
+    name = "eval_metrics",
+    srcs = ["client/eval_metrics.py"],
+    srcs_version = "PY2AND3",
+)
+
+py_library(
+    name = "client_lib",
+    srcs_version = "PY2AND3",
+    deps = [
+        ":data_ops_lib",
+        ":eval_metrics",
+        ":tensor_forest_py",
+    ],
+)
+
 cc_library(
     name = "tree_utils",
     srcs = ["core/ops/tree_utils.cc"],
@@ -86,6 +134,7 @@ py_test(
     srcs = ["python/kernel_tests/count_extremely_random_stats_op_test.py"],
     srcs_version = "PY2AND3",
     deps = [
+        ":constants",
         ":ops_lib",
         "//tensorflow:tensorflow_py",
         "//tensorflow/python:framework_test_lib",
@@ -151,6 +200,7 @@ py_test(
     srcs = ["python/kernel_tests/tree_predictions_op_test.py"],
     srcs_version = "PY2AND3",
     deps = [
+        ":constants",
         ":ops_lib",
         "//tensorflow:tensorflow_py",
         "//tensorflow/python:framework_test_lib",
@@ -176,6 +226,7 @@ py_library(
     srcs = ["python/tensor_forest.py"],
     srcs_version = "PY2AND3",
     deps = [
+        ":constants",
         ":ops_lib",
     ],
 )
diff --git a/tensorflow/contrib/tensor_forest/__init__.py b/tensorflow/contrib/tensor_forest/__init__.py
index 7cf05299c4b..7d97e01df08 100644
--- a/tensorflow/contrib/tensor_forest/__init__.py
+++ b/tensorflow/contrib/tensor_forest/__init__.py
@@ -18,4 +18,6 @@ from __future__ import division
 from __future__ import print_function
 
 # pylint: disable=unused-import,wildcard-import
+from tensorflow.contrib.tensor_forest.client import *
+from tensorflow.contrib.tensor_forest.data import *
 from tensorflow.contrib.tensor_forest.python import *
diff --git a/tensorflow/contrib/tensor_forest/client/__init__.py b/tensorflow/contrib/tensor_forest/client/__init__.py
new file mode 100644
index 00000000000..753f406cbc7
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/client/__init__.py
@@ -0,0 +1,21 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Random forest implementation in tensorflow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,wildcard-import
+from tensorflow.contrib.tensor_forest.client import eval_metrics
diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics.py b/tensorflow/contrib/tensor_forest/client/eval_metrics.py
new file mode 100644
index 00000000000..f41794a886e
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/client/eval_metrics.py
@@ -0,0 +1,69 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A collection of functions to be used as evaluation metrics."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib import losses
+from tensorflow.contrib.metrics.python.ops import metric_ops
+
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+
+
+def _accuracy(probabilities, targets):
+  predictions = math_ops.argmax(probabilities, 1)
+  # undo one-hot
+  labels = math_ops.argmax(targets, 1)
+  return metric_ops.streaming_accuracy(predictions, labels)
+
+
+def _r2(probabilities, targets):
+  if targets.get_shape().ndims == 1:
+    targets = array_ops.expand_dims(targets, -1)
+  y_mean = math_ops.reduce_mean(targets, 0)
+  squares_total = math_ops.reduce_sum(math_ops.square(targets - y_mean), 0)
+  squares_residuals = math_ops.reduce_sum(math_ops.square(
+      targets - probabilities), 0)
+  score = 1 - math_ops.reduce_sum(squares_residuals / squares_total)
+  return metric_ops.streaming_mean(score)
+
+
+def _sigmoid_entropy(probabilities, targets):
+  return metric_ops.streaming_mean(losses.sigmoid_cross_entropy(
+      probabilities, targets))
+
+
+def _softmax_entropy(probabilities, targets):
+  return metric_ops.streaming_mean(losses.softmax_cross_entropy(
+      probabilities, targets))
+
+
+def _predictions(probabilities, unused_targets):
+  return math_ops.argmax(probabilities, 1)
+
+
+_EVAL_METRICS = {'sigmoid_entropy': _sigmoid_entropy,
+                 'softmax_entropy': _softmax_entropy,
+                 'accuracy': _accuracy,
+                 'r2': _r2,
+                 'predictions': _predictions}
+
+
+def get_metric(metric_name):
+  """Given a metric name, return the corresponding metric function."""
+  return _EVAL_METRICS[metric_name]
diff --git a/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc b/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc
index 0ccf75bcc6f..0413f1e20a1 100644
--- a/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc
+++ b/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc
@@ -43,7 +43,7 @@ using tensorforest::LEAF_NODE;
 using tensorforest::FREE_NODE;
 
 using tensorforest::CheckTensorBounds;
-using tensorforest::DecideNode;
+using tensorforest::DataColumnTypes;
 using tensorforest::Initialize;
 using tensorforest::IsAllInitialized;
 
@@ -61,61 +61,77 @@ struct InputDataResult {
   bool splits_initialized;
 };
 
-void Evaluate(const Tensor& input_data, const Tensor& input_labels,
-              const Tensor& tree_tensor, const Tensor& tree_thresholds,
-              const Tensor& node_to_accumulator,
-              const Tensor& candidate_split_features,
-              const Tensor& candidate_split_thresholds,
-              InputDataResult* results, int32 start, int32 end) {
-  const auto tree = tree_tensor.tensor<int32, 2>();
-  const auto thresholds = tree_thresholds.unaligned_flat<float>();
-  const auto node_map = node_to_accumulator.unaligned_flat<int32>();
-  const auto split_features = candidate_split_features.tensor<int32, 2>();
-  const auto split_thresholds = candidate_split_thresholds.tensor<float, 2>();
+
+struct EvaluateParams {
+  std::function<bool(int, int, float,
+                    tensorforest::DataColumnTypes)> decide_function;
+  Tensor input_spec;
+  Tensor input_labels;
+  Tensor tree_tensor;
+  Tensor tree_thresholds;
+  Tensor node_to_accumulator;
+  Tensor candidate_split_features;
+  Tensor candidate_split_thresholds;
+  InputDataResult* results;
+};
+
+void Evaluate(const EvaluateParams& params, int32 start, int32 end) {
+  const auto tree = params.tree_tensor.tensor<int32, 2>();
+  const auto thresholds = params.tree_thresholds.unaligned_flat<float>();
+  const auto node_map = params.node_to_accumulator.unaligned_flat<int32>();
+  const auto split_features =
+      params.candidate_split_features.tensor<int32, 2>();
+  const auto split_thresholds =
+      params.candidate_split_thresholds.tensor<float, 2>();
+  const auto spec = params.input_spec.unaligned_flat<int32>();
 
   const int32 num_splits = static_cast<int32>(
-      candidate_split_features.shape().dim_size(1));
-  const int32 num_nodes = static_cast<int32>(tree_tensor.shape().dim_size(0));
+      params.candidate_split_features.shape().dim_size(1));
+  const int32 num_nodes = static_cast<int32>(
+      params.tree_tensor.shape().dim_size(0));
   const int32 num_accumulators = static_cast<int32>(
-      candidate_split_features.shape().dim_size(0));
+      params.candidate_split_features.shape().dim_size(0));
 
   for (int32 i = start; i < end; ++i) {
-    const Tensor point = input_data.Slice(i, i + 1);
     int node_index = 0;
-    results[i].splits_initialized = false;
+    params.results[i].splits_initialized = false;
     while (true) {
-      results[i].node_indices.push_back(node_index);
+      params.results[i].node_indices.push_back(node_index);
       CHECK_LT(node_index, num_nodes);
       int32 left_child = internal::SubtleMustCopy(
           tree(node_index, CHILDREN_INDEX));
       if (left_child == LEAF_NODE) {
         const int32 accumulator = internal::SubtleMustCopy(
             node_map(node_index));
-        results[i].leaf_accumulator = accumulator;
+        params.results[i].leaf_accumulator = accumulator;
         // If the leaf is not fertile or is not yet initialized, we don't
         // count it in the candidate/total split per-class-weights because
         // it won't have any candidate splits yet.
         if (accumulator >= 0 &&
-            IsAllInitialized(candidate_split_features.Slice(
+            IsAllInitialized(params.candidate_split_features.Slice(
                 accumulator, accumulator + 1))) {
           CHECK_LT(accumulator, num_accumulators);
-          results[i].splits_initialized = true;
+          params.results[i].splits_initialized = true;
           for (int split = 0; split < num_splits; split++) {
-            if (!DecideNode(point, split_features(accumulator, split),
-                            split_thresholds(accumulator, split))) {
-              results[i].split_adds.push_back(split);
+            const int32 feature = split_features(accumulator, split);
+            if (!params.decide_function(
+                i, feature, split_thresholds(accumulator, split),
+                static_cast<tensorforest::DataColumnTypes>(spec(feature)))) {
+              params.results[i].split_adds.push_back(split);
             }
           }
         }
         break;
       } else if (left_child == FREE_NODE) {
         LOG(ERROR) << "Reached a free node, not good.";
-        results[i].node_indices.push_back(FREE_NODE);
+        params.results[i].node_indices.push_back(FREE_NODE);
         break;
       }
+      const int32 feature = tree(node_index, FEATURE_INDEX);
       node_index =
-          left_child + DecideNode(point, tree(node_index, FEATURE_INDEX),
-                                  thresholds(node_index));
+          left_child + params.decide_function(
+              i, feature, thresholds(node_index),
+              static_cast<tensorforest::DataColumnTypes>(spec(feature)));
     }
   }
 }
@@ -124,16 +140,18 @@ REGISTER_OP("CountExtremelyRandomStats")
     .Attr("num_classes: int")
     .Attr("regression: bool = false")
     .Input("input_data: float")
+    .Input("sparse_input_indices: int64")
+    .Input("sparse_input_values: float")
+    .Input("sparse_input_shape: int64")
+    .Input("input_spec: int32")
     .Input("input_labels: float")
-
     .Input("tree: int32")
     .Input("tree_thresholds: float")
-
     .Input("node_to_accumulator: int32")
-
     .Input("candidate_split_features: int32")
     .Input("candidate_split_thresholds: float")
-
+    .Input("birth_epochs: int32")
+    .Input("current_epoch: int32")
     .Output("pcw_node_sums_delta: float")
     .Output("pcw_node_squares_delta: float")
     .Output("pcw_splits_indices: int32")
@@ -142,7 +160,6 @@ REGISTER_OP("CountExtremelyRandomStats")
     .Output("pcw_totals_indices: int32")
     .Output("pcw_totals_sums_delta: float")
     .Output("pcw_totals_squares_delta: float")
-
     .Output("leaves: int32")
     .Doc(R"doc(
 Calculates incremental statistics for a batch of training data.
@@ -156,7 +173,7 @@ For `regression` = false (classification), `pcw_node_sums_delta[i]` is
 incremented for every node i that it passes through, and the leaf it ends up
 in is recorded in `leaves[i]`.  Then, if the leaf is fertile and
 initialized, the statistics for its corresponding accumulator slot
-are updated in `pcw_candidate_splits_delta` and `pcw_total_splits_delta`.
+are updated in `pcw_candidate_sums_delta` and `pcw_totals_sums_delta`.
 
 For `regression` = true, outputs contain the sum of the input_labels
 for the appropriate nodes.  In adddition, the *_squares outputs are filled
@@ -171,6 +188,11 @@ The attr `num_classes` is needed to appropriately size the outputs.
 
 input_data: The training batch's features as a 2-d tensor; `input_data[i][j]`
   gives the j-th feature of the i-th input.
+sparse_input_indices: The indices tensor from the SparseTensor input.
+sparse_input_values: The values tensor from the SparseTensor input.
+sparse_input_shape: The shape tensor from the SparseTensor input.
+input_spec: A 1-D tensor containing the type of each column in input_data,
+  (e.g. continuous float, categorical).
 input_labels: The training batch's labels; `input_labels[i]` is the class
   of the i-th input.
 tree:= A 2-d int32 tensor.  `tree[i][0]` gives the index of the left child
@@ -185,6 +207,10 @@ candidate_split_features: `candidate_split_features[a][s]` is the
   index of the feature being considered by split s of accumulator slot a.
 candidate_split_thresholds: `candidate_split_thresholds[a][s]` is the
   threshold value being considered by split s of accumulator slot a.
+birth_epochs: `birth_epoch[i]` is the epoch node i was born in.  Only
+  nodes satisfying `current_epoch - birth_epoch <= 1` accumulate statistics.
+current_epoch:= A 1-d int32 tensor with shape (1).  current_epoch[0] contains
+  the current epoch.
 pcw_node_sums_delta: `pcw_node_sums_delta[i][c]` is the number of training
   examples in this training batch with class c that passed through node i for
   classification.  For regression, it is the sum of the input_labels that
@@ -236,17 +262,57 @@ class CountExtremelyRandomStats : public OpKernel {
 
   void Compute(OpKernelContext* context) override {
     const Tensor& input_data = context->input(0);
-    const Tensor& input_labels = context->input(1);
-    const Tensor& tree_tensor = context->input(2);
-    const Tensor& tree_thresholds = context->input(3);
-    const Tensor& node_to_accumulator = context->input(4);
-    const Tensor& candidate_split_features = context->input(5);
-    const Tensor& candidate_split_thresholds = context->input(6);
+    const Tensor& sparse_input_indices = context->input(1);
+    const Tensor& sparse_input_values = context->input(2);
+    const Tensor& sparse_input_shape = context->input(3);
+    const Tensor& input_spec = context->input(4);
+    const Tensor& input_labels = context->input(5);
+    const Tensor& tree_tensor = context->input(6);
+    const Tensor& tree_thresholds = context->input(7);
+    const Tensor& node_to_accumulator = context->input(8);
+    const Tensor& candidate_split_features = context->input(9);
+    const Tensor& candidate_split_thresholds = context->input(10);
+    const Tensor& birth_epochs = context->input(11);
+    const Tensor& current_epoch = context->input(12);
+
+    bool sparse_input = (sparse_input_indices.shape().dims() == 2);
 
     // Check inputs.
-    OP_REQUIRES(context, input_data.shape().dims() == 2,
+    if (sparse_input) {
+      OP_REQUIRES(context, sparse_input_shape.shape().dims() == 1,
+                  errors::InvalidArgument(
+                      "sparse_input_shape should be one-dimensional"));
+      OP_REQUIRES(context,
+                  sparse_input_shape.shape().dim_size(0) == 2,
+                  errors::InvalidArgument(
+                      "The sparse input data should be two-dimensional"));
+      OP_REQUIRES(context, sparse_input_values.shape().dims() == 1,
+                  errors::InvalidArgument(
+                      "sparse_input_values should be one-dimensional"));
+      OP_REQUIRES(context, sparse_input_indices.shape().dims() == 2,
+                  errors::InvalidArgument(
+                      "The sparse input data should be two-dimensional"));
+      OP_REQUIRES(context,
+                  sparse_input_indices.shape().dim_size(0) ==
+                  sparse_input_values.shape().dim_size(0),
+                  errors::InvalidArgument(
+                      "sparse_input_indices and sparse_input_values should "
+                      "agree on the number of non-zero values"));
+    } else {
+      OP_REQUIRES(context, input_data.shape().dims() == 2,
+                  errors::InvalidArgument(
+                      "input_data should be two-dimensional"));
+      OP_REQUIRES(
+          context,
+          input_data.shape().dim_size(0) == input_labels.shape().dim_size(0),
+          errors::InvalidArgument(
+              "Number of inputs should be the same in "
+              "input_data and input_labels."));
+    }
+
+    OP_REQUIRES(context, input_labels.shape().dims() >= 1,
                 errors::InvalidArgument(
-                    "input_data should be two-dimensional"));
+                    "input_labels should be at least one-dimensional"));
     OP_REQUIRES(context, tree_tensor.shape().dims() == 2,
             errors::InvalidArgument(
                 "tree should be two-dimensional"));
@@ -262,58 +328,93 @@ class CountExtremelyRandomStats : public OpKernel {
     OP_REQUIRES(context, candidate_split_thresholds.shape().dims() == 2,
             errors::InvalidArgument(
                 "candidate_split_thresholds should be two-dimensional"));
-
-    OP_REQUIRES(
-        context,
-        input_data.shape().dim_size(0) == input_labels.shape().dim_size(0),
-        errors::InvalidArgument(
-            "Number of inputs should be the same in "
-            "input_data and input_labels."));
+    OP_REQUIRES(context, birth_epochs.shape().dims() == 1,
+            errors::InvalidArgument(
+                "birth_epochs should be one-dimensional"));
+    OP_REQUIRES(context, current_epoch.shape().dims() == 1,
+            errors::InvalidArgument(
+                "current_epoch should be one-dimensional"));
 
     OP_REQUIRES(
         context,
         tree_tensor.shape().dim_size(0) ==
         tree_thresholds.shape().dim_size(0) &&
         tree_tensor.shape().dim_size(0) ==
-        node_to_accumulator.shape().dim_size(0),
+        node_to_accumulator.shape().dim_size(0) &&
+        tree_tensor.shape().dim_size(0) ==
+        birth_epochs.shape().dim_size(0),
         errors::InvalidArgument(
             "Number of nodes should be the same in "
-            "tree, tree_thresholds, and node_to_accumulator"));
+            "tree, tree_thresholds, node_to_accumulator, and birth_epoch."));
     OP_REQUIRES(
         context,
         candidate_split_features.shape() == candidate_split_thresholds.shape(),
         errors::InvalidArgument(
             "candidate_split_features and candidate_split_thresholds should be "
             "the same shape."));
+    OP_REQUIRES(
+        context,
+        current_epoch.shape().dim_size(0) == 1,
+        errors::InvalidArgument(
+            "The current_epoch should be a tensor of shape (1)."));
 
     // Check tensor bounds.
     if (!CheckTensorBounds(context, input_data)) return;
+    if (!CheckTensorBounds(context, sparse_input_indices)) return;
+    if (!CheckTensorBounds(context, sparse_input_values)) return;
+    if (!CheckTensorBounds(context, sparse_input_shape)) return;
     if (!CheckTensorBounds(context, input_labels)) return;
     if (!CheckTensorBounds(context, tree_tensor)) return;
     if (!CheckTensorBounds(context, tree_thresholds)) return;
     if (!CheckTensorBounds(context, node_to_accumulator)) return;
     if (!CheckTensorBounds(context, candidate_split_features)) return;
     if (!CheckTensorBounds(context, candidate_split_thresholds)) return;
+    if (!CheckTensorBounds(context, birth_epochs)) return;
+    if (!CheckTensorBounds(context, current_epoch)) return;
 
     // Evaluate input data in parallel.
-    const int32 num_data = static_cast<int32>(input_data.shape().dim_size(0));
+    const int32 epoch = current_epoch.unaligned_flat<int32>()(0);
+    int32 num_data;
+    std::function<bool(int, int, float,
+                      tensorforest::DataColumnTypes)> decide_function;
+    if (sparse_input) {
+      num_data = sparse_input_shape.unaligned_flat<int64>()(0);
+      decide_function = [&sparse_input_indices, &sparse_input_values](
+          int32 i, int32 feature, float bias, DataColumnTypes type) {
+        const auto sparse_indices = sparse_input_indices.matrix<int64>();
+        const auto sparse_values = sparse_input_values.vec<float>();
+        return tensorforest::DecideSparseNode(
+            sparse_indices, sparse_values, i, feature, bias, type);
+      };
+    } else {
+      num_data = static_cast<int32>(input_data.shape().dim_size(0));
+      decide_function = [&input_data](
+          int32 i, int32 feature, float bias, DataColumnTypes type) {
+        const auto input_matrix = input_data.matrix<float>();
+        return tensorforest::DecideDenseNode(
+            input_matrix, i, feature, bias, type);
+      };
+    }
     std::unique_ptr<InputDataResult[]> results(new InputDataResult[num_data]);
     auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
     int num_threads = worker_threads->num_threads;
+    EvaluateParams params;
+    params.decide_function = decide_function;
+    params.input_spec = input_spec;
+    params.input_labels = input_labels;
+    params.tree_tensor = tree_tensor;
+    params.tree_thresholds = tree_thresholds;
+    params.node_to_accumulator = node_to_accumulator;
+    params.candidate_split_features = candidate_split_features;
+    params.candidate_split_thresholds = candidate_split_thresholds;
+    params.results = results.get();
     if (num_threads <= 1) {
-      Evaluate(input_data, input_labels, tree_tensor, tree_thresholds,
-               node_to_accumulator, candidate_split_features,
-               candidate_split_thresholds, results.get(), 0, num_data);
+      Evaluate(params, 0, num_data);
     } else {
-      auto work = [&input_data, &input_labels, &tree_tensor, &tree_thresholds,
-                   &node_to_accumulator, &candidate_split_features,
-                   &candidate_split_thresholds, &num_data,
-                   &results](int64 start, int64 end) {
+      auto work = [&params, num_data](int64 start, int64 end) {
         CHECK(start <= end);
         CHECK(end <= num_data);
-        Evaluate(input_data, input_labels, tree_tensor, tree_thresholds,
-                 node_to_accumulator, candidate_split_features,
-                 candidate_split_thresholds, results.get(),
+        Evaluate(params,
                  static_cast<int32>(start), static_cast<int32>(end));
       };
       Shard(num_threads, worker_threads->workers, num_data, 100, work);
@@ -321,11 +422,13 @@ class CountExtremelyRandomStats : public OpKernel {
 
     const int32 num_nodes = static_cast<int32>(tree_tensor.shape().dim_size(0));
     if (regression_) {
-      ProcessResultsRegression(context, input_labels, std::move(results),
-                               num_nodes);
+      ProcessResultsRegression(
+          context, input_labels, birth_epochs, epoch, std::move(results),
+          num_nodes);
     } else {
-      ProcessResultsClassification(context, input_labels, std::move(results),
-                                   num_nodes);
+      ProcessResultsClassification(
+          context, input_labels, birth_epochs, epoch, std::move(results),
+          num_nodes);
     }
   }
 
@@ -333,10 +436,13 @@ class CountExtremelyRandomStats : public OpKernel {
   void ProcessResultsClassification(
       OpKernelContext* context,
       const Tensor &input_labels,
+      const Tensor &birth_epochs,
+      int32 epoch,
       std::unique_ptr<InputDataResult[]> results,
       int32 num_nodes) {
     const int32 num_data = static_cast<int32>(input_labels.shape().dim_size(0));
     const auto labels = input_labels.unaligned_flat<float>();
+    const auto start_epochs = birth_epochs.unaligned_flat<int32>();
 
     // Unused outputs for classification.  Still have to specify them or
     // tensorflow complains.
@@ -381,10 +487,16 @@ class CountExtremelyRandomStats : public OpKernel {
       CHECK_LT(column, num_classes_);
       const int32 accumulator = results[i].leaf_accumulator;
       for (const int32 node : results[i].node_indices) {
+        if (epoch > start_epochs(node) + 1) {
+          continue;
+        }
         ++out_node_sums(node, column);
         ++out_node_sums(node, 0);
       }
       out_leaves(i) = results[i].node_indices.back();
+      if (epoch > start_epochs(out_leaves(i)) + 1) {
+        continue;
+      }
       if (accumulator >= 0 && results[i].splits_initialized) {
         ++total_delta[make_pair(accumulator, column)];
         ++total_delta[make_pair(accumulator, 0)];
@@ -457,6 +569,8 @@ class CountExtremelyRandomStats : public OpKernel {
   void ProcessResultsRegression(
       OpKernelContext* context,
       const Tensor &input_labels,
+      const Tensor &birth_epochs,
+      const int32 epoch,
       std::unique_ptr<InputDataResult[]> results,
       int32 num_nodes) {
     const int32 num_data = static_cast<int32>(input_labels.shape().dim_size(0));
@@ -465,6 +579,7 @@ class CountExtremelyRandomStats : public OpKernel {
         num_outputs = static_cast<int32>(input_labels.shape().dim_size(1));
     }
     const auto labels = input_labels.unaligned_flat<float>();
+    const auto start_epochs = birth_epochs.unaligned_flat<int32>();
 
     // node pcw delta
     Tensor* output_node_pcw_sums_delta = nullptr;
@@ -503,6 +618,9 @@ class CountExtremelyRandomStats : public OpKernel {
     for (int32 i = 0; i < num_data; ++i) {
       const int32 accumulator = results[i].leaf_accumulator;
       for (const int32 node : results[i].node_indices) {
+        if (epoch > start_epochs(node) + 1) {
+          continue;
+        }
         for (int32 j = 0; j < num_outputs; ++j) {
           const float output = labels(i * num_outputs + j);
           out_node_sums(node, j + 1) += output;
@@ -512,6 +630,9 @@ class CountExtremelyRandomStats : public OpKernel {
         }
       }
       out_leaves(i) = results[i].node_indices.back();
+      if (epoch > start_epochs(out_leaves(i)) + 1) {
+        continue;
+      }
       if (accumulator >= 0 && results[i].splits_initialized) {
         total_delta[accumulator].insert(i);
         for (const int32 split : results[i].split_adds) {
diff --git a/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc b/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc
index e1369f9d8cb..d179f5b84e9 100644
--- a/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc
+++ b/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc
@@ -24,42 +24,84 @@ namespace tensorflow {
 
 using tensorforest::CheckTensorBounds;
 using tensorforest::Sum;
+using tensorforest::BestSplitDominatesClassification;
+using tensorforest::BestSplitDominatesRegression;
 
 REGISTER_OP("FinishedNodes")
+    .Attr("regression: bool = false")
     .Attr("num_split_after_samples: int")
+    .Attr("min_split_samples: int")
+    .Attr("dominate_fraction: float = 0.95")
     .Input("leaves: int32")
     .Input("node_to_accumulator: int32")
+    .Input("split_sums: float")
+    .Input("split_squares: float")
     .Input("accumulator_sums: float")
-
+    .Input("accumulator_squares: float")
+    .Input("birth_epochs: int32")
+    .Input("current_epoch: int32")
     .Output("finished: int32")
+    .Output("stale: int32")
     .Doc(R"doc(
 Determines which of the given leaf nodes are done accumulating.
 
 leaves:= A 1-d int32 tensor.  Lists the nodes that are currently leaves.
 node_to_accumulator: If the i-th node is fertile, `node_to_accumulator[i]`
   is it's accumulator slot.  Otherwise, `node_to_accumulator[i]` is -1.
-accumulator_sums: For classification, `accumulator_sums[a][c]` records how many
-  training examples have class c and have ended up in the fertile node
+split_sums:= a 3-d tensor where `split_sums[a][s]` summarizes the
+  training labels for examples that fall into the fertile node associated with
+  accumulator slot s and have then taken the *left* branch of candidate split
+  s.  For a classification problem, `split_sums[a][s][c]` is the count of such
+  examples with class c and for regression problems, `split_sums[a][s]` is the
+  sum of the regression labels for such examples.
+split_squares: Same as split_sums, but it contains the sum of the
+  squares of the regression labels.  Only used for regression.  For
+  classification problems, pass a dummy tensor into this.
+accumulator_sums: For classification, `accumulator_sums[a][c]` records how
+  many training examples have class c and have ended up in the fertile node
   associated with accumulator slot a.  It has the total sum in entry 0 for
   convenience. For regression, it is the same except it contains the sum
   of the input labels that have been seen, and entry 0 contains the number
   of training examples that have been seen.
-finished:= A 1-d int32 tensor. Contains the nodes that have total split
- counts greater or equal to the num_split_after_samples attribute.
+accumulator_squares: Same as accumulator_sums, but it contains the sum of the
+  squares of the regression labels.  Only used for regression.  For
+  classification problems, pass a dummy tensor into this.
+birth_epochs:= A 1-d int32 tensor.  `birth_epochs[i]` contains the epoch
+  the i-th node was created in.
+current_epoch:= A 1-d int32 tensor with shape (1).  `current_epoch[0]`
+  stores the current epoch number.
+finished:= A 1-d int32 tensor containing the indices of the finished nodes.
+  Nodes are finished if they have received at least num_split_after_samples
+  samples, or if they have received min_split_samples and the best scoring
+  split is sufficiently greater than the next best split.
+stale:= A 1-d int32 tensor containing the fertile nodes that were created two
+  or more epochs ago.
+
 )doc");
 
 class FinishedNodes : public OpKernel {
  public:
   explicit FinishedNodes(OpKernelConstruction* context)
       : OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr(
+        "regression", &regression_));
     OP_REQUIRES_OK(context, context->GetAttr(
         "num_split_after_samples", &num_split_after_samples_));
+    OP_REQUIRES_OK(context, context->GetAttr(
+        "min_split_samples", &min_split_samples_));
+    OP_REQUIRES_OK(context, context->GetAttr(
+        "dominate_fraction", &dominate_fraction_));
   }
 
   void Compute(OpKernelContext* context) override {
     const Tensor& leaf_tensor = context->input(0);
     const Tensor& node_to_accumulator = context->input(1);
-    const Tensor& accumulator_sums = context->input(2);
+    const Tensor& split_sums = context->input(2);
+    const Tensor& split_squares = context->input(3);
+    const Tensor& accumulator_sums = context->input(4);
+    const Tensor& accumulator_squares = context->input(5);
+    const Tensor& birth_epochs = context->input(6);
+    const Tensor& current_epoch = context->input(7);
 
     OP_REQUIRES(context, leaf_tensor.shape().dims() == 1,
                 errors::InvalidArgument(
@@ -67,25 +109,45 @@ class FinishedNodes : public OpKernel {
     OP_REQUIRES(context, node_to_accumulator.shape().dims() == 1,
                 errors::InvalidArgument(
                     "node_to_accumulator should be one-dimensional"));
+    OP_REQUIRES(context, split_sums.shape().dims() == 3,
+                errors::InvalidArgument(
+                    "split_sums should be three-dimensional"));
     OP_REQUIRES(context, accumulator_sums.shape().dims() == 2,
                 errors::InvalidArgument(
                     "accumulator_sums should be two-dimensional"));
+    OP_REQUIRES(context, birth_epochs.shape().dims() == 1,
+                errors::InvalidArgument(
+                    "birth_epochs should be one-dimensional"));
+    OP_REQUIRES(
+        context,
+        birth_epochs.shape().dim_size(0) ==
+        node_to_accumulator.shape().dim_size(0),
+        errors::InvalidArgument(
+            "birth_epochs and node_to_accumulator should be the same size."));
 
     // Check tensor bounds.
     if (!CheckTensorBounds(context, leaf_tensor)) return;
     if (!CheckTensorBounds(context, node_to_accumulator)) return;
+    if (!CheckTensorBounds(context, split_sums)) return;
+    if (!CheckTensorBounds(context, split_squares)) return;
     if (!CheckTensorBounds(context, accumulator_sums)) return;
+    if (!CheckTensorBounds(context, accumulator_squares)) return;
+    if (!CheckTensorBounds(context, birth_epochs)) return;
+    if (!CheckTensorBounds(context, current_epoch)) return;
 
     const auto leaves = leaf_tensor.unaligned_flat<int32>();
     const auto node_map = node_to_accumulator.unaligned_flat<int32>();
     const auto sums = accumulator_sums.tensor<float, 2>();
+    const auto start_epochs = birth_epochs.unaligned_flat<int32>();
+    const int32 epoch = current_epoch.unaligned_flat<int32>()(0);
 
     const int32 num_leaves = static_cast<int32>(
         leaf_tensor.shape().dim_size(0));
     const int32 num_accumulators = static_cast<int32>(
         accumulator_sums.shape().dim_size(0));
 
-    std::vector<int32> finished;
+    std::vector<int32> finished_leaves;
+    std::vector<int32> stale;
     for (int32 i = 0; i < num_leaves; i++) {
       const int32 leaf = internal::SubtleMustCopy(leaves(i));
       OP_REQUIRES(context, FastBoundsCheck(leaf, node_map.size()),
@@ -97,30 +159,74 @@ class FinishedNodes : public OpKernel {
 
       OP_REQUIRES(context, FastBoundsCheck(accumulator, num_accumulators),
                   errors::InvalidArgument("accumulator not in valid range."))
-
       // The first column holds the number of samples seen.
       // For classification, this should be the sum of the other columns.
-      if (sums(accumulator, 0) >= num_split_after_samples_) {
-        finished.push_back(leaf);
+      int32 count = sums(accumulator, 0);
+
+      if (epoch > start_epochs(leaf) + 1) {
+        if (count >= min_split_samples_) {
+          finished_leaves.push_back(leaf);
+        } else {
+          stale.push_back(leaf);
+        }
+        continue;
+      }
+
+      if (count >= num_split_after_samples_) {
+        finished_leaves.push_back(leaf);
+        continue;
+      }
+
+      if (count < min_split_samples_) {
+        continue;
+      }
+
+      bool finished = false;
+      if (regression_) {
+        finished = BestSplitDominatesRegression(
+            accumulator_sums, accumulator_squares,
+            split_sums, split_squares, accumulator);
+      } else {
+        finished = BestSplitDominatesClassification(
+            accumulator_sums, split_sums, accumulator, dominate_fraction_);
+      }
+
+      if (finished) {
+        finished_leaves.push_back(leaf);
       }
     }
 
     // Copy to output.
     Tensor* output_finished = nullptr;
     TensorShape finished_shape;
-    finished_shape.AddDim(finished.size());
+    finished_shape.AddDim(finished_leaves.size());
     OP_REQUIRES_OK(context,
                    context->allocate_output(0, finished_shape,
                                             &output_finished));
     auto out_finished = output_finished->unaligned_flat<int32>();
 
-    for (int32 i = 0; i < finished.size(); i++) {
-      out_finished(i) = finished[i];
+    for (int32 i = 0; i < finished_leaves.size(); i++) {
+      out_finished(i) = finished_leaves[i];
+    }
+
+    Tensor* output_stale = nullptr;
+    TensorShape stale_shape;
+    stale_shape.AddDim(stale.size());
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(1, stale_shape,
+                                            &output_stale));
+    auto out_stale = output_stale->unaligned_flat<int32>();
+
+    for (int32 i = 0; i < stale.size(); i++) {
+      out_stale(i) = stale[i];
     }
   }
 
  private:
+  bool regression_;
   int32 num_split_after_samples_;
+  int32 min_split_samples_;
+  float dominate_fraction_;
 };
 
 REGISTER_KERNEL_BUILDER(Name("FinishedNodes").Device(DEVICE_CPU),
diff --git a/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc b/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc
index 182b1257b6e..8b15f8a0b5f 100644
--- a/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc
+++ b/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc
@@ -35,6 +35,9 @@ REGISTER_OP("SampleInputs")
     .Attr("split_initializations_per_input: int")
     .Attr("split_sampling_random_seed: int")
     .Input("input_data: float")
+    .Input("sparse_input_indices: int64")
+    .Input("sparse_input_values: float")
+    .Input("sparse_input_shape: int64")
     .Input("node_to_accumulator: int32")
     .Input("leaves: int32")
     .Input("candidate_split_features: int32")
@@ -60,6 +63,9 @@ a single training example can initialize, and the attribute
 
 input_data: The features for the current batch of training data.
   `input_data[i][j]` is the j-th feature of the i-th input.
+sparse_input_indices: The indices tensor from the SparseTensor input.
+sparse_input_values: The values tensor from the SparseTensor input.
+sparse_input_shape: The shape tensor from the SparseTensor input.
 node_to_accumulator: For a fertile node i, node_to_accumulator[i] is the
   associated accumulator slot.  For non-fertile nodes, it is -1.
 leaves: `leaves[i]` is the leaf that the i-th input landed in, as
@@ -82,6 +88,7 @@ new_split_threshold_rows:  The new values for the candidate_split_thresholds
   `tf.scatter_update(candidate_split_thresholds,
                      accumulators_to_update,
                      new_split_feature_thresholds)`
+
 )doc");
 
 class SampleInputs : public OpKernel {
@@ -106,16 +113,74 @@ class SampleInputs : public OpKernel {
         new random::SimplePhilox(single_rand_.get()));
   }
 
+  template <typename T>
+  void GetRandomFeatureDense(const T& inputs, int32 num_features,
+                             int32 input_index, int32* index, float* val) {
+    *index = rng_->Uniform(num_features);
+    *val = inputs(input_index, *index);
+  }
+
+  template <typename T1, typename T2>
+  void GetRandomFeatureSparse(const T1& sparse_indices, const T2& sparse_values,
+                              int32 input_index, int32* index, float* val) {
+    int32 low = 0;
+    int32 high = sparse_values.dimension(0);
+    while (low < high) {
+      int32 vi = low + rng_->Uniform(high - low);
+      int64 i = internal::SubtleMustCopy(sparse_indices(vi, 0));
+      if (i == input_index) {
+        *index = internal::SubtleMustCopy(sparse_indices(vi, 1));
+        *val = sparse_values(vi);
+        return;
+      }
+      if (i < input_index) {
+        low = vi + 1;
+      } else {
+        high = vi;
+      }
+    }
+    LOG(FATAL) << "Could not find any values for input " << input_index
+               << " inside sparse_input_indices";
+  }
+
   void Compute(OpKernelContext* context) override {
     const Tensor& input_data = context->input(0);
-    const Tensor& node_to_accumulator = context->input(1);
-    const Tensor& leaves = context->input(2);
-    const Tensor& split_features = context->input(3);
-    const Tensor& split_thresholds = context->input(4);
+    const Tensor& sparse_input_indices = context->input(1);
+    const Tensor& sparse_input_values = context->input(2);
+    const Tensor& sparse_input_shape = context->input(3);
+    const Tensor& node_to_accumulator = context->input(4);
+    const Tensor& leaves = context->input(5);
+    const Tensor& split_features = context->input(6);
+    const Tensor& split_thresholds = context->input(7);
+
+    bool sparse_input = (sparse_input_indices.shape().dims() == 2);
+
+    if (sparse_input) {
+      OP_REQUIRES(context, sparse_input_shape.shape().dims() == 1,
+                  errors::InvalidArgument(
+                      "sparse_input_shape should be one-dimensional"));
+      OP_REQUIRES(context,
+                  sparse_input_shape.shape().dim_size(0) == 2,
+                  errors::InvalidArgument(
+                      "The sparse input data should be two-dimensional"));
+      OP_REQUIRES(context, sparse_input_values.shape().dims() == 1,
+                  errors::InvalidArgument(
+                      "sparse_input_values should be one-dimensional"));
+      OP_REQUIRES(context, sparse_input_indices.shape().dims() == 2,
+                  errors::InvalidArgument(
+                      "The sparse input data should be two-dimensional"));
+      OP_REQUIRES(context,
+                  sparse_input_indices.shape().dim_size(0) ==
+                  sparse_input_values.shape().dim_size(0),
+                  errors::InvalidArgument(
+                      "sparse_input_indices and sparse_input_values should "
+                      "agree on the number of non-zero values"));
+    } else {
+      OP_REQUIRES(context, input_data.shape().dims() == 2,
+                  errors::InvalidArgument(
+                  "input_data should be two-dimensional"));
+    }
 
-    OP_REQUIRES(context, input_data.shape().dims() == 2,
-                errors::InvalidArgument(
-                    "input_data should be two-dimensional"));
     OP_REQUIRES(context, node_to_accumulator.shape().dims() == 1,
                 errors::InvalidArgument(
                     "node_to_accumulator should be one-dimensional"));
@@ -137,12 +202,36 @@ class SampleInputs : public OpKernel {
 
     // Check tensor bounds.
     if (!CheckTensorBounds(context, input_data)) return;
+    if (!CheckTensorBounds(context, sparse_input_indices)) return;
+    if (!CheckTensorBounds(context, sparse_input_values)) return;
+    if (!CheckTensorBounds(context, sparse_input_shape)) return;
     if (!CheckTensorBounds(context, node_to_accumulator)) return;
     if (!CheckTensorBounds(context, leaves)) return;
     if (!CheckTensorBounds(context, split_features)) return;
     if (!CheckTensorBounds(context, split_thresholds)) return;
 
-    const auto inputs = input_data.tensor<float, 2>();
+    int32 num_features;
+    std::function<void(int32, int32*, float*)> get_random_feature;
+    // TODO(thomaswc): Figure out a way to avoid calling .vec, etc. over and
+    // over again
+    if (sparse_input) {
+      num_features = sparse_input_shape.unaligned_flat<int64>()(1);
+      get_random_feature = [&sparse_input_indices, &sparse_input_values, this](
+          int32 input_index, int32* index, float* val) {
+        const auto sparse_indices = sparse_input_indices.matrix<int64>();
+        const auto sparse_values = sparse_input_values.vec<float>();
+        GetRandomFeatureSparse(sparse_indices, sparse_values, input_index,
+                               index, val);
+      };
+    } else {
+      num_features = static_cast<int32>(input_data.shape().dim_size(1));
+      get_random_feature = [&input_data, num_features, this](
+          int32 input_index, int32* index, float* val) {
+        const auto inputs = input_data.tensor<float, 2>();
+        GetRandomFeatureDense(inputs, num_features, input_index, index, val);
+      };
+    }
+
     const auto leaves_vec = leaves.unaligned_flat<int32>();
     const auto node_map = node_to_accumulator.unaligned_flat<int32>();
     const auto features = split_features.tensor<int32, 2>();
@@ -151,8 +240,6 @@ class SampleInputs : public OpKernel {
     const int32 num_data = static_cast<int32>(leaves.shape().dim_size(0));
     const int32 num_splits = static_cast<int32>(
         split_features.shape().dim_size(1));
-    const int32 num_features = static_cast<int32>(
-        input_data.shape().dim_size(1));
     const int32 num_accumulators = static_cast<int32>(
         split_features.shape().dim_size(0));
 
@@ -234,10 +321,11 @@ class SampleInputs : public OpKernel {
         for (int split = 0; split < num_splits && num_inits > 0; split++) {
           if (new_split_feature_rows_flat(output_slot, split) < 0) {
             VLOG(1) << "Over-writing @ " << output_slot << "," << split;
-            const int32 index = rng_->Uniform(num_features);
+            int32 index;
+            float val;
+            get_random_feature(i, &index, &val);
             new_split_feature_rows_flat(output_slot, split) = index;
-            new_split_threshold_rows_flat(output_slot, split) =
-                inputs(i, index);
+            new_split_threshold_rows_flat(output_slot, split) = val;
             --num_inits;
           }
         }
diff --git a/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc b/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc
index 7db52ec3cae..1f77212d20e 100644
--- a/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc
+++ b/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc
@@ -30,12 +30,17 @@ using tensorforest::LEAF_NODE;
 using tensorforest::FREE_NODE;
 
 using tensorforest::CheckTensorBounds;
-using tensorforest::DecideNode;
+using tensorforest::DataColumnTypes;
 using tensorforest::Sum;
 
 REGISTER_OP("TreePredictions")
   .Attr("valid_leaf_threshold: float")
   .Input("input_data: float")
+  .Input("sparse_input_indices: int64")
+  .Input("sparse_input_values: float")
+  .Input("sparse_input_shape: int64")
+  .Input("input_spec: int32")
+
   .Input("tree: int32")
   .Input("tree_thresholds: float")
   .Input("node_per_class_weights: float")
@@ -46,6 +51,11 @@ REGISTER_OP("TreePredictions")
 
   input_data: The training batch's features as a 2-d tensor; `input_data[i][j]`
    gives the j-th feature of the i-th input.
+  sparse_input_indices: The indices tensor from the SparseTensor input.
+  sparse_input_values: The values tensor from the SparseTensor input.
+  sparse_input_shape: The shape tensor from the SparseTensor input.
+  input_spec: A 1-D tensor containing the type of each column in input_data,
+     (e.g. continuous float, categorical).
   tree:= A 2-d int32 tensor.  `tree[i][0]` gives the index of the left child
    of the i-th node, `tree[i][0] + 1` gives the index of the right child of
    the i-th node, and `tree[i][1]` gives the index of the feature used to
@@ -70,10 +80,42 @@ class TreePredictions : public OpKernel {
 
   void Compute(OpKernelContext* context) override {
     const Tensor& input_data = context->input(0);
+    const Tensor& sparse_input_indices = context->input(1);
+    const Tensor& sparse_input_values = context->input(2);
+    const Tensor& sparse_input_shape = context->input(3);
+    const Tensor& input_spec = context->input(4);
+    const Tensor& tree_tensor = context->input(5);
+    const Tensor& tree_thresholds = context->input(6);
+    const Tensor& node_per_class_weights = context->input(7);
 
-    const Tensor& tree_tensor = context->input(1);
-    const Tensor& tree_thresholds = context->input(2);
-    const Tensor& node_per_class_weights = context->input(3);
+    bool sparse_input = (sparse_input_indices.shape().dims() == 2);
+
+    if (sparse_input) {
+      OP_REQUIRES(context, sparse_input_values.shape().dims() == 1,
+                  errors::InvalidArgument(
+                      "sparse_input_values should be one-dimensional"));
+      OP_REQUIRES(context, sparse_input_shape.shape().dims() == 1,
+                  errors::InvalidArgument(
+                      "sparse_input_shape should be one-dimensional"));
+      OP_REQUIRES(context,
+                  sparse_input_indices.shape().dim_size(0) ==
+                  sparse_input_values.shape().dim_size(0),
+                  errors::InvalidArgument(
+                      "sparse_input_indices and sparse_input_values should "
+                      "agree on the number of non-zero values"));
+      OP_REQUIRES(context,
+                  sparse_input_indices.shape().dim_size(1) ==
+                  sparse_input_shape.shape().dim_size(0),
+                  errors::InvalidArgument(
+                      "sparse_input_indices and sparse_input_shape should "
+                      "agree on the dimensionality of data points"));
+    } else {
+      if (input_data.shape().dim_size(0) > 0) {
+        OP_REQUIRES(context, input_data.shape().dims() == 2,
+                    errors::InvalidArgument(
+                        "input_data should be two-dimensional"));
+      }
+    }
 
     OP_REQUIRES(context, tree_tensor.shape().dims() == 2,
                 errors::InvalidArgument(
@@ -85,11 +127,6 @@ class TreePredictions : public OpKernel {
                 errors::InvalidArgument(
                     "node_pcw should be two-dimensional"));
 
-    if (input_data.shape().dim_size(0) > 0) {
-      OP_REQUIRES(context, input_data.shape().dims() == 2,
-                  errors::InvalidArgument(
-                      "input_data should be two-dimensional"));
-    }
     OP_REQUIRES(
         context,
         tree_tensor.shape().dim_size(0) ==
@@ -102,16 +139,43 @@ class TreePredictions : public OpKernel {
 
     // Check tensor bounds.
     if (!CheckTensorBounds(context, input_data)) return;
+    if (!CheckTensorBounds(context, sparse_input_indices)) return;
+    if (!CheckTensorBounds(context, sparse_input_values)) return;
+    if (!CheckTensorBounds(context, sparse_input_shape)) return;
     if (!CheckTensorBounds(context, tree_tensor)) return;
     if (!CheckTensorBounds(context, tree_thresholds)) return;
     if (!CheckTensorBounds(context, node_per_class_weights)) return;
 
     const int32 num_classes = static_cast<int32>(
         node_per_class_weights.shape().dim_size(1));
-    const int32 num_data = static_cast<int32>(
-        input_data.shape().dim_size(0));
     const int32 num_nodes = static_cast<int32>(
         tree_tensor.shape().dim_size(0));
+    int32 num_data;
+    std::function<bool(int, int, float,
+                       tensorforest::DataColumnTypes)> decide_function;
+
+    if (sparse_input) {
+      num_data = sparse_input_shape.unaligned_flat<int64>()(0);
+      decide_function = [&sparse_input_indices, &sparse_input_values](
+          int32 i, int32 feature, float bias, DataColumnTypes type) {
+        const auto sparse_indices = sparse_input_indices.matrix<int64>();
+        const auto sparse_values = sparse_input_values.vec<float>();
+        return tensorforest::DecideSparseNode(
+            sparse_indices, sparse_values, i, feature, bias, type);
+      };
+    } else {
+      num_data = static_cast<int32>(input_data.shape().dim_size(0));
+      int32 num_features = 0;
+      if (num_data > 0) {
+        num_features = input_data.NumElements() / num_data;
+      }
+      decide_function = [&input_data](
+          int32 i, int32 feature, float bias, DataColumnTypes type) {
+        const auto input_matrix = input_data.matrix<float>();
+        return tensorforest::DecideDenseNode(
+            input_matrix, i, feature, bias, type);
+      };
+    }
 
     Tensor* output_predictions = nullptr;
     TensorShape output_shape;
@@ -124,10 +188,10 @@ class TreePredictions : public OpKernel {
 
     const auto node_pcw = node_per_class_weights.tensor<float, 2>();
     const auto tree = tree_tensor.tensor<int32, 2>();
+    const auto spec = input_spec.unaligned_flat<int32>();
     const auto thresholds = tree_thresholds.unaligned_flat<float>();
 
     for (int i = 0; i < num_data; i++) {
-      const Tensor point = input_data.Slice(i, i+1);
       int node_index = 0;
       int parent = -1;
       while (true) {
@@ -162,9 +226,11 @@ class TreePredictions : public OpKernel {
           return;
         }
         parent = node_index;
+        const int32 feature = tree(node_index, FEATURE_INDEX);
         node_index = left_child +
-            DecideNode(point, tree(node_index, FEATURE_INDEX),
-                       thresholds(node_index));
+            decide_function(
+                i, feature, thresholds(node_index),
+                static_cast<tensorforest::DataColumnTypes>(spec(feature)));
       }
     }
 
diff --git a/tensorflow/contrib/tensor_forest/core/ops/tree_utils.cc b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.cc
index f3fc4160554..398990780cd 100644
--- a/tensorflow/contrib/tensor_forest/core/ops/tree_utils.cc
+++ b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.cc
@@ -13,56 +13,127 @@
 // limitations under the License.
 // =============================================================================
 #include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"
+#include <cfloat>
+#include "tensorflow/core/platform/logging.h"
 
 namespace tensorflow {
 namespace tensorforest {
 
 using tensorflow::Tensor;
 
-int32 BestFeatureClassification(
+void GetTwoBest(int max, std::function<float(int)> score_fn,
+                float *best_score, int *best_index,
+                float *second_best_score) {
+  *best_index = -1;
+  *best_score = FLT_MAX;
+  *second_best_score = FLT_MAX;
+  for (int i = 0; i < max; i++) {
+    float score = score_fn(i);
+    if (score < *best_score) {
+      *second_best_score = *best_score;
+      *best_score = score;
+      *best_index = i;
+    } else if (score < *second_best_score) {
+      *second_best_score = score;
+    }
+  }
+}
+
+float ClassificationSplitScore(
+    const Eigen::Tensor<float, 1, Eigen::RowMajor>& splits,
+    const Eigen::Tensor<float, 1, Eigen::RowMajor>& rights,
+    int32 num_classes, int i) {
+  Eigen::array<int, 1> offsets;
+  offsets[0] = i * num_classes + 1;
+  Eigen::array<int, 1> extents;
+  extents[0] = num_classes - 1;
+  return WeightedGiniImpurity(splits.slice(offsets, extents)) +
+      WeightedGiniImpurity(rights.slice(offsets, extents));
+}
+
+void GetTwoBestClassification(
     const Tensor& total_counts, const Tensor& split_counts,
-    int32 accumulator) {
-  int32 best_feature_index = -1;
-  // We choose the split with the lowest score.
-  float best_score = kint64max;
+    int32 accumulator,
+    float *best_score, int *best_index,
+    float *second_best_score) {
   const int32 num_splits = static_cast<int32>(split_counts.shape().dim_size(1));
   const int32 num_classes = static_cast<int32>(
       split_counts.shape().dim_size(2));
+
   // Ideally, Eigen::Tensor::chip would be best to use here but it results
   // in seg faults, so we have to go with flat views of these tensors.  However,
   // it is still pretty efficient because we put off evaluation until the
   // score is actually returned.
   const auto tc = total_counts.Slice(
       accumulator, accumulator + 1).unaligned_flat<float>();
-  const auto splits = split_counts.Slice(
+
+  // TODO(gilberth): See if we can delay evaluation here by templating the
+  // arguments to ClassificationSplitScore.
+  const Eigen::Tensor<float, 1, Eigen::RowMajor> splits = split_counts.Slice(
       accumulator, accumulator + 1).unaligned_flat<float>();
   Eigen::array<int, 1> bcast;
   bcast[0] = num_splits;
-  const auto rights = tc.broadcast(bcast) - splits;
+  const Eigen::Tensor<float, 1, Eigen::RowMajor> rights =
+      tc.broadcast(bcast) - splits;
 
-  for (int i = 0; i < num_splits; i++) {
-    Eigen::array<int, 1> offsets;
-    offsets[0] = i * num_classes;
-    Eigen::array<int, 1> extents;
-    extents[0] = num_classes;
-    float score = WeightedGiniImpurity(splits.slice(offsets, extents)) +
-        WeightedGiniImpurity(rights.slice(offsets, extents));
+  std::function<float(int)> score_fn = std::bind(
+      ClassificationSplitScore, splits, rights, num_classes,
+      std::placeholders::_1);
 
-    if (score < best_score) {
-      best_score = score;
-      best_feature_index = i;
-    }
-  }
+  GetTwoBest(
+      num_splits, score_fn,
+      best_score, best_index, second_best_score);
+}
+
+int32 BestFeatureClassification(
+    const Tensor& total_counts, const Tensor& split_counts,
+    int32 accumulator) {
+  float best_score;
+  float second_best_score;
+  int best_feature_index;
+  GetTwoBestClassification(
+      total_counts, split_counts, accumulator,
+      &best_score, &best_feature_index, &second_best_score);
   return best_feature_index;
 }
 
-int32 BestFeatureRegression(
+float RegressionSplitScore(
+    const Eigen::Tensor<float, 3, Eigen::RowMajor>& splits_count_accessor,
+    const Eigen::Tensor<float, 2, Eigen::RowMajor>& totals_count_accessor,
+    const Eigen::Tensor<float, 1, Eigen::RowMajor>& splits_sum,
+    const Eigen::Tensor<float, 1, Eigen::RowMajor>& splits_square,
+    const Eigen::Tensor<float, 1, Eigen::RowMajor>& right_sums,
+    const Eigen::Tensor<float, 1, Eigen::RowMajor>& right_squares,
+    int32 accumulator,
+    int32 num_regression_dims, int i) {
+  Eigen::array<int, 1> offsets = {i * num_regression_dims + 1};
+  Eigen::array<int, 1> extents = {num_regression_dims - 1};
+  float left_count = splits_count_accessor(accumulator, i, 0);
+  float right_count = totals_count_accessor(accumulator, 0) - left_count;
+
+  float score = 0;
+
+  // Guard against divide-by-zero.
+  if (left_count > 0) {
+    score += WeightedVariance(
+        splits_sum.slice(offsets, extents),
+        splits_square.slice(offsets, extents), left_count);
+  }
+
+  if (right_count > 0) {
+    score += WeightedVariance(right_sums.slice(offsets, extents),
+                              right_squares.slice(offsets, extents),
+                              right_count);
+  }
+  return score;
+}
+
+void GetTwoBestRegression(
     const Tensor& total_sums, const Tensor& total_squares,
     const Tensor& split_sums, const Tensor& split_squares,
-    int32 accumulator) {
-  int32 best_feature_index = -1;
-  // We choose the split with the lowest score.
-  float best_score = kint64max;
+    int32 accumulator,
+    float *best_score, int *best_index,
+    float *second_best_score) {
   const int32 num_splits = static_cast<int32>(split_sums.shape().dim_size(1));
   const int32 num_regression_dims = static_cast<int32>(
       split_sums.shape().dim_size(2));
@@ -90,43 +161,138 @@ int32 BestFeatureRegression(
   const auto right_sums = tc_sum.broadcast(bcast) - splits_sum;
   const auto right_squares = tc_square.broadcast(bcast) - splits_square;
 
-  for (int i = 0; i < num_splits; i++) {
-    Eigen::array<int, 1> offsets;
-    offsets[0] = i * num_regression_dims;
-    Eigen::array<int, 1> extents;
-    extents[0] = num_regression_dims;
-    float left_count = splits_count_accessor(accumulator, i, 0);
-    float right_count = totals_count_accessor(accumulator, 0) - left_count;
+  GetTwoBest(
+      num_splits,
+      std::bind(RegressionSplitScore,
+                splits_count_accessor, totals_count_accessor,
+                splits_sum, splits_square, right_sums, right_squares,
+                accumulator, num_regression_dims, std::placeholders::_1),
+      best_score, best_index, second_best_score);
+}
 
-    float score = 0;
-
-    // Guard against divide-by-zero.
-    if (left_count > 0) {
-      score += WeightedVariance(
-        splits_sum.slice(offsets, extents),
-        splits_square.slice(offsets, extents), left_count);
-    }
-
-    if (right_count > 0) {
-        score += WeightedVariance(right_sums.slice(offsets, extents),
-                                  right_squares.slice(offsets, extents),
-                                  right_count);
-    }
-
-    if (score < best_score) {
-      best_score = score;
-      best_feature_index = i;
-    }
-  }
+int32 BestFeatureRegression(
+    const Tensor& total_sums, const Tensor& total_squares,
+    const Tensor& split_sums, const Tensor& split_squares,
+    int32 accumulator) {
+  float best_score;
+  float second_best_score;
+  int best_feature_index;
+  GetTwoBestRegression(
+      total_sums, total_squares, split_sums, split_squares, accumulator,
+      &best_score, &best_feature_index, &second_best_score);
   return best_feature_index;
 }
 
-bool DecideNode(const Tensor& point, int32 feature, float bias) {
+
+bool BestSplitDominatesRegression(
+    const Tensor& total_sums, const Tensor& total_squares,
+    const Tensor& split_sums, const Tensor& split_squares,
+    int32 accumulator) {
+  // TODO(thomaswc): Implement this, probably as part of v3.
+  return false;
+}
+
+bool BestSplitDominatesClassification(
+    const Tensor& total_counts,
+    const Tensor& split_counts, int32 accumulator,
+    float dominate_fraction) {
+  float best_score;
+  float second_best_score;
+  int best_feature_index;
+  GetTwoBestClassification(
+      total_counts, split_counts, accumulator,
+      &best_score, &best_feature_index, &second_best_score);
+
+  // Total counts are stored in the first column.
+  const int32 num_classes = split_counts.shape().dim_size(2) - 1;
+
+  // total_class_counts(c) is the # of class c examples seen by this
+  // accumulator.
+  auto total_class_counts = total_counts.Slice(
+      accumulator, accumulator + 1).unaligned_flat<float>();
+
+  const Eigen::Tensor<float, 1, Eigen::RowMajor> splits = split_counts.Slice(
+      accumulator, accumulator + 1).unaligned_flat<float>();
+
+  // For some reason, Eigen is fine with offsets being an array<int, 1> in
+  // ClassificationSplitScore, but it demands an array<Index, 1> here.
+  const Eigen::array<Eigen::Index, 1> offsets =
+      {num_classes * best_feature_index};
+  const Eigen::array<Eigen::Index, 1> extents = {num_classes};
+
+  const Eigen::Tensor<float, 1, Eigen::RowMajor> left_counts =
+      splits.slice(offsets, extents);
+  // I can find no other way using Eigen to copy a const Tensor into a
+  // non-const Tensor.
+  Eigen::Tensor<float, 1, Eigen::RowMajor> left_counts_copy(num_classes+1);
+  for (int i = 0; i <= num_classes; i++) {
+    left_counts_copy(i) = left_counts(i);
+  }
+
+  Eigen::Tensor<float, 1, Eigen::RowMajor> right_counts_copy =
+      total_class_counts - left_counts_copy;
+
+  // "Reverse-jackknife" estimate of how often the chosen best split is
+  // truly better than the second best split.  We use the reverse jackknife
+  // (in which counts are incremented) rather than the normal jackknife
+  // (in which counts are decremented) because the later badly underestimates
+  // the score variance of perfect splits.
+  float better_count = 0.0;
+  float worse_count = 0.0;
+  for (int i = 1; i <= num_classes; i++) {
+    left_counts_copy(i) += 1.0;
+    float weight = left_counts_copy(i);
+    float v = WeightedGiniImpurity(left_counts_copy)
+        + WeightedGiniImpurity(right_counts_copy);
+    left_counts_copy(i) -= 1.0;
+    if (v < second_best_score) {
+      better_count += weight;
+    } else {
+      worse_count += weight;
+    }
+
+    right_counts_copy(i) += 1.0;
+    weight = right_counts_copy(i);
+    v = WeightedGiniImpurity(left_counts)
+        + WeightedGiniImpurity(right_counts_copy);
+    right_counts_copy(i) -= 1.0;
+    if (v < second_best_score) {
+      better_count += weight;
+    } else {
+      worse_count += weight;
+    }
+  }
+
+  VLOG(1) << "Better count = " << better_count;
+  VLOG(1) << "Worse count = " << worse_count;
+  return better_count > dominate_fraction * (better_count + worse_count);
+}
+
+
+bool DecideNode(const Tensor& point, int32 feature, float bias,
+                DataColumnTypes type) {
   const auto p = point.unaligned_flat<float>();
   CHECK_LT(feature, p.size());
-  return p(feature) > bias;
+  return Decide(p(feature), bias, type);
 }
 
+
+bool Decide(float value, float bias, DataColumnTypes type) {
+  switch (type) {
+    case kDataFloat:
+      return value > bias;
+
+    case kDataCategorical:
+      // We arbitrarily define categorical equality as going left.
+      return value != bias;
+
+    default:
+      LOG(ERROR) << "Got unknown column type: " << type;
+      return false;
+  }
+}
+
+
 bool IsAllInitialized(const Tensor& features) {
   const auto feature_vec = features.unaligned_flat<int32>();
   return feature_vec(feature_vec.size() - 1) >= 0;
diff --git a/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h
index 19b02e379e7..067f0768d3c 100644
--- a/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h
+++ b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h
@@ -19,6 +19,7 @@
 
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/bounds_check.h"
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/types.h"
@@ -26,6 +27,7 @@
 namespace tensorflow {
 namespace tensorforest {
 
+// TODO(gilberth): Put these in protos so they can be shared by C++ and python.
 // Indexes in the tree representation's 2nd dimension for children and features.
 const int32 CHILDREN_INDEX = 0;
 const int32 FEATURE_INDEX = 1;
@@ -34,6 +36,14 @@ const int32 FEATURE_INDEX = 1;
 const int32 LEAF_NODE = -1;
 const int32 FREE_NODE = -2;
 
+// Used to indicate column types, e.g. categorical vs. float
+enum DataColumnTypes {
+  kDataFloat = 0,
+  kDataCategorical = 1
+};
+
+
+
 // Calculates the sum of a tensor.
 template<typename T>
 T Sum(Tensor counts) {
@@ -80,6 +90,20 @@ int32 BestFeatureRegression(const Tensor& total_sums,
                             const Tensor& split_sums,
                             const Tensor& split_squares, int32 accumulator);
 
+// Returns true if the best split's variance is sufficiently smaller than
+// that of the next best split.
+bool BestSplitDominatesRegression(
+    const Tensor& total_sums, const Tensor& total_squares,
+    const Tensor& split_sums, const Tensor& split_squares,
+    int32 accumulator);
+
+// Returns true if the best split's Gini impurity is sufficiently smaller than
+// that of the next best split.
+bool BestSplitDominatesClassification(
+    const Tensor& total_counts,
+    const Tensor& split_counts, int32 accumulator,
+    float dominate_fraction);
+
 // Initializes everything in the given tensor to the given value.
 template <typename T>
 void Initialize(Tensor counts, T val = 0) {
@@ -90,7 +114,74 @@ void Initialize(Tensor counts, T val = 0) {
 // Returns true if the point falls to the right (i.e., the selected feature
 // of the input point is greater than the bias threshold), and false if it
 // falls to the left.
-bool DecideNode(const Tensor& point, int32 feature, float bias);
+// Even though our input data is forced into float Tensors, it could have
+// originally been something else (e.g. categorical string data) which
+// we treat differently.
+bool DecideNode(const Tensor& point, int32 feature, float bias,
+                DataColumnTypes type = kDataFloat);
+
+// Returns input_data(i, feature) > bias.
+template <typename T>
+bool DecideDenseNode(const T& input_data,
+                     int32 i, int32 feature, float bias,
+                     DataColumnTypes type = kDataFloat) {
+    CHECK_LT(i, input_data.dimensions()[0]);
+    CHECK_LT(feature, input_data.dimensions()[1]);
+    return Decide(input_data(i, feature), bias, type);
+}
+
+// If T is a sparse float matrix represented by sparse_input_indices and
+// sparse_input_values, FindSparseValue returns T(i,j), or 0.0 if (i,j)
+// isn't present in sparse_input_indices.  sparse_input_indices is assumed
+// to be sorted.
+template <typename T1, typename T2>
+float FindSparseValue(
+    const T1& sparse_input_indices,
+    const T2& sparse_input_values,
+    int32 i, int32 j) {
+  int32 low = 0;
+  int32 high = sparse_input_values.dimension(0);
+  while (low < high) {
+    int32 mid = (low + high) / 2;
+    int64 midi = internal::SubtleMustCopy(sparse_input_indices(mid, 0));
+    int64 midj = internal::SubtleMustCopy(sparse_input_indices(mid, 1));
+    if (midi == i) {
+      if (midj == j) {
+        return sparse_input_values(mid);
+      }
+      if (midj < j) {
+        low = mid + 1;
+      } else {
+        high = mid;
+      }
+      continue;
+    }
+    if (midi < i) {
+      low = mid + 1;
+    } else {
+      high = mid;
+    }
+  }
+  return 0.0;
+}
+
+// Returns t(i, feature) > bias, where t is the sparse tensor represented by
+// sparse_input_indices and sparse_input_values.
+template <typename T1, typename T2>
+bool DecideSparseNode(
+    const T1& sparse_input_indices,
+    const T2& sparse_input_values,
+    int32 i, int32 feature, float bias,
+    DataColumnTypes type = kDataFloat) {
+  return Decide(
+      FindSparseValue(sparse_input_indices, sparse_input_values, i, feature),
+      bias, type);
+}
+
+// Returns left/right decision between the input value and the threshold bias.
+// For floating point types, the decision is value > bias, but for
+// categorical data, it is value != bias.
+bool Decide(float value, float bias, DataColumnTypes type = kDataFloat);
 
 // Returns true if all the splits are initialized. Since they get initialized
 // in order, we can simply infer this from the last split.
diff --git a/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc b/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc
index 026262e47ff..33638ca7e67 100644
--- a/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc
+++ b/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc
@@ -36,7 +36,7 @@ using tensorforest::Initialize;
 using tensorforest::WeightedGiniImpurity;
 
 REGISTER_OP("UpdateFertileSlots")
-    .Attr("max_depth: int")
+  .Attr("max_depth: int")
     .Attr("regression: bool = False")
     .Input("finished: int32")
     .Input("non_fertile_leaves: int32")
@@ -45,11 +45,10 @@ REGISTER_OP("UpdateFertileSlots")
     .Input("tree_depths: int32")
     .Input("accumulator_sums: float")
     .Input("node_to_accumulator: int32")
+    .Input("stale_leaves: int32")
     .Output("node_map_updates: int32")
     .Output("accumulators_cleared: int32")
     .Output("accumulators_allocated: int32")
-    .Output("new_nonfertile_leaves: int32")
-    .Output("new_nonfertile_leaves_scores: float")
     .Doc(R"doc(
 Updates accumulator slots to reflect finished or newly fertile nodes.
 
@@ -77,6 +76,8 @@ accumulator_sums: For classification, `accumulator_sums[a][c]` records how
   of training examples that have been seen.
 node_to_accumulator: `node_to_accumulator[i]` is the accumulator slot used by
   fertile node i, or -1 if node i isn't fertile.
+stale_leaves:= A 1-d int32 tensor containing the indices of all leaves that
+  have stopped accumulating statistics because they are too old.
 node_map_updates:= A 2-d int32 tensor describing the changes that need to
   be applied to the node_to_accumulator map.  Intended to be used with
   `tf.scatter_update(node_to_accumulator,
@@ -86,10 +87,7 @@ accumulators_cleared:= A 1-d int32 tensor containing the indices of all
   the accumulator slots that need to be cleared.
 accumulators_allocated:= A 1-d int32 tensor containing the indices of all
   the accumulator slots that need to be allocated.
-new_nonfertile_leaves:= A 1-d int32 tensor containing the indices of all the
-  leaves that are now non-fertile.
-new_nonfertile_leaves_scores: `new_nonfertile_leaves_scores[i]` contains the
-  splitting score for the non-fertile leaf `new_nonfertile_leaves[i]`.
+
 )doc");
 
 class UpdateFertileSlots : public OpKernel {
@@ -112,6 +110,7 @@ class UpdateFertileSlots : public OpKernel {
 
     const Tensor& accumulator_sums = context->input(5);
     const Tensor& node_to_accumulator = context->input(6);
+    const Tensor& stale_leaves = context->input(7);
 
     OP_REQUIRES(context, finished.shape().dims() == 1,
                 errors::InvalidArgument(
@@ -134,6 +133,9 @@ class UpdateFertileSlots : public OpKernel {
      OP_REQUIRES(context, node_to_accumulator.shape().dims() == 1,
                 errors::InvalidArgument(
                     "node_to_accumulator should be one-dimensional"));
+     OP_REQUIRES(context, stale_leaves.shape().dims() == 1,
+                errors::InvalidArgument(
+                    "stale_leaves should be one-dimensional"));
 
     OP_REQUIRES(
         context,
@@ -151,6 +153,7 @@ class UpdateFertileSlots : public OpKernel {
     if (!CheckTensorBounds(context, tree_depths)) return;
     if (!CheckTensorBounds(context, accumulator_sums)) return;
     if (!CheckTensorBounds(context, node_to_accumulator)) return;
+    if (!CheckTensorBounds(context, stale_leaves)) return;
 
     // Read finished accumulators into a set for quick lookup.
     const auto node_map = node_to_accumulator.unaligned_flat<int32>();
@@ -164,6 +167,16 @@ class UpdateFertileSlots : public OpKernel {
           errors::InvalidArgument("finished node is outside the valid range"));
       finished_accumulators.insert(node_map(node));
     }
+    // Stale accumulators are also finished for the purposes of clearing
+    // and re-allocating.
+    const auto stale_vec = stale_leaves.unaligned_flat<int32>();
+    for (int32 i = 0; i < stale_vec.size(); ++i) {
+      const int32 node = internal::SubtleMustCopy(stale_vec(i));
+      OP_REQUIRES(
+          context, FastBoundsCheck(node, node_map.size()),
+          errors::InvalidArgument("stale node is outside the valid range"));
+      finished_accumulators.insert(node_map(node));
+    }
 
     // Construct leaf heap to sort leaves to allocate accumulators to.
     const int32 num_nodes = static_cast<int32>(tree_depths.shape().dim_size(0));
@@ -210,11 +223,10 @@ class UpdateFertileSlots : public OpKernel {
     }
 
     // Construct and fill outputs.
-    SetNodeMapUpdates(accumulators_to_node, finished, context);
+    SetNodeMapUpdates(accumulators_to_node, finished, stale_leaves, context);
     SetAccumulatorsCleared(finished_accumulators,
                            accumulators_to_node, context);
     SetAccumulatorsAllocated(accumulators_to_node, context);
-    SetNewNonFertileLeaves(values.get(), i, context);
   }
 
  private:
@@ -228,18 +240,20 @@ class UpdateFertileSlots : public OpKernel {
   typedef TopN<std::pair<int32, float>, OrderBySecondGreater> LeafHeapType;
   typedef std::vector<std::pair<int32, float>> HeapValuesType;
 
-  // Creates an update tensor for node to accumulator map.  Sets finished nodes
-  // to -1 (no accumulator assigned) and newly allocated nodes to their
-  // accumulator.
+  // Creates an update tensor for node to accumulator map.  Sets finished and
+  // stale nodes to -1 (no accumulator assigned) and newly allocated nodes to
+  // their accumulator.
   void SetNodeMapUpdates(
       const std::unordered_map<int32, int32>& accumulators_to_node,
-      const Tensor& finished, OpKernelContext* context) {
+      const Tensor& finished, const Tensor& stale, OpKernelContext* context) {
     // Node map updates.
     Tensor* output_node_map = nullptr;
     TensorShape node_map_shape;
     node_map_shape.AddDim(2);
-    node_map_shape.AddDim(accumulators_to_node.size() +
-                          static_cast<int32>(finished.shape().dim_size(0)));
+    node_map_shape.AddDim(
+        accumulators_to_node.size() +
+        static_cast<int32>(stale.shape().dim_size(0) +
+                           finished.shape().dim_size(0)));
     OP_REQUIRES_OK(context,
                    context->allocate_output(0, node_map_shape,
                                             &output_node_map));
@@ -254,6 +268,13 @@ class UpdateFertileSlots : public OpKernel {
       out_node(1, output_slot)  = -1;
       ++output_slot;
     }
+    // Set stale nodes to -1.
+    const auto stale_vec = stale.unaligned_flat<int32>();
+    for (int32 i = 0; i < stale_vec.size(); ++i) {
+      out_node(0, output_slot) = stale_vec(i);
+      out_node(1, output_slot)  = -1;
+      ++output_slot;
+    }
 
     // Set newly allocated nodes to their allocator.
     for (const auto& node_alloc_pair : accumulators_to_node) {
@@ -315,56 +336,6 @@ class UpdateFertileSlots : public OpKernel {
     }
   }
 
-  // Creates output tensors for non-fertile leaves and non-fertile leaf scores.
-  // Start indicates the index in values where the leaves that weren't
-  // allocated this round begin, and should thus be placed in the new
-  // nonfertile_leaves tensors.
-  void SetNewNonFertileLeaves(HeapValuesType* values, int32 start,
-                              OpKernelContext* context) {
-    // Node map updates.
-    int32 num_values = static_cast<int32>(values->size()) - start;
-
-    // Unfortunately, a zero-sized Variable results in an uninitialized
-    // error, probably because they check for zero size instead of
-    // a real inititalization condition.
-    bool fill_with_garbage = false;
-    if (num_values == 0) {
-      num_values = 1;
-      fill_with_garbage = true;
-    }
-    Tensor* output_nonfertile_leaves = nullptr;
-    TensorShape nonfertile_leaves_shape;
-    nonfertile_leaves_shape.AddDim(num_values);
-    OP_REQUIRES_OK(context,
-                   context->allocate_output(3, nonfertile_leaves_shape,
-                                            &output_nonfertile_leaves));
-
-    auto out_nonfertile_leaves =
-        output_nonfertile_leaves->unaligned_flat<int32>();
-
-    Tensor* output_nonfertile_leaves_scores = nullptr;
-    TensorShape nonfertile_leaves_scores_shape;
-    nonfertile_leaves_scores_shape.AddDim(num_values);
-    OP_REQUIRES_OK(context,
-                   context->allocate_output(4, nonfertile_leaves_scores_shape,
-                                            &output_nonfertile_leaves_scores));
-
-    auto out_nonfertile_leaves_scores =
-        output_nonfertile_leaves_scores->unaligned_flat<float>();
-
-    if (fill_with_garbage) {
-      out_nonfertile_leaves(0) = -1;
-      out_nonfertile_leaves_scores(0) = 0.0;
-      return;
-    }
-
-    for (int32 i = start; i < values->size(); ++i) {
-      const std::pair<int32, float>& node = (*values)[i];
-      out_nonfertile_leaves(i -start) = node.first;
-      out_nonfertile_leaves_scores(i - start) = node.second;
-    }
-  }
-
   void ConstructLeafHeap(const Tensor& non_fertile_leaves,
                          const Tensor& non_fertile_leaf_scores,
                          const Tensor& tree_depths, int32 end_of_tree,
diff --git a/tensorflow/contrib/tensor_forest/data/__init__.py b/tensorflow/contrib/tensor_forest/data/__init__.py
new file mode 100644
index 00000000000..3d04705878d
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/data/__init__.py
@@ -0,0 +1,21 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Random forest implementation in tensorflow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,wildcard-import
+from tensorflow.contrib.tensor_forest.data import data_ops
diff --git a/tensorflow/contrib/tensor_forest/data/data_ops.py b/tensorflow/contrib/tensor_forest/data/data_ops.py
new file mode 100644
index 00000000000..ca229f4ce93
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/data/data_ops.py
@@ -0,0 +1,109 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops for preprocessing data."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+
+from tensorflow.contrib.tensor_forest.python import constants
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import load_library
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.platform import tf_logging as logging
+
+DATA_OPS_FILE = '_data_ops.so'
+
+_data_ops = None
+_ops_lock = threading.Lock()
+
+
+ops.NoGradient('StringToFloat')
+
+
+@ops.RegisterShape('StringToFloat')
+def StringToFloatShape(op):
+  """Shape function for StringToFloat Op."""
+  return [op.inputs[0].get_shape()]
+
+
+# Workaround for the fact that importing tensorflow imports contrib
+# (even if a user isn't using this or any other contrib op), but
+# there's not yet any guarantee that the shared object exists.
+# In which case, "import tensorflow" will always crash, even for users that
+# never use contrib.
+def Load():
+  """Load the data ops library and return the loaded module."""
+  with _ops_lock:
+    global _data_ops
+    if not _data_ops:
+      ops_path = resource_loader.get_path_to_datafile(DATA_OPS_FILE)
+      logging.info('data path: %s', ops_path)
+      _data_ops = load_library.load_op_library(ops_path)
+
+      assert _data_ops, 'Could not load _data_ops.so'
+  return _data_ops
+
+
+def ParseDataTensorOrDict(data):
+  """Return a tensor to use for input data.
+
+  The incoming features can be a dict where keys are the string names of the
+  columns, which we turn into a single 2-D tensor.
+
+  Args:
+    data: `Tensor` or `dict` of `Tensor` objects.
+
+  Returns:
+    A 2-D tensor for input to tensor_forest and a 1-D tensor of the
+      type of each column (e.g. continuous float, categorical).
+  """
+  convert_ops = Load()
+  if isinstance(data, dict):
+    data_spec = [constants.DATA_CATEGORICAL if data[k].dtype == dtypes.string
+                 else constants.DATA_FLOAT
+                 for k in sorted(data.keys())]
+    return array_ops.concat(1, [
+        convert_ops.string_to_float(data[k])
+        if data[k].dtype == dtypes.string else data[k]
+        for k in sorted(data.keys())]), data_spec
+  else:
+    return data, [constants.DATA_FLOAT] * data.get_shape().as_list()[1]
+
+
+def ParseLabelTensorOrDict(labels):
+  """Return a tensor to use for input labels to tensor_forest.
+
+  The incoming targets can be a dict where keys are the string names of the
+  columns, which we turn into a single 1-D tensor for classification or
+  2-D tensor for regression.
+
+  Args:
+    labels: `Tensor` or `dict` of `Tensor` objects.
+
+  Returns:
+    A 2-D tensor for labels/outputs.
+  """
+  if isinstance(labels, dict):
+    return math_ops.to_float(array_ops.concat(
+        1, [labels[k] for k in sorted(labels.keys())]))
+  else:
+    return math_ops.to_float(labels)
diff --git a/tensorflow/contrib/tensor_forest/data/string_to_float_op.cc b/tensorflow/contrib/tensor_forest/data/string_to_float_op.cc
new file mode 100644
index 00000000000..3908855063d
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/data/string_to_float_op.cc
@@ -0,0 +1,111 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+// Converts strings of arbitrary length to float values by
+// hashing and cramming bits.
+#include <functional>
+
+#include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+using tensorforest::CheckTensorBounds;
+
+
+float Convert(const string& in) {
+  const std::size_t intval = std::hash<string>()(in);
+  return static_cast<float>(intval);
+}
+
+
+void Evaluate(const Tensor& input_data, Tensor output_data,
+              int32 start, int32 end) {
+  auto out_data = output_data.tensor<float, 2>();
+  const auto in_data = input_data.tensor<string, 2>();
+
+  for (int32 i = start; i < end; ++i) {
+    for (int32 j = 0; j < output_data.dim_size(1); ++j) {
+      out_data(i, j) = Convert(in_data(i, j));
+    }
+  }
+}
+
+
+REGISTER_OP("StringToFloat")
+  .Input("input_data: string")
+  .Output("output_data: float")
+
+  .Doc(R"doc(
+   Converts byte arrays represented by strings to 32-bit
+   floating point numbers. The output numbers themselves are meaningless, and
+   should only be used in == comparisons.
+
+   input_data: A batch of string features as a 2-d tensor; `input_data[i][j]`
+     gives the j-th feature of the i-th input.
+   output_data: A tensor of the same shape as input_data but the values are
+     float32.
+
+)doc");
+
+class StringToFloat : public OpKernel {
+ public:
+  explicit StringToFloat(OpKernelConstruction* context)
+      : OpKernel(context) {}
+
+  void Compute(OpKernelContext* context) override {
+    const Tensor& input_data = context->input(0);
+
+    // Check inputs.
+    OP_REQUIRES(context, input_data.shape().dims() == 2,
+                errors::InvalidArgument(
+                    "input_data should be two-dimensional"));
+
+    // Check tensor bounds.
+    if (!CheckTensorBounds(context, input_data)) return;
+
+    Tensor* output_data = nullptr;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(0, input_data.shape(),
+                                            &output_data));
+
+    // Evaluate input data in parallel.
+    const int32 num_data = static_cast<int32>(input_data.shape().dim_size(0));
+    auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
+    int num_threads = worker_threads->num_threads;
+    if (num_threads <= 1) {
+      Evaluate(input_data, *output_data, 0, num_data);
+    } else {
+      auto work = [&input_data, output_data, num_data](int64 start, int64 end) {
+        CHECK(start <= end);
+        CHECK(end <= num_data);
+        Evaluate(input_data, *output_data,
+                 static_cast<int32>(start), static_cast<int32>(end));
+      };
+      Shard(num_threads, worker_threads->workers, num_data, 100, work);
+    }
+  }
+};
+
+
+REGISTER_KERNEL_BUILDER(Name("StringToFloat").Device(DEVICE_CPU),
+                        StringToFloat);
+
+}  // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/python/__init__.py b/tensorflow/contrib/tensor_forest/python/__init__.py
index 0f692bbe972..a9dd599c970 100644
--- a/tensorflow/contrib/tensor_forest/python/__init__.py
+++ b/tensorflow/contrib/tensor_forest/python/__init__.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.contrib.tensor_forest.python import constants
 from tensorflow.contrib.tensor_forest.python import tensor_forest
 from tensorflow.contrib.tensor_forest.python.ops import inference_ops
 from tensorflow.contrib.tensor_forest.python.ops import training_ops
diff --git a/tensorflow/contrib/tensor_forest/python/constants.py b/tensorflow/contrib/tensor_forest/python/constants.py
new file mode 100644
index 00000000000..029c7824615
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/constants.py
@@ -0,0 +1,26 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Constants used by tensorforest.  Some of these map to values in C++ ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# If tree[i][0] equals this value, then i is a leaf node.
+LEAF_NODE = -1
+
+# Data column types for indicating categorical or other non-float values.
+DATA_FLOAT = 0
+DATA_CATEGORICAL = 1
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py
index c5b5981adba..3641ab0ee06 100644
--- a/tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py
@@ -30,14 +30,16 @@ class BestSplitsClassificationTests(test_util.TensorFlowTestCase):
   def setUp(self):
     self.finished = [3, 5]
     self.node_map = [-1, -1, -1, 0, -1, 3, -1, -1, -1]
-    self.candidate_counts = [[[50., 60., 40., 3.], [70., 30., 70., 30.]],
-                             [[0., 0., 0., 0.], [0., 0., 0., 0.]],
-                             [[0., 0., 0., 0.], [0., 0., 0., 0.]],
-                             [[10., 10., 10., 10.], [10., 5., 5., 10.]]]
-    self.total_counts = [[100., 100., 100., 100.],
-                         [0., 0., 0., 0.],
-                         [0., 0., 0., 0.],
-                         [100., 100., 100., 100.]]
+    self.candidate_counts = [[[153., 50., 60., 40., 3.],
+                              [200., 70., 30., 70., 30.]],
+                             [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]],
+                             [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]],
+                             [[40., 10., 10., 10., 10.],
+                              [30., 10., 5., 5., 10.]]]
+    self.total_counts = [[400., 100., 100., 100., 100.],
+                         [0., 0., 0., 0., 0.],
+                         [0., 0., 0., 0., 0.],
+                         [400., 100., 100., 100., 100.]]
     self.squares = []
     self.ops = training_ops.Load()
 
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py
index eb61573f24f..a50eb22795c 100644
--- a/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
 
 import tensorflow as tf
 
+from tensorflow.contrib.tensor_forest.python import constants
 from tensorflow.contrib.tensor_forest.python.ops import training_ops
 
 from tensorflow.python.framework import test_util
@@ -37,16 +38,20 @@ class CountExtremelyRandomStatsClassificationTest(test_util.TensorFlowTestCase):
     self.split_features = [[1], [-1]]
     self.split_thresholds = [[1.], [0.]]
     self.ops = training_ops.Load()
+    self.epochs = [0, 1, 1]
+    self.current_epoch = [1]
+    self.data_spec = [constants.DATA_FLOAT] * 2
 
   def testSimple(self):
     with self.test_session():
       (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _,
        pcw_totals_indices, pcw_totals_sums, _, leaves) = (
            self.ops.count_extremely_random_stats(
-               self.input_data, self.input_labels, self.tree,
-               self.tree_thresholds, self.node_map,
-               self.split_features, self.split_thresholds, num_classes=5,
-               regression=False))
+               self.input_data, [], [], [], self.data_spec, self.input_labels,
+               self.tree, self.tree_thresholds, self.node_map,
+               self.split_features, self.split_thresholds, self.epochs,
+               self.current_epoch,
+               num_classes=5, regression=False))
 
       self.assertAllEqual(
           [[4., 1., 1., 1., 1.], [2., 1., 1., 0., 0.], [2., 0., 0., 1., 1.]],
@@ -57,15 +62,68 @@ class CountExtremelyRandomStatsClassificationTest(test_util.TensorFlowTestCase):
       self.assertAllEqual([1., 2., 1.], pcw_totals_sums.eval())
       self.assertAllEqual([1, 1, 2, 2], leaves.eval())
 
+  def testSparseInput(self):
+    sparse_shape = [4, 10]
+    sparse_indices = [[0, 0], [0, 4], [0, 9],
+                      [1, 0], [1, 7],
+                      [2, 0],
+                      [3, 1], [3, 4]]
+    sparse_values = [3.0, -1.0, 0.5,
+                     1.5, 6.0,
+                     -2.0,
+                     -0.5, 2.0]
+    with self.test_session():
+      (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _,
+       pcw_totals_indices, pcw_totals_sums, _, leaves) = (
+           self.ops.count_extremely_random_stats(
+               [], sparse_indices, sparse_values, sparse_shape, self.data_spec,
+               self.input_labels, self.tree,
+               self.tree_thresholds, self.node_map,
+               self.split_features, self.split_thresholds, self.epochs,
+               self.current_epoch,
+               num_classes=5, regression=False))
+
+      self.assertAllEqual(
+          [[4., 1., 1., 1., 1.],
+           [2., 0., 0., 1., 1.],
+           [2., 1., 1., 0., 0.]],
+          pcw_node_sums.eval())
+      self.assertAllEqual([[0, 0, 4], [0, 0, 0], [0, 0, 3]],
+                          pcw_splits_indices.eval())
+      self.assertAllEqual([1., 2., 1.], pcw_splits_sums.eval())
+      self.assertAllEqual([[0, 4], [0, 0], [0, 3]], pcw_totals_indices.eval())
+      self.assertAllEqual([1., 2., 1.], pcw_totals_sums.eval())
+      self.assertAllEqual([2, 2, 1, 1], leaves.eval())
+
+  def testFutureEpoch(self):
+    current_epoch = [3]
+    with self.test_session():
+      (pcw_node_sums, _, _, pcw_splits_sums, _,
+       _, pcw_totals_sums, _, leaves) = (
+           self.ops.count_extremely_random_stats(
+               self.input_data, [], [], [], self.data_spec, self.input_labels,
+               self.tree, self.tree_thresholds, self.node_map,
+               self.split_features, self.split_thresholds, self.epochs,
+               current_epoch, num_classes=5, regression=False))
+
+      self.assertAllEqual(
+          [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]],
+          pcw_node_sums.eval())
+      self.assertAllEqual([], pcw_splits_sums.eval())
+      self.assertAllEqual([], pcw_totals_sums.eval())
+      self.assertAllEqual([1, 1, 2, 2], leaves.eval())
+
   def testThreaded(self):
     with self.test_session(
         config=tf.ConfigProto(intra_op_parallelism_threads=2)):
       (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _,
        pcw_totals_indices, pcw_totals_sums, _, leaves) = (
            self.ops.count_extremely_random_stats(
-               self.input_data, self.input_labels, self.tree,
-               self.tree_thresholds, self.node_map, self.split_features,
-               self.split_thresholds, num_classes=5, regression=False))
+               self.input_data, [], [], [], self.data_spec, self.input_labels,
+               self.tree, self.tree_thresholds, self.node_map,
+               self.split_features,
+               self.split_thresholds, self.epochs, self.current_epoch,
+               num_classes=5, regression=False))
 
       self.assertAllEqual([[4., 1., 1., 1., 1.], [2., 1., 1., 0., 0.],
                            [2., 0., 0., 1., 1.]],
@@ -81,10 +139,10 @@ class CountExtremelyRandomStatsClassificationTest(test_util.TensorFlowTestCase):
       (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _,
        pcw_totals_indices, pcw_totals_sums, _, leaves) = (
            self.ops.count_extremely_random_stats(
-               self.input_data, self.input_labels, self.tree,
-               self.tree_thresholds, [-1] * 3,
-               self.split_features, self.split_thresholds, num_classes=5,
-               regression=False))
+               self.input_data, [], [], [], self.data_spec, self.input_labels,
+               self.tree, self.tree_thresholds, [-1] * 3,
+               self.split_features, self.split_thresholds, self.epochs,
+               self.current_epoch, num_classes=5, regression=False))
 
       self.assertAllEqual([[4., 1., 1., 1., 1.], [2., 1., 1., 0., 0.],
                            [2., 0., 0., 1., 1.]],
@@ -101,13 +159,13 @@ class CountExtremelyRandomStatsClassificationTest(test_util.TensorFlowTestCase):
     with self.test_session():
       with self.assertRaisesOpError(
           'Number of nodes should be the same in '
-          'tree, tree_thresholds, and node_to_accumulator'):
+          'tree, tree_thresholds, node_to_accumulator, and birth_epoch.'):
         pcw_node, _, _, _, _, _, _, _, _ = (
             self.ops.count_extremely_random_stats(
-                self.input_data, self.input_labels, self.tree,
-                self.tree_thresholds, self.node_map,
-                self.split_features, self.split_thresholds, num_classes=5,
-                regression=False))
+                self.input_data, [], [], [], self.data_spec, self.input_labels,
+                self.tree, self.tree_thresholds, self.node_map,
+                self.split_features, self.split_thresholds, self.epochs,
+                self.current_epoch, num_classes=5, regression=False))
 
         self.assertAllEqual([], pcw_node.eval())
 
@@ -124,6 +182,9 @@ class CountExtremelyRandomStatsRegressionTest(test_util.TensorFlowTestCase):
     self.split_features = [[1], [-1]]
     self.split_thresholds = [[1.], [0.]]
     self.ops = training_ops.Load()
+    self.epochs = [0, 1, 1]
+    self.current_epoch = [1]
+    self.data_spec = [constants.DATA_FLOAT] * 2
 
   def testSimple(self):
     with self.test_session():
@@ -131,10 +192,10 @@ class CountExtremelyRandomStatsRegressionTest(test_util.TensorFlowTestCase):
        pcw_splits_squares, pcw_totals_indices,
        pcw_totals_sums, pcw_totals_squares, leaves) = (
            self.ops.count_extremely_random_stats(
-               self.input_data, self.input_labels, self.tree,
-               self.tree_thresholds, self.node_map,
-               self.split_features, self.split_thresholds, num_classes=2,
-               regression=True))
+               self.input_data, [], [], [], self.data_spec, self.input_labels,
+               self.tree, self.tree_thresholds, self.node_map,
+               self.split_features, self.split_thresholds, self.epochs,
+               self.current_epoch, num_classes=2, regression=True))
 
       self.assertAllEqual(
           [[4., 14.], [2., 9.], [2., 5.]], pcw_node_sums.eval())
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/finished_nodes_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/finished_nodes_op_test.py
index 24fbe2c11d6..222ef2b2eb7 100644
--- a/tensorflow/contrib/tensor_forest/python/kernel_tests/finished_nodes_op_test.py
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/finished_nodes_op_test.py
@@ -30,35 +30,71 @@ class FinishedNodesTest(test_util.TensorFlowTestCase):
   def setUp(self):
     self.leaves = [1, 3, 4]
     self.node_map = [-1, -1, -1, 0, 1, -1]
-    self.pcw_total_splits = [[6, 3, 3], [11, 4, 7], [0, 0, 0], [0, 0, 0],
+    self.split_sums = [
+        # Accumulator 1
+        [[3, 0, 3], [2, 1, 1], [3, 1, 2]],
+        # Accumulator 2
+        [[6, 3, 3], [6, 2, 4], [5, 0, 5]],
+        # Accumulator 3
+        [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
+        # Accumulator 4
+        [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
+        # Accumulator 5
+        [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
+    ]
+    self.split_squares = []
+    self.accumulator_sums = [[6, 3, 3], [11, 4, 7], [0, 0, 0], [0, 0, 0],
                              [0, 0, 0]]
+    self.accumulator_squares = []
     self.ops = training_ops.Load()
+    self.birth_epochs = [0, 0, 0, 1, 1, 1]
+    self.current_epoch = [1]
 
   def testSimple(self):
     with self.test_session():
-      finished = self.ops.finished_nodes(self.leaves, self.node_map,
-                                         self.pcw_total_splits,
-                                         num_split_after_samples=10)
+      finished, stale = self.ops.finished_nodes(
+          self.leaves, self.node_map, self.split_sums,
+          self.split_squares, self.accumulator_sums, self.accumulator_squares,
+          self.birth_epochs, self.current_epoch,
+          regression=False, num_split_after_samples=10, min_split_samples=10)
 
       self.assertAllEqual([4], finished.eval())
+      self.assertAllEqual([], stale.eval())
 
   def testNoAccumulators(self):
     with self.test_session():
-      finished = self.ops.finished_nodes(self.leaves, [-1] * 6,
-                                         self.pcw_total_splits,
-                                         num_split_after_samples=10)
+      finished, stale = self.ops.finished_nodes(
+          self.leaves, [-1] * 6, self.split_sums,
+          self.split_squares, self.accumulator_sums, self.accumulator_squares,
+          self.birth_epochs, self.current_epoch,
+          regression=False, num_split_after_samples=10, min_split_samples=10)
 
       self.assertAllEqual([], finished.eval())
+      self.assertAllEqual([], stale.eval())
 
   def testBadInput(self):
     with self.test_session():
       with self.assertRaisesOpError(
           'leaf_tensor should be one-dimensional'):
-        finished = self.ops.finished_nodes([self.leaves], self.node_map,
-                                           self.pcw_total_splits,
-                                           num_split_after_samples=10)
+        finished, stale = self.ops.finished_nodes(
+            [self.leaves], self.node_map, self.split_sums,
+            self.split_squares, self.accumulator_sums, self.accumulator_squares,
+            self.birth_epochs, self.current_epoch,
+            regression=False, num_split_after_samples=10, min_split_samples=10)
 
         self.assertAllEqual([], finished.eval())
+        self.assertAllEqual([], stale.eval())
+
+  def testEarlyDominates(self):
+    with self.test_session():
+      finished, stale = self.ops.finished_nodes(
+          self.leaves, self.node_map, self.split_sums,
+          self.split_squares, self.accumulator_sums, self.accumulator_squares,
+          self.birth_epochs, self.current_epoch,
+          regression=False, num_split_after_samples=10, min_split_samples=5)
+
+      self.assertAllEqual([4], finished.eval())
+      self.assertAllEqual([], stale.eval())
 
 if __name__ == '__main__':
   googletest.main()
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py
index 0bbd94a2a4a..9830651a5d0 100644
--- a/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py
@@ -41,7 +41,8 @@ class SampleInputsTest(test_util.TensorFlowTestCase):
       tf.initialize_all_variables().run()
       indices, feature_updates, threshold_updates = (
           self.ops.sample_inputs(
-              self.input_data, self.node_map, self.leaves, self.split_features,
+              self.input_data, [], [], [],
+              self.node_map, self.leaves, self.split_features,
               self.split_thresholds, split_initializations_per_input=1,
               split_sampling_random_seed=3))
       self.assertAllEqual([1, 0], indices.eval())
@@ -50,12 +51,38 @@ class SampleInputsTest(test_util.TensorFlowTestCase):
       self.assertAllEqual([[5., -2., 50.], [-1., -10., 0.]],
                           threshold_updates.eval())
 
+  def testSparse(self):
+    sparse_shape = [4, 10]
+    sparse_indices = [[0, 0], [0, 4], [0, 9],
+                      [1, 0], [1, 7],
+                      [2, 0],
+                      [3, 1], [3, 4]]
+    sparse_values = [3.0, -1.0, 0.5,
+                     1.5, 6.0,
+                     -2.0,
+                     -0.5, 2.0]
+
+    with self.test_session():
+      tf.initialize_all_variables().run()
+      indices, feature_updates, threshold_updates = (
+          self.ops.sample_inputs(
+              [], sparse_indices, sparse_values, sparse_shape,
+              self.node_map, self.leaves, self.split_features,
+              self.split_thresholds, split_initializations_per_input=1,
+              split_sampling_random_seed=3))
+      self.assertAllEqual([1, 0], indices.eval())
+      self.assertAllEqual([[1, 0, 0], [4, 7, -1]],
+                          feature_updates.eval())
+      self.assertAllEqual([[5., -2., -2.], [-1., 6., 0.]],
+                          threshold_updates.eval())
+
   def testNoAccumulators(self):
     with self.test_session():
       tf.initialize_all_variables().run()
       indices, feature_updates, threshold_updates = (
           self.ops.sample_inputs(
-              self.input_data, [-1] * 3, self.leaves, self.split_features,
+              self.input_data, [], [], [],
+              [-1] * 3, self.leaves, self.split_features,
               self.split_thresholds, split_initializations_per_input=1,
               split_sampling_random_seed=3))
       self.assertAllEqual([], indices.eval())
@@ -69,7 +96,8 @@ class SampleInputsTest(test_util.TensorFlowTestCase):
       with self.assertRaisesOpError(
           'split_features and split_thresholds should be the same shape.'):
         indices, _, _ = self.ops.sample_inputs(
-            self.input_data, self.node_map, self.leaves, self.split_features,
+            self.input_data, [], [], [],
+            self.node_map, self.leaves, self.split_features,
             self.split_thresholds, split_initializations_per_input=1,
             split_sampling_random_seed=3)
         self.assertAllEqual([], indices.eval())
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py
index e61085657a1..aaead5610f5 100644
--- a/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
 
 import tensorflow  # pylint: disable=unused-import
 
+from tensorflow.contrib.tensor_forest.python import constants
 from tensorflow.contrib.tensor_forest.python.ops import inference_ops
 
 from tensorflow.python.framework import test_util
@@ -29,6 +30,7 @@ class TreePredictionsTest(test_util.TensorFlowTestCase):
 
   def setUp(self):
     self.ops = inference_ops.Load()
+    self.data_spec = [constants.DATA_FLOAT] * 2
 
   def testSimple(self):
     input_data = [[-1., 0.], [-1., 2.],  # node 1
@@ -41,13 +43,65 @@ class TreePredictionsTest(test_util.TensorFlowTestCase):
 
     with self.test_session():
       predictions = self.ops.tree_predictions(
-          input_data, tree, tree_thresholds, node_pcw,
-          valid_leaf_threshold=1)
+          input_data, [], [], [], self.data_spec, tree, tree_thresholds,
+          node_pcw, valid_leaf_threshold=1)
 
       self.assertAllClose([[0.1, 0.1, 0.8], [0.1, 0.1, 0.8],
                            [0.5, 0.25, 0.25], [0.5, 0.25, 0.25]],
                           predictions.eval())
 
+  def testSparseInput(self):
+    sparse_shape = [3, 10]
+    sparse_indices = [[0, 0], [0, 4], [0, 9],
+                      [1, 0], [1, 7],
+                      [2, 0]]
+    sparse_values = [3.0, -1.0, 0.5,
+                     1.5, 6.0,
+                     -2.0]
+    sparse_data_spec = [constants.DATA_FLOAT] * 10
+
+    tree = [[1, 0], [-1, 0], [-1, 0]]
+    tree_thresholds = [0., 0., 0.]
+    node_pcw = [[1.0, 0.3, 0.4, 0.3], [1.0, 0.1, 0.1, 0.8],
+                [1.0, 0.5, 0.25, 0.25]]
+
+    with self.test_session():
+      predictions = self.ops.tree_predictions(
+          [], sparse_indices, sparse_values, sparse_shape, sparse_data_spec,
+          tree, tree_thresholds, node_pcw,
+          valid_leaf_threshold=1)
+
+      self.assertAllClose([[0.5, 0.25, 0.25],
+                           [0.5, 0.25, 0.25],
+                           [0.1, 0.1, 0.8]],
+                          predictions.eval())
+
+  def testSparseInputDefaultIsZero(self):
+    sparse_shape = [3, 10]
+    sparse_indices = [[0, 0], [0, 4], [0, 9],
+                      [1, 0], [1, 7],
+                      [2, 0]]
+    sparse_values = [3.0, -1.0, 0.5,
+                     1.5, 6.0,
+                     -2.0]
+    sparse_data_spec = [constants.DATA_FLOAT] * 10
+
+    tree = [[1, 7], [-1, 0], [-1, 0]]
+    tree_thresholds = [3.0, 0., 0.]
+    node_pcw = [[1.0, 0.3, 0.4, 0.3], [1.0, 0.1, 0.1, 0.8],
+                [1.0, 0.5, 0.25, 0.25]]
+
+    with self.test_session():
+      predictions = self.ops.tree_predictions(
+          [], sparse_indices, sparse_values, sparse_shape, sparse_data_spec,
+          tree, tree_thresholds, node_pcw,
+          valid_leaf_threshold=1)
+
+      self.assertAllClose([[0.1, 0.1, 0.8],
+                           [0.5, 0.25, 0.25],
+                           [0.1, 0.1, 0.8]],
+                          predictions.eval())
+
   def testBackoffToParent(self):
     input_data = [[-1., 0.], [-1., 2.],  # node 1
                   [1., 0.], [1., -2.]]  # node 2
@@ -59,8 +113,8 @@ class TreePredictionsTest(test_util.TensorFlowTestCase):
 
     with self.test_session():
       predictions = self.ops.tree_predictions(
-          input_data, tree, tree_thresholds, node_pcw,
-          valid_leaf_threshold=10)
+          input_data, [], [], [], self.data_spec, tree, tree_thresholds,
+          node_pcw, valid_leaf_threshold=10)
 
       # Node 2 has enough data, but Node 1 needs to combine with the parent
       # counts.
@@ -78,8 +132,8 @@ class TreePredictionsTest(test_util.TensorFlowTestCase):
 
     with self.test_session():
       predictions = self.ops.tree_predictions(
-          input_data, tree, tree_thresholds, node_pcw,
-          valid_leaf_threshold=10)
+          input_data, [], [], [], self.data_spec, tree, tree_thresholds,
+          node_pcw, valid_leaf_threshold=10)
 
       self.assertEquals((0, 3), predictions.eval().shape)
 
@@ -97,8 +151,8 @@ class TreePredictionsTest(test_util.TensorFlowTestCase):
           'Number of nodes should be the same in tree, tree_thresholds '
           'and node_pcw.'):
         predictions = self.ops.tree_predictions(
-            input_data, tree, tree_thresholds, node_pcw,
-            valid_leaf_threshold=10)
+            input_data, [], [], [], self.data_spec, tree, tree_thresholds,
+            node_pcw, valid_leaf_threshold=10)
 
         self.assertEquals((0, 3), predictions.eval().shape)
 
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py
index f370903b3c6..c9af01c50b7 100644
--- a/tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py
@@ -40,48 +40,43 @@ class UpdateFertileSlotsTest(test_util.TensorFlowTestCase):
     self.node_map = [-1, -1, 0, -1, -1, -1, -1]
     self.total_counts = [[80., 40., 40.]]
     self.ops = training_ops.Load()
+    self.stale_leaves = []
 
   def testSimple(self):
     with self.test_session():
-      (node_map_updates, accumulators_cleared, accumulators_allocated,
-       new_nfl, new_nfl_scores) = self.ops.update_fertile_slots(
+      (node_map_updates, accumulators_cleared,
+       accumulators_allocated) = self.ops.update_fertile_slots(
            self.finished, self.non_fertile_leaves, self.non_fertile_leaf_scores,
            self.end_of_tree, self.depths,
-           self.total_counts, self.node_map, max_depth=4)
+           self.total_counts, self.node_map, self.stale_leaves, max_depth=4)
 
       self.assertAllEqual([[2, 4], [-1, 0]], node_map_updates.eval())
       self.assertAllEqual([], accumulators_cleared.eval())
       self.assertAllEqual([0], accumulators_allocated.eval())
-      self.assertAllEqual([3, 5, 6], new_nfl.eval())
-      self.assertAllEqual([10., 1., 1.], new_nfl_scores.eval())
 
   def testReachedMaxDepth(self):
     with self.test_session():
-      (node_map_updates, accumulators_cleared, accumulators_allocated,
-       new_nfl, new_nfl_scores) = self.ops.update_fertile_slots(
+      (node_map_updates, accumulators_cleared,
+       accumulators_allocated) = self.ops.update_fertile_slots(
            self.finished, self.non_fertile_leaves, self.non_fertile_leaf_scores,
            self.end_of_tree, self.depths,
-           self.total_counts, self.node_map, max_depth=3)
+           self.total_counts, self.node_map, self.stale_leaves, max_depth=3)
 
       self.assertAllEqual([[2], [-1]], node_map_updates.eval())
       self.assertAllEqual([0], accumulators_cleared.eval())
       self.assertAllEqual([], accumulators_allocated.eval())
-      self.assertAllEqual([-1], new_nfl.eval())
-      self.assertAllEqual([0.0], new_nfl_scores.eval())
 
   def testNoFinished(self):
     with self.test_session():
-      (node_map_updates, accumulators_cleared, accumulators_allocated,
-       new_nfl, new_nfl_scores) = self.ops.update_fertile_slots(
+      (node_map_updates, accumulators_cleared,
+       accumulators_allocated) = self.ops.update_fertile_slots(
            [], self.non_fertile_leaves, self.non_fertile_leaf_scores,
            self.end_of_tree, self.depths,
-           self.total_counts, self.node_map, max_depth=4)
+           self.total_counts, self.node_map, self.stale_leaves, max_depth=4)
 
       self.assertAllEqual((2, 0), node_map_updates.eval().shape)
       self.assertAllEqual([], accumulators_cleared.eval())
       self.assertAllEqual([], accumulators_allocated.eval())
-      self.assertAllEqual([4, 3], new_nfl.eval())
-      self.assertAllEqual([15., 10.], new_nfl_scores.eval())
 
   def testBadInput(self):
     del self.non_fertile_leaf_scores[-1]
@@ -89,10 +84,10 @@ class UpdateFertileSlotsTest(test_util.TensorFlowTestCase):
       with self.assertRaisesOpError(
           'Number of non fertile leaves should be the same in '
           'non_fertile_leaves and non_fertile_leaf_scores.'):
-        (node_map_updates, _, _, _, _) = self.ops.update_fertile_slots(
+        (node_map_updates, _, _) = self.ops.update_fertile_slots(
             self.finished, self.non_fertile_leaves,
             self.non_fertile_leaf_scores, self.end_of_tree, self.depths,
-            self.total_counts, self.node_map, max_depth=4)
+            self.total_counts, self.node_map, self.stale_leaves, max_depth=4)
         self.assertAllEqual((2, 0), node_map_updates.eval().shape)
 
 
diff --git a/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py b/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py
index 6f4e6fff401..88f8112ed4c 100644
--- a/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py
+++ b/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py
@@ -1,3 +1,4 @@
+# pylint: disable=g-bad-file-header
 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,13 +18,14 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import os
 import threading
 
-import tensorflow as tf
-
+from tensorflow.python.framework import load_library
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.platform import tf_logging as logging
+
 
 INFERENCE_OPS_FILE = '_inference_ops.so'
 
@@ -38,7 +40,11 @@ ops.NoGradient('TreePredictions')
 def TreePredictions(op):
   """Shape function for TreePredictions Op."""
   num_points = op.inputs[0].get_shape()[0].value
-  num_classes = op.inputs[3].get_shape()[1].value
+  sparse_shape = op.inputs[3].get_shape()
+  if sparse_shape.ndims == 2:
+    num_points = sparse_shape[0].value
+  num_classes = op.inputs[7].get_shape()[1].value
+
   # The output of TreePredictions is
   # [node_pcw(evaluate_tree(x), c) for c in classes for x in input_data].
   return [tensor_shape.TensorShape([num_points, num_classes - 1])]
@@ -49,16 +55,14 @@ def TreePredictions(op):
 # there's not yet any guarantee that the shared object exists.
 # In which case, "import tensorflow" will always crash, even for users that
 # never use contrib.
-def Load(library_base_dir=''):
+def Load():
   """Load the inference ops library and return the loaded module."""
   with _ops_lock:
     global _inference_ops
     if not _inference_ops:
-      data_files_path = os.path.join(library_base_dir,
-                                     tf.resource_loader.get_data_files_path())
-      tf.logging.info('data path: %s', data_files_path)
-      _inference_ops = tf.load_op_library(os.path.join(
-          data_files_path, INFERENCE_OPS_FILE))
+      ops_path = resource_loader.get_path_to_datafile(INFERENCE_OPS_FILE)
+      logging.info('data path: %s', ops_path)
+      _inference_ops = load_library.load_op_library(ops_path)
 
       assert _inference_ops, 'Could not load inference_ops.so'
   return _inference_ops
diff --git a/tensorflow/contrib/tensor_forest/python/ops/training_ops.py b/tensorflow/contrib/tensor_forest/python/ops/training_ops.py
index 7a108baf426..d25d5ce50bf 100644
--- a/tensorflow/contrib/tensor_forest/python/ops/training_ops.py
+++ b/tensorflow/contrib/tensor_forest/python/ops/training_ops.py
@@ -1,3 +1,4 @@
+# pylint: disable=g-bad-file-header
 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,13 +18,13 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import os
 import threading
 
-import tensorflow as tf
-
+from tensorflow.python.framework import load_library
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.platform import tf_logging as logging
 
 
 TRAINING_OPS_FILE = '_training_ops.so'
@@ -45,7 +46,10 @@ def _CountExtremelyRandomStatsShape(op):
   """Shape function for CountExtremelyRandomStats Op."""
   regression = op.get_attr('regression')
   num_points = op.inputs[0].get_shape()[0].value
-  num_nodes = op.inputs[2].get_shape()[0].value
+  sparse_shape = op.inputs[3].get_shape()
+  if sparse_shape.ndims == 2:
+    num_points = sparse_shape[0].value
+  num_nodes = op.inputs[6].get_shape()[0].value
   num_classes = op.get_attr('num_classes')
   # The output of TraverseTree is [leaf_node_index(x) for x in input_data].
   return [tensor_shape.TensorShape([num_nodes, num_classes]),  # node sums
@@ -66,7 +70,7 @@ def _CountExtremelyRandomStatsShape(op):
 @ops.RegisterShape('SampleInputs')
 def _SampleInputsShape(op):
   """Shape function for SampleInputs Op."""
-  num_splits = op.inputs[3].get_shape()[1].value
+  num_splits = op.inputs[6].get_shape()[1].value
   return [[None], [None, num_splits], [None, num_splits]]
 
 
@@ -85,7 +89,7 @@ def _GrowTreeShape(unused_op):
 @ops.RegisterShape('FinishedNodes')
 def _FinishedNodesShape(unused_op):
   """Shape function for FinishedNodes Op."""
-  return [[None]]
+  return [[None], [None]]
 
 
 @ops.RegisterShape('ScatterAddNdim')
@@ -97,7 +101,7 @@ def _ScatterAddNdimShape(unused_op):
 @ops.RegisterShape('UpdateFertileSlots')
 def _UpdateFertileSlotsShape(unused_op):
   """Shape function for UpdateFertileSlots Op."""
-  return [[None, 2], [None], [None], [None], [None]]
+  return [[None, 2], [None], [None]]
 
 
 # Workaround for the fact that importing tensorflow imports contrib
@@ -105,16 +109,14 @@ def _UpdateFertileSlotsShape(unused_op):
 # there's not yet any guarantee that the shared object exists.
 # In which case, "import tensorflow" will always crash, even for users that
 # never use contrib.
-def Load(library_base_dir=''):
+def Load():
   """Load training ops library and return the loaded module."""
   with _ops_lock:
     global _training_ops
     if not _training_ops:
-      data_files_path = os.path.join(library_base_dir,
-                                     tf.resource_loader.get_data_files_path())
-      tf.logging.info('data path: %s', data_files_path)
-      _training_ops = tf.load_op_library(os.path.join(
-          data_files_path, TRAINING_OPS_FILE))
+      ops_path = resource_loader.get_path_to_datafile(TRAINING_OPS_FILE)
+      logging.info('data path: %s', ops_path)
+      _training_ops = load_library.load_op_library(ops_path)
 
       assert _training_ops, 'Could not load _training_ops.so'
   return _training_ops
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
index f48efaa5db1..791954c51f4 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
@@ -1,3 +1,4 @@
+# pylint: disable=g-bad-file-header
 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,14 +21,22 @@ from __future__ import print_function
 import math
 import random
 
-import tensorflow as tf
-
+from tensorflow.contrib.tensor_forest.python import constants
 from tensorflow.contrib.tensor_forest.python.ops import inference_ops
 from tensorflow.contrib.tensor_forest.python.ops import training_ops
 
-
-# If tree[i][0] equals this value, then i is a leaf node.
-LEAF_NODE = -1
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as tf_variables
+from tensorflow.python.platform import tf_logging as logging
 
 
 # A convenience class for holding random forest hyperparameters.
@@ -49,6 +58,7 @@ class ForestHParams(object):
                max_depth=0, num_splits_to_consider=0,
                feature_bagging_fraction=1.0,
                max_fertile_nodes=0, split_after_samples=250,
+               min_split_samples=5,
                valid_leaf_threshold=1, **kwargs):
     self.num_trees = num_trees
     self.max_nodes = max_nodes
@@ -58,6 +68,7 @@ class ForestHParams(object):
     self.num_splits_to_consider = num_splits_to_consider
     self.max_fertile_nodes = max_fertile_nodes
     self.split_after_samples = split_after_samples
+    self.min_split_samples = min_split_samples
     self.valid_leaf_threshold = valid_leaf_threshold
 
     for name, value in kwargs.items():
@@ -72,11 +83,6 @@ class ForestHParams(object):
     _ = getattr(self, 'num_classes')
     _ = getattr(self, 'num_features')
 
-    self.training_library_base_dir = getattr(
-        self, 'training_library_base_dir', '')
-    self.inference_library_base_dir = getattr(
-        self, 'inference_library_base_dir', '')
-
     self.bagged_num_features = int(self.feature_bagging_fraction *
                                    self.num_features)
 
@@ -147,92 +153,86 @@ class TreeTrainingVariables(object):
   """
 
   def __init__(self, params, tree_num, training):
-    self.tree = tf.get_variable(
-        name=self.get_tree_name('tree', tree_num), dtype=tf.int32,
-        initializer=tf.constant(
-            [[-1, -1]] + [[-2, -1]] * (params.max_nodes - 1)))
-    self.tree_thresholds = tf.get_variable(
+    self.tree = variable_scope.get_variable(
+        name=self.get_tree_name('tree', tree_num), dtype=dtypes.int32,
+        shape=[params.max_nodes, 2],
+        initializer=init_ops.constant_initializer(-2))
+    self.tree_thresholds = variable_scope.get_variable(
         name=self.get_tree_name('tree_thresholds', tree_num),
         shape=[params.max_nodes],
-        initializer=tf.constant_initializer(-1.0))
-    self.tree_depths = tf.get_variable(
+        initializer=init_ops.constant_initializer(-1.0))
+    self.tree_depths = variable_scope.get_variable(
         name=self.get_tree_name('tree_depths', tree_num),
         shape=[params.max_nodes],
-        dtype=tf.int32,
-        initializer=tf.constant_initializer(1))
-    self.end_of_tree = tf.get_variable(
+        dtype=dtypes.int32,
+        initializer=init_ops.constant_initializer(1))
+    self.end_of_tree = variable_scope.get_variable(
         name=self.get_tree_name('end_of_tree', tree_num),
-        dtype=tf.int32,
-        initializer=tf.constant([1]))
+        dtype=dtypes.int32,
+        initializer=constant_op.constant([1]))
+    self.start_epoch = tf_variables.Variable(
+        [0] * (params.max_nodes), name='start_epoch')
 
     if training:
-      self.non_fertile_leaves = tf.get_variable(
-          name=self.get_tree_name('non_fertile_leaves', tree_num),
-          dtype=tf.int32,
-          initializer=tf.constant([0]))
-      self.non_fertile_leaf_scores = tf.get_variable(
-          name=self.get_tree_name('non_fertile_leaf_scores', tree_num),
-          initializer=tf.constant([1.0]))
-
-      self.node_to_accumulator_map = tf.get_variable(
+      self.node_to_accumulator_map = variable_scope.get_variable(
           name=self.get_tree_name('node_to_accumulator_map', tree_num),
           shape=[params.max_nodes],
-          dtype=tf.int32,
-          initializer=tf.constant_initializer(-1))
+          dtype=dtypes.int32,
+          initializer=init_ops.constant_initializer(-1))
 
-      self.candidate_split_features = tf.get_variable(
+      self.candidate_split_features = variable_scope.get_variable(
           name=self.get_tree_name('candidate_split_features', tree_num),
           shape=[params.max_fertile_nodes, params.num_splits_to_consider],
-          dtype=tf.int32,
-          initializer=tf.constant_initializer(-1))
-      self.candidate_split_thresholds = tf.get_variable(
+          dtype=dtypes.int32,
+          initializer=init_ops.constant_initializer(-1))
+      self.candidate_split_thresholds = variable_scope.get_variable(
           name=self.get_tree_name('candidate_split_thresholds', tree_num),
           shape=[params.max_fertile_nodes, params.num_splits_to_consider],
-          initializer=tf.constant_initializer(0.0))
+          initializer=init_ops.constant_initializer(0.0))
 
     # Statistics shared by classification and regression.
-    self.node_sums = tf.get_variable(
+    self.node_sums = variable_scope.get_variable(
         name=self.get_tree_name('node_sums', tree_num),
         shape=[params.max_nodes, params.num_output_columns],
-        initializer=tf.constant_initializer(0.0))
+        initializer=init_ops.constant_initializer(0.0))
 
     if training:
-      self.candidate_split_sums = tf.get_variable(
+      self.candidate_split_sums = variable_scope.get_variable(
           name=self.get_tree_name('candidate_split_sums', tree_num),
           shape=[params.max_fertile_nodes, params.num_splits_to_consider,
                  params.num_output_columns],
-          initializer=tf.constant_initializer(0.0))
-      self.accumulator_sums = tf.get_variable(
+          initializer=init_ops.constant_initializer(0.0))
+      self.accumulator_sums = variable_scope.get_variable(
           name=self.get_tree_name('accumulator_sums', tree_num),
           shape=[params.max_fertile_nodes, params.num_output_columns],
-          initializer=tf.constant_initializer(-1.0))
+          initializer=init_ops.constant_initializer(-1.0))
 
       # Regression also tracks second order stats.
       if params.regression:
-        self.node_squares = tf.get_variable(
+        self.node_squares = variable_scope.get_variable(
             name=self.get_tree_name('node_squares', tree_num),
             shape=[params.max_nodes, params.num_output_columns],
-            initializer=tf.constant_initializer(0.0))
+            initializer=init_ops.constant_initializer(0.0))
 
-        self.candidate_split_squares = tf.get_variable(
+        self.candidate_split_squares = variable_scope.get_variable(
             name=self.get_tree_name('candidate_split_squares', tree_num),
             shape=[params.max_fertile_nodes, params.num_splits_to_consider,
                    params.num_output_columns],
-            initializer=tf.constant_initializer(0.0))
+            initializer=init_ops.constant_initializer(0.0))
 
-        self.accumulator_squares = tf.get_variable(
+        self.accumulator_squares = variable_scope.get_variable(
             name=self.get_tree_name('accumulator_squares', tree_num),
             shape=[params.max_fertile_nodes, params.num_output_columns],
-            initializer=tf.constant_initializer(-1.0))
+            initializer=init_ops.constant_initializer(-1.0))
 
       else:
-        self.node_squares = tf.constant(
+        self.node_squares = constant_op.constant(
             0.0, name=self.get_tree_name('node_squares', tree_num))
 
-        self.candidate_split_squares = tf.constant(
+        self.candidate_split_squares = constant_op.constant(
             0.0, name=self.get_tree_name('candidate_split_squares', tree_num))
 
-        self.accumulator_squares = tf.constant(
+        self.accumulator_squares = constant_op.constant(
             0.0, name=self.get_tree_name('accumulator_squares', tree_num))
 
   def get_tree_name(self, name, num):
@@ -273,11 +273,11 @@ class ForestTrainingVariables(object):
   """
 
   def __init__(self, params, device_assigner, training=True,
-               tree_variable_class=TreeTrainingVariables):
+               tree_variables_class=TreeTrainingVariables):
     self.variables = []
     for i in range(params.num_trees):
-      with tf.device(device_assigner.get_device(i)):
-        self.variables.append(tree_variable_class(params, i, training))
+      with ops.device(device_assigner.get_device(i)):
+        self.variables.append(tree_variables_class(params, i, training))
 
   def __setitem__(self, t, val):
     self.variables[t] = val
@@ -299,7 +299,7 @@ class RandomForestDeviceAssigner(object):
 
   def get_device(self, unused_tree_num):
     if not self.cached:
-      dummy = tf.constant(0)
+      dummy = constant_op.constant(0)
       self.cached = dummy.device
 
     return self.cached
@@ -308,43 +308,51 @@ class RandomForestDeviceAssigner(object):
 class RandomForestGraphs(object):
   """Builds TF graphs for random forest training and inference."""
 
-  def __init__(self, params, device_assigner=None, variables=None,
-               tree_graphs=None,
+  def __init__(self, params, device_assigner=None,
+               variables=None, tree_variables_class=TreeTrainingVariables,
+               tree_graphs=None, training=True,
                t_ops=training_ops,
                i_ops=inference_ops):
     self.params = params
     self.device_assigner = device_assigner or RandomForestDeviceAssigner()
-    tf.logging.info('Constructing forest with params = ')
-    tf.logging.info(self.params.__dict__)
+    logging.info('Constructing forest with params = ')
+    logging.info(self.params.__dict__)
     self.variables = variables or ForestTrainingVariables(
-        self.params, device_assigner=self.device_assigner)
+        self.params, device_assigner=self.device_assigner, training=training,
+        tree_variables_class=tree_variables_class)
     tree_graph_class = tree_graphs or RandomTreeGraphs
     self.trees = [
         tree_graph_class(
             self.variables[i], self.params,
-            t_ops.Load(self.params.training_library_base_dir),
-            i_ops.Load(self.params.inference_library_base_dir), i)
+            t_ops.Load(), i_ops.Load(), i)
         for i in range(self.params.num_trees)]
 
   def _bag_features(self, tree_num, input_data):
-    split_data = tf.split(1, self.params.num_features, input_data)
-    return tf.concat(1, [split_data[ind]
-                         for ind in self.params.bagged_features[tree_num]])
+    split_data = array_ops.split(1, self.params.num_features, input_data)
+    return array_ops.concat(
+        1, [split_data[ind] for ind in self.params.bagged_features[tree_num]])
 
-  def training_graph(self, input_data, input_labels):
+  def training_graph(self, input_data, input_labels, data_spec=None,
+                     epoch=None, **tree_kwargs):
     """Constructs a TF graph for training a random forest.
 
     Args:
-      input_data: A tensor or placeholder for input data.
+      input_data: A tensor or SparseTensor or placeholder for input data.
       input_labels: A tensor or placeholder for labels associated with
         input_data.
+      data_spec: A list of tf.dtype values specifying the original types of
+        each column.
+      epoch: A tensor or placeholder for the epoch the training data comes from.
+      **tree_kwargs: Keyword arguments passed to each tree's training_graph.
 
     Returns:
       The last op in the random forest training graph.
     """
+    data_spec = ([constants.DATA_FLOAT] * self.params.num_features
+                 if data_spec is None else data_spec)
     tree_graphs = []
     for i in range(self.params.num_trees):
-      with tf.device(self.device_assigner.get_device(i)):
+      with ops.device(self.device_assigner.get_device(i)):
         seed = self.params.base_random_seed
         if seed != 0:
           seed += i
@@ -354,40 +362,54 @@ class RandomForestGraphs(object):
         if self.params.bagging_fraction < 1.0:
           # TODO(thomaswc): This does sampling without replacment.  Consider
           # also allowing sampling with replacement as an option.
-          batch_size = tf.slice(tf.shape(input_data), [0], [1])
-          r = tf.random_uniform(batch_size, seed=seed)
-          mask = tf.less(r, tf.ones_like(r) * self.params.bagging_fraction)
-          gather_indices = tf.squeeze(tf.where(mask), squeeze_dims=[1])
+          batch_size = array_ops.slice(array_ops.shape(input_data), [0], [1])
+          r = random_ops.random_uniform(batch_size, seed=seed)
+          mask = math_ops.less(
+              r, array_ops.ones_like(r) * self.params.bagging_fraction)
+          gather_indices = array_ops.squeeze(
+              array_ops.where(mask), squeeze_dims=[1])
           # TODO(thomaswc): Calculate out-of-bag data and labels, and store
           # them for use in calculating statistics later.
-          tree_data = tf.gather(input_data, gather_indices)
-          tree_labels = tf.gather(input_labels, gather_indices)
+          tree_data = array_ops.gather(input_data, gather_indices)
+          tree_labels = array_ops.gather(input_labels, gather_indices)
         if self.params.bagged_features:
           tree_data = self._bag_features(i, tree_data)
 
-        tree_graphs.append(
-            self.trees[i].training_graph(tree_data, tree_labels, seed))
-    return tf.group(*tree_graphs)
+        initialization = self.trees[i].tree_initialization()
 
-  def inference_graph(self, input_data):
+        with ops.control_dependencies([initialization]):
+          tree_graphs.append(
+              self.trees[i].training_graph(
+                  tree_data, tree_labels, seed, data_spec=data_spec,
+                  epoch=([0] if epoch is None else epoch),
+                  **tree_kwargs))
+
+    return control_flow_ops.group(*tree_graphs)
+
+  def inference_graph(self, input_data, data_spec=None):
     """Constructs a TF graph for evaluating a random forest.
 
     Args:
-      input_data: A tensor or placeholder for input data.
+      input_data: A tensor or SparseTensor or placeholder for input data.
+      data_spec: A list of tf.dtype values specifying the original types of
+        each column.
 
     Returns:
       The last op in the random forest inference graph.
     """
+    data_spec = ([constants.DATA_FLOAT] * self.params.num_features
+                 if data_spec is None else data_spec)
     probabilities = []
     for i in range(self.params.num_trees):
-      with tf.device(self.device_assigner.get_device(i)):
+      with ops.device(self.device_assigner.get_device(i)):
         tree_data = input_data
         if self.params.bagged_features:
           tree_data = self._bag_features(i, input_data)
-        probabilities.append(self.trees[i].inference_graph(tree_data))
-    with tf.device(self.device_assigner.get_device(0)):
-      all_predict = tf.pack(probabilities)
-      return tf.reduce_sum(all_predict, 0) / self.params.num_trees
+        probabilities.append(self.trees[i].inference_graph(tree_data,
+                                                           data_spec))
+    with ops.device(self.device_assigner.get_device(0)):
+      all_predict = array_ops.pack(probabilities)
+      return math_ops.reduce_sum(all_predict, 0) / self.params.num_trees
 
   def average_size(self):
     """Constructs a TF graph for evaluating the average size of a forest.
@@ -397,9 +419,16 @@ class RandomForestGraphs(object):
     """
     sizes = []
     for i in range(self.params.num_trees):
-      with tf.device(self.device_assigner.get_device(i)):
+      with ops.device(self.device_assigner.get_device(i)):
         sizes.append(self.trees[i].size())
-    return tf.reduce_mean(tf.pack(sizes))
+    return math_ops.reduce_mean(array_ops.pack(sizes))
+
+  def training_loss(self):
+    return math_ops.neg(self.average_size())
+
+  # pylint: disable=unused-argument
+  def validation_loss(self, features, labels):
+    return math_ops.neg(self.average_size())
 
   def average_impurity(self):
     """Constructs a TF graph for evaluating the leaf impurity of a forest.
@@ -409,14 +438,14 @@ class RandomForestGraphs(object):
     """
     impurities = []
     for i in range(self.params.num_trees):
-      with tf.device(self.device_assigner.get_device(i)):
+      with ops.device(self.device_assigner.get_device(i)):
         impurities.append(self.trees[i].average_impurity())
-    return tf.reduce_mean(tf.pack(impurities))
+    return math_ops.reduce_mean(array_ops.pack(impurities))
 
   def get_stats(self, session):
     tree_stats = []
     for i in range(self.params.num_trees):
-      with tf.device(self.device_assigner.get_device(i)):
+      with ops.device(self.device_assigner.get_device(i)):
         tree_stats.append(self.trees[i].get_stats(session))
     return ForestStats(tree_stats, self.params)
 
@@ -431,6 +460,18 @@ class RandomTreeGraphs(object):
     self.params = params
     self.tree_num = tree_num
 
+  def tree_initialization(self):
+    def _init_tree():
+      return state_ops.scatter_update(self.variables.tree, [0], [[-1, -1]]).op
+
+    def _nothing():
+      return control_flow_ops.no_op()
+
+    return control_flow_ops.cond(
+        math_ops.equal(array_ops.squeeze(array_ops.slice(
+            self.variables.tree, [0, 0], [1, 1])), -2),
+        _init_tree, _nothing)
+
   def _gini(self, class_counts):
     """Calculate the Gini impurity.
 
@@ -444,9 +485,9 @@ class RandomTreeGraphs(object):
     Returns:
       A 1-D tensor of the Gini impurities for each row in the input.
     """
-    smoothed = 1.0 + tf.slice(class_counts, [0, 1], [-1, -1])
-    sums = tf.reduce_sum(smoothed, 1)
-    sum_squares = tf.reduce_sum(tf.square(smoothed), 1)
+    smoothed = 1.0 + array_ops.slice(class_counts, [0, 1], [-1, -1])
+    sums = math_ops.reduce_sum(smoothed, 1)
+    sum_squares = math_ops.reduce_sum(math_ops.square(smoothed), 1)
 
     return 1.0 - sum_squares / (sums * sums)
 
@@ -463,9 +504,9 @@ class RandomTreeGraphs(object):
     Returns:
       A 1-D tensor of the Gini impurities for each row in the input.
     """
-    smoothed = 1.0 + tf.slice(class_counts, [0, 1], [-1, -1])
-    sums = tf.reduce_sum(smoothed, 1)
-    sum_squares = tf.reduce_sum(tf.square(smoothed), 1)
+    smoothed = 1.0 + array_ops.slice(class_counts, [0, 1], [-1, -1])
+    sums = math_ops.reduce_sum(smoothed, 1)
+    sum_squares = math_ops.reduce_sum(math_ops.square(smoothed), 1)
 
     return sums - sum_squares / sums
 
@@ -483,40 +524,58 @@ class RandomTreeGraphs(object):
     Returns:
       A 1-D tensor of the variances for each row in the input.
     """
-    total_count = tf.slice(sums, [0, 0], [-1, 1])
+    total_count = array_ops.slice(sums, [0, 0], [-1, 1])
     e_x = sums / total_count
     e_x2 = squares / total_count
 
-    return tf.reduce_sum(e_x2 - tf.square(e_x), 1)
+    return math_ops.reduce_sum(e_x2 - math_ops.square(e_x), 1)
+
+  def training_graph(self, input_data, input_labels, random_seed,
+                     data_spec, epoch=None):
 
-  def training_graph(self, input_data, input_labels, random_seed):
     """Constructs a TF graph for training a random tree.
 
     Args:
-      input_data: A tensor or placeholder for input data.
+      input_data: A tensor or SparseTensor or placeholder for input data.
       input_labels: A tensor or placeholder for labels associated with
         input_data.
       random_seed: The random number generator seed to use for this tree.  0
         means use the current time as the seed.
+      data_spec: A list of tf.dtype values specifying the original types of
+        each column.
+      epoch: A tensor or placeholder for the epoch the training data comes from.
 
     Returns:
       The last op in the random tree training graph.
     """
+    epoch = [0] if epoch is None else epoch
+
+    sparse_indices = []
+    sparse_values = []
+    sparse_shape = []
+    if isinstance(input_data, ops.SparseTensor):
+      sparse_indices = input_data.indices
+      sparse_values = input_data.values
+      sparse_shape = input_data.shape
+      input_data = []
+
     # Count extremely random stats.
     (node_sums, node_squares, splits_indices, splits_sums,
      splits_squares, totals_indices, totals_sums,
      totals_squares, input_leaves) = (
          self.training_ops.count_extremely_random_stats(
-             input_data, input_labels, self.variables.tree,
+             input_data, sparse_indices, sparse_values, sparse_shape,
+             data_spec, input_labels, self.variables.tree,
              self.variables.tree_thresholds,
              self.variables.node_to_accumulator_map,
              self.variables.candidate_split_features,
              self.variables.candidate_split_thresholds,
+             self.variables.start_epoch, epoch,
              num_classes=self.params.num_output_columns,
              regression=self.params.regression))
     node_update_ops = []
     node_update_ops.append(
-        tf.assign_add(self.variables.node_sums, node_sums))
+        state_ops.assign_add(self.variables.node_sums, node_sums))
 
     splits_update_ops = []
     splits_update_ops.append(self.training_ops.scatter_add_ndim(
@@ -527,8 +586,8 @@ class RandomTreeGraphs(object):
         totals_sums))
 
     if self.params.regression:
-      node_update_ops.append(tf.assign_add(self.variables.node_squares,
-                                           node_squares))
+      node_update_ops.append(state_ops.assign_add(self.variables.node_squares,
+                                                  node_squares))
       splits_update_ops.append(self.training_ops.scatter_add_ndim(
           self.variables.candidate_split_squares,
           splits_indices, splits_squares))
@@ -539,63 +598,56 @@ class RandomTreeGraphs(object):
     # Sample inputs.
     update_indices, feature_updates, threshold_updates = (
         self.training_ops.sample_inputs(
-            input_data, self.variables.node_to_accumulator_map,
+            input_data, sparse_indices, sparse_values, sparse_shape,
+            self.variables.node_to_accumulator_map,
             input_leaves, self.variables.candidate_split_features,
             self.variables.candidate_split_thresholds,
             split_initializations_per_input=(
                 self.params.split_initializations_per_input),
             split_sampling_random_seed=random_seed))
-    update_features_op = tf.scatter_update(
+    update_features_op = state_ops.scatter_update(
         self.variables.candidate_split_features, update_indices,
         feature_updates)
-    update_thresholds_op = tf.scatter_update(
+    update_thresholds_op = state_ops.scatter_update(
         self.variables.candidate_split_thresholds, update_indices,
         threshold_updates)
 
     # Calculate finished nodes.
-    with tf.control_dependencies(splits_update_ops):
-      children = tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1]),
-                            squeeze_dims=[1])
-      is_leaf = tf.equal(LEAF_NODE, children)
-      leaves = tf.to_int32(tf.squeeze(tf.where(is_leaf), squeeze_dims=[1]))
-      finished = self.training_ops.finished_nodes(
+    with ops.control_dependencies(splits_update_ops):
+      children = array_ops.squeeze(array_ops.slice(
+          self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1])
+      is_leaf = math_ops.equal(constants.LEAF_NODE, children)
+      leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf),
+                                                   squeeze_dims=[1]))
+      finished, stale = self.training_ops.finished_nodes(
           leaves, self.variables.node_to_accumulator_map,
+          self.variables.candidate_split_sums,
+          self.variables.candidate_split_squares,
           self.variables.accumulator_sums,
-          num_split_after_samples=self.params.split_after_samples)
+          self.variables.accumulator_squares,
+          self.variables.start_epoch, epoch,
+          num_split_after_samples=self.params.split_after_samples,
+          min_split_samples=self.params.min_split_samples)
 
     # Update leaf scores.
-    # TODO(gilberth): Optimize this. It currently calculates counts for
-    # every non-fertile leaf.
-    with tf.control_dependencies(node_update_ops):
-      def dont_update_leaf_scores():
-        return self.variables.non_fertile_leaf_scores
+    non_fertile_leaves = array_ops.boolean_mask(
+        leaves, math_ops.less(array_ops.gather(
+            self.variables.node_to_accumulator_map, leaves), 0))
 
-      def update_leaf_scores_regression():
-        sums = tf.gather(self.variables.node_sums,
-                         self.variables.non_fertile_leaves)
-        squares = tf.gather(self.variables.node_squares,
-                            self.variables.non_fertile_leaves)
-        new_scores = self._variance(sums, squares)
-        return tf.assign(self.variables.non_fertile_leaf_scores, new_scores)
-
-      def update_leaf_scores_classification():
-        counts = tf.gather(self.variables.node_sums,
-                           self.variables.non_fertile_leaves)
-        new_scores = self._weighted_gini(counts)
-        return tf.assign(self.variables.non_fertile_leaf_scores, new_scores)
-
-      # Because we can't have tf.self.variables of size 0, we have to put in a
-      # garbage value of -1 in there.  Here we check for that so we don't
-      # try to index into node_per_class_weights in a tf.gather with a negative
-      # number.
-      update_nonfertile_leaves_scores_op = tf.cond(
-          tf.less(self.variables.non_fertile_leaves[0], 0),
-          dont_update_leaf_scores,
-          update_leaf_scores_regression if self.params.regression else
-          update_leaf_scores_classification)
+    # TODO(gilberth): It should be possible to limit the number of non
+    # fertile leaves we calculate scores for, especially since we can only take
+    # at most array_ops.shape(finished)[0] of them.
+    with ops.control_dependencies(node_update_ops):
+      sums = array_ops.gather(self.variables.node_sums, non_fertile_leaves)
+      if self.params.regression:
+        squares = array_ops.gather(self.variables.node_squares,
+                                   non_fertile_leaves)
+        non_fertile_leaf_scores = self._variance(sums, squares)
+      else:
+        non_fertile_leaf_scores = self._weighted_gini(sums)
 
     # Calculate best splits.
-    with tf.control_dependencies(splits_update_ops):
+    with ops.control_dependencies(splits_update_ops):
       split_indices = self.training_ops.best_splits(
           finished, self.variables.node_to_accumulator_map,
           self.variables.candidate_split_sums,
@@ -605,7 +657,7 @@ class RandomTreeGraphs(object):
           regression=self.params.regression)
 
     # Grow tree.
-    with tf.control_dependencies([update_features_op, update_thresholds_op]):
+    with ops.control_dependencies([update_features_op, update_thresholds_op]):
       (tree_update_indices, tree_children_updates,
        tree_threshold_updates, tree_depth_updates, new_eot) = (
            self.training_ops.grow_tree(
@@ -613,110 +665,138 @@ class RandomTreeGraphs(object):
                self.variables.node_to_accumulator_map, finished, split_indices,
                self.variables.candidate_split_features,
                self.variables.candidate_split_thresholds))
-      tree_update_op = tf.scatter_update(
+      tree_update_op = state_ops.scatter_update(
           self.variables.tree, tree_update_indices, tree_children_updates)
-      threhsolds_update_op = tf.scatter_update(
+      thresholds_update_op = state_ops.scatter_update(
           self.variables.tree_thresholds, tree_update_indices,
           tree_threshold_updates)
-      depth_update_op = tf.scatter_update(
+      depth_update_op = state_ops.scatter_update(
           self.variables.tree_depths, tree_update_indices, tree_depth_updates)
+      # TODO(thomaswc): Only update the epoch on the new leaves.
+      new_epoch_updates = epoch * array_ops.ones_like(tree_depth_updates)
+      epoch_update_op = state_ops.scatter_update(
+          self.variables.start_epoch, tree_update_indices,
+          new_epoch_updates)
 
     # Update fertile slots.
-    with tf.control_dependencies([update_nonfertile_leaves_scores_op,
-                                  depth_update_op]):
-      (node_map_updates, accumulators_cleared, accumulators_allocated,
-       new_nonfertile_leaves, new_nonfertile_leaves_scores) = (
-           self.training_ops.update_fertile_slots(
-               finished, self.variables.non_fertile_leaves,
-               self.variables.non_fertile_leaf_scores,
-               self.variables.end_of_tree, self.variables.tree_depths,
-               self.variables.accumulator_sums,
-               self.variables.node_to_accumulator_map,
-               max_depth=self.params.max_depth,
-               regression=self.params.regression))
+    with ops.control_dependencies([depth_update_op]):
+      (node_map_updates, accumulators_cleared, accumulators_allocated) = (
+          self.training_ops.update_fertile_slots(
+              finished, non_fertile_leaves,
+              non_fertile_leaf_scores,
+              self.variables.end_of_tree, self.variables.tree_depths,
+              self.variables.accumulator_sums,
+              self.variables.node_to_accumulator_map,
+              stale,
+              max_depth=self.params.max_depth,
+              regression=self.params.regression))
 
     # Ensure end_of_tree doesn't get updated until UpdateFertileSlots has
     # used it to calculate new leaves.
-    gated_new_eot, = tf.tuple([new_eot], control_inputs=[new_nonfertile_leaves])
-    eot_update_op = tf.assign(self.variables.end_of_tree, gated_new_eot)
+    gated_new_eot, = control_flow_ops.tuple([new_eot],
+                                            control_inputs=[node_map_updates])
+    eot_update_op = state_ops.assign(self.variables.end_of_tree, gated_new_eot)
 
     updates = []
     updates.append(eot_update_op)
     updates.append(tree_update_op)
-    updates.append(threhsolds_update_op)
-    updates.append(tf.assign(
-        self.variables.non_fertile_leaves, new_nonfertile_leaves,
-        validate_shape=False))
-    updates.append(tf.assign(
-        self.variables.non_fertile_leaf_scores,
-        new_nonfertile_leaves_scores, validate_shape=False))
+    updates.append(thresholds_update_op)
+    updates.append(epoch_update_op)
 
-    updates.append(tf.scatter_update(
+    updates.append(state_ops.scatter_update(
         self.variables.node_to_accumulator_map,
-        tf.squeeze(tf.slice(node_map_updates, [0, 0], [1, -1]),
-                   squeeze_dims=[0]),
-        tf.squeeze(tf.slice(node_map_updates, [1, 0], [1, -1]),
-                   squeeze_dims=[0])))
+        array_ops.squeeze(array_ops.slice(node_map_updates, [0, 0], [1, -1]),
+                          squeeze_dims=[0]),
+        array_ops.squeeze(array_ops.slice(node_map_updates, [1, 0], [1, -1]),
+                          squeeze_dims=[0])))
 
-    cleared_and_allocated_accumulators = tf.concat(
+    cleared_and_allocated_accumulators = array_ops.concat(
         0, [accumulators_cleared, accumulators_allocated])
     # Calculate values to put into scatter update for candidate counts.
     # Candidate split counts are always reset back to 0 for both cleared
     # and allocated accumulators. This means some accumulators might be doubly
     # reset to 0 if the were released and not allocated, then later allocated.
-    split_values = tf.tile(
-        tf.expand_dims(tf.expand_dims(
-            tf.zeros_like(cleared_and_allocated_accumulators, dtype=tf.float32),
-            1), 2),
+    split_values = array_ops.tile(
+        array_ops.expand_dims(array_ops.expand_dims(
+            array_ops.zeros_like(cleared_and_allocated_accumulators,
+                                 dtype=dtypes.float32), 1), 2),
         [1, self.params.num_splits_to_consider, self.params.num_output_columns])
-    updates.append(tf.scatter_update(
+    updates.append(state_ops.scatter_update(
         self.variables.candidate_split_sums,
         cleared_and_allocated_accumulators, split_values))
     if self.params.regression:
-      updates.append(tf.scatter_update(
+      updates.append(state_ops.scatter_update(
           self.variables.candidate_split_squares,
           cleared_and_allocated_accumulators, split_values))
 
     # Calculate values to put into scatter update for total counts.
-    total_cleared = tf.tile(
-        tf.expand_dims(
-            tf.neg(tf.ones_like(accumulators_cleared, dtype=tf.float32)), 1),
+    total_cleared = array_ops.tile(
+        array_ops.expand_dims(
+            math_ops.neg(array_ops.ones_like(accumulators_cleared,
+                                             dtype=dtypes.float32)), 1),
         [1, self.params.num_output_columns])
-    total_reset = tf.tile(
-        tf.expand_dims(
-            tf.zeros_like(accumulators_allocated, dtype=tf.float32), 1),
+    total_reset = array_ops.tile(
+        array_ops.expand_dims(
+            array_ops.zeros_like(accumulators_allocated,
+                                 dtype=dtypes.float32), 1),
         [1, self.params.num_output_columns])
-    accumulator_updates = tf.concat(0, [total_cleared, total_reset])
-    updates.append(tf.scatter_update(
+    accumulator_updates = array_ops.concat(0, [total_cleared, total_reset])
+    updates.append(state_ops.scatter_update(
         self.variables.accumulator_sums,
         cleared_and_allocated_accumulators, accumulator_updates))
     if self.params.regression:
-      updates.append(tf.scatter_update(
+      updates.append(state_ops.scatter_update(
           self.variables.accumulator_squares,
           cleared_and_allocated_accumulators, accumulator_updates))
 
     # Calculate values to put into scatter update for candidate splits.
-    split_features_updates = tf.tile(
-        tf.expand_dims(
-            tf.neg(tf.ones_like(cleared_and_allocated_accumulators)), 1),
+    split_features_updates = array_ops.tile(
+        array_ops.expand_dims(
+            math_ops.neg(array_ops.ones_like(
+                cleared_and_allocated_accumulators)), 1),
         [1, self.params.num_splits_to_consider])
-    updates.append(tf.scatter_update(
+    updates.append(state_ops.scatter_update(
         self.variables.candidate_split_features,
         cleared_and_allocated_accumulators, split_features_updates))
 
-    return tf.group(*updates)
+    updates += self.finish_iteration()
 
-  def inference_graph(self, input_data):
+    return control_flow_ops.group(*updates)
+
+  def finish_iteration(self):
+    """Perform any operations that should be done at the end of an iteration.
+
+    This is mostly useful for subclasses that need to reset variables after
+    an iteration, such as ones that are used to finish nodes.
+
+    Returns:
+      A list of operations.
+    """
+    return []
+
+  def inference_graph(self, input_data, data_spec):
     """Constructs a TF graph for evaluating a random tree.
 
     Args:
-      input_data: A tensor or placeholder for input data.
+      input_data: A tensor or SparseTensor or placeholder for input data.
+      data_spec: A list of tf.dtype values specifying the original types of
+        each column.
 
     Returns:
       The last op in the random tree inference graph.
     """
+    sparse_indices = []
+    sparse_values = []
+    sparse_shape = []
+    if isinstance(input_data, ops.SparseTensor):
+      sparse_indices = input_data.indices
+      sparse_values = input_data.values
+      sparse_shape = input_data.shape
+      input_data = []
     return self.inference_ops.tree_predictions(
-        input_data, self.variables.tree, self.variables.tree_thresholds,
+        input_data, sparse_indices, sparse_values, sparse_shape, data_spec,
+        self.variables.tree,
+        self.variables.tree_thresholds,
         self.variables.node_sums,
         valid_leaf_threshold=self.params.valid_leaf_threshold)
 
@@ -729,13 +809,22 @@ class RandomTreeGraphs(object):
     Returns:
       The last op in the graph.
     """
-    children = tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1]),
-                          squeeze_dims=[1])
-    is_leaf = tf.equal(LEAF_NODE, children)
-    leaves = tf.to_int32(tf.squeeze(tf.where(is_leaf), squeeze_dims=[1]))
-    counts = tf.gather(self.variables.node_sums, leaves)
-    impurity = self._weighted_gini(counts)
-    return tf.reduce_sum(impurity) / tf.reduce_sum(counts + 1.0)
+    children = array_ops.squeeze(array_ops.slice(
+        self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1])
+    is_leaf = math_ops.equal(constants.LEAF_NODE, children)
+    leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf),
+                                                 squeeze_dims=[1]))
+    counts = array_ops.gather(self.variables.node_sums, leaves)
+    gini = self._weighted_gini(counts)
+    # Guard against step 1, when there often are no leaves yet.
+    def impurity():
+      return gini
+    # Since average impurity can be used for loss, when there's no data just
+    # return a big number so that loss always decreases.
+    def big():
+      return array_ops.ones_like(gini, dtype=dtypes.float32) * 10000000.
+    return control_flow_ops.cond(math_ops.greater(
+        array_ops.shape(leaves)[0], 0), impurity, big)
 
   def size(self):
     """Constructs a TF graph for evaluating the current number of nodes.
@@ -747,7 +836,8 @@ class RandomTreeGraphs(object):
 
   def get_stats(self, session):
     num_nodes = self.variables.end_of_tree.eval(session=session) - 1
-    num_leaves = tf.where(
-        tf.equal(tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1])),
-                 LEAF_NODE)).eval(session=session).shape[0]
+    num_leaves = array_ops.where(
+        math_ops.equal(array_ops.squeeze(array_ops.slice(
+            self.variables.tree, [0, 0], [-1, 1])), constants.LEAF_NODE)
+        ).eval(session=session).shape[0]
     return TreeStats(num_nodes, num_leaves)
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
index c3e1c8520d3..4e4cfcd1e82 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
@@ -105,6 +105,47 @@ class TensorForestTest(test_util.TensorFlowTestCase):
     graph = graph_builder.average_impurity()
     self.assertTrue(isinstance(graph, tf.Tensor))
 
+  def testTrainingConstructionClassificationSparse(self):
+    input_data = tf.SparseTensor(
+        indices=[[0, 0], [0, 3],
+                 [1, 0], [1, 7],
+                 [2, 1],
+                 [3, 9]],
+        values=[-1.0, 0.0,
+                -1., 2.,
+                1.,
+                -2.0],
+        shape=[4, 10])
+    input_labels = [0, 1, 2, 3]
+
+    params = tensor_forest.ForestHParams(
+        num_classes=4, num_features=10, num_trees=10, max_nodes=1000,
+        split_after_samples=25).fill()
+
+    graph_builder = tensor_forest.RandomForestGraphs(params)
+    graph = graph_builder.training_graph(input_data, input_labels)
+    self.assertTrue(isinstance(graph, tf.Operation))
+
+  def testInferenceConstructionSparse(self):
+    input_data = tf.SparseTensor(
+        indices=[[0, 0], [0, 3],
+                 [1, 0], [1, 7],
+                 [2, 1],
+                 [3, 9]],
+        values=[-1.0, 0.0,
+                -1., 2.,
+                1.,
+                -2.0],
+        shape=[4, 10])
+
+    params = tensor_forest.ForestHParams(
+        num_classes=4, num_features=10, num_trees=10, max_nodes=1000,
+        split_after_samples=25).fill()
+
+    graph_builder = tensor_forest.RandomForestGraphs(params)
+    graph = graph_builder.inference_graph(input_data)
+    self.assertTrue(isinstance(graph, tf.Tensor))
+
 
 if __name__ == '__main__':
   googletest.main()
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index b684522eb6d..b2a928867f4 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1812,10 +1812,13 @@ tf_cc_test(
     ],
 )
 
-tf_cc_test(
-    name = "ops/math_ops_test",
+tf_cc_tests(
     size = "small",
     linkstatic = tf_kernel_tests_linkstatic(),
+    tests = [
+        "ops/array_ops_test.cc",
+        "ops/math_ops_test.cc",
+    ],
     deps = [
         ":core",
         ":core_cpu",
diff --git a/tensorflow/core/client/tensor_c_api.cc b/tensorflow/core/client/tensor_c_api.cc
index a7f1a4fa78b..9959280029e 100644
--- a/tensorflow/core/client/tensor_c_api.cc
+++ b/tensorflow/core/client/tensor_c_api.cc
@@ -475,7 +475,7 @@ void TF_Run_Helper(TF_Session* s, const char* handle,
   // Store results in c_outputs[]
   for (int i = 0; i < noutputs; i++) {
     const Tensor& src = outputs[i];
-    if (!src.IsInitialized()) {
+    if (!src.IsInitialized() || src.NumElements() == 0) {
       c_outputs[i] = tensorflow::EmptyTensor(
           static_cast<TF_DataType>(src.dtype()), src.shape());
       continue;
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index 870970b7cab..46dd7913d32 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -746,6 +746,12 @@ void MasterSession::UpdateLastAccessTime() {
 }
 
 Status MasterSession::Create(GraphDef* graph_def) {
+  if (session_opts_.config.graph_options().place_pruned_graph()) {
+    // TODO(b/29900832): Fix this or remove the option.
+    return errors::Unimplemented(
+        "MasterSession does not support the place_pruned_graph option.");
+  }
+
   // Keeps a copy of graph_def->library() and flib_def_ serves the
   // OpRegistryInterface used by the SimpleGraphExecutionState to construct the
   // pre-partitioned graphs during DoRunWithLocalExecution().
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index 6bfc55df41f..12df379e8f0 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -159,7 +159,7 @@ Status OpKernelConstruction::allocate_temp(DataType type,
   attr.allocation_will_be_logged = true;
   Tensor new_temp(allocator_, type, shape, attr);
 
-  if (!new_temp.IsInitialized() && shape.num_elements() > 0) {
+  if (!new_temp.IsInitialized()) {
     return errors::ResourceExhausted(
         "OOM when allocating temporary tensor with shape", shape.DebugString());
   }
@@ -447,7 +447,7 @@ Status OpKernelContext::allocate_tensor(
   logged_attr.allocation_will_be_logged = true;
   Tensor new_tensor(a, type, shape, logged_attr);
 
-  if (!new_tensor.IsInitialized() && shape.num_elements() > 0) {
+  if (!new_tensor.IsInitialized()) {
     return errors::ResourceExhausted("OOM when allocating tensor with shape",
                                      shape.DebugString());
   }
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 0092c6286f8..a6cc323ceae 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -199,7 +199,9 @@ class PersistentTensor {
 
   // The check for initialization does not need to access the
   // underlying tensor buffer.
-  bool IsInitialized() { return tensor_.IsInitialized(); }
+  bool IsInitialized() const { return tensor_.IsInitialized(); }
+
+  int64 NumElements() const { return tensor_.NumElements(); }
 
  private:
   Tensor tensor_;
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 2df57d6cab3..bd8e6ea3094 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -25,9 +25,9 @@ constexpr int32 InferenceContext::kUnknownRank;
 constexpr int64 InferenceContext::kUnknownDim;
 
 InferenceContext::InferenceContext(
-    const std::vector<string>& input_shapes, int num_outputs,
-    const std::vector<const Tensor*>& input_tensors)
-    : input_tensors_(input_tensors) {
+    const NodeDef* node_def, const std::vector<string>& input_shapes,
+    int num_outputs, const std::vector<const Tensor*>& input_tensors)
+    : input_tensors_(input_tensors), node_def_(*CHECK_NOTNULL(node_def)) {
   for (const string& spec : input_shapes) {
     if (spec == "?") {
       inputs_.push_back(CreateUnknownShape());
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index bb6a66dc533..6385177bc19 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -17,6 +17,8 @@ limitations under the License.
 
 #include <vector>
 
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -80,7 +82,10 @@ class InferenceContext {
   //               the same Dimension*.
   //
   // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
-  InferenceContext(const std::vector<string>& input_shapes, int num_outputs,
+  //
+  // REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
+  InferenceContext(const NodeDef* node_def,
+                   const std::vector<string>& input_shapes, int num_outputs,
                    const std::vector<const Tensor*>& input_tensors = {});
   ~InferenceContext();
 
@@ -162,6 +167,12 @@ class InferenceContext {
   const Dimension* CreateDim(int64 value);
   const Dimension* CreateUnknownDim();
 
+  // Look up the attr for the NodeDef being evaluated with name attr_name and
+  // set *value to its value.  If no attr with attr_name is found in def(), or
+  // the attr does not have a matching type, a non-ok status will be returned.
+  template <class T>
+  Status GetAttr(StringPiece attr_name, T* value) const;
+
  private:
   Status ReturnUnknownShape(const Shape** out) {
     *out = CreateUnknownShape();
@@ -181,9 +192,14 @@ class InferenceContext {
   std::vector<const Tensor*> input_tensors_;
   std::vector<const Shape*> outputs_;
 
+  const NodeDef& node_def_;
+
   TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext);
 };
 
+// -----------------------------------------------------------------------------
+// Template and inline method implementations, please ignore
+
 inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
 inline Dimension::Dimension(int64 value) : value_(value) {}
 
@@ -191,6 +207,11 @@ inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
 inline Shape::Shape(const std::vector<const Dimension*> dims)
     : rank_(dims.size()), dims_(dims) {}
 
+template <class T>
+Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
+  return GetNodeAttr(node_def_, attr_name, value);
+}
+
 }  // namespace shape_inference
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc
index e4ca7645b2e..e52d1c5a2d6 100644
--- a/tensorflow/core/framework/shape_inference_test.cc
+++ b/tensorflow/core/framework/shape_inference_test.cc
@@ -14,6 +14,8 @@ limitations under the License.
 ==============================================================================*/
 #include "tensorflow/core/framework/shape_inference.h"
 
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_def_builder.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/platform/test.h"
 
@@ -21,7 +23,8 @@ namespace tensorflow {
 namespace shape_inference {
 
 TEST(ShapeInferenceTest, RankAndDimInspection) {
-  InferenceContext c({"?", "[1,?,3]", "[]"}, 2 /* num_outputs */);
+  NodeDef def;
+  InferenceContext c(&def, {"?", "[1,?,3]", "[]"}, 2 /* num_outputs */);
   EXPECT_EQ(3, c.num_inputs());
   EXPECT_EQ(2, c.num_outputs());
 
@@ -54,7 +57,8 @@ TEST(ShapeInferenceTest, RankAndDimInspection) {
 }
 
 TEST(ShapeInferenceTest, WithRank) {
-  InferenceContext c({"?", "[1,?,3]"}, 2 /* num_outputs */);
+  NodeDef def;
+  InferenceContext c(&def, {"?", "[1,?,3]"}, 2 /* num_outputs */);
 
   auto in0 = c.input(0);
   auto in1 = c.input(1);
@@ -91,7 +95,8 @@ TEST(ShapeInferenceTest, WithRank) {
 }
 
 TEST(ShapeInferenceTest, WithRankAtLeast) {
-  InferenceContext c({"?", "[1,?,3]"}, 2 /* num_outputs */);
+  NodeDef def;
+  InferenceContext c(&def, {"?", "[1,?,3]"}, 2 /* num_outputs */);
 
   auto in0 = c.input(0);
   auto in1 = c.input(1);
@@ -125,7 +130,8 @@ TEST(ShapeInferenceTest, WithRankAtLeast) {
 }
 
 TEST(ShapeInferenceTest, WithValue) {
-  InferenceContext c({"[1,?]"}, 2 /* num_outputs */);
+  NodeDef def;
+  InferenceContext c(&def, {"[1,?]"}, 2 /* num_outputs */);
 
   auto d0 = c.Dim(c.input(0), 0);
   auto d1 = c.Dim(c.input(0), 1);
@@ -163,7 +169,8 @@ TEST(ShapeInferenceTest, WithValue) {
 }
 
 TEST(ShapeInferenceTest, MergeDim) {
-  InferenceContext c({"[2,?,2,1,?]"}, 2 /* num_outputs */);
+  NodeDef def;
+  InferenceContext c(&def, {"[2,?,2,1,?]"}, 2 /* num_outputs */);
 
   auto d2 = c.Dim(c.input(0), 0);
   auto d_unknown = c.Dim(c.input(0), 1);
@@ -202,7 +209,9 @@ TEST(ShapeInferenceTest, MergeDim) {
 }
 
 TEST(ShapeInferenceTest, MergeShape) {
-  InferenceContext c({"?", "[1,2]", "[?,2]", "[1,?]", "[1,3]", "?", "[1]"},
+  NodeDef def;
+  InferenceContext c(&def,
+                     {"?", "[1,2]", "[?,2]", "[1,?]", "[1,3]", "?", "[1]"},
                      2 /* num_outputs */);
 
   auto s_unknown = c.input(0);
@@ -260,7 +269,8 @@ TEST(ShapeInferenceTest, MergeShape) {
 }
 
 TEST(ShapeInferenceTest, Subshape) {
-  InferenceContext c({"[1,2,3,?,5]", "?"}, 2 /* num_outputs */);
+  NodeDef def;
+  InferenceContext c(&def, {"[1,2,3,?,5]", "?"}, 2 /* num_outputs */);
 
   const Shape* unknown = c.input(1);
   const Shape* out;
@@ -297,7 +307,8 @@ TEST(ShapeInferenceTest, Subshape) {
 }
 
 TEST(ShapeInferenceTest, Concatenate) {
-  InferenceContext c({"[1,?,3]", "[4,5]", "?"}, 2 /* num_outputs */);
+  NodeDef def;
+  InferenceContext c(&def, {"[1,?,3]", "[4,5]", "?"}, 2 /* num_outputs */);
 
   auto in0 = c.input(0);
   auto in1 = c.input(1);
@@ -322,7 +333,8 @@ TEST(ShapeInferenceTest, Concatenate) {
 }
 
 TEST(ShapeInferenceTest, CreateShape) {
-  InferenceContext c({"[1,2,3,?,5]"}, 2 /* num_outputs */);
+  NodeDef def;
+  InferenceContext c(&def, {"[1,2,3,?,5]"}, 2 /* num_outputs */);
 
   std::vector<const Dimension*> dims;
   auto in0 = c.input(0);
@@ -341,7 +353,8 @@ TEST(ShapeInferenceTest, CreateShape) {
 }
 
 TEST(ShapeInferenceTest, CreateUnknownShape) {
-  InferenceContext c({}, 2 /* num_outputs */);
+  NodeDef def;
+  InferenceContext c(&def, {}, 2 /* num_outputs */);
 
   auto u0 = c.CreateUnknownShape();
   auto u1 = c.CreateUnknownShape();
@@ -352,7 +365,8 @@ TEST(ShapeInferenceTest, CreateUnknownShape) {
 
 TEST(ShapeInferenceTest, CreateShapeFromShapeTensor) {
   auto create = [](Tensor* t) {
-    InferenceContext c({"?"}, 0 /* num_outputs */, {t});
+    NodeDef def;
+    InferenceContext c(&def, {"?"}, 0 /* num_outputs */, {t});
     const Shape* out;
     Status s = c.CreateShapeFromShapeTensor(0, &out);
     if (s.ok()) {
@@ -386,7 +400,8 @@ TEST(ShapeInferenceTest, CreateShapeFromShapeTensor) {
 }
 
 TEST(ShapeInferenceTest, CreateDim) {
-  InferenceContext c({}, 2 /* num_outputs */);
+  NodeDef def;
+  InferenceContext c(&def, {}, 2 /* num_outputs */);
 
   auto* d0 = c.CreateDim(1);
   auto* d1 = c.CreateDim(1);
@@ -398,7 +413,8 @@ TEST(ShapeInferenceTest, CreateDim) {
 }
 
 TEST(ShapeInferenceTest, CreateUnknownDim) {
-  InferenceContext c({}, 2 /* num_outputs */);
+  NodeDef def;
+  InferenceContext c(&def, {}, 2 /* num_outputs */);
 
   auto* d0 = c.CreateUnknownDim();
   auto* d1 = c.CreateUnknownDim();
@@ -410,12 +426,29 @@ TEST(ShapeInferenceTest, CreateUnknownDim) {
 TEST(ShapeInferenceTest, InputTensors) {
   const Tensor t1 = tensorflow::test::AsTensor<float>({10});
   const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
-  InferenceContext c({"[1]", "[2]", "[3]"}, 2 /* num_outputs */, {&t1, &t2});
+  NodeDef def;
+  InferenceContext c(&def, {"[1]", "[2]", "[3]"}, 2 /* num_outputs */,
+                     {&t1, &t2});
 
   EXPECT_TRUE(c.input_tensor(0) == &t1);
   EXPECT_TRUE(c.input_tensor(1) == &t2);
   EXPECT_TRUE(c.input_tensor(2) == nullptr);
 }
 
+TEST(ShapeInferenceTest, GetAttr) {
+  OpRegistrationData op_reg_data;
+  CHECK(OpDefBuilder("dummy").Attr("foo:string").Finalize(&op_reg_data).ok());
+  NodeDef def;
+  CHECK(NodeDefBuilder("dummy", &op_reg_data.op_def)
+            .Attr("foo", "bar")
+            .Finalize(&def)
+            .ok());
+
+  InferenceContext c(&def, {}, 2 /* num_outputs */);
+  string value;
+  EXPECT_TRUE(c.GetAttr("foo", &value).ok());
+  EXPECT_EQ("bar", value);
+}
+
 }  // namespace shape_inference
 }  // namespace tensorflow
diff --git a/tensorflow/core/framework/shape_inference_testutil.cc b/tensorflow/core/framework/shape_inference_testutil.cc
index 9b56014edbe..f771e477644 100644
--- a/tensorflow/core/framework/shape_inference_testutil.cc
+++ b/tensorflow/core/framework/shape_inference_testutil.cc
@@ -29,13 +29,18 @@ using shape_inference::Shape;
 using errors::Unknown;
 
 Status InferShapes(const string& op_name, const string& ins,
-                   const string& expected_outs) {
+                   const string& expected_outs, const NodeDef* node_def) {
   const OpRegistrationData* op_reg_data;
   TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(op_name, &op_reg_data));
   const int num_outputs = op_reg_data->op_def.output_arg_size();
 
   std::vector<string> ins_v = str_util::Split(ins, ';');
-  shape_inference::InferenceContext c(ins_v, num_outputs);
+  std::unique_ptr<const NodeDef> new_node_def;
+  if (node_def == nullptr) {
+    new_node_def.reset(new NodeDef);
+    node_def = new_node_def.get();
+  }
+  shape_inference::InferenceContext c(node_def, ins_v, num_outputs);
   TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(&c));
 
   std::unordered_map<const Dimension*, std::pair<int, int>>
diff --git a/tensorflow/core/framework/shape_inference_testutil.h b/tensorflow/core/framework/shape_inference_testutil.h
index f2581247d9e..221ec875fb0 100644
--- a/tensorflow/core/framework/shape_inference_testutil.h
+++ b/tensorflow/core/framework/shape_inference_testutil.h
@@ -23,6 +23,8 @@ limitations under the License.
 
 namespace tensorflow {
 
+class NodeDef;
+
 // Run shape inference for <op_name>, given inputs specified by <ins>
 // and returns an error if the inferred shape does not match expected_outs.
 //
@@ -45,11 +47,16 @@ namespace tensorflow {
 // <expected_outs> can be "e"; this is used to indicate that shape inference
 // should have failed.
 Status InferShapes(const string& op_name, const string& ins,
-                   const string& expected_outs);
+                   const string& expected_outs,
+                   const NodeDef* node_def = nullptr);
 
 #define INFER_OK(op, i, o) EXPECT_EQ("", InferShapes(op, i, o).error_message())
 #define INFER_ERROR(s, op, i) \
-  EXPECT_EQ(s, InferShapes(op, i, "x").error_message())
+  EXPECT_EQ(s, InferShapes(op, i, "e").error_message())
+#define INFER_OK_WITH_DEF(op, nd, i, o) \
+  EXPECT_EQ("", InferShapes(op, i, o, nd).error_message())
+#define INFER_ERROR_WITH_DEF(s, op, nd, i) \
+  EXPECT_EQ(s, InferShapes(op, i, "e", nd).error_message())
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index 7b85ff9c364..de15d82269c 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -416,7 +416,8 @@ Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf)
 }
 
 bool Tensor::IsInitialized() const {
-  return buf_ != nullptr && buf_->data() != nullptr;
+  return (buf_ != nullptr && buf_->data() != nullptr) ||
+         shape_.num_elements() == 0;
 }
 
 void Tensor::CheckType(DataType expected_dtype) const {
@@ -507,7 +508,7 @@ Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape)
   if (shape_.num_elements() > 0 || a->ShouldAllocateEmptyTensors()) {
     CASES(type, buf_ = new Buffer<T>(a, shape.num_elements()));
   }
-  if (IsInitialized() && LogMemory::IsEnabled()) {
+  if (buf_ != nullptr && buf_->data() != nullptr && LogMemory::IsEnabled()) {
     LogMemory::RecordTensorAllocation("Unknown", LogMemory::UNKNOWN_STEP_ID,
                                       *this);
   }
@@ -521,8 +522,8 @@ Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape,
   if (shape_.num_elements() > 0 || a->ShouldAllocateEmptyTensors()) {
     CASES(type, buf_ = new Buffer<T>(a, shape.num_elements(), allocation_attr));
   }
-  if (!allocation_attr.allocation_will_be_logged && IsInitialized() &&
-      LogMemory::IsEnabled()) {
+  if (!allocation_attr.allocation_will_be_logged && buf_ != nullptr &&
+      buf_->data() != nullptr && LogMemory::IsEnabled()) {
     LogMemory::RecordTensorAllocation("Unknown (with attributes)",
                                       LogMemory::UNKNOWN_STEP_ID, *this);
   }
@@ -617,7 +618,7 @@ bool Tensor::FromProto(Allocator* a, const TensorProto& proto) {
   buf_ = p;
   // TODO(misard) add tracking of which kernels and steps are calling
   // FromProto.
-  if (IsInitialized() && LogMemory::IsEnabled()) {
+  if (buf_ != nullptr && buf_->data() != nullptr && LogMemory::IsEnabled()) {
     LogMemory::RecordTensorAllocation("Unknown (from Proto)",
                                       LogMemory::UNKNOWN_STEP_ID, *this);
   }
@@ -765,7 +766,7 @@ string Tensor::DebugString() const {
 void Tensor::FillDescription(TensorDescription* description) const {
   description->set_dtype(dtype());
   shape().AsProto(description->mutable_shape());
-  if (IsInitialized()) {
+  if (buf_ != nullptr && buf_->data() != nullptr) {
     buf_->FillAllocationDescription(
         description->mutable_allocation_description());
   }
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index dd2d9a4c863..48fbd38e0c4 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -120,7 +120,10 @@ class Tensor {
   // underlying refcounted storage
   size_t BufferHash() const;
 
-  /// Has this Tensor been initialized?
+  /// \brief If necessary, has this Tensor been initialized?
+  ///
+  /// Zero-element Tensors are always considered initialized, even if they
+  /// have never been assigned to and do not have any memory allocated.
   bool IsInitialized() const;
 
   /// Returns the estimated memory usage of this tensor.
diff --git a/tensorflow/core/framework/unique_tensor_references.cc b/tensorflow/core/framework/unique_tensor_references.cc
index 2ac6431c54b..ab33d9ede6c 100644
--- a/tensorflow/core/framework/unique_tensor_references.cc
+++ b/tensorflow/core/framework/unique_tensor_references.cc
@@ -33,7 +33,7 @@ UniqueTensorReferences::~UniqueTensorReferences() {
 void UniqueTensorReferences::Add(const Tensor& tensor) {
   DCHECK(!frozen_);
   // Do nothing if the tensor has a null buffer.
-  if (tensor.IsInitialized()) {
+  if (tensor.IsInitialized() && tensor.NumElements() > 0) {
     if (referenced_tensors_set_ != nullptr) {
       // There are enough tensors that we are using a hash set to
       // de-duplicate.
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 5cf48bfab5c..142f63c6b47 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1753,6 +1753,7 @@ filegroup(
         "cwise_ops.h",
         "cwise_ops_common.cc",
         "cwise_ops_common.h",
+        "cwise_ops_gradients.h",
         "dense_update_ops.cc",
         "dense_update_ops.h",
         "example_parsing_ops.cc",
diff --git a/tensorflow/core/kernels/barrier_ops.cc b/tensorflow/core/kernels/barrier_ops.cc
index 533b03db0e2..e66b9d41689 100644
--- a/tensorflow/core/kernels/barrier_ops.cc
+++ b/tensorflow/core/kernels/barrier_ops.cc
@@ -354,7 +354,8 @@ class Barrier : public ResourceBase {
         element.push_back(PersistentTensor(uninitialized));
       }
     }
-    if (element[1 + component_index].IsInitialized()) {
+    const PersistentTensor& component = element[1 + component_index];
+    if (component.IsInitialized() && component.NumElements() > 0) {
       return errors::InvalidArgument("Key ", keys_vec(i),
                                      " already has a value for component ",
                                      component_index, " in barrier ", name());
@@ -374,7 +375,7 @@ class Barrier : public ResourceBase {
     // ready queue.
     bool is_complete = true;
     for (int j = 0; is_complete && j < element.size(); ++j) {
-      is_complete = element[j].IsInitialized();
+      is_complete = element[j].IsInitialized() && element[j].NumElements() > 0;
     }
     if (is_complete) {
       // Add tuple to the ready queue. A queue tuple has the index
diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc
index 508ffc04029..487daa7c2da 100644
--- a/tensorflow/core/kernels/conv_grad_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_ops.cc
@@ -1024,6 +1024,9 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
       compatible_input_shape = input_shape;
     }
 
+    CHECK(padding_rows >= 0 && padding_cols >= 0)
+        << "Negative row or col paddings: (" << padding_rows << ", "
+        << padding_cols << ")";
     perftools::gputools::dnn::BatchDescriptor input_desc;
     input_desc.set_count(dims.batch_size)
         .set_height(GetTensorDim(compatible_input_shape, data_format_, 'H'))
@@ -1382,6 +1385,9 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
       compatible_input = input;
     }
 
+    CHECK(padding_rows >= 0 && padding_cols >= 0)
+        << "Negative row or col paddings: (" << padding_rows << ", "
+        << padding_cols << ")";
     perftools::gputools::dnn::BatchDescriptor input_desc;
     input_desc.set_count(dims.batch_size)
         .set_height(GetTensorDim(compatible_input, data_format_, 'H'))
diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc
index d0c6865951e..62e60d018b5 100644
--- a/tensorflow/core/kernels/conv_grad_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc
@@ -438,10 +438,10 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
     if (padding_ == Padding::SAME) {
       padding_planes =
           (output_planes - 1) * strides[0] + filter_size[0] - input_size[0];
-      padding_cols =
-          (output_cols - 1) * strides[2] + filter_size[2] - input_size[2];
-      padding_rows =
-          (output_rows - 1) * strides[1] + filter_size[1] - input_size[1];
+      padding_cols = std::max<int>(
+          0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]);
+      padding_rows = std::max<int>(
+          0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]);
     }
     const bool rows_odd = (padding_rows % 2 != 0);
     const bool cols_odd = (padding_cols % 2 != 0);
@@ -462,6 +462,9 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
                                 input_size[2]};
     }
 
+    CHECK(padding_rows >= 0 && padding_cols >= 0)
+        << "Negative row or col paddings: (" << padding_rows << ", "
+        << padding_cols << ")";
     perftools::gputools::dnn::BatchDescriptor input_desc(3);
     input_desc.set_count(batch)
         .set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4))
@@ -659,10 +662,10 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
     if (padding_ == Padding::SAME) {
       padding_planes =
           (output_planes - 1) * strides[0] + filter_size[0] - input_size[0];
-      padding_cols =
-          (output_cols - 1) * strides[2] + filter_size[2] - input_size[2];
-      padding_rows =
-          (output_rows - 1) * strides[1] + filter_size[1] - input_size[1];
+      padding_cols = std::max<int>(
+          0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]);
+      padding_rows = std::max<int>(
+          0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]);
     }
     bool rows_odd = (padding_rows % 2 != 0);
     bool cols_odd = (padding_cols % 2 != 0);
@@ -686,6 +689,9 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
       compatible_input = input;
     }
 
+    CHECK(padding_rows >= 0 && padding_cols >= 0)
+        << "Negative row or col paddings: (" << padding_rows << ", "
+        << padding_cols << ")";
     perftools::gputools::dnn::BatchDescriptor input_desc(3);
     input_desc.set_count(batch)
         .set_spatial_dim(DimIndex::X, compatible_input.dim_size(3))
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index ede9a77ed0f..e0aff98854d 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -334,8 +334,10 @@ class LaunchConvOp<GPUDevice, T> {
       // We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top
       // and Pc - Pc/2 on the bottom.  When Pr or Pc is odd, this means
       // we pad more on the right and bottom than on the top and left.
-      padding_rows = (out_rows - 1) * row_stride + patch_rows - in_rows;
-      padding_cols = (out_cols - 1) * col_stride + patch_cols - in_cols;
+      padding_rows =
+          std::max<int>(0, (out_rows - 1) * row_stride + patch_rows - in_rows);
+      padding_cols =
+          std::max<int>(0, (out_cols - 1) * col_stride + patch_cols - in_cols);
       const bool rows_odd = (padding_rows % 2 != 0);
       const bool cols_odd = (padding_cols % 2 != 0);
       if (rows_odd || cols_odd) {
@@ -375,6 +377,9 @@ class LaunchConvOp<GPUDevice, T> {
       input = transformed_input;
     }
 
+    CHECK(padding_rows >= 0 && padding_cols >= 0)
+        << "Negative row or col paddings: (" << padding_rows << ", "
+        << padding_cols << ")";
     perftools::gputools::dnn::BatchDescriptor input_desc;
     input_desc.set_count(in_batch)
         .set_feature_map_count(in_depths)
diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc
index 697b3f62679..e236edfc0d3 100644
--- a/tensorflow/core/kernels/conv_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_ops_3d.cc
@@ -160,8 +160,10 @@ struct LaunchConvOp<GPUDevice, T> {
 
     if (padding == Padding::SAME) {
       pad_planes = (out_planes - 1) * strides[0] + filter_planes - in_planes;
-      pad_rows = (out_rows - 1) * strides[1] + filter_rows - in_rows;
-      pad_cols = (out_cols - 1) * strides[2] + filter_cols - in_cols;
+      pad_rows = std::max<int64>(
+          0, (out_rows - 1) * strides[1] + filter_rows - in_rows);
+      pad_cols = std::max<int64>(
+          0, (out_cols - 1) * strides[2] + filter_cols - in_cols);
     }
 
     // NOTE: This only works in NHWC.
@@ -239,6 +241,9 @@ struct LaunchConvOp<GPUDevice, T> {
         transformed_input.tensor<T, 5>());
     input = transformed_input;
 
+    CHECK(pad_rows >= 0 && pad_cols >= 0) << "Negative row or col paddings: ("
+                                          << pad_rows << ", " << pad_cols
+                                          << ")";
     perftools::gputools::dnn::BatchDescriptor input_desc(3);
     input_desc.set_count(in_batch)
         .set_feature_map_count(in_depth)
diff --git a/tensorflow/core/kernels/cwise_op_gpu_sigmoid.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_sigmoid.cu.cc
index a7ac9baca08..b59d22310e0 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_sigmoid.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_sigmoid.cu.cc
@@ -16,10 +16,12 @@ limitations under the License.
 #if GOOGLE_CUDA
 
 #include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+#include "tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h"
 
 namespace tensorflow {
 namespace functor {
 DEFINE_UNARY3(sigmoid, Eigen::half, float, double);
+DEFINE_SIMPLE_BINARY3(sigmoid_grad, Eigen::half, float, double);
 }  // namespace functor
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/kernels/cwise_op_gpu_tanh.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_tanh.cu.cc
index 1678086c35e..66ee3c193e0 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_tanh.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_tanh.cu.cc
@@ -16,10 +16,12 @@ limitations under the License.
 #if GOOGLE_CUDA
 
 #include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+#include "tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h"
 
 namespace tensorflow {
 namespace functor {
 DEFINE_UNARY3(tanh, Eigen::half, float, double);
+DEFINE_SIMPLE_BINARY3(tanh_grad, Eigen::half, float, double);
 }  // namespace functor
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/kernels/cwise_op_sigmoid.cc b/tensorflow/core/kernels/cwise_op_sigmoid.cc
index 9d8a849bd33..cc1f9b8f03e 100644
--- a/tensorflow/core/kernels/cwise_op_sigmoid.cc
+++ b/tensorflow/core/kernels/cwise_op_sigmoid.cc
@@ -14,6 +14,7 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/core/kernels/cwise_ops_common.h"
+#include "tensorflow/core/kernels/cwise_ops_gradients.h"
 
 namespace tensorflow {
 REGISTER5(UnaryOp, CPU, "Sigmoid", functor::sigmoid, float, Eigen::half, double,
@@ -22,4 +23,12 @@ REGISTER5(UnaryOp, CPU, "Sigmoid", functor::sigmoid, float, Eigen::half, double,
 REGISTER3(UnaryOp, GPU, "Sigmoid", functor::sigmoid, float, Eigen::half,
           double);
 #endif
+
+REGISTER5(SimpleBinaryOp, CPU, "SigmoidGrad", functor::sigmoid_grad, float,
+          Eigen::half, double, complex64, complex128);
+#if GOOGLE_CUDA
+REGISTER3(SimpleBinaryOp, GPU, "SigmoidGrad", functor::sigmoid_grad, float,
+          Eigen::half, double);
+#endif
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_tanh.cc b/tensorflow/core/kernels/cwise_op_tanh.cc
index 6604d71d14c..a4c4aad053f 100644
--- a/tensorflow/core/kernels/cwise_op_tanh.cc
+++ b/tensorflow/core/kernels/cwise_op_tanh.cc
@@ -14,6 +14,7 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/core/kernels/cwise_ops_common.h"
+#include "tensorflow/core/kernels/cwise_ops_gradients.h"
 
 namespace tensorflow {
 REGISTER5(UnaryOp, CPU, "Tanh", functor::tanh, float, Eigen::half, double,
@@ -21,4 +22,11 @@ REGISTER5(UnaryOp, CPU, "Tanh", functor::tanh, float, Eigen::half, double,
 #if GOOGLE_CUDA
 REGISTER3(UnaryOp, GPU, "Tanh", functor::tanh, float, Eigen::half, double);
 #endif
+
+REGISTER5(SimpleBinaryOp, CPU, "TanhGrad", functor::tanh_grad, float,
+          Eigen::half, double, complex64, complex128);
+#if GOOGLE_CUDA
+REGISTER3(SimpleBinaryOp, GPU, "TanhGrad", functor::tanh_grad, float,
+          Eigen::half, double);
+#endif
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h
index 02a82c00bf0..6ccbe46c7fa 100644
--- a/tensorflow/core/kernels/cwise_ops_common.h
+++ b/tensorflow/core/kernels/cwise_ops_common.h
@@ -21,6 +21,7 @@ limitations under the License.
 #define EIGEN_USE_THREADS
 
 #include "tensorflow/core/kernels/cwise_ops.h"
+#include "tensorflow/core/kernels/cwise_ops_gradients.h"
 
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/op_kernel.h"
@@ -130,6 +131,35 @@ class BinaryOp : public BinaryOpShared {
   }
 };
 
+// Basic coefficient-wise binary operations that are known to not require
+// any broadcasting. This is the case for example of the gradients of
+// unary operations.
+//   Device: E.g., CPUDevice, GPUDevice.
+//   Functor: defined above. E.g., functor::tanh_grad.
+template <typename Device, typename Functor>
+class SimpleBinaryOp : public OpKernel {
+ public:
+  typedef typename Functor::in_type Tin;    // Input scalar data type.
+  typedef typename Functor::out_type Tout;  // Output scalar data type.
+
+  explicit SimpleBinaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    const Tensor& in0 = ctx->input(0);
+    const Tensor& in1 = ctx->input(1);
+
+    Tensor* out;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in0.shape(), &out));
+    auto out_flat = out->flat<Tout>();
+    auto in0_flat = in0.flat<Tin>();
+    auto in1_flat = in1.flat<Tin>();
+    const Device& eigen_device = ctx->eigen_device<Device>();
+
+    functor::SimpleBinaryFunctor<Device, Functor>()(eigen_device, out_flat,
+                                                    in0_flat, in1_flat);
+  }
+};
+
 // Coefficient-wise unary operations:
 //   Device: E.g., CPUDevice, GPUDevice.
 //   Functor: defined in cwise_functors.h. E.g., functor::sqrt.
diff --git a/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h b/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h
new file mode 100644
index 00000000000..43947707089
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h
@@ -0,0 +1,71 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if !GOOGLE_CUDA
+#error This file must only be included when building with Cuda support
+#endif
+
+#ifndef TENSORFLOW_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
+#define TENSORFLOW_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
+
+#define EIGEN_USE_GPU
+
+#include <complex>
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/cwise_ops.h"
+#include "tensorflow/core/kernels/cwise_ops_gradients.h"
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/core/platform/logging.h"
+namespace tensorflow {
+namespace functor {
+
+typedef Eigen::GpuDevice GPUDevice;
+typedef std::complex<float> complex64;
+typedef std::complex<double> complex128;
+
+// Partial specialization of SimpleBinaryFunctor<Device=GPUDevice, Functor>.
+template <typename Functor>
+struct SimpleBinaryFunctor<GPUDevice, Functor> {
+  void operator()(const GPUDevice& d, typename Functor::tout_type out,
+                  typename Functor::tin_type in1,
+                  typename Functor::tin_type in2) {
+    To32Bit(out).device(d) =
+        To32Bit(in1).binaryExpr(in2, typename Functor::func());
+  }
+};
+
+// Macros to explicitly instantiate kernels on GPU for multiple types
+// (T0, T1, etc.) for SimpleBiaryFunctor (e.g., functor::tanh_grad).
+#define DEFINE_SIMPLE_BINARY1(F, T)                  \
+  template struct SimpleBinaryFunctor<GPUDevice, F<T> >
+#define DEFINE_SIMPLE_BINARY2(F, T0, T1)             \
+  DEFINE_SIMPLE_BINARY1(F, T0);                      \
+  DEFINE_SIMPLE_BINARY1(F, T1)
+#define DEFINE_SIMPLE_BINARY3(F, T0, T1, T2)         \
+  DEFINE_SIMPLE_BINARY2(F, T0, T1);                  \
+  DEFINE_SIMPLE_BINARY1(F, T2)
+#define DEFINE_SIMPLE_BINARY4(F, T0, T1, T2, T3)     \
+  DEFINE_SIMPLE_BINARY2(F, T0, T1);                  \
+  DEFINE_SIMPLE_BINARY2(F, T2, T3)
+#define DEFINE_SIMPLE_BINARY5(F, T0, T1, T2, T3, T4) \
+  DEFINE_SIMPLE_BINARY2(F, T0, T1);                  \
+  DEFINE_SIMPLE_BINARY3(F, T2, T3, T4)
+
+}  // end namespace functor
+}  // end namespace tensorflow
+
+#endif  // TENSORFLOW_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
diff --git a/tensorflow/core/kernels/cwise_ops_gradients.h b/tensorflow/core/kernels/cwise_ops_gradients.h
new file mode 100644
index 00000000000..a59f1572810
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_ops_gradients.h
@@ -0,0 +1,107 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_KERNELS_CWISE_OPS_GRADIENTS_H_
+#define TENSORFLOW_KERNELS_CWISE_OPS_GRADIENTS_H_
+
+#define EIGEN_USE_THREADS
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace Eigen {
+namespace internal {
+
+// Gradient for the tanh function
+template <typename T>
+struct scalar_tanh_gradient_op {
+  EIGEN_EMPTY_STRUCT_CTOR(scalar_tanh_gradient_op)
+  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
+  operator()(const T& output, const T& output_gradient) const {
+    return output_gradient * (T(1) - output * output);
+  }
+  template <typename Packet>
+  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
+  packetOp(const Packet& output, const Packet& output_gradient) const {
+    return pmul(output_gradient,
+                psub(pset1<Packet>(T(1)), pmul(output, output)));
+  }
+};
+template <typename T>
+struct functor_traits<scalar_tanh_gradient_op<T>> {
+  enum {
+    Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost,
+    PacketAccess = packet_traits<T>::HasSub && packet_traits<T>::HasMul,
+  };
+};
+
+// Gradient for the sigmoid function
+template <typename T>
+struct scalar_sigmoid_gradient_op {
+  EIGEN_EMPTY_STRUCT_CTOR(scalar_sigmoid_gradient_op)
+  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
+  operator()(const T& output, const T& output_gradient) const {
+    return output_gradient * output * (T(1) - output);
+  }
+  template <typename Packet>
+  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
+  packetOp(const Packet& output, const Packet& output_gradient) const {
+    return pmul(output_gradient,
+                pmul(output, psub(pset1<Packet>(T(1)), output)));
+  }
+};
+template <typename T>
+struct functor_traits<scalar_sigmoid_gradient_op<T>> {
+  enum {
+    Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost,
+    PacketAccess = packet_traits<T>::HasSub && packet_traits<T>::HasMul,
+  };
+};
+
+}  // end namespace internal
+}  // end namespace Eigen
+
+namespace tensorflow {
+
+namespace functor {
+
+template <typename Device, typename Functor>
+struct SimpleBinaryFunctor {
+  void operator()(const Device& d, typename Functor::tout_type out,
+                  typename Functor::tin_type in0,
+                  typename Functor::tin_type in1);
+};
+
+// Partial specialization of BinaryFunctor for CPU devices
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template <typename Functor>
+struct SimpleBinaryFunctor<CPUDevice, Functor> {
+  void operator()(const CPUDevice& d, typename Functor::tout_type out,
+                  typename Functor::tin_type in0,
+                  typename Functor::tin_type in1) {
+    out.device(d) = in0.binaryExpr(in1, typename Functor::func());
+  }
+};
+
+template <typename T>
+struct tanh_grad : base<T, Eigen::internal::scalar_tanh_gradient_op<T>> {};
+
+template <typename T>
+struct sigmoid_grad : base<T, Eigen::internal::scalar_sigmoid_gradient_op<T>> {
+};
+
+}  // end namespace functor
+
+}  // end namespace tensorflow
+#endif  // TENSORFLOW_KERNELS_CWISE_OPS_GRADIENTS_H_
diff --git a/tensorflow/core/kernels/sparse_xent_op.cc b/tensorflow/core/kernels/sparse_xent_op.cc
index 34411c9bbb6..48124d20af9 100644
--- a/tensorflow/core/kernels/sparse_xent_op.cc
+++ b/tensorflow/core/kernels/sparse_xent_op.cc
@@ -35,38 +35,42 @@ class SparseSoftmaxXentWithLogitsOp : public OpKernel {
       : OpKernel(context) {}
 
   void Compute(OpKernelContext* context) override {
-    const Tensor& logits_in = context->input(0);
-    const Tensor& labels_in = context->input(1);
-    OP_REQUIRES(context, logits_in.shape().dim_size(0) == labels_in.NumElements(),
+    const Tensor& logits = context->input(0);
+    const Tensor& labels = context->input(1);
+    OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits.shape()),
+                errors::InvalidArgument("logits must be 2-D, but got shape ",
+                                        logits.shape().DebugString()));
+    OP_REQUIRES(context, TensorShapeUtils::IsVector(labels.shape()),
+                errors::InvalidArgument("labels must be 1-D, but got shape ",
+                                        labels.shape().DebugString()));
+    OP_REQUIRES(context, logits.dim_size(0) == labels.dim_size(0),
                 errors::InvalidArgument(
-                    "logits first dimension must match labels size.  logits shape=",
-                    logits_in.shape().DebugString(), " labels shape=",
-                    labels_in.shape().DebugString()));
-    OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()),
-                errors::InvalidArgument("logits must be 2-dimensional"));
-    // As we already tested that both inputs have the same shape no need to
-    // check that "labels" is a matrix too.
-
-    // loss is 1-D (one per example), and size is batch_size.
+                    "logits and labels must have the same first dimension, "
+                    "got logits shape ",
+                    logits.shape().DebugString(), " and labels shape ",
+                    labels.shape().DebugString()));
+    OP_REQUIRES(context, logits.dim_size(1) > 0,
+                errors::InvalidArgument(
+                    "Must have at least one class, but got logits shape ",
+                    logits.shape().DebugString()));
 
     Tensor scratch;
-    OP_REQUIRES_OK(
-        context, context->allocate_temp(DataTypeToEnum<T>::value,
-                                        TensorShape({logits_in.dim_size(0)}),
-                                        &scratch));
+    OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
+                                                   labels.shape(), &scratch));
 
     Tensor* loss_out = nullptr;
     OP_REQUIRES_OK(context,
-                   context->allocate_output(
-                       0, TensorShape({logits_in.dim_size(0)}), &loss_out));
+                   context->allocate_output(0, labels.shape(), &loss_out));
     Tensor* back_out = nullptr;
     OP_REQUIRES_OK(context,
-                   context->allocate_output(1, logits_in.shape(), &back_out));
+                   context->allocate_output(1, logits.shape(), &back_out));
 
-    functor::SparseXentFunctor<Device, T, Index> functor;
-    functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
-            labels_in.vec<Index>(), scratch.vec<T>(), loss_out->vec<T>(),
-            back_out->matrix<T>());
+    if (logits.dim_size(0) > 0) {
+      functor::SparseXentFunctor<Device, T, Index> functor;
+      functor(context->eigen_device<Device>(), logits.matrix<T>(),
+              labels.vec<Index>(), scratch.vec<T>(), loss_out->vec<T>(),
+              back_out->matrix<T>());
+    }
   }
 };
 
diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h
index 03b2c3b68b9..1456ec28447 100644
--- a/tensorflow/core/kernels/tensor_array.h
+++ b/tensorflow/core/kernels/tensor_array.h
@@ -441,7 +441,7 @@ Status TensorArray::LockedWriteOrAggregate(OpKernelContext* ctx,
           " but the new input shape is ", value_t->shape().DebugString(), ".");
     }
 
-    if (!t.tensor.IsInitialized()) {
+    if (!t.tensor.IsInitialized() || t.tensor.NumElements() == 0) {
       // If existing_t == nullptr but written == true, then what was stored
       // was just a shape, which just means zeros.  So all we must do in this
       // case is copy the reference over and return early.
@@ -502,7 +502,7 @@ Status TensorArray::LockedRead(OpKernelContext* ctx, const int32 index,
                                    "clear_after_read = false?).");
   }
 
-  if (!t.tensor.IsInitialized()) {
+  if (!t.tensor.IsInitialized() || t.tensor.NumElements() == 0) {
     // We stored just a shape, but no value.  This means create and
     // return zeros of the appropriate shape.
     Tensor* tensor_t;
diff --git a/tensorflow/core/lib/gtl/inlined_vector_test.cc b/tensorflow/core/lib/gtl/inlined_vector_test.cc
index 88ec1069c5d..7ce1a1d395f 100644
--- a/tensorflow/core/lib/gtl/inlined_vector_test.cc
+++ b/tensorflow/core/lib/gtl/inlined_vector_test.cc
@@ -285,6 +285,7 @@ TEST(RefCountedVec, InsertConstructorDestructor) {
     for (int pos = 0; pos <= len; pos++) {
       SCOPED_TRACE(pos);
       std::vector<int> counts(len, 0);
+      int inserted_count = 0;
       RefCountedVec v;
       for (int i = 0; i < len; ++i) {
         SCOPED_TRACE(i);
@@ -295,7 +296,6 @@ TEST(RefCountedVec, InsertConstructorDestructor) {
         EXPECT_EQ(1, elem);
       }
 
-      int inserted_count = 0;
       RefCounted insert_element(9999, &inserted_count);
       EXPECT_EQ(1, inserted_count);
       v.insert(v.begin() + pos, insert_element);
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index dc96588f73a..4ef3a48221a 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -14,17 +14,67 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
 #include "tensorflow/core/util/mirror_pad_mode.h"
 #include "tensorflow/core/util/padding.h"
 
 namespace tensorflow {
 
+typedef shape_inference::Dimension Dimension;
+typedef shape_inference::InferenceContext InferenceContext;
+typedef shape_inference::Shape Shape;
+
+namespace {
+
+Status GetAxisForPackAndUnpack(InferenceContext* c, int32 rank_after_pack,
+                               int32* axis) {
+  TF_RETURN_IF_ERROR(c->GetAttr("axis", axis));
+  if (*axis < -1 * rank_after_pack || *axis >= rank_after_pack) {
+    return errors::InvalidArgument("Invalid axis: ", *axis, "; must be in [",
+                                   -1 * rank_after_pack, ",", rank_after_pack,
+                                   ")");
+  }
+  if (*axis < 0) *axis = (rank_after_pack + *axis);
+  return Status::OK();
+}
+
+}  // namespace
+
 REGISTER_OP("Pack")
     .Input("values: N * T")
     .Output("output: T")
     .Attr("N: int >= 1")
     .Attr("T: type")
     .Attr("axis: int = 0")
+    .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+      // Validate shapes of all inputs are compatible
+      const Shape* cur = c->input(c->num_inputs() - 1);
+      for (int i = c->num_inputs() - 2; i >= 0; --i) {
+        TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
+                                        "From merging shape ", i,
+                                        " with other shapes.");
+      }
+      if (!c->RankKnown(cur)) {
+        c->set_output(0, c->CreateUnknownShape());
+        return Status::OK();
+      }
+      // Determine the axis that will be added, converting from negative
+      // axes to a positive point per negative indexing rules.
+      int32 rank = c->Rank(cur);
+      int32 axis;
+      TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank + 1, &axis));
+
+      // Copy all dimensions over, inserting a dimension of value #inputs
+      // at <axis>.
+      std::vector<const Dimension*> dims;
+      int index = 0;
+      while (index < axis) dims.push_back(c->Dim(cur, index++));
+      dims.push_back(c->CreateDim(c->num_inputs()));
+      while (index < rank) dims.push_back(c->Dim(cur, index++));
+
+      c->set_output(0, c->CreateShape(dims));
+      return Status::OK();
+    }))
     .Doc(R"doc(
 Packs a list of `N` rank-`R` tensors into one rank-`(R+1)` tensor.
 
@@ -61,6 +111,29 @@ REGISTER_OP("Unpack")
     .Attr("num: int >= 0")
     .Attr("T: type")
     .Attr("axis: int = 0")
+    .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+      const Shape* s = c->input(0);
+      const Shape* out;
+      if (c->RankKnown(s)) {
+        // Determine the axis that will be removed, converting from negative
+        // axes to a positive point per negative indexing rules.
+        int32 rank = c->Rank(s);
+        int32 axis;
+        TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank, &axis));
+
+        // Copy all dimensions, removing the <axis> dimension.
+        std::vector<const Dimension*> dims;
+        for (int i = 0; i < rank; ++i) {
+          if (i != axis) dims.push_back(c->Dim(s, i));
+        }
+        out = c->CreateShape(dims);
+      } else {
+        // All outputs are the same shape, but it's not known.
+        out = c->CreateUnknownShape();
+      }
+      for (int i = 0; i < c->num_outputs(); ++i) c->set_output(i, out);
+      return Status::OK();
+    }))
     .Doc(R"doc(
 Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors.
 
@@ -154,6 +227,18 @@ REGISTER_OP("Const")
     .Output("output: dtype")
     .Attr("value: tensor")
     .Attr("dtype: type")
+    .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+      const TensorProto* proto = nullptr;
+      TF_RETURN_IF_ERROR(c->GetAttr("value", &proto));
+      TF_RETURN_IF_ERROR(TensorShape::IsValidShape(proto->tensor_shape()));
+      TensorShape shape(proto->tensor_shape());
+      std::vector<const Dimension*> dims;
+      for (int i = 0; i < shape.dims(); ++i) {
+        dims.push_back(c->CreateDim(shape.dim_size(i)));
+      }
+      c->set_output(0, c->CreateShape(dims));
+      return Status::OK();
+    }))
     .Doc(R"doc(
 Returns a constant tensor.
 
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
new file mode 100644
index 00000000000..19dfa293584
--- /dev/null
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -0,0 +1,137 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference_testutil.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+TEST(ArrayOpsTest, Pack_ShapeFn) {
+  std::unique_ptr<NodeDef> def_storage(new NodeDef);
+  NodeDef* def = def_storage.get();
+  auto set_axis = [def](int axis) {
+    TF_CHECK_OK(NodeDefBuilder("test", "Pack")
+                    .Input({{"a", 0, DT_FLOAT}})
+                    .Attr("axis", axis)
+                    .Finalize(def));
+  };
+  const char op[] = "Pack";
+
+  set_axis(0);
+  INFER_OK_WITH_DEF(op, def, "?;?;?", "?");
+
+  for (int axis : {0, -3}) {
+    set_axis(axis);
+    INFER_OK_WITH_DEF(op, def, "?;?", "?");
+    INFER_OK_WITH_DEF(op, def, "[1,3];[1,3];?", "[3,d0_0|d1_0,d0_1|d1_1]");
+    INFER_OK_WITH_DEF(op, def, "[?,3];[1,3];?", "[3,d1_0,d0_1|d1_1]");
+    INFER_OK_WITH_DEF(op, def, "[?,?];[1,3];?", "[3,d1_0,d1_1]");
+  }
+  for (int axis : {1, -2}) {
+    set_axis(axis);
+    INFER_OK_WITH_DEF(op, def, "?;?", "?");
+    INFER_OK_WITH_DEF(op, def, "[1,3];[1,3];?", "[d0_0|d1_0,3,d0_1|d1_1]");
+    INFER_OK_WITH_DEF(op, def, "[?,3];[1,3];?", "[d1_0,3,d0_1|d1_1]");
+    INFER_OK_WITH_DEF(op, def, "[?,?];[1,3];?", "[d1_0,3,d1_1]");
+  }
+  for (int axis : {2, -1}) {
+    set_axis(axis);
+    INFER_OK_WITH_DEF(op, def, "?;?", "?");
+    INFER_OK_WITH_DEF(op, def, "[1,3];[1,3];?", "[d0_0|d1_0,d0_1|d1_1,3]");
+    INFER_OK_WITH_DEF(op, def, "[?,3];[1,3];?", "[d1_0,d0_1|d1_1,3]");
+    INFER_OK_WITH_DEF(op, def, "[?,?];[1,3];?", "[d1_0,d1_1,3]");
+  }
+
+  set_axis(-4);
+  INFER_ERROR_WITH_DEF("Invalid axis: -4; must be in [-3,3)", op, def,
+                       "[1,3];[1,3];?");
+  set_axis(3);
+  INFER_ERROR_WITH_DEF("Invalid axis: 3; must be in [-3,3)", op, def,
+                       "[1,3];[1,3];?");
+
+  set_axis(0);
+  INFER_ERROR_WITH_DEF(("Shapes must be equal rank, but are 3 and 2"
+                        "\n\tFrom merging shape 0 with other shapes."),
+                       op, def, "[1,2,3];?;[1,4]");
+}
+
+TEST(ArrayOpsTest, UnPack_ShapeFn) {
+  std::unique_ptr<NodeDef> def_storage(new NodeDef);
+  NodeDef* def = def_storage.get();
+  auto set_axis = [def](int axis) {
+    TF_CHECK_OK(NodeDefBuilder("test", "Unpack")
+                    .Input("a", 0, DT_FLOAT)
+                    .Attr("axis", axis)
+                    .Finalize(def));
+  };
+  const char op[] = "Unpack";
+
+  set_axis(0);
+  INFER_OK_WITH_DEF(op, def, "?;?;?", "?");
+
+  for (int axis : {0, -3}) {
+    set_axis(axis);
+    INFER_OK_WITH_DEF(op, def, "?", "?");
+    INFER_OK_WITH_DEF(op, def, "[1,2,3]", "[d0_1,d0_2]");
+    INFER_OK_WITH_DEF(op, def, "[?,?,?]", "[d0_1,d0_2]");
+  }
+  for (int axis : {1, -2}) {
+    set_axis(axis);
+    INFER_OK_WITH_DEF(op, def, "[1,2,3]", "[d0_0,d0_2]");
+    INFER_OK_WITH_DEF(op, def, "[?,?,?]", "[d0_0,d0_2]");
+  }
+  for (int axis : {2, -1}) {
+    set_axis(axis);
+    INFER_OK_WITH_DEF(op, def, "[1,2,3]", "[d0_0,d0_1]");
+    INFER_OK_WITH_DEF(op, def, "[?,?,?]", "[d0_0,d0_1]");
+  }
+
+  set_axis(-4);
+  INFER_ERROR_WITH_DEF("Invalid axis: -4; must be in [-3,3)", op, def,
+                       "[1,2,3]");
+  set_axis(3);
+  INFER_ERROR_WITH_DEF("Invalid axis: 3; must be in [-3,3)", op, def,
+                       "[1,2,3]");
+}
+
+TEST(ArrayOpsTest, Const_ShapeFn) {
+  std::unique_ptr<NodeDef> def_storage(new NodeDef);
+  NodeDef* def = def_storage.get();
+  TensorProto tensor_proto;
+  auto* shape_proto = tensor_proto.mutable_tensor_shape();
+  auto rebuild_node_def = [def, &tensor_proto]() {
+    TF_CHECK_OK(NodeDefBuilder("test", "Const")
+                    .Attr("value", tensor_proto)
+                    .Finalize(def));
+  };
+  const char op[] = "Const";
+
+  TensorShape{}.AsProto(shape_proto);
+  rebuild_node_def();
+  INFER_OK_WITH_DEF(op, def, "", "[]");
+  TensorShape{1, 2, 3, 4}.AsProto(shape_proto);
+  rebuild_node_def();
+  INFER_OK_WITH_DEF(op, def, "", "[1,2,3,4]");
+
+  shape_proto->add_dim()->set_size(-1);
+  rebuild_node_def();
+  INFER_ERROR_WITH_DEF("Shape [1,2,3,4,-1] has negative dimensions", op, def,
+                       "");
+}
+
+}  // end namespace tensorflow
diff --git a/tensorflow/core/ops/compat/ops_history.v0.pbtxt b/tensorflow/core/ops/compat/ops_history.v0.pbtxt
index adaa47ab8c5..2dba61efe78 100644
--- a/tensorflow/core/ops/compat/ops_history.v0.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v0.pbtxt
@@ -20208,6 +20208,34 @@ op {
     }
   }
 }
+op {
+  name: "SigmoidGrad"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
 op {
   name: "Sign"
   input_arg {
@@ -24557,6 +24585,34 @@ op {
     }
   }
 }
+op {
+  name: "TanhGrad"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
 op {
   name: "TemporaryVariable"
   output_arg {
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 0f9ee4942aa..b220a2d2d62 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -238,6 +238,13 @@ tf.complex_abs(x) ==> [5.25594902, 6.60492229]
       .Attr("T: {half, float, double, complex64, complex128}") \
       .SetShapeFn(OpShapeInferenceFn(shape_inference::UnchangedShape))
 
+#define UNARY_GRADIENT_COMPLEX()                               \
+  Input("x: T")                                                \
+      .Input("y: T")                                           \
+      .Output("z: T")                                          \
+      .Attr("T: {half, float, double, complex64, complex128}") \
+      .SetShapeFn(OpShapeInferenceFn(shape_inference::UnchangedShape))
+
 REGISTER_OP("Neg")
     .UNARY()
     .Doc(R"doc(
@@ -292,6 +299,13 @@ REGISTER_OP("Tanh")
 Computes hyperbolic tangent of `x` element-wise.
 )doc");
 
+REGISTER_OP("TanhGrad").UNARY_GRADIENT_COMPLEX().Doc(R"doc(
+Computes the gradient for the tanh of `x` wrt its input.
+
+Specifically, `grad = dy * (1 - y*y)`, where `y = tanh(x)`, and `dy`
+is the corresponding input gradient.
+)doc");
+
 REGISTER_OP("Lgamma")
     .UNARY_REAL()
     .Doc(R"doc(
@@ -325,6 +339,13 @@ Computes sigmoid of `x` element-wise.
 Specifically, `y = 1 / (1 + exp(-x))`.
 )doc");
 
+REGISTER_OP("SigmoidGrad").UNARY_GRADIENT_COMPLEX().Doc(R"doc(
+Computes the gradient of the sigmoid of `x` wrt its input.
+
+Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and
+`dy` is the corresponding input gradient.
+)doc");
+
 REGISTER_OP("Sin")
     .UNARY_COMPLEX()
     .Doc(R"doc(
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 474516bf4c6..afd6507b0d8 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -11796,6 +11796,36 @@ op {
   summary: "Computes sigmoid of `x` element-wise."
   description: "Specifically, `y = 1 / (1 + exp(-x))`."
 }
+op {
+  name: "SigmoidGrad"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  summary: "Computes the gradient of the sigmoid of `x` wrt its input."
+  description: "Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and\n`dy` is the corresponding input gradient."
+}
 op {
   name: "Sign"
   input_arg {
@@ -14643,6 +14673,36 @@ op {
   }
   summary: "Computes hyperbolic tangent of `x` element-wise."
 }
+op {
+  name: "TanhGrad"
+  input_arg {
+    name: "x"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "y"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "z"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  summary: "Computes the gradient for the tanh of `x` wrt its input."
+  description: "Specifically, `grad = dy * (1 - y*y)`, where `y = tanh(x)`, and `dy`\nis the corresponding input gradient."
+}
 op {
   name: "TemporaryVariable"
   output_arg {
diff --git a/tensorflow/core/platform/default/tracing.cc b/tensorflow/core/platform/default/tracing.cc
index 7910e97db9a..422564fb3e4 100644
--- a/tensorflow/core/platform/default/tracing.cc
+++ b/tensorflow/core/platform/default/tracing.cc
@@ -15,8 +15,6 @@ limitations under the License.
 
 #include "tensorflow/core/platform/tracing.h"
 
-#include <unistd.h>
-
 namespace tensorflow {
 namespace port {
 
@@ -26,21 +24,6 @@ void Tracing::RegisterEvent(EventCategory id, const char* name) {
 
 void Tracing::Initialize() {}
 
-static bool TryGetEnv(const char* name, const char** value) {
-  *value = getenv(name);
-  return *value != nullptr && (*value)[0] != '\0';
-}
-
-const char* Tracing::LogDir() {
-  const char* dir;
-  if (TryGetEnv("TEST_TMPDIR", &dir)) return dir;
-  if (TryGetEnv("TMP", &dir)) return dir;
-  if (TryGetEnv("TMPDIR", &dir)) return dir;
-  dir = "/tmp";
-  if (access(dir, R_OK | W_OK | X_OK) == 0) return dir;
-  return ".";  // Default to current directory.
-}
-
 static bool DoInit() {
   Tracing::Initialize();
   return true;
diff --git a/tensorflow/core/platform/posix/tracing.cc b/tensorflow/core/platform/posix/tracing.cc
new file mode 100644
index 00000000000..1d1aa53f2ca
--- /dev/null
+++ b/tensorflow/core/platform/posix/tracing.cc
@@ -0,0 +1,40 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/tracing.h"
+
+#include <stdlib.h>
+#include <unistd.h>
+
+namespace tensorflow {
+namespace port {
+
+static bool TryGetEnv(const char* name, const char** value) {
+  *value = getenv(name);
+  return *value != nullptr && (*value)[0] != '\0';
+}
+
+const char* Tracing::LogDir() {
+  const char* dir;
+  if (TryGetEnv("TEST_TMPDIR", &dir)) return dir;
+  if (TryGetEnv("TMP", &dir)) return dir;
+  if (TryGetEnv("TMPDIR", &dir)) return dir;
+  dir = "/tmp";
+  if (access(dir, R_OK | W_OK | X_OK) == 0) return dir;
+  return ".";  // Default to current directory.
+}
+
+}  // namespace port
+}  // namespace tensorflow
diff --git a/tensorflow/examples/skflow/BUILD b/tensorflow/examples/skflow/BUILD
index 5d6eae87459..7cac13df98f 100644
--- a/tensorflow/examples/skflow/BUILD
+++ b/tensorflow/examples/skflow/BUILD
@@ -231,7 +231,11 @@ sh_test(
     data = [
         ":boston",
         ":iris",
+        ":iris_custom_decay_dnn",
         ":iris_custom_model",
+        ":iris_run_config",
+        ":iris_val_based_early_stopping",
+        ":iris_with_pipeline",
         ":text_classification",
         ":text_classification_builtin_rnn_model",
         ":text_classification_character_cnn",
diff --git a/tensorflow/examples/skflow/examples_test.sh b/tensorflow/examples/skflow/examples_test.sh
index da6b35c9bb3..f4010c915e3 100755
--- a/tensorflow/examples/skflow/examples_test.sh
+++ b/tensorflow/examples/skflow/examples_test.sh
@@ -49,6 +49,10 @@ function test() {
 test boston
 test iris
 test iris_custom_model
+test iris_custom_decay_dnn
+test iris_run_config
+test iris_val_based_early_stopping
+test iris_with_pipeline
 test text_classification --test_with_fake_data
 test text_classification_builtin_rnn_model --test_with_fake_data
 test text_classification_cnn --test_with_fake_data
diff --git a/tensorflow/examples/skflow/iris_custom_decay_dnn.py b/tensorflow/examples/skflow/iris_custom_decay_dnn.py
index c1e7d22d53a..1ce6a830e4b 100644
--- a/tensorflow/examples/skflow/iris_custom_decay_dnn.py
+++ b/tensorflow/examples/skflow/iris_custom_decay_dnn.py
@@ -17,24 +17,29 @@ from __future__ import print_function
 
 from sklearn import datasets, metrics
 from sklearn.cross_validation import train_test_split
-
 import tensorflow as tf
 
-iris = datasets.load_iris()
-X_train, X_test, y_train, y_test = train_test_split(iris.data,
-                                                    iris.target,
-                                                    test_size=0.2,
-                                                    random_state=42)
-# setup exponential decay function
-def exp_decay(global_step):
-    return tf.train.exponential_decay(
-        learning_rate=0.1, global_step=global_step,
-        decay_steps=100, decay_rate=0.001)
 
-# use customized decay function in learning_rate
-optimizer = tf.train.AdagradOptimizer(learning_rate=exp_decay)
-classifier = tf.contrib.learn.DNNClassifier(hidden_units=[10, 20, 10],
-                                            n_classes=3,
-                                            optimizer=optimizer)
-classifier.fit(X_train, y_train, steps=800)
-score = metrics.accuracy_score(y_test, classifier.predict(X_test))
+def optimizer_exp_decay():
+  global_step = tf.contrib.framework.get_or_create_global_step()
+  learning_rate = tf.train.exponential_decay(
+      learning_rate=0.1, global_step=global_step,
+      decay_steps=100, decay_rate=0.001)
+  return tf.train.AdagradOptimizer(learning_rate=learning_rate)
+
+def main(unused_argv):
+  iris = datasets.load_iris()
+  x_train, x_test, y_train, y_test = train_test_split(
+      iris.data, iris.target, test_size=0.2, random_state=42)
+
+  classifier = tf.contrib.learn.DNNClassifier(hidden_units=[10, 20, 10],
+                                              n_classes=3,
+                                              optimizer=optimizer_exp_decay)
+
+  classifier.fit(x_train, y_train, steps=800)
+  score = metrics.accuracy_score(y_test, classifier.predict(x_test))
+  print('Accuracy: {0:f}'.format(score))
+
+
+if __name__ == '__main__':
+  tf.app.run()
diff --git a/tensorflow/examples/skflow/iris_run_config.py b/tensorflow/examples/skflow/iris_run_config.py
index dff0daf9e8c..c678c7c738c 100644
--- a/tensorflow/examples/skflow/iris_run_config.py
+++ b/tensorflow/examples/skflow/iris_run_config.py
@@ -16,24 +16,31 @@ from __future__ import division
 from __future__ import print_function
 
 from sklearn import datasets, metrics, cross_validation
-
-from tensorflow.contrib import learn
+import tensorflow as tf
 
 
-# Load dataset.
-iris = datasets.load_iris()
-X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data, iris.target,
-    test_size=0.2, random_state=42)
+def main(unused_argv):
+  # Load dataset.
+  iris = datasets.load_iris()
+  x_train, x_test, y_train, y_test = cross_validation.train_test_split(
+      iris.data, iris.target, test_size=0.2, random_state=42)
 
-# You can define you configurations by providing a RunConfig object to
-# estimator to control session configurations, e.g. num_cores and gpu_memory_fraction
-run_config = learn.estimators.RunConfig(num_cores=3, gpu_memory_fraction=0.6)
+  # You can define you configurations by providing a RunConfig object to
+  # estimator to control session configurations, e.g. num_cores
+  # and gpu_memory_fraction
+  run_config = tf.contrib.learn.estimators.RunConfig(
+      num_cores=3, gpu_memory_fraction=0.6)
 
-# Build 3 layer DNN with 10, 20, 10 units respectively.
-classifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10],
-    n_classes=3, steps=200, config=run_config)
+  # Build 3 layer DNN with 10, 20, 10 units respectively.
+  classifier = tf.contrib.learn.DNNClassifier(hidden_units=[10, 20, 10],
+                                              n_classes=3,
+                                              config=run_config)
 
-# Fit and predict.
-classifier.fit(X_train, y_train)
-score = metrics.accuracy_score(y_test, classifier.predict(X_test))
-print('Accuracy: {0:f}'.format(score))
+  # Fit and predict.
+  classifier.fit(x_train, y_train, steps=200)
+  score = metrics.accuracy_score(y_test, classifier.predict(x_test))
+  print('Accuracy: {0:f}'.format(score))
+
+
+if __name__ == '__main__':
+  tf.app.run()
diff --git a/tensorflow/examples/skflow/iris_val_based_early_stopping.py b/tensorflow/examples/skflow/iris_val_based_early_stopping.py
index 72e0595544f..05dfa96a077 100644
--- a/tensorflow/examples/skflow/iris_val_based_early_stopping.py
+++ b/tensorflow/examples/skflow/iris_val_based_early_stopping.py
@@ -34,21 +34,23 @@ def main(unused_argv):
       x_val, y_val, early_stopping_rounds=200)
 
   # classifier with early stopping on training data
-  classifier1 = learn.TensorFlowDNNClassifier(
+  classifier1 = learn.DNNClassifier(
       hidden_units=[10, 20, 10], n_classes=3, model_dir='/tmp/iris_model/')
   classifier1.fit(x=x_train, y=y_train, steps=2000)
   score1 = metrics.accuracy_score(y_test, classifier1.predict(x_test))
 
   # classifier with early stopping on validation data, save frequently for
   # monitor to pick up new checkpoints.
-  classifier2 = learn.TensorFlowDNNClassifier(
+  classifier2 = learn.DNNClassifier(
       hidden_units=[10, 20, 10], n_classes=3, model_dir='/tmp/iris_model_val/',
       config=tf.contrib.learn.RunConfig(save_checkpoints_secs=1))
   classifier2.fit(x=x_train, y=y_train, steps=2000, monitors=[val_monitor])
   score2 = metrics.accuracy_score(y_test, classifier2.predict(x_test))
 
   # In many applications, the score is improved by using early stopping
-  print(score2 > score1)
+  print('score1: ', score1)
+  print('score2: ', score2)
+  print('score2 > score1: ', score2 > score1)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/examples/skflow/iris_with_pipeline.py b/tensorflow/examples/skflow/iris_with_pipeline.py
index 3ba5739250e..5535cd9e3bf 100644
--- a/tensorflow/examples/skflow/iris_with_pipeline.py
+++ b/tensorflow/examples/skflow/iris_with_pipeline.py
@@ -20,22 +20,31 @@ from sklearn.datasets import load_iris
 from sklearn import cross_validation
 from sklearn.preprocessing import StandardScaler
 from sklearn.metrics import accuracy_score
+import tensorflow as tf
+
 from tensorflow.contrib import learn
 
-iris = load_iris()
-X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data, iris.target,
-    test_size=0.2, random_state=42)
 
-# It's useful to scale to ensure Stochastic Gradient Descent will do the right thing
-scaler = StandardScaler()
+def main(unused_argv):
+  iris = load_iris()
+  x_train, x_test, y_train, y_test = cross_validation.train_test_split(
+      iris.data, iris.target, test_size=0.2, random_state=42)
 
-# DNN classifier
-DNNclassifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10], n_classes=3, steps=200)
+  # It's useful to scale to ensure Stochastic Gradient Descent
+  # will do the right thing.
+  scaler = StandardScaler()
 
-pipeline = Pipeline([('scaler', scaler), ('DNNclassifier', DNNclassifier)])
+  # DNN classifier
+  classifier = learn.DNNClassifier(hidden_units=[10, 20, 10], n_classes=3)
 
-pipeline.fit(X_train, y_train)
+  pipeline = Pipeline([('scaler', scaler),
+                       ('DNNclassifier', classifier)])
 
-score = accuracy_score(y_test, pipeline.predict(X_test))
+  pipeline.fit(x_train, y_train, DNNclassifier__steps=200)
 
-print('Accuracy: {0:f}'.format(score))
+  score = accuracy_score(y_test, pipeline.predict(x_test))
+  print('Accuracy: {0:f}'.format(score))
+
+
+if __name__ == '__main__':
+  tf.app.run()
diff --git a/tensorflow/examples/tutorials/mnist/fully_connected_feed.py b/tensorflow/examples/tutorials/mnist/fully_connected_feed.py
index 773d4678276..5ab6024c2b8 100644
--- a/tensorflow/examples/tutorials/mnist/fully_connected_feed.py
+++ b/tensorflow/examples/tutorials/mnist/fully_connected_feed.py
@@ -19,6 +19,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import os.path
 import time
 
 from six.moves import xrange  # pylint: disable=redefined-builtin
@@ -198,7 +199,8 @@ def run_training():
 
       # Save a checkpoint and evaluate the model periodically.
       if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
-        saver.save(sess, FLAGS.train_dir, global_step=step)
+        checkpoint_file = os.path.join(FLAGS.train_dir, 'checkpoint')
+        saver.save(sess, checkpoint_file, global_step=step)
         # Evaluate against the training set.
         print('Training Data Eval:')
         do_eval(sess,
diff --git a/tensorflow/g3doc/api_docs/python/contrib.distributions.md b/tensorflow/g3doc/api_docs/python/contrib.distributions.md
index c2e67db7cb6..7bea8d72dd2 100644
--- a/tensorflow/g3doc/api_docs/python/contrib.distributions.md
+++ b/tensorflow/g3doc/api_docs/python/contrib.distributions.md
@@ -4022,6 +4022,323 @@ Variance of the distribution.
 
 
 
+### Transformed distributions
+
+- - -
+
+### `class tf.contrib.distributions.ContinuousTransformedDistribution` {#ContinuousTransformedDistribution}
+
+A Transformed Distribution.
+
+A Transformed Distribution models `p(y)` given a base distribution `p(x)`,
+an invertible transform, `y = f(x)`, and the determinant of the Jacobian of
+`f(x)`.
+
+Shapes, type, and reparameterization are taken from the base distribution.
+
+#### Mathematical details
+
+* `p(x)` - probability distribution for random variable X
+* `p(y)` - probability distribution for random variable Y
+* `f` - transform
+* `g` - inverse transform, `f(g(x)) = x`
+* `J(x)` - Jacobian of f(x)
+
+A Transformed Distribution exposes `sample` and `pdf`:
+
+  * `sample`: `y = f(x)`, after drawing a sample of X.
+  * `pdf`: `p(y) = p(x) / det|J(x)| = p(g(y)) / det|J(g(y))|`
+
+A simple example constructing a Log-Normal distribution from a Normal
+distribution:
+
+```
+logit_normal = ContinuousTransformedDistribution(
+  base_dist=Normal(mu, sigma),
+  transform=lambda x: tf.sigmoid(x),
+  inverse=lambda y: tf.log(y) - tf.log(1. - y),
+  log_det_jacobian=(lambda x:
+      tf.reduce_sum(tf.log(tf.sigmoid(x)) + tf.log(1. - tf.sigmoid(x)),
+                    reduction_indices=[-1])))
+  name="LogitNormalTransformedDistribution"
+)
+```
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.__init__(base_dist_cls, transform, inverse, log_det_jacobian, name='ContinuousTransformedDistribution', **base_dist_args)` {#ContinuousTransformedDistribution.__init__}
+
+Construct a Transformed Distribution.
+
+##### Args:
+
+
+*  <b>`base_dist_cls`</b>: the base distribution class to transform. Must be a
+      subclass of `ContinuousDistribution`.
+*  <b>`transform`</b>: a callable that takes a `Tensor` sample from `base_dist` and
+      returns a `Tensor` of the same shape and type. `x => y`.
+*  <b>`inverse`</b>: a callable that computes the inverse of transform. `y => x`. If
+      None, users can only call `log_pdf` on values returned by `sample`.
+*  <b>`log_det_jacobian`</b>: a callable that takes a `Tensor` sample from `base_dist`
+      and returns the log of the determinant of the Jacobian of `transform`.
+*  <b>`name`</b>: The name for the distribution.
+*  <b>`**base_dist_args`</b>: kwargs to pass on to dist_cls on construction.
+
+##### Raises:
+
+
+*  <b>`TypeError`</b>: if `base_dist_cls` is not a subclass of
+      `ContinuousDistribution`.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.base_distribution` {#ContinuousTransformedDistribution.base_distribution}
+
+Base distribution, p(x).
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.batch_shape(name='batch_shape')` {#ContinuousTransformedDistribution.batch_shape}
+
+Batch dimensions of this instance as a 1-D int32 `Tensor`.
+
+The product of the dimensions of the `batch_shape` is the number of
+independent distributions of this kind the instance represents.
+
+##### Args:
+
+
+*  <b>`name`</b>: name to give to the op.
+
+##### Returns:
+
+  `Tensor` `batch_shape`
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.cdf(value, name='cdf')` {#ContinuousTransformedDistribution.cdf}
+
+Cumulative distribution function.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.dtype` {#ContinuousTransformedDistribution.dtype}
+
+
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.entropy(name='entropy')` {#ContinuousTransformedDistribution.entropy}
+
+Entropy of the distribution in nats.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.event_shape(name='event_shape')` {#ContinuousTransformedDistribution.event_shape}
+
+Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
+
+##### Args:
+
+
+*  <b>`name`</b>: name to give to the op.
+
+##### Returns:
+
+  `Tensor` `event_shape`
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.get_batch_shape()` {#ContinuousTransformedDistribution.get_batch_shape}
+
+`TensorShape` available at graph construction time.
+
+Same meaning as `batch_shape`. May be only partially defined.
+
+##### Returns:
+
+  batch shape
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.get_event_shape()` {#ContinuousTransformedDistribution.get_event_shape}
+
+`TensorShape` available at graph construction time.
+
+Same meaning as `event_shape`. May be only partially defined.
+
+##### Returns:
+
+  event shape
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.inverse` {#ContinuousTransformedDistribution.inverse}
+
+Inverse function of transform, y => x.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.is_reparameterized` {#ContinuousTransformedDistribution.is_reparameterized}
+
+
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.log_cdf(value, name='log_cdf')` {#ContinuousTransformedDistribution.log_cdf}
+
+Log CDF.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.log_det_jacobian` {#ContinuousTransformedDistribution.log_det_jacobian}
+
+Function computing the log determinant of the Jacobian of transform.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.log_likelihood(value, name='log_likelihood')` {#ContinuousTransformedDistribution.log_likelihood}
+
+Log likelihood of this distribution (same as log_pdf).
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.log_pdf(y, name='log_pdf')` {#ContinuousTransformedDistribution.log_pdf}
+
+Log pdf of observations in `y`.
+
+`log ( p(g(y)) / det|J(g(y))| )`, where `g` is the inverse of `transform`.
+
+##### Args:
+
+
+*  <b>`y`</b>: tensor of dtype `dtype`.
+*  <b>`name`</b>: The name to give this op.
+
+##### Returns:
+
+
+*  <b>`log_pdf`</b>: tensor of dtype `dtype`, the log-PDFs of `y`.
+
+##### Raises:
+
+
+*  <b>`ValueError`</b>: if `inverse` was not provided to the distribution and `y` was
+      not returned from `sample`.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.mean(name='mean')` {#ContinuousTransformedDistribution.mean}
+
+Mean of the distribution.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.mode(name='mode')` {#ContinuousTransformedDistribution.mode}
+
+Mode of the distribution.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.name` {#ContinuousTransformedDistribution.name}
+
+
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.pdf(y, name='pdf')` {#ContinuousTransformedDistribution.pdf}
+
+The PDF of observations in `y`.
+
+`p(g(y)) / det|J(g(y))|`, where `g` is the inverse of `transform`.
+
+##### Args:
+
+
+*  <b>`y`</b>: `Tensor` of dtype `dtype`.
+*  <b>`name`</b>: The name to give this op.
+
+##### Returns:
+
+
+*  <b>`pdf`</b>: `Tensor` of dtype `dtype`, the pdf values of `y`.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.sample(n, seed=None, name='sample')` {#ContinuousTransformedDistribution.sample}
+
+Sample `n` observations.
+
+Samples from the base distribution and then passes through the transform.
+
+##### Args:
+
+
+*  <b>`n`</b>: scalar, type int32, the number of observations to sample.
+*  <b>`seed`</b>: Python integer, the random seed.
+*  <b>`name`</b>: The name to give this op.
+
+##### Returns:
+
+
+*  <b>`samples`</b>: `[n, ...]`, a `Tensor` of `n` samples.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.std(name='std')` {#ContinuousTransformedDistribution.std}
+
+Standard deviation of the distribution.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.strict` {#ContinuousTransformedDistribution.strict}
+
+
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.strict_statistics` {#ContinuousTransformedDistribution.strict_statistics}
+
+
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.transform` {#ContinuousTransformedDistribution.transform}
+
+Function transforming x => y.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.variance(name='variance')` {#ContinuousTransformedDistribution.variance}
+
+Variance of the distribution.
+
+
+
+
 ## Operators allowing for matrix-free methods
 
 ### Positive definite operators
diff --git a/tensorflow/g3doc/api_docs/python/contrib.framework.md b/tensorflow/g3doc/api_docs/python/contrib.framework.md
index 4c234555975..df4df30d199 100644
--- a/tensorflow/g3doc/api_docs/python/contrib.framework.md
+++ b/tensorflow/g3doc/api_docs/python/contrib.framework.md
@@ -227,50 +227,49 @@ adds them via `tf.add_n`.
 
 - - -
 
-### `tf.contrib.framework.safe_embedding_lookup_sparse(embedding_weights, sparse_ids, sparse_weights=None, combiner='mean', default_id=None, name=None, partition_strategy='div')` {#safe_embedding_lookup_sparse}
+### `tf.contrib.framework.safe_embedding_lookup_sparse(*args, **kwargs)` {#safe_embedding_lookup_sparse}
 
-Lookup embedding results, accounting for invalid IDs and empty features.
+Lookup embedding results, accounting for invalid IDs and empty features. (deprecated)
 
-The partitioned embedding in `embedding_weights` must all be the same shape
-except for the first dimension. The first dimension is allowed to vary as the
-vocabulary size is not necessarily a multiple of `P`.
+THIS FUNCTION IS DEPRECATED. It will be removed after 2016-09-01.
+Instructions for updating:
+Please use tf.contrib.layers.safe_embedding_lookup_sparse.
 
-Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
-with non-positive weight. For an entry with no features, the embedding vector
-for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
+  The partitioned embedding in `embedding_weights` must all be the same shape
+  except for the first dimension. The first dimension is allowed to vary as the
+  vocabulary size is not necessarily a multiple of `P`.
 
-The ids and weights may be multi-dimensional. Embeddings are always aggregated
-along the last dimension.
+  Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
+  with non-positive weight. For an entry with no features, the embedding vector
+  for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
 
-##### Args:
+  The ids and weights may be multi-dimensional. Embeddings are always aggregated
+  along the last dimension.
+
+  Args:
+    embedding_weights:  A list of `P` float tensors or values representing
+        partitioned embedding tensors.  The total unpartitioned shape should be
+        `[e_0, e_1, ..., e_m]`, where `e_0` represents the vocab size and
+        `e_1, ..., e_m` are the embedding dimensions.
+    sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
+        ids. `d_0` is typically batch size.
+    sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
+        float weights corresponding to `sparse_ids`, or `None` if all weights
+        are be assumed to be 1.0.
+    combiner: A string specifying how to combine embedding results for each
+        entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
+        the default.
+    default_id: The id to use for an entry with no features.
+    name: A name for this operation (optional).
+    partition_strategy: A string specifying the partitioning strategy.
+        Currently `"div"` and `"mod"` are supported. Default is `"div"`.
 
 
-*  <b>`embedding_weights`</b>: A list of `P` float tensors or values representing
-      partitioned embedding tensors.  The total unpartitioned shape should be
-      `[e_0, e_1, ..., e_m]`, where `e_0` represents the vocab size and
-      `e_1, ..., e_m` are the embedding dimensions.
-*  <b>`sparse_ids`</b>: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
-      ids. `d_0` is typically batch size.
-*  <b>`sparse_weights`</b>: `SparseTensor` of same shape as `sparse_ids`, containing
-      float weights corresponding to `sparse_ids`, or `None` if all weights
-      are be assumed to be 1.0.
-*  <b>`combiner`</b>: A string specifying how to combine embedding results for each
-      entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
-      the default.
-*  <b>`default_id`</b>: The id to use for an entry with no features.
-*  <b>`name`</b>: A name for this operation (optional).
-*  <b>`partition_strategy`</b>: A string specifying the partitioning strategy.
-      Currently `"div"` and `"mod"` are supported. Default is `"div"`.
+  Returns:
+    Dense tensor of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`.
 
-
-##### Returns:
-
-  Dense tensor of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`.
-
-##### Raises:
-
-
-*  <b>`ValueError`</b>: if `embedding_weights` is empty.
+  Raises:
+    ValueError: if `embedding_weights` is empty.
 
 
 - - -
diff --git a/tensorflow/g3doc/api_docs/python/contrib.learn.md b/tensorflow/g3doc/api_docs/python/contrib.learn.md
index b573b02bd71..25bfe07e7a0 100644
--- a/tensorflow/g3doc/api_docs/python/contrib.learn.md
+++ b/tensorflow/g3doc/api_docs/python/contrib.learn.md
@@ -5083,7 +5083,7 @@ Use `parse_fn` if you need to do parsing / processing on single examples.
 
 - - -
 
-### `tf.contrib.learn.read_batch_features(file_pattern, batch_size, features, reader, randomize_input=True, num_epochs=None, queue_capacity=10000, reader_num_threads=1, parser_num_threads=1, read_batch_size=1, name=None)` {#read_batch_features}
+### `tf.contrib.learn.read_batch_features(file_pattern, batch_size, features, reader, randomize_input=True, num_epochs=None, queue_capacity=10000, reader_num_threads=1, parser_num_threads=1, name=None)` {#read_batch_features}
 
 Adds operations to read, queue, batch and parse `Example` protos.
 
@@ -5115,7 +5115,6 @@ All ops are added to the default graph.
 *  <b>`queue_capacity`</b>: Capacity for input queue.
 *  <b>`reader_num_threads`</b>: The number of threads to read examples.
 *  <b>`parser_num_threads`</b>: The number of threads to parse examples.
-*  <b>`read_batch_size`</b>: An int or scalar `Tensor` specifying the number of
     records to read at once
 *  <b>`name`</b>: Name of resulting op.
 
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.distributions.ContinuousTransformedDistribution.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.distributions.ContinuousTransformedDistribution.md
new file mode 100644
index 00000000000..1cda4145d6d
--- /dev/null
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.distributions.ContinuousTransformedDistribution.md
@@ -0,0 +1,309 @@
+A Transformed Distribution.
+
+A Transformed Distribution models `p(y)` given a base distribution `p(x)`,
+an invertible transform, `y = f(x)`, and the determinant of the Jacobian of
+`f(x)`.
+
+Shapes, type, and reparameterization are taken from the base distribution.
+
+#### Mathematical details
+
+* `p(x)` - probability distribution for random variable X
+* `p(y)` - probability distribution for random variable Y
+* `f` - transform
+* `g` - inverse transform, `f(g(x)) = x`
+* `J(x)` - Jacobian of f(x)
+
+A Transformed Distribution exposes `sample` and `pdf`:
+
+  * `sample`: `y = f(x)`, after drawing a sample of X.
+  * `pdf`: `p(y) = p(x) / det|J(x)| = p(g(y)) / det|J(g(y))|`
+
+A simple example constructing a Log-Normal distribution from a Normal
+distribution:
+
+```
+logit_normal = ContinuousTransformedDistribution(
+  base_dist=Normal(mu, sigma),
+  transform=lambda x: tf.sigmoid(x),
+  inverse=lambda y: tf.log(y) - tf.log(1. - y),
+  log_det_jacobian=(lambda x:
+      tf.reduce_sum(tf.log(tf.sigmoid(x)) + tf.log(1. - tf.sigmoid(x)),
+                    reduction_indices=[-1])))
+  name="LogitNormalTransformedDistribution"
+)
+```
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.__init__(base_dist_cls, transform, inverse, log_det_jacobian, name='ContinuousTransformedDistribution', **base_dist_args)` {#ContinuousTransformedDistribution.__init__}
+
+Construct a Transformed Distribution.
+
+##### Args:
+
+
+*  <b>`base_dist_cls`</b>: the base distribution class to transform. Must be a
+      subclass of `ContinuousDistribution`.
+*  <b>`transform`</b>: a callable that takes a `Tensor` sample from `base_dist` and
+      returns a `Tensor` of the same shape and type. `x => y`.
+*  <b>`inverse`</b>: a callable that computes the inverse of transform. `y => x`. If
+      None, users can only call `log_pdf` on values returned by `sample`.
+*  <b>`log_det_jacobian`</b>: a callable that takes a `Tensor` sample from `base_dist`
+      and returns the log of the determinant of the Jacobian of `transform`.
+*  <b>`name`</b>: The name for the distribution.
+*  <b>`**base_dist_args`</b>: kwargs to pass on to dist_cls on construction.
+
+##### Raises:
+
+
+*  <b>`TypeError`</b>: if `base_dist_cls` is not a subclass of
+      `ContinuousDistribution`.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.base_distribution` {#ContinuousTransformedDistribution.base_distribution}
+
+Base distribution, p(x).
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.batch_shape(name='batch_shape')` {#ContinuousTransformedDistribution.batch_shape}
+
+Batch dimensions of this instance as a 1-D int32 `Tensor`.
+
+The product of the dimensions of the `batch_shape` is the number of
+independent distributions of this kind the instance represents.
+
+##### Args:
+
+
+*  <b>`name`</b>: name to give to the op.
+
+##### Returns:
+
+  `Tensor` `batch_shape`
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.cdf(value, name='cdf')` {#ContinuousTransformedDistribution.cdf}
+
+Cumulative distribution function.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.dtype` {#ContinuousTransformedDistribution.dtype}
+
+
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.entropy(name='entropy')` {#ContinuousTransformedDistribution.entropy}
+
+Entropy of the distribution in nats.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.event_shape(name='event_shape')` {#ContinuousTransformedDistribution.event_shape}
+
+Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
+
+##### Args:
+
+
+*  <b>`name`</b>: name to give to the op.
+
+##### Returns:
+
+  `Tensor` `event_shape`
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.get_batch_shape()` {#ContinuousTransformedDistribution.get_batch_shape}
+
+`TensorShape` available at graph construction time.
+
+Same meaning as `batch_shape`. May be only partially defined.
+
+##### Returns:
+
+  batch shape
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.get_event_shape()` {#ContinuousTransformedDistribution.get_event_shape}
+
+`TensorShape` available at graph construction time.
+
+Same meaning as `event_shape`. May be only partially defined.
+
+##### Returns:
+
+  event shape
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.inverse` {#ContinuousTransformedDistribution.inverse}
+
+Inverse function of transform, y => x.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.is_reparameterized` {#ContinuousTransformedDistribution.is_reparameterized}
+
+
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.log_cdf(value, name='log_cdf')` {#ContinuousTransformedDistribution.log_cdf}
+
+Log CDF.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.log_det_jacobian` {#ContinuousTransformedDistribution.log_det_jacobian}
+
+Function computing the log determinant of the Jacobian of transform.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.log_likelihood(value, name='log_likelihood')` {#ContinuousTransformedDistribution.log_likelihood}
+
+Log likelihood of this distribution (same as log_pdf).
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.log_pdf(y, name='log_pdf')` {#ContinuousTransformedDistribution.log_pdf}
+
+Log pdf of observations in `y`.
+
+`log ( p(g(y)) / det|J(g(y))| )`, where `g` is the inverse of `transform`.
+
+##### Args:
+
+
+*  <b>`y`</b>: tensor of dtype `dtype`.
+*  <b>`name`</b>: The name to give this op.
+
+##### Returns:
+
+
+*  <b>`log_pdf`</b>: tensor of dtype `dtype`, the log-PDFs of `y`.
+
+##### Raises:
+
+
+*  <b>`ValueError`</b>: if `inverse` was not provided to the distribution and `y` was
+      not returned from `sample`.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.mean(name='mean')` {#ContinuousTransformedDistribution.mean}
+
+Mean of the distribution.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.mode(name='mode')` {#ContinuousTransformedDistribution.mode}
+
+Mode of the distribution.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.name` {#ContinuousTransformedDistribution.name}
+
+
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.pdf(y, name='pdf')` {#ContinuousTransformedDistribution.pdf}
+
+The PDF of observations in `y`.
+
+`p(g(y)) / det|J(g(y))|`, where `g` is the inverse of `transform`.
+
+##### Args:
+
+
+*  <b>`y`</b>: `Tensor` of dtype `dtype`.
+*  <b>`name`</b>: The name to give this op.
+
+##### Returns:
+
+
+*  <b>`pdf`</b>: `Tensor` of dtype `dtype`, the pdf values of `y`.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.sample(n, seed=None, name='sample')` {#ContinuousTransformedDistribution.sample}
+
+Sample `n` observations.
+
+Samples from the base distribution and then passes through the transform.
+
+##### Args:
+
+
+*  <b>`n`</b>: scalar, type int32, the number of observations to sample.
+*  <b>`seed`</b>: Python integer, the random seed.
+*  <b>`name`</b>: The name to give this op.
+
+##### Returns:
+
+
+*  <b>`samples`</b>: `[n, ...]`, a `Tensor` of `n` samples.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.std(name='std')` {#ContinuousTransformedDistribution.std}
+
+Standard deviation of the distribution.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.strict` {#ContinuousTransformedDistribution.strict}
+
+
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.strict_statistics` {#ContinuousTransformedDistribution.strict_statistics}
+
+
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.transform` {#ContinuousTransformedDistribution.transform}
+
+Function transforming x => y.
+
+
+- - -
+
+#### `tf.contrib.distributions.ContinuousTransformedDistribution.variance(name='variance')` {#ContinuousTransformedDistribution.variance}
+
+Variance of the distribution.
+
+
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.framework.safe_embedding_lookup_sparse.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.framework.safe_embedding_lookup_sparse.md
index f56043cc0dc..3d5491c98aa 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.framework.safe_embedding_lookup_sparse.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.framework.safe_embedding_lookup_sparse.md
@@ -1,45 +1,44 @@
-### `tf.contrib.framework.safe_embedding_lookup_sparse(embedding_weights, sparse_ids, sparse_weights=None, combiner='mean', default_id=None, name=None, partition_strategy='div')` {#safe_embedding_lookup_sparse}
+### `tf.contrib.framework.safe_embedding_lookup_sparse(*args, **kwargs)` {#safe_embedding_lookup_sparse}
 
-Lookup embedding results, accounting for invalid IDs and empty features.
+Lookup embedding results, accounting for invalid IDs and empty features. (deprecated)
 
-The partitioned embedding in `embedding_weights` must all be the same shape
-except for the first dimension. The first dimension is allowed to vary as the
-vocabulary size is not necessarily a multiple of `P`.
+THIS FUNCTION IS DEPRECATED. It will be removed after 2016-09-01.
+Instructions for updating:
+Please use tf.contrib.layers.safe_embedding_lookup_sparse.
 
-Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
-with non-positive weight. For an entry with no features, the embedding vector
-for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
+  The partitioned embedding in `embedding_weights` must all be the same shape
+  except for the first dimension. The first dimension is allowed to vary as the
+  vocabulary size is not necessarily a multiple of `P`.
 
-The ids and weights may be multi-dimensional. Embeddings are always aggregated
-along the last dimension.
+  Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
+  with non-positive weight. For an entry with no features, the embedding vector
+  for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
 
-##### Args:
+  The ids and weights may be multi-dimensional. Embeddings are always aggregated
+  along the last dimension.
+
+  Args:
+    embedding_weights:  A list of `P` float tensors or values representing
+        partitioned embedding tensors.  The total unpartitioned shape should be
+        `[e_0, e_1, ..., e_m]`, where `e_0` represents the vocab size and
+        `e_1, ..., e_m` are the embedding dimensions.
+    sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
+        ids. `d_0` is typically batch size.
+    sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
+        float weights corresponding to `sparse_ids`, or `None` if all weights
+        are be assumed to be 1.0.
+    combiner: A string specifying how to combine embedding results for each
+        entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
+        the default.
+    default_id: The id to use for an entry with no features.
+    name: A name for this operation (optional).
+    partition_strategy: A string specifying the partitioning strategy.
+        Currently `"div"` and `"mod"` are supported. Default is `"div"`.
 
 
-*  <b>`embedding_weights`</b>: A list of `P` float tensors or values representing
-      partitioned embedding tensors.  The total unpartitioned shape should be
-      `[e_0, e_1, ..., e_m]`, where `e_0` represents the vocab size and
-      `e_1, ..., e_m` are the embedding dimensions.
-*  <b>`sparse_ids`</b>: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
-      ids. `d_0` is typically batch size.
-*  <b>`sparse_weights`</b>: `SparseTensor` of same shape as `sparse_ids`, containing
-      float weights corresponding to `sparse_ids`, or `None` if all weights
-      are be assumed to be 1.0.
-*  <b>`combiner`</b>: A string specifying how to combine embedding results for each
-      entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
-      the default.
-*  <b>`default_id`</b>: The id to use for an entry with no features.
-*  <b>`name`</b>: A name for this operation (optional).
-*  <b>`partition_strategy`</b>: A string specifying the partitioning strategy.
-      Currently `"div"` and `"mod"` are supported. Default is `"div"`.
+  Returns:
+    Dense tensor of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`.
 
-
-##### Returns:
-
-  Dense tensor of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`.
-
-##### Raises:
-
-
-*  <b>`ValueError`</b>: if `embedding_weights` is empty.
+  Raises:
+    ValueError: if `embedding_weights` is empty.
 
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.learn.read_batch_features.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.learn.read_batch_features.md
index 6327760ca08..d18c316080e 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.learn.read_batch_features.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.learn.read_batch_features.md
@@ -1,4 +1,4 @@
-### `tf.contrib.learn.read_batch_features(file_pattern, batch_size, features, reader, randomize_input=True, num_epochs=None, queue_capacity=10000, reader_num_threads=1, parser_num_threads=1, read_batch_size=1, name=None)` {#read_batch_features}
+### `tf.contrib.learn.read_batch_features(file_pattern, batch_size, features, reader, randomize_input=True, num_epochs=None, queue_capacity=10000, reader_num_threads=1, parser_num_threads=1, name=None)` {#read_batch_features}
 
 Adds operations to read, queue, batch and parse `Example` protos.
 
@@ -30,7 +30,6 @@ All ops are added to the default graph.
 *  <b>`queue_capacity`</b>: Capacity for input queue.
 *  <b>`reader_num_threads`</b>: The number of threads to read examples.
 *  <b>`parser_num_threads`</b>: The number of threads to parse examples.
-*  <b>`read_batch_size`</b>: An int or scalar `Tensor` specifying the number of
     records to read at once
 *  <b>`name`</b>: Name of resulting op.
 
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.train.Coordinator.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.train.Coordinator.md
index f51c0721ff0..bebc34754f4 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.train.Coordinator.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.train.Coordinator.md
@@ -93,10 +93,21 @@ except Exception:
 ```
 - - -
 
-#### `tf.train.Coordinator.__init__()` {#Coordinator.__init__}
+#### `tf.train.Coordinator.__init__(clean_stop_exception_types=None)` {#Coordinator.__init__}
 
 Create a new Coordinator.
 
+##### Args:
+
+
+*  <b>`clean_stop_exception_types`</b>: Optional tuple of Exception types that should
+    cause a clean stop of the coordinator. If an exception of one of these
+    types is reported to `request_stop(ex)` the coordinator will behave as
+    if `request_stop(None)` was called.  Defaults to
+    `(tf.errors.OutOfRangeError,)` which is used by input queues to signal
+    the end of input. When feeding training data from a Python iterator it
+    is common to add `StopIteration` to this list.
+
 
 - - -
 
diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md
index 5c4f7107d19..53a971e694a 100644
--- a/tensorflow/g3doc/api_docs/python/index.md
+++ b/tensorflow/g3doc/api_docs/python/index.md
@@ -584,6 +584,7 @@
   * [`Categorical`](../../api_docs/python/contrib.distributions.md#Categorical)
   * [`Chi2`](../../api_docs/python/contrib.distributions.md#Chi2)
   * [`ContinuousDistribution`](../../api_docs/python/contrib.distributions.md#ContinuousDistribution)
+  * [`ContinuousTransformedDistribution`](../../api_docs/python/contrib.distributions.md#ContinuousTransformedDistribution)
   * [`DirichletMultinomial`](../../api_docs/python/contrib.distributions.md#DirichletMultinomial)
   * [`DiscreteDistribution`](../../api_docs/python/contrib.distributions.md#DiscreteDistribution)
   * [`Exponential`](../../api_docs/python/contrib.distributions.md#Exponential)
diff --git a/tensorflow/g3doc/api_docs/python/train.md b/tensorflow/g3doc/api_docs/python/train.md
index 64c8085c875..4f63fa0b7c1 100644
--- a/tensorflow/g3doc/api_docs/python/train.md
+++ b/tensorflow/g3doc/api_docs/python/train.md
@@ -1214,10 +1214,21 @@ except Exception:
 ```
 - - -
 
-#### `tf.train.Coordinator.__init__()` {#Coordinator.__init__}
+#### `tf.train.Coordinator.__init__(clean_stop_exception_types=None)` {#Coordinator.__init__}
 
 Create a new Coordinator.
 
+##### Args:
+
+
+*  <b>`clean_stop_exception_types`</b>: Optional tuple of Exception types that should
+    cause a clean stop of the coordinator. If an exception of one of these
+    types is reported to `request_stop(ex)` the coordinator will behave as
+    if `request_stop(None)` was called.  Defaults to
+    `(tf.errors.OutOfRangeError,)` which is used by input queues to signal
+    the end of input. When feeding training data from a Python iterator it
+    is common to add `StopIteration` to this list.
+
 
 - - -
 
diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md
index 158c84b4ef0..e1cece4faa3 100644
--- a/tensorflow/g3doc/get_started/os_setup.md
+++ b/tensorflow/g3doc/get_started/os_setup.md
@@ -63,7 +63,7 @@ Then, select the correct binary to install:
 # Ubuntu/Linux 64-bit, CPU only, Python 2.7
 $ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl
 
-# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7 
+# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
 # Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
 $ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl
 
@@ -73,14 +73,14 @@ $ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/tensorflow-
 # Ubuntu/Linux 64-bit, CPU only, Python 3.4
 $ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl
 
-# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4 
+# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
 # Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
 $ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl
 
 # Ubuntu/Linux 64-bit, CPU only, Python 3.5
 $ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl
 
-# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5 
+# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
 # Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
 $ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl
 
@@ -153,7 +153,7 @@ Now, install TensorFlow just as you would for a regular Pip installation. First
 # Ubuntu/Linux 64-bit, CPU only, Python 2.7
 (tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl
 
-# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7 
+# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
 # Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
 (tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl
 
@@ -163,14 +163,14 @@ Now, install TensorFlow just as you would for a regular Pip installation. First
 # Ubuntu/Linux 64-bit, CPU only, Python 3.4
 (tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl
 
-# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4 
+# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
 # Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
 (tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl
 
 # Ubuntu/Linux 64-bit, CPU only, Python 3.5
 (tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl
 
-# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5 
+# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
 # Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
 (tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl
 
@@ -277,7 +277,7 @@ Now, install TensorFlow just as you would for a regular Pip installation. First
 # Ubuntu/Linux 64-bit, CPU only, Python 2.7
 (tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl
 
-# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7 
+# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
 # Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
 (tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl
 
@@ -287,14 +287,14 @@ Now, install TensorFlow just as you would for a regular Pip installation. First
 # Ubuntu/Linux 64-bit, CPU only, Python 3.4
 (tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl
 
-# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4 
+# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
 # Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
 (tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl
 
 # Ubuntu/Linux 64-bit, CPU only, Python 3.5
 (tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl
 
-# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5 
+# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
 # Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
 (tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl
 
@@ -952,6 +952,14 @@ SyntaxError: invalid syntax
 
 Solution: make sure you are using Python 2.7.
 
+#### Ubuntu build issue on Linux 16.04 when building with --config=cuda: build fail with cuda: identifier "__builtin_ia32_mwaitx" is undefined.
+GitHub issue: https://github.com/tensorflow/tensorflow/issues/1066
+
+Solution: Add the following compiler flags to third_party/gpus/crosstool/CROSSTOOL
+
+cxx_flag: "-D_MWAITXINTRIN_H_INCLUDED"
+cxx_flag: "-D_FORCE_INLINES"
+
 ### Mac OS X: ImportError: No module named copyreg
 
 On Mac OS X, you may encounter the following when importing tensorflow.
diff --git a/tensorflow/g3doc/images/wide_n_deep.svg b/tensorflow/g3doc/images/wide_n_deep.svg
deleted file mode 100644
index 6dfe9e7f102..00000000000
--- a/tensorflow/g3doc/images/wide_n_deep.svg
+++ /dev/null
@@ -1,1540 +0,0 @@
-<?xml version="1.0" encoding="UTF-8" standalone="no"?>
-<svg
-   xmlns:dc="http://purl.org/dc/elements/1.1/"
-   xmlns:cc="http://creativecommons.org/ns#"
-   xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
-   xmlns:svg="http://www.w3.org/2000/svg"
-   xmlns="http://www.w3.org/2000/svg"
-   xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
-   xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
-   version="1.1"
-   viewBox="0 0 952.2796 201.30696"
-   stroke-miterlimit="10"
-   id="svg4775"
-   inkscape:version="0.91 r13725"
-   sodipodi:docname="wide_n_deep_resized.svg"
-   width="952.2796"
-   height="201.30696"
-   style="fill:none;stroke:none;stroke-linecap:square;stroke-miterlimit:10">
-  <metadata
-     id="metadata5374">
-    <rdf:RDF>
-      <cc:Work
-         rdf:about="">
-        <dc:format>image/svg+xml</dc:format>
-        <dc:type
-           rdf:resource="http://purl.org/dc/dcmitype/StillImage" />
-        <dc:title></dc:title>
-      </cc:Work>
-    </rdf:RDF>
-  </metadata>
-  <defs
-     id="defs5372" />
-  <sodipodi:namedview
-     pagecolor="#ffffff"
-     bordercolor="#666666"
-     borderopacity="1"
-     objecttolerance="10"
-     gridtolerance="10"
-     guidetolerance="10"
-     inkscape:pageopacity="0"
-     inkscape:pageshadow="2"
-     inkscape:window-width="1421"
-     inkscape:window-height="797"
-     id="namedview5370"
-     showgrid="false"
-     fit-margin-top="0"
-     fit-margin-left="0"
-     fit-margin-right="0"
-     fit-margin-bottom="0"
-     inkscape:zoom="0.90138889"
-     inkscape:cx="430.75268"
-     inkscape:cy="135.99525"
-     inkscape:window-x="1"
-     inkscape:window-y="20"
-     inkscape:window-maximized="0"
-     inkscape:current-layer="g4780" />
-  <clipPath
-     id="p.0">
-    <path
-       d="M 0,0 960,0 960,720 0,720 0,0 Z"
-       id="path4778"
-       inkscape:connector-curvature="0"
-       style="clip-rule:nonzero" />
-  </clipPath>
-  <g
-     clip-path="url(#p.0)"
-     id="g4780"
-     transform="translate(-4.8713584,-250.31233)">
-    <path
-       d="m 0,0 960,0 0,720 -960,0 z"
-       id="path4782"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 2.6456692,433.43008 5.3765955,-12.23624 0,0 941.0262953,0 0,0 5.37659,12.23624 -5.37659,12.23621 0,0 -941.0262953,0 0,0 z"
-       id="path4784"
-       inkscape:connector-curvature="0"
-       style="fill:#efefef;fill-rule:nonzero" />
-    <path
-       d="m 393.94235,353.87927 562.7086,0 0,34.48819 -562.7086,0 z"
-       id="path4786"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 393.94235,353.87927 562.7086,0 0,34.48819 -562.7086,0 z"
-       id="path4788"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#b7b7b7;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round;stroke-dasharray:4, 3" />
-    <path
-       d="m 86.80062,252.30708 773.41736,0 0,30.11024 -773.41736,0 z"
-       id="path4790"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 86.80062,252.30708 773.41736,0 0,30.11024 -773.41736,0 z"
-       id="path4792"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#b7b7b7;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round;stroke-dasharray:4, 3" />
-    <path
-       d="m 430.66415,289.09183 484.09445,0 0,58.11023 -484.09445,0 z"
-       id="path4794"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 430.66415,289.09183 484.09445,0 0,58.11023 -484.09445,0 z"
-       id="path4796"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#b7b7b7;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round;stroke-dasharray:4, 3" />
-    <path
-       d="m 4.8713584,391.71652 952.1575016,0 0,24.47244 -952.1575016,0 z"
-       id="path4798"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 4.8713584,391.71652 952.1575016,0 0,24.47244 -952.1575016,0 z"
-       id="path4800"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#b7b7b7;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round;stroke-dasharray:4, 3" />
-    <path
-       d="m 8.301801,400.93906 0,0 c 0,-4.25455 3.448989,-7.70352 7.703532,-7.70352 l 0,0 c 2.043104,0 4.002527,0.81161 5.44722,2.25632 1.444693,1.44467 2.256311,3.40411 2.256311,5.4472 l 0,0 c 0,4.25455 -3.448988,7.70355 -7.703531,7.70355 l 0,0 c -4.254543,0 -7.7035319,-3.449 -7.7035319,-7.70355 z"
-       id="path4802"
-       inkscape:connector-curvature="0"
-       style="fill:#999999;fill-rule:nonzero" />
-    <path
-       d="m 33.63098,400.94022 0,0 c 0,-4.25452 3.44899,-7.70352 7.703533,-7.70352 l 0,0 c 2.043102,0 4.002525,0.81161 5.44722,2.25632 1.444691,1.4447 2.256313,3.40411 2.256313,5.4472 l 0,0 c 0,4.25455 -3.44899,7.70355 -7.703533,7.70355 l 0,0 c -4.254543,0 -7.703533,-3.449 -7.703533,-7.70355 z"
-       id="path4804"
-       inkscape:connector-curvature="0"
-       style="fill:#999999;fill-rule:nonzero" />
-    <path
-       d="M 117.4221,273.73578 41.32435,393.228"
-       id="path4806"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="M 117.4221,273.73578 41.32435,393.228"
-       id="path4808"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="M 117.4221,273.73578 16.00975,393.228"
-       id="path4810"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="M 117.4221,273.73578 16.00975,393.228"
-       id="path4812"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 57.94737,400.93848 0,0 c 0,-4.25455 3.44899,-7.70352 7.703533,-7.70352 l 0,0 c 2.043106,0 4.002525,0.81161 5.44722,2.25632 1.444694,1.44467 2.25631,3.40411 2.25631,5.4472 l 0,0 c 0,4.25455 -3.44899,7.70352 -7.70353,7.70352 l 0,0 c -4.254543,0 -7.703533,-3.44897 -7.703533,-7.70352 z"
-       id="path4814"
-       inkscape:connector-curvature="0"
-       style="fill:#f1c232;fill-rule:nonzero" />
-    <path
-       d="m 82.263756,400.93848 0,0 c 0,-4.25455 3.44899,-7.70352 7.703529,-7.70352 l 0,0 c 2.043106,0 4.002533,0.81161 5.447228,2.25632 1.444687,1.44467 2.256309,3.40411 2.256309,5.4472 l 0,0 c 0,4.25455 -3.44899,7.70352 -7.703537,7.70352 l 0,0 c -4.254539,0 -7.703529,-3.44897 -7.703529,-7.70352 z"
-       id="path4816"
-       inkscape:connector-curvature="0"
-       style="fill:#999999;fill-rule:nonzero" />
-    <path
-       d="m 108.60574,400.93848 0,0 c 0,-4.25455 3.44898,-7.70352 7.70353,-7.70352 l 0,0 c 2.0431,0 4.00252,0.81161 5.44722,2.25632 1.44469,1.44467 2.2563,3.40411 2.2563,5.4472 l 0,0 c 0,4.25455 -3.44899,7.70352 -7.70352,7.70352 l 0,0 c -4.25455,0 -7.70353,-3.44897 -7.70353,-7.70352 z"
-       id="path4818"
-       inkscape:connector-curvature="0"
-       style="fill:#999999;fill-rule:nonzero" />
-    <path
-       d="m 132.92212,400.93848 0,0 c 0,-4.25455 3.44899,-7.70352 7.70354,-7.70352 l 0,0 c 2.0431,0 4.00251,0.81161 5.44722,2.25632 1.44468,1.44467 2.2563,3.40411 2.2563,5.4472 l 0,0 c 0,4.25455 -3.44899,7.70352 -7.70352,7.70352 l 0,0 c -4.25455,0 -7.70354,-3.44897 -7.70354,-7.70352 z"
-       id="path4820"
-       inkscape:connector-curvature="0"
-       style="fill:#999999;fill-rule:nonzero" />
-    <path
-       d="m 157.23851,400.93848 0,0 c 0,-4.25455 3.44899,-7.70352 7.70352,-7.70352 l 0,0 c 2.04311,0 4.00253,0.81161 5.44722,2.25632 1.4447,1.44467 2.25632,3.40411 2.25632,5.4472 l 0,0 c 0,4.25455 -3.44899,7.70352 -7.70354,7.70352 l 0,0 c -4.25453,0 -7.70352,-3.44897 -7.70352,-7.70352 z"
-       id="path4822"
-       inkscape:connector-curvature="0"
-       style="fill:#999999;fill-rule:nonzero" />
-    <path
-       d="M 117.4221,273.73578 65.645405,393.228"
-       id="path4824"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="M 117.4221,273.73578 65.645405,393.228"
-       id="path4826"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="M 117.4221,273.73578 164.94471,393.228"
-       id="path4828"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="M 117.4221,273.73578 164.94471,393.228"
-       id="path4830"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="M 117.4221,273.73578 89.952465,393.228"
-       id="path4832"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="M 117.4221,273.73578 89.952465,393.228"
-       id="path4834"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="M 117.4221,273.73578 140.63765,393.228"
-       id="path4836"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="M 117.4221,273.73578 140.63765,393.228"
-       id="path4838"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="M 117.4221,273.73578 116.31659,393.228"
-       id="path4840"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="M 117.4221,273.73578 116.31659,393.228"
-       id="path4842"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 181.5549,399.8878 0,0 c 0,-4.25455 3.44897,-7.70352 7.70352,-7.70352 l 0,0 c 2.04311,0 4.00253,0.81161 5.44722,2.25628 1.4447,1.44471 2.25632,3.40415 2.25632,5.44724 l 0,0 c 0,4.25455 -3.44899,7.70352 -7.70354,7.70352 l 0,0 c -4.25455,0 -7.70352,-3.44897 -7.70352,-7.70352 z"
-       id="path4844"
-       inkscape:connector-curvature="0"
-       style="fill:#999999;fill-rule:nonzero" />
-    <path
-       d="m 205.87128,399.8878 0,0 c 0,-4.25455 3.44899,-7.70352 7.70354,-7.70352 l 0,0 c 2.0431,0 4.00253,0.81161 5.44722,2.25628 1.44468,1.44471 2.25631,3.40415 2.25631,5.44724 l 0,0 c 0,4.25455 -3.44899,7.70352 -7.70353,7.70352 l 0,0 c -4.25455,0 -7.70354,-3.44897 -7.70354,-7.70352 z"
-       id="path4846"
-       inkscape:connector-curvature="0"
-       style="fill:#f1c232;fill-rule:nonzero" />
-    <path
-       d="m 230.18767,399.8878 0,0 c 0,-4.25455 3.44899,-7.70352 7.70354,-7.70352 l 0,0 c 2.0431,0 4.00251,0.81161 5.44722,2.25628 1.44468,1.44471 2.2563,3.40415 2.2563,5.44724 l 0,0 c 0,4.25455 -3.44898,7.70352 -7.70352,7.70352 l 0,0 c -4.25455,0 -7.70354,-3.44897 -7.70354,-7.70352 z"
-       id="path4848"
-       inkscape:connector-curvature="0"
-       style="fill:#999999;fill-rule:nonzero" />
-    <path
-       d="m 117.4221,273.73578 71.84365,118.4567"
-       id="path4850"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 117.4221,273.73578 71.84365,118.4567"
-       id="path4852"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 117.4221,273.73578 96.15072,118.4567"
-       id="path4854"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 117.4221,273.73578 96.15072,118.4567"
-       id="path4856"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 117.4221,273.73578 120.47176,118.4567"
-       id="path4858"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 117.4221,273.73578 120.47176,118.4567"
-       id="path4860"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 109.71856,266.03226 0,0 c 0,-4.25455 3.44899,-7.70352 7.70354,-7.70352 l 0,0 c 2.0431,0 4.00252,0.81161 5.44722,2.25632 1.44468,1.44467 2.25631,3.40411 2.25631,5.4472 l 0,0 c 0,4.25455 -3.44899,7.70352 -7.70353,7.70352 l 0,0 c -4.25455,0 -7.70354,-3.44897 -7.70354,-7.70352 z"
-       id="path4862"
-       inkscape:connector-curvature="0"
-       style="fill:#f1c232;fill-rule:nonzero" />
-    <path
-       d="m 111.36923,270.73413 c 3.14159,0 4.71238,-2.35443 6.28317,-4.70886 1.5708,-2.35443 3.14159,-4.7089 6.28318,-4.7089"
-       id="path4864"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 111.36923,270.73413 c 3.14159,0 4.71238,-2.35443 6.28317,-4.70886 1.5708,-2.35443 3.14159,-4.7089 6.28318,-4.7089"
-       id="path4866"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#434343;stroke-width:3;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 846.92847,363.60492 0,0 c 0,-4.9e-4 4.3e-4,-8.9e-4 9.2e-4,-8.9e-4 l 105.84033,8.9e-4 c 1.8e-4,0 4.2e-4,9e-5 6.1e-4,2.4e-4 1.2e-4,1.9e-4 2.4e-4,4e-4 2.4e-4,6.5e-4 l -8.5e-4,23.72979 c 0,4.9e-4 -4.3e-4,8.6e-4 -9.2e-4,8.6e-4 l -105.84033,-8.6e-4 0,0 c -4.9e-4,0 -8.5e-4,-3.9e-4 -8.5e-4,-8.8e-4 z"
-       id="path4868"
-       inkscape:connector-curvature="0"
-       style="fill:#cccccc;fill-rule:nonzero" />
-    <path
-       d="m 722.254,364.3828 0,0 c 0,-4.6e-4 3.7e-4,-8.2e-4 8.5e-4,-8.2e-4 l 98.01074,8.2e-4 c 1.9e-4,0 4.3e-4,9e-5 5.5e-4,2.4e-4 1.9e-4,1.6e-4 2.5e-4,3.7e-4 2.5e-4,5.8e-4 l -8e-4,23.72986 c 0,4.3e-4 -3.6e-4,8e-4 -8.5e-4,8e-4 l -98.01074,-8e-4 0,0 c -4.3e-4,0 -7.9e-4,-3.6e-4 -7.9e-4,-8.2e-4 z"
-       id="path4870"
-       inkscape:connector-curvature="0"
-       style="fill:#cccccc;fill-rule:nonzero" />
-    <path
-       d="m 731.6505,376.2962 0,0 c 0,-4.08316 3.31006,-7.39319 7.39319,-7.39319 l 0,0 c 1.96075,0 3.84125,0.77893 5.22772,2.16541 1.38654,1.3865 2.16541,3.26696 2.16541,5.22778 l 0,0 c 0,4.08313 -3.31,7.39316 -7.39313,7.39316 l 0,0 c -4.08313,0 -7.39319,-3.31003 -7.39319,-7.39316 z"
-       id="path4872"
-       inkscape:connector-curvature="0"
-       style="fill:#434343;fill-rule:nonzero" />
-    <path
-       d="m 754.547,376.2727 0,0 c 0,-4.08316 3.31006,-7.39319 7.39319,-7.39319 l 0,0 c 1.96081,0 3.84131,0.77893 5.22778,2.16541 1.38648,1.3865 2.16541,3.26696 2.16541,5.22778 l 0,0 c 0,4.08313 -3.31006,7.39316 -7.39319,7.39316 l 0,0 c -4.08313,0 -7.39319,-3.31003 -7.39319,-7.39316 z"
-       id="path4874"
-       inkscape:connector-curvature="0"
-       style="fill:#666666;fill-rule:nonzero" />
-    <path
-       d="m 775.6335,376.2727 0,0 c 0,-4.08316 3.31006,-7.39319 7.39319,-7.39319 l 0,0 c 1.96081,0 3.84125,0.77893 5.22778,2.16541 1.38648,1.3865 2.16541,3.26696 2.16541,5.22778 l 0,0 c 0,4.08313 -3.31006,7.39316 -7.39319,7.39316 l 0,0 c -4.08313,0 -7.39319,-3.31003 -7.39319,-7.39316 z"
-       id="path4876"
-       inkscape:connector-curvature="0"
-       style="fill:#999999;fill-rule:nonzero" />
-    <path
-       d="m 798.9702,376.27158 0,0 c 0,-4.08313 3.31006,-7.39319 7.39319,-7.39319 l 0,0 c 1.96081,0 3.84125,0.77893 5.22778,2.16544 1.38648,1.38647 2.16541,3.26697 2.16541,5.22775 l 0,0 c 0,4.08313 -3.31006,7.39316 -7.39319,7.39316 l 0,0 c -4.08313,0 -7.39319,-3.31003 -7.39319,-7.39316 z"
-       id="path4878"
-       inkscape:connector-curvature="0"
-       style="fill:#efefef;fill-rule:nonzero" />
-    <path
-       d="m 857.57184,375.3713 0,0 c 0,-4.08313 3.31006,-7.39316 7.39319,-7.39316 l 0,0 c 1.96081,0 3.84131,0.77893 5.22778,2.16541 1.38648,1.3865 2.16541,3.26697 2.16541,5.22775 l 0,0 c 0,4.08316 -3.31006,7.39319 -7.39319,7.39319 l 0,0 c -4.08313,0 -7.39319,-3.31003 -7.39319,-7.39319 z"
-       id="path4880"
-       inkscape:connector-curvature="0"
-       style="fill:#434343;fill-rule:nonzero" />
-    <path
-       d="m 881.88055,375.37244 0,0 c 0,-4.08313 3.31006,-7.39316 7.39319,-7.39316 l 0,0 c 1.96081,0 3.84131,0.7789 5.22778,2.16541 1.38648,1.38647 2.16541,3.26697 2.16541,5.22775 l 0,0 c 0,4.08313 -3.31006,7.39319 -7.39319,7.39319 l 0,0 c -4.08313,0 -7.39319,-3.31006 -7.39319,-7.39319 z"
-       id="path4882"
-       inkscape:connector-curvature="0"
-       style="fill:#666666;fill-rule:nonzero" />
-    <path
-       d="m 788.2263,328.63565 0,0 c 0,-4.08316 3.31,-7.39319 7.39313,-7.39319 l 0,0 c 1.96081,0 3.84131,0.77893 5.22778,2.16541 1.38648,1.3865 2.16541,3.26696 2.16541,5.22778 l 0,0 c 0,4.08313 -3.31006,7.39316 -7.39319,7.39316 l 0,0 c -4.08313,0 -7.39313,-3.31003 -7.39313,-7.39316 z"
-       id="path4884"
-       inkscape:connector-curvature="0"
-       style="fill:#f1c232;fill-rule:nonzero" />
-    <path
-       d="m 795.61945,336.0288 -56.56683,32.87646"
-       id="path4886"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 795.61945,336.0288 -56.56683,32.87646"
-       id="path4888"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 795.61945,336.0288 -33.66882,32.84958"
-       id="path4890"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 795.61945,336.0288 -33.66882,32.84958"
-       id="path4892"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 795.61945,336.0288 93.67371,31.93637"
-       id="path4894"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 795.61945,336.0288 93.67371,31.93637"
-       id="path4896"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 795.61945,336.0288 -12.59723,32.84958"
-       id="path4898"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 795.61945,336.0288 -12.59723,32.84958"
-       id="path4900"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 795.61945,336.0288 69.35211,31.93637"
-       id="path4902"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 795.61945,336.0288 69.35211,31.93637"
-       id="path4904"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 795.61945,336.0288 10.74396,32.84958"
-       id="path4906"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 795.61945,336.0288 10.74396,32.84958"
-       id="path4908"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 905.2173,375.37076 0,0 c 0,-4.08313 3.31006,-7.39316 7.39319,-7.39316 l 0,0 c 1.96081,0 3.84131,0.7789 5.22778,2.16541 1.38648,1.38647 2.16541,3.26697 2.16541,5.22775 l 0,0 c 0,4.08313 -3.31006,7.39319 -7.39319,7.39319 l 0,0 c -4.08313,0 -7.39319,-3.31006 -7.39319,-7.39319 z"
-       id="path4910"
-       inkscape:connector-curvature="0"
-       style="fill:#999999;fill-rule:nonzero" />
-    <path
-       d="m 928.554,375.37076 0,0 c 0,-4.08313 3.31006,-7.39316 7.39319,-7.39316 l 0,0 c 1.96081,0 3.84125,0.7789 5.22778,2.16541 1.38648,1.38647 2.16541,3.26697 2.16541,5.22775 l 0,0 c 0,4.08313 -3.31006,7.39319 -7.39319,7.39319 l 0,0 c -4.08313,0 -7.39319,-3.31006 -7.39319,-7.39319 z"
-       id="path4912"
-       inkscape:connector-curvature="0"
-       style="fill:#efefef;fill-rule:nonzero" />
-    <path
-       d="m 795.61945,336.0288 116.97461,31.93637"
-       id="path4914"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 795.61945,336.0288 116.97461,31.93637"
-       id="path4916"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 795.61945,336.0288 140.31586,31.93637"
-       id="path4918"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 795.61945,336.0288 140.31586,31.93637"
-       id="path4920"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 826.32306,328.63565 0,0 c 0,-4.08316 3.31006,-7.39319 7.39319,-7.39319 l 0,0 c 1.96075,0 3.84125,0.77893 5.22772,2.16541 1.38654,1.3865 2.16541,3.26696 2.16541,5.22778 l 0,0 c 0,4.08313 -3.31,7.39316 -7.39313,7.39316 l 0,0 c -4.08313,0 -7.39319,-3.31003 -7.39319,-7.39316 z"
-       id="path4922"
-       inkscape:connector-curvature="0"
-       style="fill:#f1c232;fill-rule:nonzero" />
-    <path
-       d="m 833.71625,336.0288 -94.68097,32.87646"
-       id="path4924"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 833.71625,336.0288 -94.68097,32.87646"
-       id="path4926"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 833.71625,336.0288 -27.34332,32.84958"
-       id="path4928"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 833.71625,336.0288 -27.34332,32.84958"
-       id="path4930"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 833.71625,336.0288 55.55957,31.93637"
-       id="path4932"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 833.71625,336.0288 55.55957,31.93637"
-       id="path4934"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 833.71625,336.0288 -71.78296,32.84958"
-       id="path4936"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 833.71625,336.0288 -71.78296,32.84958"
-       id="path4938"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 833.71625,336.0288 31.23798,31.93637"
-       id="path4940"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 833.71625,336.0288 31.23798,31.93637"
-       id="path4942"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 833.71625,336.0288 -45.4469,35.02524"
-       id="path4944"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 833.71625,336.0288 -45.4469,35.02524"
-       id="path4946"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 833.71625,336.0288 78.90076,31.93637"
-       id="path4948"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 833.71625,336.0288 78.90076,31.93637"
-       id="path4950"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 833.71625,336.0288 102.242,31.93637"
-       id="path4952"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 833.71625,336.0288 102.242,31.93637"
-       id="path4954"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 864.42096,328.63565 0,0 c 0,-4.08316 3.31,-7.39319 7.39319,-7.39319 l 0,0 c 1.96075,0 3.84125,0.77893 5.22772,2.16541 1.38648,1.3865 2.16541,3.26696 2.16541,5.22778 l 0,0 c 0,4.08313 -3.31,7.39316 -7.39313,7.39316 l 0,0 c -4.08319,0 -7.39319,-3.31003 -7.39319,-7.39316 z"
-       id="path4956"
-       inkscape:connector-curvature="0"
-       style="fill:#f1c232;fill-rule:nonzero" />
-    <path
-       d="M 871.81415,336.0288 739.05933,368.90526"
-       id="path4958"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="M 871.81415,336.0288 739.05933,368.90526"
-       id="path4960"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 871.81415,336.0288 -65.44403,32.84958"
-       id="path4962"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 871.81415,336.0288 -65.44403,32.84958"
-       id="path4964"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 871.81415,336.0288 17.44544,31.93637"
-       id="path4966"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 871.81415,336.0288 17.44544,31.93637"
-       id="path4968"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="M 871.81415,336.0288 761.95734,368.87838"
-       id="path4970"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="M 871.81415,336.0288 761.95734,368.87838"
-       id="path4972"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 871.81415,336.0288 -6.83588,31.93637"
-       id="path4974"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 871.81415,336.0288 -6.83588,31.93637"
-       id="path4976"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 871.81415,336.0288 -88.78522,32.84958"
-       id="path4978"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 871.81415,336.0288 -88.78522,32.84958"
-       id="path4980"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 871.81415,336.0288 40.78662,31.93637"
-       id="path4982"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 871.81415,336.0288 40.78662,31.93637"
-       id="path4984"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 871.81415,336.0288 64.12787,31.93637"
-       id="path4986"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 871.81415,336.0288 64.12787,31.93637"
-       id="path4988"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 788.2263,299.19727 0,0 c 0,-4.08316 3.31,-7.39319 7.39313,-7.39319 l 0,0 c 1.96081,0 3.84131,0.77893 5.22778,2.16541 1.38648,1.3865 2.16541,3.26696 2.16541,5.22778 l 0,0 c 0,4.08313 -3.31006,7.39316 -7.39319,7.39316 l 0,0 c -4.08313,0 -7.39313,-3.31003 -7.39313,-7.39316 z"
-       id="path4990"
-       inkscape:connector-curvature="0"
-       style="fill:#f1c232;fill-rule:nonzero" />
-    <path
-       d="m 826.32306,299.19727 0,0 c 0,-4.08316 3.31006,-7.39319 7.39319,-7.39319 l 0,0 c 1.96075,0 3.84125,0.77893 5.22772,2.16541 1.38654,1.3865 2.16541,3.26696 2.16541,5.22778 l 0,0 c 0,4.08313 -3.31,7.39316 -7.39313,7.39316 l 0,0 c -4.08313,0 -7.39319,-3.31003 -7.39319,-7.39316 z"
-       id="path4992"
-       inkscape:connector-curvature="0"
-       style="fill:#f1c232;fill-rule:nonzero" />
-    <path
-       d="m 864.42096,299.19727 0,0 c 0,-4.08316 3.31,-7.39319 7.39319,-7.39319 l 0,0 c 1.96075,0 3.84125,0.77893 5.22772,2.16541 1.38648,1.3865 2.16541,3.26696 2.16541,5.22778 l 0,0 c 0,4.08313 -3.31,7.39316 -7.39313,7.39316 l 0,0 c -4.08319,0 -7.39319,-3.31003 -7.39319,-7.39316 z"
-       id="path4994"
-       inkscape:connector-curvature="0"
-       style="fill:#f1c232;fill-rule:nonzero" />
-    <path
-       d="m 871.81415,306.59042 0,14.65204"
-       id="path4996"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 871.81415,306.59042 0,14.65204"
-       id="path4998"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 833.71625,306.59042 0,14.65204"
-       id="path5000"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 833.71625,306.59042 0,14.65204"
-       id="path5002"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 795.61945,306.59042 0,14.65204"
-       id="path5004"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 795.61945,306.59042 0,14.65204"
-       id="path5006"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 871.81415,306.59042 -76.18799,14.65204"
-       id="path5008"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 871.81415,306.59042 -76.18799,14.65204"
-       id="path5010"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 871.81415,306.59042 -38.11414,14.65204"
-       id="path5012"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 871.81415,306.59042 -38.11414,14.65204"
-       id="path5014"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 833.71625,306.59042 38.11407,14.65204"
-       id="path5016"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 833.71625,306.59042 38.11407,14.65204"
-       id="path5018"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 795.61945,306.59042 38.11414,14.65204"
-       id="path5020"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 795.61945,306.59042 38.11414,14.65204"
-       id="path5022"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 833.71625,306.59042 -38.11414,14.65204"
-       id="path5024"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 833.71625,306.59042 -38.11414,14.65204"
-       id="path5026"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 833.71625,273.60336 -38.11414,18.21097"
-       id="path5028"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 833.71625,273.60336 -38.11414,18.21097"
-       id="path5030"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 833.71625,273.60336 0,18.21097"
-       id="path5032"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 833.71625,273.60336 0,18.21097"
-       id="path5034"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 833.71625,273.60336 38.11407,18.21097"
-       id="path5036"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 833.71625,273.60336 38.11407,18.21097"
-       id="path5038"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 763.86664,402.68573 0,0 c 0,-4.08313 3.31006,-7.39316 7.39319,-7.39316 l 0,0 c 1.96081,0 3.84125,0.77893 5.22778,2.16541 1.38648,1.3865 2.16541,3.26697 2.16541,5.22775 l 0,0 c 0,4.08316 -3.31006,7.39319 -7.39319,7.39319 l 0,0 c -4.08313,0 -7.39319,-3.31003 -7.39319,-7.39319 z"
-       id="path5040"
-       inkscape:connector-curvature="0"
-       style="fill:#f1c232;fill-rule:nonzero" />
-    <path
-       d="m 771.2598,388.1135 0,7.185"
-       id="path5042"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 771.2598,388.1135 0,7.185"
-       id="path5044"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 892.45593,402.68573 0,0 c 0,-4.08313 3.31,-7.39316 7.39313,-7.39316 l 0,0 c 1.96081,0 3.84131,0.77893 5.22778,2.16541 1.38648,1.3865 2.16541,3.26697 2.16541,5.22775 l 0,0 c 0,4.08316 -3.31006,7.39319 -7.39319,7.39319 l 0,0 c -4.08313,0 -7.39313,-3.31003 -7.39313,-7.39319 z"
-       id="path5046"
-       inkscape:connector-curvature="0"
-       style="fill:#f1c232;fill-rule:nonzero" />
-    <path
-       d="m 899.84906,387.3356 0,7.96393"
-       id="path5048"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 899.84906,387.3356 0,7.96393"
-       id="path5050"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 788.3015,330.95367 8.12622,0 4.51245,-7.81622"
-       id="path5052"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 788.3015,330.95367 8.12622,0 4.51245,-7.81622"
-       id="path5054"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#434343;stroke-width:3;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 826.3613,330.95367 8.12622,0 4.51251,-7.81622"
-       id="path5056"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 826.3613,330.95367 8.12622,0 4.51251,-7.81622"
-       id="path5058"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#434343;stroke-width:3;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 864.42065,330.95367 8.12622,0 4.51245,-7.81622"
-       id="path5060"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 864.42065,330.95367 8.12622,0 4.51245,-7.81622"
-       id="path5062"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#434343;stroke-width:3;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 864.42065,301.47162 8.12622,0 4.51245,-7.81619"
-       id="path5064"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 864.42065,301.47162 8.12622,0 4.51245,-7.81619"
-       id="path5066"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#434343;stroke-width:3;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 826.62317,301.47162 8.12622,0 4.51245,-7.81619"
-       id="path5068"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 826.62317,301.47162 8.12622,0 4.51245,-7.81619"
-       id="path5070"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#434343;stroke-width:3;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 788.3015,301.47162 8.12622,0 4.51245,-7.81619"
-       id="path5072"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 788.3015,301.47162 8.12622,0 4.51245,-7.81619"
-       id="path5074"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#434343;stroke-width:3;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 826.32306,266.21017 0,0 c 0,-4.08313 3.31006,-7.39316 7.39319,-7.39316 l 0,0 c 1.96075,0 3.84125,0.7789 5.22772,2.16541 1.38654,1.38647 2.16541,3.26697 2.16541,5.22775 l 0,0 c 0,4.08313 -3.31,7.39319 -7.39313,7.39319 l 0,0 c -4.08313,0 -7.39319,-3.31006 -7.39319,-7.39319 z"
-       id="path5076"
-       inkscape:connector-curvature="0"
-       style="fill:#f1c232;fill-rule:nonzero" />
-    <path
-       d="m 827.6896,270.72263 c 3.0083,0 4.51245,-2.25958 6.0166,-4.51917 1.50415,-2.25958 3.0083,-4.51916 6.0166,-4.51916"
-       id="path5078"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 827.6896,270.72263 c 3.0083,0 4.51245,-2.25958 6.0166,-4.51917 1.50415,-2.25958 3.0083,-4.51916 6.0166,-4.51916"
-       id="path5080"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#434343;stroke-width:3;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 411.3028,272.46396 51.11435,20.4736"
-       id="path5082"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 411.3028,272.46396 51.11435,20.4736"
-       id="path5084"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 411.3028,272.46396 87.09213,20.4736"
-       id="path5086"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 411.3028,272.46396 87.09213,20.4736"
-       id="path5088"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 411.3028,272.46396 123.05725,20.4736"
-       id="path5090"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 411.3028,272.46396 123.05725,20.4736"
-       id="path5092"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 260.54385,400.8314 0,0 c 0,-3.85428 3.12451,-6.97879 6.97876,-6.97879 l 0,0 c 1.85089,0 3.62598,0.73526 4.93475,2.04404 1.30878,1.30877 2.04404,3.08386 2.04404,4.93475 l 0,0 c 0,3.85428 -3.12451,6.97879 -6.97879,6.97879 l 0,0 c -3.85425,0 -6.97876,-3.12451 -6.97876,-6.97879 z"
-       id="path5094"
-       inkscape:connector-curvature="0"
-       style="fill:#999999;fill-rule:nonzero" />
-    <path
-       d="m 282.57254,400.8314 0,0 c 0,-3.85428 3.12451,-6.97879 6.97876,-6.97879 l 0,0 c 1.85089,0 3.62598,0.73526 4.93475,2.04404 1.30878,1.30877 2.04404,3.08386 2.04404,4.93475 l 0,0 c 0,3.85428 -3.12451,6.97879 -6.97879,6.97879 l 0,0 c -3.85425,0 -6.97876,-3.12451 -6.97876,-6.97879 z"
-       id="path5096"
-       inkscape:connector-curvature="0"
-       style="fill:#999999;fill-rule:nonzero" />
-    <path
-       d="m 304.60123,400.8314 0,0 c 0,-3.85428 3.12451,-6.97879 6.97876,-6.97879 l 0,0 c 1.85089,0 3.62598,0.73526 4.93475,2.04404 1.30878,1.30877 2.04404,3.08386 2.04404,4.93475 l 0,0 c 0,3.85428 -3.12451,6.97879 -6.97879,6.97879 l 0,0 c -3.85425,0 -6.97876,-3.12451 -6.97876,-6.97879 z"
-       id="path5098"
-       inkscape:connector-curvature="0"
-       style="fill:#f1c232;fill-rule:nonzero" />
-    <path
-       d="m 326.6299,400.83032 0,0 c 0,-3.85425 3.12451,-6.97876 6.97876,-6.97876 l 0,0 c 1.85089,0 3.62598,0.73526 4.93475,2.04404 1.30878,1.30877 2.04404,3.08386 2.04404,4.93472 l 0,0 c 0,3.85428 -3.12451,6.97879 -6.97879,6.97879 l 0,0 c -3.85425,0 -6.97876,-3.12451 -6.97876,-6.97879 z"
-       id="path5100"
-       inkscape:connector-curvature="0"
-       style="fill:#999999;fill-rule:nonzero" />
-    <path
-       d="m 349.5761,400.83032 0,0 c 0,-3.85425 3.12451,-6.97876 6.97876,-6.97876 l 0,0 c 1.85089,0 3.62598,0.73526 4.93475,2.04404 1.30878,1.30877 2.04404,3.08386 2.04404,4.93472 l 0,0 c 0,3.85428 -3.12451,6.97879 -6.97879,6.97879 l 0,0 c -3.85425,0 -6.97876,-3.12451 -6.97876,-6.97879 z"
-       id="path5102"
-       inkscape:connector-curvature="0"
-       style="fill:#999999;fill-rule:nonzero" />
-    <path
-       d="m 372.5223,400.8314 0,0 c 0,-3.85428 3.12451,-6.97879 6.97876,-6.97879 l 0,0 c 1.85089,0 3.62598,0.73526 4.93475,2.04404 1.30878,1.30877 2.04404,3.08386 2.04404,4.93475 l 0,0 c 0,3.85428 -3.12451,6.97879 -6.97879,6.97879 l 0,0 c -3.85425,0 -6.97876,-3.12451 -6.97876,-6.97879 z"
-       id="path5104"
-       inkscape:connector-curvature="0"
-       style="fill:#999999;fill-rule:nonzero" />
-    <path
-       d="M 411.3028,272.46396 267.5184,393.84781"
-       id="path5106"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="M 411.3028,272.46396 267.5184,393.84781"
-       id="path5108"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="M 411.3028,272.46396 289.53864,393.84781"
-       id="path5110"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="M 411.3028,272.46396 289.53864,393.84781"
-       id="path5112"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="M 411.3028,272.46396 379.49579,393.84781"
-       id="path5114"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="M 411.3028,272.46396 379.49579,393.84781"
-       id="path5116"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="M 411.3028,272.46396 311.58423,393.84781"
-       id="path5118"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="M 411.3028,272.46396 311.58423,393.84781"
-       id="path5120"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="M 411.3028,272.46396 356.56281,393.84781"
-       id="path5122"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="M 411.3028,272.46396 356.56281,393.84781"
-       id="path5124"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="M 411.3028,272.46396 333.60447,393.84781"
-       id="path5126"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="M 411.3028,272.46396 333.60447,393.84781"
-       id="path5128"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 404.32404,265.48517 0,0 c 0,-3.85428 3.12451,-6.97879 6.97876,-6.97879 l 0,0 c 1.85089,0 3.62598,0.73526 4.93475,2.04404 1.30878,1.30877 2.04404,3.08386 2.04404,4.93475 l 0,0 c 0,3.85428 -3.12451,6.97879 -6.97879,6.97879 l 0,0 c -3.85425,0 -6.97876,-3.12451 -6.97876,-6.97879 z"
-       id="path5130"
-       inkscape:connector-curvature="0"
-       style="fill:#f1c232;fill-rule:nonzero" />
-    <path
-       d="m 405.61658,269.7447 c 2.84918,0 4.27377,-2.12659 5.69839,-4.25317 1.4246,-2.12662 2.84919,-4.25321 5.69837,-4.25321"
-       id="path5132"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 405.61658,269.7447 c 2.84918,0 4.27377,-2.12659 5.69839,-4.25317 1.4246,-2.12662 2.84919,-4.25321 5.69837,-4.25321"
-       id="path5134"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#434343;stroke-width:3;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 498.61996,365.02457 0,0 c 0,-4e-4 3.4e-4,-7.3e-4 7.3e-4,-7.3e-4 l 89.61417,7.3e-4 c 1.9e-4,0 3.7e-4,9e-5 4.9e-4,2.1e-4 1.2e-4,1.6e-4 2.4e-4,3.4e-4 2.4e-4,5.5e-4 l -7.3e-4,22.39978 c 0,4.3e-4 -3.6e-4,7.7e-4 -7.9e-4,7.7e-4 l -89.61411,-7.7e-4 0,0 c -4.2e-4,0 -7.6e-4,-3e-4 -7.6e-4,-7.3e-4 z"
-       id="path5136"
-       inkscape:connector-curvature="0"
-       style="fill:#d9d9d9;fill-rule:nonzero" />
-    <path
-       d="m 401.27518,364.6629 0,0 c 0,-4e-4 3.1e-4,-7.3e-4 7.3e-4,-7.3e-4 l 87.64917,7.3e-4 c 2.2e-4,0 4e-4,9e-5 5.2e-4,2.1e-4 1.5e-4,1.6e-4 2.1e-4,3.4e-4 2.1e-4,5.2e-4 l -7.3e-4,22.39984 c 0,4e-4 -3e-4,7.1e-4 -7.3e-4,7.1e-4 l -87.64917,-7.1e-4 0,0 c -4e-4,0 -7.3e-4,-3.3e-4 -7.3e-4,-7.3e-4 z"
-       id="path5138"
-       inkscape:connector-curvature="0"
-       style="fill:#d9d9d9;fill-rule:nonzero" />
-    <path
-       d="m 405.2727,375.9086 0,0 c 0,-3.85425 3.12451,-6.97876 6.97879,-6.97876 l 0,0 c 1.85086,0 3.62595,0.73526 4.93472,2.04404 1.30878,1.30877 2.04404,3.08386 2.04404,4.93472 l 0,0 c 0,3.85428 -3.12451,6.97879 -6.97876,6.97879 l 0,0 c -3.85428,0 -6.97879,-3.12451 -6.97879,-6.97879 z"
-       id="path5140"
-       inkscape:connector-curvature="0"
-       style="fill:#434343;fill-rule:nonzero" />
-    <path
-       d="m 426.88586,375.8864 0,0 c 0,-3.85425 3.12448,-6.97876 6.97876,-6.97876 l 0,0 c 1.85089,0 3.62598,0.73526 4.93475,2.04404 1.30878,1.30877 2.04404,3.08386 2.04404,4.93472 l 0,0 c 0,3.85428 -3.12451,6.97879 -6.97879,6.97879 l 0,0 c -3.85428,0 -6.97876,-3.12451 -6.97876,-6.97879 z"
-       id="path5142"
-       inkscape:connector-curvature="0"
-       style="fill:#666666;fill-rule:nonzero" />
-    <path
-       d="m 446.7904,375.8864 0,0 c 0,-3.85425 3.12451,-6.97876 6.97879,-6.97876 l 0,0 c 1.85089,0 3.62598,0.73526 4.93475,2.04404 1.30878,1.30877 2.04404,3.08386 2.04404,4.93472 l 0,0 c 0,3.85428 -3.12451,6.97879 -6.97879,6.97879 l 0,0 c -3.85428,0 -6.97879,-3.12451 -6.97879,-6.97879 z"
-       id="path5144"
-       inkscape:connector-curvature="0"
-       style="fill:#999999;fill-rule:nonzero" />
-    <path
-       d="m 468.8191,375.88538 0,0 c 0,-3.85428 3.12451,-6.97879 6.97879,-6.97879 l 0,0 c 1.85089,0 3.62598,0.73526 4.93475,2.04404 1.30878,1.30877 2.04404,3.08386 2.04404,4.93475 l 0,0 c 0,3.85428 -3.12451,6.97876 -6.97879,6.97876 l 0,0 c -3.85428,0 -6.97879,-3.12448 -6.97879,-6.97876 z"
-       id="path5146"
-       inkscape:connector-curvature="0"
-       style="fill:#b7b7b7;fill-rule:nonzero" />
-    <path
-       d="m 502.69196,375.76062 0,0 c 0,-3.85428 3.12451,-6.97879 6.97879,-6.97879 l 0,0 c 1.85089,0 3.62595,0.73526 4.93472,2.04404 1.30878,1.30877 2.04407,3.08386 2.04407,4.93475 l 0,0 c 0,3.85428 -3.12451,6.97876 -6.97879,6.97876 l 0,0 c -3.85428,0 -6.97879,-3.12448 -6.97879,-6.97876 z"
-       id="path5148"
-       inkscape:connector-curvature="0"
-       style="fill:#b7b7b7;fill-rule:nonzero" />
-    <path
-       d="m 525.6382,375.76166 0,0 c 0,-3.85425 3.12451,-6.97876 6.97876,-6.97876 l 0,0 c 1.85089,0 3.62598,0.73526 4.93475,2.04404 1.30878,1.30877 2.04401,3.08386 2.04401,4.93472 l 0,0 c 0,3.85428 -3.12451,6.97879 -6.97876,6.97879 l 0,0 c -3.85425,0 -6.97876,-3.12451 -6.97876,-6.97879 z"
-       id="path5150"
-       inkscape:connector-curvature="0"
-       style="fill:#999999;fill-rule:nonzero" />
-    <path
-       d="m 455.45737,330.91946 0,0 c 0,-3.85428 3.12451,-6.97879 6.97879,-6.97879 l 0,0 c 1.85089,0 3.62598,0.73526 4.93475,2.04404 1.30875,1.30877 2.04404,3.08386 2.04404,4.93475 l 0,0 c 0,3.85428 -3.12451,6.97879 -6.97879,6.97879 l 0,0 c -3.85428,0 -6.97879,-3.12451 -6.97879,-6.97879 z"
-       id="path5152"
-       inkscape:connector-curvature="0"
-       style="fill:#f1c232;fill-rule:nonzero" />
-    <path
-       d="m 462.43616,337.89825 -50.17621,31.02103"
-       id="path5154"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 462.43616,337.89825 -50.17621,31.02103"
-       id="path5156"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 462.43616,337.89825 -28.56165,31.02103"
-       id="path5158"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 462.43616,337.89825 -28.56165,31.02103"
-       id="path5160"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 462.43616,337.89825 70.16809,30.89426"
-       id="path5162"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 462.43616,337.89825 70.16809,30.89426"
-       id="path5164"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 462.43616,337.89825 -8.65851,31.02103"
-       id="path5166"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 462.43616,337.89825 -8.65851,31.02103"
-       id="path5168"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 462.43616,337.89825 47.2478,30.89426"
-       id="path5170"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 462.43616,337.89825 47.2478,30.89426"
-       id="path5172"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 462.43616,337.89825 13.34906,31.02103"
-       id="path5174"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 462.43616,337.89825 13.34906,31.02103"
-       id="path5176"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 547.6669,375.76007 0,0 c 0,-3.85425 3.12451,-6.97876 6.97876,-6.97876 l 0,0 c 1.85089,0 3.62598,0.73526 4.93475,2.04404 1.30878,1.30877 2.04401,3.08386 2.04401,4.93472 l 0,0 c 0,3.85428 -3.12451,6.97879 -6.97876,6.97879 l 0,0 c -3.85425,0 -6.97876,-3.12451 -6.97876,-6.97879 z"
-       id="path5178"
-       inkscape:connector-curvature="0"
-       style="fill:#434343;fill-rule:nonzero" />
-    <path
-       d="m 569.69556,375.76007 0,0 c 0,-3.85425 3.12451,-6.97876 6.97876,-6.97876 l 0,0 c 1.85089,0 3.62598,0.73526 4.93475,2.04404 1.30878,1.30877 2.04401,3.08386 2.04401,4.93472 l 0,0 c 0,3.85428 -3.12451,6.97879 -6.97876,6.97879 l 0,0 c -3.85425,0 -6.97876,-3.12451 -6.97876,-6.97879 z"
-       id="path5180"
-       inkscape:connector-curvature="0"
-       style="fill:#666666;fill-rule:nonzero" />
-    <path
-       d="m 462.43616,337.89825 92.22638,30.89426"
-       id="path5182"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 462.43616,337.89825 92.22638,30.89426"
-       id="path5184"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 462.43616,337.89825 114.23395,30.89426"
-       id="path5186"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 462.43616,337.89825 114.23395,30.89426"
-       id="path5188"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 491.4188,330.91946 0,0 c 0,-3.85428 3.12451,-6.97879 6.97879,-6.97879 l 0,0 c 1.85089,0 3.62595,0.73526 4.93472,2.04404 1.30878,1.30877 2.04404,3.08386 2.04404,4.93475 l 0,0 c 0,3.85428 -3.12448,6.97879 -6.97876,6.97879 l 0,0 c -3.85428,0 -6.97879,-3.12451 -6.97879,-6.97879 z"
-       id="path5190"
-       inkscape:connector-curvature="0"
-       style="fill:#f1c232;fill-rule:nonzero" />
-    <path
-       d="m 498.39758,337.89825 -86.14136,31.02103"
-       id="path5192"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 498.39758,337.89825 -86.14136,31.02103"
-       id="path5194"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 498.39758,337.89825 -22.61606,31.02103"
-       id="path5196"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 498.39758,337.89825 -22.61606,31.02103"
-       id="path5198"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 498.39758,337.89825 34.203,30.89426"
-       id="path5200"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 498.39758,337.89825 34.203,30.89426"
-       id="path5202"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 498.39758,337.89825 -64.53946,31.02103"
-       id="path5204"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 498.39758,337.89825 -64.53946,31.02103"
-       id="path5206"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 498.39758,337.89825 11.26999,30.89426"
-       id="path5208"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 498.39758,337.89825 11.26999,30.89426"
-       id="path5210"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 498.39758,337.89825 -39.67953,33.03668"
-       id="path5212"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 498.39758,337.89825 -39.67953,33.03668"
-       id="path5214"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 498.39758,337.89825 56.2486,30.89426"
-       id="path5216"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 498.39758,337.89825 56.2486,30.89426"
-       id="path5218"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 498.39758,337.89825 78.2688,30.89426"
-       id="path5220"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 498.39758,337.89825 78.2688,30.89426"
-       id="path5222"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 527.3813,330.91946 0,0 c 0,-3.85428 3.12451,-6.97879 6.97876,-6.97879 l 0,0 c 1.85089,0 3.62598,0.73526 4.93475,2.04404 1.30878,1.30877 2.04401,3.08386 2.04401,4.93475 l 0,0 c 0,3.85428 -3.12445,6.97879 -6.97876,6.97879 l 0,0 c -3.85425,0 -6.97876,-3.12451 -6.97876,-6.97879 z"
-       id="path5224"
-       inkscape:connector-curvature="0"
-       style="fill:#f1c232;fill-rule:nonzero" />
-    <path
-       d="M 534.36005,337.89825 412.2536,368.91928"
-       id="path5226"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="M 534.36005,337.89825 412.2536,368.91928"
-       id="path5228"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 534.36005,337.89825 -58.55582,31.02103"
-       id="path5230"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 534.36005,337.89825 -58.55582,31.02103"
-       id="path5232"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 534.36005,337.89825 -1.73676,30.89426"
-       id="path5234"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 534.36005,337.89825 -1.73676,30.89426"
-       id="path5236"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="M 534.36005,337.89825 433.8555,368.91928"
-       id="path5238"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="M 534.36005,337.89825 433.8555,368.91928"
-       id="path5240"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 534.36005,337.89825 -24.6951,30.89426"
-       id="path5242"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 534.36005,337.89825 -24.6951,30.89426"
-       id="path5244"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 534.36005,337.89825 -80.60141,31.02103"
-       id="path5246"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 534.36005,337.89825 -80.60141,31.02103"
-       id="path5248"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 534.36005,337.89825 20.28345,30.89426"
-       id="path5250"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 534.36005,337.89825 20.28345,30.89426"
-       id="path5252"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 534.36005,337.89825 42.30371,30.89426"
-       id="path5254"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 534.36005,337.89825 42.30371,30.89426"
-       id="path5256"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 455.45737,299.91113 0,0 c 0,-3.85428 3.12451,-6.97879 6.97879,-6.97879 l 0,0 c 1.85089,0 3.62598,0.73526 4.93475,2.04404 1.30875,1.30877 2.04404,3.08386 2.04404,4.93475 l 0,0 c 0,3.85425 -3.12451,6.97876 -6.97879,6.97876 l 0,0 c -3.85428,0 -6.97879,-3.12451 -6.97879,-6.97876 z"
-       id="path5258"
-       inkscape:connector-curvature="0"
-       style="fill:#f1c232;fill-rule:nonzero" />
-    <path
-       d="m 491.4188,299.91113 0,0 c 0,-3.85428 3.12451,-6.97879 6.97879,-6.97879 l 0,0 c 1.85089,0 3.62595,0.73526 4.93472,2.04404 1.30878,1.30877 2.04404,3.08386 2.04404,4.93475 l 0,0 c 0,3.85425 -3.12448,6.97876 -6.97876,6.97876 l 0,0 c -3.85428,0 -6.97879,-3.12451 -6.97879,-6.97876 z"
-       id="path5260"
-       inkscape:connector-curvature="0"
-       style="fill:#f1c232;fill-rule:nonzero" />
-    <path
-       d="m 527.3813,299.91113 0,0 c 0,-3.85428 3.12451,-6.97879 6.97876,-6.97879 l 0,0 c 1.85089,0 3.62598,0.73526 4.93475,2.04404 1.30878,1.30877 2.04401,3.08386 2.04401,4.93475 l 0,0 c 0,3.85425 -3.12445,6.97876 -6.97876,6.97876 l 0,0 c -3.85425,0 -6.97876,-3.12451 -6.97876,-6.97876 z"
-       id="path5262"
-       inkscape:connector-curvature="0"
-       style="fill:#f1c232;fill-rule:nonzero" />
-    <path
-       d="m 534.36005,306.8899 0,17.03812"
-       id="path5264"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 534.36005,306.8899 0,17.03812"
-       id="path5266"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 498.39758,306.8899 0,17.03812"
-       id="path5268"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 498.39758,306.8899 0,17.03812"
-       id="path5270"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 462.43616,306.8899 0,17.03812"
-       id="path5272"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 462.43616,306.8899 0,17.03812"
-       id="path5274"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 534.36005,306.8899 -71.9429,17.03812"
-       id="path5276"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 534.36005,306.8899 -71.9429,17.03812"
-       id="path5278"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 534.36005,306.8899 -35.96512,17.03812"
-       id="path5280"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 534.36005,306.8899 -35.96512,17.03812"
-       id="path5282"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 498.39758,306.8899 35.96509,17.03812"
-       id="path5284"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 498.39758,306.8899 35.96509,17.03812"
-       id="path5286"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 462.43616,306.8899 35.96512,17.03812"
-       id="path5288"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 462.43616,306.8899 35.96512,17.03812"
-       id="path5290"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 498.39758,306.8899 -35.96512,17.03812"
-       id="path5292"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 498.39758,306.8899 -35.96512,17.03812"
-       id="path5294"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 455.4576,301.80026 7.67075,0 4.25952,-7.37811"
-       id="path5296"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 455.4576,301.80026 7.67075,0 4.25952,-7.37811"
-       id="path5298"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#434343;stroke-width:3;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 491.41956,301.80026 7.67075,0 4.25952,-7.37811"
-       id="path5300"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 491.41956,301.80026 7.67075,0 4.25952,-7.37811"
-       id="path5302"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#434343;stroke-width:3;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 527.97626,301.80026 7.67078,0 4.25952,-7.37811"
-       id="path5304"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 527.97626,301.80026 7.67078,0 4.25952,-7.37811"
-       id="path5306"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#434343;stroke-width:3;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 491.78485,333.51288 7.67075,0 4.25952,-7.37811"
-       id="path5308"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 491.78485,333.51288 7.67075,0 4.25952,-7.37811"
-       id="path5310"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#434343;stroke-width:3;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 527.61096,333.51288 7.67078,0 4.25952,-7.37811"
-       id="path5312"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 527.61096,333.51288 7.67078,0 4.25952,-7.37811"
-       id="path5314"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#434343;stroke-width:3;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 455.8229,333.51288 7.67075,0 4.25952,-7.37811"
-       id="path5316"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 455.8229,333.51288 7.67075,0 4.25952,-7.37811"
-       id="path5318"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#434343;stroke-width:3;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 445.10013,387.06348 0,6.80762"
-       id="path5320"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 445.10013,387.06348 0,6.80762"
-       id="path5322"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 543.42737,387.4251 0,6.42734"
-       id="path5324"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 543.42737,387.4251 0,6.42734"
-       id="path5326"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 536.4481,400.8322 0,0 c 0,-3.85428 3.12451,-6.97879 6.97876,-6.97879 l 0,0 c 1.85089,0 3.62598,0.73529 4.93475,2.04404 1.30878,1.30877 2.04407,3.08386 2.04407,4.93475 l 0,0 c 0,3.85428 -3.12451,6.97879 -6.97882,6.97879 l 0,0 c -3.85425,0 -6.97876,-3.12451 -6.97876,-6.97879 z"
-       id="path5328"
-       inkscape:connector-curvature="0"
-       style="fill:#f1c232;fill-rule:nonzero" />
-    <path
-       d="m 438.12134,400.8317 0,0 c 0,-3.85428 3.12451,-6.97879 6.97879,-6.97879 l 0,0 c 1.85089,0 3.62598,0.73526 4.93475,2.04404 1.30878,1.30877 2.04404,3.08386 2.04404,4.93475 l 0,0 c 0,3.85428 -3.12451,6.97876 -6.97879,6.97876 l 0,0 c -3.85428,0 -6.97879,-3.12448 -6.97879,-6.97876 z"
-       id="path5330"
-       inkscape:connector-curvature="0"
-       style="fill:#f1c232;fill-rule:nonzero" />
-    <path
-       d="m 462.43616,306.8899 71.93701,17.03937"
-       id="path5332"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 462.43616,306.8899 71.93701,17.03937"
-       id="path5334"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 795.61945,306.59042 76.18896,14.64566"
-       id="path5336"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 795.61945,306.59042 76.18896,14.64566"
-       id="path5338"
-       inkscape:connector-curvature="0"
-       style="fill-rule:nonzero;stroke:#999999;stroke-width:1;stroke-linecap:butt;stroke-linejoin:round" />
-    <path
-       d="m 25.223194,416.41708 182.582686,0 0,34.48819 -182.582686,0 z"
-       id="path5340"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 81.940315,436.93707 -2.046875,-8.59375 1.765625,0 1.296875,5.90625 1.578125,-5.90625 2.0625,0 1.5,6 1.3125,-6 1.75,0 -2.078125,8.59375 -1.84375,0 -1.71875,-6.42188 -1.703125,6.42188 -1.875,0 z m 10.091797,-7.0625 0,-1.53125 1.65625,0 0,1.53125 -1.65625,0 z m 0,7.0625 0,-6.21875 1.65625,0 0,6.21875 -1.65625,0 z m 9.037108,0 -1.515624,0 0,-0.92188 q -0.390625,0.54688 -0.90625,0.8125 -0.515625,0.25 -1.046875,0.25 -1.078125,0 -1.84375,-0.85937 -0.75,-0.875 -0.75,-2.42188 0,-1.57812 0.734375,-2.39062 0.75,-0.82813 1.890625,-0.82813 1.03125,0 1.796875,0.85938 l 0,-3.09375 1.640624,0 0,8.59375 z m -4.390624,-3.25 q 0,1 0.28125,1.4375 0.390625,0.65625 1.109375,0.65625 0.5625,0 0.953125,-0.48438 0.40625,-0.48437 0.40625,-1.45312 0,-1.0625 -0.390625,-1.53125 -0.375,-0.48438 -0.984375,-0.48438 -0.578125,0 -0.984375,0.46875 -0.390625,0.46875 -0.390625,1.39063 z m 9.626954,1.26562 1.64062,0.28125 q -0.3125,0.90625 -1,1.375 -0.6875,0.46875 -1.70312,0.46875 -1.625,0 -2.40625,-1.0625 -0.625,-0.84375 -0.625,-2.14062 0,-1.54688 0.8125,-2.42188 0.8125,-0.875 2.04687,-0.875 1.39063,0 2.1875,0.92188 0.8125,0.90625 0.76563,2.79687 l -4.125,0 q 0.0312,0.73438 0.40625,1.14063 0.375,0.40625 0.95312,0.40625 0.375,0 0.64063,-0.20313 0.26562,-0.21875 0.40625,-0.6875 z m 0.0937,-1.65625 q -0.0156,-0.71875 -0.375,-1.09375 -0.34375,-0.375 -0.85938,-0.375 -0.53125,0 -0.89062,0.39063 -0.34375,0.40625 -0.34375,1.07812 l 2.46875,0 z m 6.28906,3.64063 0,-8.59375 2.60938,0 1.54687,5.85937 1.54688,-5.85937 2.59375,0 0,8.59375 -1.60938,0 0,-6.76563 -1.70312,6.76563 -1.67188,0 -1.6875,-6.76563 0,6.76563 -1.625,0 z m 9.63672,-3.20313 q 0,-0.8125 0.40625,-1.57812 0.40625,-0.78125 1.14063,-1.17188 0.73437,-0.40625 1.65625,-0.40625 1.40625,0 2.3125,0.92188 0.90624,0.90625 0.90624,2.3125 0,1.40625 -0.92187,2.34375 -0.90625,0.92187 -2.28125,0.92187 -0.85937,0 -1.64062,-0.39062 -0.76563,-0.39063 -1.17188,-1.125 -0.40625,-0.75 -0.40625,-1.82813 z m 1.6875,0.0937 q 0,0.92188 0.4375,1.42188 0.4375,0.48437 1.07813,0.48437 0.65625,0 1.07812,-0.48437 0.4375,-0.5 0.4375,-1.4375 0,-0.90625 -0.4375,-1.39063 -0.42187,-0.5 -1.07812,-0.5 -0.64063,0 -1.07813,0.5 -0.4375,0.48438 -0.4375,1.40625 z m 11.7207,3.10938 -1.51563,0 0,-0.92188 q -0.39062,0.54688 -0.90625,0.8125 -0.51562,0.25 -1.04687,0.25 -1.07813,0 -1.84375,-0.85937 -0.75,-0.875 -0.75,-2.42188 0,-1.57812 0.73437,-2.39062 0.75,-0.82813 1.89063,-0.82813 1.03125,0 1.79687,0.85938 l 0,-3.09375 1.64063,0 0,8.59375 z m -4.39063,-3.25 q 0,1 0.28125,1.4375 0.39063,0.65625 1.10938,0.65625 0.5625,0 0.95312,-0.48438 0.40625,-0.48437 0.40625,-1.45312 0,-1.0625 -0.39062,-1.53125 -0.375,-0.48438 -0.98438,-0.48438 -0.57812,0 -0.98437,0.46875 -0.39063,0.46875 -0.39063,1.39063 z m 9.62695,1.26562 1.64063,0.28125 q -0.3125,0.90625 -1,1.375 -0.6875,0.46875 -1.70313,0.46875 -1.625,0 -2.40625,-1.0625 -0.625,-0.84375 -0.625,-2.14062 0,-1.54688 0.8125,-2.42188 0.8125,-0.875 2.04688,-0.875 1.39062,0 2.1875,0.92188 0.8125,0.90625 0.76562,2.79687 l -4.125,0 q 0.0312,0.73438 0.40625,1.14063 0.375,0.40625 0.95313,0.40625 0.375,0 0.64062,-0.20313 0.26563,-0.21875 0.40625,-0.6875 z m 0.0937,-1.65625 q -0.0156,-0.71875 -0.375,-1.09375 -0.34375,-0.375 -0.85937,-0.375 -0.53125,0 -0.89063,0.39063 -0.34375,0.40625 -0.34375,1.07812 l 2.46875,0 z m 2.97071,3.64063 0,-8.59375 1.65625,0 0,8.59375 -1.65625,0 z m 2.75586,-1.78125 1.65625,-0.25 q 0.10937,0.48437 0.42187,0.73437 0.32813,0.25 0.90625,0.25 0.64063,0 0.95313,-0.23437 0.21875,-0.17188 0.21875,-0.4375 0,-0.1875 -0.10938,-0.3125 -0.125,-0.125 -0.54687,-0.21875 -2,-0.4375 -2.53125,-0.79688 -0.73438,-0.51562 -0.73438,-1.40625 0,-0.8125 0.625,-1.35937 0.64063,-0.54688 1.98438,-0.54688 1.28125,0 1.90625,0.42188 0.625,0.40625 0.85937,1.21875 l -1.5625,0.28125 q -0.0937,-0.35938 -0.375,-0.54688 -0.28125,-0.20312 -0.79687,-0.20312 -0.64063,0 -0.92188,0.1875 -0.1875,0.125 -0.1875,0.32812 0,0.1875 0.15625,0.3125 0.21875,0.15625 1.53125,0.45313 1.32813,0.29687 1.84375,0.73437 0.51563,0.4375 0.51563,1.21875 0,0.85938 -0.71875,1.48438 -0.70313,0.60937 -2.10938,0.60937 -1.26562,0 -2.01562,-0.51562 -0.73438,-0.51563 -0.96875,-1.40625 z"
-       id="path5342"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-rule:nonzero" />
-    <path
-       d="m 743.672,415.30432 182.5827,0 0,36.31497 -182.5827,0 z"
-       id="path5344"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 798.83057,427.23056 3.15625,0 q 1.07812,0 1.64062,0.17188 0.75,0.21875 1.28125,0.79687 0.54688,0.5625 0.82813,1.39063 0.28125,0.8125 0.28125,2.01562 0,1.0625 -0.25,1.82813 -0.32813,0.9375 -0.92188,1.51562 -0.45312,0.45313 -1.21875,0.6875 -0.57812,0.1875 -1.54687,0.1875 l -3.25,0 0,-8.59375 z m 1.73437,1.45313 0,5.6875 1.28125,0 q 0.73438,0 1.0625,-0.0781 0.42188,-0.10937 0.6875,-0.35937 0.28125,-0.25 0.45313,-0.82813 0.1875,-0.57812 0.1875,-1.57812 0,-0.98438 -0.1875,-1.51563 -0.17188,-0.54687 -0.48438,-0.84375 -0.3125,-0.29687 -0.79687,-0.40625 -0.375,-0.0781 -1.4375,-0.0781 l -0.76563,0 z m 10.5254,5.15625 1.64062,0.28125 q -0.3125,0.90625 -1,1.375 -0.6875,0.46875 -1.70312,0.46875 -1.625,0 -2.40625,-1.0625 -0.625,-0.84375 -0.625,-2.14063 0,-1.54687 0.8125,-2.42187 0.8125,-0.875 2.04687,-0.875 1.39063,0 2.1875,0.92187 0.8125,0.90625 0.76563,2.79688 l -4.125,0 q 0.0312,0.73437 0.40625,1.14062 0.375,0.40625 0.95312,0.40625 0.375,0 0.64063,-0.20312 0.26562,-0.21875 0.40625,-0.6875 z m 0.0937,-1.65625 q -0.0156,-0.71875 -0.375,-1.09375 -0.34375,-0.375 -0.85938,-0.375 -0.53125,0 -0.89062,0.39062 -0.34375,0.40625 -0.34375,1.07813 l 2.46875,0 z m 6.58007,1.65625 1.64063,0.28125 q -0.3125,0.90625 -1,1.375 -0.6875,0.46875 -1.70313,0.46875 -1.625,0 -2.40625,-1.0625 -0.625,-0.84375 -0.625,-2.14063 0,-1.54687 0.8125,-2.42187 0.8125,-0.875 2.04688,-0.875 1.39062,0 2.1875,0.92187 0.8125,0.90625 0.76562,2.79688 l -4.125,0 q 0.0312,0.73437 0.40625,1.14062 0.375,0.40625 0.95313,0.40625 0.375,0 0.64062,-0.20312 0.26563,-0.21875 0.40625,-0.6875 z m 0.0937,-1.65625 q -0.0156,-0.71875 -0.375,-1.09375 -0.34375,-0.375 -0.85937,-0.375 -0.53125,0 -0.89063,0.39062 -0.34375,0.40625 -0.34375,1.07813 l 2.46875,0 z m 2.92383,-2.57813 1.53125,0 0,0.90625 q 0.3125,-0.46875 0.8125,-0.75 0.51563,-0.29687 1.14063,-0.29687 1.07812,0 1.82812,0.84375 0.76563,0.84375 0.76563,2.375 0,1.54687 -0.76563,2.42187 -0.76562,0.85938 -1.84375,0.85938 -0.51562,0 -0.9375,-0.20313 -0.42187,-0.20312 -0.875,-0.70312 l 0,3.14062 -1.65625,0 0,-8.59375 z m 1.625,3 q 0,1.04688 0.42188,1.54688 0.42187,0.5 1.01562,0.5 0.57813,0 0.95313,-0.45313 0.375,-0.46875 0.375,-1.51562 0,-0.96875 -0.39063,-1.4375 -0.39062,-0.48438 -0.96875,-0.48438 -0.60937,0 -1.01562,0.46875 -0.39063,0.46875 -0.39063,1.375 z m 9.07031,3.21875 0,-8.59375 2.60938,0 1.54687,5.85938 1.54688,-5.85938 2.59375,0 0,8.59375 -1.60938,0 0,-6.76562 -1.70312,6.76562 -1.67188,0 -1.6875,-6.76562 0,6.76562 -1.625,0 z m 9.63672,-3.20312 q 0,-0.8125 0.40625,-1.57813 0.40625,-0.78125 1.14063,-1.17187 0.73437,-0.40625 1.65625,-0.40625 1.40625,0 2.3125,0.92187 0.90625,0.90625 0.90625,2.3125 0,1.40625 -0.92188,2.34375 -0.90625,0.92188 -2.28125,0.92188 -0.85937,0 -1.64062,-0.39063 -0.76563,-0.39062 -1.17188,-1.125 -0.40625,-0.75 -0.40625,-1.82812 z m 1.6875,0.0937 q 0,0.92187 0.4375,1.42187 0.4375,0.48438 1.07813,0.48438 0.65625,0 1.07812,-0.48438 0.4375,-0.5 0.4375,-1.4375 0,-0.90625 -0.4375,-1.39062 -0.42187,-0.5 -1.07812,-0.5 -0.64063,0 -1.07813,0.5 -0.4375,0.48437 -0.4375,1.40625 z m 11.72071,3.10937 -1.51563,0 0,-0.92187 q -0.39062,0.54687 -0.90625,0.8125 -0.51562,0.25 -1.04687,0.25 -1.07813,0 -1.84375,-0.85938 -0.75,-0.875 -0.75,-2.42187 0,-1.57813 0.73437,-2.39063 0.75,-0.82812 1.89063,-0.82812 1.03125,0 1.79687,0.85937 l 0,-3.09375 1.64063,0 0,8.59375 z m -4.39063,-3.25 q 0,1 0.28125,1.4375 0.39063,0.65625 1.10938,0.65625 0.5625,0 0.95312,-0.48437 0.40625,-0.48438 0.40625,-1.45313 0,-1.0625 -0.39062,-1.53125 -0.375,-0.48437 -0.98438,-0.48437 -0.57812,0 -0.98437,0.46875 -0.39063,0.46875 -0.39063,1.39062 z m 9.62695,1.26563 1.64063,0.28125 q -0.3125,0.90625 -1,1.375 -0.6875,0.46875 -1.70313,0.46875 -1.625,0 -2.40625,-1.0625 -0.625,-0.84375 -0.625,-2.14063 0,-1.54687 0.8125,-2.42187 0.8125,-0.875 2.04688,-0.875 1.39062,0 2.1875,0.92187 0.8125,0.90625 0.76562,2.79688 l -4.125,0 q 0.0312,0.73437 0.40625,1.14062 0.375,0.40625 0.95313,0.40625 0.375,0 0.64062,-0.20312 0.26563,-0.21875 0.40625,-0.6875 z m 0.0937,-1.65625 q -0.0156,-0.71875 -0.375,-1.09375 -0.34375,-0.375 -0.85937,-0.375 -0.53125,0 -0.89063,0.39062 -0.34375,0.40625 -0.34375,1.07813 l 2.46875,0 z m 2.97071,3.64062 0,-8.59375 1.65625,0 0,8.59375 -1.65625,0 z m 2.75586,-1.78125 1.65625,-0.25 q 0.10937,0.48438 0.42187,0.73438 0.32813,0.25 0.90625,0.25 0.64063,0 0.95313,-0.23438 0.21875,-0.17187 0.21875,-0.4375 0,-0.1875 -0.10938,-0.3125 -0.125,-0.125 -0.54687,-0.21875 -2,-0.4375 -2.53125,-0.79687 -0.73438,-0.51563 -0.73438,-1.40625 0,-0.8125 0.625,-1.35938 0.64063,-0.54687 1.98438,-0.54687 1.28125,0 1.90625,0.42187 0.625,0.40625 0.85937,1.21875 l -1.5625,0.28125 q -0.0937,-0.35937 -0.375,-0.54687 -0.28125,-0.20313 -0.79687,-0.20313 -0.64063,0 -0.92188,0.1875 -0.1875,0.125 -0.1875,0.32813 0,0.1875 0.15625,0.3125 0.21875,0.15625 1.53125,0.45312 1.32813,0.29688 1.84375,0.73438 0.51563,0.4375 0.51563,1.21875 0,0.85937 -0.71875,1.48437 -0.70313,0.60938 -2.10938,0.60938 -1.26562,0 -2.01562,-0.51563 -0.73438,-0.51562 -0.96875,-1.40625 z"
-       id="path5346"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-rule:nonzero" />
-    <path
-       d="m 356.7665,416.4171 182.58267,0 0,34.48819 -182.58267,0 z"
-       id="path5348"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 391.14474,436.9371 -2.04687,-8.59375 1.76562,0 1.29688,5.90625 1.57812,-5.90625 2.0625,0 1.5,6 1.3125,-6 1.75,0 -2.07812,8.59375 -1.84375,0 -1.71875,-6.42188 -1.70313,6.42188 -1.875,0 z m 10.0918,-7.0625 0,-1.53125 1.65625,0 0,1.53125 -1.65625,0 z m 0,7.0625 0,-6.21875 1.65625,0 0,6.21875 -1.65625,0 z m 9.03711,0 -1.51563,0 0,-0.92188 q -0.39062,0.54688 -0.90625,0.8125 -0.51562,0.25 -1.04687,0.25 -1.07813,0 -1.84375,-0.85937 -0.75,-0.875 -0.75,-2.42188 0,-1.57812 0.73437,-2.39062 0.75,-0.82813 1.89063,-0.82813 1.03125,0 1.79687,0.85938 l 0,-3.09375 1.64063,0 0,8.59375 z m -4.39063,-3.25 q 0,1 0.28125,1.4375 0.39063,0.65625 1.10938,0.65625 0.5625,0 0.95312,-0.48438 0.40625,-0.48437 0.40625,-1.45312 0,-1.0625 -0.39062,-1.53125 -0.375,-0.48438 -0.98438,-0.48438 -0.57812,0 -0.98437,0.46875 -0.39063,0.46875 -0.39063,1.39063 z m 9.62695,1.26562 1.64063,0.28125 q -0.3125,0.90625 -1,1.375 -0.6875,0.46875 -1.70313,0.46875 -1.625,0 -2.40625,-1.0625 -0.625,-0.84375 -0.625,-2.14062 0,-1.54688 0.8125,-2.42188 0.8125,-0.875 2.04688,-0.875 1.39062,0 2.1875,0.92188 0.8125,0.90625 0.76562,2.79687 l -4.125,0 q 0.0312,0.73438 0.40625,1.14063 0.375,0.40625 0.95313,0.40625 0.375,0 0.64062,-0.20313 0.26563,-0.21875 0.40625,-0.6875 z m 0.0937,-1.65625 q -0.0156,-0.71875 -0.375,-1.09375 -0.34375,-0.375 -0.85937,-0.375 -0.53125,0 -0.89063,0.39063 -0.34375,0.40625 -0.34375,1.07812 l 2.46875,0 z m 13.91407,2.59375 -0.98438,1.26563 q -0.73437,-0.35938 -1.40625,-0.98438 -0.53125,0.46875 -1.125,0.70313 -0.57812,0.21875 -1.39062,0.21875 -1.59375,0 -2.40625,-0.90625 -0.625,-0.70313 -0.625,-1.60938 0,-0.82812 0.48437,-1.48437 0.5,-0.65625 1.46875,-1.14063 -0.4375,-0.53125 -0.65625,-1 -0.21875,-0.46875 -0.21875,-0.89062 0,-0.78125 0.625,-1.3125 0.625,-0.54688 1.78125,-0.54688 1.10938,0 1.73438,0.57813 0.625,0.5625 0.625,1.375 0,0.51562 -0.3125,0.98437 -0.3125,0.46875 -1.25,1.0625 l 1.1875,1.57813 q 0.21875,-0.375 0.375,-0.96875 l 1.48437,0.32812 q -0.21875,0.79688 -0.39062,1.17188 -0.15625,0.35937 -0.34375,0.60937 0.26562,0.25 0.70312,0.5625 0.4375,0.29688 0.64063,0.40625 z m -4.48438,-4.67187 0.45313,-0.34375 q 0.48437,-0.375 0.48437,-0.75 0,-0.3125 -0.23437,-0.53125 -0.23438,-0.23438 -0.64063,-0.23438 -0.39062,0 -0.60937,0.20313 -0.21875,0.1875 -0.21875,0.45312 0,0.29688 0.375,0.73438 l 0.39062,0.46875 z m -0.64062,1.78125 q -0.5625,0.29687 -0.84375,0.70312 -0.28125,0.39063 -0.28125,0.8125 0,0.54688 0.34375,0.89063 0.34375,0.32812 0.9375,0.32812 0.39062,0 0.73437,-0.15625 0.35938,-0.15625 0.78125,-0.5 l -1.67187,-2.07812 z m 9.53125,-4.65625 3.15625,0 q 1.07812,0 1.64062,0.17187 0.75,0.21875 1.28125,0.79688 0.54688,0.5625 0.82813,1.39062 0.28125,0.8125 0.28125,2.01563 0,1.0625 -0.25,1.82812 -0.32813,0.9375 -0.92188,1.51563 -0.45312,0.45312 -1.21875,0.6875 -0.57812,0.1875 -1.54687,0.1875 l -3.25,0 0,-8.59375 z m 1.73437,1.45312 0,5.6875 1.28125,0 q 0.73438,0 1.0625,-0.0781 0.42188,-0.10938 0.6875,-0.35938 0.28125,-0.25 0.45313,-0.82812 0.1875,-0.57813 0.1875,-1.57813 0,-0.98437 -0.1875,-1.51562 -0.17188,-0.54688 -0.48438,-0.84375 -0.3125,-0.29688 -0.79687,-0.40625 -0.375,-0.0781 -1.4375,-0.0781 l -0.76563,0 z m 10.52539,5.15625 1.64063,0.28125 q -0.3125,0.90625 -1,1.375 -0.6875,0.46875 -1.70313,0.46875 -1.625,0 -2.40625,-1.0625 -0.625,-0.84375 -0.625,-2.14062 0,-1.54688 0.8125,-2.42188 0.8125,-0.875 2.04688,-0.875 1.39062,0 2.1875,0.92188 0.8125,0.90625 0.76562,2.79687 l -4.125,0 q 0.0312,0.73438 0.40625,1.14063 0.375,0.40625 0.95313,0.40625 0.375,0 0.64062,-0.20313 0.26563,-0.21875 0.40625,-0.6875 z m 0.0937,-1.65625 q -0.0156,-0.71875 -0.375,-1.09375 -0.34375,-0.375 -0.85937,-0.375 -0.53125,0 -0.89063,0.39063 -0.34375,0.40625 -0.34375,1.07812 l 2.46875,0 z m 6.58008,1.65625 1.64063,0.28125 q -0.3125,0.90625 -1,1.375 -0.6875,0.46875 -1.70313,0.46875 -1.625,0 -2.40625,-1.0625 -0.625,-0.84375 -0.625,-2.14062 0,-1.54688 0.8125,-2.42188 0.8125,-0.875 2.04688,-0.875 1.39062,0 2.1875,0.92188 0.8125,0.90625 0.76562,2.79687 l -4.125,0 q 0.0312,0.73438 0.40625,1.14063 0.375,0.40625 0.95313,0.40625 0.375,0 0.64062,-0.20313 0.26563,-0.21875 0.40625,-0.6875 z m 0.0937,-1.65625 q -0.0156,-0.71875 -0.375,-1.09375 -0.34375,-0.375 -0.85937,-0.375 -0.53125,0 -0.89063,0.39063 -0.34375,0.40625 -0.34375,1.07812 l 2.46875,0 z m 2.92383,-2.57812 1.53125,0 0,0.90625 q 0.3125,-0.46875 0.8125,-0.75 0.51562,-0.29688 1.14062,-0.29688 1.07813,0 1.82813,0.84375 0.76562,0.84375 0.76562,2.375 0,1.54688 -0.76562,2.42188 -0.76563,0.85937 -1.84375,0.85937 -0.51563,0 -0.9375,-0.20312 -0.42188,-0.20313 -0.875,-0.70313 l 0,3.14063 -1.65625,0 0,-8.59375 z m 1.625,3 q 0,1.04687 0.42187,1.54687 0.42188,0.5 1.01563,0.5 0.57812,0 0.95312,-0.45312 0.375,-0.46875 0.375,-1.51563 0,-0.96875 -0.39062,-1.4375 -0.39063,-0.48437 -0.96875,-0.48437 -0.60938,0 -1.01563,0.46875 -0.39062,0.46875 -0.39062,1.375 z m 9.07031,3.21875 0,-8.59375 2.60938,0 1.54687,5.85937 1.54688,-5.85937 2.59375,0 0,8.59375 -1.60938,0 0,-6.76563 -1.70312,6.76563 -1.67188,0 -1.6875,-6.76563 0,6.76563 -1.625,0 z m 9.63672,-3.20313 q 0,-0.8125 0.40625,-1.57812 0.40625,-0.78125 1.14063,-1.17188 0.73437,-0.40625 1.65625,-0.40625 1.40625,0 2.3125,0.92188 0.90625,0.90625 0.90625,2.3125 0,1.40625 -0.92188,2.34375 -0.90625,0.92187 -2.28125,0.92187 -0.85937,0 -1.64062,-0.39062 -0.76563,-0.39063 -1.17188,-1.125 -0.40625,-0.75 -0.40625,-1.82813 z m 1.6875,0.0937 q 0,0.92188 0.4375,1.42188 0.4375,0.48437 1.07813,0.48437 0.65625,0 1.07812,-0.48437 0.4375,-0.5 0.4375,-1.4375 0,-0.90625 -0.4375,-1.39063 -0.42187,-0.5 -1.07812,-0.5 -0.64063,0 -1.07813,0.5 -0.4375,0.48438 -0.4375,1.40625 z m 11.7207,3.10938 -1.51562,0 0,-0.92188 q -0.39063,0.54688 -0.90625,0.8125 -0.51563,0.25 -1.04688,0.25 -1.07812,0 -1.84375,-0.85937 -0.75,-0.875 -0.75,-2.42188 0,-1.57812 0.73438,-2.39062 0.75,-0.82813 1.89062,-0.82813 1.03125,0 1.79688,0.85938 l 0,-3.09375 1.64062,0 0,8.59375 z m -4.39062,-3.25 q 0,1 0.28125,1.4375 0.39062,0.65625 1.10937,0.65625 0.5625,0 0.95313,-0.48438 0.40625,-0.48437 0.40625,-1.45312 0,-1.0625 -0.39063,-1.53125 -0.375,-0.48438 -0.98437,-0.48438 -0.57813,0 -0.98438,0.46875 -0.39062,0.46875 -0.39062,1.39063 z m 9.62695,1.26562 1.64063,0.28125 q -0.3125,0.90625 -1,1.375 -0.6875,0.46875 -1.70313,0.46875 -1.625,0 -2.40625,-1.0625 -0.625,-0.84375 -0.625,-2.14062 0,-1.54688 0.8125,-2.42188 0.8125,-0.875 2.04688,-0.875 1.39062,0 2.1875,0.92188 0.8125,0.90625 0.76562,2.79687 l -4.125,0 q 0.0312,0.73438 0.40625,1.14063 0.375,0.40625 0.95313,0.40625 0.375,0 0.64062,-0.20313 0.26563,-0.21875 0.40625,-0.6875 z m 0.0937,-1.65625 q -0.0156,-0.71875 -0.375,-1.09375 -0.34375,-0.375 -0.85937,-0.375 -0.53125,0 -0.89063,0.39063 -0.34375,0.40625 -0.34375,1.07812 l 2.46875,0 z m 2.9707,3.64063 0,-8.59375 1.65625,0 0,8.59375 -1.65625,0 z m 2.75586,-1.78125 1.65625,-0.25 q 0.10938,0.48437 0.42188,0.73437 0.32812,0.25 0.90625,0.25 0.64062,0 0.95312,-0.23437 0.21875,-0.17188 0.21875,-0.4375 0,-0.1875 -0.10937,-0.3125 -0.125,-0.125 -0.54688,-0.21875 -2,-0.4375 -2.53125,-0.79688 -0.73437,-0.51562 -0.73437,-1.40625 0,-0.8125 0.625,-1.35937 0.64062,-0.54688 1.98437,-0.54688 1.28125,0 1.90625,0.42188 0.625,0.40625 0.85938,1.21875 l -1.5625,0.28125 q -0.0937,-0.35938 -0.375,-0.54688 -0.28125,-0.20312 -0.79688,-0.20312 -0.64062,0 -0.92187,0.1875 -0.1875,0.125 -0.1875,0.32812 0,0.1875 0.15625,0.3125 0.21875,0.15625 1.53125,0.45313 1.32812,0.29687 1.84375,0.73437 0.51562,0.4375 0.51562,1.21875 0,0.85938 -0.71875,1.48438 -0.70312,0.60937 -2.10937,0.60937 -1.26563,0 -2.01563,-0.51562 -0.73437,-0.51563 -0.96875,-1.40625 z"
-       id="path5350"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-rule:nonzero" />
-    <path
-       d="m 601.66156,300.39102 120.66144,0 0,30.11023 -120.66144,0 z"
-       id="path5352"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 621.5206,320.911 0,-8.59375 1.73437,0 0,3.39062 3.40625,0 0,-3.39062 1.73438,0 0,8.59375 -1.73438,0 0,-3.75 -3.40625,0 0,3.75 -1.73437,0 z m 8.65039,-7.0625 0,-1.53125 1.65625,0 0,1.53125 -1.65625,0 z m 0,7.0625 0,-6.21875 1.65625,0 0,6.21875 -1.65625,0 z m 9.03711,0 -1.51563,0 0,-0.92188 q -0.39062,0.54688 -0.90625,0.8125 -0.51562,0.25 -1.04687,0.25 -1.07813,0 -1.84375,-0.85937 -0.75,-0.875 -0.75,-2.42188 0,-1.57812 0.73437,-2.39062 0.75,-0.82813 1.89063,-0.82813 1.03125,0 1.79687,0.85938 l 0,-3.09375 1.64063,0 0,8.59375 z m -4.39063,-3.25 q 0,1 0.28125,1.4375 0.39063,0.65625 1.10938,0.65625 0.5625,0 0.95312,-0.48438 0.40625,-0.48437 0.40625,-1.45312 0,-1.0625 -0.39062,-1.53125 -0.375,-0.48438 -0.98438,-0.48438 -0.57812,0 -0.98437,0.46875 -0.39063,0.46875 -0.39063,1.39063 z m 11.72071,3.25 -1.51563,0 0,-0.92188 q -0.39062,0.54688 -0.90625,0.8125 -0.51562,0.25 -1.04687,0.25 -1.07813,0 -1.84375,-0.85937 -0.75,-0.875 -0.75,-2.42188 0,-1.57812 0.73437,-2.39062 0.75,-0.82813 1.89063,-0.82813 1.03125,0 1.79687,0.85938 l 0,-3.09375 1.64063,0 0,8.59375 z m -4.39063,-3.25 q 0,1 0.28125,1.4375 0.39063,0.65625 1.10938,0.65625 0.5625,0 0.95312,-0.48438 0.40625,-0.48437 0.40625,-1.45312 0,-1.0625 -0.39062,-1.53125 -0.375,-0.48438 -0.98438,-0.48438 -0.57812,0 -0.98437,0.46875 -0.39063,0.46875 -0.39063,1.39063 z m 9.62696,1.26562 1.64062,0.28125 q -0.3125,0.90625 -1,1.375 -0.6875,0.46875 -1.70312,0.46875 -1.625,0 -2.40625,-1.0625 -0.625,-0.84375 -0.625,-2.14062 0,-1.54688 0.8125,-2.42188 0.8125,-0.875 2.04687,-0.875 1.39063,0 2.1875,0.92188 0.8125,0.90625 0.76563,2.79687 l -4.125,0 q 0.0312,0.73438 0.40625,1.14063 0.375,0.40625 0.95312,0.40625 0.375,0 0.64063,-0.20313 0.26562,-0.21875 0.40625,-0.6875 z m 0.0937,-1.65625 q -0.0156,-0.71875 -0.375,-1.09375 -0.34375,-0.375 -0.85938,-0.375 -0.53125,0 -0.89062,0.39063 -0.34375,0.40625 -0.34375,1.07812 l 2.46875,0 z m 8.62695,3.64063 -1.64063,0 0,-3.17188 q 0,-1.01562 -0.10937,-1.3125 -0.0937,-0.29687 -0.34375,-0.45312 -0.23438,-0.17188 -0.5625,-0.17188 -0.4375,0 -0.78125,0.23438 -0.32813,0.23437 -0.45313,0.625 -0.125,0.39062 -0.125,1.4375 l 0,2.8125 -1.65625,0 0,-6.21875 1.53125,0 0,0.90625 q 0.8125,-1.04688 2.0625,-1.04688 0.53125,0 0.98438,0.20313 0.45312,0.1875 0.6875,0.5 0.23437,0.29687 0.3125,0.6875 0.0937,0.375 0.0937,1.09375 l 0,3.875 z m 5.07031,0 0,-8.51563 1.73438,0 0,7.0625 4.3125,0 0,1.45313 -6.04688,0 z m 8.50195,-4.32813 -1.5,-0.26562 q 0.25,-0.90625 0.85938,-1.32813 0.625,-0.4375 1.84375,-0.4375 1.09375,0 1.625,0.26563 0.54687,0.25 0.76562,0.65625 0.21875,0.39062 0.21875,1.46875 l -0.0156,1.92187 q 0,0.82813 0.0781,1.21875 0.0781,0.375 0.29688,0.82813 l -1.625,0 q -0.0625,-0.17188 -0.15625,-0.48438 -0.0469,-0.15625 -0.0625,-0.20312 -0.42188,0.42187 -0.90625,0.625 -0.46875,0.20312 -1.01563,0.20312 -0.96875,0 -1.53125,-0.51562 -0.54687,-0.53125 -0.54687,-1.32813 0,-0.53125 0.25,-0.9375 0.26562,-0.40625 0.71875,-0.625 0.45312,-0.23437 1.3125,-0.39062 1.14062,-0.21875 1.59375,-0.40625 l 0,-0.15625 q 0,-0.48438 -0.23438,-0.6875 -0.23437,-0.20313 -0.89062,-0.20313 -0.4375,0 -0.6875,0.17188 -0.23438,0.17187 -0.39063,0.60937 z m 2.20313,1.34375 q -0.3125,0.0937 -1,0.25 -0.6875,0.14063 -0.90625,0.28125 -0.3125,0.23438 -0.3125,0.57813 0,0.34375 0.25,0.60937 0.26562,0.25 0.65625,0.25 0.45312,0 0.85937,-0.29687 0.29688,-0.21875 0.39063,-0.54688 0.0625,-0.20312 0.0625,-0.79687 l 0,-0.32813 z m 2.45508,-3.23437 1.75,0 1.5,4.40625 1.45312,-4.40625 1.70313,0 -2.20313,5.98437 -0.39062,1.07813 q -0.21875,0.54687 -0.42188,0.82812 -0.1875,0.29688 -0.45312,0.46875 -0.25,0.1875 -0.625,0.28125 -0.35938,0.10938 -0.82813,0.10938 -0.48437,0 -0.9375,-0.10938 l -0.14062,-1.28125 q 0.39062,0.0781 0.6875,0.0781 0.57812,0 0.84375,-0.34375 0.28125,-0.32813 0.4375,-0.85938 l -2.375,-6.23437 z m 11.06445,4.23437 1.64063,0.28125 q -0.3125,0.90625 -1,1.375 -0.6875,0.46875 -1.70313,0.46875 -1.625,0 -2.40625,-1.0625 -0.625,-0.84375 -0.625,-2.14062 0,-1.54688 0.8125,-2.42188 0.8125,-0.875 2.04688,-0.875 1.39062,0 2.1875,0.92188 0.8125,0.90625 0.76562,2.79687 l -4.125,0 q 0.0312,0.73438 0.40625,1.14063 0.375,0.40625 0.95313,0.40625 0.375,0 0.64062,-0.20313 0.26563,-0.21875 0.40625,-0.6875 z m 0.0937,-1.65625 q -0.0156,-0.71875 -0.375,-1.09375 -0.34375,-0.375 -0.85937,-0.375 -0.53125,0 -0.89063,0.39063 -0.34375,0.40625 -0.34375,1.07812 l 2.46875,0 z m 4.54883,3.64063 -1.64063,0 0,-6.21875 1.53125,0 0,0.875 q 0.39063,-0.625 0.70313,-0.8125 0.3125,-0.20313 0.70312,-0.20313 0.5625,0 1.09375,0.3125 l -0.51562,1.42188 q -0.42188,-0.26563 -0.76563,-0.26563 -0.35937,0 -0.59375,0.20313 -0.23437,0.1875 -0.375,0.6875 -0.14062,0.48437 -0.14062,2.07812 l 0,1.92188 z m 2.51367,-1.78125 1.65625,-0.25 q 0.10938,0.48437 0.42188,0.73437 0.32812,0.25 0.90625,0.25 0.64062,0 0.95312,-0.23437 0.21875,-0.17188 0.21875,-0.4375 0,-0.1875 -0.10937,-0.3125 -0.125,-0.125 -0.54688,-0.21875 -2,-0.4375 -2.53125,-0.79688 -0.73437,-0.51562 -0.73437,-1.40625 0,-0.8125 0.625,-1.35937 0.64062,-0.54688 1.98437,-0.54688 1.28125,0 1.90625,0.42188 0.625,0.40625 0.85938,1.21875 l -1.5625,0.28125 q -0.0937,-0.35938 -0.375,-0.54688 -0.28125,-0.20312 -0.79688,-0.20312 -0.64062,0 -0.92187,0.1875 -0.1875,0.125 -0.1875,0.32812 0,0.1875 0.15625,0.3125 0.21875,0.15625 1.53125,0.45313 1.32812,0.29687 1.84375,0.73437 0.51562,0.4375 0.51562,1.21875 0,0.85938 -0.71875,1.48438 -0.70312,0.60937 -2.10937,0.60937 -1.26563,0 -2.01563,-0.51562 -0.73437,-0.51563 -0.96875,-1.40625 z"
-       id="path5354"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-rule:nonzero" />
-    <path
-       d="m 605.00006,385.7558 120.66144,0 0,30.11026 -120.66144,0 z"
-       id="path5356"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 619.0778,403.47894 1.6875,-0.15625 q 0.15625,0.84375 0.60938,1.25 0.46875,0.39063 1.26562,0.39063 0.82813,0 1.25,-0.34375 0.4375,-0.35938 0.4375,-0.84375 0,-0.29688 -0.1875,-0.51563 -0.17187,-0.21875 -0.625,-0.375 -0.29687,-0.10937 -1.375,-0.375 -1.40625,-0.34375 -1.96875,-0.84375 -0.78125,-0.71875 -0.78125,-1.73437 0,-0.65625 0.35938,-1.21875 0.375,-0.57813 1.07812,-0.875 0.70313,-0.29688 1.6875,-0.29688 1.625,0 2.4375,0.71875 0.8125,0.70313 0.85938,1.875 l -1.73438,0.0781 q -0.10937,-0.65625 -0.48437,-0.9375 -0.35938,-0.29688 -1.09375,-0.29688 -0.75,0 -1.1875,0.3125 -0.26563,0.1875 -0.26563,0.53125 0,0.29688 0.25,0.51563 0.32813,0.28125 1.59375,0.57812 1.26563,0.29688 1.875,0.625 0.60938,0.3125 0.95313,0.875 0.34375,0.54688 0.34375,1.35938 0,0.73437 -0.42188,1.39062 -0.40625,0.64063 -1.15625,0.95313 -0.75,0.3125 -1.85937,0.3125 -1.64063,0 -2.51563,-0.75 -0.85937,-0.76563 -1.03125,-2.20313 z m 8.37891,-3.42187 1.53125,0 0,0.90625 q 0.3125,-0.46875 0.8125,-0.75 0.51562,-0.29688 1.14062,-0.29688 1.07813,0 1.82813,0.84375 0.76562,0.84375 0.76562,2.375 0,1.54688 -0.76562,2.42188 -0.76563,0.85937 -1.84375,0.85937 -0.51563,0 -0.9375,-0.20312 -0.42188,-0.20313 -0.875,-0.70313 l 0,3.14063 -1.65625,0 0,-8.59375 z m 1.625,3 q 0,1.04687 0.42187,1.54687 0.42188,0.5 1.01563,0.5 0.57812,0 0.95312,-0.45312 0.375,-0.46875 0.375,-1.51563 0,-0.96875 -0.39062,-1.4375 -0.39063,-0.48437 -0.96875,-0.48437 -0.60938,0 -1.01563,0.46875 -0.39062,0.46875 -0.39062,1.375 z m 6.98632,-1.10938 -1.5,-0.26562 q 0.25,-0.90625 0.85938,-1.32813 0.625,-0.4375 1.84375,-0.4375 1.09375,0 1.625,0.26563 0.54687,0.25 0.76562,0.65625 0.21875,0.39062 0.21875,1.46875 l -0.0156,1.92187 q 0,0.82813 0.0781,1.21875 0.0781,0.375 0.29688,0.82813 l -1.625,0 q -0.0625,-0.17188 -0.15625,-0.48438 -0.0469,-0.15625 -0.0625,-0.20312 -0.42188,0.42187 -0.90625,0.625 -0.46875,0.20312 -1.01563,0.20312 -0.96875,0 -1.53125,-0.51562 -0.54687,-0.53125 -0.54687,-1.32813 0,-0.53125 0.25,-0.9375 0.26562,-0.40625 0.71875,-0.625 0.45312,-0.23437 1.3125,-0.39062 1.14062,-0.21875 1.59375,-0.40625 l 0,-0.15625 q 0,-0.48438 -0.23438,-0.6875 -0.23437,-0.20313 -0.89062,-0.20313 -0.4375,0 -0.6875,0.17188 -0.23438,0.17187 -0.39063,0.60937 z m 2.20313,1.34375 q -0.3125,0.0937 -1,0.25 -0.6875,0.14063 -0.90625,0.28125 -0.3125,0.23438 -0.3125,0.57813 0,0.34375 0.25,0.60937 0.26562,0.25 0.65625,0.25 0.45312,0 0.85937,-0.29687 0.29688,-0.21875 0.39063,-0.54688 0.0625,-0.20312 0.0625,-0.79687 l 0,-0.32813 z m 4.81445,2.98438 -1.64062,0 0,-6.21875 1.53125,0 0,0.875 q 0.39062,-0.625 0.70312,-0.8125 0.3125,-0.20313 0.70313,-0.20313 0.5625,0 1.09375,0.3125 l -0.51563,1.42188 q -0.42187,-0.26563 -0.76562,-0.26563 -0.35938,0 -0.59375,0.20313 -0.23438,0.1875 -0.375,0.6875 -0.14063,0.48437 -0.14063,2.07812 l 0,1.92188 z m 2.51367,-1.78125 1.65625,-0.25 q 0.10938,0.48437 0.42188,0.73437 0.32812,0.25 0.90625,0.25 0.64062,0 0.95312,-0.23437 0.21875,-0.17188 0.21875,-0.4375 0,-0.1875 -0.10937,-0.3125 -0.125,-0.125 -0.54688,-0.21875 -2,-0.4375 -2.53125,-0.79688 -0.73437,-0.51562 -0.73437,-1.40625 0,-0.8125 0.625,-1.35937 0.64062,-0.54688 1.98437,-0.54688 1.28125,0 1.90625,0.42188 0.625,0.40625 0.85938,1.21875 l -1.5625,0.28125 q -0.0937,-0.35938 -0.375,-0.54688 -0.28125,-0.20312 -0.79688,-0.20312 -0.64062,0 -0.92187,0.1875 -0.1875,0.125 -0.1875,0.32812 0,0.1875 0.15625,0.3125 0.21875,0.15625 1.53125,0.45313 1.32812,0.29687 1.84375,0.73437 0.51562,0.4375 0.51562,1.21875 0,0.85938 -0.71875,1.48438 -0.70312,0.60937 -2.10937,0.60937 -1.26563,0 -2.01563,-0.51562 -0.73437,-0.51563 -0.96875,-1.40625 z m 10.86133,-0.20313 1.64063,0.28125 q -0.3125,0.90625 -1,1.375 -0.6875,0.46875 -1.70313,0.46875 -1.625,0 -2.40625,-1.0625 -0.625,-0.84375 -0.625,-2.14062 0,-1.54688 0.8125,-2.42188 0.8125,-0.875 2.04688,-0.875 1.39062,0 2.1875,0.92188 0.8125,0.90625 0.76562,2.79687 l -4.125,0 q 0.0312,0.73438 0.40625,1.14063 0.375,0.40625 0.95313,0.40625 0.375,0 0.64062,-0.20313 0.26563,-0.21875 0.40625,-0.6875 z m 0.0937,-1.65625 q -0.0156,-0.71875 -0.375,-1.09375 -0.34375,-0.375 -0.85937,-0.375 -0.53125,0 -0.89063,0.39063 -0.34375,0.40625 -0.34375,1.07812 l 2.46875,0 z m 6.33594,3.64063 0,-8.59375 5.89062,0 0,1.45312 -4.15625,0 0,2.03125 3.57813,0 0,1.45313 -3.57813,0 0,3.65625 -1.73437,0 z m 10.9082,-1.98438 1.64063,0.28125 q -0.3125,0.90625 -1,1.375 -0.6875,0.46875 -1.70313,0.46875 -1.625,0 -2.40625,-1.0625 -0.625,-0.84375 -0.625,-2.14062 0,-1.54688 0.8125,-2.42188 0.8125,-0.875 2.04688,-0.875 1.39062,0 2.1875,0.92188 0.8125,0.90625 0.76562,2.79687 l -4.125,0 q 0.0312,0.73438 0.40625,1.14063 0.375,0.40625 0.95313,0.40625 0.375,0 0.64062,-0.20313 0.26563,-0.21875 0.40625,-0.6875 z m 0.0937,-1.65625 q -0.0156,-0.71875 -0.375,-1.09375 -0.34375,-0.375 -0.85937,-0.375 -0.53125,0 -0.89063,0.39063 -0.34375,0.40625 -0.34375,1.07812 l 2.46875,0 z m 4.20508,-0.6875 -1.5,-0.26562 q 0.25,-0.90625 0.85938,-1.32813 0.625,-0.4375 1.84375,-0.4375 1.09375,0 1.625,0.26563 0.54687,0.25 0.76562,0.65625 0.21875,0.39062 0.21875,1.46875 l -0.0156,1.92187 q 0,0.82813 0.0781,1.21875 0.0781,0.375 0.29688,0.82813 l -1.625,0 q -0.0625,-0.17188 -0.15625,-0.48438 -0.0469,-0.15625 -0.0625,-0.20312 -0.42188,0.42187 -0.90625,0.625 -0.46875,0.20312 -1.01563,0.20312 -0.96875,0 -1.53125,-0.51562 -0.54687,-0.53125 -0.54687,-1.32813 0,-0.53125 0.25,-0.9375 0.26562,-0.40625 0.71875,-0.625 0.45312,-0.23437 1.3125,-0.39062 1.14062,-0.21875 1.59375,-0.40625 l 0,-0.15625 q 0,-0.48438 -0.23438,-0.6875 -0.23437,-0.20313 -0.89062,-0.20313 -0.4375,0 -0.6875,0.17188 -0.23438,0.17187 -0.39063,0.60937 z m 2.20313,1.34375 q -0.3125,0.0937 -1,0.25 -0.6875,0.14063 -0.90625,0.28125 -0.3125,0.23438 -0.3125,0.57813 0,0.34375 0.25,0.60937 0.26562,0.25 0.65625,0.25 0.45312,0 0.85937,-0.29687 0.29688,-0.21875 0.39063,-0.54688 0.0625,-0.20312 0.0625,-0.79687 l 0,-0.32813 z m 6.0957,-3.23437 0,1.3125 -1.125,0 0,2.5 q 0,0.76562 0.0312,0.89062 0.0312,0.125 0.14062,0.21875 0.125,0.0781 0.28125,0.0781 0.23438,0 0.65625,-0.17188 l 0.14063,1.28125 q -0.5625,0.25 -1.29688,0.25 -0.4375,0 -0.79687,-0.14062 -0.35938,-0.15625 -0.53125,-0.39063 -0.15625,-0.25 -0.23438,-0.64062 -0.0469,-0.29688 -0.0469,-1.17188 l 0,-2.70312 -0.75,0 0,-1.3125 0.75,0 0,-1.23438 1.65625,-0.96875 0,2.20313 1.125,0 z m 5.23047,6.21875 0,-0.9375 q -0.32813,0.5 -0.89063,0.79687 -0.54687,0.28125 -1.17187,0.28125 -0.625,0 -1.125,-0.26562 -0.5,-0.28125 -0.71875,-0.78125 -0.21875,-0.5 -0.21875,-1.375 l 0,-3.9375 1.64062,0 0,2.85937 q 0,1.3125 0.0937,1.60938 0.0937,0.29687 0.32813,0.46875 0.25,0.17187 0.60937,0.17187 0.42188,0 0.75,-0.23437 0.34375,-0.23438 0.46875,-0.57813 0.125,-0.34375 0.125,-1.67187 l 0,-2.625 1.64063,0 0,6.21875 -1.53125,0 z m 4.81445,0 -1.64062,0 0,-6.21875 1.53125,0 0,0.875 q 0.39062,-0.625 0.70312,-0.8125 0.3125,-0.20313 0.70313,-0.20313 0.5625,0 1.09375,0.3125 l -0.51563,1.42188 q -0.42187,-0.26563 -0.76562,-0.26563 -0.35938,0 -0.59375,0.20313 -0.23438,0.1875 -0.375,0.6875 -0.14063,0.48437 -0.14063,2.07812 l 0,1.92188 z m 6.70117,-1.98438 1.64063,0.28125 q -0.3125,0.90625 -1,1.375 -0.6875,0.46875 -1.70313,0.46875 -1.625,0 -2.40625,-1.0625 -0.625,-0.84375 -0.625,-2.14062 0,-1.54688 0.8125,-2.42188 0.8125,-0.875 2.04688,-0.875 1.39062,0 2.1875,0.92188 0.8125,0.90625 0.76562,2.79687 l -4.125,0 q 0.0312,0.73438 0.40625,1.14063 0.375,0.40625 0.95313,0.40625 0.375,0 0.64062,-0.20313 0.26563,-0.21875 0.40625,-0.6875 z m 0.0937,-1.65625 q -0.0156,-0.71875 -0.375,-1.09375 -0.34375,-0.375 -0.85937,-0.375 -0.53125,0 -0.89063,0.39063 -0.34375,0.40625 -0.34375,1.07812 l 2.46875,0 z m 2.39258,1.85938 1.65625,-0.25 q 0.10938,0.48437 0.42188,0.73437 0.32812,0.25 0.90625,0.25 0.64062,0 0.95312,-0.23437 0.21875,-0.17188 0.21875,-0.4375 0,-0.1875 -0.10937,-0.3125 -0.125,-0.125 -0.54688,-0.21875 -2,-0.4375 -2.53125,-0.79688 -0.73437,-0.51562 -0.73437,-1.40625 0,-0.8125 0.625,-1.35937 0.64062,-0.54688 1.98437,-0.54688 1.28125,0 1.90625,0.42188 0.625,0.40625 0.85938,1.21875 l -1.5625,0.28125 q -0.0937,-0.35938 -0.375,-0.54688 -0.28125,-0.20312 -0.79688,-0.20312 -0.64062,0 -0.92187,0.1875 -0.1875,0.125 -0.1875,0.32812 0,0.1875 0.15625,0.3125 0.21875,0.15625 1.53125,0.45313 1.32812,0.29687 1.84375,0.73437 0.51562,0.4375 0.51562,1.21875 0,0.85938 -0.71875,1.48438 -0.70312,0.60937 -2.10937,0.60937 -1.26563,0 -2.01563,-0.51562 -0.73437,-0.51563 -0.96875,-1.40625 z"
-       id="path5358"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-rule:nonzero" />
-    <path
-       d="m 601.66156,250.31233 120.66144,0 0,30.11024 -120.66144,0 z"
-       id="path5360"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 626.18274,266.58234 q 0,-1.3125 0.40625,-2.20313 0.28125,-0.65625 0.78125,-1.17187 0.51562,-0.51563 1.125,-0.76563 0.79687,-0.34375 1.84375,-0.34375 1.90625,0 3.04687,1.1875 1.14063,1.17188 1.14063,3.26563 0,2.07812 -1.14063,3.25 -1.125,1.17187 -3.01562,1.17187 -1.92188,0 -3.0625,-1.15625 -1.125,-1.17187 -1.125,-3.23437 z m 1.79687,-0.0469 q 0,1.45313 0.67188,2.20313 0.67187,0.75 1.70312,0.75 1.04688,0 1.70313,-0.73438 0.67187,-0.75 0.67187,-2.25 0,-1.46875 -0.65625,-2.1875 -0.64062,-0.73437 -1.71875,-0.73437 -1.0625,0 -1.71875,0.73437 -0.65625,0.73438 -0.65625,2.21875 z m 11.97461,4.29688 0,-0.9375 q -0.32812,0.5 -0.89062,0.79687 -0.54688,0.28125 -1.17188,0.28125 -0.625,0 -1.125,-0.26562 -0.5,-0.28125 -0.71875,-0.78125 -0.21875,-0.5 -0.21875,-1.375 l 0,-3.9375 1.64063,0 0,2.85937 q 0,1.3125 0.0937,1.60938 0.0937,0.29687 0.32812,0.46875 0.25,0.17187 0.60938,0.17187 0.42187,0 0.75,-0.23437 0.34375,-0.23438 0.46875,-0.57813 0.125,-0.34375 0.125,-1.67187 l 0,-2.625 1.64062,0 0,6.21875 -1.53125,0 z m 6.09571,-6.21875 0,1.3125 -1.125,0 0,2.5 q 0,0.76562 0.0312,0.89062 0.0312,0.125 0.14062,0.21875 0.125,0.0781 0.28125,0.0781 0.23438,0 0.65625,-0.17188 l 0.14063,1.28125 q -0.5625,0.25 -1.29688,0.25 -0.4375,0 -0.79687,-0.14062 -0.35938,-0.15625 -0.53125,-0.39063 -0.15625,-0.25 -0.23438,-0.64062 -0.0469,-0.29688 -0.0469,-1.17188 l 0,-2.70312 -0.75,0 0,-1.3125 0.75,0 0,-1.23438 1.65625,-0.96875 0,2.20313 1.125,0 z m 1.08984,0 1.53125,0 0,0.90625 q 0.3125,-0.46875 0.8125,-0.75 0.51563,-0.29688 1.14063,-0.29688 1.07812,0 1.82812,0.84375 0.76563,0.84375 0.76563,2.375 0,1.54688 -0.76563,2.42188 -0.76562,0.85937 -1.84375,0.85937 -0.51562,0 -0.9375,-0.20312 -0.42187,-0.20313 -0.875,-0.70313 l 0,3.14063 -1.65625,0 0,-8.59375 z m 1.625,3 q 0,1.04687 0.42188,1.54687 0.42187,0.5 1.01562,0.5 0.57813,0 0.95313,-0.45312 0.375,-0.46875 0.375,-1.51563 0,-0.96875 -0.39063,-1.4375 -0.39062,-0.48437 -0.96875,-0.48437 -0.60937,0 -1.01562,0.46875 -0.39063,0.46875 -0.39063,1.375 z m 9.8457,3.21875 0,-0.9375 q -0.32812,0.5 -0.89062,0.79687 -0.54688,0.28125 -1.17188,0.28125 -0.625,0 -1.125,-0.26562 -0.5,-0.28125 -0.71875,-0.78125 -0.21875,-0.5 -0.21875,-1.375 l 0,-3.9375 1.64063,0 0,2.85937 q 0,1.3125 0.0937,1.60938 0.0937,0.29687 0.32812,0.46875 0.25,0.17187 0.60938,0.17187 0.42187,0 0.75,-0.23437 0.34375,-0.23438 0.46875,-0.57813 0.125,-0.34375 0.125,-1.67187 l 0,-2.625 1.64062,0 0,6.21875 -1.53125,0 z m 6.09571,-6.21875 0,1.3125 -1.125,0 0,2.5 q 0,0.76562 0.0312,0.89062 0.0312,0.125 0.14062,0.21875 0.125,0.0781 0.28125,0.0781 0.23438,0 0.65625,-0.17188 l 0.14063,1.28125 q -0.5625,0.25 -1.29688,0.25 -0.4375,0 -0.79687,-0.14062 -0.35938,-0.15625 -0.53125,-0.39063 -0.15625,-0.25 -0.23438,-0.64062 -0.0469,-0.29688 -0.0469,-1.17188 l 0,-2.70312 -0.75,0 0,-1.3125 0.75,0 0,-1.23438 1.65625,-0.96875 0,2.20313 1.125,0 z m 4.4707,-2.375 1.73437,0 0,4.65625 q 0,1.10937 0.0625,1.4375 0.10938,0.53125 0.53125,0.84375 0.42188,0.3125 1.15625,0.3125 0.73438,0 1.10938,-0.29688 0.375,-0.29687 0.45312,-0.73437 0.0781,-0.4375 0.0781,-1.46875 l 0,-4.75 1.73437,0 0,4.51562 q 0,1.54688 -0.14062,2.1875 -0.14063,0.64063 -0.53125,1.07813 -0.375,0.4375 -1.01563,0.70312 -0.625,0.25 -1.64062,0.25 -1.23438,0 -1.875,-0.28125 -0.625,-0.28125 -1,-0.73437 -0.35938,-0.45313 -0.48438,-0.95313 -0.17187,-0.73437 -0.17187,-2.1875 l 0,-4.57812 z m 14.32227,8.59375 -1.64063,0 0,-3.17188 q 0,-1.01562 -0.10937,-1.3125 -0.0937,-0.29687 -0.34375,-0.45312 -0.23438,-0.17188 -0.5625,-0.17188 -0.4375,0 -0.78125,0.23438 -0.32813,0.23437 -0.45313,0.625 -0.125,0.39062 -0.125,1.4375 l 0,2.8125 -1.65625,0 0,-6.21875 1.53125,0 0,0.90625 q 0.8125,-1.04688 2.0625,-1.04688 0.53125,0 0.98438,0.20313 0.45312,0.1875 0.6875,0.5 0.23437,0.29687 0.3125,0.6875 0.0937,0.375 0.0937,1.09375 l 0,3.875 z m 1.67382,-7.0625 0,-1.53125 1.65625,0 0,1.53125 -1.65625,0 z m 0,7.0625 0,-6.21875 1.65625,0 0,6.21875 -1.65625,0 z m 6.19336,-6.21875 0,1.3125 -1.125,0 0,2.5 q 0,0.76562 0.0312,0.89062 0.0312,0.125 0.14063,0.21875 0.125,0.0781 0.28125,0.0781 0.23437,0 0.65625,-0.17188 l 0.14062,1.28125 q -0.5625,0.25 -1.29687,0.25 -0.4375,0 -0.79688,-0.14062 -0.35937,-0.15625 -0.53125,-0.39063 -0.15625,-0.25 -0.23437,-0.64062 -0.0469,-0.29688 -0.0469,-1.17188 l 0,-2.70312 -0.75,0 0,-1.3125 0.75,0 0,-1.23438 1.65625,-0.96875 0,2.20313 1.125,0 z m 0.5586,4.4375 1.65625,-0.25 q 0.10937,0.48437 0.42187,0.73437 0.32813,0.25 0.90625,0.25 0.64063,0 0.95313,-0.23437 0.21875,-0.17188 0.21875,-0.4375 0,-0.1875 -0.10938,-0.3125 -0.125,-0.125 -0.54687,-0.21875 -2,-0.4375 -2.53125,-0.79688 -0.73438,-0.51562 -0.73438,-1.40625 0,-0.8125 0.625,-1.35937 0.64063,-0.54688 1.98438,-0.54688 1.28125,0 1.90625,0.42188 0.625,0.40625 0.85937,1.21875 l -1.5625,0.28125 q -0.0937,-0.35938 -0.375,-0.54688 -0.28125,-0.20312 -0.79687,-0.20312 -0.64063,0 -0.92188,0.1875 -0.1875,0.125 -0.1875,0.32812 0,0.1875 0.15625,0.3125 0.21875,0.15625 1.53125,0.45313 1.32813,0.29687 1.84375,0.73437 0.51563,0.4375 0.51563,1.21875 0,0.85938 -0.71875,1.48438 -0.70313,0.60937 -2.10938,0.60937 -1.26562,0 -2.01562,-0.51562 -0.73438,-0.51563 -0.96875,-1.40625 z"
-       id="path5362"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-rule:nonzero" />
-    <path
-       d="m 601.6615,345.60638 120.66144,0 0,42.86615 -120.66144,0 z"
-       id="path5364"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-opacity:0;fill-rule:nonzero" />
-    <path
-       d="m 644.85846,357.53265 3.15625,0 q 1.07813,0 1.64063,0.17187 0.75,0.21875 1.28125,0.79688 0.54687,0.5625 0.82812,1.39062 0.28125,0.8125 0.28125,2.01563 0,1.0625 -0.25,1.82812 -0.32812,0.9375 -0.92187,1.51563 -0.45313,0.45312 -1.21875,0.6875 -0.57813,0.1875 -1.54688,0.1875 l -3.25,0 0,-8.59375 z m 1.73438,1.45312 0,5.6875 1.28125,0 q 0.73437,0 1.0625,-0.0781 0.42187,-0.10938 0.6875,-0.35938 0.28125,-0.25 0.45312,-0.82812 0.1875,-0.57813 0.1875,-1.57813 0,-0.98437 -0.1875,-1.51562 -0.17187,-0.54688 -0.48437,-0.84375 -0.3125,-0.29688 -0.79688,-0.40625 -0.375,-0.0781 -1.4375,-0.0781 l -0.76562,0 z m 10.52539,5.15625 1.64062,0.28125 q -0.3125,0.90625 -1,1.375 -0.6875,0.46875 -1.70312,0.46875 -1.625,0 -2.40625,-1.0625 -0.625,-0.84375 -0.625,-2.14062 0,-1.54688 0.8125,-2.42188 0.8125,-0.875 2.04687,-0.875 1.39063,0 2.1875,0.92188 0.8125,0.90625 0.76563,2.79687 l -4.125,0 q 0.0312,0.73438 0.40625,1.14063 0.375,0.40625 0.95312,0.40625 0.375,0 0.64063,-0.20313 0.26562,-0.21875 0.40625,-0.6875 z m 0.0937,-1.65625 q -0.0156,-0.71875 -0.375,-1.09375 -0.34375,-0.375 -0.85938,-0.375 -0.53125,0 -0.89062,0.39063 -0.34375,0.40625 -0.34375,1.07812 l 2.46875,0 z m 8.62695,3.64063 -1.64063,0 0,-3.17188 q 0,-1.01562 -0.10937,-1.3125 -0.0937,-0.29687 -0.34375,-0.45312 -0.23438,-0.17188 -0.5625,-0.17188 -0.4375,0 -0.78125,0.23438 -0.32813,0.23437 -0.45313,0.625 -0.125,0.39062 -0.125,1.4375 l 0,2.8125 -1.65625,0 0,-6.21875 1.53125,0 0,0.90625 q 0.8125,-1.04688 2.0625,-1.04688 0.53125,0 0.98438,0.20313 0.45312,0.1875 0.6875,0.5 0.23437,0.29687 0.3125,0.6875 0.0937,0.375 0.0937,1.09375 l 0,3.875 z m 1.0957,-1.78125 1.65625,-0.25 q 0.10938,0.48437 0.42188,0.73437 0.32812,0.25 0.90625,0.25 0.64062,0 0.95312,-0.23437 0.21875,-0.17188 0.21875,-0.4375 0,-0.1875 -0.10937,-0.3125 -0.125,-0.125 -0.54688,-0.21875 -2,-0.4375 -2.53125,-0.79688 -0.73437,-0.51562 -0.73437,-1.40625 0,-0.8125 0.625,-1.35937 0.64062,-0.54688 1.98437,-0.54688 1.28125,0 1.90625,0.42188 0.625,0.40625 0.85938,1.21875 l -1.5625,0.28125 q -0.0937,-0.35938 -0.375,-0.54688 -0.28125,-0.20312 -0.79688,-0.20312 -0.64062,0 -0.92187,0.1875 -0.1875,0.125 -0.1875,0.32812 0,0.1875 0.15625,0.3125 0.21875,0.15625 1.53125,0.45313 1.32812,0.29687 1.84375,0.73437 0.51562,0.4375 0.51562,1.21875 0,0.85938 -0.71875,1.48438 -0.70312,0.60937 -2.10937,0.60937 -1.26563,0 -2.01563,-0.51562 -0.73437,-0.51563 -0.96875,-1.40625 z m 10.86133,-0.20313 1.64063,0.28125 q -0.3125,0.90625 -1,1.375 -0.6875,0.46875 -1.70313,0.46875 -1.625,0 -2.40625,-1.0625 -0.625,-0.84375 -0.625,-2.14062 0,-1.54688 0.8125,-2.42188 0.8125,-0.875 2.04688,-0.875 1.39062,0 2.1875,0.92188 0.8125,0.90625 0.76562,2.79687 l -4.125,0 q 0.0312,0.73438 0.40625,1.14063 0.375,0.40625 0.95313,0.40625 0.375,0 0.64062,-0.20313 0.26563,-0.21875 0.40625,-0.6875 z m 0.0937,-1.65625 q -0.0156,-0.71875 -0.375,-1.09375 -0.34375,-0.375 -0.85937,-0.375 -0.53125,0 -0.89063,0.39063 -0.34375,0.40625 -0.34375,1.07812 l 2.46875,0 z"
-       id="path5366"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-rule:nonzero" />
-    <path
-       d="m 626.8643,380.1264 0,-8.59375 6.375,0 0,1.45312 -4.64063,0 0,1.90625 4.3125,0 0,1.45313 -4.3125,0 0,2.32812 4.79688,0 0,1.45313 -6.53125,0 z m 7.86328,-6.21875 1.51563,0 0,0.84375 q 0.82812,-0.98438 1.95312,-0.98438 0.59375,0 1.03125,0.25 0.4375,0.23438 0.71875,0.73438 0.40625,-0.5 0.875,-0.73438 0.48438,-0.25 1.03125,-0.25 0.67188,0 1.14063,0.28125 0.48437,0.26563 0.71875,0.8125 0.17187,0.39063 0.17187,1.28125 l 0,3.98438 -1.64062,0 0,-3.5625 q 0,-0.92188 -0.17188,-1.1875 -0.23437,-0.35938 -0.70312,-0.35938 -0.34375,0 -0.65625,0.21875 -0.29688,0.20313 -0.4375,0.60938 -0.14063,0.40625 -0.14063,1.29687 l 0,2.98438 -1.64062,0 0,-3.40625 q 0,-0.90625 -0.0937,-1.17188 -0.0781,-0.26562 -0.26563,-0.39062 -0.1875,-0.14063 -0.5,-0.14063 -0.375,0 -0.6875,0.21875 -0.29687,0.20313 -0.4375,0.59375 -0.125,0.375 -0.125,1.26563 l 0,3.03125 -1.65625,0 0,-6.21875 z m 10.73242,6.21875 0,-8.59375 1.64063,0 0,3.09375 q 0.76562,-0.85938 1.8125,-0.85938 1.125,0 1.875,0.82813 0.75,0.8125 0.75,2.35937 0,1.59375 -0.76563,2.45313 -0.76562,0.85937 -1.84375,0.85937 -0.53125,0 -1.04687,-0.26562 -0.51563,-0.26563 -0.89063,-0.79688 l 0,0.92188 -1.53125,0 z m 1.625,-3.25 q 0,0.96875 0.3125,1.4375 0.42188,0.65625 1.14063,0.65625 0.53125,0 0.92187,-0.46875 0.39063,-0.46875 0.39063,-1.46875 0,-1.0625 -0.39063,-1.53125 -0.39062,-0.48438 -1,-0.48438 -0.57812,0 -0.98437,0.46875 -0.39063,0.45313 -0.39063,1.39063 z m 9.37696,1.26562 1.64062,0.28125 q -0.3125,0.90625 -1,1.375 -0.6875,0.46875 -1.70312,0.46875 -1.625,0 -2.40625,-1.0625 -0.625,-0.84375 -0.625,-2.14062 0,-1.54688 0.8125,-2.42188 0.8125,-0.875 2.04687,-0.875 1.39063,0 2.1875,0.92188 0.8125,0.90625 0.76563,2.79687 l -4.125,0 q 0.0312,0.73438 0.40625,1.14063 0.375,0.40625 0.95312,0.40625 0.375,0 0.64063,-0.20313 0.26562,-0.21875 0.40625,-0.6875 z m 0.0937,-1.65625 q -0.0156,-0.71875 -0.375,-1.09375 -0.34375,-0.375 -0.85938,-0.375 -0.53125,0 -0.89062,0.39063 -0.34375,0.40625 -0.34375,1.07812 l 2.46875,0 z m 8.67382,3.64063 -1.51562,0 0,-0.92188 q -0.39063,0.54688 -0.90625,0.8125 -0.51563,0.25 -1.04688,0.25 -1.07812,0 -1.84375,-0.85937 -0.75,-0.875 -0.75,-2.42188 0,-1.57812 0.73438,-2.39062 0.75,-0.82813 1.89062,-0.82813 1.03125,0 1.79688,0.85938 l 0,-3.09375 1.64062,0 0,8.59375 z m -4.39062,-3.25 q 0,1 0.28125,1.4375 0.39062,0.65625 1.10937,0.65625 0.5625,0 0.95313,-0.48438 0.40625,-0.48437 0.40625,-1.45312 0,-1.0625 -0.39063,-1.53125 -0.375,-0.48438 -0.98437,-0.48438 -0.57813,0 -0.98438,0.46875 -0.39062,0.46875 -0.39062,1.39063 z m 11.7207,3.25 -1.51562,0 0,-0.92188 q -0.39063,0.54688 -0.90625,0.8125 -0.51563,0.25 -1.04688,0.25 -1.07812,0 -1.84375,-0.85937 -0.75,-0.875 -0.75,-2.42188 0,-1.57812 0.73438,-2.39062 0.75,-0.82813 1.89062,-0.82813 1.03125,0 1.79688,0.85938 l 0,-3.09375 1.64062,0 0,8.59375 z m -4.39062,-3.25 q 0,1 0.28125,1.4375 0.39062,0.65625 1.10937,0.65625 0.5625,0 0.95313,-0.48438 0.40625,-0.48437 0.40625,-1.45312 0,-1.0625 -0.39063,-1.53125 -0.375,-0.48438 -0.98437,-0.48438 -0.57813,0 -0.98438,0.46875 -0.39062,0.46875 -0.39062,1.39063 z m 6.01758,-3.8125 0,-1.53125 1.65625,0 0,1.53125 -1.65625,0 z m 0,7.0625 0,-6.21875 1.65625,0 0,6.21875 -1.65625,0 z m 8.99023,0 -1.64063,0 0,-3.17188 q 0,-1.01562 -0.10937,-1.3125 -0.0937,-0.29687 -0.34375,-0.45312 -0.23438,-0.17188 -0.5625,-0.17188 -0.4375,0 -0.78125,0.23438 -0.32813,0.23437 -0.45313,0.625 -0.125,0.39062 -0.125,1.4375 l 0,2.8125 -1.65625,0 0,-6.21875 1.53125,0 0,0.90625 q 0.8125,-1.04688 2.0625,-1.04688 0.53125,0 0.98438,0.20313 0.45312,0.1875 0.6875,0.5 0.23437,0.29687 0.3125,0.6875 0.0937,0.375 0.0937,1.09375 l 0,3.875 z m 1.51758,0.40625 1.89062,0.23437 q 0.0469,0.32813 0.21875,0.45313 0.23438,0.17187 0.73438,0.17187 0.64062,0 0.96875,-0.1875 0.21875,-0.14062 0.32812,-0.42187 0.0781,-0.20313 0.0781,-0.75 l 0,-0.92188 q -0.75,1.01563 -1.875,1.01563 -1.25,0 -1.98438,-1.0625 -0.5625,-0.84375 -0.5625,-2.07813 0,-1.57812 0.75,-2.39062 0.75,-0.82813 1.875,-0.82813 1.14063,0 1.89063,1.01563 l 0,-0.875 1.54687,0 0,5.57812 q 0,1.10938 -0.1875,1.64063 -0.17187,0.54687 -0.5,0.85937 -0.32812,0.3125 -0.875,0.48438 -0.54687,0.1875 -1.39062,0.1875 -1.57813,0 -2.25,-0.54688 -0.65625,-0.54687 -0.65625,-1.375 0,-0.0781 0,-0.20312 z m 1.48437,-3.64063 q 0,0.98438 0.375,1.45313 0.39063,0.45312 0.95313,0.45312 0.59375,0 1.01562,-0.46875 0.42188,-0.48437 0.42188,-1.40625 0,-0.96875 -0.40625,-1.4375 -0.39063,-0.46875 -1,-0.46875 -0.59375,0 -0.98438,0.46875 -0.375,0.45313 -0.375,1.40625 z m 5.42383,1.45313 1.65625,-0.25 q 0.10938,0.48437 0.42188,0.73437 0.32812,0.25 0.90625,0.25 0.64062,0 0.95312,-0.23437 0.21875,-0.17188 0.21875,-0.4375 0,-0.1875 -0.10937,-0.3125 -0.125,-0.125 -0.54688,-0.21875 -2,-0.4375 -2.53125,-0.79688 -0.73437,-0.51562 -0.73437,-1.40625 0,-0.8125 0.625,-1.35937 0.64062,-0.54688 1.98437,-0.54688 1.28125,0 1.90625,0.42188 0.625,0.40625 0.85938,1.21875 l -1.5625,0.28125 q -0.0937,-0.35938 -0.375,-0.54688 -0.28125,-0.20312 -0.79688,-0.20312 -0.64062,0 -0.92187,0.1875 -0.1875,0.125 -0.1875,0.32812 0,0.1875 0.15625,0.3125 0.21875,0.15625 1.53125,0.45313 1.32812,0.29687 1.84375,0.73437 0.51562,0.4375 0.51562,1.21875 0,0.85938 -0.71875,1.48438 -0.70312,0.60937 -2.10937,0.60937 -1.26563,0 -2.01563,-0.51562 -0.73437,-0.51563 -0.96875,-1.40625 z"
-       id="path5368"
-       inkscape:connector-curvature="0"
-       style="fill:#000000;fill-rule:nonzero" />
-  </g>
-</svg>
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index c2e5b0cc1c8..c5418cf076f 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -670,6 +670,8 @@ tf_gen_op_wrapper_py(
         "MatMul",
         "Sigmoid",
         "Tanh",
+        "SigmoidGrad",
+        "TanhGrad",
     ],
     require_shape_functions = True,
 )
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
index cc769ec2748..adfedbed64b 100644
--- a/tensorflow/python/kernel_tests/constant_op_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -516,6 +516,9 @@ class FillTest(tf.test.TestCase):
         tf.placeholder(tf.int32, shape=(4,)), 3.0)
     self.assertEqual([None, None, None, None], f.get_shape().as_list())
 
+    f = tf.fill([tf.placeholder(tf.int32, shape=()), 17], 1.0)
+    self.assertEqual([None, 17], f.get_shape().as_list())
+
   def testGradient(self):
     with self.test_session():
       in_v = tf.constant(5.0)
diff --git a/tensorflow/python/kernel_tests/sparse_xent_op_test.py b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
index eb6bdff8b5a..ea379fbac01 100644
--- a/tensorflow/python/kernel_tests/sparse_xent_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
@@ -120,6 +120,13 @@ class SparseXentTest(tf.test.TestCase):
         tf.nn.sparse_softmax_cross_entropy_with_logits(
             tf.constant(1.0), tf.constant(0))
 
+  def testLabelsPlaceholderScalar(self):
+    with self.test_session():
+      labels = tf.placeholder(np.int32)
+      y = tf.nn.sparse_softmax_cross_entropy_with_logits([[7.]], labels)
+      with self.assertRaisesOpError("labels must be 1-D"):
+        y.eval(feed_dict={labels: 0})
+
   def testVector(self):
     with self.test_session():
       loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
@@ -145,6 +152,9 @@ class SparseXentTest(tf.test.TestCase):
           np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float16),
           np.array([3, 0]).astype(label_dtype))
 
+  def testEmpty(self):
+    self._testXent(np.zeros((0, 3)), np.zeros((0,), dtype=np.int32))
+
   def testGradient(self):
     with self.test_session():
       l = tf.constant([3, 0, 1], name="l")
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 82f010cf5b2..23a7f3717cc 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -197,6 +197,17 @@ class VariablesTestCase(tf.test.TestCase):
       self.assertAllClose(3.0, var_y.eval())
       self.assertAllClose(5.0, tf.add(var_x, var_y).eval())
 
+  def testZeroSizeVarSameAsConst(self):
+    with self.test_session():
+      zero_size_var = tf.Variable(tf.zeros([0, 2]))
+      zero_size_const = tf.ones([2, 0])
+      variable_mul = tf.matmul(zero_size_const, zero_size_var)
+      const_mul = tf.matmul(zero_size_const, zero_size_const, transpose_b=True)
+      tf.initialize_all_variables().run()
+      variable_output = variable_mul.eval()
+      self.assertAllClose(const_mul.eval(), variable_output)
+      self.assertAllClose([[0., 0.], [0., 0.]], variable_output)
+
   def testCachingDevice(self):
     with self.test_session():
       var = tf.Variable(2.0)
@@ -387,6 +398,23 @@ class IsInitializedTest(tf.test.TestCase):
       v.initializer.run()
       self.assertEqual(0, sess.run(uninited).size)
 
+  def testZeroSizeVarInitialized(self):
+    with tf.Graph().as_default(), self.test_session() as sess:
+      v = tf.Variable(tf.zeros([0, 2]), name="v")
+      uninited = tf.report_uninitialized_variables()
+      v.initializer.run()  # not strictly necessary
+      self.assertEqual(0, sess.run(uninited).size)
+
+  def testTrainingWIthZeroSizeVar(self):
+    with tf.Graph().as_default(), self.test_session() as sess:
+      a = tf.Variable(tf.zeros([0, 2]))
+      b = tf.Variable(tf.ones([2, 2]))
+      objective = tf.reduce_sum(b + tf.matmul(a, a, transpose_a=True))
+      tf.initialize_all_variables().run()
+      do_opt = tf.train.GradientDescentOptimizer(0.1).minimize(objective)
+      sess.run([do_opt])
+      self.assertAllClose([[0.9, 0.9], [0.9, 0.9]], b.eval())
+
 
 class ObsoleteIsInitializedTest(tf.test.TestCase):
 
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 509f6271700..658844dfb3f 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1835,16 +1835,16 @@ def _FillShape(op):
 
   Returns:
     A single-element list containing the shape of the output.
+
+  Raises:
+    ValueError: If the shapes or arguments are known to be invalid.
   """
-  dimensions_shape = op.inputs[0].get_shape().with_rank(1)
-  op.inputs[1].get_shape().assert_is_compatible_with(tensor_shape.scalar())
+  op.inputs[0].get_shape().assert_has_rank(1)
+  op.inputs[1].get_shape().assert_has_rank(0)
   fill_dims = tensor_util.constant_value(op.inputs[0])
-  if fill_dims is None:
-    # Attempt to infer the rank of the output from the length of
-    # dimensions.
-    return [tensor_shape.unknown_shape(ndims=dimensions_shape[0].value)]
-  else:
-    return [tensor_shape.TensorShape(fill_dims.tolist())]
+  if fill_dims is not None and any(d < 0 for d in fill_dims):
+    raise ValueError("Fill dimensions must be >= 0")
+  return [tensor_util.constant_value_as_shape(op.inputs[0])]
 
 
 @ops.RegisterShape("InvertPermutation")
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 8bfd9ce8bf8..348ab9fd121 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_math_ops
 from tensorflow.python.ops import math_ops
 
 
@@ -272,7 +273,7 @@ def _TanhGrad(op, grad):
   with ops.control_dependencies([grad.op]):
     if y.dtype.is_complex:
       y = math_ops.conj(y)
-    return grad * (1 - math_ops.square(y))
+    return gen_math_ops._tanh_grad(y, grad)
 
 
 @ops.RegisterGradient("Erf")
@@ -374,7 +375,7 @@ def _SigmoidGrad(op, grad):
   with ops.control_dependencies([grad.op]):
     if y.dtype.is_complex:
       y = math_ops.conj(y)
-    return grad * (y * (1 - y))
+    return gen_math_ops._sigmoid_grad(y, grad)
 
 
 @ops.RegisterGradient("Sign")
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 38c7e515941..07d93160ad8 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -1628,6 +1628,8 @@ ops.RegisterShape("BatchFFT2D")(common_shapes.unchanged_shape)
 ops.RegisterShape("BatchIFFT2D")(common_shapes.unchanged_shape)
 ops.RegisterShape("BatchFFT3D")(common_shapes.unchanged_shape)
 ops.RegisterShape("BatchIFFT3D")(common_shapes.unchanged_shape)
+ops.RegisterShape("TanhGrad")(common_shapes.unchanged_shape)
+ops.RegisterShape("SigmoidGrad")(common_shapes.unchanged_shape)
 
 
 @ops.RegisterShape("Add")
diff --git a/tensorflow/python/training/coordinator.py b/tensorflow/python/training/coordinator.py
index 5001b6acd9e..1785ff372b6 100644
--- a/tensorflow/python/training/coordinator.py
+++ b/tensorflow/python/training/coordinator.py
@@ -125,8 +125,21 @@ class Coordinator(object):
   ```
   """
 
-  def __init__(self):
-    """Create a new Coordinator."""
+  def __init__(self, clean_stop_exception_types=None):
+    """Create a new Coordinator.
+
+    Args:
+      clean_stop_exception_types: Optional tuple of Exception types that should
+        cause a clean stop of the coordinator. If an exception of one of these
+        types is reported to `request_stop(ex)` the coordinator will behave as
+        if `request_stop(None)` was called.  Defaults to
+        `(tf.errors.OutOfRangeError,)` which is used by input queues to signal
+        the end of input. When feeding training data from a Python iterator it
+        is common to add `StopIteration` to this list.
+    """
+    if clean_stop_exception_types is None:
+      clean_stop_exception_types = (errors.OutOfRangeError,)
+    self._clean_stop_exception_types = clean_stop_exception_types
     # Protects all attributes.
     self._lock = threading.Lock()
     # Event set when threads must stop.
@@ -143,9 +156,8 @@ class Coordinator(object):
     reported to the users.  If yes, it returns `ex` as is, otherwise it returns
     None.
 
-    The code returns None for exceptions that are used for control flow such as
-    the OutOfRangeError raised by the dequeue operations to indicate that a
-    queue was closed after its contents were dequeued.
+    The code returns None for exception types listed in
+    `_clean_stop_exception_types`.
 
     Args:
       ex: None, an `Exception`, or a Python `exc_info` tuple as returned by
@@ -158,12 +170,7 @@ class Coordinator(object):
       ex2 = ex[1]
     else:
       ex2 = ex
-    # OutOfRangeError is used to indicate "end of input".  We do not want to
-    # report an exception for it.  TODO(touts): Likely also need to ignore
-    # some of the Aborted and Cancelled exceptions raised by queue ops after
-    # queues are closed, but this can only be done after these exceptions have
-    # been clearly identified.
-    if isinstance(ex2, (errors.OutOfRangeError)):
+    if isinstance(ex2, self._clean_stop_exception_types):
       # Ignore the exception.
       ex = None
     return ex
diff --git a/tensorflow/python/training/coordinator_test.py b/tensorflow/python/training/coordinator_test.py
index dac5fe59e9d..9e700cbe68c 100644
--- a/tensorflow/python/training/coordinator_test.py
+++ b/tensorflow/python/training/coordinator_test.py
@@ -132,6 +132,16 @@ class CoordinatorTest(tf.test.TestCase):
       t.start()
     coord.join(threads)
 
+  def testJoinIgnoresMyExceptionType(self):
+    coord = tf.train.Coordinator(clean_stop_exception_types=(ValueError,))
+    threads = [
+        threading.Thread(target=RaiseInN,
+                         args=(coord, 0.01, ValueError("Clean stop"), True))
+        ]
+    for t in threads:
+      t.start()
+    coord.join(threads)
+
   def testJoinRaiseReportExceptionUsingHandler(self):
     coord = tf.train.Coordinator()
     threads = [
diff --git a/tensorflow/python/training/server_lib_test.py b/tensorflow/python/training/server_lib_test.py
index 40d399ccedf..f72e96f5af4 100644
--- a/tensorflow/python/training/server_lib_test.py
+++ b/tensorflow/python/training/server_lib_test.py
@@ -280,6 +280,21 @@ class GrpcServerTest(tf.test.TestCase):
                           job_name="local",
                           task_index=0)
 
+  def testInteractiveSession(self):
+    server = tf.train.Server.create_local_server()
+    # TODO(b/29900832): Remove this assertion when the bug is fixed.
+    a = tf.constant(1.0)
+    with self.assertRaisesRegexp(tf.errors.UnimplementedError, "pruned"):
+      sess = tf.InteractiveSession(target=server.target)
+      sess.run(a)
+
+    # TODO(b/29900832): The following code fails (without the unimplemented
+    # check in `tensorflow::MasterSession`):
+    # a = tf.constant(1.0)
+    # b = tf.constant(2.0)
+    # self.assertEqual(1.0, sess.run(a))
+    # self.assertEqual(2.0, sess.run(b))
+
 
 class ServerDefTest(tf.test.TestCase):
 
diff --git a/tensorflow/tensorboard/TAG b/tensorflow/tensorboard/TAG
index aabe6ec3909..2bd5a0a98a3 100644
--- a/tensorflow/tensorboard/TAG
+++ b/tensorflow/tensorboard/TAG
@@ -1 +1 @@
-21
+22
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts
index dbc0213ac7b..64f245c7311 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts
@@ -827,7 +827,7 @@ export function traceAllInputsOfOpNode(
   // Get visible parent.
   let currentVisibleParent = getVisibleParent(renderGraphInfo, startNode);
   // Mark as input node.
-  d3.selectAll(`[data-name="${currentVisibleParent.name}"]`)
+  d3.select(`.node[data-name="${currentVisibleParent.name}"]`)
       .classed('input-highlight', true);
 
   // Find the visible parent of each input.
@@ -1018,7 +1018,7 @@ function _markParentsOfNodes(visibleNodes: {[nodeName: string]: Node}) {
     let currentNode = nodeInstance;
 
     while (currentNode.name !== tf.graph.ROOT_NAME) {
-      let renderedElement = d3.select(`[data-name="${currentNode.name}"]`);
+      let renderedElement = d3.select(`.node[data-name="${currentNode.name}"]`);
       // Only mark the element as a parent node to an input if it is not
       // marked as input node itself.
       if (renderedElement[0][0] &&
diff --git a/tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html b/tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html
index d7bc9b8c04d..04c34bf935c 100644
--- a/tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html
+++ b/tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html
@@ -72,6 +72,7 @@
 ::content .faded rect,
 ::content .faded ellipse,
 ::content .faded path,
+::content .faded use,
 ::content #rectHatch line,
 ::content #ellipseHatch line {
   color: #e0d4b3 !important;
@@ -88,7 +89,8 @@
   fill: url(#rectHatch) !important;
 }
 
-::content .faded ellipse {
+::content .faded ellipse,
+::content .faded use {
   fill: url(#ellipseHatch) !important;
 }
 
@@ -111,8 +113,12 @@
 ::content .non-input > * > use,
 /* For Const nodes. */
 ::content .non-input > * > .constant:not([class*="input-highlight"]) >
-.annotation-node > ellipse {
+  .annotation-node > ellipse,
+/* For styling of annotation nodes of non-input nodes. */
+::content .non-input > g > .annotation > .annotation-node > rect {
   stroke: #e0d4b3 !important;
+  stroke-width: inherit;
+  stroke-dasharray: inherit;
 }
 
 
@@ -121,7 +127,9 @@
 }
 
 ::content .non-input > .nodeshape > rect,
-::content .non-input > .annotation-node > rect
+::content .non-input > .annotation-node > rect,
+/* For styling of annotation nodes of non-input nodes. */
+::content .non-input > g > .annotation > .annotation-node > rect
 {
   fill: url(#rectHatch) !important;
 }
diff --git a/tensorflow/tensorboard/dist/tf-tensorboard.html b/tensorflow/tensorboard/dist/tf-tensorboard.html
index 77e224dd205..1c0297dbdb4 100644
--- a/tensorflow/tensorboard/dist/tf-tensorboard.html
+++ b/tensorflow/tensorboard/dist/tf-tensorboard.html
@@ -3963,6 +3963,7 @@ var tf;
             }
             return OpNodeImpl;
         }());
+        graph_1.OpNodeImpl = OpNodeImpl;
         ;
         function createMetanode(name, opt) {
             if (opt === void 0) { opt = {}; }
@@ -4171,6 +4172,7 @@ var tf;
             };
             return MetanodeImpl;
         }());
+        graph_1.MetanodeImpl = MetanodeImpl;
         ;
         function createMetaedge(v, w) {
             return new MetaedgeImpl(v, w);
@@ -4527,6 +4529,7 @@ var tf;
             var parts = name.split(graph_1.NAMESPACE_DELIM);
             return name + graph_1.NAMESPACE_DELIM + '(' + parts[parts.length - 1] + ')';
         }
+        graph_1.getStrictName = getStrictName;
         /**
          * For each op node (embedding or non-embedding), rename it if there is a
          * non-embedding node under its namespace. For example, assume node name 'A'.
@@ -6494,6 +6497,7 @@ var tf;
                     this.index[hierarchy.root.name] = this.root;
                     this.buildSubhierarchy(hierarchy.root.name);
                     this.root.expanded = true;
+                    this.traceInputs = false;
                 }
                 RenderGraphInfo.prototype.computeScales = function () {
                     this.deviceColorMap = d3.scale.ordinal()
@@ -8788,10 +8792,333 @@ var tf;
                         d3.rgb(fill).darker().toString();
                 }
                 node_1.getStrokeForFill = getStrokeForFill;
+                /**
+                 * Finds selected node and highlights all nodes which are providing direct
+                 * or indirect input to the node and all edges connecting these nodes
+                 * together and to the selected node.
+                 *
+                 * @param renderGraphInfo Information on the rendered state of the graph.
+                 */
+                function traceInputs(renderGraphInfo) {
+                    // Reset all styling.
+                    d3.selectAll('.input-highlight').classed('input-highlight', false);
+                    d3.selectAll('.non-input').classed('non-input', false);
+                    d3.selectAll('.input-parent').classed('input-parent', false);
+                    d3.selectAll('.input-child').classed('input-child', false);
+                    d3.selectAll('.input-edge-highlight').classed('input-edge-highlight', false);
+                    d3.selectAll('.non-input-edge-highlight')
+                        .classed('non-input-edge-highlight', false);
+                    d3.selectAll('.input-highlight-selected')
+                        .classed('input-highlight-selected', false);
+                    // Extract currently selected node. Return if input tracing disabled or no
+                    // node is selected.
+                    var selectedNodeSelectorString = 'g.node.selected,g.op.selected';
+                    var node = d3.select(selectedNodeSelectorString);
+                    var currentNode = undefined;
+                    if (renderGraphInfo && renderGraphInfo.traceInputs && node && node[0] &&
+                        node[0][0]) {
+                        currentNode = node[0][0];
+                    }
+                    else {
+                        return;
+                    }
+                    var nodeName = currentNode.getAttribute('data-name');
+                    var opNodes = _getAllContainedOpNodes(nodeName, renderGraphInfo);
+                    var allTracedNodes = {};
+                    _.each(opNodes, function (nodeInstance) {
+                        allTracedNodes =
+                            traceAllInputsOfOpNode(renderGraphInfo, nodeInstance, allTracedNodes);
+                    });
+                    d3.selectAll(selectedNodeSelectorString).classed({
+                        // Remove the input-highlight from the selected node.
+                        'input-highlight': false,
+                        // Add input-highlight-selected class to selected node, which allows
+                        // treating the selected not as a special case of an input node.
+                        'input-highlight-selected': true
+                    });
+                    // Highlight all parent nodes of each OpNode as input parent to allow
+                    // specific highlighting.
+                    var highlightedNodes = Object.keys(allTracedNodes);
+                    var visibleNodes = _findVisibleParentsFromOpNodes(renderGraphInfo, highlightedNodes);
+                    _markParentsOfNodes(visibleNodes);
+                    // Attach class to all non-input nodes and edges for styling.
+                    d3.selectAll('g.node:not(.selected):not(.input-highlight)' +
+                        ':not(.input-parent):not(.input-children)')
+                        .classed('non-input', true)
+                        .each(function (d) {
+                        // Mark all nodes with the specified name as non-inputs. This
+                        // results in Annotation nodes which are attached to inputs to be
+                        // tagged as well.
+                        var nodeName = d.node.name;
+                        d3.selectAll("[data-name=\"" + nodeName + "\"]").classed('non-input', true);
+                    });
+                    d3.selectAll('g.edge:not(.input-edge-highlight)')
+                        .classed('non-input-edge-highlight', true);
+                }
+                node_1.traceInputs = traceInputs;
+                /**
+                 * Recursively find all op nodes contained by the node identified by the
+                 * provided name.
+                 * @param nodeName The meta or op node of which the OpNode instances are
+                 * required.
+                 * @param renderGraphInfo The rendered graph information object.
+                 * @returns {Array} An array of OpNodeImpl instances.
+                 */
+                function _getAllContainedOpNodes(nodeName, renderGraphInfo) {
+                    var opNodes = [];
+                    // Get current node.
+                    var node = renderGraphInfo.getNodeByName(nodeName);
+                    // If node is already OpNode then return the node plus its input embeddings.
+                    if (node instanceof tf.graph.OpNodeImpl) {
+                        return [node].concat(node.inEmbeddings);
+                    }
+                    // Otherwise, make recursive call for each node contained by the GroupNode.
+                    var childNodeNames = node.metagraph.nodes();
+                    _.each(childNodeNames, function (childNodeName) {
+                        opNodes =
+                            opNodes.concat(_getAllContainedOpNodes(childNodeName, renderGraphInfo));
+                    });
+                    return opNodes;
+                }
+                node_1._getAllContainedOpNodes = _getAllContainedOpNodes;
+                function traceAllInputsOfOpNode(renderGraphInfo, startNode, allTracedNodes) {
+                    // To prevent infinite loops due to cyclical relationships and improving
+                    // performance by tracing OpNode which is input to 2+ nodes only once.
+                    if (allTracedNodes[startNode.name]) {
+                        return allTracedNodes;
+                    }
+                    else {
+                        allTracedNodes[startNode.name] = true;
+                    }
+                    // Extract the inputs.
+                    var inputs = startNode.inputs;
+                    // Get visible parent.
+                    var currentVisibleParent = getVisibleParent(renderGraphInfo, startNode);
+                    // Mark as input node.
+                    d3.select(".node[data-name=\"" + currentVisibleParent.name + "\"]")
+                        .classed('input-highlight', true);
+                    // Find the visible parent of each input.
+                    var visibleInputs = {};
+                    _.each(inputs, function (nodeInstance) {
+                        var resolvedNode = renderGraphInfo.getNodeByName(nodeInstance.name);
+                        if (resolvedNode === undefined) {
+                            // Node could not be found in rendered Hierarchy, which happens when
+                            // tracing inputs of a SummaryNode.
+                            return;
+                        }
+                        // Ensure node is resolved to OpNode if name collision with Metanode exists.
+                        if (resolvedNode instanceof graph.MetanodeImpl) {
+                            var resolvedNodeName = tf.graph.getStrictName(resolvedNode.name);
+                            resolvedNode = renderGraphInfo.getNodeByName(resolvedNodeName);
+                        }
+                        var visibleParent = getVisibleParent(renderGraphInfo, resolvedNode);
+                        // Append OpNode to visible parent entry.
+                        var visibleInputsEntry = visibleInputs[visibleParent.name];
+                        if (visibleInputsEntry) {
+                            visibleInputsEntry.opNodes.push(resolvedNode);
+                        }
+                        else {
+                            visibleInputs[visibleParent.name] = {
+                                visibleParent: visibleParent,
+                                opNodes: [resolvedNode]
+                            };
+                        }
+                    });
+                    // Find all parents of the start node.
+                    var startNodeParents = {};
+                    var indexedStartNodeParents = [currentVisibleParent];
+                    startNodeParents[currentVisibleParent.name] = {
+                        traced: false,
+                        index: 0,
+                        connectionEndpoints: []
+                    };
+                    var currentNode = currentVisibleParent;
+                    for (var index = 1; currentNode.name !== tf.graph.ROOT_NAME; index++) {
+                        currentNode = currentNode.parentNode;
+                        startNodeParents[currentNode.name] = {
+                            traced: false,
+                            index: index,
+                            connectionEndpoints: []
+                        };
+                        indexedStartNodeParents[index] = currentNode;
+                    }
+                    // Find first mutual parent of each input node and highlight connection.
+                    _.forOwn(visibleInputs, function (visibleParentInfo, key) {
+                        var nodeInstance = visibleParentInfo.visibleParent;
+                        // Make recursive call for each input-OpNode contained by the visible
+                        // parent.
+                        _.each(visibleParentInfo.opNodes, function (opNode) {
+                            allTracedNodes =
+                                traceAllInputsOfOpNode(renderGraphInfo, opNode, allTracedNodes);
+                        });
+                        if (nodeInstance.name !== currentVisibleParent.name) {
+                            _createVisibleTrace(nodeInstance, startNodeParents, indexedStartNodeParents);
+                        }
+                    });
+                    return allTracedNodes;
+                }
+                node_1.traceAllInputsOfOpNode = traceAllInputsOfOpNode;
+                /**
+                 * Colors the edges to connect the passed node to the start node. This is
+                 * done by:
+                 *
+                 * a) Finding the first (visible) common parent in the rendered
+                 * hierarchy.
+                 * NB: There are 2 types of connections:
+                 * 1) Direct connections between node A
+                 * and B, marked below as II,
+                 * 2) Connections from any node A to its parent, A'. Marked below as I and III.
+                 * For type 2 connection you need to know the inner-nested node, the
+                 * direct parent, and the ultimate destination of the connection.
+                 *
+                 *  A_parent      B_parent
+                 * +--------+    +---------+
+                 * |        |    |         |
+                 * |  +--+ I| II |III+--+  |
+                 * |  |A +----------\x3e+B |  |
+                 * |  +--+  |    |   +--+  |
+                 * |        |    |         |
+                 * +--------+    +---------+
+                 *
+                 *
+                 * b) Highlighting the direct connection between the parents of A and B,
+                 * called A_parent and B_parent, s.t. A_parent and B_parent are children of the
+                 * mutual parent of A and B found in a), marked above as II.
+                 *
+                 * c) Highlighting the connection from A to A_parent and B to B_parent
+                 * (through all layers of parents between A and A_parent and B and B_parent,
+                 * respectively). Marked above as I and III.
+                 *
+                 * @param nodeInstance The instance of the node to use as destination node, B.
+                 * @param startNodeParents Map of startNodeParent names to information objects
+                 * about the parent.
+                 * @param indexedStartNodeParents An array of all parents of the start node.
+                 * This is required to find the child of the mutual parent which is a parent
+                 * of the start node.
+                 * @private
+                 */
+                function _createVisibleTrace(nodeInstance, startNodeParents, indexedStartNodeParents) {
+                    var currentNode = nodeInstance;
+                    var previousNode = nodeInstance;
+                    // Ascend through parents until a mutual parent is found with the start
+                    // node.
+                    var destinationParentPairs = [];
+                    while (!startNodeParents[currentNode.name]) {
+                        if (previousNode.name !== currentNode.name) {
+                            destinationParentPairs.push([previousNode, currentNode]);
+                        }
+                        previousNode = currentNode;
+                        currentNode = currentNode.parentNode;
+                    }
+                    // Connection between nodes is drawn between the parents of each
+                    // respective node, both of which share the mutual parent.
+                    var startNodeIndex = startNodeParents[currentNode.name].index;
+                    var startNodeName = indexedStartNodeParents[Math.max(startNodeIndex - 1, 0)].name;
+                    var startNodeTopParentName = startNodeName;
+                    var targetNodeTopParentName = previousNode.name;
+                    var endNodeName = previousNode.name;
+                    d3.selectAll("[data-edge=\"" + endNodeName + "--" + startNodeName + "\"]")
+                        .classed('input-edge-highlight', true);
+                    // Trace up the parents of the input.
+                    _.each(destinationParentPairs, function (value) {
+                        var inner = value[0];
+                        var outer = value[1];
+                        var edgeSelector = ("[data-edge=\"" + inner.name + "--" + startNodeTopParentName) +
+                            ("~~" + outer.name + "~~OUT\"]");
+                        d3.selectAll(edgeSelector).classed('input-edge-highlight', true);
+                    });
+                    // Trace up the parents of the start node.
+                    for (var index = 1; index < startNodeIndex; index++) {
+                        var inner = indexedStartNodeParents[index - 1];
+                        var outer = indexedStartNodeParents[index];
+                        var edgeSelector = ("[data-edge=\"" + targetNodeTopParentName + "~~" + outer.name) +
+                            ("~~IN--" + inner.name + "\"]");
+                        d3.selectAll(edgeSelector).classed('input-edge-highlight', true);
+                    }
+                }
+                /**
+                 * Creates map { [name: string] -> Node } of all visible / rendered parents
+                 * of the nodes identified by the node names passed in.
+                 *
+                 * @param renderGraphInfo The information on the rendered graph.
+                 * @param nodeNames String array of node names.
+                 * @returns {[nodeName: string]: Node}
+                 * @private
+                 */
+                function _findVisibleParentsFromOpNodes(renderGraphInfo, nodeNames) {
+                    var visibleParents = {};
+                    _.each(nodeNames, function (nodeName) {
+                        var currentNode = renderGraphInfo.getNodeByName(nodeName);
+                        var visibleParent = getVisibleParent(renderGraphInfo, currentNode);
+                        visibleParents[visibleParent.name] = visibleParent;
+                    });
+                    return visibleParents;
+                }
+                /**
+                 * Traverse through the parents of all nodes in the list and mark each
+                 * encountered node as input-parent.
+                 * @param visibleNodes Map of input nodes, have to be visible/rendered when
+                 * called.
+                 * @private
+                 */
+                function _markParentsOfNodes(visibleNodes) {
+                    _.forOwn(visibleNodes, function (nodeInstance) {
+                        // Mark all parents of the node as input-parents.
+                        var currentNode = nodeInstance;
+                        while (currentNode.name !== tf.graph.ROOT_NAME) {
+                            var renderedElement = d3.select(".node[data-name=\"" + currentNode.name + "\"]");
+                            // Only mark the element as a parent node to an input if it is not
+                            // marked as input node itself.
+                            if (renderedElement[0][0] &&
+                                !renderedElement.classed('input-highlight') &&
+                                !renderedElement.classed('selected') &&
+                                // OpNode only parent if start node is embedded node, in which case
+                                // the OpNode should be faded as well.
+                                !renderedElement.classed('op')) {
+                                renderedElement.classed('input-parent', true);
+                            }
+                            currentNode = currentNode.parentNode;
+                        }
+                    });
+                }
+                /**
+                 * Find the parent of the passed in op node which is expanded. This is done
+                 * by going through all parents until the parent's parent is expanded, thus
+                 * finding the the first unexpanded parent which is rendered on the screen.
+                 * @param renderGraphInfo The graph info object used to gain access to the
+                 * render info of the parents.
+                 * @param currentNode The node whose parent is to be found.
+                 * @returns Node
+                 */
+                function getVisibleParent(renderGraphInfo, currentNode) {
+                    var found = false;
+                    var currentParent = currentNode;
+                    while (!found) {
+                        // Get parent element, to extract name.
+                        currentNode = currentParent;
+                        currentParent = currentNode.parentNode;
+                        if (currentParent === undefined) {
+                            found = true;
+                        }
+                        else {
+                            var renderNode = renderGraphInfo.getRenderNodeByName(currentParent.name);
+                            // Found if node is rendered on the screen (renderNode truthy), and
+                            // the parent is either expanded (i.e. it is a metanode or seriesnode)
+                            // or the parent is an OpNode in which case currentNode is an embedded
+                            // node which has another OpNode as parent.
+                            if (renderNode &&
+                                (renderNode.expanded || currentParent instanceof graph.OpNodeImpl)) {
+                                found = true;
+                            }
+                        }
+                    } // Close while loop.
+                    return currentNode;
+                }
+                node_1.getVisibleParent = getVisibleParent;
             })(node = scene.node || (scene.node = {}));
         })(scene = graph.scene || (graph.scene = {}));
     })(graph = tf.graph || (tf.graph = {}));
-})(tf || (tf = {})); // close module
+})(tf || (tf = {})); // Close module.
 </script>
 <script>/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
 
@@ -10137,6 +10464,7 @@ Polymer({
 ::content .faded rect,
 ::content .faded ellipse,
 ::content .faded path,
+::content .faded use,
 ::content #rectHatch line,
 ::content #ellipseHatch line {
   color: #e0d4b3 !important;
@@ -10153,7 +10481,8 @@ Polymer({
   fill: url("#rectHatch") !important;
 }
 
-::content .faded ellipse {
+::content .faded ellipse,
+::content .faded use {
   fill: url("#ellipseHatch") !important;
 }
 
@@ -10161,6 +10490,81 @@ Polymer({
   opacity: 0;
 }
 
+/* Rules used for input-tracing. */
+::content .input-highlight > * > rect,
+::content .input-highlight > * > ellipse,
+::content .input-highlight > * > use
+{
+  fill: white;
+  stroke: #ff9800 !important;
+}
+
+/*  - Faded non-input styling */
+::content .non-input > * > rect,
+::content .non-input > * > ellipse,
+::content .non-input > * > use,
+/* For Const nodes. */
+::content .non-input > * > .constant:not([class*="input-highlight"]) >
+  .annotation-node > ellipse,
+/* For styling of annotation nodes of non-input nodes. */
+::content .non-input > g > .annotation > .annotation-node > rect {
+  stroke: #e0d4b3 !important;
+  stroke-width: inherit;
+  stroke-dasharray: inherit;
+}
+
+
+::content .non-input path {
+  visibility: hidden;
+}
+
+::content .non-input > .nodeshape > rect,
+::content .non-input > .annotation-node > rect,
+/* For styling of annotation nodes of non-input nodes. */
+::content .non-input > g > .annotation > .annotation-node > rect
+{
+  fill: url("#rectHatch") !important;
+}
+
+::content .non-input ellipse,
+::content .non-input use {
+  fill: url("#ellipseHatch") !important;
+}
+
+::content .non-input > text {
+  opacity: 0;
+}
+
+::content .non-input .annotation > .annotation-edge {
+  marker-end: url("#annotation-arrowhead-faded");
+}
+
+::content .non-input .annotation > .annotation-edge.refline {
+  marker-start: url("#ref-annotation-arrowhead-faded");
+}
+
+/* Input edges. */
+::content .input-edge-highlight > text {
+  fill: black !important;
+}
+::content .input-edge-highlight > path,
+::content .input-highlight > .in-annotations > .annotation > .annotation-edge,
+::content .input-highlight-selected > .in-annotations > .annotation >
+.annotation-edge {
+  stroke: #999 !important;
+}
+
+/* Non-input edges. */
+::content .non-input-edge-highlight,
+::content .non-input > g > .annotation > path,
+/* Annotation styles (label and edges respectively). */
+::content .non-input > g >
+.annotation:not(.input-highlight):not(.input-highlight-selected) >
+.annotation-label
+/*.annotation-edge*/
+{
+  visibility: hidden;
+}
 
 /* --- Op Node --- */
 
@@ -10689,6 +11093,7 @@ Polymer({
     tf.graph.util.time('tf-graph-scene (build scene):', function() {
       tf.graph.scene.buildGroup(d3.select(this.$.root), renderHierarchy.root, this);
       tf.graph.scene.addGraphClickListener(this.$.svg, this);
+      tf.graph.scene.node.traceInputs(renderHierarchy);
     }.bind(this));
     // Update the minimap again when the graph is done animating.
     setTimeout(function() {
@@ -10853,6 +11258,13 @@ Polymer({
     }, this);
   },
 
+  /**
+   * Handles new node selection. 1) Updates the selected-state of each node,
+   * 2) triggers input tracing.
+   * @param selectedNode {string} The name of the newly selected node.
+   * @param oldSelectedNode {string} The name of the previously selected node.
+   * @private
+   */
   _selectedNodeChanged: function(selectedNode, oldSelectedNode) {
     if (selectedNode === oldSelectedNode) {
       return;
@@ -10865,9 +11277,13 @@ Polymer({
       this._updateNodeState(oldSelectedNode);
     }
 
+    tf.graph.scene.node.traceInputs(this.renderHierarchy);
+
     if (!selectedNode) {
       return;
     }
+
+
     // Update the minimap to reflect the highlighted (selected) node.
     this.minimap.update();
     var node = this.renderHierarchy.hierarchy.node(selectedNode);
@@ -12296,10 +12712,10 @@ paper-progress {
 </template>
 <div class$="[[_getContainerClass(progress)]]">
   <div id="main">
-    <tf-graph id="graph" graph-hierarchy="{{graphHierarchy}}" basic-graph="[[graph]]" hierarchy-params="[[hierarchyParams]]" render-hierarchy="{{_renderHierarchy}}" devices-for-stats="[[devicesForStats]]" stats="[[stats]]" selected-node="{{_selectedNode}}" highlighted-node="{{_highlightedNode}}" color-by="[[colorBy]]" color-by-params="{{colorByParams}}" progress="{{progress}}"></tf-graph>
+    <tf-graph id="graph" graph-hierarchy="{{graphHierarchy}}" basic-graph="[[graph]]" hierarchy-params="[[hierarchyParams]]" render-hierarchy="{{renderHierarchy}}" devices-for-stats="[[devicesForStats]]" stats="[[stats]]" selected-node="{{_selectedNode}}" highlighted-node="{{_highlightedNode}}" color-by="[[colorBy]]" color-by-params="{{colorByParams}}" progress="{{progress}}"></tf-graph>
   </div>
   <div id="info">
-    <tf-graph-info id="graph-info" title="selected" graph-hierarchy="[[graphHierarchy]]" render-hierarchy="[[_renderHierarchy]]" graph="[[graph]]" selected-node="{{_selectedNode}}" selected-node-include="{{_selectedNodeInclude}}" highlighted-node="{{_highlightedNode}}" color-by="[[colorBy]]" color-by-params="[[colorByParams]]"></tf-graph-info>
+    <tf-graph-info id="graph-info" title="selected" graph-hierarchy="[[graphHierarchy]]" render-hierarchy="[[renderHierarchy]]" graph="[[graph]]" selected-node="{{_selectedNode}}" selected-node-include="{{_selectedNodeInclude}}" highlighted-node="{{_highlightedNode}}" color-by="[[colorBy]]" color-by-params="[[colorByParams]]"></tf-graph-info>
   </div>
   <div class="context-menu"></div>
 </div>
@@ -12324,14 +12740,17 @@ Polymer({
     colorBy: String,
     colorByParams: {
       type: Object,
-      notify: true,
+      notify: true
+    },
+    renderHierarchy: {
+      type: Object,
+      notify: true
     },
     // Private API: Data routing between child components.
     _selectedNode: String,
     // The enum value of the include property of the selected node.
     _selectedNodeInclude: Number,
-    _highlightedNode: String,
-    _renderHierarchy: Object,
+    _highlightedNode: String
   },
   listeners: {
     'node-toggle-extract': '_nodeToggleExtract'
@@ -12617,6 +13036,14 @@ span.counter {
       <input type="file" id="file" name="file" on-change="_updateFileInput">
     </div>
   </div>
+  <div class="control-holder">
+    <div class="title">
+      Trace inputs
+    </div>
+    <paper-toggle-button id="trace-inputs">
+
+    </paper-toggle-button>
+  </div>
   <div class="control-holder">
     <div class="title">Color</div>
     <paper-radio-group selected="{{colorBy}}">
@@ -12846,6 +13273,10 @@ Polymer({
       type: Array,
       observer: '_datasetsChanged'
     },
+    renderHierarchy: {
+      type: Object,
+      notify: true,
+    },
     metadataTags: {
       type: Array,
       computed: '_getMetadataTags(selectedDataset, datasets)'
@@ -12870,6 +13301,14 @@ Polymer({
       computed: '_getCurrentGradientParams(colorByParams, colorBy)'
     }
   },
+  listeners: {
+    'trace-inputs.change': '_traceInputToggleChanged'
+  },
+  _traceInputToggleChanged: function(event) {
+    // Flip the state of the trace inputs flag.
+    this.renderHierarchy.traceInputs = event.target.active;
+    tf.graph.scene.node.traceInputs(this.renderHierarchy);
+  },
   _statsNotNull: function(stats) {
     return stats != null;
   },
@@ -12988,6 +13427,7 @@ Polymer({
     if (this.datasets) {
       this.set('selectedMetadataTag', -1);
       this.set('colorBy', 'structure');
+      this.$['trace-inputs'].active = false; // Set trace input to off-state.
       this._setDownloadFilename(this.datasets[newDataset].path);
     }
   },
@@ -13019,12 +13459,11 @@ Polymer({
 <template is="dom-if" if="[[!_datasetsEmpty(_datasets)]]">
 <tf-dashboard-layout>
 <div class="sidebar">
-  <tf-graph-controls id="controls" devices-for-stats="{{_devicesForStats}}" color-by-params="[[_colorByParams]]" stats="[[_stats]]" color-by="{{_colorBy}}" ,="" datasets="[[_datasets]]" selected-dataset="{{_selectedDataset}}" selected-file="{{_selectedFile}}" selected-metadata-tag="{{_selectedMetadataTag}}"></tf-graph-controls>
-  <tf-graph-loader id="loader" datasets="[[_datasets]]" ,="" selected-dataset="[[_selectedDataset]]" selected-metadata-tag="[[_selectedMetadataTag]]" selected-file="[[_selectedFile]]" out-graph-hierarchy="{{_graphHierarchy}}" out-graph="{{_graph}}" out-stats="{{_stats}}" progress="{{_progress}}" out-hierarchy-params="{{_hierarchyParams}}"></tf-graph-loader>
+  <tf-graph-controls id="controls" devices-for-stats="{{_devicesForStats}}" color-by-params="[[_colorByParams]]" stats="[[_stats]]" color-by="{{_colorBy}}" datasets="[[_datasets]]" render-hierarchy="[[_renderHierarchy]]" selected-dataset="{{_selectedDataset}}" selected-file="{{_selectedFile}}" selected-metadata-tag="{{_selectedMetadataTag}}"></tf-graph-controls>
+  <tf-graph-loader id="loader" datasets="[[_datasets]]" selected-dataset="[[_selectedDataset]]" selected-metadata-tag="[[_selectedMetadataTag]]" selected-file="[[_selectedFile]]" out-graph-hierarchy="{{_graphHierarchy}}" out-graph="{{_graph}}" out-stats="{{_stats}}" progress="{{_progress}}" out-hierarchy-params="{{_hierarchyParams}}"></tf-graph-loader>
 </div>
 <div class="center">
-    <tf-graph-board id="graphboard" devices-for-stats="[[_devicesForStats]]" graph-hierarchy="[[_graphHierarchy]]" graph="[[_graph]]" stats="[[_stats]]" progress="[[_progress]]" color-by="[[_colorBy]]" color-by-params="{{_colorByParams}}" hierarchy-params="[[_hierarchyParams]]">
-    </tf-graph-board>
+    <tf-graph-board id="graphboard" devices-for-stats="[[_devicesForStats]]" color-by="[[_colorBy]]" color-by-params="{{_colorByParams}}" graph-hierarchy="[[_graphHierarchy]]" graph="[[_graph]]" hierarchy-params="[[_hierarchyParams]]" progress="[[_progress]]" render-hierarchy="{{_renderHierarchy}}" stats="[[_stats]]"></tf-graph-board>
 </div>
 </tf-dashboard-layout>
 </template>
@@ -13049,8 +13488,9 @@ Polymer({
   is: 'tf-graph-dashboard',
   properties: {
     _datasets: Object,
+    _renderHierarchy: Object,
     backend: {type: Object, observer: 'reload'},
-    runs: Array,
+    runs: Array
   },
   reload: function() {
     Promise.all([this.backend.graphRuns(), this.backend.runMetadataRuns()])
diff --git a/tensorflow/tools/ci_build/install/install_bootstrap_deb_packages.sh b/tensorflow/tools/ci_build/install/install_bootstrap_deb_packages.sh
index b56225c0b84..15526ef4f8c 100755
--- a/tensorflow/tools/ci_build/install/install_bootstrap_deb_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_bootstrap_deb_packages.sh
@@ -18,7 +18,7 @@ set -e
 
 # Install bootstrap dependencies from ubuntu deb repository.
 apt-get update
-apt-get install -y \
+apt-get install -y --no-install-recommends \
     software-properties-common
 apt-get clean
 rm -rf /var/lib/apt/lists/*
diff --git a/tensorflow/tools/ci_build/install/install_deb_packages.sh b/tensorflow/tools/ci_build/install/install_deb_packages.sh
index 9b5f1418b14..4dc58c8ce45 100755
--- a/tensorflow/tools/ci_build/install/install_deb_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_deb_packages.sh
@@ -19,7 +19,7 @@ set -e
 # Install dependencies from ubuntu deb repository.
 apt-get update
 
-apt-get install -y \
+apt-get install -y --no-install-recommends \
     autoconf \
     automake \
     build-essential \
@@ -37,6 +37,7 @@ apt-get install -y \
     python-virtualenv \
     python3-dev \
     python3-pip \
+    rsync \
     sudo \
     swig \
     unzip \
diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh
index 1e667690d20..d244a249504 100755
--- a/tensorflow/tools/ci_build/install/install_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh
@@ -19,6 +19,9 @@ set -e
 # Install pip packages from whl files to avoid the time-consuming process of
 # building from source.
 
+pip install wheel
+pip3 install wheel
+
 # Use pip to install numpy to the latest version, instead of 1.8.2 through
 # apt-get
 wget -q https://pypi.python.org/packages/17/f3/404bc85be67150663024d2bb5af654c7d16cf678077690dda27b91be14eb/numpy-1.8.2-cp27-cp27mu-manylinux1_x86_64.whl#md5=3ccf5c004fc99bd06dd443de80d622e6
diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
index e781361627f..ba4293cb276 100755
--- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
@@ -30,7 +30,7 @@ tar xzf swig-3.0.8.tar.gz
 
 pushd /swig-3.0.8
 
-apt-get install -y libpcre3-dev
+apt-get install -y --no-install-recommends libpcre3-dev
 ./configure
 make
 make install
@@ -43,7 +43,7 @@ rm -rf swig-3.0.8
 rm -f swig-3.0.8.tar.gz
 
 # Install Python 3.5 and dev library
-apt-get install -y python3.5 libpython3.5-dev
+apt-get install -y --no-install-recommends python3.5 libpython3.5-dev
 
 # Install pip3.4 and numpy for Python 3.4
 # This strange-looking install step is a stopgap measure to make the genrule
diff --git a/tensorflow/tools/ci_build/install/install_tensorboard_packages.sh b/tensorflow/tools/ci_build/install/install_tensorboard_packages.sh
index 95b8314f4c6..ca5092cd475 100755
--- a/tensorflow/tools/ci_build/install/install_tensorboard_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_tensorboard_packages.sh
@@ -18,7 +18,7 @@ set -e
 
 # Install dependencies from ubuntu deb repository.
 apt-get update
-apt-get install -y \
+apt-get install -y --no-install-recommends \
     chromium-browser \
     nodejs \
     nodejs-legacy \
diff --git a/tensorflow/tools/dist_test/Dockerfile b/tensorflow/tools/dist_test/Dockerfile
index 5dc3e5565c3..66787ca7f8b 100644
--- a/tensorflow/tools/dist_test/Dockerfile
+++ b/tensorflow/tools/dist_test/Dockerfile
@@ -3,7 +3,7 @@ FROM ubuntu:14.04
 MAINTAINER Shanqing Cai <cais@google.com>
 
 RUN apt-get update
-RUN apt-get install -y \
+RUN apt-get install -y --no-install-recommends \
     curl \
     python \
     python-numpy \
diff --git a/tensorflow/tools/dist_test/Dockerfile.local b/tensorflow/tools/dist_test/Dockerfile.local
index fae7ddb14a4..e23fa034a3d 100644
--- a/tensorflow/tools/dist_test/Dockerfile.local
+++ b/tensorflow/tools/dist_test/Dockerfile.local
@@ -4,7 +4,7 @@ MAINTAINER Shanqing Cai <cais@google.com>
 
 RUN apt-get update
 
-RUN apt-get install -y \
+RUN apt-get install -y --no-install-recommends \
     build-essential \
     dbus \
     git \
diff --git a/tensorflow/tools/dist_test/local/Dockerfile b/tensorflow/tools/dist_test/local/Dockerfile
index dece508c0df..96846f65648 100644
--- a/tensorflow/tools/dist_test/local/Dockerfile
+++ b/tensorflow/tools/dist_test/local/Dockerfile
@@ -4,7 +4,7 @@ MAINTAINER Shanqing Cai <cais@google.com>
 
 RUN apt-get update
 
-RUN apt-get install -y \
+RUN apt-get install -y --no-install-recommends \
     build-essential \
     git \
     software-properties-common
diff --git a/tensorflow/tools/docker/Dockerfile b/tensorflow/tools/docker/Dockerfile
index 1645d3f7f9b..31c3cd4d30a 100644
--- a/tensorflow/tools/docker/Dockerfile
+++ b/tensorflow/tools/docker/Dockerfile
@@ -3,15 +3,20 @@ FROM ubuntu:14.04
 MAINTAINER Craig Citro <craigcitro@google.com>
 
 # Pick up some TF dependencies
-RUN apt-get update && apt-get install -y \
+RUN apt-get update && apt-get install -y --no-install-recommends \
+        build-essential \
         curl \
         libfreetype6-dev \
         libpng12-dev \
         libzmq3-dev \
         pkg-config \
+        python \
+        python-dev \
         python-numpy \
         python-pip \
         python-scipy \
+        rsync \
+        unzip \
         && \
     apt-get clean && \
     rm -rf /var/lib/apt/lists/*
diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel
index 4ed80b12ad8..5e8693525be 100644
--- a/tensorflow/tools/docker/Dockerfile.devel
+++ b/tensorflow/tools/docker/Dockerfile.devel
@@ -2,7 +2,7 @@ FROM ubuntu:14.04
 
 MAINTAINER Craig Citro <craigcitro@google.com>
 
-RUN apt-get update && apt-get install -y \
+RUN apt-get update && apt-get install -y --no-install-recommends \
         build-essential \
         curl \
         git \
@@ -13,8 +13,10 @@ RUN apt-get update && apt-get install -y \
         python-dev \
         python-numpy \
         python-pip \
+        rsync \
         software-properties-common \
         swig \
+        unzip \
         zip \
         zlib1g-dev \
         && \
@@ -50,7 +52,7 @@ COPY run_jupyter.sh /
 #   https://bugs.launchpad.net/trusty-backports/+bug/1368094
 RUN add-apt-repository -y ppa:openjdk-r/ppa && \
     apt-get update && \
-    apt-get install -y openjdk-8-jdk openjdk-8-jre-headless && \
+    apt-get install -y --no-install-recommends openjdk-8-jdk openjdk-8-jre-headless && \
     apt-get clean && \
     rm -rf /var/lib/apt/lists/*
 
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu
index b3db52e081d..74f41ca746f 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu
@@ -2,7 +2,7 @@ FROM nvidia/cuda:7.5-cudnn4-devel
 
 MAINTAINER Craig Citro <craigcitro@google.com>
 
-RUN apt-get update && apt-get install -y \
+RUN apt-get update && apt-get install -y --no-install-recommends \
         build-essential \
         curl \
         git \
@@ -10,11 +10,14 @@ RUN apt-get update && apt-get install -y \
         libpng12-dev \
         libzmq3-dev \
         pkg-config \
+        python \
         python-dev \
         python-numpy \
         python-pip \
+        rsync \
         software-properties-common \
         swig \
+        unzip \
         zip \
         zlib1g-dev \
         && \
@@ -50,7 +53,7 @@ COPY run_jupyter.sh /
 #   https://bugs.launchpad.net/trusty-backports/+bug/1368094
 RUN add-apt-repository -y ppa:openjdk-r/ppa && \
     apt-get update && \
-    apt-get install -y openjdk-8-jdk openjdk-8-jre-headless && \
+    apt-get install -y --no-install-recommends openjdk-8-jdk openjdk-8-jre-headless && \
     apt-get clean && \
     rm -rf /var/lib/apt/lists/*
 
diff --git a/tensorflow/tools/docker/Dockerfile.gpu b/tensorflow/tools/docker/Dockerfile.gpu
index 7f0765b4e81..db91720cd9e 100644
--- a/tensorflow/tools/docker/Dockerfile.gpu
+++ b/tensorflow/tools/docker/Dockerfile.gpu
@@ -3,15 +3,20 @@ FROM nvidia/cuda:7.5-cudnn4-devel
 MAINTAINER Craig Citro <craigcitro@google.com>
 
 # Pick up some TF dependencies
-RUN apt-get update && apt-get install -y \
+RUN apt-get update && apt-get install -y --no-install-recommends \
+        build-essential \
         curl \
         libfreetype6-dev \
         libpng12-dev \
         libzmq3-dev \
         pkg-config \
+        python \
+        python-dev \
         python-numpy \
         python-pip \
         python-scipy \
+        rsync \
+        unzip \
         && \
     apt-get clean && \
     rm -rf /var/lib/apt/lists/*
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index ae631812586..b4fc9e4df2e 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -35,6 +35,7 @@ sh_binary(
 #        "//tensorflow/contrib/session_bundle/example:half_plus_two",
         "//tensorflow/contrib/slim:all_files",
         "//tensorflow/contrib/slim/python/slim/data:all_files",
+        "//tensorflow/contrib/tensor_forest:all_files",
         "//tensorflow/core:framework_headers",
         "//tensorflow/examples/tutorials/mnist:package",
         "//tensorflow/models/embedding:package",
diff --git a/tensorflow/tools/proto_text/BUILD b/tensorflow/tools/proto_text/BUILD
index e6592479b6d..980a6d651e9 100644
--- a/tensorflow/tools/proto_text/BUILD
+++ b/tensorflow/tools/proto_text/BUILD
@@ -47,8 +47,8 @@ cc_library(
         "-lpthread",
     ] + select({
         "//tensorflow:darwin": [],
-        "//conditions:default": ["-lrt"]
-    }),    
+        "//conditions:default": ["-lrt"],
+    }),
     deps = [
         "//tensorflow/core:lib",
     ],