[tf.data] Internal minor code restructure
PiperOrigin-RevId: 170787468
This commit is contained in:
parent
991dea6bed
commit
df3dbbadbc
@ -13,7 +13,7 @@ py_library(
|
||||
"//tensorflow/contrib/data/python/ops:readers",
|
||||
"//tensorflow/contrib/data/python/ops:transformation_ops",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -54,7 +54,7 @@ from tensorflow.contrib.data.python.ops.readers import TextLineDataset
|
||||
from tensorflow.contrib.data.python.ops.readers import TFRecordDataset
|
||||
from tensorflow.contrib.data.python.ops.resampling import rejection_resample
|
||||
from tensorflow.contrib.data.python.ops.sloppy_ops import sloppy_interleave
|
||||
from tensorflow.python.data.ops.dataset_ops import Iterator
|
||||
from tensorflow.python.data.ops.iterator_ops import Iterator
|
||||
# pylint: enable=unused-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
@ -62,6 +62,7 @@ py_test(
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
@ -160,6 +161,7 @@ py_test(
|
||||
"//tensorflow/python:function",
|
||||
"//tensorflow/python:functional_ops",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
],
|
||||
)
|
||||
|
||||
@ -188,6 +190,7 @@ py_test(
|
||||
"//tensorflow/python:script_ops",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
@ -252,6 +255,7 @@ py_test(
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
],
|
||||
)
|
||||
|
||||
@ -261,7 +265,6 @@ py_test(
|
||||
srcs = ["reader_dataset_ops_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/data/python/ops:dataset_ops",
|
||||
"//tensorflow/contrib/data/python/ops:readers",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -275,6 +278,7 @@ py_test(
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
],
|
||||
)
|
||||
|
||||
@ -338,6 +342,7 @@ py_test(
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
@ -24,6 +24,7 @@ import tempfile
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -59,8 +60,8 @@ class FilesystemCacheDatasetTest(test.TestCase):
|
||||
|
||||
# Create initialization ops for iterators without and with
|
||||
# caching, respectively.
|
||||
iterator = dataset_ops.Iterator.from_structure(cache_dataset.output_types,
|
||||
cache_dataset.output_shapes)
|
||||
iterator = iterator_ops.Iterator.from_structure(cache_dataset.output_types,
|
||||
cache_dataset.output_shapes)
|
||||
init_fifo_op = iterator.make_initializer(repeat_dataset)
|
||||
init_cache_op = iterator.make_initializer(cache_dataset)
|
||||
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import function
|
||||
@ -44,7 +45,7 @@ class IteratorClusterTest(test.TestCase):
|
||||
iterator_3_handle = iterator_3.string_handle()
|
||||
|
||||
with ops.device("/job:worker/replica:0/task:0/cpu:0"):
|
||||
remote_it = dataset_ops.Iterator.from_string_handle(
|
||||
remote_it = iterator_ops.Iterator.from_string_handle(
|
||||
iterator_3_handle, dataset_3.output_types, dataset_3.output_shapes)
|
||||
get_next_op = remote_it.get_next()
|
||||
|
||||
@ -60,7 +61,7 @@ class IteratorClusterTest(test.TestCase):
|
||||
|
||||
@function.Defun(dtypes.string)
|
||||
def _remote_fn(h):
|
||||
remote_iterator = dataset_ops.Iterator.from_string_handle(
|
||||
remote_iterator = iterator_ops.Iterator.from_string_handle(
|
||||
h, dataset_3.output_types, dataset_3.output_shapes)
|
||||
return remote_iterator.get_next()
|
||||
|
||||
|
@ -24,6 +24,7 @@ from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.contrib.data.python.ops import readers
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -239,7 +240,7 @@ class IteratorTest(test.TestCase):
|
||||
# functions in this graph, to ensure that we are not
|
||||
# accidentally redefining functions with the same names in the
|
||||
# new graph.
|
||||
iterator = dataset_ops.Iterator.from_structure(
|
||||
iterator = iterator_ops.Iterator.from_structure(
|
||||
shared_name="shared_iterator",
|
||||
output_types=(dtypes.int64, dtypes.int64, dtypes.float64),
|
||||
output_shapes=([], [3], []))
|
||||
@ -269,8 +270,8 @@ class IteratorTest(test.TestCase):
|
||||
constant_op.constant([1, 2, 3]))
|
||||
dataset_4 = dataset_ops.Dataset.from_tensors(
|
||||
constant_op.constant([4, 5, 6, 7]))
|
||||
iterator = dataset_ops.Iterator.from_structure(dataset_3.output_types,
|
||||
[None])
|
||||
iterator = iterator_ops.Iterator.from_structure(dataset_3.output_types,
|
||||
[None])
|
||||
|
||||
dataset_3_init_op = iterator.make_initializer(dataset_3)
|
||||
dataset_4_init_op = iterator.make_initializer(dataset_4)
|
||||
@ -306,12 +307,12 @@ class IteratorTest(test.TestCase):
|
||||
def testReinitializableIteratorStaticErrors(self):
|
||||
# Non-matching structure for types and shapes.
|
||||
with self.assertRaises(TypeError):
|
||||
iterator = dataset_ops.Iterator.from_structure((dtypes.int64,
|
||||
dtypes.float64), [None])
|
||||
iterator = iterator_ops.Iterator.from_structure((dtypes.int64,
|
||||
dtypes.float64), [None])
|
||||
|
||||
# Test validation of dataset argument.
|
||||
iterator = dataset_ops.Iterator.from_structure((dtypes.int64,
|
||||
dtypes.float64))
|
||||
iterator = iterator_ops.Iterator.from_structure((dtypes.int64,
|
||||
dtypes.float64))
|
||||
|
||||
# Incompatible structure.
|
||||
with self.assertRaises(ValueError):
|
||||
@ -328,7 +329,7 @@ class IteratorTest(test.TestCase):
|
||||
[4., 5., 6., 7.], dtype=dtypes.float32))))
|
||||
|
||||
# Incompatible shapes.
|
||||
iterator = dataset_ops.Iterator.from_structure(
|
||||
iterator = iterator_ops.Iterator.from_structure(
|
||||
(dtypes.int64, dtypes.float64), ([None], []))
|
||||
with self.assertRaises(TypeError):
|
||||
iterator.make_initializer(
|
||||
@ -344,7 +345,7 @@ class IteratorTest(test.TestCase):
|
||||
iterator_4 = dataset_4.make_one_shot_iterator()
|
||||
|
||||
handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
|
||||
feedable_iterator = dataset_ops.Iterator.from_string_handle(
|
||||
feedable_iterator = iterator_ops.Iterator.from_string_handle(
|
||||
handle_placeholder, dataset_3.output_types, dataset_3.output_shapes)
|
||||
next_element = feedable_iterator.get_next()
|
||||
|
||||
@ -391,11 +392,11 @@ class IteratorTest(test.TestCase):
|
||||
|
||||
handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
|
||||
|
||||
feedable_int_scalar = dataset_ops.Iterator.from_string_handle(
|
||||
feedable_int_scalar = iterator_ops.Iterator.from_string_handle(
|
||||
handle_placeholder, dtypes.int32, [])
|
||||
feedable_int_vector = dataset_ops.Iterator.from_string_handle(
|
||||
feedable_int_vector = iterator_ops.Iterator.from_string_handle(
|
||||
handle_placeholder, dtypes.int32, [None])
|
||||
feedable_int_any = dataset_ops.Iterator.from_string_handle(
|
||||
feedable_int_any = iterator_ops.Iterator.from_string_handle(
|
||||
handle_placeholder, dtypes.int32)
|
||||
|
||||
with self.test_session() as sess:
|
||||
@ -435,7 +436,7 @@ class IteratorTest(test.TestCase):
|
||||
|
||||
@function.Defun(dtypes.string)
|
||||
def _remote_fn(h):
|
||||
remote_iterator = dataset_ops.Iterator.from_string_handle(
|
||||
remote_iterator = iterator_ops.Iterator.from_string_handle(
|
||||
h, dataset_3.output_types, dataset_3.output_shapes)
|
||||
return remote_iterator.get_next()
|
||||
|
||||
@ -495,7 +496,7 @@ class IteratorTest(test.TestCase):
|
||||
@function.Defun(dtypes.uint8)
|
||||
def _remote_fn(h):
|
||||
handle = script_ops.py_func(_encode_raw, [h], dtypes.string)
|
||||
remote_iterator = dataset_ops.Iterator.from_string_handle(
|
||||
remote_iterator = iterator_ops.Iterator.from_string_handle(
|
||||
handle, dataset_3.output_types, dataset_3.output_shapes)
|
||||
return remote_iterator.get_next()
|
||||
|
||||
|
@ -21,6 +21,7 @@ import os
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.contrib.data.python.ops import enumerate_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -279,8 +280,8 @@ class RangeDatasetTest(test.TestCase):
|
||||
# Create an empty IteratorResource and restore the Iterator into it.
|
||||
output_types = dtypes.int64
|
||||
output_shapes = tensor_shape.scalar()
|
||||
iterator = dataset_ops.Iterator.from_structure(output_types,
|
||||
output_shapes)
|
||||
iterator = iterator_ops.Iterator.from_structure(output_types,
|
||||
output_shapes)
|
||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
||||
path)
|
||||
get_next = iterator.get_next()
|
||||
|
@ -21,10 +21,10 @@ import gzip
|
||||
import os
|
||||
import zlib
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.contrib.data.python.ops import readers
|
||||
from tensorflow.core.example import example_pb2
|
||||
from tensorflow.core.example import feature_pb2
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -87,7 +87,7 @@ class TextLineDatasetTest(test.TestCase):
|
||||
filenames, compression_type=compression_type).repeat(num_epochs)
|
||||
batch_dataset = repeat_dataset.batch(batch_size)
|
||||
|
||||
iterator = dataset_ops.Iterator.from_structure(batch_dataset.output_types)
|
||||
iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
|
||||
init_op = iterator.make_initializer(repeat_dataset)
|
||||
init_batch_op = iterator.make_initializer(batch_dataset)
|
||||
get_next = iterator.get_next()
|
||||
@ -199,7 +199,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
|
||||
.repeat(num_epochs))
|
||||
batch_dataset = repeat_dataset.batch(batch_size)
|
||||
|
||||
iterator = dataset_ops.Iterator.from_structure(batch_dataset.output_types)
|
||||
iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
|
||||
init_op = iterator.make_initializer(repeat_dataset)
|
||||
init_batch_op = iterator.make_initializer(batch_dataset)
|
||||
get_next = iterator.get_next()
|
||||
@ -293,7 +293,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
|
||||
def _restore_iterator(self):
|
||||
output_types = dtypes.string
|
||||
output_shapes = tensor_shape.scalar()
|
||||
iterator = dataset_ops.Iterator.from_structure(output_types, output_shapes)
|
||||
iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
|
||||
get_next = iterator.get_next()
|
||||
restore_op = gen_dataset_ops.restore_iterator(
|
||||
iterator._iterator_resource, self._iterator_checkpoint_path())
|
||||
@ -575,7 +575,7 @@ class TFRecordDatasetTest(test.TestCase):
|
||||
self.num_epochs)
|
||||
batch_dataset = repeat_dataset.batch(self.batch_size)
|
||||
|
||||
iterator = dataset_ops.Iterator.from_structure(batch_dataset.output_types)
|
||||
iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
|
||||
self.init_op = iterator.make_initializer(repeat_dataset)
|
||||
self.init_batch_op = iterator.make_initializer(batch_dataset)
|
||||
self.get_next = iterator.get_next()
|
||||
|
@ -22,6 +22,7 @@ import collections
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -52,7 +53,7 @@ class ShuffleDatasetTest(test.TestCase):
|
||||
|
||||
# Create initialization ops for iterators without and with
|
||||
# shuffling, respectively.
|
||||
iterator = dataset_ops.Iterator.from_structure(
|
||||
iterator = iterator_ops.Iterator.from_structure(
|
||||
shuffle_dataset.output_types, shuffle_dataset.output_shapes)
|
||||
init_fifo_op = iterator.make_initializer(repeat_dataset)
|
||||
init_shuffle_op = iterator.make_initializer(shuffle_dataset)
|
||||
|
@ -16,7 +16,6 @@ py_library(
|
||||
"//tensorflow/python:script_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
],
|
||||
)
|
||||
|
@ -23,9 +23,6 @@ from tensorflow.contrib.data.python.ops import error_ops
|
||||
from tensorflow.contrib.data.python.ops import grouping
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.python.data.ops.iterator import Iterator
|
||||
# pylint: enable=unused-import
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
|
@ -11,7 +11,7 @@ py_library(
|
||||
deps = [
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
"//tensorflow/python/data/ops:readers",
|
||||
],
|
||||
)
|
||||
|
@ -29,7 +29,7 @@ from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.python.data.ops.dataset_ops import Dataset
|
||||
from tensorflow.python.data.ops.iterator import Iterator
|
||||
from tensorflow.python.data.ops.iterator_ops import Iterator
|
||||
from tensorflow.python.data.ops.readers import FixedLengthRecordDataset
|
||||
from tensorflow.python.data.ops.readers import TextLineDataset
|
||||
from tensorflow.python.data.ops.readers import TFRecordDataset
|
||||
|
@ -9,7 +9,7 @@ py_library(
|
||||
srcs = ["dataset_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":iterator",
|
||||
":iterator_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dataset_ops_gen",
|
||||
"//tensorflow/python:dtypes",
|
||||
@ -41,8 +41,8 @@ py_library(
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "iterator",
|
||||
srcs = ["iterator.py"],
|
||||
name = "iterator_ops",
|
||||
srcs = ["iterator_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:dataset_ops_gen",
|
||||
|
@ -23,8 +23,7 @@ import threading
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.ops import iterator
|
||||
from tensorflow.python.data.ops.iterator import Iterator # pylint: disable=unused-import
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -92,9 +91,8 @@ class Dataset(object):
|
||||
with ops.colocate_with(iterator_resource):
|
||||
initializer = gen_dataset_ops.make_iterator(
|
||||
self._as_variant_tensor(), iterator_resource)
|
||||
return iterator.Iterator(
|
||||
iterator_resource, initializer, self.output_types,
|
||||
self.output_shapes)
|
||||
return iterator_ops.Iterator(iterator_resource, initializer,
|
||||
self.output_types, self.output_shapes)
|
||||
|
||||
def make_one_shot_iterator(self):
|
||||
"""Creates an `Iterator` for enumerating the elements of this dataset.
|
||||
@ -113,7 +111,7 @@ class Dataset(object):
|
||||
|
||||
_make_dataset.add_to_graph(ops.get_default_graph())
|
||||
|
||||
return iterator.Iterator(
|
||||
return iterator_ops.Iterator(
|
||||
gen_dataset_ops.one_shot_iterator(
|
||||
dataset_factory=_make_dataset,
|
||||
output_types=nest.flatten(self.output_types),
|
||||
|
@ -2960,6 +2960,7 @@ tf_py_test(
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
],
|
||||
)
|
||||
|
||||
@ -2978,7 +2979,7 @@ tf_py_test(
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
"//tensorflow/python/data/ops:readers",
|
||||
],
|
||||
)
|
||||
@ -3009,6 +3010,7 @@ tf_py_test(
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
],
|
||||
)
|
||||
|
||||
@ -3036,6 +3038,7 @@ tf_py_test(
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
],
|
||||
)
|
||||
|
||||
@ -3076,6 +3079,7 @@ tf_py_test(
|
||||
"//tensorflow/python/data/ops:readers",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
@ -3111,6 +3115,7 @@ tf_py_test(
|
||||
"//tensorflow/python:functional_ops",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
],
|
||||
tags = ["no_windows"],
|
||||
)
|
||||
|
@ -24,6 +24,7 @@ import tempfile
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -59,8 +60,8 @@ class FilesystemCacheDatasetTest(test.TestCase):
|
||||
|
||||
# Create initialization ops for iterators without and with
|
||||
# caching, respectively.
|
||||
iterator = dataset_ops.Iterator.from_structure(cache_dataset.output_types,
|
||||
cache_dataset.output_shapes)
|
||||
iterator = iterator_ops.Iterator.from_structure(cache_dataset.output_types,
|
||||
cache_dataset.output_shapes)
|
||||
init_fifo_op = iterator.make_initializer(repeat_dataset)
|
||||
init_cache_op = iterator.make_initializer(cache_dataset)
|
||||
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import function
|
||||
@ -44,7 +45,7 @@ class IteratorClusterTest(test.TestCase):
|
||||
iterator_3_handle = iterator_3.string_handle()
|
||||
|
||||
with ops.device("/job:worker/replica:0/task:0/cpu:0"):
|
||||
remote_it = dataset_ops.Iterator.from_string_handle(
|
||||
remote_it = iterator_ops.Iterator.from_string_handle(
|
||||
iterator_3_handle, dataset_3.output_types, dataset_3.output_shapes)
|
||||
get_next_op = remote_it.get_next()
|
||||
|
||||
@ -65,7 +66,7 @@ class IteratorClusterTest(test.TestCase):
|
||||
|
||||
@function.Defun(dtypes.string)
|
||||
def _remote_fn(h):
|
||||
remote_iterator = dataset_ops.Iterator.from_string_handle(
|
||||
remote_iterator = iterator_ops.Iterator.from_string_handle(
|
||||
h, dataset_3.output_types, dataset_3.output_shapes)
|
||||
return remote_iterator.get_next()
|
||||
|
||||
|
@ -23,6 +23,7 @@ import numpy as np
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.data.ops import readers
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -239,7 +240,7 @@ class IteratorTest(test.TestCase):
|
||||
# functions in this graph, to ensure that we are not
|
||||
# accidentally redefining functions with the same names in the
|
||||
# new graph.
|
||||
iterator = dataset_ops.Iterator.from_structure(
|
||||
iterator = iterator_ops.Iterator.from_structure(
|
||||
shared_name="shared_iterator",
|
||||
output_types=(dtypes.int64, dtypes.int64, dtypes.float64),
|
||||
output_shapes=([], [3], []))
|
||||
@ -269,8 +270,8 @@ class IteratorTest(test.TestCase):
|
||||
constant_op.constant([1, 2, 3]))
|
||||
dataset_4 = dataset_ops.Dataset.from_tensors(
|
||||
constant_op.constant([4, 5, 6, 7]))
|
||||
iterator = dataset_ops.Iterator.from_structure(dataset_3.output_types,
|
||||
[None])
|
||||
iterator = iterator_ops.Iterator.from_structure(dataset_3.output_types,
|
||||
[None])
|
||||
|
||||
dataset_3_init_op = iterator.make_initializer(dataset_3)
|
||||
dataset_4_init_op = iterator.make_initializer(dataset_4)
|
||||
@ -306,12 +307,12 @@ class IteratorTest(test.TestCase):
|
||||
def testReinitializableIteratorStaticErrors(self):
|
||||
# Non-matching structure for types and shapes.
|
||||
with self.assertRaises(TypeError):
|
||||
iterator = dataset_ops.Iterator.from_structure((dtypes.int64,
|
||||
dtypes.float64), [None])
|
||||
iterator = iterator_ops.Iterator.from_structure((dtypes.int64,
|
||||
dtypes.float64), [None])
|
||||
|
||||
# Test validation of dataset argument.
|
||||
iterator = dataset_ops.Iterator.from_structure((dtypes.int64,
|
||||
dtypes.float64))
|
||||
iterator = iterator_ops.Iterator.from_structure((dtypes.int64,
|
||||
dtypes.float64))
|
||||
|
||||
# Incompatible structure.
|
||||
with self.assertRaises(ValueError):
|
||||
@ -328,7 +329,7 @@ class IteratorTest(test.TestCase):
|
||||
[4., 5., 6., 7.], dtype=dtypes.float32))))
|
||||
|
||||
# Incompatible shapes.
|
||||
iterator = dataset_ops.Iterator.from_structure(
|
||||
iterator = iterator_ops.Iterator.from_structure(
|
||||
(dtypes.int64, dtypes.float64), ([None], []))
|
||||
with self.assertRaises(TypeError):
|
||||
iterator.make_initializer(
|
||||
@ -344,7 +345,7 @@ class IteratorTest(test.TestCase):
|
||||
iterator_4 = dataset_4.make_one_shot_iterator()
|
||||
|
||||
handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
|
||||
feedable_iterator = dataset_ops.Iterator.from_string_handle(
|
||||
feedable_iterator = iterator_ops.Iterator.from_string_handle(
|
||||
handle_placeholder, dataset_3.output_types, dataset_3.output_shapes)
|
||||
next_element = feedable_iterator.get_next()
|
||||
|
||||
@ -391,11 +392,11 @@ class IteratorTest(test.TestCase):
|
||||
|
||||
handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
|
||||
|
||||
feedable_int_scalar = dataset_ops.Iterator.from_string_handle(
|
||||
feedable_int_scalar = iterator_ops.Iterator.from_string_handle(
|
||||
handle_placeholder, dtypes.int32, [])
|
||||
feedable_int_vector = dataset_ops.Iterator.from_string_handle(
|
||||
feedable_int_vector = iterator_ops.Iterator.from_string_handle(
|
||||
handle_placeholder, dtypes.int32, [None])
|
||||
feedable_int_any = dataset_ops.Iterator.from_string_handle(
|
||||
feedable_int_any = iterator_ops.Iterator.from_string_handle(
|
||||
handle_placeholder, dtypes.int32)
|
||||
|
||||
with self.test_session() as sess:
|
||||
@ -435,7 +436,7 @@ class IteratorTest(test.TestCase):
|
||||
|
||||
@function.Defun(dtypes.string)
|
||||
def _remote_fn(h):
|
||||
remote_iterator = dataset_ops.Iterator.from_string_handle(
|
||||
remote_iterator = iterator_ops.Iterator.from_string_handle(
|
||||
h, dataset_3.output_types, dataset_3.output_shapes)
|
||||
return remote_iterator.get_next()
|
||||
|
||||
@ -495,7 +496,7 @@ class IteratorTest(test.TestCase):
|
||||
@function.Defun(dtypes.uint8)
|
||||
def _remote_fn(h):
|
||||
handle = script_ops.py_func(_encode_raw, [h], dtypes.string)
|
||||
remote_iterator = dataset_ops.Iterator.from_string_handle(
|
||||
remote_iterator = iterator_ops.Iterator.from_string_handle(
|
||||
handle, dataset_3.output_types, dataset_3.output_shapes)
|
||||
return remote_iterator.get_next()
|
||||
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
import os
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
@ -255,8 +256,8 @@ class RangeDatasetTest(test.TestCase):
|
||||
# Create an empty IteratorResource and restore the Iterator into it.
|
||||
output_types = dtypes.int64
|
||||
output_shapes = tensor_shape.scalar()
|
||||
iterator = dataset_ops.Iterator.from_structure(output_types,
|
||||
output_shapes)
|
||||
iterator = iterator_ops.Iterator.from_structure(output_types,
|
||||
output_shapes)
|
||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
||||
path)
|
||||
get_next = iterator.get_next()
|
||||
|
@ -21,7 +21,7 @@ import gzip
|
||||
import os
|
||||
import zlib
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.data.ops import readers
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -84,7 +84,7 @@ class TextLineDatasetTest(test.TestCase):
|
||||
filenames, compression_type=compression_type).repeat(num_epochs)
|
||||
batch_dataset = repeat_dataset.batch(batch_size)
|
||||
|
||||
iterator = dataset_ops.Iterator.from_structure(batch_dataset.output_types)
|
||||
iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
|
||||
init_op = iterator.make_initializer(repeat_dataset)
|
||||
init_batch_op = iterator.make_initializer(batch_dataset)
|
||||
get_next = iterator.get_next()
|
||||
@ -196,7 +196,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
|
||||
.repeat(num_epochs))
|
||||
batch_dataset = repeat_dataset.batch(batch_size)
|
||||
|
||||
iterator = dataset_ops.Iterator.from_structure(batch_dataset.output_types)
|
||||
iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
|
||||
init_op = iterator.make_initializer(repeat_dataset)
|
||||
init_batch_op = iterator.make_initializer(batch_dataset)
|
||||
get_next = iterator.get_next()
|
||||
@ -290,7 +290,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
|
||||
def _restore_iterator(self):
|
||||
output_types = dtypes.string
|
||||
output_shapes = tensor_shape.scalar()
|
||||
iterator = dataset_ops.Iterator.from_structure(output_types, output_shapes)
|
||||
iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
|
||||
get_next = iterator.get_next()
|
||||
restore_op = gen_dataset_ops.restore_iterator(
|
||||
iterator._iterator_resource, self._iterator_checkpoint_path())
|
||||
@ -572,7 +572,7 @@ class TFRecordDatasetTest(test.TestCase):
|
||||
self.num_epochs)
|
||||
batch_dataset = repeat_dataset.batch(self.batch_size)
|
||||
|
||||
iterator = dataset_ops.Iterator.from_structure(batch_dataset.output_types)
|
||||
iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
|
||||
self.init_op = iterator.make_initializer(repeat_dataset)
|
||||
self.init_batch_op = iterator.make_initializer(batch_dataset)
|
||||
self.get_next = iterator.get_next()
|
||||
|
@ -22,6 +22,7 @@ import collections
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -52,7 +53,7 @@ class ShuffleDatasetTest(test.TestCase):
|
||||
|
||||
# Create initialization ops for iterators without and with
|
||||
# shuffling, respectively.
|
||||
iterator = dataset_ops.Iterator.from_structure(
|
||||
iterator = iterator_ops.Iterator.from_structure(
|
||||
shuffle_dataset.output_types, shuffle_dataset.output_shapes)
|
||||
init_fifo_op = iterator.make_initializer(repeat_dataset)
|
||||
init_shuffle_op = iterator.make_initializer(shuffle_dataset)
|
||||
|
Loading…
Reference in New Issue
Block a user