From 011e9baccd343eb943d25014c4e8aec53eac396b Mon Sep 17 00:00:00 2001
From: Vijay Vasudevan <vrv@google.com>
Date: Thu, 12 Nov 2015 18:34:45 -0800
Subject: [PATCH] TensorFlow: a few small updates.

Changes:
- Fix softmax formula in word2vec to remove an extra exp()
  by @gouwsmeister

- Python3 fixes to remove basestring / support for unicode by @mrry

- Remove some comments by Josh

- Specify exact versions of bower dependencies for TensorBoard by
  @danmane.

Base CL: 107742361
---
 tensorflow/core/framework/node_def_util.h     |  3 +-
 tensorflow/core/graph/testlib.h               |  2 -
 tensorflow/g3doc/tutorials/word2vec/index.md  |  2 +-
 tensorflow/python/client/session.py           | 81 +++++++++++++++----
 tensorflow/python/client/session_test.py      | 49 ++++++++++-
 tensorflow/python/client/tf_session.i         |  4 +-
 tensorflow/python/client/tf_session_helper.cc | 42 ++++++----
 tensorflow/python/framework/importer.py       |  7 +-
 tensorflow/python/framework/importer_test.py  | 52 ++++++++++++
 tensorflow/python/framework/ops.py            | 20 ++---
 tensorflow/python/framework/tensor_util.py    |  5 +-
 tensorflow/python/lib/core/strings.i          | 45 ++---------
 tensorflow/python/ops/clip_ops.py             |  6 +-
 tensorflow/python/ops/op_def_library.py       |  5 +-
 tensorflow/python/ops/variable_scope.py       |  5 +-
 tensorflow/python/platform/base.i             | 30 +++----
 tensorflow/python/training/saver.py           |  4 +-
 tensorflow/python/training/saver_test.py      |  9 ++-
 tensorflow/python/training/summary_io.py      |  4 +-
 tensorflow/python/util/protobuf/compare.py    |  8 +-
 .../python/util/protobuf/compare_test.py      |  4 +-
 tensorflow/tensorboard/bower.json             |  8 +-
 22 files changed, 264 insertions(+), 131 deletions(-)

diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index fce6fd24338..9efcd5567f5 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -111,8 +111,7 @@ Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
 Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                    const NameAttrList** value);  // type: "func"
 
-// Computes the input and output types for a specific node, for
-// attr-style ops.
+// Computes the input and output types for a specific node.
 // REQUIRES: ValidateOpDef(op_def).ok()
 Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
                          DataTypeVector* inputs, DataTypeVector* outputs);
diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h
index 11905bbf6aa..2a5ad6e14aa 100644
--- a/tensorflow/core/graph/testlib.h
+++ b/tensorflow/core/graph/testlib.h
@@ -64,11 +64,9 @@ Node* Unary(Graph* g, const string& func, Node* input, int index = 0);
 Node* Identity(Graph* g, Node* input, int index = 0);
 
 // Adds a binary function "func" node in "g" taking "in0" and "in1".
-// Requires that "func" name an attr-style Op.
 Node* Binary(Graph* g, const string& func, Node* in0, Node* in1);
 
 // Adds a function "func" node in "g" taking inputs "ins".
-// Requires that "func" name an attr-style Op.
 Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins);
 
 // Adds a binary add node in "g" doing in0 + in1.
diff --git a/tensorflow/g3doc/tutorials/word2vec/index.md b/tensorflow/g3doc/tutorials/word2vec/index.md
index a046d70e193..a8916d6f6a2 100644
--- a/tensorflow/g3doc/tutorials/word2vec/index.md
+++ b/tensorflow/g3doc/tutorials/word2vec/index.md
@@ -100,7 +100,7 @@ given the previous words \\(h\\) (for "history") in terms of a
 
 $$
 \begin{align}
-P(w_t | h) &= \text{softmax}(\exp \{ \text{score}(w_t, h) \}) \\
+P(w_t | h) &= \text{softmax}(\text{score}(w_t, h)) \\
            &= \frac{\exp \{ \text{score}(w_t, h) \} }
              {\sum_\text{Word w' in Vocab} \exp \{ \text{score}(w', h) \} }.
 \end{align}
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index fedaa2c2ca9..38e01798ac1 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -11,6 +11,7 @@ import threading
 import tensorflow.python.platform
 
 import numpy as np
+import six
 
 from tensorflow.python import pywrap_tensorflow as tf_session
 from tensorflow.python.framework import errors
@@ -36,6 +37,56 @@ class SessionInterface(object):
     raise NotImplementedError('Run')
 
 
+def _as_bytes(bytes_or_unicode):
+  """Returns the given argument as a byte array.
+
+  NOTE(mrry): For Python 2 and 3 compatibility, we convert all string
+  arguments to SWIG methods into byte arrays. Unicode strings are
+  encoded as UTF-8; however the valid arguments for all of the
+  human-readable arguments must currently be a subset of ASCII.
+
+  Args:
+    bytes_or_unicode: A `unicode`, `string`, or `bytes` object.
+
+  Returns:
+    A `bytes` object.
+
+  Raises:
+    TypeError: If `bytes_or_unicode` is not a binary or unicode string.
+  """
+  if isinstance(bytes_or_unicode, six.text_type):
+    return bytes_or_unicode.encode('utf-8')
+  elif isinstance(bytes_or_unicode, six.binary_type):
+    return bytes_or_unicode
+  else:
+    raise TypeError('bytes_or_unicode must be a binary or unicode string.')
+
+
+def _as_text(bytes_or_unicode):
+  """Returns the given argument as a unicode string.
+
+  NOTE(mrry): For Python 2 and 3 compatibility, we interpret all
+  returned strings from SWIG methods as byte arrays. This function
+  converts those strings that are intended to be human-readable into
+  UTF-8 unicode strings.
+
+  Args:
+    bytes_or_unicode: A `unicode`, `string`, or `bytes` object.
+
+  Returns:
+    A `unicode` (Python 2) or `str` (Python 3) object.
+
+  Raises:
+    TypeError: If `bytes_or_unicode` is not a binary or unicode string.
+  """
+  if isinstance(bytes_or_unicode, six.text_type):
+    return bytes_or_unicode
+  elif isinstance(bytes_or_unicode, six.binary_type):
+    return bytes_or_unicode.decode('utf-8')
+  else:
+    raise TypeError('bytes_or_unicode must be a binary or unicode string.')
+
+
 class BaseSession(SessionInterface):
   """A class for interacting with a TensorFlow computation.
 
@@ -75,8 +126,7 @@ class BaseSession(SessionInterface):
       status = tf_session.TF_NewStatus()
       self._session = tf_session.TF_NewSession(opts, status)
       if tf_session.TF_GetCode(status) != 0:
-        message = tf_session.TF_Message(status)
-        raise RuntimeError(message)
+        raise RuntimeError(_as_text(tf_session.TF_Message(status)))
 
     finally:
       tf_session.TF_DeleteSessionOptions(opts)
@@ -97,7 +147,7 @@ class BaseSession(SessionInterface):
           status = tf_session.TF_NewStatus()
           tf_session.TF_CloseSession(self._session, status)
           if tf_session.TF_GetCode(status) != 0:
-            raise RuntimeError(tf_session.TF_Message(status))
+            raise RuntimeError(_as_text(tf_session.TF_Message(status)))
         finally:
           tf_session.TF_DeleteStatus(status)
 
@@ -108,7 +158,7 @@ class BaseSession(SessionInterface):
       if self._session is not None:
         tf_session.TF_DeleteSession(self._session, status)
         if tf_session.TF_GetCode(status) != 0:
-          raise RuntimeError(tf_session.TF_Message(status))
+          raise RuntimeError(_as_text(tf_session.TF_Message(status)))
         self._session = None
     finally:
       tf_session.TF_DeleteStatus(status)
@@ -265,7 +315,6 @@ class BaseSession(SessionInterface):
       TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type.
       ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a
         `Tensor` that doesn't exist.
-
     """
     def _fetch_fn(fetch):
       for tensor_type, fetch_fn, _ in BaseSession._REGISTERED_EXPANSIONS:
@@ -302,9 +351,9 @@ class BaseSession(SessionInterface):
           fetch_t = self.graph.as_graph_element(subfetch, allow_tensor=True,
                                                 allow_operation=True)
           if isinstance(fetch_t, ops.Operation):
-            target_list.append(fetch_t.name)
+            target_list.append(_as_bytes(fetch_t.name))
           else:
-            subfetch_names.append(fetch_t.name)
+            subfetch_names.append(_as_bytes(fetch_t.name))
         except TypeError as e:
           raise TypeError('Fetch argument %r of %r has invalid type %r, '
                           'must be a string or Tensor. (%s)'
@@ -343,7 +392,7 @@ class BaseSession(SessionInterface):
                   'which has shape %r'
                   % (np_val.shape, subfeed_t.name,
                      tuple(subfeed_t.get_shape().dims)))
-          feed_dict_string[str(subfeed_t.name)] = np_val
+          feed_dict_string[_as_bytes(subfeed_t.name)] = np_val
 
     # Run request and get response.
     results = self._do_run(target_list, unique_fetch_targets, feed_dict_string)
@@ -373,11 +422,12 @@ class BaseSession(SessionInterface):
     """Runs a step based on the given fetches and feeds.
 
     Args:
-      target_list: A list of strings corresponding to names of tensors
+      target_list: A list of byte arrays corresponding to names of tensors
         or operations to be run to, but not fetched.
-      fetch_list: A list of strings corresponding to names of tensors to be
-        fetched and operations to be run.
-      feed_dict: A dictionary that maps tensor names to numpy ndarrays.
+      fetch_list: A list of byte arrays corresponding to names of tensors to
+        be fetched and operations to be run.
+      feed_dict: A dictionary that maps tensor names (as byte arrays) to
+        numpy ndarrays.
 
     Returns:
       A list of numpy ndarrays, corresponding to the elements of
@@ -397,7 +447,7 @@ class BaseSession(SessionInterface):
             tf_session.TF_ExtendGraph(
                 self._session, graph_def.SerializeToString(), status)
             if tf_session.TF_GetCode(status) != 0:
-              raise RuntimeError(tf_session.TF_Message(status))
+              raise RuntimeError(_as_text(tf_session.TF_Message(status)))
             self._opened = True
           finally:
             tf_session.TF_DeleteStatus(status)
@@ -409,7 +459,8 @@ class BaseSession(SessionInterface):
 
     except tf_session.StatusNotOK as e:
       e_type, e_value, e_traceback = sys.exc_info()
-      m = BaseSession._NODEDEF_NAME_RE.search(e.error_message)
+      error_message = _as_text(e.error_message)
+      m = BaseSession._NODEDEF_NAME_RE.search(error_message)
       if m is not None:
         node_name = m.group(1)
         node_def = None
@@ -419,7 +470,7 @@ class BaseSession(SessionInterface):
         except KeyError:
           op = None
         # pylint: disable=protected-access
-        raise errors._make_specific_exception(node_def, op, e.error_message,
+        raise errors._make_specific_exception(node_def, op, error_message,
                                               e.code)
         # pylint: enable=protected-access
       raise e_type, e_value, e_traceback
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 5472c96a75c..96b7136b15c 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -9,6 +9,7 @@ import time
 import tensorflow.python.platform
 
 import numpy as np
+import six
 from six.moves import xrange  # pylint: disable=redefined-builtin
 
 from tensorflow.core.framework import config_pb2
@@ -551,9 +552,55 @@ class SessionTest(test_util.TensorFlowTestCase):
       self.assertEqual(c_list[0], out[0])
       self.assertEqual(c_list[1], out[1])
 
+  def testStringFeedWithUnicode(self):
+    with session.Session():
+      c_list = [u'\n\x01\x00', u'\n\x00\x01']
+      feed_t = array_ops.placeholder(dtype=types.string, shape=[2])
+      c = array_ops.identity(feed_t)
+
+      out = c.eval(feed_dict={feed_t: c_list})
+      self.assertEqual(c_list[0], out[0].decode('utf-8'))
+      self.assertEqual(c_list[1], out[1].decode('utf-8'))
+
+      out = c.eval(feed_dict={feed_t: np.array(c_list, dtype=np.object)})
+      self.assertEqual(c_list[0], out[0].decode('utf-8'))
+      self.assertEqual(c_list[1], out[1].decode('utf-8'))
+
   def testInvalidTargetFails(self):
     with self.assertRaises(RuntimeError):
-      session.Session("INVALID_TARGET")
+      session.Session('INVALID_TARGET')
+
+  def testFetchByNameDifferentStringTypes(self):
+    with session.Session() as sess:
+      c = constant_op.constant(42.0, name='c')
+      d = constant_op.constant(43.0, name=u'd')
+      e = constant_op.constant(44.0, name=b'e')
+      f = constant_op.constant(45.0, name=r'f')
+
+      self.assertTrue(isinstance(c.name, six.text_type))
+      self.assertTrue(isinstance(d.name, six.text_type))
+      self.assertTrue(isinstance(e.name, six.text_type))
+      self.assertTrue(isinstance(f.name, six.text_type))
+
+      self.assertEqual(42.0, sess.run('c:0'))
+      self.assertEqual(42.0, sess.run(u'c:0'))
+      self.assertEqual(42.0, sess.run(b'c:0'))
+      self.assertEqual(42.0, sess.run(r'c:0'))
+
+      self.assertEqual(43.0, sess.run('d:0'))
+      self.assertEqual(43.0, sess.run(u'd:0'))
+      self.assertEqual(43.0, sess.run(b'd:0'))
+      self.assertEqual(43.0, sess.run(r'd:0'))
+
+      self.assertEqual(44.0, sess.run('e:0'))
+      self.assertEqual(44.0, sess.run(u'e:0'))
+      self.assertEqual(44.0, sess.run(b'e:0'))
+      self.assertEqual(44.0, sess.run(r'e:0'))
+
+      self.assertEqual(45.0, sess.run('f:0'))
+      self.assertEqual(45.0, sess.run(u'f:0'))
+      self.assertEqual(45.0, sess.run(b'f:0'))
+      self.assertEqual(45.0, sess.run(r'f:0'))
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 30e80f779f7..0e813896ff2 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -68,7 +68,7 @@ import_array();
   PyObject* value;
   Py_ssize_t pos = 0;
   while (PyDict_Next($input, &pos, &key, &value)) {
-    const char* key_string = PyString_AsString(key);
+    char* key_string = PyBytes_AsString(key);
     if (!key_string) {
       SWIG_fail;
     }
@@ -121,7 +121,7 @@ import_array();
     PyList_SET_ITEM(temp_string_list.get(), i, elem);
     Py_INCREF(elem);
 
-    const char* fetch_name = PyString_AsString(elem);
+    char* fetch_name = PyBytes_AsString(elem);
     if (!fetch_name) {
       PyErr_SetString(PyExc_TypeError,
                       "a fetch or target name was not a string");
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index 06483da87b5..d986e99e8bd 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -39,7 +39,7 @@ Status PyArrayDescr_to_TF_DataType(PyArray_Descr* descr,
   PyObject* value;
   Py_ssize_t pos = 0;
   if (PyDict_Next(descr->fields, &pos, &key, &value)) {
-    const char* key_string = PyString_AsString(key);
+    const char* key_string = PyBytes_AsString(key);
     if (!key_string) {
       return errors::Internal("Corrupt numpy type descriptor");
     }
@@ -166,7 +166,7 @@ Status TF_DataType_to_PyArray_TYPE(TF_DataType tf_datatype,
 // Iterate over the string array 'array', extract the ptr and len of each string
 // element and call f(ptr, len).
 template <typename F>
-Status PyStringArrayMap(PyArrayObject* array, F f) {
+Status PyBytesArrayMap(PyArrayObject* array, F f) {
   Safe_PyObjectPtr iter = tensorflow::make_safe(
       PyArray_IterNew(reinterpret_cast<PyObject*>(array)));
   while (PyArray_ITER_NOTDONE(iter.get())) {
@@ -177,10 +177,23 @@ Status PyStringArrayMap(PyArrayObject* array, F f) {
     }
     char* ptr;
     Py_ssize_t len;
-    int success = PyString_AsStringAndSize(item.get(), &ptr, &len);
-    if (success != 0) {
-      return errors::Internal("Unable to get element from the feed.");
+
+#if PY_VERSION_HEX >= 0x03030000
+    // Accept unicode in Python 3, by converting to UTF-8 bytes.
+    if (PyUnicode_Check(item.get())) {
+      ptr = PyUnicode_AsUTF8AndSize(item.get(), &len);
+      if (!buf) {
+        return errors::Internal("Unable to get element from the feed.");
+      }
+    } else {
+#endif
+      int success = PyBytes_AsStringAndSize(item.get(), &ptr, &len);
+      if (success != 0) {
+        return errors::Internal("Unable to get element from the feed.");
+      }
+#if PY_VERSION_HEX >= 0x03030000
     }
+#endif
     f(ptr, len);
     PyArray_ITER_NEXT(iter.get());
   }
@@ -189,15 +202,14 @@ Status PyStringArrayMap(PyArrayObject* array, F f) {
 
 // Encode the strings in 'array' into a contiguous buffer and return the base of
 // the buffer. The caller takes ownership of the buffer.
-Status EncodePyStringArray(PyArrayObject* array, tensorflow::int64 nelems,
-                           size_t* size, void** buffer) {
+Status EncodePyBytesArray(PyArrayObject* array, tensorflow::int64 nelems,
+                          size_t* size, void** buffer) {
   // Compute bytes needed for encoding.
   *size = 0;
-  TF_RETURN_IF_ERROR(
-      PyStringArrayMap(array, [&size](char* ptr, Py_ssize_t len) {
-        *size += sizeof(tensorflow::uint64) +
-                 tensorflow::core::VarintLength(len) + len;
-      }));
+  TF_RETURN_IF_ERROR(PyBytesArrayMap(array, [&size](char* ptr, Py_ssize_t len) {
+    *size +=
+        sizeof(tensorflow::uint64) + tensorflow::core::VarintLength(len) + len;
+  }));
   // Encode all strings.
   std::unique_ptr<char[]> base_ptr(new char[*size]);
   char* base = base_ptr.get();
@@ -205,7 +217,7 @@ Status EncodePyStringArray(PyArrayObject* array, tensorflow::int64 nelems,
   char* dst = data_start;  // Where next string is encoded.
   tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base);
 
-  TF_RETURN_IF_ERROR(PyStringArrayMap(
+  TF_RETURN_IF_ERROR(PyBytesArrayMap(
       array, [&base, &data_start, &dst, &offsets](char* ptr, Py_ssize_t len) {
         *offsets = (dst - data_start);
         offsets++;
@@ -251,7 +263,7 @@ static Status CopyStringToPyArrayElement(PyArrayObject* pyarray, void* i_ptr,
   tensorflow::uint64 len;
   TF_RETURN_IF_ERROR(
       TF_StringTensor_GetPtrAndLen(tensor, num_elements, i, &ptr, &len));
-  auto py_string = tensorflow::make_safe(PyString_FromStringAndSize(ptr, len));
+  auto py_string = tensorflow::make_safe(PyBytes_FromStringAndSize(ptr, len));
   int success =
       PyArray_SETITEM(pyarray, PyArray_ITER_DATA(i_ptr), py_string.get());
   if (success != 0) {
@@ -446,7 +458,7 @@ void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs,
     } else {
       size_t size;
       void* encoded;
-      Status s = EncodePyStringArray(array, nelems, &size, &encoded);
+      Status s = EncodePyBytesArray(array, nelems, &size, &encoded);
       if (!s.ok()) {
         *out_status = s;
         return;
diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py
index f10a33ae331..00369816ea8 100644
--- a/tensorflow/python/framework/importer.py
+++ b/tensorflow/python/framework/importer.py
@@ -7,6 +7,8 @@ import contextlib
 
 import tensorflow.python.platform
 
+import six
+
 from tensorflow.core.framework import graph_pb2
 from tensorflow.core.framework import types_pb2
 from tensorflow.python.framework import op_def_registry
@@ -170,12 +172,13 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
     input_map = {}
   else:
     if not (isinstance(input_map, dict)
-            and all(isinstance(k, basestring) for k in input_map.keys())):
+            and all(isinstance(k, six.string_types) for k in input_map.keys())):
       raise TypeError('input_map must be a dictionary mapping strings to '
                       'Tensor objects.')
   if (return_elements is not None
       and not (isinstance(return_elements, (list, tuple))
-               and all(isinstance(x, basestring) for x in return_elements))):
+               and all(isinstance(x, six.string_types)
+                       for x in return_elements))):
     raise TypeError('return_elements must be a list of strings.')
 
   # Use a canonical representation for all tensor names.
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py
index 12f1cc0de77..c5eb23f05be 100644
--- a/tensorflow/python/framework/importer_test.py
+++ b/tensorflow/python/framework/importer_test.py
@@ -176,6 +176,58 @@ class ImportGraphDefTest(tf.test.TestCase):
       self.assertEqual(d.inputs[0], a.outputs[1])
       self.assertEqual(d.inputs[1], feed_b_1)
 
+  def testInputMapBytes(self):
+    with tf.Graph().as_default():
+      feed_a_0 = tf.constant(0, dtype=tf.int32)
+      feed_b_1 = tf.constant(1, dtype=tf.int32)
+
+      a, b, c, d = tf.import_graph_def(
+          self._MakeGraphDef("""
+          node { name: 'A' op: 'Oii' }
+          node { name: 'B' op: 'Oii' }
+          node { name: 'C' op: 'In'
+                 attr { key: 'N' value { i: 2 } }
+                 attr { key: 'T' value { type: DT_INT32 } }
+                 input: 'A:0' input: 'B:0' }
+          node { name: 'D' op: 'In'
+                 attr { key: 'N' value { i: 2 } }
+                 attr { key: 'T' value { type: DT_INT32 } }
+                 input: 'A:1' input: 'B:1' }
+          """),
+          input_map={b'A:0': feed_a_0, b'B:1': feed_b_1},
+          return_elements=[b'A', b'B', b'C', b'D'])
+
+      self.assertEqual(c.inputs[0], feed_a_0)
+      self.assertEqual(c.inputs[1], b.outputs[0])
+      self.assertEqual(d.inputs[0], a.outputs[1])
+      self.assertEqual(d.inputs[1], feed_b_1)
+
+  def testInputMapUnicode(self):
+    with tf.Graph().as_default():
+      feed_a_0 = tf.constant(0, dtype=tf.int32)
+      feed_b_1 = tf.constant(1, dtype=tf.int32)
+
+      a, b, c, d = tf.import_graph_def(
+          self._MakeGraphDef("""
+          node { name: 'A' op: 'Oii' }
+          node { name: 'B' op: 'Oii' }
+          node { name: 'C' op: 'In'
+                 attr { key: 'N' value { i: 2 } }
+                 attr { key: 'T' value { type: DT_INT32 } }
+                 input: 'A:0' input: 'B:0' }
+          node { name: 'D' op: 'In'
+                 attr { key: 'N' value { i: 2 } }
+                 attr { key: 'T' value { type: DT_INT32 } }
+                 input: 'A:1' input: 'B:1' }
+          """),
+          input_map={u'A:0': feed_a_0, u'B:1': feed_b_1},
+          return_elements=[u'A', u'B', u'C', u'D'])
+
+      self.assertEqual(c.inputs[0], feed_a_0)
+      self.assertEqual(c.inputs[1], b.outputs[0])
+      self.assertEqual(d.inputs[0], a.outputs[1])
+      self.assertEqual(d.inputs[1], feed_b_1)
+
   def testImplicitZerothOutput(self):
     with tf.Graph().as_default():
       a, b = tf.import_graph_def(
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 2801d588e89..54917a506f1 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -1327,7 +1327,7 @@ class RegisterGradient(object):
       op_type: The string type of an operation. This corresponds to the
         `OpDef.name` field for the proto that defines the operation.
     """
-    if not isinstance(op_type, basestring):
+    if not isinstance(op_type, six.string_types):
       raise TypeError("op_type must be a string")
     self._op_type = op_type
 
@@ -1356,7 +1356,7 @@ def NoGradient(op_type):
     TypeError: If `op_type` is not a string.
 
   """
-  if not isinstance(op_type, basestring):
+  if not isinstance(op_type, six.string_types):
     raise TypeError("op_type must be a string")
   _gradient_registry.register(None, op_type)
 
@@ -1400,7 +1400,7 @@ class RegisterShape(object):
 
   def __init__(self, op_type):
     """Saves the "op_type" as the Operation type."""
-    if not isinstance(op_type, basestring):
+    if not isinstance(op_type, six.string_types):
       raise TypeError("op_type must be a string")
     self._op_type = op_type
 
@@ -1818,7 +1818,7 @@ class Graph(object):
       obj = conv_fn()
 
     # If obj appears to be a name...
-    if isinstance(obj, basestring):
+    if isinstance(obj, six.string_types):
       name = obj
 
       if ":" in name and allow_tensor:
@@ -1909,7 +1909,7 @@ class Graph(object):
       KeyError: If `name` does not correspond to an operation in this graph.
     """
 
-    if not isinstance(name, basestring):
+    if not isinstance(name, six.string_types):
       raise TypeError("Operation names are strings (or similar), not %s."
                       % type(name).__name__)
     return self.as_graph_element(name, allow_tensor=False, allow_operation=True)
@@ -1930,7 +1930,7 @@ class Graph(object):
       KeyError: If `name` does not correspond to a tensor in this graph.
     """
     # Names should be strings.
-    if not isinstance(name, basestring):
+    if not isinstance(name, six.string_types):
       raise TypeError("Tensor names are strings (or similar), not %s."
                       % type(name).__name__)
     return self.as_graph_element(name, allow_tensor=True, allow_operation=False)
@@ -2495,8 +2495,8 @@ class Graph(object):
     saved_labels = {}
     # Install the given label
     for op_type, label in op_to_kernel_label_map.items():
-      if not (isinstance(op_type, basestring)
-              and isinstance(label, basestring)):
+      if not (isinstance(op_type, six.string_types)
+              and isinstance(label, six.string_types)):
         raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
                         "strings to strings")
       try:
@@ -2558,8 +2558,8 @@ class Graph(object):
     saved_mappings = {}
     # Install the given label
     for op_type, mapped_op_type in op_type_map.items():
-      if not (isinstance(op_type, basestring)
-              and isinstance(mapped_op_type, basestring)):
+      if not (isinstance(op_type, six.string_types)
+              and isinstance(mapped_op_type, six.string_types)):
         raise TypeError("op_type_map must be a dictionary mapping "
                         "strings to strings")
       try:
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index e00b3b6d910..c872223b39b 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -207,7 +207,10 @@ def _FilterComplex(v):
 def _FilterStr(v):
   if isinstance(v, (list, tuple)):
     return _FirstNotNone([_FilterStr(x) for x in v])
-  return None if isinstance(v, basestring) else _NotNone(v)
+  if isinstance(v, (six.string_types, six.binary_type)):
+    return None
+  else:
+    return _NotNone(v)
 
 
 def _FilterBool(v):
diff --git a/tensorflow/python/lib/core/strings.i b/tensorflow/python/lib/core/strings.i
index c88e426a54c..7ee912c0f7c 100644
--- a/tensorflow/python/lib/core/strings.i
+++ b/tensorflow/python/lib/core/strings.i
@@ -25,31 +25,15 @@
 %typemap(typecheck) tensorflow::StringPiece = char *;
 %typemap(typecheck) const tensorflow::StringPiece & = char *;
 
-// "tensorflow::StringPiece" arguments can be provided by a simple Python 'str' string
-// or a 'unicode' object. If 'unicode', it's translated using the default
-// encoding, i.e., sys.getdefaultencoding(). If passed None, a tensorflow::StringPiece
-// of zero length with a NULL pointer is provided.
+// "tensorflow::StringPiece" arguments must be specified as a 'str' or 'bytes' object.
 %typemap(in) tensorflow::StringPiece {
   if ($input != Py_None) {
     char * buf;
     Py_ssize_t len;
-%#if PY_VERSION_HEX >= 0x03030000
-    /* Do unicode handling as PyBytes_AsStringAndSize doesn't in Python 3. */
-    if (PyUnicode_Check($input)) {
-      buf = PyUnicode_AsUTF8AndSize($input, &len);
-      if (buf == NULL)
-        SWIG_fail;
-    } else {
-%#elif PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 3
-%#  error "Unsupported Python 3.x C API version (3.3 or later required)."
-%#endif
-      if (PyBytes_AsStringAndSize($input, &buf, &len) == -1) {
-        // Python has raised an error (likely TypeError or UnicodeEncodeError).
-        SWIG_fail;
-      }
-%#if PY_VERSION_HEX >= 0x03030000
+    if (PyBytes_AsStringAndSize($input, &buf, &len) == -1) {
+      // Python has raised an error (likely TypeError or UnicodeEncodeError).
+      SWIG_fail;
     }
-%#endif
     $1.set(buf, len);
   }
 }
@@ -60,23 +44,10 @@
   if ($input != Py_None) {
     char * buf;
     Py_ssize_t len;
-%#if PY_VERSION_HEX >= 0x03030000
-    /* Do unicode handling as PyBytes_AsStringAndSize doesn't in Python 3. */
-    if (PyUnicode_Check($input)) {
-      buf = PyUnicode_AsUTF8AndSize($input, &len);
-      if (buf == NULL)
-        SWIG_fail;
-    } else {
-%#elif PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 3
-%#  error "Unsupported Python 3.x C API version (3.3 or later required)."
-%#endif
-      if (PyBytes_AsStringAndSize($input, &buf, &len) == -1) {
-        // Python has raised an error (likely TypeError or UnicodeEncodeError).
-        SWIG_fail;
-      }
-%#if PY_VERSION_HEX >= 0x03030000
+    if (PyBytes_AsStringAndSize($input, &buf, &len) == -1) {
+      // Python has raised an error (likely TypeError).
+      SWIG_fail;
     }
-%#endif
     temp.set(buf, len);
   }
   $1 = &temp;
@@ -86,7 +57,7 @@
 // or None if the StringPiece contained a NULL pointer.
 %typemap(out) tensorflow::StringPiece {
   if ($1.data()) {
-    $result = PyString_FromStringAndSize($1.data(), $1.size());
+    $result = PyBytes_FromStringAndSize($1.data(), $1.size());
   } else {
     Py_INCREF(Py_None);
     $result = Py_None;
diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py
index 9893c9b824b..e682787f5b4 100644
--- a/tensorflow/python/ops/clip_ops.py
+++ b/tensorflow/python/ops/clip_ops.py
@@ -5,6 +5,8 @@ from __future__ import print_function
 
 import collections
 
+import six
+
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import types
 from tensorflow.python.ops import array_ops
@@ -102,7 +104,7 @@ def global_norm(t_list, name=None):
     TypeError: If `t_list` is not a sequence.
   """
   if (not isinstance(t_list, collections.Sequence)
-      or isinstance(t_list, basestring)):
+      or isinstance(t_list, six.string_types)):
     raise TypeError("t_list should be a sequence")
   t_list = list(t_list)
   with ops.op_scope(t_list, name, "global_norm") as name:
@@ -164,7 +166,7 @@ def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):
     TypeError: If `t_list` is not a sequence.
   """
   if (not isinstance(t_list, collections.Sequence)
-      or isinstance(t_list, basestring)):
+      or isinstance(t_list, six.string_types)):
     raise TypeError("t_list should be a sequence")
   t_list = list(t_list)
   if use_norm is None:
diff --git a/tensorflow/python/ops/op_def_library.py b/tensorflow/python/ops/op_def_library.py
index 957c5123bcc..c6acd84bb69 100644
--- a/tensorflow/python/ops/op_def_library.py
+++ b/tensorflow/python/ops/op_def_library.py
@@ -5,6 +5,7 @@ from __future__ import division
 from __future__ import print_function
 
 import numbers
+import six
 
 from tensorflow.core.framework import attr_value_pb2
 from tensorflow.core.framework import op_def_pb2
@@ -128,7 +129,7 @@ def _MakeFloat(v, arg_name):
 
 
 def _MakeInt(v, arg_name):
-  if isinstance(v, basestring):
+  if isinstance(v, six.string_types):
     raise TypeError("Expected int for argument '%s' not %s." %
                     (arg_name, repr(v)))
   try:
@@ -139,7 +140,7 @@ def _MakeInt(v, arg_name):
 
 
 def _MakeStr(v, arg_name):
-  if not isinstance(v, basestring):
+  if not isinstance(v, six.string_types):
     raise TypeError("Expected string for argument '%s' not %s." %
                     (arg_name, repr(v)))
   # TODO(irving): Figure out what to do here from Python 3
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 47149163af6..0b8ed237764 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -5,6 +5,7 @@ from __future__ import division
 from __future__ import print_function
 
 import contextlib
+import six
 
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
@@ -296,14 +297,14 @@ def variable_scope(name_or_scope, reuse=None, initializer=None):
       a reuse scope, or if reuse is not `None` or `True`.
     TypeError: when the types of some arguments are not appropriate.
   """
-  if not isinstance(name_or_scope, (_VariableScope, basestring)):
+  if not isinstance(name_or_scope, (_VariableScope,) + six.string_types):
     raise TypeError("VariableScope: name_scope must be a string or "
                     "VariableScope.")
   if reuse not in [None, True]:
     raise ValueError("VariableScope reuse parameter must be True or None.")
   if not reuse and isinstance(name_or_scope, (_VariableScope)):
     logging.info("Passing VariableScope to a non-reusing scope, intended?")
-  if reuse and isinstance(name_or_scope, (basestring)):
+  if reuse and isinstance(name_or_scope, six.string_types):
     logging.info("Re-using string-named scope, consider capturing as object.")
   get_variable_scope()  # Ensure that a default exists, then get a pointer.
   default_varscope = ops.get_collection(_VARSCOPE_KEY)
diff --git a/tensorflow/python/platform/base.i b/tensorflow/python/platform/base.i
index 85fa3968a13..22792669fa9 100644
--- a/tensorflow/python/platform/base.i
+++ b/tensorflow/python/platform/base.i
@@ -21,13 +21,7 @@
       bool _PyObjAs(PyObject *pystr, ::string* cstr) {
     char *buf;
     Py_ssize_t len;
-#if PY_VERSION_HEX >= 0x03030000
-    if (PyUnicode_Check(pystr)) {
-      buf = PyUnicode_AsUTF8AndSize(pystr, &len);
-      if (!buf) return false;
-    } else  // NOLINT
-#endif
-      if (PyBytes_AsStringAndSize(pystr, &buf, &len) == -1) return false;
+    if (PyBytes_AsStringAndSize(pystr, &buf, &len) == -1) return false;
     if (cstr) cstr->assign(buf, len);
     return true;
   }
@@ -36,29 +30,23 @@
       bool _PyObjAs(PyObject *pystr, std::string* cstr) {
     char *buf;
     Py_ssize_t len;
-#if PY_VERSION_HEX >= 0x03030000
-    if (PyUnicode_Check(pystr)) {
-      buf = PyUnicode_AsUTF8AndSize(pystr, &len);
-      if (!buf) return false;
-    } else  // NOLINT
-#endif
-      if (PyBytes_AsStringAndSize(pystr, &buf, &len) == -1) return false;
+    if (PyBytes_AsStringAndSize(pystr, &buf, &len) == -1) return false;
     if (cstr) cstr->assign(buf, len);
     return true;
   }
 #ifdef HAS_GLOBAL_STRING
   template<>
       PyObject* _PyObjFrom(const ::string& c) {
-    return PyString_FromStringAndSize(c.data(), c.size());
+    return PyBytes_FromStringAndSize(c.data(), c.size());
   }
 #endif
   template<>
       PyObject* _PyObjFrom(const std::string& c) {
-    return PyString_FromStringAndSize(c.data(), c.size());
+    return PyBytes_FromStringAndSize(c.data(), c.size());
   }
 
   PyObject* _SwigString_FromString(const string& s) {
-    return PyUnicode_FromStringAndSize(s.data(), s.size());
+    return PyBytes_FromStringAndSize(s.data(), s.size());
   }
 %}
 
@@ -72,11 +60,11 @@
 }
 
 %typemap(out) string {
-  $result = PyString_FromStringAndSize($1.data(), $1.size());
+  $result = PyBytes_FromStringAndSize($1.data(), $1.size());
 }
 
 %typemap(out) const string& {
-  $result = PyString_FromStringAndSize($1->data(), $1->size());
+  $result = PyBytes_FromStringAndSize($1->data(), $1->size());
 }
 
 %typemap(in, numinputs = 0) string* OUTPUT (string temp) {
@@ -84,7 +72,7 @@
 }
 
 %typemap(argout) string * OUTPUT {
-  PyObject *str = PyString_FromStringAndSize($1->data(), $1->length());
+  PyObject *str = PyBytes_FromStringAndSize($1->data(), $1->length());
   if (!str) SWIG_fail;
   %append_output(str);
 }
@@ -92,7 +80,7 @@
 %typemap(argout) string* INOUT = string* OUTPUT;
 
 %typemap(varout) string {
-  $result = PyString_FromStringAndSize($1.data(), $1.size());
+  $result = PyBytes_FromStringAndSize($1.data(), $1.size());
 }
 
 %define _LIST_OUTPUT_TYPEMAP(type, py_converter)
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index fcd02716d28..8f142edb4bb 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -9,6 +9,8 @@ import numbers
 import os.path
 import time
 
+import six
+
 from google.protobuf import text_format
 
 from tensorflow.python.client import graph_util
@@ -299,7 +301,7 @@ class BaseSaverBuilder(object):
     vars_to_save = []
     seen_variables = set()
     for name in sorted(names_to_variables.keys()):
-      if not isinstance(name, basestring):
+      if not isinstance(name, six.string_types):
         raise TypeError("names_to_variables must be a dict mapping string "
                         "names to variable Tensors. Name is not a string: %s" %
                         name)
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 4e248f625ce..76ca4d90d6b 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -10,6 +10,7 @@ import tensorflow.python.platform
 
 import tensorflow as tf
 import numpy as np
+import six
 
 from tensorflow.python.platform import gfile
 
@@ -33,7 +34,7 @@ class SaverTest(tf.test.TestCase):
 
       # Save the initialized values in the file at "save_path"
       val = save.save(sess, save_path)
-      self.assertTrue(isinstance(val, basestring))
+      self.assertTrue(isinstance(val, six.string_types))
       self.assertEqual(save_path, val)
 
     # Start a second session.  In that session the parameter nodes
@@ -84,7 +85,7 @@ class SaverTest(tf.test.TestCase):
 
       # Save the initialized values in the file at "save_path"
       val = save.save(sess, save_path)
-      self.assertTrue(isinstance(val, basestring))
+      self.assertTrue(isinstance(val, six.string_types))
       self.assertEqual(save_path, val)
 
       with self.test_session() as sess:
@@ -131,7 +132,7 @@ class SaverTest(tf.test.TestCase):
 
       # Save the initialized values in the file at "save_path"
       val = save.save(sess, save_path)
-      self.assertTrue(isinstance(val, basestring))
+      self.assertTrue(isinstance(val, six.string_types))
       self.assertEqual(save_path, val)
 
     # Start a second session.  In that session the variables
@@ -517,7 +518,7 @@ class SaveRestoreWithVariableNameMap(tf.test.TestCase):
       # Save the initialized values in the file at "save_path"
       # Use a variable name map to set the saved tensor names
       val = save.save(sess, save_path)
-      self.assertTrue(isinstance(val, basestring))
+      self.assertTrue(isinstance(val, six.string_types))
       self.assertEqual(save_path, val)
 
       # Verify that the original names are not in the Saved file
diff --git a/tensorflow/python/training/summary_io.py b/tensorflow/python/training/summary_io.py
index 7d1838d9f22..f96bfde9dd5 100644
--- a/tensorflow/python/training/summary_io.py
+++ b/tensorflow/python/training/summary_io.py
@@ -9,6 +9,8 @@ import Queue
 import threading
 import time
 
+import six
+
 from tensorflow.core.framework import summary_pb2
 from tensorflow.core.util import event_pb2
 from tensorflow.python import pywrap_tensorflow
@@ -103,7 +105,7 @@ class SummaryWriter(object):
       global_step: Number. Optional global step value to record with the
         summary.
     """
-    if isinstance(summary, basestring):
+    if isinstance(summary, six.binary_type):
       summ = summary_pb2.Summary()
       summ.ParseFromString(summary)
       summary = summ
diff --git a/tensorflow/python/util/protobuf/compare.py b/tensorflow/python/util/protobuf/compare.py
index 51cef58a987..8cbb10f2493 100644
--- a/tensorflow/python/util/protobuf/compare.py
+++ b/tensorflow/python/util/protobuf/compare.py
@@ -82,7 +82,7 @@ def assertProto2Equal(self, a, b, check_initialized=True,
       numbers before comparison.
     msg: if specified, is used as the error message on failure
   """
-  if isinstance(a, basestring):
+  if isinstance(a, six.string_types):
     a = text_format.Merge(a, b.__class__())
 
   for pb in a, b:
@@ -121,7 +121,7 @@ def assertProto2SameElements(self, a, b, number_matters=False,
       numbers before comparison.
     msg: if specified, is used as the error message on failure
   """
-  if isinstance(a, basestring):
+  if isinstance(a, six.string_types):
     a = text_format.Merge(a, b.__class__())
   else:
     a = copy.deepcopy(a)
@@ -152,7 +152,7 @@ def assertProto2Contains(self, a, b,  # pylint: disable=invalid-name
     check_initialized: boolean, whether to fail if b isn't initialized
     msg: if specified, is used as the error message on failure
   """
-  if isinstance(a, basestring):
+  if isinstance(a, six.string_types):
     a = text_format.Merge(a, b.__class__())
   else:
     a = copy.deepcopy(a)
@@ -299,7 +299,7 @@ def NormalizeNumberFields(pb):
 
 
 def _IsRepeatedContainer(value):
-  if isinstance(value, basestring):
+  if isinstance(value, six.string_types):
     return False
   try:
     iter(value)
diff --git a/tensorflow/python/util/protobuf/compare_test.py b/tensorflow/python/util/protobuf/compare_test.py
index d8cb53bc2b0..4a8dcf3add4 100644
--- a/tensorflow/python/util/protobuf/compare_test.py
+++ b/tensorflow/python/util/protobuf/compare_test.py
@@ -342,12 +342,12 @@ class NormalizeNumbersTest(googletest.TestCase):
 class AssertTest(googletest.TestCase):
   """Tests both assertProto2Equal() and assertProto2SameElements()."""
   def assertProto2Equal(self, a, b, **kwargs):
-    if isinstance(a, basestring) and isinstance(b, basestring):
+    if isinstance(a, six.string_types) and isinstance(b, six.string_types):
       a, b = LargePbs(a, b)
     compare.assertProto2Equal(self, a, b, **kwargs)
 
   def assertProto2SameElements(self, a, b, **kwargs):
-    if isinstance(a, basestring) and isinstance(b, basestring):
+    if isinstance(a, six.string_types) and isinstance(b, six.string_types):
       a, b = LargePbs(a, b)
     compare.assertProto2SameElements(self, a, b, **kwargs)
 
diff --git a/tensorflow/tensorboard/bower.json b/tensorflow/tensorboard/bower.json
index 995ba30363f..a9419c5140a 100644
--- a/tensorflow/tensorboard/bower.json
+++ b/tensorflow/tensorboard/bower.json
@@ -16,9 +16,9 @@
   "private": true,
   "dependencies": {
     "d3": "3.5.6",
-    "dagre": "~0.7.4",
-    "es6-promise": "~3.0.2",
-    "graphlib": "~1.0.7",
+    "dagre": "0.7.4",
+    "es6-promise": "3.0.2",
+    "graphlib": "1.0.7",
     "iron-ajax": "PolymerElements/iron-ajax#1.0.7",
     "iron-collapse": "PolymerElements/iron-collapse#1.0.4",
     "iron-list": "PolymerElements/iron-list#1.1.5",
@@ -38,7 +38,7 @@
     "paper-styles": "PolymerElements/paper-styles#1.0.12",
     "paper-toggle-button": "PolymerElements/paper-toggle-button#1.0.11",
     "paper-toolbar": "PolymerElements/paper-toolbar#1.0.4",
-    "plottable": "~1.16.1",
+    "plottable": "1.16.2",
     "polymer": "1.1.5"
   },
   "devDependencies": {