TensorFlow: a few small updates.
Changes: - Fix softmax formula in word2vec to remove an extra exp() by @gouwsmeister - Python3 fixes to remove basestring / support for unicode by @mrry - Remove some comments by Josh - Specify exact versions of bower dependencies for TensorBoard by @danmane. Base CL: 107742361
This commit is contained in:
parent
d50565b35e
commit
011e9baccd
@ -111,8 +111,7 @@ Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
|
||||
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
|
||||
const NameAttrList** value); // type: "func"
|
||||
|
||||
// Computes the input and output types for a specific node, for
|
||||
// attr-style ops.
|
||||
// Computes the input and output types for a specific node.
|
||||
// REQUIRES: ValidateOpDef(op_def).ok()
|
||||
Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
|
||||
DataTypeVector* inputs, DataTypeVector* outputs);
|
||||
|
@ -64,11 +64,9 @@ Node* Unary(Graph* g, const string& func, Node* input, int index = 0);
|
||||
Node* Identity(Graph* g, Node* input, int index = 0);
|
||||
|
||||
// Adds a binary function "func" node in "g" taking "in0" and "in1".
|
||||
// Requires that "func" name an attr-style Op.
|
||||
Node* Binary(Graph* g, const string& func, Node* in0, Node* in1);
|
||||
|
||||
// Adds a function "func" node in "g" taking inputs "ins".
|
||||
// Requires that "func" name an attr-style Op.
|
||||
Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins);
|
||||
|
||||
// Adds a binary add node in "g" doing in0 + in1.
|
||||
|
@ -100,7 +100,7 @@ given the previous words \\(h\\) (for "history") in terms of a
|
||||
|
||||
$$
|
||||
\begin{align}
|
||||
P(w_t | h) &= \text{softmax}(\exp \{ \text{score}(w_t, h) \}) \\
|
||||
P(w_t | h) &= \text{softmax}(\text{score}(w_t, h)) \\
|
||||
&= \frac{\exp \{ \text{score}(w_t, h) \} }
|
||||
{\sum_\text{Word w' in Vocab} \exp \{ \text{score}(w', h) \} }.
|
||||
\end{align}
|
||||
|
@ -11,6 +11,7 @@ import threading
|
||||
import tensorflow.python.platform
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow as tf_session
|
||||
from tensorflow.python.framework import errors
|
||||
@ -36,6 +37,56 @@ class SessionInterface(object):
|
||||
raise NotImplementedError('Run')
|
||||
|
||||
|
||||
def _as_bytes(bytes_or_unicode):
|
||||
"""Returns the given argument as a byte array.
|
||||
|
||||
NOTE(mrry): For Python 2 and 3 compatibility, we convert all string
|
||||
arguments to SWIG methods into byte arrays. Unicode strings are
|
||||
encoded as UTF-8; however the valid arguments for all of the
|
||||
human-readable arguments must currently be a subset of ASCII.
|
||||
|
||||
Args:
|
||||
bytes_or_unicode: A `unicode`, `string`, or `bytes` object.
|
||||
|
||||
Returns:
|
||||
A `bytes` object.
|
||||
|
||||
Raises:
|
||||
TypeError: If `bytes_or_unicode` is not a binary or unicode string.
|
||||
"""
|
||||
if isinstance(bytes_or_unicode, six.text_type):
|
||||
return bytes_or_unicode.encode('utf-8')
|
||||
elif isinstance(bytes_or_unicode, six.binary_type):
|
||||
return bytes_or_unicode
|
||||
else:
|
||||
raise TypeError('bytes_or_unicode must be a binary or unicode string.')
|
||||
|
||||
|
||||
def _as_text(bytes_or_unicode):
|
||||
"""Returns the given argument as a unicode string.
|
||||
|
||||
NOTE(mrry): For Python 2 and 3 compatibility, we interpret all
|
||||
returned strings from SWIG methods as byte arrays. This function
|
||||
converts those strings that are intended to be human-readable into
|
||||
UTF-8 unicode strings.
|
||||
|
||||
Args:
|
||||
bytes_or_unicode: A `unicode`, `string`, or `bytes` object.
|
||||
|
||||
Returns:
|
||||
A `unicode` (Python 2) or `str` (Python 3) object.
|
||||
|
||||
Raises:
|
||||
TypeError: If `bytes_or_unicode` is not a binary or unicode string.
|
||||
"""
|
||||
if isinstance(bytes_or_unicode, six.text_type):
|
||||
return bytes_or_unicode
|
||||
elif isinstance(bytes_or_unicode, six.binary_type):
|
||||
return bytes_or_unicode.decode('utf-8')
|
||||
else:
|
||||
raise TypeError('bytes_or_unicode must be a binary or unicode string.')
|
||||
|
||||
|
||||
class BaseSession(SessionInterface):
|
||||
"""A class for interacting with a TensorFlow computation.
|
||||
|
||||
@ -75,8 +126,7 @@ class BaseSession(SessionInterface):
|
||||
status = tf_session.TF_NewStatus()
|
||||
self._session = tf_session.TF_NewSession(opts, status)
|
||||
if tf_session.TF_GetCode(status) != 0:
|
||||
message = tf_session.TF_Message(status)
|
||||
raise RuntimeError(message)
|
||||
raise RuntimeError(_as_text(tf_session.TF_Message(status)))
|
||||
|
||||
finally:
|
||||
tf_session.TF_DeleteSessionOptions(opts)
|
||||
@ -97,7 +147,7 @@ class BaseSession(SessionInterface):
|
||||
status = tf_session.TF_NewStatus()
|
||||
tf_session.TF_CloseSession(self._session, status)
|
||||
if tf_session.TF_GetCode(status) != 0:
|
||||
raise RuntimeError(tf_session.TF_Message(status))
|
||||
raise RuntimeError(_as_text(tf_session.TF_Message(status)))
|
||||
finally:
|
||||
tf_session.TF_DeleteStatus(status)
|
||||
|
||||
@ -108,7 +158,7 @@ class BaseSession(SessionInterface):
|
||||
if self._session is not None:
|
||||
tf_session.TF_DeleteSession(self._session, status)
|
||||
if tf_session.TF_GetCode(status) != 0:
|
||||
raise RuntimeError(tf_session.TF_Message(status))
|
||||
raise RuntimeError(_as_text(tf_session.TF_Message(status)))
|
||||
self._session = None
|
||||
finally:
|
||||
tf_session.TF_DeleteStatus(status)
|
||||
@ -265,7 +315,6 @@ class BaseSession(SessionInterface):
|
||||
TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type.
|
||||
ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a
|
||||
`Tensor` that doesn't exist.
|
||||
|
||||
"""
|
||||
def _fetch_fn(fetch):
|
||||
for tensor_type, fetch_fn, _ in BaseSession._REGISTERED_EXPANSIONS:
|
||||
@ -302,9 +351,9 @@ class BaseSession(SessionInterface):
|
||||
fetch_t = self.graph.as_graph_element(subfetch, allow_tensor=True,
|
||||
allow_operation=True)
|
||||
if isinstance(fetch_t, ops.Operation):
|
||||
target_list.append(fetch_t.name)
|
||||
target_list.append(_as_bytes(fetch_t.name))
|
||||
else:
|
||||
subfetch_names.append(fetch_t.name)
|
||||
subfetch_names.append(_as_bytes(fetch_t.name))
|
||||
except TypeError as e:
|
||||
raise TypeError('Fetch argument %r of %r has invalid type %r, '
|
||||
'must be a string or Tensor. (%s)'
|
||||
@ -343,7 +392,7 @@ class BaseSession(SessionInterface):
|
||||
'which has shape %r'
|
||||
% (np_val.shape, subfeed_t.name,
|
||||
tuple(subfeed_t.get_shape().dims)))
|
||||
feed_dict_string[str(subfeed_t.name)] = np_val
|
||||
feed_dict_string[_as_bytes(subfeed_t.name)] = np_val
|
||||
|
||||
# Run request and get response.
|
||||
results = self._do_run(target_list, unique_fetch_targets, feed_dict_string)
|
||||
@ -373,11 +422,12 @@ class BaseSession(SessionInterface):
|
||||
"""Runs a step based on the given fetches and feeds.
|
||||
|
||||
Args:
|
||||
target_list: A list of strings corresponding to names of tensors
|
||||
target_list: A list of byte arrays corresponding to names of tensors
|
||||
or operations to be run to, but not fetched.
|
||||
fetch_list: A list of strings corresponding to names of tensors to be
|
||||
fetched and operations to be run.
|
||||
feed_dict: A dictionary that maps tensor names to numpy ndarrays.
|
||||
fetch_list: A list of byte arrays corresponding to names of tensors to
|
||||
be fetched and operations to be run.
|
||||
feed_dict: A dictionary that maps tensor names (as byte arrays) to
|
||||
numpy ndarrays.
|
||||
|
||||
Returns:
|
||||
A list of numpy ndarrays, corresponding to the elements of
|
||||
@ -397,7 +447,7 @@ class BaseSession(SessionInterface):
|
||||
tf_session.TF_ExtendGraph(
|
||||
self._session, graph_def.SerializeToString(), status)
|
||||
if tf_session.TF_GetCode(status) != 0:
|
||||
raise RuntimeError(tf_session.TF_Message(status))
|
||||
raise RuntimeError(_as_text(tf_session.TF_Message(status)))
|
||||
self._opened = True
|
||||
finally:
|
||||
tf_session.TF_DeleteStatus(status)
|
||||
@ -409,7 +459,8 @@ class BaseSession(SessionInterface):
|
||||
|
||||
except tf_session.StatusNotOK as e:
|
||||
e_type, e_value, e_traceback = sys.exc_info()
|
||||
m = BaseSession._NODEDEF_NAME_RE.search(e.error_message)
|
||||
error_message = _as_text(e.error_message)
|
||||
m = BaseSession._NODEDEF_NAME_RE.search(error_message)
|
||||
if m is not None:
|
||||
node_name = m.group(1)
|
||||
node_def = None
|
||||
@ -419,7 +470,7 @@ class BaseSession(SessionInterface):
|
||||
except KeyError:
|
||||
op = None
|
||||
# pylint: disable=protected-access
|
||||
raise errors._make_specific_exception(node_def, op, e.error_message,
|
||||
raise errors._make_specific_exception(node_def, op, error_message,
|
||||
e.code)
|
||||
# pylint: enable=protected-access
|
||||
raise e_type, e_value, e_traceback
|
||||
|
@ -9,6 +9,7 @@ import time
|
||||
import tensorflow.python.platform
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.framework import config_pb2
|
||||
@ -551,9 +552,55 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(c_list[0], out[0])
|
||||
self.assertEqual(c_list[1], out[1])
|
||||
|
||||
def testStringFeedWithUnicode(self):
|
||||
with session.Session():
|
||||
c_list = [u'\n\x01\x00', u'\n\x00\x01']
|
||||
feed_t = array_ops.placeholder(dtype=types.string, shape=[2])
|
||||
c = array_ops.identity(feed_t)
|
||||
|
||||
out = c.eval(feed_dict={feed_t: c_list})
|
||||
self.assertEqual(c_list[0], out[0].decode('utf-8'))
|
||||
self.assertEqual(c_list[1], out[1].decode('utf-8'))
|
||||
|
||||
out = c.eval(feed_dict={feed_t: np.array(c_list, dtype=np.object)})
|
||||
self.assertEqual(c_list[0], out[0].decode('utf-8'))
|
||||
self.assertEqual(c_list[1], out[1].decode('utf-8'))
|
||||
|
||||
def testInvalidTargetFails(self):
|
||||
with self.assertRaises(RuntimeError):
|
||||
session.Session("INVALID_TARGET")
|
||||
session.Session('INVALID_TARGET')
|
||||
|
||||
def testFetchByNameDifferentStringTypes(self):
|
||||
with session.Session() as sess:
|
||||
c = constant_op.constant(42.0, name='c')
|
||||
d = constant_op.constant(43.0, name=u'd')
|
||||
e = constant_op.constant(44.0, name=b'e')
|
||||
f = constant_op.constant(45.0, name=r'f')
|
||||
|
||||
self.assertTrue(isinstance(c.name, six.text_type))
|
||||
self.assertTrue(isinstance(d.name, six.text_type))
|
||||
self.assertTrue(isinstance(e.name, six.text_type))
|
||||
self.assertTrue(isinstance(f.name, six.text_type))
|
||||
|
||||
self.assertEqual(42.0, sess.run('c:0'))
|
||||
self.assertEqual(42.0, sess.run(u'c:0'))
|
||||
self.assertEqual(42.0, sess.run(b'c:0'))
|
||||
self.assertEqual(42.0, sess.run(r'c:0'))
|
||||
|
||||
self.assertEqual(43.0, sess.run('d:0'))
|
||||
self.assertEqual(43.0, sess.run(u'd:0'))
|
||||
self.assertEqual(43.0, sess.run(b'd:0'))
|
||||
self.assertEqual(43.0, sess.run(r'd:0'))
|
||||
|
||||
self.assertEqual(44.0, sess.run('e:0'))
|
||||
self.assertEqual(44.0, sess.run(u'e:0'))
|
||||
self.assertEqual(44.0, sess.run(b'e:0'))
|
||||
self.assertEqual(44.0, sess.run(r'e:0'))
|
||||
|
||||
self.assertEqual(45.0, sess.run('f:0'))
|
||||
self.assertEqual(45.0, sess.run(u'f:0'))
|
||||
self.assertEqual(45.0, sess.run(b'f:0'))
|
||||
self.assertEqual(45.0, sess.run(r'f:0'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -68,7 +68,7 @@ import_array();
|
||||
PyObject* value;
|
||||
Py_ssize_t pos = 0;
|
||||
while (PyDict_Next($input, &pos, &key, &value)) {
|
||||
const char* key_string = PyString_AsString(key);
|
||||
char* key_string = PyBytes_AsString(key);
|
||||
if (!key_string) {
|
||||
SWIG_fail;
|
||||
}
|
||||
@ -121,7 +121,7 @@ import_array();
|
||||
PyList_SET_ITEM(temp_string_list.get(), i, elem);
|
||||
Py_INCREF(elem);
|
||||
|
||||
const char* fetch_name = PyString_AsString(elem);
|
||||
char* fetch_name = PyBytes_AsString(elem);
|
||||
if (!fetch_name) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"a fetch or target name was not a string");
|
||||
|
@ -39,7 +39,7 @@ Status PyArrayDescr_to_TF_DataType(PyArray_Descr* descr,
|
||||
PyObject* value;
|
||||
Py_ssize_t pos = 0;
|
||||
if (PyDict_Next(descr->fields, &pos, &key, &value)) {
|
||||
const char* key_string = PyString_AsString(key);
|
||||
const char* key_string = PyBytes_AsString(key);
|
||||
if (!key_string) {
|
||||
return errors::Internal("Corrupt numpy type descriptor");
|
||||
}
|
||||
@ -166,7 +166,7 @@ Status TF_DataType_to_PyArray_TYPE(TF_DataType tf_datatype,
|
||||
// Iterate over the string array 'array', extract the ptr and len of each string
|
||||
// element and call f(ptr, len).
|
||||
template <typename F>
|
||||
Status PyStringArrayMap(PyArrayObject* array, F f) {
|
||||
Status PyBytesArrayMap(PyArrayObject* array, F f) {
|
||||
Safe_PyObjectPtr iter = tensorflow::make_safe(
|
||||
PyArray_IterNew(reinterpret_cast<PyObject*>(array)));
|
||||
while (PyArray_ITER_NOTDONE(iter.get())) {
|
||||
@ -177,10 +177,23 @@ Status PyStringArrayMap(PyArrayObject* array, F f) {
|
||||
}
|
||||
char* ptr;
|
||||
Py_ssize_t len;
|
||||
int success = PyString_AsStringAndSize(item.get(), &ptr, &len);
|
||||
if (success != 0) {
|
||||
return errors::Internal("Unable to get element from the feed.");
|
||||
|
||||
#if PY_VERSION_HEX >= 0x03030000
|
||||
// Accept unicode in Python 3, by converting to UTF-8 bytes.
|
||||
if (PyUnicode_Check(item.get())) {
|
||||
ptr = PyUnicode_AsUTF8AndSize(item.get(), &len);
|
||||
if (!buf) {
|
||||
return errors::Internal("Unable to get element from the feed.");
|
||||
}
|
||||
} else {
|
||||
#endif
|
||||
int success = PyBytes_AsStringAndSize(item.get(), &ptr, &len);
|
||||
if (success != 0) {
|
||||
return errors::Internal("Unable to get element from the feed.");
|
||||
}
|
||||
#if PY_VERSION_HEX >= 0x03030000
|
||||
}
|
||||
#endif
|
||||
f(ptr, len);
|
||||
PyArray_ITER_NEXT(iter.get());
|
||||
}
|
||||
@ -189,15 +202,14 @@ Status PyStringArrayMap(PyArrayObject* array, F f) {
|
||||
|
||||
// Encode the strings in 'array' into a contiguous buffer and return the base of
|
||||
// the buffer. The caller takes ownership of the buffer.
|
||||
Status EncodePyStringArray(PyArrayObject* array, tensorflow::int64 nelems,
|
||||
size_t* size, void** buffer) {
|
||||
Status EncodePyBytesArray(PyArrayObject* array, tensorflow::int64 nelems,
|
||||
size_t* size, void** buffer) {
|
||||
// Compute bytes needed for encoding.
|
||||
*size = 0;
|
||||
TF_RETURN_IF_ERROR(
|
||||
PyStringArrayMap(array, [&size](char* ptr, Py_ssize_t len) {
|
||||
*size += sizeof(tensorflow::uint64) +
|
||||
tensorflow::core::VarintLength(len) + len;
|
||||
}));
|
||||
TF_RETURN_IF_ERROR(PyBytesArrayMap(array, [&size](char* ptr, Py_ssize_t len) {
|
||||
*size +=
|
||||
sizeof(tensorflow::uint64) + tensorflow::core::VarintLength(len) + len;
|
||||
}));
|
||||
// Encode all strings.
|
||||
std::unique_ptr<char[]> base_ptr(new char[*size]);
|
||||
char* base = base_ptr.get();
|
||||
@ -205,7 +217,7 @@ Status EncodePyStringArray(PyArrayObject* array, tensorflow::int64 nelems,
|
||||
char* dst = data_start; // Where next string is encoded.
|
||||
tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base);
|
||||
|
||||
TF_RETURN_IF_ERROR(PyStringArrayMap(
|
||||
TF_RETURN_IF_ERROR(PyBytesArrayMap(
|
||||
array, [&base, &data_start, &dst, &offsets](char* ptr, Py_ssize_t len) {
|
||||
*offsets = (dst - data_start);
|
||||
offsets++;
|
||||
@ -251,7 +263,7 @@ static Status CopyStringToPyArrayElement(PyArrayObject* pyarray, void* i_ptr,
|
||||
tensorflow::uint64 len;
|
||||
TF_RETURN_IF_ERROR(
|
||||
TF_StringTensor_GetPtrAndLen(tensor, num_elements, i, &ptr, &len));
|
||||
auto py_string = tensorflow::make_safe(PyString_FromStringAndSize(ptr, len));
|
||||
auto py_string = tensorflow::make_safe(PyBytes_FromStringAndSize(ptr, len));
|
||||
int success =
|
||||
PyArray_SETITEM(pyarray, PyArray_ITER_DATA(i_ptr), py_string.get());
|
||||
if (success != 0) {
|
||||
@ -446,7 +458,7 @@ void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs,
|
||||
} else {
|
||||
size_t size;
|
||||
void* encoded;
|
||||
Status s = EncodePyStringArray(array, nelems, &size, &encoded);
|
||||
Status s = EncodePyBytesArray(array, nelems, &size, &encoded);
|
||||
if (!s.ok()) {
|
||||
*out_status = s;
|
||||
return;
|
||||
|
@ -7,6 +7,8 @@ import contextlib
|
||||
|
||||
import tensorflow.python.platform
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.python.framework import op_def_registry
|
||||
@ -170,12 +172,13 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
|
||||
input_map = {}
|
||||
else:
|
||||
if not (isinstance(input_map, dict)
|
||||
and all(isinstance(k, basestring) for k in input_map.keys())):
|
||||
and all(isinstance(k, six.string_types) for k in input_map.keys())):
|
||||
raise TypeError('input_map must be a dictionary mapping strings to '
|
||||
'Tensor objects.')
|
||||
if (return_elements is not None
|
||||
and not (isinstance(return_elements, (list, tuple))
|
||||
and all(isinstance(x, basestring) for x in return_elements))):
|
||||
and all(isinstance(x, six.string_types)
|
||||
for x in return_elements))):
|
||||
raise TypeError('return_elements must be a list of strings.')
|
||||
|
||||
# Use a canonical representation for all tensor names.
|
||||
|
@ -176,6 +176,58 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
self.assertEqual(d.inputs[0], a.outputs[1])
|
||||
self.assertEqual(d.inputs[1], feed_b_1)
|
||||
|
||||
def testInputMapBytes(self):
|
||||
with tf.Graph().as_default():
|
||||
feed_a_0 = tf.constant(0, dtype=tf.int32)
|
||||
feed_b_1 = tf.constant(1, dtype=tf.int32)
|
||||
|
||||
a, b, c, d = tf.import_graph_def(
|
||||
self._MakeGraphDef("""
|
||||
node { name: 'A' op: 'Oii' }
|
||||
node { name: 'B' op: 'Oii' }
|
||||
node { name: 'C' op: 'In'
|
||||
attr { key: 'N' value { i: 2 } }
|
||||
attr { key: 'T' value { type: DT_INT32 } }
|
||||
input: 'A:0' input: 'B:0' }
|
||||
node { name: 'D' op: 'In'
|
||||
attr { key: 'N' value { i: 2 } }
|
||||
attr { key: 'T' value { type: DT_INT32 } }
|
||||
input: 'A:1' input: 'B:1' }
|
||||
"""),
|
||||
input_map={b'A:0': feed_a_0, b'B:1': feed_b_1},
|
||||
return_elements=[b'A', b'B', b'C', b'D'])
|
||||
|
||||
self.assertEqual(c.inputs[0], feed_a_0)
|
||||
self.assertEqual(c.inputs[1], b.outputs[0])
|
||||
self.assertEqual(d.inputs[0], a.outputs[1])
|
||||
self.assertEqual(d.inputs[1], feed_b_1)
|
||||
|
||||
def testInputMapUnicode(self):
|
||||
with tf.Graph().as_default():
|
||||
feed_a_0 = tf.constant(0, dtype=tf.int32)
|
||||
feed_b_1 = tf.constant(1, dtype=tf.int32)
|
||||
|
||||
a, b, c, d = tf.import_graph_def(
|
||||
self._MakeGraphDef("""
|
||||
node { name: 'A' op: 'Oii' }
|
||||
node { name: 'B' op: 'Oii' }
|
||||
node { name: 'C' op: 'In'
|
||||
attr { key: 'N' value { i: 2 } }
|
||||
attr { key: 'T' value { type: DT_INT32 } }
|
||||
input: 'A:0' input: 'B:0' }
|
||||
node { name: 'D' op: 'In'
|
||||
attr { key: 'N' value { i: 2 } }
|
||||
attr { key: 'T' value { type: DT_INT32 } }
|
||||
input: 'A:1' input: 'B:1' }
|
||||
"""),
|
||||
input_map={u'A:0': feed_a_0, u'B:1': feed_b_1},
|
||||
return_elements=[u'A', u'B', u'C', u'D'])
|
||||
|
||||
self.assertEqual(c.inputs[0], feed_a_0)
|
||||
self.assertEqual(c.inputs[1], b.outputs[0])
|
||||
self.assertEqual(d.inputs[0], a.outputs[1])
|
||||
self.assertEqual(d.inputs[1], feed_b_1)
|
||||
|
||||
def testImplicitZerothOutput(self):
|
||||
with tf.Graph().as_default():
|
||||
a, b = tf.import_graph_def(
|
||||
|
@ -1327,7 +1327,7 @@ class RegisterGradient(object):
|
||||
op_type: The string type of an operation. This corresponds to the
|
||||
`OpDef.name` field for the proto that defines the operation.
|
||||
"""
|
||||
if not isinstance(op_type, basestring):
|
||||
if not isinstance(op_type, six.string_types):
|
||||
raise TypeError("op_type must be a string")
|
||||
self._op_type = op_type
|
||||
|
||||
@ -1356,7 +1356,7 @@ def NoGradient(op_type):
|
||||
TypeError: If `op_type` is not a string.
|
||||
|
||||
"""
|
||||
if not isinstance(op_type, basestring):
|
||||
if not isinstance(op_type, six.string_types):
|
||||
raise TypeError("op_type must be a string")
|
||||
_gradient_registry.register(None, op_type)
|
||||
|
||||
@ -1400,7 +1400,7 @@ class RegisterShape(object):
|
||||
|
||||
def __init__(self, op_type):
|
||||
"""Saves the "op_type" as the Operation type."""
|
||||
if not isinstance(op_type, basestring):
|
||||
if not isinstance(op_type, six.string_types):
|
||||
raise TypeError("op_type must be a string")
|
||||
self._op_type = op_type
|
||||
|
||||
@ -1818,7 +1818,7 @@ class Graph(object):
|
||||
obj = conv_fn()
|
||||
|
||||
# If obj appears to be a name...
|
||||
if isinstance(obj, basestring):
|
||||
if isinstance(obj, six.string_types):
|
||||
name = obj
|
||||
|
||||
if ":" in name and allow_tensor:
|
||||
@ -1909,7 +1909,7 @@ class Graph(object):
|
||||
KeyError: If `name` does not correspond to an operation in this graph.
|
||||
"""
|
||||
|
||||
if not isinstance(name, basestring):
|
||||
if not isinstance(name, six.string_types):
|
||||
raise TypeError("Operation names are strings (or similar), not %s."
|
||||
% type(name).__name__)
|
||||
return self.as_graph_element(name, allow_tensor=False, allow_operation=True)
|
||||
@ -1930,7 +1930,7 @@ class Graph(object):
|
||||
KeyError: If `name` does not correspond to a tensor in this graph.
|
||||
"""
|
||||
# Names should be strings.
|
||||
if not isinstance(name, basestring):
|
||||
if not isinstance(name, six.string_types):
|
||||
raise TypeError("Tensor names are strings (or similar), not %s."
|
||||
% type(name).__name__)
|
||||
return self.as_graph_element(name, allow_tensor=True, allow_operation=False)
|
||||
@ -2495,8 +2495,8 @@ class Graph(object):
|
||||
saved_labels = {}
|
||||
# Install the given label
|
||||
for op_type, label in op_to_kernel_label_map.items():
|
||||
if not (isinstance(op_type, basestring)
|
||||
and isinstance(label, basestring)):
|
||||
if not (isinstance(op_type, six.string_types)
|
||||
and isinstance(label, six.string_types)):
|
||||
raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
|
||||
"strings to strings")
|
||||
try:
|
||||
@ -2558,8 +2558,8 @@ class Graph(object):
|
||||
saved_mappings = {}
|
||||
# Install the given label
|
||||
for op_type, mapped_op_type in op_type_map.items():
|
||||
if not (isinstance(op_type, basestring)
|
||||
and isinstance(mapped_op_type, basestring)):
|
||||
if not (isinstance(op_type, six.string_types)
|
||||
and isinstance(mapped_op_type, six.string_types)):
|
||||
raise TypeError("op_type_map must be a dictionary mapping "
|
||||
"strings to strings")
|
||||
try:
|
||||
|
@ -207,7 +207,10 @@ def _FilterComplex(v):
|
||||
def _FilterStr(v):
|
||||
if isinstance(v, (list, tuple)):
|
||||
return _FirstNotNone([_FilterStr(x) for x in v])
|
||||
return None if isinstance(v, basestring) else _NotNone(v)
|
||||
if isinstance(v, (six.string_types, six.binary_type)):
|
||||
return None
|
||||
else:
|
||||
return _NotNone(v)
|
||||
|
||||
|
||||
def _FilterBool(v):
|
||||
|
@ -25,31 +25,15 @@
|
||||
%typemap(typecheck) tensorflow::StringPiece = char *;
|
||||
%typemap(typecheck) const tensorflow::StringPiece & = char *;
|
||||
|
||||
// "tensorflow::StringPiece" arguments can be provided by a simple Python 'str' string
|
||||
// or a 'unicode' object. If 'unicode', it's translated using the default
|
||||
// encoding, i.e., sys.getdefaultencoding(). If passed None, a tensorflow::StringPiece
|
||||
// of zero length with a NULL pointer is provided.
|
||||
// "tensorflow::StringPiece" arguments must be specified as a 'str' or 'bytes' object.
|
||||
%typemap(in) tensorflow::StringPiece {
|
||||
if ($input != Py_None) {
|
||||
char * buf;
|
||||
Py_ssize_t len;
|
||||
%#if PY_VERSION_HEX >= 0x03030000
|
||||
/* Do unicode handling as PyBytes_AsStringAndSize doesn't in Python 3. */
|
||||
if (PyUnicode_Check($input)) {
|
||||
buf = PyUnicode_AsUTF8AndSize($input, &len);
|
||||
if (buf == NULL)
|
||||
SWIG_fail;
|
||||
} else {
|
||||
%#elif PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 3
|
||||
%# error "Unsupported Python 3.x C API version (3.3 or later required)."
|
||||
%#endif
|
||||
if (PyBytes_AsStringAndSize($input, &buf, &len) == -1) {
|
||||
// Python has raised an error (likely TypeError or UnicodeEncodeError).
|
||||
SWIG_fail;
|
||||
}
|
||||
%#if PY_VERSION_HEX >= 0x03030000
|
||||
if (PyBytes_AsStringAndSize($input, &buf, &len) == -1) {
|
||||
// Python has raised an error (likely TypeError or UnicodeEncodeError).
|
||||
SWIG_fail;
|
||||
}
|
||||
%#endif
|
||||
$1.set(buf, len);
|
||||
}
|
||||
}
|
||||
@ -60,23 +44,10 @@
|
||||
if ($input != Py_None) {
|
||||
char * buf;
|
||||
Py_ssize_t len;
|
||||
%#if PY_VERSION_HEX >= 0x03030000
|
||||
/* Do unicode handling as PyBytes_AsStringAndSize doesn't in Python 3. */
|
||||
if (PyUnicode_Check($input)) {
|
||||
buf = PyUnicode_AsUTF8AndSize($input, &len);
|
||||
if (buf == NULL)
|
||||
SWIG_fail;
|
||||
} else {
|
||||
%#elif PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 3
|
||||
%# error "Unsupported Python 3.x C API version (3.3 or later required)."
|
||||
%#endif
|
||||
if (PyBytes_AsStringAndSize($input, &buf, &len) == -1) {
|
||||
// Python has raised an error (likely TypeError or UnicodeEncodeError).
|
||||
SWIG_fail;
|
||||
}
|
||||
%#if PY_VERSION_HEX >= 0x03030000
|
||||
if (PyBytes_AsStringAndSize($input, &buf, &len) == -1) {
|
||||
// Python has raised an error (likely TypeError).
|
||||
SWIG_fail;
|
||||
}
|
||||
%#endif
|
||||
temp.set(buf, len);
|
||||
}
|
||||
$1 = &temp;
|
||||
@ -86,7 +57,7 @@
|
||||
// or None if the StringPiece contained a NULL pointer.
|
||||
%typemap(out) tensorflow::StringPiece {
|
||||
if ($1.data()) {
|
||||
$result = PyString_FromStringAndSize($1.data(), $1.size());
|
||||
$result = PyBytes_FromStringAndSize($1.data(), $1.size());
|
||||
} else {
|
||||
Py_INCREF(Py_None);
|
||||
$result = Py_None;
|
||||
|
@ -5,6 +5,8 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import types
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -102,7 +104,7 @@ def global_norm(t_list, name=None):
|
||||
TypeError: If `t_list` is not a sequence.
|
||||
"""
|
||||
if (not isinstance(t_list, collections.Sequence)
|
||||
or isinstance(t_list, basestring)):
|
||||
or isinstance(t_list, six.string_types)):
|
||||
raise TypeError("t_list should be a sequence")
|
||||
t_list = list(t_list)
|
||||
with ops.op_scope(t_list, name, "global_norm") as name:
|
||||
@ -164,7 +166,7 @@ def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):
|
||||
TypeError: If `t_list` is not a sequence.
|
||||
"""
|
||||
if (not isinstance(t_list, collections.Sequence)
|
||||
or isinstance(t_list, basestring)):
|
||||
or isinstance(t_list, six.string_types)):
|
||||
raise TypeError("t_list should be a sequence")
|
||||
t_list = list(t_list)
|
||||
if use_norm is None:
|
||||
|
@ -5,6 +5,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numbers
|
||||
import six
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.framework import op_def_pb2
|
||||
@ -128,7 +129,7 @@ def _MakeFloat(v, arg_name):
|
||||
|
||||
|
||||
def _MakeInt(v, arg_name):
|
||||
if isinstance(v, basestring):
|
||||
if isinstance(v, six.string_types):
|
||||
raise TypeError("Expected int for argument '%s' not %s." %
|
||||
(arg_name, repr(v)))
|
||||
try:
|
||||
@ -139,7 +140,7 @@ def _MakeInt(v, arg_name):
|
||||
|
||||
|
||||
def _MakeStr(v, arg_name):
|
||||
if not isinstance(v, basestring):
|
||||
if not isinstance(v, six.string_types):
|
||||
raise TypeError("Expected string for argument '%s' not %s." %
|
||||
(arg_name, repr(v)))
|
||||
# TODO(irving): Figure out what to do here from Python 3
|
||||
|
@ -5,6 +5,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
@ -296,14 +297,14 @@ def variable_scope(name_or_scope, reuse=None, initializer=None):
|
||||
a reuse scope, or if reuse is not `None` or `True`.
|
||||
TypeError: when the types of some arguments are not appropriate.
|
||||
"""
|
||||
if not isinstance(name_or_scope, (_VariableScope, basestring)):
|
||||
if not isinstance(name_or_scope, (_VariableScope,) + six.string_types):
|
||||
raise TypeError("VariableScope: name_scope must be a string or "
|
||||
"VariableScope.")
|
||||
if reuse not in [None, True]:
|
||||
raise ValueError("VariableScope reuse parameter must be True or None.")
|
||||
if not reuse and isinstance(name_or_scope, (_VariableScope)):
|
||||
logging.info("Passing VariableScope to a non-reusing scope, intended?")
|
||||
if reuse and isinstance(name_or_scope, (basestring)):
|
||||
if reuse and isinstance(name_or_scope, six.string_types):
|
||||
logging.info("Re-using string-named scope, consider capturing as object.")
|
||||
get_variable_scope() # Ensure that a default exists, then get a pointer.
|
||||
default_varscope = ops.get_collection(_VARSCOPE_KEY)
|
||||
|
@ -21,13 +21,7 @@
|
||||
bool _PyObjAs(PyObject *pystr, ::string* cstr) {
|
||||
char *buf;
|
||||
Py_ssize_t len;
|
||||
#if PY_VERSION_HEX >= 0x03030000
|
||||
if (PyUnicode_Check(pystr)) {
|
||||
buf = PyUnicode_AsUTF8AndSize(pystr, &len);
|
||||
if (!buf) return false;
|
||||
} else // NOLINT
|
||||
#endif
|
||||
if (PyBytes_AsStringAndSize(pystr, &buf, &len) == -1) return false;
|
||||
if (PyBytes_AsStringAndSize(pystr, &buf, &len) == -1) return false;
|
||||
if (cstr) cstr->assign(buf, len);
|
||||
return true;
|
||||
}
|
||||
@ -36,29 +30,23 @@
|
||||
bool _PyObjAs(PyObject *pystr, std::string* cstr) {
|
||||
char *buf;
|
||||
Py_ssize_t len;
|
||||
#if PY_VERSION_HEX >= 0x03030000
|
||||
if (PyUnicode_Check(pystr)) {
|
||||
buf = PyUnicode_AsUTF8AndSize(pystr, &len);
|
||||
if (!buf) return false;
|
||||
} else // NOLINT
|
||||
#endif
|
||||
if (PyBytes_AsStringAndSize(pystr, &buf, &len) == -1) return false;
|
||||
if (PyBytes_AsStringAndSize(pystr, &buf, &len) == -1) return false;
|
||||
if (cstr) cstr->assign(buf, len);
|
||||
return true;
|
||||
}
|
||||
#ifdef HAS_GLOBAL_STRING
|
||||
template<>
|
||||
PyObject* _PyObjFrom(const ::string& c) {
|
||||
return PyString_FromStringAndSize(c.data(), c.size());
|
||||
return PyBytes_FromStringAndSize(c.data(), c.size());
|
||||
}
|
||||
#endif
|
||||
template<>
|
||||
PyObject* _PyObjFrom(const std::string& c) {
|
||||
return PyString_FromStringAndSize(c.data(), c.size());
|
||||
return PyBytes_FromStringAndSize(c.data(), c.size());
|
||||
}
|
||||
|
||||
PyObject* _SwigString_FromString(const string& s) {
|
||||
return PyUnicode_FromStringAndSize(s.data(), s.size());
|
||||
return PyBytes_FromStringAndSize(s.data(), s.size());
|
||||
}
|
||||
%}
|
||||
|
||||
@ -72,11 +60,11 @@
|
||||
}
|
||||
|
||||
%typemap(out) string {
|
||||
$result = PyString_FromStringAndSize($1.data(), $1.size());
|
||||
$result = PyBytes_FromStringAndSize($1.data(), $1.size());
|
||||
}
|
||||
|
||||
%typemap(out) const string& {
|
||||
$result = PyString_FromStringAndSize($1->data(), $1->size());
|
||||
$result = PyBytes_FromStringAndSize($1->data(), $1->size());
|
||||
}
|
||||
|
||||
%typemap(in, numinputs = 0) string* OUTPUT (string temp) {
|
||||
@ -84,7 +72,7 @@
|
||||
}
|
||||
|
||||
%typemap(argout) string * OUTPUT {
|
||||
PyObject *str = PyString_FromStringAndSize($1->data(), $1->length());
|
||||
PyObject *str = PyBytes_FromStringAndSize($1->data(), $1->length());
|
||||
if (!str) SWIG_fail;
|
||||
%append_output(str);
|
||||
}
|
||||
@ -92,7 +80,7 @@
|
||||
%typemap(argout) string* INOUT = string* OUTPUT;
|
||||
|
||||
%typemap(varout) string {
|
||||
$result = PyString_FromStringAndSize($1.data(), $1.size());
|
||||
$result = PyBytes_FromStringAndSize($1.data(), $1.size());
|
||||
}
|
||||
|
||||
%define _LIST_OUTPUT_TYPEMAP(type, py_converter)
|
||||
|
@ -9,6 +9,8 @@ import numbers
|
||||
import os.path
|
||||
import time
|
||||
|
||||
import six
|
||||
|
||||
from google.protobuf import text_format
|
||||
|
||||
from tensorflow.python.client import graph_util
|
||||
@ -299,7 +301,7 @@ class BaseSaverBuilder(object):
|
||||
vars_to_save = []
|
||||
seen_variables = set()
|
||||
for name in sorted(names_to_variables.keys()):
|
||||
if not isinstance(name, basestring):
|
||||
if not isinstance(name, six.string_types):
|
||||
raise TypeError("names_to_variables must be a dict mapping string "
|
||||
"names to variable Tensors. Name is not a string: %s" %
|
||||
name)
|
||||
|
@ -10,6 +10,7 @@ import tensorflow.python.platform
|
||||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python.platform import gfile
|
||||
|
||||
@ -33,7 +34,7 @@ class SaverTest(tf.test.TestCase):
|
||||
|
||||
# Save the initialized values in the file at "save_path"
|
||||
val = save.save(sess, save_path)
|
||||
self.assertTrue(isinstance(val, basestring))
|
||||
self.assertTrue(isinstance(val, six.string_types))
|
||||
self.assertEqual(save_path, val)
|
||||
|
||||
# Start a second session. In that session the parameter nodes
|
||||
@ -84,7 +85,7 @@ class SaverTest(tf.test.TestCase):
|
||||
|
||||
# Save the initialized values in the file at "save_path"
|
||||
val = save.save(sess, save_path)
|
||||
self.assertTrue(isinstance(val, basestring))
|
||||
self.assertTrue(isinstance(val, six.string_types))
|
||||
self.assertEqual(save_path, val)
|
||||
|
||||
with self.test_session() as sess:
|
||||
@ -131,7 +132,7 @@ class SaverTest(tf.test.TestCase):
|
||||
|
||||
# Save the initialized values in the file at "save_path"
|
||||
val = save.save(sess, save_path)
|
||||
self.assertTrue(isinstance(val, basestring))
|
||||
self.assertTrue(isinstance(val, six.string_types))
|
||||
self.assertEqual(save_path, val)
|
||||
|
||||
# Start a second session. In that session the variables
|
||||
@ -517,7 +518,7 @@ class SaveRestoreWithVariableNameMap(tf.test.TestCase):
|
||||
# Save the initialized values in the file at "save_path"
|
||||
# Use a variable name map to set the saved tensor names
|
||||
val = save.save(sess, save_path)
|
||||
self.assertTrue(isinstance(val, basestring))
|
||||
self.assertTrue(isinstance(val, six.string_types))
|
||||
self.assertEqual(save_path, val)
|
||||
|
||||
# Verify that the original names are not in the Saved file
|
||||
|
@ -9,6 +9,8 @@ import Queue
|
||||
import threading
|
||||
import time
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.core.framework import summary_pb2
|
||||
from tensorflow.core.util import event_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
@ -103,7 +105,7 @@ class SummaryWriter(object):
|
||||
global_step: Number. Optional global step value to record with the
|
||||
summary.
|
||||
"""
|
||||
if isinstance(summary, basestring):
|
||||
if isinstance(summary, six.binary_type):
|
||||
summ = summary_pb2.Summary()
|
||||
summ.ParseFromString(summary)
|
||||
summary = summ
|
||||
|
@ -82,7 +82,7 @@ def assertProto2Equal(self, a, b, check_initialized=True,
|
||||
numbers before comparison.
|
||||
msg: if specified, is used as the error message on failure
|
||||
"""
|
||||
if isinstance(a, basestring):
|
||||
if isinstance(a, six.string_types):
|
||||
a = text_format.Merge(a, b.__class__())
|
||||
|
||||
for pb in a, b:
|
||||
@ -121,7 +121,7 @@ def assertProto2SameElements(self, a, b, number_matters=False,
|
||||
numbers before comparison.
|
||||
msg: if specified, is used as the error message on failure
|
||||
"""
|
||||
if isinstance(a, basestring):
|
||||
if isinstance(a, six.string_types):
|
||||
a = text_format.Merge(a, b.__class__())
|
||||
else:
|
||||
a = copy.deepcopy(a)
|
||||
@ -152,7 +152,7 @@ def assertProto2Contains(self, a, b, # pylint: disable=invalid-name
|
||||
check_initialized: boolean, whether to fail if b isn't initialized
|
||||
msg: if specified, is used as the error message on failure
|
||||
"""
|
||||
if isinstance(a, basestring):
|
||||
if isinstance(a, six.string_types):
|
||||
a = text_format.Merge(a, b.__class__())
|
||||
else:
|
||||
a = copy.deepcopy(a)
|
||||
@ -299,7 +299,7 @@ def NormalizeNumberFields(pb):
|
||||
|
||||
|
||||
def _IsRepeatedContainer(value):
|
||||
if isinstance(value, basestring):
|
||||
if isinstance(value, six.string_types):
|
||||
return False
|
||||
try:
|
||||
iter(value)
|
||||
|
@ -342,12 +342,12 @@ class NormalizeNumbersTest(googletest.TestCase):
|
||||
class AssertTest(googletest.TestCase):
|
||||
"""Tests both assertProto2Equal() and assertProto2SameElements()."""
|
||||
def assertProto2Equal(self, a, b, **kwargs):
|
||||
if isinstance(a, basestring) and isinstance(b, basestring):
|
||||
if isinstance(a, six.string_types) and isinstance(b, six.string_types):
|
||||
a, b = LargePbs(a, b)
|
||||
compare.assertProto2Equal(self, a, b, **kwargs)
|
||||
|
||||
def assertProto2SameElements(self, a, b, **kwargs):
|
||||
if isinstance(a, basestring) and isinstance(b, basestring):
|
||||
if isinstance(a, six.string_types) and isinstance(b, six.string_types):
|
||||
a, b = LargePbs(a, b)
|
||||
compare.assertProto2SameElements(self, a, b, **kwargs)
|
||||
|
||||
|
@ -16,9 +16,9 @@
|
||||
"private": true,
|
||||
"dependencies": {
|
||||
"d3": "3.5.6",
|
||||
"dagre": "~0.7.4",
|
||||
"es6-promise": "~3.0.2",
|
||||
"graphlib": "~1.0.7",
|
||||
"dagre": "0.7.4",
|
||||
"es6-promise": "3.0.2",
|
||||
"graphlib": "1.0.7",
|
||||
"iron-ajax": "PolymerElements/iron-ajax#1.0.7",
|
||||
"iron-collapse": "PolymerElements/iron-collapse#1.0.4",
|
||||
"iron-list": "PolymerElements/iron-list#1.1.5",
|
||||
@ -38,7 +38,7 @@
|
||||
"paper-styles": "PolymerElements/paper-styles#1.0.12",
|
||||
"paper-toggle-button": "PolymerElements/paper-toggle-button#1.0.11",
|
||||
"paper-toolbar": "PolymerElements/paper-toolbar#1.0.4",
|
||||
"plottable": "~1.16.1",
|
||||
"plottable": "1.16.2",
|
||||
"polymer": "1.1.5"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
Loading…
Reference in New Issue
Block a user