Add SerializeIterator op that serializes an IteratorResource into a variant tensor.
Add DeserializeIterator op that builds IteratorResource from a variant tensor. Move BundleReaderWrapper and BundleWriterWrapper from dataset.h to iterator_ops.cc. Add generic key-value store interfaces IteratorStateReader and IteratorStateWriter for reading/writing state of iterators. Get rid of IteratorBundleReader and IteratorBundleWriter. PiperOrigin-RevId: 173140858
This commit is contained in:
parent
57f3e529d9
commit
1038927c09
@ -185,6 +185,7 @@ py_test(
|
||||
"//tensorflow/python:function",
|
||||
"//tensorflow/python:functional_ops",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:io_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python:script_ops",
|
||||
@ -252,6 +253,8 @@ py_test(
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:io_ops",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:variables",
|
||||
@ -274,6 +277,7 @@ py_test(
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:io_ops",
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
@ -538,9 +539,23 @@ class IteratorTest(test.TestCase):
|
||||
|
||||
def testIncorrectIteratorRestore(self):
|
||||
|
||||
def _iterator_checkpoint_prefix():
|
||||
def _path():
|
||||
return os.path.join(self.get_temp_dir(), "iterator")
|
||||
|
||||
def _save_op(iterator_resource):
|
||||
iterator_state_variant = gen_dataset_ops.serialize_iterator(
|
||||
iterator_resource)
|
||||
save_op = io_ops.write_file(
|
||||
_path(), parsing_ops.serialize_tensor(iterator_state_variant))
|
||||
return save_op
|
||||
|
||||
def _restore_op(iterator_resource):
|
||||
iterator_state_variant = parsing_ops.parse_tensor(
|
||||
io_ops.read_file(_path()), dtypes.variant)
|
||||
restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
|
||||
iterator_state_variant)
|
||||
return restore_op
|
||||
|
||||
def _build_range_dataset_graph():
|
||||
start = 1
|
||||
stop = 10
|
||||
@ -548,22 +563,18 @@ class IteratorTest(test.TestCase):
|
||||
stop).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
path = _iterator_checkpoint_prefix()
|
||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
||||
path)
|
||||
save_op = _save_op(iterator._iterator_resource)
|
||||
restore_op = _restore_op(iterator._iterator_resource)
|
||||
return init_op, get_next, save_op, restore_op
|
||||
|
||||
def _build_reader_dataset_graph():
|
||||
filenames = ["test"] # Does not exist but we don't care in this test.
|
||||
path = _iterator_checkpoint_prefix()
|
||||
iterator = readers.FixedLengthRecordDataset(
|
||||
filenames, 1, 0, 0).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next_op = iterator.get_next()
|
||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
||||
path)
|
||||
save_op = _save_op(iterator._iterator_resource)
|
||||
restore_op = _restore_op(iterator._iterator_resource)
|
||||
return init_op, get_next_op, save_op, restore_op
|
||||
|
||||
# Saving iterator for RangeDataset graph.
|
||||
|
@ -29,6 +29,8 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
@ -193,6 +195,21 @@ class RangeDatasetTest(test.TestCase):
|
||||
def _iterator_checkpoint_prefix(self):
|
||||
return os.path.join(self.get_temp_dir(), "iterator")
|
||||
|
||||
def _save_op(self, iterator_resource):
|
||||
iterator_state_variant = gen_dataset_ops.serialize_iterator(
|
||||
iterator_resource)
|
||||
save_op = io_ops.write_file(
|
||||
self._iterator_checkpoint_prefix(),
|
||||
parsing_ops.serialize_tensor(iterator_state_variant))
|
||||
return save_op
|
||||
|
||||
def _restore_op(self, iterator_resource):
|
||||
iterator_state_variant = parsing_ops.parse_tensor(
|
||||
io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant)
|
||||
restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
|
||||
iterator_state_variant)
|
||||
return restore_op
|
||||
|
||||
def testSaveRestore(self):
|
||||
|
||||
def _build_graph(start, stop):
|
||||
@ -200,10 +217,8 @@ class RangeDatasetTest(test.TestCase):
|
||||
stop).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
path = self._iterator_checkpoint_prefix()
|
||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
||||
path)
|
||||
save_op = self._save_op(iterator._iterator_resource)
|
||||
restore_op = self._restore_op(iterator._iterator_resource)
|
||||
return init_op, get_next, save_op, restore_op
|
||||
|
||||
# Saving and restoring in different sessions.
|
||||
@ -246,14 +261,13 @@ class RangeDatasetTest(test.TestCase):
|
||||
|
||||
def testRestoreWithoutBuildingDatasetGraph(self):
|
||||
|
||||
def _build_graph(start, stop, num_epochs, path):
|
||||
def _build_graph(start, stop, num_epochs):
|
||||
dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs)
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
||||
path)
|
||||
save_op = self._save_op(iterator._iterator_resource)
|
||||
restore_op = self._restore_op(iterator._iterator_resource)
|
||||
return init_op, get_next, save_op, restore_op
|
||||
|
||||
# Saving and restoring in different sessions.
|
||||
@ -262,10 +276,8 @@ class RangeDatasetTest(test.TestCase):
|
||||
num_epochs = 5
|
||||
break_point = 5
|
||||
break_epoch = 3
|
||||
path = self._iterator_checkpoint_prefix()
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs,
|
||||
path)
|
||||
init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs)
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
@ -282,8 +294,7 @@ class RangeDatasetTest(test.TestCase):
|
||||
output_shapes = tensor_shape.scalar()
|
||||
iterator = iterator_ops.Iterator.from_structure(output_types,
|
||||
output_shapes)
|
||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
||||
path)
|
||||
restore_op = self._restore_op(iterator._iterator_resource)
|
||||
get_next = iterator.get_next()
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(restore_op)
|
||||
@ -302,10 +313,8 @@ class RangeDatasetTest(test.TestCase):
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
path = self._iterator_checkpoint_prefix()
|
||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
||||
path)
|
||||
save_op = self._save_op(iterator._iterator_resource)
|
||||
restore_op = self._restore_op(iterator._iterator_resource)
|
||||
return init_op, get_next, save_op, restore_op
|
||||
|
||||
# Saving and restoring in different sessions.
|
||||
@ -343,10 +352,8 @@ class RangeDatasetTest(test.TestCase):
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
path = self._iterator_checkpoint_prefix()
|
||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
||||
path)
|
||||
save_op = self._save_op(iterator._iterator_resource)
|
||||
restore_op = self._restore_op(iterator._iterator_resource)
|
||||
return init_op, get_next, save_op, restore_op
|
||||
|
||||
# Saving and restoring in different sessions.
|
||||
@ -379,10 +386,8 @@ class RangeDatasetTest(test.TestCase):
|
||||
stop).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
path = self._iterator_checkpoint_prefix()
|
||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
||||
path)
|
||||
save_op = self._save_op(iterator._iterator_resource)
|
||||
restore_op = self._restore_op(iterator._iterator_resource)
|
||||
return init_op, get_next, save_op, restore_op
|
||||
|
||||
start = 2
|
||||
@ -424,10 +429,8 @@ class RangeDatasetTest(test.TestCase):
|
||||
start, stop).repeat(num_epochs).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
path = self._iterator_checkpoint_prefix()
|
||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
||||
path)
|
||||
save_op = self._save_op(iterator._iterator_resource)
|
||||
restore_op = self._restore_op(iterator._iterator_resource)
|
||||
return init_op, get_next, save_op, restore_op
|
||||
|
||||
start = 2
|
||||
@ -471,10 +474,8 @@ class RangeDatasetTest(test.TestCase):
|
||||
start, stop).repeat(num_epochs).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
path = self._iterator_checkpoint_prefix()
|
||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
||||
path)
|
||||
save_op = self._save_op(iterator._iterator_resource)
|
||||
restore_op = self._restore_op(iterator._iterator_resource)
|
||||
return init_op, get_next, save_op, restore_op
|
||||
|
||||
start = 2
|
||||
|
@ -33,6 +33,7 @@ from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.lib.io import python_io
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
@ -276,18 +277,31 @@ class FixedLengthRecordReaderTest(test.TestCase):
|
||||
def _iterator_checkpoint_path(self):
|
||||
return os.path.join(self.get_temp_dir(), "iterator")
|
||||
|
||||
def _save_op(self, iterator_resource):
|
||||
iterator_state_variant = gen_dataset_ops.serialize_iterator(
|
||||
iterator_resource)
|
||||
save_op = io_ops.write_file(
|
||||
self._iterator_checkpoint_path(),
|
||||
parsing_ops.serialize_tensor(iterator_state_variant))
|
||||
return save_op
|
||||
|
||||
def _restore_op(self, iterator_resource):
|
||||
iterator_state_variant = parsing_ops.parse_tensor(
|
||||
io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant)
|
||||
restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
|
||||
iterator_state_variant)
|
||||
return restore_op
|
||||
|
||||
def _build_iterator_graph(self, num_epochs):
|
||||
filenames = self._createFiles()
|
||||
path = self._iterator_checkpoint_path()
|
||||
dataset = (readers.FixedLengthRecordDataset(
|
||||
filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
|
||||
.repeat(num_epochs))
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next_op = iterator.get_next()
|
||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
||||
path)
|
||||
save_op = self._save_op(iterator._iterator_resource)
|
||||
restore_op = self._restore_op(iterator._iterator_resource)
|
||||
return init_op, get_next_op, save_op, restore_op
|
||||
|
||||
def _restore_iterator(self):
|
||||
@ -295,8 +309,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
|
||||
output_shapes = tensor_shape.scalar()
|
||||
iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
|
||||
get_next = iterator.get_next()
|
||||
restore_op = gen_dataset_ops.restore_iterator(
|
||||
iterator._iterator_resource, self._iterator_checkpoint_path())
|
||||
restore_op = self._restore_op(iterator._iterator_resource)
|
||||
return restore_op, get_next
|
||||
|
||||
def testSaveRestore(self):
|
||||
|
@ -163,6 +163,7 @@ CORE_PROTO_SRCS = [
|
||||
"framework/function.proto",
|
||||
"framework/graph.proto",
|
||||
"framework/graph_transfer_info.proto",
|
||||
"framework/iterator.proto",
|
||||
"framework/kernel_def.proto",
|
||||
"framework/log_memory.proto",
|
||||
"framework/node_def.proto",
|
||||
|
17
tensorflow/core/framework/iterator.proto
Normal file
17
tensorflow/core/framework/iterator.proto
Normal file
@ -0,0 +1,17 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "IteratorProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.util";
|
||||
|
||||
// Protocol buffer representing the metadata for an iterator's state stored
|
||||
// as a Variant tensor.
|
||||
message IteratorStateMetadata {
|
||||
// A user-specified version string.
|
||||
string version = 1;
|
||||
|
||||
// Keys for tensors in the VariantTensorDataProto.
|
||||
repeated string keys = 2;
|
||||
}
|
@ -6061,6 +6061,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -17,12 +17,14 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/common_runtime/graph_runner.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||
#include "tensorflow/core/framework/variant_tensor_data.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
@ -39,54 +41,25 @@ namespace tensorflow {
|
||||
|
||||
class ResourceMgr;
|
||||
|
||||
class BundleReaderWrapper {
|
||||
// Interface for reading values from a key-value store.
|
||||
// Used for restoring iterator state.
|
||||
class IteratorStateReader {
|
||||
public:
|
||||
BundleReaderWrapper(BundleReader* bundle_reader)
|
||||
: bundle_reader_(bundle_reader) {}
|
||||
virtual Status ReadScalar(StringPiece key, int64* val) = 0;
|
||||
virtual Status ReadScalar(StringPiece key, string* val) = 0;
|
||||
virtual bool Contains(StringPiece key) = 0;
|
||||
|
||||
// Reads a scalar value.
|
||||
template <typename T>
|
||||
Status ReadScalar(StringPiece key, T* val) {
|
||||
Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
|
||||
TF_RETURN_IF_ERROR(Lookup(key, &val_t));
|
||||
*val = val_t.scalar<T>()();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool Contains(StringPiece key) { return bundle_reader_->Contains(key); }
|
||||
|
||||
private:
|
||||
Status Lookup(StringPiece key, Tensor* val) {
|
||||
return bundle_reader_->Lookup(key, val);
|
||||
}
|
||||
|
||||
BundleReader* bundle_reader_;
|
||||
virtual ~IteratorStateReader() {}
|
||||
};
|
||||
|
||||
class BundleWriterWrapper {
|
||||
// Interface for writing values to a key-value store.
|
||||
// Used for saving iterator state.
|
||||
class IteratorStateWriter {
|
||||
public:
|
||||
// Note: We intentionally do not provide a constructor that builds a
|
||||
// BundleWriter from the checkpoint path because we want the caller to be
|
||||
// in-charge of calling BundleWriter::Finish(). If we expose the Finish()
|
||||
// method here it may be called pre-maturely by users of this object.
|
||||
explicit BundleWriterWrapper(BundleWriter* bundle_writer)
|
||||
: bundle_writer_(bundle_writer) {}
|
||||
virtual Status WriteScalar(StringPiece key, const int64& val) = 0;
|
||||
virtual Status WriteScalar(StringPiece key, const string& val) = 0;
|
||||
|
||||
// Writes a scalar value.
|
||||
template <typename T>
|
||||
Status WriteScalar(StringPiece key, const T val) {
|
||||
Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
|
||||
val_t.scalar<T>()() = val;
|
||||
TF_RETURN_IF_ERROR(Add(key, val_t));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
Status Add(StringPiece key, const Tensor& val) {
|
||||
return bundle_writer_->Add(key, val);
|
||||
}
|
||||
|
||||
BundleWriter* bundle_writer_;
|
||||
virtual ~IteratorStateWriter() {}
|
||||
};
|
||||
|
||||
// Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
|
||||
@ -249,10 +222,6 @@ class IteratorContext {
|
||||
// range of outputs is typically represented by an `DatasetBase`,
|
||||
// defined below.
|
||||
class IteratorBase {
|
||||
protected:
|
||||
class IteratorBundleReader;
|
||||
class IteratorBundleWriter;
|
||||
|
||||
public:
|
||||
virtual ~IteratorBase() {}
|
||||
|
||||
@ -284,75 +253,17 @@ class IteratorBase {
|
||||
virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
|
||||
|
||||
// Saves the state of this iterator.
|
||||
virtual Status Save(OpKernelContext* ctx, const string& path) {
|
||||
BundleWriter bundle_writer(ctx->env(), path);
|
||||
TF_RETURN_IF_ERROR(bundle_writer.status());
|
||||
IteratorBundleWriter writer(&bundle_writer);
|
||||
TF_RETURN_IF_ERROR(Save(ctx, &writer));
|
||||
return bundle_writer.Finish();
|
||||
}
|
||||
|
||||
virtual Status Restore(OpKernelContext* ctx, const string& path) {
|
||||
if (!(ctx->env()->FileExists(MetaFilename(path)).ok())) {
|
||||
return errors::NotFound(
|
||||
"Failed to restore Iterator state. No file found at ",
|
||||
MetaFilename(path));
|
||||
}
|
||||
BundleReader bundle_reader(ctx->env(), path);
|
||||
TF_RETURN_IF_ERROR(bundle_reader.status());
|
||||
IteratorBundleReader reader(&bundle_reader);
|
||||
return Restore(ctx, &reader);
|
||||
}
|
||||
|
||||
static const char kIteratorExhausted[];
|
||||
|
||||
protected:
|
||||
// This is needed so that sub-classes of IteratorBase can call
|
||||
// `RestoreInternal` on their parent iterators, e.g., in
|
||||
// `RepeatDataasetOp::Dataset`.
|
||||
class IteratorBundleReader : public BundleReaderWrapper {
|
||||
public:
|
||||
IteratorBundleReader(BundleReader* bundle_reader)
|
||||
: BundleReaderWrapper(bundle_reader) {}
|
||||
|
||||
// Restores the state of a parent iterator recursively.
|
||||
Status RestoreParent(OpKernelContext* ctx,
|
||||
const std::unique_ptr<IteratorBase>& parent) {
|
||||
return parent->RestoreInternal(ctx, this);
|
||||
}
|
||||
};
|
||||
|
||||
// This is needed so that sub-classes of IteratorBase can call
|
||||
// `SaveInternal` on their parent iterators, e.g., in
|
||||
// `RepeatDataasetOp::Dataset`.
|
||||
class IteratorBundleWriter : public BundleWriterWrapper {
|
||||
public:
|
||||
IteratorBundleWriter(BundleWriter* bundle_writer)
|
||||
: BundleWriterWrapper(bundle_writer) {}
|
||||
// Saves the state of a parent iterator recursively.
|
||||
Status SaveParent(OpKernelContext* ctx,
|
||||
const std::unique_ptr<IteratorBase>& parent) {
|
||||
return parent->SaveInternal(ctx, this);
|
||||
}
|
||||
};
|
||||
|
||||
virtual Status Save(OpKernelContext* ctx, IteratorBundleWriter* writer) {
|
||||
virtual Status Save(IteratorStateWriter* writer) {
|
||||
if (is_exhausted_) {
|
||||
LOG(INFO) << "Iterator exhausted.";
|
||||
return writer->WriteScalar<string>(kIteratorExhausted,
|
||||
kIteratorExhausted);
|
||||
return writer->WriteScalar(kIteratorExhausted, kIteratorExhausted);
|
||||
} else {
|
||||
return SaveInternal(ctx, writer);
|
||||
return SaveInternal(writer);
|
||||
}
|
||||
}
|
||||
|
||||
// Saves the state of this iterator.
|
||||
virtual Status SaveInternal(OpKernelContext* ctx,
|
||||
IteratorBundleWriter* writer) {
|
||||
return errors::Unimplemented("SaveInternal");
|
||||
}
|
||||
|
||||
virtual Status Restore(OpKernelContext* ctx, IteratorBundleReader* reader) {
|
||||
// Restores the state of this iterator.
|
||||
virtual Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) {
|
||||
if (reader->Contains(kIteratorExhausted)) {
|
||||
LOG(INFO) << "Iterator exhausted. Nothing to restore.";
|
||||
is_exhausted_ = true;
|
||||
@ -362,9 +273,33 @@ class IteratorBase {
|
||||
}
|
||||
}
|
||||
|
||||
// Restores the state of this iterator.
|
||||
static const char kIteratorExhausted[];
|
||||
|
||||
protected:
|
||||
// This is needed so that sub-classes of IteratorBase can call
|
||||
// `SaveInternal` on their parent iterators, e.g., in
|
||||
// `RepeatDataasetOp::Dataset`.
|
||||
Status SaveParent(IteratorStateWriter* writer,
|
||||
const std::unique_ptr<IteratorBase>& parent) {
|
||||
return parent->SaveInternal(writer);
|
||||
}
|
||||
|
||||
// This is needed so that sub-classes of IteratorBase can call
|
||||
// `RestoreInternal` on their parent iterators, e.g., in
|
||||
// `RepeatDataasetOp::Dataset`.
|
||||
Status RestoreParent(OpKernelContext* ctx, IteratorStateReader* reader,
|
||||
const std::unique_ptr<IteratorBase>& parent) {
|
||||
return parent->RestoreInternal(ctx, reader);
|
||||
}
|
||||
|
||||
// Saves the state of this iterator recursively.
|
||||
virtual Status SaveInternal(IteratorStateWriter* writer) {
|
||||
return errors::Unimplemented("SaveInternal");
|
||||
}
|
||||
|
||||
// Restores the state of this iterator recursively.
|
||||
virtual Status RestoreInternal(OpKernelContext* ctx,
|
||||
IteratorBundleReader* reader) {
|
||||
IteratorStateReader* reader) {
|
||||
return errors::Unimplemented("RestoreInternal");
|
||||
}
|
||||
|
||||
@ -404,7 +339,7 @@ class DatasetBase : public core::RefCounted {
|
||||
virtual string DebugString() = 0;
|
||||
|
||||
// Serializes the dataset and writes it to the `writer`.
|
||||
virtual Status Save(BundleWriterWrapper* writer) const {
|
||||
virtual Status Save(IteratorStateWriter* writer) const {
|
||||
return errors::Unimplemented("DatasetBase::Save");
|
||||
}
|
||||
|
||||
@ -435,20 +370,14 @@ class GraphDatasetBase : public DatasetBase {
|
||||
|
||||
const string op_name() const { return op_name_; }
|
||||
|
||||
Status Save(BundleWriterWrapper* writer) const override {
|
||||
GraphDefBuilder b;
|
||||
DatasetGraphDefBuilder db(&b);
|
||||
Node* node = nullptr;
|
||||
TF_RETURN_IF_ERROR(AsGraphDefInternal(&db, &node));
|
||||
string output_name = node->name();
|
||||
GraphDef graph_def;
|
||||
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
|
||||
Status Save(IteratorStateWriter* writer) const override {
|
||||
string serialized_graph_def;
|
||||
graph_def.SerializeToString(&serialized_graph_def);
|
||||
string output_node;
|
||||
TF_RETURN_IF_ERROR(Serialize(&serialized_graph_def, &output_node));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar<string>(kDatasetGraphKey, serialized_graph_def));
|
||||
writer->WriteScalar(kDatasetGraphKey, serialized_graph_def));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar<string>(kDatasetGraphOutputNodeKey, output_name));
|
||||
writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -460,6 +389,18 @@ class GraphDatasetBase : public DatasetBase {
|
||||
static const char kDatasetGraphOutputNodeKey[];
|
||||
|
||||
private:
|
||||
Status Serialize(string* serialized_graph_def, string* output_node) const {
|
||||
GraphDefBuilder b;
|
||||
DatasetGraphDefBuilder db(&b);
|
||||
Node* node = nullptr;
|
||||
TF_RETURN_IF_ERROR(AsGraphDefInternal(&db, &node));
|
||||
*output_node = node->name();
|
||||
GraphDef graph_def;
|
||||
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
|
||||
graph_def.SerializeToString(serialized_graph_def);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const string op_name_;
|
||||
};
|
||||
|
||||
@ -505,18 +446,18 @@ class DatasetIterator : public IteratorBase {
|
||||
return GetNextInternal(ctx, out_tensors, end_of_sequence);
|
||||
}
|
||||
|
||||
protected:
|
||||
Status Save(OpKernelContext* ctx, IteratorBundleWriter* writer) final {
|
||||
Status Save(IteratorStateWriter* writer) final {
|
||||
TF_RETURN_IF_ERROR(dataset()->Save(writer));
|
||||
return IteratorBase::Save(ctx, writer);
|
||||
return IteratorBase::Save(writer);
|
||||
}
|
||||
|
||||
protected:
|
||||
// Internal implementation of GetNext that is wrapped in tracing logic.
|
||||
virtual Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) = 0;
|
||||
|
||||
string full_name(const string& name) {
|
||||
string full_name(const string& name) const {
|
||||
return strings::StrCat(prefix(), ":", name);
|
||||
}
|
||||
|
||||
|
@ -16,9 +16,11 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_runner.h"
|
||||
#include "tensorflow/core/framework/iterator.pb.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/resource_op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
@ -35,6 +37,8 @@ namespace {
|
||||
// See documentation in ../ops/dataset_ops.cc for a high-level
|
||||
// description of the following ops.
|
||||
|
||||
const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
|
||||
|
||||
Status VerifyTypesMatch(const DataTypeVector& expected,
|
||||
const DataTypeVector& received) {
|
||||
if (expected.size() != received.size()) {
|
||||
@ -93,10 +97,10 @@ class IteratorResource : public ResourceBase {
|
||||
}
|
||||
}
|
||||
|
||||
Status Save(OpKernelContext* ctx, const string& path) {
|
||||
Status Save(IteratorStateWriter* writer) {
|
||||
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
|
||||
if (captured_iterator) {
|
||||
return captured_iterator->Save(ctx, path);
|
||||
return captured_iterator->Save(writer);
|
||||
} else {
|
||||
return errors::FailedPrecondition(
|
||||
"Save() failed because the iterator has not been initialized. "
|
||||
@ -105,49 +109,34 @@ class IteratorResource : public ResourceBase {
|
||||
}
|
||||
}
|
||||
|
||||
Status Restore(OpKernelContext* ctx, const string& path) {
|
||||
if (!(ctx->env()->FileExists(MetaFilename(path)).ok())) {
|
||||
return errors::NotFound(
|
||||
"Failed to restore Iterator state. No file found at ",
|
||||
MetaFilename(path));
|
||||
Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) {
|
||||
string serialized_graph_def;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(GraphDatasetBase::kDatasetGraphKey,
|
||||
&serialized_graph_def));
|
||||
GraphDef graph_def;
|
||||
if (!graph_def.ParseFromString(serialized_graph_def)) {
|
||||
return errors::Internal("Error parsing dataset GraphDef.");
|
||||
}
|
||||
string output_node;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||
GraphDatasetBase::kDatasetGraphOutputNodeKey, &output_node));
|
||||
DatasetBase* dataset = nullptr;
|
||||
Graph graph(OpRegistry::Global());
|
||||
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
|
||||
std::vector<Tensor> outputs;
|
||||
GraphRunner graph_runner(ctx->env());
|
||||
TF_RETURN_IF_ERROR(graph_runner.Run(&graph, ctx->function_library(), {},
|
||||
{output_node}, &outputs));
|
||||
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
|
||||
|
||||
BundleReader bundle_reader(ctx->env(), path);
|
||||
TF_RETURN_IF_ERROR(bundle_reader.status());
|
||||
BundleReaderWrapper reader(&bundle_reader);
|
||||
if (reader.Contains(GraphDatasetBase::kDatasetGraphKey)) {
|
||||
string serialized_graph_def;
|
||||
TF_RETURN_IF_ERROR(reader.ReadScalar(GraphDatasetBase::kDatasetGraphKey,
|
||||
&serialized_graph_def));
|
||||
GraphDef graph_def;
|
||||
graph_def.ParseFromString(serialized_graph_def);
|
||||
// TODO(srbs): Is there a way of getting the op registry of the original
|
||||
// graph.
|
||||
Graph graph(OpRegistry::Global());
|
||||
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
|
||||
string output_node;
|
||||
TF_RETURN_IF_ERROR(reader.ReadScalar(
|
||||
GraphDatasetBase::kDatasetGraphOutputNodeKey, &output_node));
|
||||
std::vector<Tensor> outputs;
|
||||
GraphRunner graph_runner(ctx->env());
|
||||
TF_RETURN_IF_ERROR(graph_runner.Run(&graph, ctx->function_library(), {},
|
||||
{output_node}, &outputs));
|
||||
DatasetBase* dataset;
|
||||
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
|
||||
TF_RETURN_IF_ERROR(set_iterator(dataset->MakeIterator("Iterator")));
|
||||
} else if (reader.Contains(IteratorBase::kIteratorExhausted)) {
|
||||
TF_RETURN_IF_ERROR(set_iterator(std::unique_ptr<IteratorBase>(
|
||||
new ExhaustedIterator(output_dtypes_, output_shapes_))));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(set_iterator(dataset->MakeIterator("Iterator")));
|
||||
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
|
||||
|
||||
if (captured_iterator) {
|
||||
// TODO(srbs): Figure a way to pass bundle_reader here.
|
||||
return captured_iterator->Restore(ctx, path);
|
||||
return captured_iterator->Restore(ctx, reader);
|
||||
} else {
|
||||
return errors::FailedPrecondition(
|
||||
"Failed to restore iterator from ", path,
|
||||
". Make sure the checkpoint ",
|
||||
"Failed to restore iterator. Make sure the checkpoint ",
|
||||
"is not corrupt. If the checkpoint does not contain the GraphDef, ",
|
||||
"you will need to initialize your iterator before restoring.");
|
||||
}
|
||||
@ -174,43 +163,194 @@ class IteratorResource : public ResourceBase {
|
||||
}
|
||||
|
||||
private:
|
||||
// A no-op iterator which always sets end_of_sequence = true. An instance of
|
||||
// this is returned when attempting to restore an exhausted iterator. This is
|
||||
// needed because the Dataset GraphDef may not have been saved for exhausted
|
||||
// iterators so the actual Iterator can not be built.
|
||||
class ExhaustedIterator : public IteratorBase {
|
||||
public:
|
||||
ExhaustedIterator(const DataTypeVector& output_dtypes,
|
||||
const std::vector<PartialTensorShape>& output_shapes)
|
||||
: output_dtypes_(output_dtypes), output_shapes_(output_shapes) {}
|
||||
Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) final {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
return output_dtypes_;
|
||||
}
|
||||
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
return output_shapes_;
|
||||
}
|
||||
|
||||
virtual const std::vector<PartialTensorShape>& output_shapes() {
|
||||
return output_shapes_;
|
||||
}
|
||||
|
||||
private:
|
||||
const DataTypeVector output_dtypes_;
|
||||
const std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
std::shared_ptr<IteratorBase> iterator_;
|
||||
const DataTypeVector output_dtypes_;
|
||||
const std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
// Helper class for reading data from a VariantTensorData object.
|
||||
class VariantTensorDataReader : public IteratorStateReader {
|
||||
public:
|
||||
explicit VariantTensorDataReader(const VariantTensorData* data)
|
||||
: data_(data) {
|
||||
PreProcess();
|
||||
}
|
||||
|
||||
// Returns OK iff the initialization was successful, i.e.,
|
||||
// pre-processing did not have errors.
|
||||
Status status() const { return status_; }
|
||||
|
||||
Status ReadScalar(StringPiece key, int64* val) override {
|
||||
return ReadScalarInternal(key, val);
|
||||
}
|
||||
|
||||
Status ReadScalar(StringPiece key, string* val) override {
|
||||
return ReadScalarInternal(key, val);
|
||||
}
|
||||
|
||||
bool Contains(StringPiece key) override {
|
||||
return map_.find(key.ToString()) != map_.end();
|
||||
}
|
||||
|
||||
private:
|
||||
void PreProcess() {
|
||||
string metadata;
|
||||
data_->get_metadata(&metadata);
|
||||
IteratorStateMetadata proto;
|
||||
if (!proto.ParseFromString(metadata)) {
|
||||
status_ = errors::Internal("Error parsing IteratorStateMetadata.");
|
||||
return;
|
||||
}
|
||||
size_t num_entries = proto.keys_size();
|
||||
CHECK_EQ(num_entries, data_->tensors_size());
|
||||
for (size_t i = 0; i < num_entries; i++) {
|
||||
map_[proto.keys(i)] = i;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ReadScalarInternal(StringPiece key, T* val) {
|
||||
if (map_.find(key.ToString()) == map_.end()) {
|
||||
return errors::NotFound(key);
|
||||
}
|
||||
*val = data_->tensors(map_[key.ToString()]).scalar<T>()();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::map<string, size_t> map_;
|
||||
const VariantTensorData* data_; // Not owned.
|
||||
Status status_;
|
||||
};
|
||||
|
||||
// Helper class for writing data to a VariantTensorData object.
|
||||
class VariantTensorDataWriter : public IteratorStateWriter {
|
||||
public:
|
||||
// Does not take ownership of data.
|
||||
explicit VariantTensorDataWriter(VariantTensorData* data) : data_(data) {}
|
||||
|
||||
Status WriteScalar(StringPiece key, const int64& val) override {
|
||||
return WriteScalarInternal(key, val);
|
||||
}
|
||||
|
||||
Status WriteScalar(StringPiece key, const string& val) override {
|
||||
return WriteScalarInternal(key, val);
|
||||
}
|
||||
|
||||
// Writes the metadata to `data_`.
|
||||
Status Flush() {
|
||||
string metadata;
|
||||
if (!metadata_proto_.SerializeToString(&metadata)) {
|
||||
return errors::Internal("Unable to serialize IteratorStateMetadata.");
|
||||
}
|
||||
data_->set_metadata(metadata);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
Status WriteScalarInternal(StringPiece key, const T& val) {
|
||||
// Write key to the metadata proto. This gets written to `data_`
|
||||
// when `Flush()` is called. We do this lazily to avoid multiple
|
||||
// serialization calls.
|
||||
metadata_proto_.add_keys(key.ToString());
|
||||
|
||||
// Update tensors.
|
||||
Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
|
||||
val_t.scalar<T>()() = val;
|
||||
*(data_->add_tensors()) = std::move(val_t);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
VariantTensorData* data_;
|
||||
// TODO(srbs): Set the version string.
|
||||
IteratorStateMetadata metadata_proto_;
|
||||
};
|
||||
|
||||
// Wrapper for encoding/decoding the iterator state stored in a Variant tensor.
|
||||
// The get() method returns an IteratorStateReader which can be used
|
||||
// to restore iterator state.
|
||||
//
|
||||
// Usage example:
|
||||
//
|
||||
// Encoding:
|
||||
//
|
||||
// Tensor t(DT_VARIANT, TensorShape({}));
|
||||
// t->scalar<Variant>()() = IteratorStateVariant(iterator_resource);
|
||||
//
|
||||
// Encode() sets the type_name of the VariantTensorData object to
|
||||
// IteratorStateVariant::TypeName().
|
||||
//
|
||||
// Decoding:
|
||||
//
|
||||
// Variant v = <VariantTensorDataProto object>;
|
||||
// DecodeUnaryVariant(&v);
|
||||
// IteratorStateVariant* wrapper = v.get<IteratorStateVariant>();
|
||||
// iterator_resource->Restore(ctx, wrapper->get())
|
||||
//
|
||||
// The type_name of the VariantTensorData object to be decoded must
|
||||
// match IteratorStateVariant::TypeName().
|
||||
class IteratorStateVariant {
|
||||
public:
|
||||
IteratorStateVariant() : data_(nullptr) {}
|
||||
IteratorStateVariant(const IteratorStateVariant& other) : data_(nullptr) {
|
||||
if (other.data_) {
|
||||
Decode(*other.data_);
|
||||
}
|
||||
}
|
||||
// Initializes this object with the current state of the iterator so
|
||||
// that it can be written on the next call to Encode().
|
||||
Status InitializeFromIterator(IteratorResource* iterator_resource) {
|
||||
data_.reset(new VariantTensorData());
|
||||
data_->set_type_name(TypeName());
|
||||
VariantTensorDataWriter writer(data_.get());
|
||||
TF_RETURN_IF_ERROR(iterator_resource->Save(&writer));
|
||||
TF_RETURN_IF_ERROR(writer.Flush());
|
||||
return Status::OK();
|
||||
}
|
||||
string TypeName() const { return kIteratorVariantTypeName; }
|
||||
void Encode(VariantTensorData* data) const { *data = *data_; }
|
||||
bool Decode(const VariantTensorData& data) {
|
||||
if (data.type_name() != TypeName()) {
|
||||
return false;
|
||||
}
|
||||
std::unique_ptr<VariantTensorData> tensor_data(new VariantTensorData);
|
||||
*tensor_data = data;
|
||||
std::unique_ptr<VariantTensorDataReader> reader(
|
||||
new VariantTensorDataReader(tensor_data.get()));
|
||||
status_ = reader->status();
|
||||
if (!status_.ok()) {
|
||||
return false;
|
||||
}
|
||||
data_ = std::move(tensor_data);
|
||||
reader_ = std::move(reader);
|
||||
return true;
|
||||
}
|
||||
IteratorStateReader* get() { return reader_.get(); }
|
||||
Status status() const { return status_; }
|
||||
string DebugString() const {
|
||||
if (data_) {
|
||||
return strings::StrCat("IteratorStateVariant<",
|
||||
"data: ", data_->DebugString(),
|
||||
" status: ", status_.ToString(), ">");
|
||||
} else {
|
||||
return strings::StrCat("IteratorStateVariant<empty>");
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<IteratorStateReader> reader_;
|
||||
Status status_;
|
||||
std::unique_ptr<VariantTensorData> data_;
|
||||
};
|
||||
|
||||
// Register the reader class in the global variant decode_fn registry
|
||||
// so that a Variant containing a serialized representation of iterator state
|
||||
// can be decoded using DecodeUnaryVariant. If we don't do this we will need
|
||||
// to manually decode the returned Variant using MaybeDecodeAndCopy in
|
||||
// DeserializeIteratorOp which is not recommended.
|
||||
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
|
||||
kIteratorVariantTypeName);
|
||||
|
||||
// TODO(mrry): Can we simply use the template kernel here?
|
||||
class IteratorHandleOp : public ResourceOpKernel<IteratorResource> {
|
||||
public:
|
||||
@ -294,37 +434,6 @@ class ToSingleElementOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
class SaveIteratorOp : public OpKernel {
|
||||
public:
|
||||
explicit SaveIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
IteratorResource* iterator_resource;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->input(1).shape()),
|
||||
errors::InvalidArgument("SaveIteratorOp: path must be scalar"));
|
||||
const string& path = ctx->input(1).scalar<string>()();
|
||||
OP_REQUIRES_OK(ctx, iterator_resource->Save(ctx, path));
|
||||
}
|
||||
};
|
||||
|
||||
class RestoreIteratorOp : public OpKernel {
|
||||
public:
|
||||
explicit RestoreIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
IteratorResource* iterator_resource;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
|
||||
OP_REQUIRES(
|
||||
ctx, TensorShapeUtils::IsScalar(ctx->input(1).shape()),
|
||||
errors::InvalidArgument("RestoreIteratorOp: path must be scalar"));
|
||||
const string& path = ctx->input(1).scalar<string>()();
|
||||
OP_REQUIRES_OK(ctx, iterator_resource->Restore(ctx, path));
|
||||
}
|
||||
};
|
||||
|
||||
class OneShotIteratorOp : public AsyncOpKernel {
|
||||
public:
|
||||
explicit OneShotIteratorOp(OpKernelConstruction* ctx)
|
||||
@ -644,15 +753,55 @@ class IteratorFromStringHandleOp : public OpKernel {
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
class SerializeIteratorOp : public OpKernel {
|
||||
public:
|
||||
explicit SerializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& resource_handle_t = ctx->input(0);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
|
||||
errors::InvalidArgument("resource_handle must be a scalar"));
|
||||
|
||||
// Validate that the handle corresponds to a real resource, and
|
||||
// that it is an IteratorResource.
|
||||
IteratorResource* iterator_resource;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
|
||||
iterator_resource->Unref();
|
||||
Tensor* variant_t;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &variant_t));
|
||||
IteratorStateVariant v;
|
||||
OP_REQUIRES_OK(ctx, v.InitializeFromIterator(iterator_resource));
|
||||
variant_t->scalar<Variant>()() = v;
|
||||
}
|
||||
};
|
||||
|
||||
class DeserializeIteratorOp : public OpKernel {
|
||||
public:
|
||||
explicit DeserializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// Validate that the handle corresponds to a real resource, and
|
||||
// that it is an IteratorResource.
|
||||
IteratorResource* iterator_resource;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
|
||||
|
||||
Variant variant = ctx->input(1).scalar<Variant>()();
|
||||
auto* wrapper = variant.get<IteratorStateVariant>();
|
||||
OP_REQUIRES(ctx, wrapper != nullptr,
|
||||
errors::InvalidArgument(
|
||||
"DeserializeIteratorOp: Unable to parse variant tensor."));
|
||||
OP_REQUIRES_OK(ctx, wrapper->status());
|
||||
OP_REQUIRES_OK(ctx, iterator_resource->Restore(ctx, wrapper->get()));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU),
|
||||
MakeIteratorOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
|
||||
ToSingleElementOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("SaveIterator").Device(DEVICE_CPU),
|
||||
SaveIteratorOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("RestoreIterator").Device(DEVICE_CPU),
|
||||
RestoreIteratorOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
|
||||
OneShotIteratorOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU),
|
||||
@ -661,6 +810,10 @@ REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle").Device(DEVICE_CPU),
|
||||
IteratorToStringHandleOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU),
|
||||
IteratorFromStringHandleOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
|
||||
SerializeIteratorOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
|
||||
DeserializeIteratorOp);
|
||||
|
||||
} // namespace
|
||||
|
||||
|
@ -92,6 +92,7 @@ class SerializeTensorOp : public OpKernel {
|
||||
Name("SerializeTensor").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
SerializeTensorOp<T>);
|
||||
TF_CALL_ALL_TYPES(REGISTER)
|
||||
TF_CALL_variant(REGISTER)
|
||||
#undef REGISTER
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -112,19 +112,16 @@ class RangeDatasetOp : public DatasetOpKernel {
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(OpKernelContext* ctx,
|
||||
IteratorBundleWriter* writer) override {
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar<int64>(full_name("next"), next_));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("next"), next_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(OpKernelContext* ctx,
|
||||
IteratorBundleReader* reader) override {
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar<int64>(full_name("next"), &next_));
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("next"), &next_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -356,31 +356,30 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel {
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(OpKernelContext* ctx,
|
||||
IteratorBundleWriter* writer) override {
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar<int64>(
|
||||
full_name("current_file_index"), current_file_index_));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"),
|
||||
current_file_index_));
|
||||
|
||||
// `input_buffer_` is empty if
|
||||
// 1. GetNext has not been called even once.
|
||||
// 2. All files have been read and iterator has been exhausted.
|
||||
int64 current_pos = input_buffer_ ? input_buffer_->Tell() : -1;
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar<int64>(full_name("current_pos"), current_pos));
|
||||
writer->WriteScalar(full_name("current_pos"), current_pos));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(OpKernelContext* ctx,
|
||||
IteratorBundleReader* reader) override {
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
int64 current_file_index;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar<int64>(
|
||||
full_name("current_file_index"), ¤t_file_index));
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"),
|
||||
¤t_file_index));
|
||||
current_file_index_ = size_t(current_file_index);
|
||||
int64 current_pos;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar<int64>(full_name("current_pos"), ¤t_pos));
|
||||
reader->ReadScalar(full_name("current_pos"), ¤t_pos));
|
||||
|
||||
// Seek to current_pos.
|
||||
input_buffer_.reset();
|
||||
|
@ -124,19 +124,18 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(OpKernelContext* ctx,
|
||||
IteratorBundleWriter* writer) override {
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar<int64>(full_name("i"), i_));
|
||||
TF_RETURN_IF_ERROR(writer->SaveParent(ctx, input_impl_));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
|
||||
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(OpKernelContext* ctx,
|
||||
IteratorBundleReader* reader) override {
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar<int64>(full_name("i"), &i_));
|
||||
TF_RETURN_IF_ERROR(reader->RestoreParent(ctx, input_impl_));
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
|
||||
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -28753,18 +28753,6 @@ op {
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "RestoreIterator"
|
||||
input_arg {
|
||||
name: "iterator"
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
input_arg {
|
||||
name: "path"
|
||||
type: DT_STRING
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "RestoreSlice"
|
||||
input_arg {
|
||||
@ -29548,18 +29536,6 @@ op {
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "SaveIterator"
|
||||
input_arg {
|
||||
name: "iterator"
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
input_arg {
|
||||
name: "path"
|
||||
type: DT_STRING
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "SaveSlices"
|
||||
input_arg {
|
||||
|
@ -598,24 +598,6 @@ This operation may be executed multiple times. Each execution will reset the
|
||||
iterator in `iterator` to the first element of `dataset`.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("SaveIterator")
|
||||
.Input("iterator: resource")
|
||||
.Input("path: string")
|
||||
.SetShapeFn(shape_inference::NoOutputs)
|
||||
.Doc(R"doc(
|
||||
Saves the state of the `iterator` at `path`.
|
||||
|
||||
This state can be restored using "RestoreIterator".
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("RestoreIterator")
|
||||
.Input("iterator: resource")
|
||||
.Input("path: string")
|
||||
.SetShapeFn(shape_inference::NoOutputs)
|
||||
.Doc(R"doc(
|
||||
Restores the state of the `iterator` from the checkpoint saved at `path` using "SaveIterator".
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("OneShotIterator")
|
||||
.Output("handle: resource")
|
||||
.Attr("dataset_factory: func")
|
||||
@ -737,4 +719,28 @@ output_shapes: If specified, defines the shape of each tuple component in an
|
||||
element produced by the resulting iterator.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("SerializeIterator")
|
||||
.Input("resource_handle: resource")
|
||||
.Output("serialized: variant")
|
||||
.SetShapeFn(shape_inference::ScalarShape)
|
||||
.Doc(R"doc(
|
||||
Converts the given `resource_handle` representing an iterator to a variant tensor.
|
||||
|
||||
resource_handle: A handle to an iterator resource.
|
||||
serialized: A variant tensor storing the state of the iterator contained in the
|
||||
resource.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("DeserializeIterator")
|
||||
.Input("resource_handle: resource")
|
||||
.Input("serialized: variant")
|
||||
.SetShapeFn(shape_inference::NoOutputs)
|
||||
.Doc(R"doc(
|
||||
Converts the given variant tensor to an iterator and stores it in the given resource.
|
||||
|
||||
resource_handle: A handle to an iterator resource.
|
||||
serialized: A variant tensor storing the state of the iterator contained in the
|
||||
resource.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -2886,7 +2886,9 @@ tf_py_test(
|
||||
"//tensorflow/python:dataset_ops_gen",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:io_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:variables",
|
||||
@ -2907,7 +2909,9 @@ tf_py_test(
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:io_ops",
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
@ -3022,6 +3026,7 @@ tf_py_test(
|
||||
"//tensorflow/python:function",
|
||||
"//tensorflow/python:functional_ops",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:io_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python:script_ops",
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
@ -538,9 +539,23 @@ class IteratorTest(test.TestCase):
|
||||
|
||||
def testIncorrectIteratorRestore(self):
|
||||
|
||||
def _iterator_checkpoint_prefix():
|
||||
def _path():
|
||||
return os.path.join(self.get_temp_dir(), "iterator")
|
||||
|
||||
def _save_op(iterator_resource):
|
||||
iterator_state_variant = gen_dataset_ops.serialize_iterator(
|
||||
iterator_resource)
|
||||
save_op = io_ops.write_file(
|
||||
_path(), parsing_ops.serialize_tensor(iterator_state_variant))
|
||||
return save_op
|
||||
|
||||
def _restore_op(iterator_resource):
|
||||
iterator_state_variant = parsing_ops.parse_tensor(
|
||||
io_ops.read_file(_path()), dtypes.variant)
|
||||
restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
|
||||
iterator_state_variant)
|
||||
return restore_op
|
||||
|
||||
def _build_range_dataset_graph():
|
||||
start = 1
|
||||
stop = 10
|
||||
@ -548,22 +563,18 @@ class IteratorTest(test.TestCase):
|
||||
stop).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
path = _iterator_checkpoint_prefix()
|
||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
||||
path)
|
||||
save_op = _save_op(iterator._iterator_resource)
|
||||
restore_op = _restore_op(iterator._iterator_resource)
|
||||
return init_op, get_next, save_op, restore_op
|
||||
|
||||
def _build_reader_dataset_graph():
|
||||
filenames = ["test"] # Does not exist but we don't care in this test.
|
||||
path = _iterator_checkpoint_prefix()
|
||||
iterator = readers.FixedLengthRecordDataset(
|
||||
filenames, 1, 0, 0).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next_op = iterator.get_next()
|
||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
||||
path)
|
||||
save_op = _save_op(iterator._iterator_resource)
|
||||
restore_op = _restore_op(iterator._iterator_resource)
|
||||
return init_op, get_next_op, save_op, restore_op
|
||||
|
||||
# Saving iterator for RangeDataset graph.
|
||||
|
@ -27,6 +27,8 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
@ -169,6 +171,21 @@ class RangeDatasetTest(test.TestCase):
|
||||
def _iterator_checkpoint_prefix(self):
|
||||
return os.path.join(self.get_temp_dir(), "iterator")
|
||||
|
||||
def _save_op(self, iterator_resource):
|
||||
iterator_state_variant = gen_dataset_ops.serialize_iterator(
|
||||
iterator_resource)
|
||||
save_op = io_ops.write_file(
|
||||
self._iterator_checkpoint_prefix(),
|
||||
parsing_ops.serialize_tensor(iterator_state_variant))
|
||||
return save_op
|
||||
|
||||
def _restore_op(self, iterator_resource):
|
||||
iterator_state_variant = parsing_ops.parse_tensor(
|
||||
io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant)
|
||||
restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
|
||||
iterator_state_variant)
|
||||
return restore_op
|
||||
|
||||
def testSaveRestore(self):
|
||||
|
||||
def _build_graph(start, stop):
|
||||
@ -176,10 +193,8 @@ class RangeDatasetTest(test.TestCase):
|
||||
stop).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
path = self._iterator_checkpoint_prefix()
|
||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
||||
path)
|
||||
save_op = self._save_op(iterator._iterator_resource)
|
||||
restore_op = self._restore_op(iterator._iterator_resource)
|
||||
return init_op, get_next, save_op, restore_op
|
||||
|
||||
# Saving and restoring in different sessions.
|
||||
@ -222,14 +237,13 @@ class RangeDatasetTest(test.TestCase):
|
||||
|
||||
def testRestoreWithoutBuildingDatasetGraph(self):
|
||||
|
||||
def _build_graph(start, stop, num_epochs, path):
|
||||
def _build_graph(start, stop, num_epochs):
|
||||
dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs)
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
||||
path)
|
||||
save_op = self._save_op(iterator._iterator_resource)
|
||||
restore_op = self._restore_op(iterator._iterator_resource)
|
||||
return init_op, get_next, save_op, restore_op
|
||||
|
||||
# Saving and restoring in different sessions.
|
||||
@ -238,10 +252,8 @@ class RangeDatasetTest(test.TestCase):
|
||||
num_epochs = 5
|
||||
break_point = 5
|
||||
break_epoch = 3
|
||||
path = self._iterator_checkpoint_prefix()
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs,
|
||||
path)
|
||||
init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs)
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(init_op)
|
||||
@ -258,8 +270,7 @@ class RangeDatasetTest(test.TestCase):
|
||||
output_shapes = tensor_shape.scalar()
|
||||
iterator = iterator_ops.Iterator.from_structure(output_types,
|
||||
output_shapes)
|
||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
||||
path)
|
||||
restore_op = self._restore_op(iterator._iterator_resource)
|
||||
get_next = iterator.get_next()
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(restore_op)
|
||||
@ -278,10 +289,8 @@ class RangeDatasetTest(test.TestCase):
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
path = self._iterator_checkpoint_prefix()
|
||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
||||
path)
|
||||
save_op = self._save_op(iterator._iterator_resource)
|
||||
restore_op = self._restore_op(iterator._iterator_resource)
|
||||
return init_op, get_next, save_op, restore_op
|
||||
|
||||
# Saving and restoring in different sessions.
|
||||
@ -319,10 +328,8 @@ class RangeDatasetTest(test.TestCase):
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
path = self._iterator_checkpoint_prefix()
|
||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
||||
path)
|
||||
save_op = self._save_op(iterator._iterator_resource)
|
||||
restore_op = self._restore_op(iterator._iterator_resource)
|
||||
return init_op, get_next, save_op, restore_op
|
||||
|
||||
# Saving and restoring in different sessions.
|
||||
@ -355,10 +362,8 @@ class RangeDatasetTest(test.TestCase):
|
||||
stop).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
path = self._iterator_checkpoint_prefix()
|
||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
||||
path)
|
||||
save_op = self._save_op(iterator._iterator_resource)
|
||||
restore_op = self._restore_op(iterator._iterator_resource)
|
||||
return init_op, get_next, save_op, restore_op
|
||||
|
||||
start = 2
|
||||
@ -400,10 +405,8 @@ class RangeDatasetTest(test.TestCase):
|
||||
start, stop).repeat(num_epochs).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
path = self._iterator_checkpoint_prefix()
|
||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
||||
path)
|
||||
save_op = self._save_op(iterator._iterator_resource)
|
||||
restore_op = self._restore_op(iterator._iterator_resource)
|
||||
return init_op, get_next, save_op, restore_op
|
||||
|
||||
start = 2
|
||||
@ -447,10 +450,8 @@ class RangeDatasetTest(test.TestCase):
|
||||
start, stop).repeat(num_epochs).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
path = self._iterator_checkpoint_prefix()
|
||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
||||
path)
|
||||
save_op = self._save_op(iterator._iterator_resource)
|
||||
restore_op = self._restore_op(iterator._iterator_resource)
|
||||
return init_op, get_next, save_op, restore_op
|
||||
|
||||
start = 2
|
||||
|
@ -31,6 +31,8 @@ from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.lib.io import python_io
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
@ -273,18 +275,31 @@ class FixedLengthRecordReaderTest(test.TestCase):
|
||||
def _iterator_checkpoint_path(self):
|
||||
return os.path.join(self.get_temp_dir(), "iterator")
|
||||
|
||||
def _save_op(self, iterator_resource):
|
||||
iterator_state_variant = gen_dataset_ops.serialize_iterator(
|
||||
iterator_resource)
|
||||
save_op = io_ops.write_file(
|
||||
self._iterator_checkpoint_path(),
|
||||
parsing_ops.serialize_tensor(iterator_state_variant))
|
||||
return save_op
|
||||
|
||||
def _restore_op(self, iterator_resource):
|
||||
iterator_state_variant = parsing_ops.parse_tensor(
|
||||
io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant)
|
||||
restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
|
||||
iterator_state_variant)
|
||||
return restore_op
|
||||
|
||||
def _build_iterator_graph(self, num_epochs):
|
||||
filenames = self._createFiles()
|
||||
path = self._iterator_checkpoint_path()
|
||||
dataset = (readers.FixedLengthRecordDataset(
|
||||
filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
|
||||
.repeat(num_epochs))
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next_op = iterator.get_next()
|
||||
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
|
||||
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
|
||||
path)
|
||||
save_op = self._save_op(iterator._iterator_resource)
|
||||
restore_op = self._restore_op(iterator._iterator_resource)
|
||||
return init_op, get_next_op, save_op, restore_op
|
||||
|
||||
def _restore_iterator(self):
|
||||
@ -292,8 +307,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
|
||||
output_shapes = tensor_shape.scalar()
|
||||
iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
|
||||
get_next = iterator.get_next()
|
||||
restore_op = gen_dataset_ops.restore_iterator(
|
||||
iterator._iterator_resource, self._iterator_checkpoint_path())
|
||||
restore_op = self._restore_op(iterator._iterator_resource)
|
||||
return restore_op, get_next
|
||||
|
||||
def testSaveRestore(self):
|
||||
|
Loading…
Reference in New Issue
Block a user