diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 26ada939da2..fc67e8d5c1c 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -2,13 +2,14 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) -load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "py_test", "cuda_py_test") py_library( name = "tfe", srcs = ["tfe.py"], srcs_version = "PY2AND3", deps = [ + ":datasets", ":saver", "//tensorflow/python:framework_ops", "//tensorflow/python:util", @@ -31,6 +32,34 @@ cuda_py_test( ], ) +py_library( + name = "datasets", + srcs = ["datasets.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/util:nest", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:errors", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/eager:context", + ], +) + +py_test( + name = "datasets_test", + srcs = ["datasets_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":datasets", + "//tensorflow/contrib/data", + "//tensorflow/python:math_ops", + "//tensorflow/python/eager:test", + "//third_party/py/numpy", + ], +) + py_library( name = "saver", srcs = ["saver.py"], diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py new file mode 100644 index 00000000000..f7f9ddd7b08 --- /dev/null +++ b/tensorflow/contrib/eager/python/datasets.py @@ -0,0 +1,96 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Support for tf.contrib.data when eager execution is enabled.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + +from tensorflow.contrib.data.python.util import nest +from tensorflow.python.eager import context +from tensorflow.python.framework import errors +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import resource_variable_ops + +_uid_counter = 0 +_uid_lock = threading.Lock() + + +def _iterator_shared_name(): + with _uid_lock: + global _uid_counter + uid = _uid_counter + _uid_counter += 1 + return "eager_iterator_{}".format(uid) + + +class Iterator(object): + """An iterator producing tf.Tensor objects from a tf.contrib.data.Dataset.""" + + def __init__(self, dataset): + """Creates a new iterator over the given dataset. + + For example: + ```python + dataset = tf.contrib.data.Dataset.range(4) + for x in Iterator(dataset): + print(x) + ``` + + Args: + dataset: A `tf.contrib.data.Dataset` object. + + Raises: + RuntimeError: When invoked without eager execution enabled. + """ + + if not context.in_eager_mode(): + raise RuntimeError( + "{} objects only make sense when eager execution is enabled".format( + type(self))) + ds_variant = dataset.make_dataset_resource() + self._output_types = dataset.output_types + self._flat_output_types = nest.flatten(dataset.output_types) + self._flat_output_shapes = nest.flatten(dataset.output_shapes) + self._resource = gen_dataset_ops.iterator( + container="", + shared_name=_iterator_shared_name(), + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + gen_dataset_ops.make_iterator(ds_variant, self._resource) + + def __del__(self): + if self._resource is not None: + resource_variable_ops.destroy_resource_op(self._resource) + self._resource = None + + def __iter__(self): + return self + + def __next__(self): # For Python 3 compatibility + return self.next() + + def next(self): + """Return the next tf.Tensor from the dataset.""" + try: + ret = gen_dataset_ops.iterator_get_next( + self._resource, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + return nest.pack_sequence_as(self._output_types, ret) + except errors.OutOfRangeError: + raise StopIteration diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py new file mode 100644 index 00000000000..2f729581b17 --- /dev/null +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -0,0 +1,69 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data import Dataset +from tensorflow.contrib.eager.python import datasets +from tensorflow.python.eager import test +from tensorflow.python.ops import math_ops + + +class IteratorTest(test.TestCase): + + def testBasic(self): + got = [] + for t in datasets.Iterator(Dataset.range(4)): + got.append(t.numpy()) + self.assertAllEqual([0, 1, 2, 3], got) + + def testMultipleIteratorsOnTheSameDataset(self): + ds = Dataset.range(4) + it1 = datasets.Iterator(ds) + it2 = datasets.Iterator(ds) + got = [x.numpy() for x in it1] + self.assertAllEqual([0, 1, 2, 3], got) + + got = [x.numpy() for x in it2] + self.assertAllEqual([0, 1, 2, 3], got) + + def testNestedOutputs(self): + ds = Dataset.zip((Dataset.range(4), Dataset.zip((Dataset.range(4), + Dataset.range(4))))) + total = 0 + # The Iterator will return a nested structure of Tensor objects. + # Some funkiness to compare against simple integers. + for (i, x) in enumerate(datasets.Iterator(ds)): + want = (i, (i, i)) + got = (x[0].numpy(), (x[1][0].numpy(), x[1][1].numpy())) + self.assertEqual(got, want) + total += 1 + self.assertEqual(4, total) + + def testMapAndFilter(self): + # TODO(ashankar): Address this + self.skipTest('Not working yet, requires function attribute support') + + def even(x): + return math_ops.equal(math_ops.mod(x, 2), 0) + + it = datasets.Iterator(Dataset.range(8).map(math_ops.square).filter(even)) + got = [x.numpy() for x in it] + self.assertAllEqual([0, 4, 16, 36], got) + + +if __name__ == '__main__': + test.main()