diff --git a/RELEASE.md b/RELEASE.md index ec66e555add..d120f068cae 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -9,6 +9,9 @@ for LSTMs and stacked LSTMs. This bug fix follows recommendations from published literature, but is a behavioral change. State dropout behavior may be customized via the new `dropout_state_filter_visitor` argument. +* Removed `tf.contrib.training.python_input`. The same behavior, in a more + flexible and reproducible package, is available via the new + `tf.contrib.data.Dataset.from_generator` method! # Release 1.3.0 diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py index e692822aaff..5b77129a487 100644 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ b/tensorflow/contrib/data/python/ops/dataset_ops.py @@ -595,6 +595,23 @@ class Dataset(object): The elements generated by `generator` must be compatible with the given `output_types` and (optional) `output_shapes` arguments. + For example: + + ```python + import itertools + + def gen(): + for i in itertools.count(1): + yield (i, [1] * i) + + ds = Dataset.from_generator( + gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None]))) + value = ds.make_one_shot_iterator().get_next() + + sess.run(value) # (1, array([1])) + sess.run(value) # (2, array([1, 1])) + ``` + Args: generator: A callable object that takes no arguments and returns an object that supports the `iter()` protocol. diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index e8c6c349c8c..8e3d869a51c 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -23,7 +23,6 @@ py_library( "python/training/evaluation.py", "python/training/feeding_queue_runner.py", "python/training/hparam.py", - "python/training/python_input.py", "python/training/resample.py", "python/training/sampling_ops.py", "python/training/sequence_queueing_state_saver.py", @@ -226,23 +225,6 @@ py_test( ], ) -py_test( - name = "python_input_test", - size = "medium", - srcs = ["python/training/python_input_test.py"], - srcs_version = "PY2AND3", - tags = ["manual"], - deps = [ - ":training_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:training", - "//third_party/py/numpy", - ], -) - py_test( name = "evaluation_test", size = "small", diff --git a/tensorflow/contrib/training/__init__.py b/tensorflow/contrib/training/__init__.py index 87a70e6d164..da2de3e421b 100644 --- a/tensorflow/contrib/training/__init__.py +++ b/tensorflow/contrib/training/__init__.py @@ -36,7 +36,6 @@ See @{$python/contrib.training} guide. @@HParams @@HParamDef @@parse_values -@@python_input """ from __future__ import absolute_import @@ -55,7 +54,6 @@ from tensorflow.contrib.training.python.training.evaluation import SummaryAtEndH from tensorflow.contrib.training.python.training.evaluation import wait_for_new_checkpoint from tensorflow.contrib.training.python.training.feeding_queue_runner import FeedingQueueRunner from tensorflow.contrib.training.python.training.hparam import * -from tensorflow.contrib.training.python.training.python_input import python_input from tensorflow.contrib.training.python.training.resample import * from tensorflow.contrib.training.python.training.sampling_ops import * from tensorflow.contrib.training.python.training.sequence_queueing_state_saver import * diff --git a/tensorflow/contrib/training/python/training/python_input.py b/tensorflow/contrib/training/python/training/python_input.py deleted file mode 100644 index 7f5420a98a1..00000000000 --- a/tensorflow/contrib/training/python/training/python_input.py +++ /dev/null @@ -1,178 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Operations for asynchronously reading data from python into queues. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import threading - -import numpy as np - -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import parsing_ops -from tensorflow.python.ops import script_ops - - -def _process_yielded_dict(feature_values, keys, features, dtypes, shapes): - """Read feature_values from the generator and emit a proper output dict.""" - if not isinstance(feature_values, dict): - raise TypeError("generator must return dict, saw: %s" % feature_values) - - processed_values = {} - for pk in keys: - if feature_values.get(pk, None) is not None: - processed_values[pk] = np.asarray( - feature_values[pk], dtype=dtypes[pk].as_numpy_dtype) - check_shape = tensor_shape.TensorShape(processed_values[pk].shape) - if not shapes[pk].is_compatible_with(check_shape): - raise ValueError( - "Feature '%s' has shape %s that is incompatible with declared " - "shape: %s" % (pk, shapes[pk], check_shape)) - continue - if isinstance(features[pk], parsing_ops.FixedLenFeature): - if features[pk].default_value is not None: - processed_values[pk] = np.asarray( - features[pk].default_value, dtype=dtypes[pk].as_numpy_dtype) - elif isinstance(features[pk], parsing_ops.FixedLenSequenceFeature): - processed_values[pk] = np.empty( - [0] + features[pk].shape.aslist(), dtype=dtypes[pk].as_numpy_dtype) - else: - raise ValueError( - "Expected generator to return key '%s' with non-empty value" % pk) - - return processed_values - - -def python_input(generator, features, name=None): - """Easily feed data from a python generator into TensorFlow queues. - - Example usage: - - ```python - def generator(): - for i in range(3): - yield {"value": i} - - features = { - "value": tf.FixedLenFeature(shape=[], dtype=dtypes.int32) - } - - tensor_dict = tf.contrib.training.python_input(generator, features) - batched_dict = tf.train.batch( - tensor_dict, batch_size=2, allow_smaller_final_batch=True) - - s = tf.Session() - tf.train.start_queue_runners() - - batch1 = s.run(batched_dict) # returns {"value": np.array([0, 1])} - batch2 = s.run(batched_dict) # returns {"value": np.array([2])} - s.run(batched_dict) # error: Queue is closed (generator finished at i==3) - ``` - - Args: - generator: A python generator that takes no arguments, and yields dicts - containing a single minibatch entry one at a time. - features: A python `dict` mapping keys expected from the generator to - instances of `tf.FixedLenFeature`, or `tf.FixedLenSequenceFeature`. - name: (Optional) A name for the operations. - - Returns: - A dict mapping keys of the `features` dict to `Tensor` objects. - These `Tensor` objects are outputs of a queue that is fed by `generator`. - - Raises: - TypeError: If generator is not callable or features is not a dict. - TypeError: If any of features' values are not a Feature object. - NotImplementedError: If any of features' values are instances of - `SparseFeature` or `VarLenFeature` (these are not currently supported). - ValueError: If any FixedLenSequenceFeatures contain a default value - (this field is not supported). - ValueError: if any FixedLenSequenceFeatures have allow_missing=False - (this field is not supported). - """ - if not callable(generator): - raise TypeError("generator must be callable, saw: %s" % generator) - if not isinstance(features, dict): - raise TypeError("features must be a dict, saw: %s" - % type(features).__name__) - - with ops.name_scope(name, "python_input"): - shapes = {} - dtypes = {} - for k, v in features.items(): - if isinstance(v, parsing_ops.FixedLenFeature): - if v.default_value is not None: - value = ops.convert_to_tensor(v.default_value, dtype=v.dtype, name=k) - shapes[k] = value.shape - dtypes[k] = value.dtype - else: - tensor_shape.TensorShape(v.shape).assert_is_fully_defined() - shapes[k] = tensor_shape.TensorShape(v.shape) - dtypes[k] = v.dtype - elif isinstance(v, parsing_ops.VarLenFeature): - raise NotImplementedError("VarLenFeature not supported") - elif isinstance(v, parsing_ops.SparseFeature): - raise NotImplementedError("SparseFeature not supported") - elif isinstance(v, parsing_ops.FixedLenSequenceFeature): - if v.default_value is not None: - raise ValueError("FixedLenSequenceFeature with default value not " - "supported") - if not v.allow_missing: - raise ValueError("FixedLenSequenceFeature with allow_missing=False " - "not supported") - tensor_shape.TensorShape(v.shape).assert_is_fully_defined() - shapes[k] = tensor_shape.TensorShape([None]).concatenate(v.shape) - dtypes[k] = v.dtype - else: - raise TypeError( - "Expected value for features key '%s' to be one of " - "FixedLenFeature, VarLenFeature, SparseFeature, or " - "FixedLenSequenceFeature. Got: %s" % (k, v)) - - keys = list(shapes.keys()) - dtypes_list = [dtypes[pk] for pk in keys] - - counter = [0] - lock = threading.Lock() - iterator = iter(generator()) - - def generator_iter(): - """Iterate through generator output and return np.arrays to py_func.""" - with lock: - try: - feature_values = next(iterator) - counter[0] += 1 - except StopIteration as e: - raise StopIteration("Iteration finished. Processed %d entries (%s)" - % (counter[0], e)) - - processed_dict = _process_yielded_dict( - feature_values, keys, features, dtypes, shapes) - return [processed_dict[pk] for pk in keys] - - generator_pyfunc_values = script_ops.py_func( - generator_iter, inp=[], Tout=dtypes_list, stateful=True) - - pyfunc_input = {k: v for (k, v) in zip(keys, generator_pyfunc_values)} - for k, v in shapes.items(): - pyfunc_input[k].set_shape(v) - - return pyfunc_input - - -__all__ = ["python_input"] diff --git a/tensorflow/contrib/training/python/training/python_input_test.py b/tensorflow/contrib/training/python/training/python_input_test.py deleted file mode 100644 index afd0f38c2cd..00000000000 --- a/tensorflow/contrib/training/python/training/python_input_test.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tf.contrib.training.python_input.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -from tensorflow.contrib.training.python.training import bucket_ops -from tensorflow.contrib.training.python.training import python_input -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.ops import parsing_ops -from tensorflow.python.platform import test -from tensorflow.python.training import coordinator -from tensorflow.python.training import input as core_input -from tensorflow.python.training import queue_runner_impl - - -class PythonInputTest(test.TestCase): - - def testGenerator(self): - def simple_generator(): - for i in range(2): - yield {"value": i, "ignored": 3} - - simple_features = { - "value": parsing_ops.FixedLenFeature(shape=[], dtype=dtypes.int32) - } - tensors = python_input.python_input(simple_generator, simple_features) - self.assertEqual(["value"], tensors.keys()) - self.assertEqual(dtypes.int32, tensors["value"].dtype) - self.assertEqual((), tensors["value"].shape) - - with self.test_session() as sess: - self.assertEqual({"value": 0}, sess.run(tensors)) - self.assertEqual({"value": 1}, sess.run(tensors)) - with self.assertRaisesOpError("Iteration finished"): - sess.run(tensors) - - def testInvalidGenerator(self): - generator1 = lambda: iter([{"value": "a"}]) - int_features = { - "value": parsing_ops.FixedLenFeature(shape=[], dtype=dtypes.int32) - } - tensors1 = python_input.python_input(generator1, int_features) - - with self.test_session() as sess: - with self.assertRaisesOpError("invalid literal"): - # Can't convert a string to an integer - sess.run(tensors1) - - generator2 = lambda: iter([None]) - tensors2 = python_input.python_input(generator2, int_features) - - with self.test_session() as sess: - with self.assertRaisesOpError("generator must return dict"): - sess.run(tensors2) - - generator3 = lambda: iter([{"value": [1, 2]}]) - tensors3 = python_input.python_input(generator3, int_features) - - with self.test_session() as sess: - with self.assertRaisesOpError("incompatible with declared shape"): - sess.run(tensors3) - - def testGeneratorWorksWithBatching(self): - def simple_generator(): - for i in range(5): - yield {"value": i, "ignored": 3} - - simple_features = { - "value": parsing_ops.FixedLenFeature(shape=[], dtype=dtypes.int32) - } - tensors = python_input.python_input(simple_generator, simple_features) - - # Request batches of size 4 at a time, the final batch may be smaller. - batched_tensors = core_input.batch(tensors, batch_size=4, - allow_smaller_final_batch=True) - - self.assertEqual(["value"], batched_tensors.keys()) - self.assertEqual(dtypes.int32, batched_tensors["value"].dtype) - self.assertEqual([None], batched_tensors["value"].shape.as_list()) - - with self.test_session() as sess: - # The generator emits 5 items total. The first 4 are returned in - # the first session run; the final one is returned in the - # second. This works because allow_smaller_final_batch=True. - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) - r1 = sess.run(batched_tensors) - r2 = sess.run(batched_tensors) - self.assertAllEqual([0, 1, 2, 3], r1["value"]) - self.assertEqual([4], r2["value"]) - with self.assertRaisesOpError("Iteration finished"): - sess.run(tensors) - coord.request_stop() - for thread in threads: - thread.join() - - def testGeneratorWorksWithManyBatchingThreads(self): - def simple_generator(): - for i in range(5000): - yield {"value": i, "ignored": 3} - - simple_features = { - "value": parsing_ops.FixedLenFeature(shape=[], dtype=dtypes.int32) - } - tensors = python_input.python_input(simple_generator, simple_features) - - # Request batches of size 20 at a time, the final batch may be smaller. - _, batched_tensors = bucket_ops.bucket( - tensors, which_bucket=tensors["value"] % 5, - batch_size=20, num_buckets=5, num_threads=7, capacity=17, - allow_smaller_final_batch=True) - - self.assertEqual(["value"], batched_tensors.keys()) - self.assertEqual(dtypes.int32, batched_tensors["value"].dtype) - self.assertEqual([None], batched_tensors["value"].shape.as_list()) - - with self.test_session() as sess: - # The generator emits 5 items total. The first 4 are returned in - # the first session run; the final one is returned in the - # second. This works because allow_smaller_final_batch=True. - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) - results = [] - while True: - try: - r = sess.run(batched_tensors) - results.extend(r["value"].tolist()) - except errors.OutOfRangeError: - break - coord.request_stop() - for thread in threads: - thread.join() - self.assertEqual(sorted(results), - list(range(5000))) - - def testVaryingFieldsInGenerator(self): - def simple_generator(): - for i in range(2): - yield {"value": i, - "seqlen_value": np.ones((i, 1))} - - simple_features = { - "value": parsing_ops.FixedLenFeature(shape=[], dtype=dtypes.int32), - "seqlen_value": parsing_ops.FixedLenSequenceFeature( - shape=[1], dtype=dtypes.float32, allow_missing=True), - "empty_value": parsing_ops.FixedLenFeature( - default_value=[-1, -2], dtype=dtypes.int32, shape=[2]) - } - tensors = python_input.python_input(simple_generator, simple_features) - self.assertEqual( - set(["value", "seqlen_value", "empty_value"]), set(tensors.keys())) - self.assertEqual(dtypes.int32, tensors["value"].dtype) - self.assertEqual((), tensors["value"].shape) - self.assertEqual(dtypes.float32, tensors["seqlen_value"].dtype) - self.assertEqual([None, 1], tensors["seqlen_value"].shape.as_list()) - self.assertEqual(dtypes.int32, tensors["empty_value"].dtype) - self.assertEqual([2], tensors["empty_value"].shape) - - with self.test_session() as sess: - r1 = sess.run(tensors) - self.assertAllEqual(0, r1["value"]) - self.assertAllEqual(np.ones((0, 1)), r1["seqlen_value"]) - self.assertAllEqual([-1, -2], r1["empty_value"]) - - r2 = sess.run(tensors) - self.assertAllEqual(1, r2["value"]) - self.assertAllEqual([[1]], r2["seqlen_value"]) - self.assertAllEqual([-1, -2], r2["empty_value"]) - - with self.assertRaisesOpError("Iteration finished"): - sess.run(tensors) - - -if __name__ == "__main__": - test.main()