[tf.data] Internal minor code restructure

PiperOrigin-RevId: 170787468
This commit is contained in:
A. Unique TensorFlower 2017-10-02 18:23:21 -07:00 committed by TensorFlower Gardener
parent 991dea6bed
commit df3dbbadbc
23 changed files with 85 additions and 71 deletions

View File

@ -13,7 +13,7 @@ py_library(
"//tensorflow/contrib/data/python/ops:readers",
"//tensorflow/contrib/data/python/ops:transformation_ops",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
],
)

View File

@ -54,7 +54,7 @@ from tensorflow.contrib.data.python.ops.readers import TextLineDataset
from tensorflow.contrib.data.python.ops.readers import TFRecordDataset
from tensorflow.contrib.data.python.ops.resampling import rejection_resample
from tensorflow.contrib.data.python.ops.sloppy_ops import sloppy_interleave
from tensorflow.python.data.ops.dataset_ops import Iterator
from tensorflow.python.data.ops.iterator_ops import Iterator
# pylint: enable=unused-import
from tensorflow.python.util.all_util import remove_undocumented

View File

@ -62,6 +62,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:variables",
"//tensorflow/python/data/ops:iterator_ops",
"//third_party/py/numpy",
],
)
@ -160,6 +161,7 @@ py_test(
"//tensorflow/python:function",
"//tensorflow/python:functional_ops",
"//tensorflow/python:session",
"//tensorflow/python/data/ops:iterator_ops",
],
)
@ -188,6 +190,7 @@ py_test(
"//tensorflow/python:script_ops",
"//tensorflow/python:session",
"//tensorflow/python:training",
"//tensorflow/python/data/ops:iterator_ops",
"//third_party/py/numpy",
],
)
@ -252,6 +255,7 @@ py_test(
"//tensorflow/python:platform",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:variables",
"//tensorflow/python/data/ops:iterator_ops",
],
)
@ -261,7 +265,6 @@ py_test(
srcs = ["reader_dataset_ops_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/contrib/data/python/ops:readers",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
@ -275,6 +278,7 @@ py_test(
"//tensorflow/python:parsing_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:iterator_ops",
],
)
@ -338,6 +342,7 @@ py_test(
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python/data/ops:iterator_ops",
"//third_party/py/numpy",
],
)

View File

@ -24,6 +24,7 @@ import tempfile
import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -59,8 +60,8 @@ class FilesystemCacheDatasetTest(test.TestCase):
# Create initialization ops for iterators without and with
# caching, respectively.
iterator = dataset_ops.Iterator.from_structure(cache_dataset.output_types,
cache_dataset.output_shapes)
iterator = iterator_ops.Iterator.from_structure(cache_dataset.output_types,
cache_dataset.output_shapes)
init_fifo_op = iterator.make_initializer(repeat_dataset)
init_cache_op = iterator.make_initializer(cache_dataset)

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
@ -44,7 +45,7 @@ class IteratorClusterTest(test.TestCase):
iterator_3_handle = iterator_3.string_handle()
with ops.device("/job:worker/replica:0/task:0/cpu:0"):
remote_it = dataset_ops.Iterator.from_string_handle(
remote_it = iterator_ops.Iterator.from_string_handle(
iterator_3_handle, dataset_3.output_types, dataset_3.output_shapes)
get_next_op = remote_it.get_next()
@ -60,7 +61,7 @@ class IteratorClusterTest(test.TestCase):
@function.Defun(dtypes.string)
def _remote_fn(h):
remote_iterator = dataset_ops.Iterator.from_string_handle(
remote_iterator = iterator_ops.Iterator.from_string_handle(
h, dataset_3.output_types, dataset_3.output_shapes)
return remote_iterator.get_next()

View File

@ -24,6 +24,7 @@ from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.contrib.data.python.ops import readers
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -239,7 +240,7 @@ class IteratorTest(test.TestCase):
# functions in this graph, to ensure that we are not
# accidentally redefining functions with the same names in the
# new graph.
iterator = dataset_ops.Iterator.from_structure(
iterator = iterator_ops.Iterator.from_structure(
shared_name="shared_iterator",
output_types=(dtypes.int64, dtypes.int64, dtypes.float64),
output_shapes=([], [3], []))
@ -269,8 +270,8 @@ class IteratorTest(test.TestCase):
constant_op.constant([1, 2, 3]))
dataset_4 = dataset_ops.Dataset.from_tensors(
constant_op.constant([4, 5, 6, 7]))
iterator = dataset_ops.Iterator.from_structure(dataset_3.output_types,
[None])
iterator = iterator_ops.Iterator.from_structure(dataset_3.output_types,
[None])
dataset_3_init_op = iterator.make_initializer(dataset_3)
dataset_4_init_op = iterator.make_initializer(dataset_4)
@ -306,12 +307,12 @@ class IteratorTest(test.TestCase):
def testReinitializableIteratorStaticErrors(self):
# Non-matching structure for types and shapes.
with self.assertRaises(TypeError):
iterator = dataset_ops.Iterator.from_structure((dtypes.int64,
dtypes.float64), [None])
iterator = iterator_ops.Iterator.from_structure((dtypes.int64,
dtypes.float64), [None])
# Test validation of dataset argument.
iterator = dataset_ops.Iterator.from_structure((dtypes.int64,
dtypes.float64))
iterator = iterator_ops.Iterator.from_structure((dtypes.int64,
dtypes.float64))
# Incompatible structure.
with self.assertRaises(ValueError):
@ -328,7 +329,7 @@ class IteratorTest(test.TestCase):
[4., 5., 6., 7.], dtype=dtypes.float32))))
# Incompatible shapes.
iterator = dataset_ops.Iterator.from_structure(
iterator = iterator_ops.Iterator.from_structure(
(dtypes.int64, dtypes.float64), ([None], []))
with self.assertRaises(TypeError):
iterator.make_initializer(
@ -344,7 +345,7 @@ class IteratorTest(test.TestCase):
iterator_4 = dataset_4.make_one_shot_iterator()
handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
feedable_iterator = dataset_ops.Iterator.from_string_handle(
feedable_iterator = iterator_ops.Iterator.from_string_handle(
handle_placeholder, dataset_3.output_types, dataset_3.output_shapes)
next_element = feedable_iterator.get_next()
@ -391,11 +392,11 @@ class IteratorTest(test.TestCase):
handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
feedable_int_scalar = dataset_ops.Iterator.from_string_handle(
feedable_int_scalar = iterator_ops.Iterator.from_string_handle(
handle_placeholder, dtypes.int32, [])
feedable_int_vector = dataset_ops.Iterator.from_string_handle(
feedable_int_vector = iterator_ops.Iterator.from_string_handle(
handle_placeholder, dtypes.int32, [None])
feedable_int_any = dataset_ops.Iterator.from_string_handle(
feedable_int_any = iterator_ops.Iterator.from_string_handle(
handle_placeholder, dtypes.int32)
with self.test_session() as sess:
@ -435,7 +436,7 @@ class IteratorTest(test.TestCase):
@function.Defun(dtypes.string)
def _remote_fn(h):
remote_iterator = dataset_ops.Iterator.from_string_handle(
remote_iterator = iterator_ops.Iterator.from_string_handle(
h, dataset_3.output_types, dataset_3.output_shapes)
return remote_iterator.get_next()
@ -495,7 +496,7 @@ class IteratorTest(test.TestCase):
@function.Defun(dtypes.uint8)
def _remote_fn(h):
handle = script_ops.py_func(_encode_raw, [h], dtypes.string)
remote_iterator = dataset_ops.Iterator.from_string_handle(
remote_iterator = iterator_ops.Iterator.from_string_handle(
handle, dataset_3.output_types, dataset_3.output_shapes)
return remote_iterator.get_next()

View File

@ -21,6 +21,7 @@ import os
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.contrib.data.python.ops import enumerate_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -279,8 +280,8 @@ class RangeDatasetTest(test.TestCase):
# Create an empty IteratorResource and restore the Iterator into it.
output_types = dtypes.int64
output_shapes = tensor_shape.scalar()
iterator = dataset_ops.Iterator.from_structure(output_types,
output_shapes)
iterator = iterator_ops.Iterator.from_structure(output_types,
output_shapes)
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
path)
get_next = iterator.get_next()

View File

@ -21,10 +21,10 @@ import gzip
import os
import zlib
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.contrib.data.python.ops import readers
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -87,7 +87,7 @@ class TextLineDatasetTest(test.TestCase):
filenames, compression_type=compression_type).repeat(num_epochs)
batch_dataset = repeat_dataset.batch(batch_size)
iterator = dataset_ops.Iterator.from_structure(batch_dataset.output_types)
iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
init_op = iterator.make_initializer(repeat_dataset)
init_batch_op = iterator.make_initializer(batch_dataset)
get_next = iterator.get_next()
@ -199,7 +199,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
.repeat(num_epochs))
batch_dataset = repeat_dataset.batch(batch_size)
iterator = dataset_ops.Iterator.from_structure(batch_dataset.output_types)
iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
init_op = iterator.make_initializer(repeat_dataset)
init_batch_op = iterator.make_initializer(batch_dataset)
get_next = iterator.get_next()
@ -293,7 +293,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
def _restore_iterator(self):
output_types = dtypes.string
output_shapes = tensor_shape.scalar()
iterator = dataset_ops.Iterator.from_structure(output_types, output_shapes)
iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
get_next = iterator.get_next()
restore_op = gen_dataset_ops.restore_iterator(
iterator._iterator_resource, self._iterator_checkpoint_path())
@ -575,7 +575,7 @@ class TFRecordDatasetTest(test.TestCase):
self.num_epochs)
batch_dataset = repeat_dataset.batch(self.batch_size)
iterator = dataset_ops.Iterator.from_structure(batch_dataset.output_types)
iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
self.init_op = iterator.make_initializer(repeat_dataset)
self.init_batch_op = iterator.make_initializer(batch_dataset)
self.get_next = iterator.get_next()

View File

@ -22,6 +22,7 @@ import collections
import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -52,7 +53,7 @@ class ShuffleDatasetTest(test.TestCase):
# Create initialization ops for iterators without and with
# shuffling, respectively.
iterator = dataset_ops.Iterator.from_structure(
iterator = iterator_ops.Iterator.from_structure(
shuffle_dataset.output_types, shuffle_dataset.output_shapes)
init_fifo_op = iterator.make_initializer(repeat_dataset)
init_shuffle_op = iterator.make_initializer(shuffle_dataset)

View File

@ -16,7 +16,6 @@ py_library(
"//tensorflow/python:script_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator",
"//tensorflow/python/data/util:nest",
],
)

View File

@ -23,9 +23,6 @@ from tensorflow.contrib.data.python.ops import error_ops
from tensorflow.contrib.data.python.ops import grouping
from tensorflow.python.data.ops import dataset_ops
# pylint: disable=unused-import
from tensorflow.python.data.ops.iterator import Iterator
# pylint: enable=unused-import
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape

View File

@ -11,7 +11,7 @@ py_library(
deps = [
"//tensorflow/python:util",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator",
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/ops:readers",
],
)

View File

@ -29,7 +29,7 @@ from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.python.data.ops.dataset_ops import Dataset
from tensorflow.python.data.ops.iterator import Iterator
from tensorflow.python.data.ops.iterator_ops import Iterator
from tensorflow.python.data.ops.readers import FixedLengthRecordDataset
from tensorflow.python.data.ops.readers import TextLineDataset
from tensorflow.python.data.ops.readers import TFRecordDataset

View File

@ -9,7 +9,7 @@ py_library(
srcs = ["dataset_ops.py"],
srcs_version = "PY2AND3",
deps = [
":iterator",
":iterator_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
@ -41,8 +41,8 @@ py_library(
)
py_library(
name = "iterator",
srcs = ["iterator.py"],
name = "iterator_ops",
srcs = ["iterator_ops.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:dataset_ops_gen",

View File

@ -23,8 +23,7 @@ import threading
import numpy as np
from tensorflow.python.data.ops import iterator
from tensorflow.python.data.ops.iterator import Iterator # pylint: disable=unused-import
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -92,9 +91,8 @@ class Dataset(object):
with ops.colocate_with(iterator_resource):
initializer = gen_dataset_ops.make_iterator(
self._as_variant_tensor(), iterator_resource)
return iterator.Iterator(
iterator_resource, initializer, self.output_types,
self.output_shapes)
return iterator_ops.Iterator(iterator_resource, initializer,
self.output_types, self.output_shapes)
def make_one_shot_iterator(self):
"""Creates an `Iterator` for enumerating the elements of this dataset.
@ -113,7 +111,7 @@ class Dataset(object):
_make_dataset.add_to_graph(ops.get_default_graph())
return iterator.Iterator(
return iterator_ops.Iterator(
gen_dataset_ops.one_shot_iterator(
dataset_factory=_make_dataset,
output_types=nest.flatten(self.output_types),

View File

@ -2960,6 +2960,7 @@ tf_py_test(
"//tensorflow/python:tensor_shape",
"//tensorflow/python:variables",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
],
)
@ -2978,7 +2979,7 @@ tf_py_test(
"//tensorflow/python:lib",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/ops:readers",
],
)
@ -3009,6 +3010,7 @@ tf_py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
],
)
@ -3036,6 +3038,7 @@ tf_py_test(
"//tensorflow/python:errors",
"//tensorflow/python:variables",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
],
)
@ -3076,6 +3079,7 @@ tf_py_test(
"//tensorflow/python/data/ops:readers",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@ -3111,6 +3115,7 @@ tf_py_test(
"//tensorflow/python:functional_ops",
"//tensorflow/python:session",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
],
tags = ["no_windows"],
)

View File

@ -24,6 +24,7 @@ import tempfile
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -59,8 +60,8 @@ class FilesystemCacheDatasetTest(test.TestCase):
# Create initialization ops for iterators without and with
# caching, respectively.
iterator = dataset_ops.Iterator.from_structure(cache_dataset.output_types,
cache_dataset.output_shapes)
iterator = iterator_ops.Iterator.from_structure(cache_dataset.output_types,
cache_dataset.output_shapes)
init_fifo_op = iterator.make_initializer(repeat_dataset)
init_cache_op = iterator.make_initializer(cache_dataset)

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
@ -44,7 +45,7 @@ class IteratorClusterTest(test.TestCase):
iterator_3_handle = iterator_3.string_handle()
with ops.device("/job:worker/replica:0/task:0/cpu:0"):
remote_it = dataset_ops.Iterator.from_string_handle(
remote_it = iterator_ops.Iterator.from_string_handle(
iterator_3_handle, dataset_3.output_types, dataset_3.output_shapes)
get_next_op = remote_it.get_next()
@ -65,7 +66,7 @@ class IteratorClusterTest(test.TestCase):
@function.Defun(dtypes.string)
def _remote_fn(h):
remote_iterator = dataset_ops.Iterator.from_string_handle(
remote_iterator = iterator_ops.Iterator.from_string_handle(
h, dataset_3.output_types, dataset_3.output_shapes)
return remote_iterator.get_next()

View File

@ -23,6 +23,7 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -239,7 +240,7 @@ class IteratorTest(test.TestCase):
# functions in this graph, to ensure that we are not
# accidentally redefining functions with the same names in the
# new graph.
iterator = dataset_ops.Iterator.from_structure(
iterator = iterator_ops.Iterator.from_structure(
shared_name="shared_iterator",
output_types=(dtypes.int64, dtypes.int64, dtypes.float64),
output_shapes=([], [3], []))
@ -269,8 +270,8 @@ class IteratorTest(test.TestCase):
constant_op.constant([1, 2, 3]))
dataset_4 = dataset_ops.Dataset.from_tensors(
constant_op.constant([4, 5, 6, 7]))
iterator = dataset_ops.Iterator.from_structure(dataset_3.output_types,
[None])
iterator = iterator_ops.Iterator.from_structure(dataset_3.output_types,
[None])
dataset_3_init_op = iterator.make_initializer(dataset_3)
dataset_4_init_op = iterator.make_initializer(dataset_4)
@ -306,12 +307,12 @@ class IteratorTest(test.TestCase):
def testReinitializableIteratorStaticErrors(self):
# Non-matching structure for types and shapes.
with self.assertRaises(TypeError):
iterator = dataset_ops.Iterator.from_structure((dtypes.int64,
dtypes.float64), [None])
iterator = iterator_ops.Iterator.from_structure((dtypes.int64,
dtypes.float64), [None])
# Test validation of dataset argument.
iterator = dataset_ops.Iterator.from_structure((dtypes.int64,
dtypes.float64))
iterator = iterator_ops.Iterator.from_structure((dtypes.int64,
dtypes.float64))
# Incompatible structure.
with self.assertRaises(ValueError):
@ -328,7 +329,7 @@ class IteratorTest(test.TestCase):
[4., 5., 6., 7.], dtype=dtypes.float32))))
# Incompatible shapes.
iterator = dataset_ops.Iterator.from_structure(
iterator = iterator_ops.Iterator.from_structure(
(dtypes.int64, dtypes.float64), ([None], []))
with self.assertRaises(TypeError):
iterator.make_initializer(
@ -344,7 +345,7 @@ class IteratorTest(test.TestCase):
iterator_4 = dataset_4.make_one_shot_iterator()
handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
feedable_iterator = dataset_ops.Iterator.from_string_handle(
feedable_iterator = iterator_ops.Iterator.from_string_handle(
handle_placeholder, dataset_3.output_types, dataset_3.output_shapes)
next_element = feedable_iterator.get_next()
@ -391,11 +392,11 @@ class IteratorTest(test.TestCase):
handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
feedable_int_scalar = dataset_ops.Iterator.from_string_handle(
feedable_int_scalar = iterator_ops.Iterator.from_string_handle(
handle_placeholder, dtypes.int32, [])
feedable_int_vector = dataset_ops.Iterator.from_string_handle(
feedable_int_vector = iterator_ops.Iterator.from_string_handle(
handle_placeholder, dtypes.int32, [None])
feedable_int_any = dataset_ops.Iterator.from_string_handle(
feedable_int_any = iterator_ops.Iterator.from_string_handle(
handle_placeholder, dtypes.int32)
with self.test_session() as sess:
@ -435,7 +436,7 @@ class IteratorTest(test.TestCase):
@function.Defun(dtypes.string)
def _remote_fn(h):
remote_iterator = dataset_ops.Iterator.from_string_handle(
remote_iterator = iterator_ops.Iterator.from_string_handle(
h, dataset_3.output_types, dataset_3.output_shapes)
return remote_iterator.get_next()
@ -495,7 +496,7 @@ class IteratorTest(test.TestCase):
@function.Defun(dtypes.uint8)
def _remote_fn(h):
handle = script_ops.py_func(_encode_raw, [h], dtypes.string)
remote_iterator = dataset_ops.Iterator.from_string_handle(
remote_iterator = iterator_ops.Iterator.from_string_handle(
handle, dataset_3.output_types, dataset_3.output_shapes)
return remote_iterator.get_next()

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import os
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@ -255,8 +256,8 @@ class RangeDatasetTest(test.TestCase):
# Create an empty IteratorResource and restore the Iterator into it.
output_types = dtypes.int64
output_shapes = tensor_shape.scalar()
iterator = dataset_ops.Iterator.from_structure(output_types,
output_shapes)
iterator = iterator_ops.Iterator.from_structure(output_types,
output_shapes)
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
path)
get_next = iterator.get_next()

View File

@ -21,7 +21,7 @@ import gzip
import os
import zlib
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -84,7 +84,7 @@ class TextLineDatasetTest(test.TestCase):
filenames, compression_type=compression_type).repeat(num_epochs)
batch_dataset = repeat_dataset.batch(batch_size)
iterator = dataset_ops.Iterator.from_structure(batch_dataset.output_types)
iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
init_op = iterator.make_initializer(repeat_dataset)
init_batch_op = iterator.make_initializer(batch_dataset)
get_next = iterator.get_next()
@ -196,7 +196,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
.repeat(num_epochs))
batch_dataset = repeat_dataset.batch(batch_size)
iterator = dataset_ops.Iterator.from_structure(batch_dataset.output_types)
iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
init_op = iterator.make_initializer(repeat_dataset)
init_batch_op = iterator.make_initializer(batch_dataset)
get_next = iterator.get_next()
@ -290,7 +290,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
def _restore_iterator(self):
output_types = dtypes.string
output_shapes = tensor_shape.scalar()
iterator = dataset_ops.Iterator.from_structure(output_types, output_shapes)
iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
get_next = iterator.get_next()
restore_op = gen_dataset_ops.restore_iterator(
iterator._iterator_resource, self._iterator_checkpoint_path())
@ -572,7 +572,7 @@ class TFRecordDatasetTest(test.TestCase):
self.num_epochs)
batch_dataset = repeat_dataset.batch(self.batch_size)
iterator = dataset_ops.Iterator.from_structure(batch_dataset.output_types)
iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
self.init_op = iterator.make_initializer(repeat_dataset)
self.init_batch_op = iterator.make_initializer(batch_dataset)
self.get_next = iterator.get_next()

View File

@ -22,6 +22,7 @@ import collections
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -52,7 +53,7 @@ class ShuffleDatasetTest(test.TestCase):
# Create initialization ops for iterators without and with
# shuffling, respectively.
iterator = dataset_ops.Iterator.from_structure(
iterator = iterator_ops.Iterator.from_structure(
shuffle_dataset.output_types, shuffle_dataset.output_shapes)
init_fifo_op = iterator.make_initializer(repeat_dataset)
init_shuffle_op = iterator.make_initializer(shuffle_dataset)