Add SerializeIterator op that serializes an IteratorResource into a variant tensor.
Add DeserializeIterator op that builds IteratorResource from a variant tensor. Move BundleReaderWrapper and BundleWriterWrapper from dataset.h to iterator_ops.cc. Add generic key-value store interfaces IteratorStateReader and IteratorStateWriter for reading/writing state of iterators. Get rid of IteratorBundleReader and IteratorBundleWriter. PiperOrigin-RevId: 173140858
This commit is contained in:
parent
57f3e529d9
commit
1038927c09
@ -185,6 +185,7 @@ py_test(
|
|||||||
"//tensorflow/python:function",
|
"//tensorflow/python:function",
|
||||||
"//tensorflow/python:functional_ops",
|
"//tensorflow/python:functional_ops",
|
||||||
"//tensorflow/python:gradients",
|
"//tensorflow/python:gradients",
|
||||||
|
"//tensorflow/python:io_ops",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:parsing_ops",
|
"//tensorflow/python:parsing_ops",
|
||||||
"//tensorflow/python:script_ops",
|
"//tensorflow/python:script_ops",
|
||||||
@ -252,6 +253,8 @@ py_test(
|
|||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:errors",
|
"//tensorflow/python:errors",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:io_ops",
|
||||||
|
"//tensorflow/python:parsing_ops",
|
||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
"//tensorflow/python:tensor_shape",
|
"//tensorflow/python:tensor_shape",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
@ -274,6 +277,7 @@ py_test(
|
|||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:errors",
|
"//tensorflow/python:errors",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:io_ops",
|
||||||
"//tensorflow/python:lib",
|
"//tensorflow/python:lib",
|
||||||
"//tensorflow/python:parsing_ops",
|
"//tensorflow/python:parsing_ops",
|
||||||
"//tensorflow/python:tensor_shape",
|
"//tensorflow/python:tensor_shape",
|
||||||
|
@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import functional_ops
|
from tensorflow.python.ops import functional_ops
|
||||||
from tensorflow.python.ops import gen_dataset_ops
|
from tensorflow.python.ops import gen_dataset_ops
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
|
from tensorflow.python.ops import io_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import parsing_ops
|
from tensorflow.python.ops import parsing_ops
|
||||||
from tensorflow.python.ops import script_ops
|
from tensorflow.python.ops import script_ops
|
||||||
@ -538,9 +539,23 @@ class IteratorTest(test.TestCase):
|
|||||||
|
|
||||||
def testIncorrectIteratorRestore(self):
|
def testIncorrectIteratorRestore(self):
|
||||||
|
|
||||||
def _iterator_checkpoint_prefix():
|
def _path():
|
||||||
return os.path.join(self.get_temp_dir(), "iterator")
|
return os.path.join(self.get_temp_dir(), "iterator")
|
||||||
|
|
||||||
|
def _save_op(iterator_resource):
|
||||||
|
iterator_state_variant = gen_dataset_ops.serialize_iterator(
|
||||||
|
iterator_resource)
|
||||||
|
save_op = io_ops.write_file(
|
||||||
|
_path(), parsing_ops.serialize_tensor(iterator_state_variant))
|
||||||
|
return save_op
|
||||||
|
|
||||||
|
def _restore_op(iterator_resource):
|
||||||
|
iterator_state_variant = parsing_ops.parse_tensor(
|
||||||
|
io_ops.read_file(_path()), dtypes.variant)
|
||||||
|
restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
|
||||||
|
iterator_state_variant)
|
||||||
|
return restore_op
|
||||||
|
|
||||||
def _build_range_dataset_graph():
|
def _build_range_dataset_graph():
|
||||||
start = 1
|
start = 1
|
||||||
stop = 10
|
stop = 10
|
||||||
@ -548,22 +563,18 @@ class IteratorTest(test.TestCase):
|
|||||||
stop).make_initializable_iterator()
|
stop).make_initializable_iterator()
|
||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
path = _iterator_checkpoint_prefix()
|
save_op = _save_op(iterator._iterator_resource)
|
||||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
restore_op = _restore_op(iterator._iterator_resource)
|
||||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
|
||||||
path)
|
|
||||||
return init_op, get_next, save_op, restore_op
|
return init_op, get_next, save_op, restore_op
|
||||||
|
|
||||||
def _build_reader_dataset_graph():
|
def _build_reader_dataset_graph():
|
||||||
filenames = ["test"] # Does not exist but we don't care in this test.
|
filenames = ["test"] # Does not exist but we don't care in this test.
|
||||||
path = _iterator_checkpoint_prefix()
|
|
||||||
iterator = readers.FixedLengthRecordDataset(
|
iterator = readers.FixedLengthRecordDataset(
|
||||||
filenames, 1, 0, 0).make_initializable_iterator()
|
filenames, 1, 0, 0).make_initializable_iterator()
|
||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next_op = iterator.get_next()
|
get_next_op = iterator.get_next()
|
||||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
save_op = _save_op(iterator._iterator_resource)
|
||||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
restore_op = _restore_op(iterator._iterator_resource)
|
||||||
path)
|
|
||||||
return init_op, get_next_op, save_op, restore_op
|
return init_op, get_next_op, save_op, restore_op
|
||||||
|
|
||||||
# Saving iterator for RangeDataset graph.
|
# Saving iterator for RangeDataset graph.
|
||||||
|
@ -29,6 +29,8 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gen_dataset_ops
|
from tensorflow.python.ops import gen_dataset_ops
|
||||||
|
from tensorflow.python.ops import io_ops
|
||||||
|
from tensorflow.python.ops import parsing_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -193,6 +195,21 @@ class RangeDatasetTest(test.TestCase):
|
|||||||
def _iterator_checkpoint_prefix(self):
|
def _iterator_checkpoint_prefix(self):
|
||||||
return os.path.join(self.get_temp_dir(), "iterator")
|
return os.path.join(self.get_temp_dir(), "iterator")
|
||||||
|
|
||||||
|
def _save_op(self, iterator_resource):
|
||||||
|
iterator_state_variant = gen_dataset_ops.serialize_iterator(
|
||||||
|
iterator_resource)
|
||||||
|
save_op = io_ops.write_file(
|
||||||
|
self._iterator_checkpoint_prefix(),
|
||||||
|
parsing_ops.serialize_tensor(iterator_state_variant))
|
||||||
|
return save_op
|
||||||
|
|
||||||
|
def _restore_op(self, iterator_resource):
|
||||||
|
iterator_state_variant = parsing_ops.parse_tensor(
|
||||||
|
io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant)
|
||||||
|
restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
|
||||||
|
iterator_state_variant)
|
||||||
|
return restore_op
|
||||||
|
|
||||||
def testSaveRestore(self):
|
def testSaveRestore(self):
|
||||||
|
|
||||||
def _build_graph(start, stop):
|
def _build_graph(start, stop):
|
||||||
@ -200,10 +217,8 @@ class RangeDatasetTest(test.TestCase):
|
|||||||
stop).make_initializable_iterator()
|
stop).make_initializable_iterator()
|
||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
path = self._iterator_checkpoint_prefix()
|
save_op = self._save_op(iterator._iterator_resource)
|
||||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
restore_op = self._restore_op(iterator._iterator_resource)
|
||||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
|
||||||
path)
|
|
||||||
return init_op, get_next, save_op, restore_op
|
return init_op, get_next, save_op, restore_op
|
||||||
|
|
||||||
# Saving and restoring in different sessions.
|
# Saving and restoring in different sessions.
|
||||||
@ -246,14 +261,13 @@ class RangeDatasetTest(test.TestCase):
|
|||||||
|
|
||||||
def testRestoreWithoutBuildingDatasetGraph(self):
|
def testRestoreWithoutBuildingDatasetGraph(self):
|
||||||
|
|
||||||
def _build_graph(start, stop, num_epochs, path):
|
def _build_graph(start, stop, num_epochs):
|
||||||
dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs)
|
dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs)
|
||||||
iterator = dataset.make_initializable_iterator()
|
iterator = dataset.make_initializable_iterator()
|
||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
save_op = self._save_op(iterator._iterator_resource)
|
||||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
restore_op = self._restore_op(iterator._iterator_resource)
|
||||||
path)
|
|
||||||
return init_op, get_next, save_op, restore_op
|
return init_op, get_next, save_op, restore_op
|
||||||
|
|
||||||
# Saving and restoring in different sessions.
|
# Saving and restoring in different sessions.
|
||||||
@ -262,10 +276,8 @@ class RangeDatasetTest(test.TestCase):
|
|||||||
num_epochs = 5
|
num_epochs = 5
|
||||||
break_point = 5
|
break_point = 5
|
||||||
break_epoch = 3
|
break_epoch = 3
|
||||||
path = self._iterator_checkpoint_prefix()
|
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs,
|
init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs)
|
||||||
path)
|
|
||||||
with self.test_session(graph=g) as sess:
|
with self.test_session(graph=g) as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
sess.run(variables.global_variables_initializer())
|
||||||
sess.run(init_op)
|
sess.run(init_op)
|
||||||
@ -282,8 +294,7 @@ class RangeDatasetTest(test.TestCase):
|
|||||||
output_shapes = tensor_shape.scalar()
|
output_shapes = tensor_shape.scalar()
|
||||||
iterator = iterator_ops.Iterator.from_structure(output_types,
|
iterator = iterator_ops.Iterator.from_structure(output_types,
|
||||||
output_shapes)
|
output_shapes)
|
||||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
restore_op = self._restore_op(iterator._iterator_resource)
|
||||||
path)
|
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
with self.test_session(graph=g) as sess:
|
with self.test_session(graph=g) as sess:
|
||||||
sess.run(restore_op)
|
sess.run(restore_op)
|
||||||
@ -302,10 +313,8 @@ class RangeDatasetTest(test.TestCase):
|
|||||||
iterator = dataset.make_initializable_iterator()
|
iterator = dataset.make_initializable_iterator()
|
||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
path = self._iterator_checkpoint_prefix()
|
save_op = self._save_op(iterator._iterator_resource)
|
||||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
restore_op = self._restore_op(iterator._iterator_resource)
|
||||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
|
||||||
path)
|
|
||||||
return init_op, get_next, save_op, restore_op
|
return init_op, get_next, save_op, restore_op
|
||||||
|
|
||||||
# Saving and restoring in different sessions.
|
# Saving and restoring in different sessions.
|
||||||
@ -343,10 +352,8 @@ class RangeDatasetTest(test.TestCase):
|
|||||||
iterator = dataset.make_initializable_iterator()
|
iterator = dataset.make_initializable_iterator()
|
||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
path = self._iterator_checkpoint_prefix()
|
save_op = self._save_op(iterator._iterator_resource)
|
||||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
restore_op = self._restore_op(iterator._iterator_resource)
|
||||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
|
||||||
path)
|
|
||||||
return init_op, get_next, save_op, restore_op
|
return init_op, get_next, save_op, restore_op
|
||||||
|
|
||||||
# Saving and restoring in different sessions.
|
# Saving and restoring in different sessions.
|
||||||
@ -379,10 +386,8 @@ class RangeDatasetTest(test.TestCase):
|
|||||||
stop).make_initializable_iterator()
|
stop).make_initializable_iterator()
|
||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
path = self._iterator_checkpoint_prefix()
|
save_op = self._save_op(iterator._iterator_resource)
|
||||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
restore_op = self._restore_op(iterator._iterator_resource)
|
||||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
|
||||||
path)
|
|
||||||
return init_op, get_next, save_op, restore_op
|
return init_op, get_next, save_op, restore_op
|
||||||
|
|
||||||
start = 2
|
start = 2
|
||||||
@ -424,10 +429,8 @@ class RangeDatasetTest(test.TestCase):
|
|||||||
start, stop).repeat(num_epochs).make_initializable_iterator()
|
start, stop).repeat(num_epochs).make_initializable_iterator()
|
||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
path = self._iterator_checkpoint_prefix()
|
save_op = self._save_op(iterator._iterator_resource)
|
||||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
restore_op = self._restore_op(iterator._iterator_resource)
|
||||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
|
||||||
path)
|
|
||||||
return init_op, get_next, save_op, restore_op
|
return init_op, get_next, save_op, restore_op
|
||||||
|
|
||||||
start = 2
|
start = 2
|
||||||
@ -471,10 +474,8 @@ class RangeDatasetTest(test.TestCase):
|
|||||||
start, stop).repeat(num_epochs).make_initializable_iterator()
|
start, stop).repeat(num_epochs).make_initializable_iterator()
|
||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
path = self._iterator_checkpoint_prefix()
|
save_op = self._save_op(iterator._iterator_resource)
|
||||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
restore_op = self._restore_op(iterator._iterator_resource)
|
||||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
|
||||||
path)
|
|
||||||
return init_op, get_next, save_op, restore_op
|
return init_op, get_next, save_op, restore_op
|
||||||
|
|
||||||
start = 2
|
start = 2
|
||||||
|
@ -33,6 +33,7 @@ from tensorflow.python.framework import tensor_shape
|
|||||||
from tensorflow.python.lib.io import python_io
|
from tensorflow.python.lib.io import python_io
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gen_dataset_ops
|
from tensorflow.python.ops import gen_dataset_ops
|
||||||
|
from tensorflow.python.ops import io_ops
|
||||||
from tensorflow.python.ops import parsing_ops
|
from tensorflow.python.ops import parsing_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
@ -276,18 +277,31 @@ class FixedLengthRecordReaderTest(test.TestCase):
|
|||||||
def _iterator_checkpoint_path(self):
|
def _iterator_checkpoint_path(self):
|
||||||
return os.path.join(self.get_temp_dir(), "iterator")
|
return os.path.join(self.get_temp_dir(), "iterator")
|
||||||
|
|
||||||
|
def _save_op(self, iterator_resource):
|
||||||
|
iterator_state_variant = gen_dataset_ops.serialize_iterator(
|
||||||
|
iterator_resource)
|
||||||
|
save_op = io_ops.write_file(
|
||||||
|
self._iterator_checkpoint_path(),
|
||||||
|
parsing_ops.serialize_tensor(iterator_state_variant))
|
||||||
|
return save_op
|
||||||
|
|
||||||
|
def _restore_op(self, iterator_resource):
|
||||||
|
iterator_state_variant = parsing_ops.parse_tensor(
|
||||||
|
io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant)
|
||||||
|
restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
|
||||||
|
iterator_state_variant)
|
||||||
|
return restore_op
|
||||||
|
|
||||||
def _build_iterator_graph(self, num_epochs):
|
def _build_iterator_graph(self, num_epochs):
|
||||||
filenames = self._createFiles()
|
filenames = self._createFiles()
|
||||||
path = self._iterator_checkpoint_path()
|
|
||||||
dataset = (readers.FixedLengthRecordDataset(
|
dataset = (readers.FixedLengthRecordDataset(
|
||||||
filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
|
filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
|
||||||
.repeat(num_epochs))
|
.repeat(num_epochs))
|
||||||
iterator = dataset.make_initializable_iterator()
|
iterator = dataset.make_initializable_iterator()
|
||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next_op = iterator.get_next()
|
get_next_op = iterator.get_next()
|
||||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
save_op = self._save_op(iterator._iterator_resource)
|
||||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
restore_op = self._restore_op(iterator._iterator_resource)
|
||||||
path)
|
|
||||||
return init_op, get_next_op, save_op, restore_op
|
return init_op, get_next_op, save_op, restore_op
|
||||||
|
|
||||||
def _restore_iterator(self):
|
def _restore_iterator(self):
|
||||||
@ -295,8 +309,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
|
|||||||
output_shapes = tensor_shape.scalar()
|
output_shapes = tensor_shape.scalar()
|
||||||
iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
|
iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
restore_op = gen_dataset_ops.restore_iterator(
|
restore_op = self._restore_op(iterator._iterator_resource)
|
||||||
iterator._iterator_resource, self._iterator_checkpoint_path())
|
|
||||||
return restore_op, get_next
|
return restore_op, get_next
|
||||||
|
|
||||||
def testSaveRestore(self):
|
def testSaveRestore(self):
|
||||||
|
@ -163,6 +163,7 @@ CORE_PROTO_SRCS = [
|
|||||||
"framework/function.proto",
|
"framework/function.proto",
|
||||||
"framework/graph.proto",
|
"framework/graph.proto",
|
||||||
"framework/graph_transfer_info.proto",
|
"framework/graph_transfer_info.proto",
|
||||||
|
"framework/iterator.proto",
|
||||||
"framework/kernel_def.proto",
|
"framework/kernel_def.proto",
|
||||||
"framework/log_memory.proto",
|
"framework/log_memory.proto",
|
||||||
"framework/node_def.proto",
|
"framework/node_def.proto",
|
||||||
|
17
tensorflow/core/framework/iterator.proto
Normal file
17
tensorflow/core/framework/iterator.proto
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package tensorflow;
|
||||||
|
option cc_enable_arenas = true;
|
||||||
|
option java_outer_classname = "IteratorProtos";
|
||||||
|
option java_multiple_files = true;
|
||||||
|
option java_package = "org.tensorflow.util";
|
||||||
|
|
||||||
|
// Protocol buffer representing the metadata for an iterator's state stored
|
||||||
|
// as a Variant tensor.
|
||||||
|
message IteratorStateMetadata {
|
||||||
|
// A user-specified version string.
|
||||||
|
string version = 1;
|
||||||
|
|
||||||
|
// Keys for tensors in the VariantTensorDataProto.
|
||||||
|
repeated string keys = 2;
|
||||||
|
}
|
@ -6061,6 +6061,7 @@ tf_kernel_library(
|
|||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,12 +17,14 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/graph_runner.h"
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/resource_mgr.h"
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||||
#include "tensorflow/core/framework/variant_tensor_data.h"
|
#include "tensorflow/core/framework/variant_tensor_data.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
@ -39,54 +41,25 @@ namespace tensorflow {
|
|||||||
|
|
||||||
class ResourceMgr;
|
class ResourceMgr;
|
||||||
|
|
||||||
class BundleReaderWrapper {
|
// Interface for reading values from a key-value store.
|
||||||
|
// Used for restoring iterator state.
|
||||||
|
class IteratorStateReader {
|
||||||
public:
|
public:
|
||||||
BundleReaderWrapper(BundleReader* bundle_reader)
|
virtual Status ReadScalar(StringPiece key, int64* val) = 0;
|
||||||
: bundle_reader_(bundle_reader) {}
|
virtual Status ReadScalar(StringPiece key, string* val) = 0;
|
||||||
|
virtual bool Contains(StringPiece key) = 0;
|
||||||
|
|
||||||
// Reads a scalar value.
|
virtual ~IteratorStateReader() {}
|
||||||
template <typename T>
|
|
||||||
Status ReadScalar(StringPiece key, T* val) {
|
|
||||||
Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
|
|
||||||
TF_RETURN_IF_ERROR(Lookup(key, &val_t));
|
|
||||||
*val = val_t.scalar<T>()();
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Contains(StringPiece key) { return bundle_reader_->Contains(key); }
|
|
||||||
|
|
||||||
private:
|
|
||||||
Status Lookup(StringPiece key, Tensor* val) {
|
|
||||||
return bundle_reader_->Lookup(key, val);
|
|
||||||
}
|
|
||||||
|
|
||||||
BundleReader* bundle_reader_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class BundleWriterWrapper {
|
// Interface for writing values to a key-value store.
|
||||||
|
// Used for saving iterator state.
|
||||||
|
class IteratorStateWriter {
|
||||||
public:
|
public:
|
||||||
// Note: We intentionally do not provide a constructor that builds a
|
virtual Status WriteScalar(StringPiece key, const int64& val) = 0;
|
||||||
// BundleWriter from the checkpoint path because we want the caller to be
|
virtual Status WriteScalar(StringPiece key, const string& val) = 0;
|
||||||
// in-charge of calling BundleWriter::Finish(). If we expose the Finish()
|
|
||||||
// method here it may be called pre-maturely by users of this object.
|
|
||||||
explicit BundleWriterWrapper(BundleWriter* bundle_writer)
|
|
||||||
: bundle_writer_(bundle_writer) {}
|
|
||||||
|
|
||||||
// Writes a scalar value.
|
virtual ~IteratorStateWriter() {}
|
||||||
template <typename T>
|
|
||||||
Status WriteScalar(StringPiece key, const T val) {
|
|
||||||
Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
|
|
||||||
val_t.scalar<T>()() = val;
|
|
||||||
TF_RETURN_IF_ERROR(Add(key, val_t));
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
Status Add(StringPiece key, const Tensor& val) {
|
|
||||||
return bundle_writer_->Add(key, val);
|
|
||||||
}
|
|
||||||
|
|
||||||
BundleWriter* bundle_writer_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
|
// Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
|
||||||
@ -249,10 +222,6 @@ class IteratorContext {
|
|||||||
// range of outputs is typically represented by an `DatasetBase`,
|
// range of outputs is typically represented by an `DatasetBase`,
|
||||||
// defined below.
|
// defined below.
|
||||||
class IteratorBase {
|
class IteratorBase {
|
||||||
protected:
|
|
||||||
class IteratorBundleReader;
|
|
||||||
class IteratorBundleWriter;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
virtual ~IteratorBase() {}
|
virtual ~IteratorBase() {}
|
||||||
|
|
||||||
@ -284,75 +253,17 @@ class IteratorBase {
|
|||||||
virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
|
virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
|
||||||
|
|
||||||
// Saves the state of this iterator.
|
// Saves the state of this iterator.
|
||||||
virtual Status Save(OpKernelContext* ctx, const string& path) {
|
virtual Status Save(IteratorStateWriter* writer) {
|
||||||
BundleWriter bundle_writer(ctx->env(), path);
|
|
||||||
TF_RETURN_IF_ERROR(bundle_writer.status());
|
|
||||||
IteratorBundleWriter writer(&bundle_writer);
|
|
||||||
TF_RETURN_IF_ERROR(Save(ctx, &writer));
|
|
||||||
return bundle_writer.Finish();
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual Status Restore(OpKernelContext* ctx, const string& path) {
|
|
||||||
if (!(ctx->env()->FileExists(MetaFilename(path)).ok())) {
|
|
||||||
return errors::NotFound(
|
|
||||||
"Failed to restore Iterator state. No file found at ",
|
|
||||||
MetaFilename(path));
|
|
||||||
}
|
|
||||||
BundleReader bundle_reader(ctx->env(), path);
|
|
||||||
TF_RETURN_IF_ERROR(bundle_reader.status());
|
|
||||||
IteratorBundleReader reader(&bundle_reader);
|
|
||||||
return Restore(ctx, &reader);
|
|
||||||
}
|
|
||||||
|
|
||||||
static const char kIteratorExhausted[];
|
|
||||||
|
|
||||||
protected:
|
|
||||||
// This is needed so that sub-classes of IteratorBase can call
|
|
||||||
// `RestoreInternal` on their parent iterators, e.g., in
|
|
||||||
// `RepeatDataasetOp::Dataset`.
|
|
||||||
class IteratorBundleReader : public BundleReaderWrapper {
|
|
||||||
public:
|
|
||||||
IteratorBundleReader(BundleReader* bundle_reader)
|
|
||||||
: BundleReaderWrapper(bundle_reader) {}
|
|
||||||
|
|
||||||
// Restores the state of a parent iterator recursively.
|
|
||||||
Status RestoreParent(OpKernelContext* ctx,
|
|
||||||
const std::unique_ptr<IteratorBase>& parent) {
|
|
||||||
return parent->RestoreInternal(ctx, this);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// This is needed so that sub-classes of IteratorBase can call
|
|
||||||
// `SaveInternal` on their parent iterators, e.g., in
|
|
||||||
// `RepeatDataasetOp::Dataset`.
|
|
||||||
class IteratorBundleWriter : public BundleWriterWrapper {
|
|
||||||
public:
|
|
||||||
IteratorBundleWriter(BundleWriter* bundle_writer)
|
|
||||||
: BundleWriterWrapper(bundle_writer) {}
|
|
||||||
// Saves the state of a parent iterator recursively.
|
|
||||||
Status SaveParent(OpKernelContext* ctx,
|
|
||||||
const std::unique_ptr<IteratorBase>& parent) {
|
|
||||||
return parent->SaveInternal(ctx, this);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
virtual Status Save(OpKernelContext* ctx, IteratorBundleWriter* writer) {
|
|
||||||
if (is_exhausted_) {
|
if (is_exhausted_) {
|
||||||
LOG(INFO) << "Iterator exhausted.";
|
LOG(INFO) << "Iterator exhausted.";
|
||||||
return writer->WriteScalar<string>(kIteratorExhausted,
|
return writer->WriteScalar(kIteratorExhausted, kIteratorExhausted);
|
||||||
kIteratorExhausted);
|
|
||||||
} else {
|
} else {
|
||||||
return SaveInternal(ctx, writer);
|
return SaveInternal(writer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Saves the state of this iterator.
|
// Restores the state of this iterator.
|
||||||
virtual Status SaveInternal(OpKernelContext* ctx,
|
virtual Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) {
|
||||||
IteratorBundleWriter* writer) {
|
|
||||||
return errors::Unimplemented("SaveInternal");
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual Status Restore(OpKernelContext* ctx, IteratorBundleReader* reader) {
|
|
||||||
if (reader->Contains(kIteratorExhausted)) {
|
if (reader->Contains(kIteratorExhausted)) {
|
||||||
LOG(INFO) << "Iterator exhausted. Nothing to restore.";
|
LOG(INFO) << "Iterator exhausted. Nothing to restore.";
|
||||||
is_exhausted_ = true;
|
is_exhausted_ = true;
|
||||||
@ -362,9 +273,33 @@ class IteratorBase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Restores the state of this iterator.
|
static const char kIteratorExhausted[];
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// This is needed so that sub-classes of IteratorBase can call
|
||||||
|
// `SaveInternal` on their parent iterators, e.g., in
|
||||||
|
// `RepeatDataasetOp::Dataset`.
|
||||||
|
Status SaveParent(IteratorStateWriter* writer,
|
||||||
|
const std::unique_ptr<IteratorBase>& parent) {
|
||||||
|
return parent->SaveInternal(writer);
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is needed so that sub-classes of IteratorBase can call
|
||||||
|
// `RestoreInternal` on their parent iterators, e.g., in
|
||||||
|
// `RepeatDataasetOp::Dataset`.
|
||||||
|
Status RestoreParent(OpKernelContext* ctx, IteratorStateReader* reader,
|
||||||
|
const std::unique_ptr<IteratorBase>& parent) {
|
||||||
|
return parent->RestoreInternal(ctx, reader);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Saves the state of this iterator recursively.
|
||||||
|
virtual Status SaveInternal(IteratorStateWriter* writer) {
|
||||||
|
return errors::Unimplemented("SaveInternal");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restores the state of this iterator recursively.
|
||||||
virtual Status RestoreInternal(OpKernelContext* ctx,
|
virtual Status RestoreInternal(OpKernelContext* ctx,
|
||||||
IteratorBundleReader* reader) {
|
IteratorStateReader* reader) {
|
||||||
return errors::Unimplemented("RestoreInternal");
|
return errors::Unimplemented("RestoreInternal");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -404,7 +339,7 @@ class DatasetBase : public core::RefCounted {
|
|||||||
virtual string DebugString() = 0;
|
virtual string DebugString() = 0;
|
||||||
|
|
||||||
// Serializes the dataset and writes it to the `writer`.
|
// Serializes the dataset and writes it to the `writer`.
|
||||||
virtual Status Save(BundleWriterWrapper* writer) const {
|
virtual Status Save(IteratorStateWriter* writer) const {
|
||||||
return errors::Unimplemented("DatasetBase::Save");
|
return errors::Unimplemented("DatasetBase::Save");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -435,20 +370,14 @@ class GraphDatasetBase : public DatasetBase {
|
|||||||
|
|
||||||
const string op_name() const { return op_name_; }
|
const string op_name() const { return op_name_; }
|
||||||
|
|
||||||
Status Save(BundleWriterWrapper* writer) const override {
|
Status Save(IteratorStateWriter* writer) const override {
|
||||||
GraphDefBuilder b;
|
|
||||||
DatasetGraphDefBuilder db(&b);
|
|
||||||
Node* node = nullptr;
|
|
||||||
TF_RETURN_IF_ERROR(AsGraphDefInternal(&db, &node));
|
|
||||||
string output_name = node->name();
|
|
||||||
GraphDef graph_def;
|
|
||||||
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
|
|
||||||
string serialized_graph_def;
|
string serialized_graph_def;
|
||||||
graph_def.SerializeToString(&serialized_graph_def);
|
string output_node;
|
||||||
|
TF_RETURN_IF_ERROR(Serialize(&serialized_graph_def, &output_node));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
writer->WriteScalar<string>(kDatasetGraphKey, serialized_graph_def));
|
writer->WriteScalar(kDatasetGraphKey, serialized_graph_def));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
writer->WriteScalar<string>(kDatasetGraphOutputNodeKey, output_name));
|
writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -460,6 +389,18 @@ class GraphDatasetBase : public DatasetBase {
|
|||||||
static const char kDatasetGraphOutputNodeKey[];
|
static const char kDatasetGraphOutputNodeKey[];
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
Status Serialize(string* serialized_graph_def, string* output_node) const {
|
||||||
|
GraphDefBuilder b;
|
||||||
|
DatasetGraphDefBuilder db(&b);
|
||||||
|
Node* node = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(AsGraphDefInternal(&db, &node));
|
||||||
|
*output_node = node->name();
|
||||||
|
GraphDef graph_def;
|
||||||
|
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
|
||||||
|
graph_def.SerializeToString(serialized_graph_def);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
const string op_name_;
|
const string op_name_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -505,18 +446,18 @@ class DatasetIterator : public IteratorBase {
|
|||||||
return GetNextInternal(ctx, out_tensors, end_of_sequence);
|
return GetNextInternal(ctx, out_tensors, end_of_sequence);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
Status Save(IteratorStateWriter* writer) final {
|
||||||
Status Save(OpKernelContext* ctx, IteratorBundleWriter* writer) final {
|
|
||||||
TF_RETURN_IF_ERROR(dataset()->Save(writer));
|
TF_RETURN_IF_ERROR(dataset()->Save(writer));
|
||||||
return IteratorBase::Save(ctx, writer);
|
return IteratorBase::Save(writer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
// Internal implementation of GetNext that is wrapped in tracing logic.
|
// Internal implementation of GetNext that is wrapped in tracing logic.
|
||||||
virtual Status GetNextInternal(IteratorContext* ctx,
|
virtual Status GetNextInternal(IteratorContext* ctx,
|
||||||
std::vector<Tensor>* out_tensors,
|
std::vector<Tensor>* out_tensors,
|
||||||
bool* end_of_sequence) = 0;
|
bool* end_of_sequence) = 0;
|
||||||
|
|
||||||
string full_name(const string& name) {
|
string full_name(const string& name) const {
|
||||||
return strings::StrCat(prefix(), ":", name);
|
return strings::StrCat(prefix(), ":", name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,9 +16,11 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
#include "tensorflow/core/common_runtime/graph_runner.h"
|
#include "tensorflow/core/common_runtime/graph_runner.h"
|
||||||
|
#include "tensorflow/core/framework/iterator.pb.h"
|
||||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/resource_op_kernel.h"
|
#include "tensorflow/core/framework/resource_op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
#include "tensorflow/core/kernels/ops_util.h"
|
#include "tensorflow/core/kernels/ops_util.h"
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
@ -35,6 +37,8 @@ namespace {
|
|||||||
// See documentation in ../ops/dataset_ops.cc for a high-level
|
// See documentation in ../ops/dataset_ops.cc for a high-level
|
||||||
// description of the following ops.
|
// description of the following ops.
|
||||||
|
|
||||||
|
const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
|
||||||
|
|
||||||
Status VerifyTypesMatch(const DataTypeVector& expected,
|
Status VerifyTypesMatch(const DataTypeVector& expected,
|
||||||
const DataTypeVector& received) {
|
const DataTypeVector& received) {
|
||||||
if (expected.size() != received.size()) {
|
if (expected.size() != received.size()) {
|
||||||
@ -93,10 +97,10 @@ class IteratorResource : public ResourceBase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Save(OpKernelContext* ctx, const string& path) {
|
Status Save(IteratorStateWriter* writer) {
|
||||||
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
|
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
|
||||||
if (captured_iterator) {
|
if (captured_iterator) {
|
||||||
return captured_iterator->Save(ctx, path);
|
return captured_iterator->Save(writer);
|
||||||
} else {
|
} else {
|
||||||
return errors::FailedPrecondition(
|
return errors::FailedPrecondition(
|
||||||
"Save() failed because the iterator has not been initialized. "
|
"Save() failed because the iterator has not been initialized. "
|
||||||
@ -105,49 +109,34 @@ class IteratorResource : public ResourceBase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Restore(OpKernelContext* ctx, const string& path) {
|
Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) {
|
||||||
if (!(ctx->env()->FileExists(MetaFilename(path)).ok())) {
|
string serialized_graph_def;
|
||||||
return errors::NotFound(
|
TF_RETURN_IF_ERROR(reader->ReadScalar(GraphDatasetBase::kDatasetGraphKey,
|
||||||
"Failed to restore Iterator state. No file found at ",
|
&serialized_graph_def));
|
||||||
MetaFilename(path));
|
GraphDef graph_def;
|
||||||
|
if (!graph_def.ParseFromString(serialized_graph_def)) {
|
||||||
|
return errors::Internal("Error parsing dataset GraphDef.");
|
||||||
}
|
}
|
||||||
|
string output_node;
|
||||||
|
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||||
|
GraphDatasetBase::kDatasetGraphOutputNodeKey, &output_node));
|
||||||
|
DatasetBase* dataset = nullptr;
|
||||||
|
Graph graph(OpRegistry::Global());
|
||||||
|
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
|
||||||
|
std::vector<Tensor> outputs;
|
||||||
|
GraphRunner graph_runner(ctx->env());
|
||||||
|
TF_RETURN_IF_ERROR(graph_runner.Run(&graph, ctx->function_library(), {},
|
||||||
|
{output_node}, &outputs));
|
||||||
|
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
|
||||||
|
|
||||||
BundleReader bundle_reader(ctx->env(), path);
|
TF_RETURN_IF_ERROR(set_iterator(dataset->MakeIterator("Iterator")));
|
||||||
TF_RETURN_IF_ERROR(bundle_reader.status());
|
|
||||||
BundleReaderWrapper reader(&bundle_reader);
|
|
||||||
if (reader.Contains(GraphDatasetBase::kDatasetGraphKey)) {
|
|
||||||
string serialized_graph_def;
|
|
||||||
TF_RETURN_IF_ERROR(reader.ReadScalar(GraphDatasetBase::kDatasetGraphKey,
|
|
||||||
&serialized_graph_def));
|
|
||||||
GraphDef graph_def;
|
|
||||||
graph_def.ParseFromString(serialized_graph_def);
|
|
||||||
// TODO(srbs): Is there a way of getting the op registry of the original
|
|
||||||
// graph.
|
|
||||||
Graph graph(OpRegistry::Global());
|
|
||||||
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
|
|
||||||
string output_node;
|
|
||||||
TF_RETURN_IF_ERROR(reader.ReadScalar(
|
|
||||||
GraphDatasetBase::kDatasetGraphOutputNodeKey, &output_node));
|
|
||||||
std::vector<Tensor> outputs;
|
|
||||||
GraphRunner graph_runner(ctx->env());
|
|
||||||
TF_RETURN_IF_ERROR(graph_runner.Run(&graph, ctx->function_library(), {},
|
|
||||||
{output_node}, &outputs));
|
|
||||||
DatasetBase* dataset;
|
|
||||||
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
|
|
||||||
TF_RETURN_IF_ERROR(set_iterator(dataset->MakeIterator("Iterator")));
|
|
||||||
} else if (reader.Contains(IteratorBase::kIteratorExhausted)) {
|
|
||||||
TF_RETURN_IF_ERROR(set_iterator(std::unique_ptr<IteratorBase>(
|
|
||||||
new ExhaustedIterator(output_dtypes_, output_shapes_))));
|
|
||||||
}
|
|
||||||
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
|
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
|
||||||
|
|
||||||
if (captured_iterator) {
|
if (captured_iterator) {
|
||||||
// TODO(srbs): Figure a way to pass bundle_reader here.
|
return captured_iterator->Restore(ctx, reader);
|
||||||
return captured_iterator->Restore(ctx, path);
|
|
||||||
} else {
|
} else {
|
||||||
return errors::FailedPrecondition(
|
return errors::FailedPrecondition(
|
||||||
"Failed to restore iterator from ", path,
|
"Failed to restore iterator. Make sure the checkpoint ",
|
||||||
". Make sure the checkpoint ",
|
|
||||||
"is not corrupt. If the checkpoint does not contain the GraphDef, ",
|
"is not corrupt. If the checkpoint does not contain the GraphDef, ",
|
||||||
"you will need to initialize your iterator before restoring.");
|
"you will need to initialize your iterator before restoring.");
|
||||||
}
|
}
|
||||||
@ -174,43 +163,194 @@ class IteratorResource : public ResourceBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// A no-op iterator which always sets end_of_sequence = true. An instance of
|
|
||||||
// this is returned when attempting to restore an exhausted iterator. This is
|
|
||||||
// needed because the Dataset GraphDef may not have been saved for exhausted
|
|
||||||
// iterators so the actual Iterator can not be built.
|
|
||||||
class ExhaustedIterator : public IteratorBase {
|
|
||||||
public:
|
|
||||||
ExhaustedIterator(const DataTypeVector& output_dtypes,
|
|
||||||
const std::vector<PartialTensorShape>& output_shapes)
|
|
||||||
: output_dtypes_(output_dtypes), output_shapes_(output_shapes) {}
|
|
||||||
Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
|
||||||
bool* end_of_sequence) final {
|
|
||||||
*end_of_sequence = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
const DataTypeVector& output_dtypes() const override {
|
|
||||||
return output_dtypes_;
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
|
||||||
return output_shapes_;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const std::vector<PartialTensorShape>& output_shapes() {
|
|
||||||
return output_shapes_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
const DataTypeVector output_dtypes_;
|
|
||||||
const std::vector<PartialTensorShape> output_shapes_;
|
|
||||||
};
|
|
||||||
|
|
||||||
std::shared_ptr<IteratorBase> iterator_;
|
std::shared_ptr<IteratorBase> iterator_;
|
||||||
const DataTypeVector output_dtypes_;
|
const DataTypeVector output_dtypes_;
|
||||||
const std::vector<PartialTensorShape> output_shapes_;
|
const std::vector<PartialTensorShape> output_shapes_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Helper class for reading data from a VariantTensorData object.
|
||||||
|
class VariantTensorDataReader : public IteratorStateReader {
|
||||||
|
public:
|
||||||
|
explicit VariantTensorDataReader(const VariantTensorData* data)
|
||||||
|
: data_(data) {
|
||||||
|
PreProcess();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns OK iff the initialization was successful, i.e.,
|
||||||
|
// pre-processing did not have errors.
|
||||||
|
Status status() const { return status_; }
|
||||||
|
|
||||||
|
Status ReadScalar(StringPiece key, int64* val) override {
|
||||||
|
return ReadScalarInternal(key, val);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ReadScalar(StringPiece key, string* val) override {
|
||||||
|
return ReadScalarInternal(key, val);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Contains(StringPiece key) override {
|
||||||
|
return map_.find(key.ToString()) != map_.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void PreProcess() {
|
||||||
|
string metadata;
|
||||||
|
data_->get_metadata(&metadata);
|
||||||
|
IteratorStateMetadata proto;
|
||||||
|
if (!proto.ParseFromString(metadata)) {
|
||||||
|
status_ = errors::Internal("Error parsing IteratorStateMetadata.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
size_t num_entries = proto.keys_size();
|
||||||
|
CHECK_EQ(num_entries, data_->tensors_size());
|
||||||
|
for (size_t i = 0; i < num_entries; i++) {
|
||||||
|
map_[proto.keys(i)] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status ReadScalarInternal(StringPiece key, T* val) {
|
||||||
|
if (map_.find(key.ToString()) == map_.end()) {
|
||||||
|
return errors::NotFound(key);
|
||||||
|
}
|
||||||
|
*val = data_->tensors(map_[key.ToString()]).scalar<T>()();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::map<string, size_t> map_;
|
||||||
|
const VariantTensorData* data_; // Not owned.
|
||||||
|
Status status_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Helper class for writing data to a VariantTensorData object.
|
||||||
|
class VariantTensorDataWriter : public IteratorStateWriter {
|
||||||
|
public:
|
||||||
|
// Does not take ownership of data.
|
||||||
|
explicit VariantTensorDataWriter(VariantTensorData* data) : data_(data) {}
|
||||||
|
|
||||||
|
Status WriteScalar(StringPiece key, const int64& val) override {
|
||||||
|
return WriteScalarInternal(key, val);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status WriteScalar(StringPiece key, const string& val) override {
|
||||||
|
return WriteScalarInternal(key, val);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Writes the metadata to `data_`.
|
||||||
|
Status Flush() {
|
||||||
|
string metadata;
|
||||||
|
if (!metadata_proto_.SerializeToString(&metadata)) {
|
||||||
|
return errors::Internal("Unable to serialize IteratorStateMetadata.");
|
||||||
|
}
|
||||||
|
data_->set_metadata(metadata);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename T>
|
||||||
|
Status WriteScalarInternal(StringPiece key, const T& val) {
|
||||||
|
// Write key to the metadata proto. This gets written to `data_`
|
||||||
|
// when `Flush()` is called. We do this lazily to avoid multiple
|
||||||
|
// serialization calls.
|
||||||
|
metadata_proto_.add_keys(key.ToString());
|
||||||
|
|
||||||
|
// Update tensors.
|
||||||
|
Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
|
||||||
|
val_t.scalar<T>()() = val;
|
||||||
|
*(data_->add_tensors()) = std::move(val_t);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
VariantTensorData* data_;
|
||||||
|
// TODO(srbs): Set the version string.
|
||||||
|
IteratorStateMetadata metadata_proto_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Wrapper for encoding/decoding the iterator state stored in a Variant tensor.
|
||||||
|
// The get() method returns an IteratorStateReader which can be used
|
||||||
|
// to restore iterator state.
|
||||||
|
//
|
||||||
|
// Usage example:
|
||||||
|
//
|
||||||
|
// Encoding:
|
||||||
|
//
|
||||||
|
// Tensor t(DT_VARIANT, TensorShape({}));
|
||||||
|
// t->scalar<Variant>()() = IteratorStateVariant(iterator_resource);
|
||||||
|
//
|
||||||
|
// Encode() sets the type_name of the VariantTensorData object to
|
||||||
|
// IteratorStateVariant::TypeName().
|
||||||
|
//
|
||||||
|
// Decoding:
|
||||||
|
//
|
||||||
|
// Variant v = <VariantTensorDataProto object>;
|
||||||
|
// DecodeUnaryVariant(&v);
|
||||||
|
// IteratorStateVariant* wrapper = v.get<IteratorStateVariant>();
|
||||||
|
// iterator_resource->Restore(ctx, wrapper->get())
|
||||||
|
//
|
||||||
|
// The type_name of the VariantTensorData object to be decoded must
|
||||||
|
// match IteratorStateVariant::TypeName().
|
||||||
|
class IteratorStateVariant {
|
||||||
|
public:
|
||||||
|
IteratorStateVariant() : data_(nullptr) {}
|
||||||
|
IteratorStateVariant(const IteratorStateVariant& other) : data_(nullptr) {
|
||||||
|
if (other.data_) {
|
||||||
|
Decode(*other.data_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Initializes this object with the current state of the iterator so
|
||||||
|
// that it can be written on the next call to Encode().
|
||||||
|
Status InitializeFromIterator(IteratorResource* iterator_resource) {
|
||||||
|
data_.reset(new VariantTensorData());
|
||||||
|
data_->set_type_name(TypeName());
|
||||||
|
VariantTensorDataWriter writer(data_.get());
|
||||||
|
TF_RETURN_IF_ERROR(iterator_resource->Save(&writer));
|
||||||
|
TF_RETURN_IF_ERROR(writer.Flush());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
string TypeName() const { return kIteratorVariantTypeName; }
|
||||||
|
void Encode(VariantTensorData* data) const { *data = *data_; }
|
||||||
|
bool Decode(const VariantTensorData& data) {
|
||||||
|
if (data.type_name() != TypeName()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::unique_ptr<VariantTensorData> tensor_data(new VariantTensorData);
|
||||||
|
*tensor_data = data;
|
||||||
|
std::unique_ptr<VariantTensorDataReader> reader(
|
||||||
|
new VariantTensorDataReader(tensor_data.get()));
|
||||||
|
status_ = reader->status();
|
||||||
|
if (!status_.ok()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
data_ = std::move(tensor_data);
|
||||||
|
reader_ = std::move(reader);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
IteratorStateReader* get() { return reader_.get(); }
|
||||||
|
Status status() const { return status_; }
|
||||||
|
string DebugString() const {
|
||||||
|
if (data_) {
|
||||||
|
return strings::StrCat("IteratorStateVariant<",
|
||||||
|
"data: ", data_->DebugString(),
|
||||||
|
" status: ", status_.ToString(), ">");
|
||||||
|
} else {
|
||||||
|
return strings::StrCat("IteratorStateVariant<empty>");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<IteratorStateReader> reader_;
|
||||||
|
Status status_;
|
||||||
|
std::unique_ptr<VariantTensorData> data_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Register the reader class in the global variant decode_fn registry
|
||||||
|
// so that a Variant containing a serialized representation of iterator state
|
||||||
|
// can be decoded using DecodeUnaryVariant. If we don't do this we will need
|
||||||
|
// to manually decode the returned Variant using MaybeDecodeAndCopy in
|
||||||
|
// DeserializeIteratorOp which is not recommended.
|
||||||
|
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
|
||||||
|
kIteratorVariantTypeName);
|
||||||
|
|
||||||
// TODO(mrry): Can we simply use the template kernel here?
|
// TODO(mrry): Can we simply use the template kernel here?
|
||||||
class IteratorHandleOp : public ResourceOpKernel<IteratorResource> {
|
class IteratorHandleOp : public ResourceOpKernel<IteratorResource> {
|
||||||
public:
|
public:
|
||||||
@ -294,37 +434,6 @@ class ToSingleElementOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class SaveIteratorOp : public OpKernel {
|
|
||||||
public:
|
|
||||||
explicit SaveIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
|
||||||
IteratorResource* iterator_resource;
|
|
||||||
OP_REQUIRES_OK(
|
|
||||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
|
|
||||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->input(1).shape()),
|
|
||||||
errors::InvalidArgument("SaveIteratorOp: path must be scalar"));
|
|
||||||
const string& path = ctx->input(1).scalar<string>()();
|
|
||||||
OP_REQUIRES_OK(ctx, iterator_resource->Save(ctx, path));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class RestoreIteratorOp : public OpKernel {
|
|
||||||
public:
|
|
||||||
explicit RestoreIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
|
||||||
IteratorResource* iterator_resource;
|
|
||||||
OP_REQUIRES_OK(
|
|
||||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
|
|
||||||
OP_REQUIRES(
|
|
||||||
ctx, TensorShapeUtils::IsScalar(ctx->input(1).shape()),
|
|
||||||
errors::InvalidArgument("RestoreIteratorOp: path must be scalar"));
|
|
||||||
const string& path = ctx->input(1).scalar<string>()();
|
|
||||||
OP_REQUIRES_OK(ctx, iterator_resource->Restore(ctx, path));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class OneShotIteratorOp : public AsyncOpKernel {
|
class OneShotIteratorOp : public AsyncOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit OneShotIteratorOp(OpKernelConstruction* ctx)
|
explicit OneShotIteratorOp(OpKernelConstruction* ctx)
|
||||||
@ -644,15 +753,55 @@ class IteratorFromStringHandleOp : public OpKernel {
|
|||||||
std::vector<PartialTensorShape> output_shapes_;
|
std::vector<PartialTensorShape> output_shapes_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class SerializeIteratorOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit SerializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
const Tensor& resource_handle_t = ctx->input(0);
|
||||||
|
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
|
||||||
|
errors::InvalidArgument("resource_handle must be a scalar"));
|
||||||
|
|
||||||
|
// Validate that the handle corresponds to a real resource, and
|
||||||
|
// that it is an IteratorResource.
|
||||||
|
IteratorResource* iterator_resource;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
|
||||||
|
iterator_resource->Unref();
|
||||||
|
Tensor* variant_t;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &variant_t));
|
||||||
|
IteratorStateVariant v;
|
||||||
|
OP_REQUIRES_OK(ctx, v.InitializeFromIterator(iterator_resource));
|
||||||
|
variant_t->scalar<Variant>()() = v;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class DeserializeIteratorOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit DeserializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
// Validate that the handle corresponds to a real resource, and
|
||||||
|
// that it is an IteratorResource.
|
||||||
|
IteratorResource* iterator_resource;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
|
||||||
|
|
||||||
|
Variant variant = ctx->input(1).scalar<Variant>()();
|
||||||
|
auto* wrapper = variant.get<IteratorStateVariant>();
|
||||||
|
OP_REQUIRES(ctx, wrapper != nullptr,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"DeserializeIteratorOp: Unable to parse variant tensor."));
|
||||||
|
OP_REQUIRES_OK(ctx, wrapper->status());
|
||||||
|
OP_REQUIRES_OK(ctx, iterator_resource->Restore(ctx, wrapper->get()));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
|
REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
|
||||||
REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU),
|
||||||
MakeIteratorOp);
|
MakeIteratorOp);
|
||||||
REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
|
||||||
ToSingleElementOp);
|
ToSingleElementOp);
|
||||||
REGISTER_KERNEL_BUILDER(Name("SaveIterator").Device(DEVICE_CPU),
|
|
||||||
SaveIteratorOp);
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("RestoreIterator").Device(DEVICE_CPU),
|
|
||||||
RestoreIteratorOp);
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
|
||||||
OneShotIteratorOp);
|
OneShotIteratorOp);
|
||||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU),
|
||||||
@ -661,6 +810,10 @@ REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle").Device(DEVICE_CPU),
|
|||||||
IteratorToStringHandleOp);
|
IteratorToStringHandleOp);
|
||||||
REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU),
|
||||||
IteratorFromStringHandleOp);
|
IteratorFromStringHandleOp);
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
|
||||||
|
SerializeIteratorOp);
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
|
||||||
|
DeserializeIteratorOp);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -92,6 +92,7 @@ class SerializeTensorOp : public OpKernel {
|
|||||||
Name("SerializeTensor").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
Name("SerializeTensor").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||||
SerializeTensorOp<T>);
|
SerializeTensorOp<T>);
|
||||||
TF_CALL_ALL_TYPES(REGISTER)
|
TF_CALL_ALL_TYPES(REGISTER)
|
||||||
|
TF_CALL_variant(REGISTER)
|
||||||
#undef REGISTER
|
#undef REGISTER
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -112,19 +112,16 @@ class RangeDatasetOp : public DatasetOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Status SaveInternal(OpKernelContext* ctx,
|
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||||
IteratorBundleWriter* writer) override {
|
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("next"), next_));
|
||||||
writer->WriteScalar<int64>(full_name("next"), next_));
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RestoreInternal(OpKernelContext* ctx,
|
Status RestoreInternal(OpKernelContext* ctx,
|
||||||
IteratorBundleReader* reader) override {
|
IteratorStateReader* reader) override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("next"), &next_));
|
||||||
reader->ReadScalar<int64>(full_name("next"), &next_));
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -356,31 +356,30 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Status SaveInternal(OpKernelContext* ctx,
|
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||||
IteratorBundleWriter* writer) override {
|
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
TF_RETURN_IF_ERROR(writer->WriteScalar<int64>(
|
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"),
|
||||||
full_name("current_file_index"), current_file_index_));
|
current_file_index_));
|
||||||
|
|
||||||
// `input_buffer_` is empty if
|
// `input_buffer_` is empty if
|
||||||
// 1. GetNext has not been called even once.
|
// 1. GetNext has not been called even once.
|
||||||
// 2. All files have been read and iterator has been exhausted.
|
// 2. All files have been read and iterator has been exhausted.
|
||||||
int64 current_pos = input_buffer_ ? input_buffer_->Tell() : -1;
|
int64 current_pos = input_buffer_ ? input_buffer_->Tell() : -1;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
writer->WriteScalar<int64>(full_name("current_pos"), current_pos));
|
writer->WriteScalar(full_name("current_pos"), current_pos));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RestoreInternal(OpKernelContext* ctx,
|
Status RestoreInternal(OpKernelContext* ctx,
|
||||||
IteratorBundleReader* reader) override {
|
IteratorStateReader* reader) override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
int64 current_file_index;
|
int64 current_file_index;
|
||||||
TF_RETURN_IF_ERROR(reader->ReadScalar<int64>(
|
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"),
|
||||||
full_name("current_file_index"), ¤t_file_index));
|
¤t_file_index));
|
||||||
current_file_index_ = size_t(current_file_index);
|
current_file_index_ = size_t(current_file_index);
|
||||||
int64 current_pos;
|
int64 current_pos;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
reader->ReadScalar<int64>(full_name("current_pos"), ¤t_pos));
|
reader->ReadScalar(full_name("current_pos"), ¤t_pos));
|
||||||
|
|
||||||
// Seek to current_pos.
|
// Seek to current_pos.
|
||||||
input_buffer_.reset();
|
input_buffer_.reset();
|
||||||
|
@ -124,19 +124,18 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Status SaveInternal(OpKernelContext* ctx,
|
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||||
IteratorBundleWriter* writer) override {
|
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
TF_RETURN_IF_ERROR(writer->WriteScalar<int64>(full_name("i"), i_));
|
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
|
||||||
TF_RETURN_IF_ERROR(writer->SaveParent(ctx, input_impl_));
|
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RestoreInternal(OpKernelContext* ctx,
|
Status RestoreInternal(OpKernelContext* ctx,
|
||||||
IteratorBundleReader* reader) override {
|
IteratorStateReader* reader) override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
TF_RETURN_IF_ERROR(reader->ReadScalar<int64>(full_name("i"), &i_));
|
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
|
||||||
TF_RETURN_IF_ERROR(reader->RestoreParent(ctx, input_impl_));
|
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28753,18 +28753,6 @@ op {
|
|||||||
}
|
}
|
||||||
is_stateful: true
|
is_stateful: true
|
||||||
}
|
}
|
||||||
op {
|
|
||||||
name: "RestoreIterator"
|
|
||||||
input_arg {
|
|
||||||
name: "iterator"
|
|
||||||
type: DT_RESOURCE
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "path"
|
|
||||||
type: DT_STRING
|
|
||||||
}
|
|
||||||
is_stateful: true
|
|
||||||
}
|
|
||||||
op {
|
op {
|
||||||
name: "RestoreSlice"
|
name: "RestoreSlice"
|
||||||
input_arg {
|
input_arg {
|
||||||
@ -29548,18 +29536,6 @@ op {
|
|||||||
}
|
}
|
||||||
is_stateful: true
|
is_stateful: true
|
||||||
}
|
}
|
||||||
op {
|
|
||||||
name: "SaveIterator"
|
|
||||||
input_arg {
|
|
||||||
name: "iterator"
|
|
||||||
type: DT_RESOURCE
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "path"
|
|
||||||
type: DT_STRING
|
|
||||||
}
|
|
||||||
is_stateful: true
|
|
||||||
}
|
|
||||||
op {
|
op {
|
||||||
name: "SaveSlices"
|
name: "SaveSlices"
|
||||||
input_arg {
|
input_arg {
|
||||||
|
@ -598,24 +598,6 @@ This operation may be executed multiple times. Each execution will reset the
|
|||||||
iterator in `iterator` to the first element of `dataset`.
|
iterator in `iterator` to the first element of `dataset`.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
REGISTER_OP("SaveIterator")
|
|
||||||
.Input("iterator: resource")
|
|
||||||
.Input("path: string")
|
|
||||||
.SetShapeFn(shape_inference::NoOutputs)
|
|
||||||
.Doc(R"doc(
|
|
||||||
Saves the state of the `iterator` at `path`.
|
|
||||||
|
|
||||||
This state can be restored using "RestoreIterator".
|
|
||||||
)doc");
|
|
||||||
|
|
||||||
REGISTER_OP("RestoreIterator")
|
|
||||||
.Input("iterator: resource")
|
|
||||||
.Input("path: string")
|
|
||||||
.SetShapeFn(shape_inference::NoOutputs)
|
|
||||||
.Doc(R"doc(
|
|
||||||
Restores the state of the `iterator` from the checkpoint saved at `path` using "SaveIterator".
|
|
||||||
)doc");
|
|
||||||
|
|
||||||
REGISTER_OP("OneShotIterator")
|
REGISTER_OP("OneShotIterator")
|
||||||
.Output("handle: resource")
|
.Output("handle: resource")
|
||||||
.Attr("dataset_factory: func")
|
.Attr("dataset_factory: func")
|
||||||
@ -737,4 +719,28 @@ output_shapes: If specified, defines the shape of each tuple component in an
|
|||||||
element produced by the resulting iterator.
|
element produced by the resulting iterator.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("SerializeIterator")
|
||||||
|
.Input("resource_handle: resource")
|
||||||
|
.Output("serialized: variant")
|
||||||
|
.SetShapeFn(shape_inference::ScalarShape)
|
||||||
|
.Doc(R"doc(
|
||||||
|
Converts the given `resource_handle` representing an iterator to a variant tensor.
|
||||||
|
|
||||||
|
resource_handle: A handle to an iterator resource.
|
||||||
|
serialized: A variant tensor storing the state of the iterator contained in the
|
||||||
|
resource.
|
||||||
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("DeserializeIterator")
|
||||||
|
.Input("resource_handle: resource")
|
||||||
|
.Input("serialized: variant")
|
||||||
|
.SetShapeFn(shape_inference::NoOutputs)
|
||||||
|
.Doc(R"doc(
|
||||||
|
Converts the given variant tensor to an iterator and stores it in the given resource.
|
||||||
|
|
||||||
|
resource_handle: A handle to an iterator resource.
|
||||||
|
serialized: A variant tensor storing the state of the iterator contained in the
|
||||||
|
resource.
|
||||||
|
)doc");
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -2886,7 +2886,9 @@ tf_py_test(
|
|||||||
"//tensorflow/python:dataset_ops_gen",
|
"//tensorflow/python:dataset_ops_gen",
|
||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:errors",
|
"//tensorflow/python:errors",
|
||||||
|
"//tensorflow/python:io_ops",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:parsing_ops",
|
||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
"//tensorflow/python:tensor_shape",
|
"//tensorflow/python:tensor_shape",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
@ -2907,7 +2909,9 @@ tf_py_test(
|
|||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:errors",
|
"//tensorflow/python:errors",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:io_ops",
|
||||||
"//tensorflow/python:lib",
|
"//tensorflow/python:lib",
|
||||||
|
"//tensorflow/python:parsing_ops",
|
||||||
"//tensorflow/python:tensor_shape",
|
"//tensorflow/python:tensor_shape",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python/data/ops:iterator_ops",
|
"//tensorflow/python/data/ops:iterator_ops",
|
||||||
@ -3022,6 +3026,7 @@ tf_py_test(
|
|||||||
"//tensorflow/python:function",
|
"//tensorflow/python:function",
|
||||||
"//tensorflow/python:functional_ops",
|
"//tensorflow/python:functional_ops",
|
||||||
"//tensorflow/python:gradients",
|
"//tensorflow/python:gradients",
|
||||||
|
"//tensorflow/python:io_ops",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:parsing_ops",
|
"//tensorflow/python:parsing_ops",
|
||||||
"//tensorflow/python:script_ops",
|
"//tensorflow/python:script_ops",
|
||||||
|
@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import functional_ops
|
from tensorflow.python.ops import functional_ops
|
||||||
from tensorflow.python.ops import gen_dataset_ops
|
from tensorflow.python.ops import gen_dataset_ops
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
|
from tensorflow.python.ops import io_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import parsing_ops
|
from tensorflow.python.ops import parsing_ops
|
||||||
from tensorflow.python.ops import script_ops
|
from tensorflow.python.ops import script_ops
|
||||||
@ -538,9 +539,23 @@ class IteratorTest(test.TestCase):
|
|||||||
|
|
||||||
def testIncorrectIteratorRestore(self):
|
def testIncorrectIteratorRestore(self):
|
||||||
|
|
||||||
def _iterator_checkpoint_prefix():
|
def _path():
|
||||||
return os.path.join(self.get_temp_dir(), "iterator")
|
return os.path.join(self.get_temp_dir(), "iterator")
|
||||||
|
|
||||||
|
def _save_op(iterator_resource):
|
||||||
|
iterator_state_variant = gen_dataset_ops.serialize_iterator(
|
||||||
|
iterator_resource)
|
||||||
|
save_op = io_ops.write_file(
|
||||||
|
_path(), parsing_ops.serialize_tensor(iterator_state_variant))
|
||||||
|
return save_op
|
||||||
|
|
||||||
|
def _restore_op(iterator_resource):
|
||||||
|
iterator_state_variant = parsing_ops.parse_tensor(
|
||||||
|
io_ops.read_file(_path()), dtypes.variant)
|
||||||
|
restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
|
||||||
|
iterator_state_variant)
|
||||||
|
return restore_op
|
||||||
|
|
||||||
def _build_range_dataset_graph():
|
def _build_range_dataset_graph():
|
||||||
start = 1
|
start = 1
|
||||||
stop = 10
|
stop = 10
|
||||||
@ -548,22 +563,18 @@ class IteratorTest(test.TestCase):
|
|||||||
stop).make_initializable_iterator()
|
stop).make_initializable_iterator()
|
||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
path = _iterator_checkpoint_prefix()
|
save_op = _save_op(iterator._iterator_resource)
|
||||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
restore_op = _restore_op(iterator._iterator_resource)
|
||||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
|
||||||
path)
|
|
||||||
return init_op, get_next, save_op, restore_op
|
return init_op, get_next, save_op, restore_op
|
||||||
|
|
||||||
def _build_reader_dataset_graph():
|
def _build_reader_dataset_graph():
|
||||||
filenames = ["test"] # Does not exist but we don't care in this test.
|
filenames = ["test"] # Does not exist but we don't care in this test.
|
||||||
path = _iterator_checkpoint_prefix()
|
|
||||||
iterator = readers.FixedLengthRecordDataset(
|
iterator = readers.FixedLengthRecordDataset(
|
||||||
filenames, 1, 0, 0).make_initializable_iterator()
|
filenames, 1, 0, 0).make_initializable_iterator()
|
||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next_op = iterator.get_next()
|
get_next_op = iterator.get_next()
|
||||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
save_op = _save_op(iterator._iterator_resource)
|
||||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
restore_op = _restore_op(iterator._iterator_resource)
|
||||||
path)
|
|
||||||
return init_op, get_next_op, save_op, restore_op
|
return init_op, get_next_op, save_op, restore_op
|
||||||
|
|
||||||
# Saving iterator for RangeDataset graph.
|
# Saving iterator for RangeDataset graph.
|
||||||
|
@ -27,6 +27,8 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gen_dataset_ops
|
from tensorflow.python.ops import gen_dataset_ops
|
||||||
|
from tensorflow.python.ops import io_ops
|
||||||
|
from tensorflow.python.ops import parsing_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -169,6 +171,21 @@ class RangeDatasetTest(test.TestCase):
|
|||||||
def _iterator_checkpoint_prefix(self):
|
def _iterator_checkpoint_prefix(self):
|
||||||
return os.path.join(self.get_temp_dir(), "iterator")
|
return os.path.join(self.get_temp_dir(), "iterator")
|
||||||
|
|
||||||
|
def _save_op(self, iterator_resource):
|
||||||
|
iterator_state_variant = gen_dataset_ops.serialize_iterator(
|
||||||
|
iterator_resource)
|
||||||
|
save_op = io_ops.write_file(
|
||||||
|
self._iterator_checkpoint_prefix(),
|
||||||
|
parsing_ops.serialize_tensor(iterator_state_variant))
|
||||||
|
return save_op
|
||||||
|
|
||||||
|
def _restore_op(self, iterator_resource):
|
||||||
|
iterator_state_variant = parsing_ops.parse_tensor(
|
||||||
|
io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant)
|
||||||
|
restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
|
||||||
|
iterator_state_variant)
|
||||||
|
return restore_op
|
||||||
|
|
||||||
def testSaveRestore(self):
|
def testSaveRestore(self):
|
||||||
|
|
||||||
def _build_graph(start, stop):
|
def _build_graph(start, stop):
|
||||||
@ -176,10 +193,8 @@ class RangeDatasetTest(test.TestCase):
|
|||||||
stop).make_initializable_iterator()
|
stop).make_initializable_iterator()
|
||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
path = self._iterator_checkpoint_prefix()
|
save_op = self._save_op(iterator._iterator_resource)
|
||||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
restore_op = self._restore_op(iterator._iterator_resource)
|
||||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
|
||||||
path)
|
|
||||||
return init_op, get_next, save_op, restore_op
|
return init_op, get_next, save_op, restore_op
|
||||||
|
|
||||||
# Saving and restoring in different sessions.
|
# Saving and restoring in different sessions.
|
||||||
@ -222,14 +237,13 @@ class RangeDatasetTest(test.TestCase):
|
|||||||
|
|
||||||
def testRestoreWithoutBuildingDatasetGraph(self):
|
def testRestoreWithoutBuildingDatasetGraph(self):
|
||||||
|
|
||||||
def _build_graph(start, stop, num_epochs, path):
|
def _build_graph(start, stop, num_epochs):
|
||||||
dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs)
|
dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs)
|
||||||
iterator = dataset.make_initializable_iterator()
|
iterator = dataset.make_initializable_iterator()
|
||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
save_op = self._save_op(iterator._iterator_resource)
|
||||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
restore_op = self._restore_op(iterator._iterator_resource)
|
||||||
path)
|
|
||||||
return init_op, get_next, save_op, restore_op
|
return init_op, get_next, save_op, restore_op
|
||||||
|
|
||||||
# Saving and restoring in different sessions.
|
# Saving and restoring in different sessions.
|
||||||
@ -238,10 +252,8 @@ class RangeDatasetTest(test.TestCase):
|
|||||||
num_epochs = 5
|
num_epochs = 5
|
||||||
break_point = 5
|
break_point = 5
|
||||||
break_epoch = 3
|
break_epoch = 3
|
||||||
path = self._iterator_checkpoint_prefix()
|
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs,
|
init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs)
|
||||||
path)
|
|
||||||
with self.test_session(graph=g) as sess:
|
with self.test_session(graph=g) as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
sess.run(variables.global_variables_initializer())
|
||||||
sess.run(init_op)
|
sess.run(init_op)
|
||||||
@ -258,8 +270,7 @@ class RangeDatasetTest(test.TestCase):
|
|||||||
output_shapes = tensor_shape.scalar()
|
output_shapes = tensor_shape.scalar()
|
||||||
iterator = iterator_ops.Iterator.from_structure(output_types,
|
iterator = iterator_ops.Iterator.from_structure(output_types,
|
||||||
output_shapes)
|
output_shapes)
|
||||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
restore_op = self._restore_op(iterator._iterator_resource)
|
||||||
path)
|
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
with self.test_session(graph=g) as sess:
|
with self.test_session(graph=g) as sess:
|
||||||
sess.run(restore_op)
|
sess.run(restore_op)
|
||||||
@ -278,10 +289,8 @@ class RangeDatasetTest(test.TestCase):
|
|||||||
iterator = dataset.make_initializable_iterator()
|
iterator = dataset.make_initializable_iterator()
|
||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
path = self._iterator_checkpoint_prefix()
|
save_op = self._save_op(iterator._iterator_resource)
|
||||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
restore_op = self._restore_op(iterator._iterator_resource)
|
||||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
|
||||||
path)
|
|
||||||
return init_op, get_next, save_op, restore_op
|
return init_op, get_next, save_op, restore_op
|
||||||
|
|
||||||
# Saving and restoring in different sessions.
|
# Saving and restoring in different sessions.
|
||||||
@ -319,10 +328,8 @@ class RangeDatasetTest(test.TestCase):
|
|||||||
iterator = dataset.make_initializable_iterator()
|
iterator = dataset.make_initializable_iterator()
|
||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
path = self._iterator_checkpoint_prefix()
|
save_op = self._save_op(iterator._iterator_resource)
|
||||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
restore_op = self._restore_op(iterator._iterator_resource)
|
||||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
|
||||||
path)
|
|
||||||
return init_op, get_next, save_op, restore_op
|
return init_op, get_next, save_op, restore_op
|
||||||
|
|
||||||
# Saving and restoring in different sessions.
|
# Saving and restoring in different sessions.
|
||||||
@ -355,10 +362,8 @@ class RangeDatasetTest(test.TestCase):
|
|||||||
stop).make_initializable_iterator()
|
stop).make_initializable_iterator()
|
||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
path = self._iterator_checkpoint_prefix()
|
save_op = self._save_op(iterator._iterator_resource)
|
||||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
restore_op = self._restore_op(iterator._iterator_resource)
|
||||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
|
||||||
path)
|
|
||||||
return init_op, get_next, save_op, restore_op
|
return init_op, get_next, save_op, restore_op
|
||||||
|
|
||||||
start = 2
|
start = 2
|
||||||
@ -400,10 +405,8 @@ class RangeDatasetTest(test.TestCase):
|
|||||||
start, stop).repeat(num_epochs).make_initializable_iterator()
|
start, stop).repeat(num_epochs).make_initializable_iterator()
|
||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
path = self._iterator_checkpoint_prefix()
|
save_op = self._save_op(iterator._iterator_resource)
|
||||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
restore_op = self._restore_op(iterator._iterator_resource)
|
||||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
|
||||||
path)
|
|
||||||
return init_op, get_next, save_op, restore_op
|
return init_op, get_next, save_op, restore_op
|
||||||
|
|
||||||
start = 2
|
start = 2
|
||||||
@ -447,10 +450,8 @@ class RangeDatasetTest(test.TestCase):
|
|||||||
start, stop).repeat(num_epochs).make_initializable_iterator()
|
start, stop).repeat(num_epochs).make_initializable_iterator()
|
||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
path = self._iterator_checkpoint_prefix()
|
save_op = self._save_op(iterator._iterator_resource)
|
||||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
restore_op = self._restore_op(iterator._iterator_resource)
|
||||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
|
||||||
path)
|
|
||||||
return init_op, get_next, save_op, restore_op
|
return init_op, get_next, save_op, restore_op
|
||||||
|
|
||||||
start = 2
|
start = 2
|
||||||
|
@ -31,6 +31,8 @@ from tensorflow.python.framework import tensor_shape
|
|||||||
from tensorflow.python.lib.io import python_io
|
from tensorflow.python.lib.io import python_io
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gen_dataset_ops
|
from tensorflow.python.ops import gen_dataset_ops
|
||||||
|
from tensorflow.python.ops import io_ops
|
||||||
|
from tensorflow.python.ops import parsing_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
|
|
||||||
@ -273,18 +275,31 @@ class FixedLengthRecordReaderTest(test.TestCase):
|
|||||||
def _iterator_checkpoint_path(self):
|
def _iterator_checkpoint_path(self):
|
||||||
return os.path.join(self.get_temp_dir(), "iterator")
|
return os.path.join(self.get_temp_dir(), "iterator")
|
||||||
|
|
||||||
|
def _save_op(self, iterator_resource):
|
||||||
|
iterator_state_variant = gen_dataset_ops.serialize_iterator(
|
||||||
|
iterator_resource)
|
||||||
|
save_op = io_ops.write_file(
|
||||||
|
self._iterator_checkpoint_path(),
|
||||||
|
parsing_ops.serialize_tensor(iterator_state_variant))
|
||||||
|
return save_op
|
||||||
|
|
||||||
|
def _restore_op(self, iterator_resource):
|
||||||
|
iterator_state_variant = parsing_ops.parse_tensor(
|
||||||
|
io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant)
|
||||||
|
restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
|
||||||
|
iterator_state_variant)
|
||||||
|
return restore_op
|
||||||
|
|
||||||
def _build_iterator_graph(self, num_epochs):
|
def _build_iterator_graph(self, num_epochs):
|
||||||
filenames = self._createFiles()
|
filenames = self._createFiles()
|
||||||
path = self._iterator_checkpoint_path()
|
|
||||||
dataset = (readers.FixedLengthRecordDataset(
|
dataset = (readers.FixedLengthRecordDataset(
|
||||||
filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
|
filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
|
||||||
.repeat(num_epochs))
|
.repeat(num_epochs))
|
||||||
iterator = dataset.make_initializable_iterator()
|
iterator = dataset.make_initializable_iterator()
|
||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next_op = iterator.get_next()
|
get_next_op = iterator.get_next()
|
||||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
save_op = self._save_op(iterator._iterator_resource)
|
||||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
restore_op = self._restore_op(iterator._iterator_resource)
|
||||||
path)
|
|
||||||
return init_op, get_next_op, save_op, restore_op
|
return init_op, get_next_op, save_op, restore_op
|
||||||
|
|
||||||
def _restore_iterator(self):
|
def _restore_iterator(self):
|
||||||
@ -292,8 +307,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
|
|||||||
output_shapes = tensor_shape.scalar()
|
output_shapes = tensor_shape.scalar()
|
||||||
iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
|
iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
restore_op = gen_dataset_ops.restore_iterator(
|
restore_op = self._restore_op(iterator._iterator_resource)
|
||||||
iterator._iterator_resource, self._iterator_checkpoint_path())
|
|
||||||
return restore_op, get_next
|
return restore_op, get_next
|
||||||
|
|
||||||
def testSaveRestore(self):
|
def testSaveRestore(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user