diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index f3f3887afc5..7eecbf79e5c 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -382,7 +382,8 @@ def map_structure(func, *structure, **check_types_dict):
 def _yield_flat_up_to(shallow_tree, input_tree):
   """Yields elements `input_tree` partially flattened up to `shallow_tree`."""
   if is_sequence(shallow_tree):
-    for shallow_branch, input_branch in zip(shallow_tree, input_tree):
+    for shallow_branch, input_branch in zip(_yield_value(shallow_tree),
+                                            _yield_value(input_tree)):
       for input_leaf in _yield_flat_up_to(shallow_branch, input_branch):
         yield input_leaf
   else:
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index 375e30e9534..b4a0525e6c3 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -301,6 +301,7 @@ class NestTest(test.TestCase):
     nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False)
 
   def testFlattenUpTo(self):
+    # Shallow tree ends at scalar.
     input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
     shallow_tree = [[True, True], [False, True]]
     flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
@@ -308,6 +309,7 @@ class NestTest(test.TestCase):
     self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
     self.assertEqual(flattened_shallow_tree, [True, True, False, True])
 
+    # Shallow tree ends at string.
     input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
     shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
     input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
@@ -317,6 +319,46 @@ class NestTest(test.TestCase):
                      [("a", 1), ("b", 2), ("c", 3), ("d", 4)])
     self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])
 
+    # Make sure dicts are correctly flattened, yielding values, not keys.
+    input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
+    shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
+    input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
+                                                              input_tree)
+    self.assertEqual(input_tree_flattened_as_shallow_tree,
+                     [1, {"c": 2}, 3, (4, 5)])
+
+    # Namedtuples.
+    ab_tuple = collections.namedtuple("ab_tuple", "a, b")
+    input_tree = ab_tuple(a=[0, 1], b=2)
+    shallow_tree = ab_tuple(a=0, b=1)
+    input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
+                                                              input_tree)
+    self.assertEqual(input_tree_flattened_as_shallow_tree,
+                     [[0, 1], 2])
+
+    # Nested dicts, OrderedDicts and namedtuples.
+    input_tree = collections.OrderedDict(
+        [("a", ab_tuple(a=[0, {"b": 1}], b=2)),
+         ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})])
+    shallow_tree = input_tree
+    input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
+                                                              input_tree)
+    self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
+    shallow_tree = collections.OrderedDict([("a", 0),
+                                            ("b", {"d": 3, "e": 1})])
+    input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
+                                                              input_tree)
+    self.assertEqual(input_tree_flattened_as_shallow_tree,
+                     [ab_tuple(a=[0, {"b": 1}], b=2),
+                      3,
+                      collections.OrderedDict([("f", 4)])])
+    shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
+    input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
+                                                              input_tree)
+    self.assertEqual(input_tree_flattened_as_shallow_tree,
+                     [ab_tuple(a=[0, {"b": 1}], b=2),
+                      {"d": 3, "e": collections.OrderedDict([("f", 4)])}])
+
     ## Shallow non-list edge-case.
     # Using iterable elements.
     input_tree = ["input_tree"]
@@ -401,6 +443,7 @@ class NestTest(test.TestCase):
     self.assertEqual(flattened_shallow_tree, shallow_tree)
 
   def testMapStructureUpTo(self):
+    # Named tuples.
     ab_tuple = collections.namedtuple("ab_tuple", "a, b")
     op_tuple = collections.namedtuple("op_tuple", "add, mul")
     inp_val = ab_tuple(a=2, b=3)
@@ -410,6 +453,7 @@ class NestTest(test.TestCase):
     self.assertEqual(out.a, 6)
     self.assertEqual(out.b, 15)
 
+    # Lists.
     data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
     name_list = ["evens", ["odds", "primes"]]
     out = nest.map_structure_up_to(