Delete tf.contrib.training.python_input.
It has been replaced by tf.contrib.data.Dataset.from_generator. PiperOrigin-RevId: 167004190
This commit is contained in:
parent
f9c5e921dd
commit
48e3b62541
@ -9,6 +9,9 @@
|
||||
for LSTMs and stacked LSTMs. This bug fix follows recommendations from
|
||||
published literature, but is a behavioral change. State dropout behavior
|
||||
may be customized via the new `dropout_state_filter_visitor` argument.
|
||||
* Removed `tf.contrib.training.python_input`. The same behavior, in a more
|
||||
flexible and reproducible package, is available via the new
|
||||
`tf.contrib.data.Dataset.from_generator` method!
|
||||
|
||||
# Release 1.3.0
|
||||
|
||||
|
@ -595,6 +595,23 @@ class Dataset(object):
|
||||
The elements generated by `generator` must be compatible with the given
|
||||
`output_types` and (optional) `output_shapes` arguments.
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
import itertools
|
||||
|
||||
def gen():
|
||||
for i in itertools.count(1):
|
||||
yield (i, [1] * i)
|
||||
|
||||
ds = Dataset.from_generator(
|
||||
gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None])))
|
||||
value = ds.make_one_shot_iterator().get_next()
|
||||
|
||||
sess.run(value) # (1, array([1]))
|
||||
sess.run(value) # (2, array([1, 1]))
|
||||
```
|
||||
|
||||
Args:
|
||||
generator: A callable object that takes no arguments and returns an
|
||||
object that supports the `iter()` protocol.
|
||||
|
@ -23,7 +23,6 @@ py_library(
|
||||
"python/training/evaluation.py",
|
||||
"python/training/feeding_queue_runner.py",
|
||||
"python/training/hparam.py",
|
||||
"python/training/python_input.py",
|
||||
"python/training/resample.py",
|
||||
"python/training/sampling_ops.py",
|
||||
"python/training/sequence_queueing_state_saver.py",
|
||||
@ -226,23 +225,6 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "python_input_test",
|
||||
size = "medium",
|
||||
srcs = ["python/training/python_input_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["manual"],
|
||||
deps = [
|
||||
":training_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "evaluation_test",
|
||||
size = "small",
|
||||
|
@ -36,7 +36,6 @@ See @{$python/contrib.training} guide.
|
||||
@@HParams
|
||||
@@HParamDef
|
||||
@@parse_values
|
||||
@@python_input
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@ -55,7 +54,6 @@ from tensorflow.contrib.training.python.training.evaluation import SummaryAtEndH
|
||||
from tensorflow.contrib.training.python.training.evaluation import wait_for_new_checkpoint
|
||||
from tensorflow.contrib.training.python.training.feeding_queue_runner import FeedingQueueRunner
|
||||
from tensorflow.contrib.training.python.training.hparam import *
|
||||
from tensorflow.contrib.training.python.training.python_input import python_input
|
||||
from tensorflow.contrib.training.python.training.resample import *
|
||||
from tensorflow.contrib.training.python.training.sampling_ops import *
|
||||
from tensorflow.contrib.training.python.training.sequence_queueing_state_saver import *
|
||||
|
@ -1,178 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Operations for asynchronously reading data from python into queues.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
|
||||
|
||||
def _process_yielded_dict(feature_values, keys, features, dtypes, shapes):
|
||||
"""Read feature_values from the generator and emit a proper output dict."""
|
||||
if not isinstance(feature_values, dict):
|
||||
raise TypeError("generator must return dict, saw: %s" % feature_values)
|
||||
|
||||
processed_values = {}
|
||||
for pk in keys:
|
||||
if feature_values.get(pk, None) is not None:
|
||||
processed_values[pk] = np.asarray(
|
||||
feature_values[pk], dtype=dtypes[pk].as_numpy_dtype)
|
||||
check_shape = tensor_shape.TensorShape(processed_values[pk].shape)
|
||||
if not shapes[pk].is_compatible_with(check_shape):
|
||||
raise ValueError(
|
||||
"Feature '%s' has shape %s that is incompatible with declared "
|
||||
"shape: %s" % (pk, shapes[pk], check_shape))
|
||||
continue
|
||||
if isinstance(features[pk], parsing_ops.FixedLenFeature):
|
||||
if features[pk].default_value is not None:
|
||||
processed_values[pk] = np.asarray(
|
||||
features[pk].default_value, dtype=dtypes[pk].as_numpy_dtype)
|
||||
elif isinstance(features[pk], parsing_ops.FixedLenSequenceFeature):
|
||||
processed_values[pk] = np.empty(
|
||||
[0] + features[pk].shape.aslist(), dtype=dtypes[pk].as_numpy_dtype)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Expected generator to return key '%s' with non-empty value" % pk)
|
||||
|
||||
return processed_values
|
||||
|
||||
|
||||
def python_input(generator, features, name=None):
|
||||
"""Easily feed data from a python generator into TensorFlow queues.
|
||||
|
||||
Example usage:
|
||||
|
||||
```python
|
||||
def generator():
|
||||
for i in range(3):
|
||||
yield {"value": i}
|
||||
|
||||
features = {
|
||||
"value": tf.FixedLenFeature(shape=[], dtype=dtypes.int32)
|
||||
}
|
||||
|
||||
tensor_dict = tf.contrib.training.python_input(generator, features)
|
||||
batched_dict = tf.train.batch(
|
||||
tensor_dict, batch_size=2, allow_smaller_final_batch=True)
|
||||
|
||||
s = tf.Session()
|
||||
tf.train.start_queue_runners()
|
||||
|
||||
batch1 = s.run(batched_dict) # returns {"value": np.array([0, 1])}
|
||||
batch2 = s.run(batched_dict) # returns {"value": np.array([2])}
|
||||
s.run(batched_dict) # error: Queue is closed (generator finished at i==3)
|
||||
```
|
||||
|
||||
Args:
|
||||
generator: A python generator that takes no arguments, and yields dicts
|
||||
containing a single minibatch entry one at a time.
|
||||
features: A python `dict` mapping keys expected from the generator to
|
||||
instances of `tf.FixedLenFeature`, or `tf.FixedLenSequenceFeature`.
|
||||
name: (Optional) A name for the operations.
|
||||
|
||||
Returns:
|
||||
A dict mapping keys of the `features` dict to `Tensor` objects.
|
||||
These `Tensor` objects are outputs of a queue that is fed by `generator`.
|
||||
|
||||
Raises:
|
||||
TypeError: If generator is not callable or features is not a dict.
|
||||
TypeError: If any of features' values are not a Feature object.
|
||||
NotImplementedError: If any of features' values are instances of
|
||||
`SparseFeature` or `VarLenFeature` (these are not currently supported).
|
||||
ValueError: If any FixedLenSequenceFeatures contain a default value
|
||||
(this field is not supported).
|
||||
ValueError: if any FixedLenSequenceFeatures have allow_missing=False
|
||||
(this field is not supported).
|
||||
"""
|
||||
if not callable(generator):
|
||||
raise TypeError("generator must be callable, saw: %s" % generator)
|
||||
if not isinstance(features, dict):
|
||||
raise TypeError("features must be a dict, saw: %s"
|
||||
% type(features).__name__)
|
||||
|
||||
with ops.name_scope(name, "python_input"):
|
||||
shapes = {}
|
||||
dtypes = {}
|
||||
for k, v in features.items():
|
||||
if isinstance(v, parsing_ops.FixedLenFeature):
|
||||
if v.default_value is not None:
|
||||
value = ops.convert_to_tensor(v.default_value, dtype=v.dtype, name=k)
|
||||
shapes[k] = value.shape
|
||||
dtypes[k] = value.dtype
|
||||
else:
|
||||
tensor_shape.TensorShape(v.shape).assert_is_fully_defined()
|
||||
shapes[k] = tensor_shape.TensorShape(v.shape)
|
||||
dtypes[k] = v.dtype
|
||||
elif isinstance(v, parsing_ops.VarLenFeature):
|
||||
raise NotImplementedError("VarLenFeature not supported")
|
||||
elif isinstance(v, parsing_ops.SparseFeature):
|
||||
raise NotImplementedError("SparseFeature not supported")
|
||||
elif isinstance(v, parsing_ops.FixedLenSequenceFeature):
|
||||
if v.default_value is not None:
|
||||
raise ValueError("FixedLenSequenceFeature with default value not "
|
||||
"supported")
|
||||
if not v.allow_missing:
|
||||
raise ValueError("FixedLenSequenceFeature with allow_missing=False "
|
||||
"not supported")
|
||||
tensor_shape.TensorShape(v.shape).assert_is_fully_defined()
|
||||
shapes[k] = tensor_shape.TensorShape([None]).concatenate(v.shape)
|
||||
dtypes[k] = v.dtype
|
||||
else:
|
||||
raise TypeError(
|
||||
"Expected value for features key '%s' to be one of "
|
||||
"FixedLenFeature, VarLenFeature, SparseFeature, or "
|
||||
"FixedLenSequenceFeature. Got: %s" % (k, v))
|
||||
|
||||
keys = list(shapes.keys())
|
||||
dtypes_list = [dtypes[pk] for pk in keys]
|
||||
|
||||
counter = [0]
|
||||
lock = threading.Lock()
|
||||
iterator = iter(generator())
|
||||
|
||||
def generator_iter():
|
||||
"""Iterate through generator output and return np.arrays to py_func."""
|
||||
with lock:
|
||||
try:
|
||||
feature_values = next(iterator)
|
||||
counter[0] += 1
|
||||
except StopIteration as e:
|
||||
raise StopIteration("Iteration finished. Processed %d entries (%s)"
|
||||
% (counter[0], e))
|
||||
|
||||
processed_dict = _process_yielded_dict(
|
||||
feature_values, keys, features, dtypes, shapes)
|
||||
return [processed_dict[pk] for pk in keys]
|
||||
|
||||
generator_pyfunc_values = script_ops.py_func(
|
||||
generator_iter, inp=[], Tout=dtypes_list, stateful=True)
|
||||
|
||||
pyfunc_input = {k: v for (k, v) in zip(keys, generator_pyfunc_values)}
|
||||
for k, v in shapes.items():
|
||||
pyfunc_input[k].set_shape(v)
|
||||
|
||||
return pyfunc_input
|
||||
|
||||
|
||||
__all__ = ["python_input"]
|
@ -1,191 +0,0 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tf.contrib.training.python_input."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from tensorflow.contrib.training.python.training import bucket_ops
|
||||
from tensorflow.contrib.training.python.training import python_input
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import coordinator
|
||||
from tensorflow.python.training import input as core_input
|
||||
from tensorflow.python.training import queue_runner_impl
|
||||
|
||||
|
||||
class PythonInputTest(test.TestCase):
|
||||
|
||||
def testGenerator(self):
|
||||
def simple_generator():
|
||||
for i in range(2):
|
||||
yield {"value": i, "ignored": 3}
|
||||
|
||||
simple_features = {
|
||||
"value": parsing_ops.FixedLenFeature(shape=[], dtype=dtypes.int32)
|
||||
}
|
||||
tensors = python_input.python_input(simple_generator, simple_features)
|
||||
self.assertEqual(["value"], tensors.keys())
|
||||
self.assertEqual(dtypes.int32, tensors["value"].dtype)
|
||||
self.assertEqual((), tensors["value"].shape)
|
||||
|
||||
with self.test_session() as sess:
|
||||
self.assertEqual({"value": 0}, sess.run(tensors))
|
||||
self.assertEqual({"value": 1}, sess.run(tensors))
|
||||
with self.assertRaisesOpError("Iteration finished"):
|
||||
sess.run(tensors)
|
||||
|
||||
def testInvalidGenerator(self):
|
||||
generator1 = lambda: iter([{"value": "a"}])
|
||||
int_features = {
|
||||
"value": parsing_ops.FixedLenFeature(shape=[], dtype=dtypes.int32)
|
||||
}
|
||||
tensors1 = python_input.python_input(generator1, int_features)
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.assertRaisesOpError("invalid literal"):
|
||||
# Can't convert a string to an integer
|
||||
sess.run(tensors1)
|
||||
|
||||
generator2 = lambda: iter([None])
|
||||
tensors2 = python_input.python_input(generator2, int_features)
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.assertRaisesOpError("generator must return dict"):
|
||||
sess.run(tensors2)
|
||||
|
||||
generator3 = lambda: iter([{"value": [1, 2]}])
|
||||
tensors3 = python_input.python_input(generator3, int_features)
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.assertRaisesOpError("incompatible with declared shape"):
|
||||
sess.run(tensors3)
|
||||
|
||||
def testGeneratorWorksWithBatching(self):
|
||||
def simple_generator():
|
||||
for i in range(5):
|
||||
yield {"value": i, "ignored": 3}
|
||||
|
||||
simple_features = {
|
||||
"value": parsing_ops.FixedLenFeature(shape=[], dtype=dtypes.int32)
|
||||
}
|
||||
tensors = python_input.python_input(simple_generator, simple_features)
|
||||
|
||||
# Request batches of size 4 at a time, the final batch may be smaller.
|
||||
batched_tensors = core_input.batch(tensors, batch_size=4,
|
||||
allow_smaller_final_batch=True)
|
||||
|
||||
self.assertEqual(["value"], batched_tensors.keys())
|
||||
self.assertEqual(dtypes.int32, batched_tensors["value"].dtype)
|
||||
self.assertEqual([None], batched_tensors["value"].shape.as_list())
|
||||
|
||||
with self.test_session() as sess:
|
||||
# The generator emits 5 items total. The first 4 are returned in
|
||||
# the first session run; the final one is returned in the
|
||||
# second. This works because allow_smaller_final_batch=True.
|
||||
coord = coordinator.Coordinator()
|
||||
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
|
||||
r1 = sess.run(batched_tensors)
|
||||
r2 = sess.run(batched_tensors)
|
||||
self.assertAllEqual([0, 1, 2, 3], r1["value"])
|
||||
self.assertEqual([4], r2["value"])
|
||||
with self.assertRaisesOpError("Iteration finished"):
|
||||
sess.run(tensors)
|
||||
coord.request_stop()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
def testGeneratorWorksWithManyBatchingThreads(self):
|
||||
def simple_generator():
|
||||
for i in range(5000):
|
||||
yield {"value": i, "ignored": 3}
|
||||
|
||||
simple_features = {
|
||||
"value": parsing_ops.FixedLenFeature(shape=[], dtype=dtypes.int32)
|
||||
}
|
||||
tensors = python_input.python_input(simple_generator, simple_features)
|
||||
|
||||
# Request batches of size 20 at a time, the final batch may be smaller.
|
||||
_, batched_tensors = bucket_ops.bucket(
|
||||
tensors, which_bucket=tensors["value"] % 5,
|
||||
batch_size=20, num_buckets=5, num_threads=7, capacity=17,
|
||||
allow_smaller_final_batch=True)
|
||||
|
||||
self.assertEqual(["value"], batched_tensors.keys())
|
||||
self.assertEqual(dtypes.int32, batched_tensors["value"].dtype)
|
||||
self.assertEqual([None], batched_tensors["value"].shape.as_list())
|
||||
|
||||
with self.test_session() as sess:
|
||||
# The generator emits 5 items total. The first 4 are returned in
|
||||
# the first session run; the final one is returned in the
|
||||
# second. This works because allow_smaller_final_batch=True.
|
||||
coord = coordinator.Coordinator()
|
||||
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
|
||||
results = []
|
||||
while True:
|
||||
try:
|
||||
r = sess.run(batched_tensors)
|
||||
results.extend(r["value"].tolist())
|
||||
except errors.OutOfRangeError:
|
||||
break
|
||||
coord.request_stop()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
self.assertEqual(sorted(results),
|
||||
list(range(5000)))
|
||||
|
||||
def testVaryingFieldsInGenerator(self):
|
||||
def simple_generator():
|
||||
for i in range(2):
|
||||
yield {"value": i,
|
||||
"seqlen_value": np.ones((i, 1))}
|
||||
|
||||
simple_features = {
|
||||
"value": parsing_ops.FixedLenFeature(shape=[], dtype=dtypes.int32),
|
||||
"seqlen_value": parsing_ops.FixedLenSequenceFeature(
|
||||
shape=[1], dtype=dtypes.float32, allow_missing=True),
|
||||
"empty_value": parsing_ops.FixedLenFeature(
|
||||
default_value=[-1, -2], dtype=dtypes.int32, shape=[2])
|
||||
}
|
||||
tensors = python_input.python_input(simple_generator, simple_features)
|
||||
self.assertEqual(
|
||||
set(["value", "seqlen_value", "empty_value"]), set(tensors.keys()))
|
||||
self.assertEqual(dtypes.int32, tensors["value"].dtype)
|
||||
self.assertEqual((), tensors["value"].shape)
|
||||
self.assertEqual(dtypes.float32, tensors["seqlen_value"].dtype)
|
||||
self.assertEqual([None, 1], tensors["seqlen_value"].shape.as_list())
|
||||
self.assertEqual(dtypes.int32, tensors["empty_value"].dtype)
|
||||
self.assertEqual([2], tensors["empty_value"].shape)
|
||||
|
||||
with self.test_session() as sess:
|
||||
r1 = sess.run(tensors)
|
||||
self.assertAllEqual(0, r1["value"])
|
||||
self.assertAllEqual(np.ones((0, 1)), r1["seqlen_value"])
|
||||
self.assertAllEqual([-1, -2], r1["empty_value"])
|
||||
|
||||
r2 = sess.run(tensors)
|
||||
self.assertAllEqual(1, r2["value"])
|
||||
self.assertAllEqual([[1]], r2["seqlen_value"])
|
||||
self.assertAllEqual([-1, -2], r2["empty_value"])
|
||||
|
||||
with self.assertRaisesOpError("Iteration finished"):
|
||||
sess.run(tensors)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user