TensorFlow: a few small updates.

Changes:
- Fix softmax formula in word2vec to remove an extra exp()
  by @gouwsmeister

- Python3 fixes to remove basestring / support for unicode by @mrry

- Remove some comments by Josh

- Specify exact versions of bower dependencies for TensorBoard by
  @danmane.

Base CL: 107742361
This commit is contained in:
Vijay Vasudevan 2015-11-12 18:34:45 -08:00
parent d50565b35e
commit 011e9baccd
22 changed files with 264 additions and 131 deletions

View File

@ -111,8 +111,7 @@ Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
const NameAttrList** value); // type: "func"
// Computes the input and output types for a specific node, for
// attr-style ops.
// Computes the input and output types for a specific node.
// REQUIRES: ValidateOpDef(op_def).ok()
Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
DataTypeVector* inputs, DataTypeVector* outputs);

View File

@ -64,11 +64,9 @@ Node* Unary(Graph* g, const string& func, Node* input, int index = 0);
Node* Identity(Graph* g, Node* input, int index = 0);
// Adds a binary function "func" node in "g" taking "in0" and "in1".
// Requires that "func" name an attr-style Op.
Node* Binary(Graph* g, const string& func, Node* in0, Node* in1);
// Adds a function "func" node in "g" taking inputs "ins".
// Requires that "func" name an attr-style Op.
Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins);
// Adds a binary add node in "g" doing in0 + in1.

View File

@ -100,7 +100,7 @@ given the previous words \\(h\\) (for "history") in terms of a
$$
\begin{align}
P(w_t | h) &= \text{softmax}(\exp \{ \text{score}(w_t, h) \}) \\
P(w_t | h) &= \text{softmax}(\text{score}(w_t, h)) \\
&= \frac{\exp \{ \text{score}(w_t, h) \} }
{\sum_\text{Word w' in Vocab} \exp \{ \text{score}(w', h) \} }.
\end{align}

View File

@ -11,6 +11,7 @@ import threading
import tensorflow.python.platform
import numpy as np
import six
from tensorflow.python import pywrap_tensorflow as tf_session
from tensorflow.python.framework import errors
@ -36,6 +37,56 @@ class SessionInterface(object):
raise NotImplementedError('Run')
def _as_bytes(bytes_or_unicode):
"""Returns the given argument as a byte array.
NOTE(mrry): For Python 2 and 3 compatibility, we convert all string
arguments to SWIG methods into byte arrays. Unicode strings are
encoded as UTF-8; however the valid arguments for all of the
human-readable arguments must currently be a subset of ASCII.
Args:
bytes_or_unicode: A `unicode`, `string`, or `bytes` object.
Returns:
A `bytes` object.
Raises:
TypeError: If `bytes_or_unicode` is not a binary or unicode string.
"""
if isinstance(bytes_or_unicode, six.text_type):
return bytes_or_unicode.encode('utf-8')
elif isinstance(bytes_or_unicode, six.binary_type):
return bytes_or_unicode
else:
raise TypeError('bytes_or_unicode must be a binary or unicode string.')
def _as_text(bytes_or_unicode):
"""Returns the given argument as a unicode string.
NOTE(mrry): For Python 2 and 3 compatibility, we interpret all
returned strings from SWIG methods as byte arrays. This function
converts those strings that are intended to be human-readable into
UTF-8 unicode strings.
Args:
bytes_or_unicode: A `unicode`, `string`, or `bytes` object.
Returns:
A `unicode` (Python 2) or `str` (Python 3) object.
Raises:
TypeError: If `bytes_or_unicode` is not a binary or unicode string.
"""
if isinstance(bytes_or_unicode, six.text_type):
return bytes_or_unicode
elif isinstance(bytes_or_unicode, six.binary_type):
return bytes_or_unicode.decode('utf-8')
else:
raise TypeError('bytes_or_unicode must be a binary or unicode string.')
class BaseSession(SessionInterface):
"""A class for interacting with a TensorFlow computation.
@ -75,8 +126,7 @@ class BaseSession(SessionInterface):
status = tf_session.TF_NewStatus()
self._session = tf_session.TF_NewSession(opts, status)
if tf_session.TF_GetCode(status) != 0:
message = tf_session.TF_Message(status)
raise RuntimeError(message)
raise RuntimeError(_as_text(tf_session.TF_Message(status)))
finally:
tf_session.TF_DeleteSessionOptions(opts)
@ -97,7 +147,7 @@ class BaseSession(SessionInterface):
status = tf_session.TF_NewStatus()
tf_session.TF_CloseSession(self._session, status)
if tf_session.TF_GetCode(status) != 0:
raise RuntimeError(tf_session.TF_Message(status))
raise RuntimeError(_as_text(tf_session.TF_Message(status)))
finally:
tf_session.TF_DeleteStatus(status)
@ -108,7 +158,7 @@ class BaseSession(SessionInterface):
if self._session is not None:
tf_session.TF_DeleteSession(self._session, status)
if tf_session.TF_GetCode(status) != 0:
raise RuntimeError(tf_session.TF_Message(status))
raise RuntimeError(_as_text(tf_session.TF_Message(status)))
self._session = None
finally:
tf_session.TF_DeleteStatus(status)
@ -265,7 +315,6 @@ class BaseSession(SessionInterface):
TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type.
ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a
`Tensor` that doesn't exist.
"""
def _fetch_fn(fetch):
for tensor_type, fetch_fn, _ in BaseSession._REGISTERED_EXPANSIONS:
@ -302,9 +351,9 @@ class BaseSession(SessionInterface):
fetch_t = self.graph.as_graph_element(subfetch, allow_tensor=True,
allow_operation=True)
if isinstance(fetch_t, ops.Operation):
target_list.append(fetch_t.name)
target_list.append(_as_bytes(fetch_t.name))
else:
subfetch_names.append(fetch_t.name)
subfetch_names.append(_as_bytes(fetch_t.name))
except TypeError as e:
raise TypeError('Fetch argument %r of %r has invalid type %r, '
'must be a string or Tensor. (%s)'
@ -343,7 +392,7 @@ class BaseSession(SessionInterface):
'which has shape %r'
% (np_val.shape, subfeed_t.name,
tuple(subfeed_t.get_shape().dims)))
feed_dict_string[str(subfeed_t.name)] = np_val
feed_dict_string[_as_bytes(subfeed_t.name)] = np_val
# Run request and get response.
results = self._do_run(target_list, unique_fetch_targets, feed_dict_string)
@ -373,11 +422,12 @@ class BaseSession(SessionInterface):
"""Runs a step based on the given fetches and feeds.
Args:
target_list: A list of strings corresponding to names of tensors
target_list: A list of byte arrays corresponding to names of tensors
or operations to be run to, but not fetched.
fetch_list: A list of strings corresponding to names of tensors to be
fetched and operations to be run.
feed_dict: A dictionary that maps tensor names to numpy ndarrays.
fetch_list: A list of byte arrays corresponding to names of tensors to
be fetched and operations to be run.
feed_dict: A dictionary that maps tensor names (as byte arrays) to
numpy ndarrays.
Returns:
A list of numpy ndarrays, corresponding to the elements of
@ -397,7 +447,7 @@ class BaseSession(SessionInterface):
tf_session.TF_ExtendGraph(
self._session, graph_def.SerializeToString(), status)
if tf_session.TF_GetCode(status) != 0:
raise RuntimeError(tf_session.TF_Message(status))
raise RuntimeError(_as_text(tf_session.TF_Message(status)))
self._opened = True
finally:
tf_session.TF_DeleteStatus(status)
@ -409,7 +459,8 @@ class BaseSession(SessionInterface):
except tf_session.StatusNotOK as e:
e_type, e_value, e_traceback = sys.exc_info()
m = BaseSession._NODEDEF_NAME_RE.search(e.error_message)
error_message = _as_text(e.error_message)
m = BaseSession._NODEDEF_NAME_RE.search(error_message)
if m is not None:
node_name = m.group(1)
node_def = None
@ -419,7 +470,7 @@ class BaseSession(SessionInterface):
except KeyError:
op = None
# pylint: disable=protected-access
raise errors._make_specific_exception(node_def, op, e.error_message,
raise errors._make_specific_exception(node_def, op, error_message,
e.code)
# pylint: enable=protected-access
raise e_type, e_value, e_traceback

View File

@ -9,6 +9,7 @@ import time
import tensorflow.python.platform
import numpy as np
import six
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.framework import config_pb2
@ -551,9 +552,55 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertEqual(c_list[0], out[0])
self.assertEqual(c_list[1], out[1])
def testStringFeedWithUnicode(self):
with session.Session():
c_list = [u'\n\x01\x00', u'\n\x00\x01']
feed_t = array_ops.placeholder(dtype=types.string, shape=[2])
c = array_ops.identity(feed_t)
out = c.eval(feed_dict={feed_t: c_list})
self.assertEqual(c_list[0], out[0].decode('utf-8'))
self.assertEqual(c_list[1], out[1].decode('utf-8'))
out = c.eval(feed_dict={feed_t: np.array(c_list, dtype=np.object)})
self.assertEqual(c_list[0], out[0].decode('utf-8'))
self.assertEqual(c_list[1], out[1].decode('utf-8'))
def testInvalidTargetFails(self):
with self.assertRaises(RuntimeError):
session.Session("INVALID_TARGET")
session.Session('INVALID_TARGET')
def testFetchByNameDifferentStringTypes(self):
with session.Session() as sess:
c = constant_op.constant(42.0, name='c')
d = constant_op.constant(43.0, name=u'd')
e = constant_op.constant(44.0, name=b'e')
f = constant_op.constant(45.0, name=r'f')
self.assertTrue(isinstance(c.name, six.text_type))
self.assertTrue(isinstance(d.name, six.text_type))
self.assertTrue(isinstance(e.name, six.text_type))
self.assertTrue(isinstance(f.name, six.text_type))
self.assertEqual(42.0, sess.run('c:0'))
self.assertEqual(42.0, sess.run(u'c:0'))
self.assertEqual(42.0, sess.run(b'c:0'))
self.assertEqual(42.0, sess.run(r'c:0'))
self.assertEqual(43.0, sess.run('d:0'))
self.assertEqual(43.0, sess.run(u'd:0'))
self.assertEqual(43.0, sess.run(b'd:0'))
self.assertEqual(43.0, sess.run(r'd:0'))
self.assertEqual(44.0, sess.run('e:0'))
self.assertEqual(44.0, sess.run(u'e:0'))
self.assertEqual(44.0, sess.run(b'e:0'))
self.assertEqual(44.0, sess.run(r'e:0'))
self.assertEqual(45.0, sess.run('f:0'))
self.assertEqual(45.0, sess.run(u'f:0'))
self.assertEqual(45.0, sess.run(b'f:0'))
self.assertEqual(45.0, sess.run(r'f:0'))
if __name__ == '__main__':

View File

@ -68,7 +68,7 @@ import_array();
PyObject* value;
Py_ssize_t pos = 0;
while (PyDict_Next($input, &pos, &key, &value)) {
const char* key_string = PyString_AsString(key);
char* key_string = PyBytes_AsString(key);
if (!key_string) {
SWIG_fail;
}
@ -121,7 +121,7 @@ import_array();
PyList_SET_ITEM(temp_string_list.get(), i, elem);
Py_INCREF(elem);
const char* fetch_name = PyString_AsString(elem);
char* fetch_name = PyBytes_AsString(elem);
if (!fetch_name) {
PyErr_SetString(PyExc_TypeError,
"a fetch or target name was not a string");

View File

@ -39,7 +39,7 @@ Status PyArrayDescr_to_TF_DataType(PyArray_Descr* descr,
PyObject* value;
Py_ssize_t pos = 0;
if (PyDict_Next(descr->fields, &pos, &key, &value)) {
const char* key_string = PyString_AsString(key);
const char* key_string = PyBytes_AsString(key);
if (!key_string) {
return errors::Internal("Corrupt numpy type descriptor");
}
@ -166,7 +166,7 @@ Status TF_DataType_to_PyArray_TYPE(TF_DataType tf_datatype,
// Iterate over the string array 'array', extract the ptr and len of each string
// element and call f(ptr, len).
template <typename F>
Status PyStringArrayMap(PyArrayObject* array, F f) {
Status PyBytesArrayMap(PyArrayObject* array, F f) {
Safe_PyObjectPtr iter = tensorflow::make_safe(
PyArray_IterNew(reinterpret_cast<PyObject*>(array)));
while (PyArray_ITER_NOTDONE(iter.get())) {
@ -177,10 +177,23 @@ Status PyStringArrayMap(PyArrayObject* array, F f) {
}
char* ptr;
Py_ssize_t len;
int success = PyString_AsStringAndSize(item.get(), &ptr, &len);
if (success != 0) {
return errors::Internal("Unable to get element from the feed.");
#if PY_VERSION_HEX >= 0x03030000
// Accept unicode in Python 3, by converting to UTF-8 bytes.
if (PyUnicode_Check(item.get())) {
ptr = PyUnicode_AsUTF8AndSize(item.get(), &len);
if (!buf) {
return errors::Internal("Unable to get element from the feed.");
}
} else {
#endif
int success = PyBytes_AsStringAndSize(item.get(), &ptr, &len);
if (success != 0) {
return errors::Internal("Unable to get element from the feed.");
}
#if PY_VERSION_HEX >= 0x03030000
}
#endif
f(ptr, len);
PyArray_ITER_NEXT(iter.get());
}
@ -189,15 +202,14 @@ Status PyStringArrayMap(PyArrayObject* array, F f) {
// Encode the strings in 'array' into a contiguous buffer and return the base of
// the buffer. The caller takes ownership of the buffer.
Status EncodePyStringArray(PyArrayObject* array, tensorflow::int64 nelems,
size_t* size, void** buffer) {
Status EncodePyBytesArray(PyArrayObject* array, tensorflow::int64 nelems,
size_t* size, void** buffer) {
// Compute bytes needed for encoding.
*size = 0;
TF_RETURN_IF_ERROR(
PyStringArrayMap(array, [&size](char* ptr, Py_ssize_t len) {
*size += sizeof(tensorflow::uint64) +
tensorflow::core::VarintLength(len) + len;
}));
TF_RETURN_IF_ERROR(PyBytesArrayMap(array, [&size](char* ptr, Py_ssize_t len) {
*size +=
sizeof(tensorflow::uint64) + tensorflow::core::VarintLength(len) + len;
}));
// Encode all strings.
std::unique_ptr<char[]> base_ptr(new char[*size]);
char* base = base_ptr.get();
@ -205,7 +217,7 @@ Status EncodePyStringArray(PyArrayObject* array, tensorflow::int64 nelems,
char* dst = data_start; // Where next string is encoded.
tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base);
TF_RETURN_IF_ERROR(PyStringArrayMap(
TF_RETURN_IF_ERROR(PyBytesArrayMap(
array, [&base, &data_start, &dst, &offsets](char* ptr, Py_ssize_t len) {
*offsets = (dst - data_start);
offsets++;
@ -251,7 +263,7 @@ static Status CopyStringToPyArrayElement(PyArrayObject* pyarray, void* i_ptr,
tensorflow::uint64 len;
TF_RETURN_IF_ERROR(
TF_StringTensor_GetPtrAndLen(tensor, num_elements, i, &ptr, &len));
auto py_string = tensorflow::make_safe(PyString_FromStringAndSize(ptr, len));
auto py_string = tensorflow::make_safe(PyBytes_FromStringAndSize(ptr, len));
int success =
PyArray_SETITEM(pyarray, PyArray_ITER_DATA(i_ptr), py_string.get());
if (success != 0) {
@ -446,7 +458,7 @@ void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs,
} else {
size_t size;
void* encoded;
Status s = EncodePyStringArray(array, nelems, &size, &encoded);
Status s = EncodePyBytesArray(array, nelems, &size, &encoded);
if (!s.ok()) {
*out_status = s;
return;

View File

@ -7,6 +7,8 @@ import contextlib
import tensorflow.python.platform
import six
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.python.framework import op_def_registry
@ -170,12 +172,13 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
input_map = {}
else:
if not (isinstance(input_map, dict)
and all(isinstance(k, basestring) for k in input_map.keys())):
and all(isinstance(k, six.string_types) for k in input_map.keys())):
raise TypeError('input_map must be a dictionary mapping strings to '
'Tensor objects.')
if (return_elements is not None
and not (isinstance(return_elements, (list, tuple))
and all(isinstance(x, basestring) for x in return_elements))):
and all(isinstance(x, six.string_types)
for x in return_elements))):
raise TypeError('return_elements must be a list of strings.')
# Use a canonical representation for all tensor names.

View File

@ -176,6 +176,58 @@ class ImportGraphDefTest(tf.test.TestCase):
self.assertEqual(d.inputs[0], a.outputs[1])
self.assertEqual(d.inputs[1], feed_b_1)
def testInputMapBytes(self):
with tf.Graph().as_default():
feed_a_0 = tf.constant(0, dtype=tf.int32)
feed_b_1 = tf.constant(1, dtype=tf.int32)
a, b, c, d = tf.import_graph_def(
self._MakeGraphDef("""
node { name: 'A' op: 'Oii' }
node { name: 'B' op: 'Oii' }
node { name: 'C' op: 'In'
attr { key: 'N' value { i: 2 } }
attr { key: 'T' value { type: DT_INT32 } }
input: 'A:0' input: 'B:0' }
node { name: 'D' op: 'In'
attr { key: 'N' value { i: 2 } }
attr { key: 'T' value { type: DT_INT32 } }
input: 'A:1' input: 'B:1' }
"""),
input_map={b'A:0': feed_a_0, b'B:1': feed_b_1},
return_elements=[b'A', b'B', b'C', b'D'])
self.assertEqual(c.inputs[0], feed_a_0)
self.assertEqual(c.inputs[1], b.outputs[0])
self.assertEqual(d.inputs[0], a.outputs[1])
self.assertEqual(d.inputs[1], feed_b_1)
def testInputMapUnicode(self):
with tf.Graph().as_default():
feed_a_0 = tf.constant(0, dtype=tf.int32)
feed_b_1 = tf.constant(1, dtype=tf.int32)
a, b, c, d = tf.import_graph_def(
self._MakeGraphDef("""
node { name: 'A' op: 'Oii' }
node { name: 'B' op: 'Oii' }
node { name: 'C' op: 'In'
attr { key: 'N' value { i: 2 } }
attr { key: 'T' value { type: DT_INT32 } }
input: 'A:0' input: 'B:0' }
node { name: 'D' op: 'In'
attr { key: 'N' value { i: 2 } }
attr { key: 'T' value { type: DT_INT32 } }
input: 'A:1' input: 'B:1' }
"""),
input_map={u'A:0': feed_a_0, u'B:1': feed_b_1},
return_elements=[u'A', u'B', u'C', u'D'])
self.assertEqual(c.inputs[0], feed_a_0)
self.assertEqual(c.inputs[1], b.outputs[0])
self.assertEqual(d.inputs[0], a.outputs[1])
self.assertEqual(d.inputs[1], feed_b_1)
def testImplicitZerothOutput(self):
with tf.Graph().as_default():
a, b = tf.import_graph_def(

View File

@ -1327,7 +1327,7 @@ class RegisterGradient(object):
op_type: The string type of an operation. This corresponds to the
`OpDef.name` field for the proto that defines the operation.
"""
if not isinstance(op_type, basestring):
if not isinstance(op_type, six.string_types):
raise TypeError("op_type must be a string")
self._op_type = op_type
@ -1356,7 +1356,7 @@ def NoGradient(op_type):
TypeError: If `op_type` is not a string.
"""
if not isinstance(op_type, basestring):
if not isinstance(op_type, six.string_types):
raise TypeError("op_type must be a string")
_gradient_registry.register(None, op_type)
@ -1400,7 +1400,7 @@ class RegisterShape(object):
def __init__(self, op_type):
"""Saves the "op_type" as the Operation type."""
if not isinstance(op_type, basestring):
if not isinstance(op_type, six.string_types):
raise TypeError("op_type must be a string")
self._op_type = op_type
@ -1818,7 +1818,7 @@ class Graph(object):
obj = conv_fn()
# If obj appears to be a name...
if isinstance(obj, basestring):
if isinstance(obj, six.string_types):
name = obj
if ":" in name and allow_tensor:
@ -1909,7 +1909,7 @@ class Graph(object):
KeyError: If `name` does not correspond to an operation in this graph.
"""
if not isinstance(name, basestring):
if not isinstance(name, six.string_types):
raise TypeError("Operation names are strings (or similar), not %s."
% type(name).__name__)
return self.as_graph_element(name, allow_tensor=False, allow_operation=True)
@ -1930,7 +1930,7 @@ class Graph(object):
KeyError: If `name` does not correspond to a tensor in this graph.
"""
# Names should be strings.
if not isinstance(name, basestring):
if not isinstance(name, six.string_types):
raise TypeError("Tensor names are strings (or similar), not %s."
% type(name).__name__)
return self.as_graph_element(name, allow_tensor=True, allow_operation=False)
@ -2495,8 +2495,8 @@ class Graph(object):
saved_labels = {}
# Install the given label
for op_type, label in op_to_kernel_label_map.items():
if not (isinstance(op_type, basestring)
and isinstance(label, basestring)):
if not (isinstance(op_type, six.string_types)
and isinstance(label, six.string_types)):
raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
"strings to strings")
try:
@ -2558,8 +2558,8 @@ class Graph(object):
saved_mappings = {}
# Install the given label
for op_type, mapped_op_type in op_type_map.items():
if not (isinstance(op_type, basestring)
and isinstance(mapped_op_type, basestring)):
if not (isinstance(op_type, six.string_types)
and isinstance(mapped_op_type, six.string_types)):
raise TypeError("op_type_map must be a dictionary mapping "
"strings to strings")
try:

View File

@ -207,7 +207,10 @@ def _FilterComplex(v):
def _FilterStr(v):
if isinstance(v, (list, tuple)):
return _FirstNotNone([_FilterStr(x) for x in v])
return None if isinstance(v, basestring) else _NotNone(v)
if isinstance(v, (six.string_types, six.binary_type)):
return None
else:
return _NotNone(v)
def _FilterBool(v):

View File

@ -25,31 +25,15 @@
%typemap(typecheck) tensorflow::StringPiece = char *;
%typemap(typecheck) const tensorflow::StringPiece & = char *;
// "tensorflow::StringPiece" arguments can be provided by a simple Python 'str' string
// or a 'unicode' object. If 'unicode', it's translated using the default
// encoding, i.e., sys.getdefaultencoding(). If passed None, a tensorflow::StringPiece
// of zero length with a NULL pointer is provided.
// "tensorflow::StringPiece" arguments must be specified as a 'str' or 'bytes' object.
%typemap(in) tensorflow::StringPiece {
if ($input != Py_None) {
char * buf;
Py_ssize_t len;
%#if PY_VERSION_HEX >= 0x03030000
/* Do unicode handling as PyBytes_AsStringAndSize doesn't in Python 3. */
if (PyUnicode_Check($input)) {
buf = PyUnicode_AsUTF8AndSize($input, &len);
if (buf == NULL)
SWIG_fail;
} else {
%#elif PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 3
%# error "Unsupported Python 3.x C API version (3.3 or later required)."
%#endif
if (PyBytes_AsStringAndSize($input, &buf, &len) == -1) {
// Python has raised an error (likely TypeError or UnicodeEncodeError).
SWIG_fail;
}
%#if PY_VERSION_HEX >= 0x03030000
if (PyBytes_AsStringAndSize($input, &buf, &len) == -1) {
// Python has raised an error (likely TypeError or UnicodeEncodeError).
SWIG_fail;
}
%#endif
$1.set(buf, len);
}
}
@ -60,23 +44,10 @@
if ($input != Py_None) {
char * buf;
Py_ssize_t len;
%#if PY_VERSION_HEX >= 0x03030000
/* Do unicode handling as PyBytes_AsStringAndSize doesn't in Python 3. */
if (PyUnicode_Check($input)) {
buf = PyUnicode_AsUTF8AndSize($input, &len);
if (buf == NULL)
SWIG_fail;
} else {
%#elif PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 3
%# error "Unsupported Python 3.x C API version (3.3 or later required)."
%#endif
if (PyBytes_AsStringAndSize($input, &buf, &len) == -1) {
// Python has raised an error (likely TypeError or UnicodeEncodeError).
SWIG_fail;
}
%#if PY_VERSION_HEX >= 0x03030000
if (PyBytes_AsStringAndSize($input, &buf, &len) == -1) {
// Python has raised an error (likely TypeError).
SWIG_fail;
}
%#endif
temp.set(buf, len);
}
$1 = &temp;
@ -86,7 +57,7 @@
// or None if the StringPiece contained a NULL pointer.
%typemap(out) tensorflow::StringPiece {
if ($1.data()) {
$result = PyString_FromStringAndSize($1.data(), $1.size());
$result = PyBytes_FromStringAndSize($1.data(), $1.size());
} else {
Py_INCREF(Py_None);
$result = Py_None;

View File

@ -5,6 +5,8 @@ from __future__ import print_function
import collections
import six
from tensorflow.python.framework import ops
from tensorflow.python.framework import types
from tensorflow.python.ops import array_ops
@ -102,7 +104,7 @@ def global_norm(t_list, name=None):
TypeError: If `t_list` is not a sequence.
"""
if (not isinstance(t_list, collections.Sequence)
or isinstance(t_list, basestring)):
or isinstance(t_list, six.string_types)):
raise TypeError("t_list should be a sequence")
t_list = list(t_list)
with ops.op_scope(t_list, name, "global_norm") as name:
@ -164,7 +166,7 @@ def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):
TypeError: If `t_list` is not a sequence.
"""
if (not isinstance(t_list, collections.Sequence)
or isinstance(t_list, basestring)):
or isinstance(t_list, six.string_types)):
raise TypeError("t_list should be a sequence")
t_list = list(t_list)
if use_norm is None:

View File

@ -5,6 +5,7 @@ from __future__ import division
from __future__ import print_function
import numbers
import six
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import op_def_pb2
@ -128,7 +129,7 @@ def _MakeFloat(v, arg_name):
def _MakeInt(v, arg_name):
if isinstance(v, basestring):
if isinstance(v, six.string_types):
raise TypeError("Expected int for argument '%s' not %s." %
(arg_name, repr(v)))
try:
@ -139,7 +140,7 @@ def _MakeInt(v, arg_name):
def _MakeStr(v, arg_name):
if not isinstance(v, basestring):
if not isinstance(v, six.string_types):
raise TypeError("Expected string for argument '%s' not %s." %
(arg_name, repr(v)))
# TODO(irving): Figure out what to do here from Python 3

View File

@ -5,6 +5,7 @@ from __future__ import division
from __future__ import print_function
import contextlib
import six
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@ -296,14 +297,14 @@ def variable_scope(name_or_scope, reuse=None, initializer=None):
a reuse scope, or if reuse is not `None` or `True`.
TypeError: when the types of some arguments are not appropriate.
"""
if not isinstance(name_or_scope, (_VariableScope, basestring)):
if not isinstance(name_or_scope, (_VariableScope,) + six.string_types):
raise TypeError("VariableScope: name_scope must be a string or "
"VariableScope.")
if reuse not in [None, True]:
raise ValueError("VariableScope reuse parameter must be True or None.")
if not reuse and isinstance(name_or_scope, (_VariableScope)):
logging.info("Passing VariableScope to a non-reusing scope, intended?")
if reuse and isinstance(name_or_scope, (basestring)):
if reuse and isinstance(name_or_scope, six.string_types):
logging.info("Re-using string-named scope, consider capturing as object.")
get_variable_scope() # Ensure that a default exists, then get a pointer.
default_varscope = ops.get_collection(_VARSCOPE_KEY)

View File

@ -21,13 +21,7 @@
bool _PyObjAs(PyObject *pystr, ::string* cstr) {
char *buf;
Py_ssize_t len;
#if PY_VERSION_HEX >= 0x03030000
if (PyUnicode_Check(pystr)) {
buf = PyUnicode_AsUTF8AndSize(pystr, &len);
if (!buf) return false;
} else // NOLINT
#endif
if (PyBytes_AsStringAndSize(pystr, &buf, &len) == -1) return false;
if (PyBytes_AsStringAndSize(pystr, &buf, &len) == -1) return false;
if (cstr) cstr->assign(buf, len);
return true;
}
@ -36,29 +30,23 @@
bool _PyObjAs(PyObject *pystr, std::string* cstr) {
char *buf;
Py_ssize_t len;
#if PY_VERSION_HEX >= 0x03030000
if (PyUnicode_Check(pystr)) {
buf = PyUnicode_AsUTF8AndSize(pystr, &len);
if (!buf) return false;
} else // NOLINT
#endif
if (PyBytes_AsStringAndSize(pystr, &buf, &len) == -1) return false;
if (PyBytes_AsStringAndSize(pystr, &buf, &len) == -1) return false;
if (cstr) cstr->assign(buf, len);
return true;
}
#ifdef HAS_GLOBAL_STRING
template<>
PyObject* _PyObjFrom(const ::string& c) {
return PyString_FromStringAndSize(c.data(), c.size());
return PyBytes_FromStringAndSize(c.data(), c.size());
}
#endif
template<>
PyObject* _PyObjFrom(const std::string& c) {
return PyString_FromStringAndSize(c.data(), c.size());
return PyBytes_FromStringAndSize(c.data(), c.size());
}
PyObject* _SwigString_FromString(const string& s) {
return PyUnicode_FromStringAndSize(s.data(), s.size());
return PyBytes_FromStringAndSize(s.data(), s.size());
}
%}
@ -72,11 +60,11 @@
}
%typemap(out) string {
$result = PyString_FromStringAndSize($1.data(), $1.size());
$result = PyBytes_FromStringAndSize($1.data(), $1.size());
}
%typemap(out) const string& {
$result = PyString_FromStringAndSize($1->data(), $1->size());
$result = PyBytes_FromStringAndSize($1->data(), $1->size());
}
%typemap(in, numinputs = 0) string* OUTPUT (string temp) {
@ -84,7 +72,7 @@
}
%typemap(argout) string * OUTPUT {
PyObject *str = PyString_FromStringAndSize($1->data(), $1->length());
PyObject *str = PyBytes_FromStringAndSize($1->data(), $1->length());
if (!str) SWIG_fail;
%append_output(str);
}
@ -92,7 +80,7 @@
%typemap(argout) string* INOUT = string* OUTPUT;
%typemap(varout) string {
$result = PyString_FromStringAndSize($1.data(), $1.size());
$result = PyBytes_FromStringAndSize($1.data(), $1.size());
}
%define _LIST_OUTPUT_TYPEMAP(type, py_converter)

View File

@ -9,6 +9,8 @@ import numbers
import os.path
import time
import six
from google.protobuf import text_format
from tensorflow.python.client import graph_util
@ -299,7 +301,7 @@ class BaseSaverBuilder(object):
vars_to_save = []
seen_variables = set()
for name in sorted(names_to_variables.keys()):
if not isinstance(name, basestring):
if not isinstance(name, six.string_types):
raise TypeError("names_to_variables must be a dict mapping string "
"names to variable Tensors. Name is not a string: %s" %
name)

View File

@ -10,6 +10,7 @@ import tensorflow.python.platform
import tensorflow as tf
import numpy as np
import six
from tensorflow.python.platform import gfile
@ -33,7 +34,7 @@ class SaverTest(tf.test.TestCase):
# Save the initialized values in the file at "save_path"
val = save.save(sess, save_path)
self.assertTrue(isinstance(val, basestring))
self.assertTrue(isinstance(val, six.string_types))
self.assertEqual(save_path, val)
# Start a second session. In that session the parameter nodes
@ -84,7 +85,7 @@ class SaverTest(tf.test.TestCase):
# Save the initialized values in the file at "save_path"
val = save.save(sess, save_path)
self.assertTrue(isinstance(val, basestring))
self.assertTrue(isinstance(val, six.string_types))
self.assertEqual(save_path, val)
with self.test_session() as sess:
@ -131,7 +132,7 @@ class SaverTest(tf.test.TestCase):
# Save the initialized values in the file at "save_path"
val = save.save(sess, save_path)
self.assertTrue(isinstance(val, basestring))
self.assertTrue(isinstance(val, six.string_types))
self.assertEqual(save_path, val)
# Start a second session. In that session the variables
@ -517,7 +518,7 @@ class SaveRestoreWithVariableNameMap(tf.test.TestCase):
# Save the initialized values in the file at "save_path"
# Use a variable name map to set the saved tensor names
val = save.save(sess, save_path)
self.assertTrue(isinstance(val, basestring))
self.assertTrue(isinstance(val, six.string_types))
self.assertEqual(save_path, val)
# Verify that the original names are not in the Saved file

View File

@ -9,6 +9,8 @@ import Queue
import threading
import time
import six
from tensorflow.core.framework import summary_pb2
from tensorflow.core.util import event_pb2
from tensorflow.python import pywrap_tensorflow
@ -103,7 +105,7 @@ class SummaryWriter(object):
global_step: Number. Optional global step value to record with the
summary.
"""
if isinstance(summary, basestring):
if isinstance(summary, six.binary_type):
summ = summary_pb2.Summary()
summ.ParseFromString(summary)
summary = summ

View File

@ -82,7 +82,7 @@ def assertProto2Equal(self, a, b, check_initialized=True,
numbers before comparison.
msg: if specified, is used as the error message on failure
"""
if isinstance(a, basestring):
if isinstance(a, six.string_types):
a = text_format.Merge(a, b.__class__())
for pb in a, b:
@ -121,7 +121,7 @@ def assertProto2SameElements(self, a, b, number_matters=False,
numbers before comparison.
msg: if specified, is used as the error message on failure
"""
if isinstance(a, basestring):
if isinstance(a, six.string_types):
a = text_format.Merge(a, b.__class__())
else:
a = copy.deepcopy(a)
@ -152,7 +152,7 @@ def assertProto2Contains(self, a, b, # pylint: disable=invalid-name
check_initialized: boolean, whether to fail if b isn't initialized
msg: if specified, is used as the error message on failure
"""
if isinstance(a, basestring):
if isinstance(a, six.string_types):
a = text_format.Merge(a, b.__class__())
else:
a = copy.deepcopy(a)
@ -299,7 +299,7 @@ def NormalizeNumberFields(pb):
def _IsRepeatedContainer(value):
if isinstance(value, basestring):
if isinstance(value, six.string_types):
return False
try:
iter(value)

View File

@ -342,12 +342,12 @@ class NormalizeNumbersTest(googletest.TestCase):
class AssertTest(googletest.TestCase):
"""Tests both assertProto2Equal() and assertProto2SameElements()."""
def assertProto2Equal(self, a, b, **kwargs):
if isinstance(a, basestring) and isinstance(b, basestring):
if isinstance(a, six.string_types) and isinstance(b, six.string_types):
a, b = LargePbs(a, b)
compare.assertProto2Equal(self, a, b, **kwargs)
def assertProto2SameElements(self, a, b, **kwargs):
if isinstance(a, basestring) and isinstance(b, basestring):
if isinstance(a, six.string_types) and isinstance(b, six.string_types):
a, b = LargePbs(a, b)
compare.assertProto2SameElements(self, a, b, **kwargs)

View File

@ -16,9 +16,9 @@
"private": true,
"dependencies": {
"d3": "3.5.6",
"dagre": "~0.7.4",
"es6-promise": "~3.0.2",
"graphlib": "~1.0.7",
"dagre": "0.7.4",
"es6-promise": "3.0.2",
"graphlib": "1.0.7",
"iron-ajax": "PolymerElements/iron-ajax#1.0.7",
"iron-collapse": "PolymerElements/iron-collapse#1.0.4",
"iron-list": "PolymerElements/iron-list#1.1.5",
@ -38,7 +38,7 @@
"paper-styles": "PolymerElements/paper-styles#1.0.12",
"paper-toggle-button": "PolymerElements/paper-toggle-button#1.0.11",
"paper-toolbar": "PolymerElements/paper-toolbar#1.0.4",
"plottable": "~1.16.1",
"plottable": "1.16.2",
"polymer": "1.1.5"
},
"devDependencies": {