Add SerializeIterator op that serializes an IteratorResource into a variant tensor.

Add DeserializeIterator op that builds IteratorResource from a variant tensor.
Move BundleReaderWrapper and BundleWriterWrapper from dataset.h to iterator_ops.cc.
Add generic key-value store interfaces IteratorStateReader and IteratorStateWriter for reading/writing state of iterators.
Get rid of IteratorBundleReader and IteratorBundleWriter.

PiperOrigin-RevId: 173140858
This commit is contained in:
Saurabh Saxena 2017-10-23 11:07:10 -07:00 committed by TensorFlower Gardener
parent 57f3e529d9
commit 1038927c09
19 changed files with 544 additions and 393 deletions

View File

@ -185,6 +185,7 @@ py_test(
"//tensorflow/python:function",
"//tensorflow/python:functional_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:io_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:script_ops",
@ -252,6 +253,8 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:io_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:platform",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:variables",
@ -274,6 +277,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:io_ops",
"//tensorflow/python:lib",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:tensor_shape",

View File

@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import script_ops
@ -538,9 +539,23 @@ class IteratorTest(test.TestCase):
def testIncorrectIteratorRestore(self):
def _iterator_checkpoint_prefix():
def _path():
return os.path.join(self.get_temp_dir(), "iterator")
def _save_op(iterator_resource):
iterator_state_variant = gen_dataset_ops.serialize_iterator(
iterator_resource)
save_op = io_ops.write_file(
_path(), parsing_ops.serialize_tensor(iterator_state_variant))
return save_op
def _restore_op(iterator_resource):
iterator_state_variant = parsing_ops.parse_tensor(
io_ops.read_file(_path()), dtypes.variant)
restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
iterator_state_variant)
return restore_op
def _build_range_dataset_graph():
start = 1
stop = 10
@ -548,22 +563,18 @@ class IteratorTest(test.TestCase):
stop).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
path = _iterator_checkpoint_prefix()
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
path)
save_op = _save_op(iterator._iterator_resource)
restore_op = _restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
def _build_reader_dataset_graph():
filenames = ["test"] # Does not exist but we don't care in this test.
path = _iterator_checkpoint_prefix()
iterator = readers.FixedLengthRecordDataset(
filenames, 1, 0, 0).make_initializable_iterator()
init_op = iterator.initializer
get_next_op = iterator.get_next()
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
path)
save_op = _save_op(iterator._iterator_resource)
restore_op = _restore_op(iterator._iterator_resource)
return init_op, get_next_op, save_op, restore_op
# Saving iterator for RangeDataset graph.

View File

@ -29,6 +29,8 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
@ -193,6 +195,21 @@ class RangeDatasetTest(test.TestCase):
def _iterator_checkpoint_prefix(self):
return os.path.join(self.get_temp_dir(), "iterator")
def _save_op(self, iterator_resource):
iterator_state_variant = gen_dataset_ops.serialize_iterator(
iterator_resource)
save_op = io_ops.write_file(
self._iterator_checkpoint_prefix(),
parsing_ops.serialize_tensor(iterator_state_variant))
return save_op
def _restore_op(self, iterator_resource):
iterator_state_variant = parsing_ops.parse_tensor(
io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant)
restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
iterator_state_variant)
return restore_op
def testSaveRestore(self):
def _build_graph(start, stop):
@ -200,10 +217,8 @@ class RangeDatasetTest(test.TestCase):
stop).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
path = self._iterator_checkpoint_prefix()
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
path)
save_op = self._save_op(iterator._iterator_resource)
restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
# Saving and restoring in different sessions.
@ -246,14 +261,13 @@ class RangeDatasetTest(test.TestCase):
def testRestoreWithoutBuildingDatasetGraph(self):
def _build_graph(start, stop, num_epochs, path):
def _build_graph(start, stop, num_epochs):
dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs)
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
path)
save_op = self._save_op(iterator._iterator_resource)
restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
# Saving and restoring in different sessions.
@ -262,10 +276,8 @@ class RangeDatasetTest(test.TestCase):
num_epochs = 5
break_point = 5
break_epoch = 3
path = self._iterator_checkpoint_prefix()
with ops.Graph().as_default() as g:
init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs,
path)
init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs)
with self.test_session(graph=g) as sess:
sess.run(variables.global_variables_initializer())
sess.run(init_op)
@ -282,8 +294,7 @@ class RangeDatasetTest(test.TestCase):
output_shapes = tensor_shape.scalar()
iterator = iterator_ops.Iterator.from_structure(output_types,
output_shapes)
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
path)
restore_op = self._restore_op(iterator._iterator_resource)
get_next = iterator.get_next()
with self.test_session(graph=g) as sess:
sess.run(restore_op)
@ -302,10 +313,8 @@ class RangeDatasetTest(test.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
path = self._iterator_checkpoint_prefix()
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
path)
save_op = self._save_op(iterator._iterator_resource)
restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
# Saving and restoring in different sessions.
@ -343,10 +352,8 @@ class RangeDatasetTest(test.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
path = self._iterator_checkpoint_prefix()
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
path)
save_op = self._save_op(iterator._iterator_resource)
restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
# Saving and restoring in different sessions.
@ -379,10 +386,8 @@ class RangeDatasetTest(test.TestCase):
stop).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
path = self._iterator_checkpoint_prefix()
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
path)
save_op = self._save_op(iterator._iterator_resource)
restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
start = 2
@ -424,10 +429,8 @@ class RangeDatasetTest(test.TestCase):
start, stop).repeat(num_epochs).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
path = self._iterator_checkpoint_prefix()
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
path)
save_op = self._save_op(iterator._iterator_resource)
restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
start = 2
@ -471,10 +474,8 @@ class RangeDatasetTest(test.TestCase):
start, stop).repeat(num_epochs).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
path = self._iterator_checkpoint_prefix()
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
path)
save_op = self._save_op(iterator._iterator_resource)
restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
start = 2

View File

@ -33,6 +33,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
from tensorflow.python.util import compat
@ -276,18 +277,31 @@ class FixedLengthRecordReaderTest(test.TestCase):
def _iterator_checkpoint_path(self):
return os.path.join(self.get_temp_dir(), "iterator")
def _save_op(self, iterator_resource):
iterator_state_variant = gen_dataset_ops.serialize_iterator(
iterator_resource)
save_op = io_ops.write_file(
self._iterator_checkpoint_path(),
parsing_ops.serialize_tensor(iterator_state_variant))
return save_op
def _restore_op(self, iterator_resource):
iterator_state_variant = parsing_ops.parse_tensor(
io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant)
restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
iterator_state_variant)
return restore_op
def _build_iterator_graph(self, num_epochs):
filenames = self._createFiles()
path = self._iterator_checkpoint_path()
dataset = (readers.FixedLengthRecordDataset(
filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
.repeat(num_epochs))
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next_op = iterator.get_next()
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
path)
save_op = self._save_op(iterator._iterator_resource)
restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next_op, save_op, restore_op
def _restore_iterator(self):
@ -295,8 +309,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
output_shapes = tensor_shape.scalar()
iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
get_next = iterator.get_next()
restore_op = gen_dataset_ops.restore_iterator(
iterator._iterator_resource, self._iterator_checkpoint_path())
restore_op = self._restore_op(iterator._iterator_resource)
return restore_op, get_next
def testSaveRestore(self):

View File

@ -163,6 +163,7 @@ CORE_PROTO_SRCS = [
"framework/function.proto",
"framework/graph.proto",
"framework/graph_transfer_info.proto",
"framework/iterator.proto",
"framework/kernel_def.proto",
"framework/log_memory.proto",
"framework/node_def.proto",

View File

@ -0,0 +1,17 @@
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "IteratorProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.util";
// Protocol buffer representing the metadata for an iterator's state stored
// as a Variant tensor.
message IteratorStateMetadata {
// A user-specified version string.
string version = 1;
// Keys for tensors in the VariantTensorDataProto.
repeated string keys = 2;
}

View File

@ -6061,6 +6061,7 @@ tf_kernel_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
],
)

View File

@ -17,12 +17,14 @@ limitations under the License.
#include <memory>
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@ -39,54 +41,25 @@ namespace tensorflow {
class ResourceMgr;
class BundleReaderWrapper {
// Interface for reading values from a key-value store.
// Used for restoring iterator state.
class IteratorStateReader {
public:
BundleReaderWrapper(BundleReader* bundle_reader)
: bundle_reader_(bundle_reader) {}
virtual Status ReadScalar(StringPiece key, int64* val) = 0;
virtual Status ReadScalar(StringPiece key, string* val) = 0;
virtual bool Contains(StringPiece key) = 0;
// Reads a scalar value.
template <typename T>
Status ReadScalar(StringPiece key, T* val) {
Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
TF_RETURN_IF_ERROR(Lookup(key, &val_t));
*val = val_t.scalar<T>()();
return Status::OK();
}
bool Contains(StringPiece key) { return bundle_reader_->Contains(key); }
private:
Status Lookup(StringPiece key, Tensor* val) {
return bundle_reader_->Lookup(key, val);
}
BundleReader* bundle_reader_;
virtual ~IteratorStateReader() {}
};
class BundleWriterWrapper {
// Interface for writing values to a key-value store.
// Used for saving iterator state.
class IteratorStateWriter {
public:
// Note: We intentionally do not provide a constructor that builds a
// BundleWriter from the checkpoint path because we want the caller to be
// in-charge of calling BundleWriter::Finish(). If we expose the Finish()
// method here it may be called pre-maturely by users of this object.
explicit BundleWriterWrapper(BundleWriter* bundle_writer)
: bundle_writer_(bundle_writer) {}
virtual Status WriteScalar(StringPiece key, const int64& val) = 0;
virtual Status WriteScalar(StringPiece key, const string& val) = 0;
// Writes a scalar value.
template <typename T>
Status WriteScalar(StringPiece key, const T val) {
Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
val_t.scalar<T>()() = val;
TF_RETURN_IF_ERROR(Add(key, val_t));
return Status::OK();
}
private:
Status Add(StringPiece key, const Tensor& val) {
return bundle_writer_->Add(key, val);
}
BundleWriter* bundle_writer_;
virtual ~IteratorStateWriter() {}
};
// Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
@ -249,10 +222,6 @@ class IteratorContext {
// range of outputs is typically represented by an `DatasetBase`,
// defined below.
class IteratorBase {
protected:
class IteratorBundleReader;
class IteratorBundleWriter;
public:
virtual ~IteratorBase() {}
@ -284,75 +253,17 @@ class IteratorBase {
virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
// Saves the state of this iterator.
virtual Status Save(OpKernelContext* ctx, const string& path) {
BundleWriter bundle_writer(ctx->env(), path);
TF_RETURN_IF_ERROR(bundle_writer.status());
IteratorBundleWriter writer(&bundle_writer);
TF_RETURN_IF_ERROR(Save(ctx, &writer));
return bundle_writer.Finish();
}
virtual Status Restore(OpKernelContext* ctx, const string& path) {
if (!(ctx->env()->FileExists(MetaFilename(path)).ok())) {
return errors::NotFound(
"Failed to restore Iterator state. No file found at ",
MetaFilename(path));
}
BundleReader bundle_reader(ctx->env(), path);
TF_RETURN_IF_ERROR(bundle_reader.status());
IteratorBundleReader reader(&bundle_reader);
return Restore(ctx, &reader);
}
static const char kIteratorExhausted[];
protected:
// This is needed so that sub-classes of IteratorBase can call
// `RestoreInternal` on their parent iterators, e.g., in
// `RepeatDataasetOp::Dataset`.
class IteratorBundleReader : public BundleReaderWrapper {
public:
IteratorBundleReader(BundleReader* bundle_reader)
: BundleReaderWrapper(bundle_reader) {}
// Restores the state of a parent iterator recursively.
Status RestoreParent(OpKernelContext* ctx,
const std::unique_ptr<IteratorBase>& parent) {
return parent->RestoreInternal(ctx, this);
}
};
// This is needed so that sub-classes of IteratorBase can call
// `SaveInternal` on their parent iterators, e.g., in
// `RepeatDataasetOp::Dataset`.
class IteratorBundleWriter : public BundleWriterWrapper {
public:
IteratorBundleWriter(BundleWriter* bundle_writer)
: BundleWriterWrapper(bundle_writer) {}
// Saves the state of a parent iterator recursively.
Status SaveParent(OpKernelContext* ctx,
const std::unique_ptr<IteratorBase>& parent) {
return parent->SaveInternal(ctx, this);
}
};
virtual Status Save(OpKernelContext* ctx, IteratorBundleWriter* writer) {
virtual Status Save(IteratorStateWriter* writer) {
if (is_exhausted_) {
LOG(INFO) << "Iterator exhausted.";
return writer->WriteScalar<string>(kIteratorExhausted,
kIteratorExhausted);
return writer->WriteScalar(kIteratorExhausted, kIteratorExhausted);
} else {
return SaveInternal(ctx, writer);
return SaveInternal(writer);
}
}
// Saves the state of this iterator.
virtual Status SaveInternal(OpKernelContext* ctx,
IteratorBundleWriter* writer) {
return errors::Unimplemented("SaveInternal");
}
virtual Status Restore(OpKernelContext* ctx, IteratorBundleReader* reader) {
// Restores the state of this iterator.
virtual Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) {
if (reader->Contains(kIteratorExhausted)) {
LOG(INFO) << "Iterator exhausted. Nothing to restore.";
is_exhausted_ = true;
@ -362,9 +273,33 @@ class IteratorBase {
}
}
// Restores the state of this iterator.
static const char kIteratorExhausted[];
protected:
// This is needed so that sub-classes of IteratorBase can call
// `SaveInternal` on their parent iterators, e.g., in
// `RepeatDataasetOp::Dataset`.
Status SaveParent(IteratorStateWriter* writer,
const std::unique_ptr<IteratorBase>& parent) {
return parent->SaveInternal(writer);
}
// This is needed so that sub-classes of IteratorBase can call
// `RestoreInternal` on their parent iterators, e.g., in
// `RepeatDataasetOp::Dataset`.
Status RestoreParent(OpKernelContext* ctx, IteratorStateReader* reader,
const std::unique_ptr<IteratorBase>& parent) {
return parent->RestoreInternal(ctx, reader);
}
// Saves the state of this iterator recursively.
virtual Status SaveInternal(IteratorStateWriter* writer) {
return errors::Unimplemented("SaveInternal");
}
// Restores the state of this iterator recursively.
virtual Status RestoreInternal(OpKernelContext* ctx,
IteratorBundleReader* reader) {
IteratorStateReader* reader) {
return errors::Unimplemented("RestoreInternal");
}
@ -404,7 +339,7 @@ class DatasetBase : public core::RefCounted {
virtual string DebugString() = 0;
// Serializes the dataset and writes it to the `writer`.
virtual Status Save(BundleWriterWrapper* writer) const {
virtual Status Save(IteratorStateWriter* writer) const {
return errors::Unimplemented("DatasetBase::Save");
}
@ -435,20 +370,14 @@ class GraphDatasetBase : public DatasetBase {
const string op_name() const { return op_name_; }
Status Save(BundleWriterWrapper* writer) const override {
GraphDefBuilder b;
DatasetGraphDefBuilder db(&b);
Node* node = nullptr;
TF_RETURN_IF_ERROR(AsGraphDefInternal(&db, &node));
string output_name = node->name();
GraphDef graph_def;
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
Status Save(IteratorStateWriter* writer) const override {
string serialized_graph_def;
graph_def.SerializeToString(&serialized_graph_def);
string output_node;
TF_RETURN_IF_ERROR(Serialize(&serialized_graph_def, &output_node));
TF_RETURN_IF_ERROR(
writer->WriteScalar<string>(kDatasetGraphKey, serialized_graph_def));
writer->WriteScalar(kDatasetGraphKey, serialized_graph_def));
TF_RETURN_IF_ERROR(
writer->WriteScalar<string>(kDatasetGraphOutputNodeKey, output_name));
writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node));
return Status::OK();
}
@ -460,6 +389,18 @@ class GraphDatasetBase : public DatasetBase {
static const char kDatasetGraphOutputNodeKey[];
private:
Status Serialize(string* serialized_graph_def, string* output_node) const {
GraphDefBuilder b;
DatasetGraphDefBuilder db(&b);
Node* node = nullptr;
TF_RETURN_IF_ERROR(AsGraphDefInternal(&db, &node));
*output_node = node->name();
GraphDef graph_def;
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
graph_def.SerializeToString(serialized_graph_def);
return Status::OK();
}
const string op_name_;
};
@ -505,18 +446,18 @@ class DatasetIterator : public IteratorBase {
return GetNextInternal(ctx, out_tensors, end_of_sequence);
}
protected:
Status Save(OpKernelContext* ctx, IteratorBundleWriter* writer) final {
Status Save(IteratorStateWriter* writer) final {
TF_RETURN_IF_ERROR(dataset()->Save(writer));
return IteratorBase::Save(ctx, writer);
return IteratorBase::Save(writer);
}
protected:
// Internal implementation of GetNext that is wrapped in tracing logic.
virtual Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) = 0;
string full_name(const string& name) {
string full_name(const string& name) const {
return strings::StrCat(prefix(), ":", name);
}

View File

@ -16,9 +16,11 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/framework/iterator.pb.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
@ -35,6 +37,8 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following ops.
const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
Status VerifyTypesMatch(const DataTypeVector& expected,
const DataTypeVector& received) {
if (expected.size() != received.size()) {
@ -93,10 +97,10 @@ class IteratorResource : public ResourceBase {
}
}
Status Save(OpKernelContext* ctx, const string& path) {
Status Save(IteratorStateWriter* writer) {
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
if (captured_iterator) {
return captured_iterator->Save(ctx, path);
return captured_iterator->Save(writer);
} else {
return errors::FailedPrecondition(
"Save() failed because the iterator has not been initialized. "
@ -105,49 +109,34 @@ class IteratorResource : public ResourceBase {
}
}
Status Restore(OpKernelContext* ctx, const string& path) {
if (!(ctx->env()->FileExists(MetaFilename(path)).ok())) {
return errors::NotFound(
"Failed to restore Iterator state. No file found at ",
MetaFilename(path));
Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) {
string serialized_graph_def;
TF_RETURN_IF_ERROR(reader->ReadScalar(GraphDatasetBase::kDatasetGraphKey,
&serialized_graph_def));
GraphDef graph_def;
if (!graph_def.ParseFromString(serialized_graph_def)) {
return errors::Internal("Error parsing dataset GraphDef.");
}
string output_node;
TF_RETURN_IF_ERROR(reader->ReadScalar(
GraphDatasetBase::kDatasetGraphOutputNodeKey, &output_node));
DatasetBase* dataset = nullptr;
Graph graph(OpRegistry::Global());
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
std::vector<Tensor> outputs;
GraphRunner graph_runner(ctx->env());
TF_RETURN_IF_ERROR(graph_runner.Run(&graph, ctx->function_library(), {},
{output_node}, &outputs));
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
BundleReader bundle_reader(ctx->env(), path);
TF_RETURN_IF_ERROR(bundle_reader.status());
BundleReaderWrapper reader(&bundle_reader);
if (reader.Contains(GraphDatasetBase::kDatasetGraphKey)) {
string serialized_graph_def;
TF_RETURN_IF_ERROR(reader.ReadScalar(GraphDatasetBase::kDatasetGraphKey,
&serialized_graph_def));
GraphDef graph_def;
graph_def.ParseFromString(serialized_graph_def);
// TODO(srbs): Is there a way of getting the op registry of the original
// graph.
Graph graph(OpRegistry::Global());
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
string output_node;
TF_RETURN_IF_ERROR(reader.ReadScalar(
GraphDatasetBase::kDatasetGraphOutputNodeKey, &output_node));
std::vector<Tensor> outputs;
GraphRunner graph_runner(ctx->env());
TF_RETURN_IF_ERROR(graph_runner.Run(&graph, ctx->function_library(), {},
{output_node}, &outputs));
DatasetBase* dataset;
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
TF_RETURN_IF_ERROR(set_iterator(dataset->MakeIterator("Iterator")));
} else if (reader.Contains(IteratorBase::kIteratorExhausted)) {
TF_RETURN_IF_ERROR(set_iterator(std::unique_ptr<IteratorBase>(
new ExhaustedIterator(output_dtypes_, output_shapes_))));
}
TF_RETURN_IF_ERROR(set_iterator(dataset->MakeIterator("Iterator")));
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
if (captured_iterator) {
// TODO(srbs): Figure a way to pass bundle_reader here.
return captured_iterator->Restore(ctx, path);
return captured_iterator->Restore(ctx, reader);
} else {
return errors::FailedPrecondition(
"Failed to restore iterator from ", path,
". Make sure the checkpoint ",
"Failed to restore iterator. Make sure the checkpoint ",
"is not corrupt. If the checkpoint does not contain the GraphDef, ",
"you will need to initialize your iterator before restoring.");
}
@ -174,43 +163,194 @@ class IteratorResource : public ResourceBase {
}
private:
// A no-op iterator which always sets end_of_sequence = true. An instance of
// this is returned when attempting to restore an exhausted iterator. This is
// needed because the Dataset GraphDef may not have been saved for exhausted
// iterators so the actual Iterator can not be built.
class ExhaustedIterator : public IteratorBase {
public:
ExhaustedIterator(const DataTypeVector& output_dtypes,
const std::vector<PartialTensorShape>& output_shapes)
: output_dtypes_(output_dtypes), output_shapes_(output_shapes) {}
Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) final {
*end_of_sequence = true;
return Status::OK();
}
const DataTypeVector& output_dtypes() const override {
return output_dtypes_;
}
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
virtual const std::vector<PartialTensorShape>& output_shapes() {
return output_shapes_;
}
private:
const DataTypeVector output_dtypes_;
const std::vector<PartialTensorShape> output_shapes_;
};
std::shared_ptr<IteratorBase> iterator_;
const DataTypeVector output_dtypes_;
const std::vector<PartialTensorShape> output_shapes_;
};
// Helper class for reading data from a VariantTensorData object.
class VariantTensorDataReader : public IteratorStateReader {
public:
explicit VariantTensorDataReader(const VariantTensorData* data)
: data_(data) {
PreProcess();
}
// Returns OK iff the initialization was successful, i.e.,
// pre-processing did not have errors.
Status status() const { return status_; }
Status ReadScalar(StringPiece key, int64* val) override {
return ReadScalarInternal(key, val);
}
Status ReadScalar(StringPiece key, string* val) override {
return ReadScalarInternal(key, val);
}
bool Contains(StringPiece key) override {
return map_.find(key.ToString()) != map_.end();
}
private:
void PreProcess() {
string metadata;
data_->get_metadata(&metadata);
IteratorStateMetadata proto;
if (!proto.ParseFromString(metadata)) {
status_ = errors::Internal("Error parsing IteratorStateMetadata.");
return;
}
size_t num_entries = proto.keys_size();
CHECK_EQ(num_entries, data_->tensors_size());
for (size_t i = 0; i < num_entries; i++) {
map_[proto.keys(i)] = i;
}
}
template <typename T>
Status ReadScalarInternal(StringPiece key, T* val) {
if (map_.find(key.ToString()) == map_.end()) {
return errors::NotFound(key);
}
*val = data_->tensors(map_[key.ToString()]).scalar<T>()();
return Status::OK();
}
std::map<string, size_t> map_;
const VariantTensorData* data_; // Not owned.
Status status_;
};
// Helper class for writing data to a VariantTensorData object.
class VariantTensorDataWriter : public IteratorStateWriter {
public:
// Does not take ownership of data.
explicit VariantTensorDataWriter(VariantTensorData* data) : data_(data) {}
Status WriteScalar(StringPiece key, const int64& val) override {
return WriteScalarInternal(key, val);
}
Status WriteScalar(StringPiece key, const string& val) override {
return WriteScalarInternal(key, val);
}
// Writes the metadata to `data_`.
Status Flush() {
string metadata;
if (!metadata_proto_.SerializeToString(&metadata)) {
return errors::Internal("Unable to serialize IteratorStateMetadata.");
}
data_->set_metadata(metadata);
return Status::OK();
}
private:
template <typename T>
Status WriteScalarInternal(StringPiece key, const T& val) {
// Write key to the metadata proto. This gets written to `data_`
// when `Flush()` is called. We do this lazily to avoid multiple
// serialization calls.
metadata_proto_.add_keys(key.ToString());
// Update tensors.
Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
val_t.scalar<T>()() = val;
*(data_->add_tensors()) = std::move(val_t);
return Status::OK();
}
VariantTensorData* data_;
// TODO(srbs): Set the version string.
IteratorStateMetadata metadata_proto_;
};
// Wrapper for encoding/decoding the iterator state stored in a Variant tensor.
// The get() method returns an IteratorStateReader which can be used
// to restore iterator state.
//
// Usage example:
//
// Encoding:
//
// Tensor t(DT_VARIANT, TensorShape({}));
// t->scalar<Variant>()() = IteratorStateVariant(iterator_resource);
//
// Encode() sets the type_name of the VariantTensorData object to
// IteratorStateVariant::TypeName().
//
// Decoding:
//
// Variant v = <VariantTensorDataProto object>;
// DecodeUnaryVariant(&v);
// IteratorStateVariant* wrapper = v.get<IteratorStateVariant>();
// iterator_resource->Restore(ctx, wrapper->get())
//
// The type_name of the VariantTensorData object to be decoded must
// match IteratorStateVariant::TypeName().
class IteratorStateVariant {
public:
IteratorStateVariant() : data_(nullptr) {}
IteratorStateVariant(const IteratorStateVariant& other) : data_(nullptr) {
if (other.data_) {
Decode(*other.data_);
}
}
// Initializes this object with the current state of the iterator so
// that it can be written on the next call to Encode().
Status InitializeFromIterator(IteratorResource* iterator_resource) {
data_.reset(new VariantTensorData());
data_->set_type_name(TypeName());
VariantTensorDataWriter writer(data_.get());
TF_RETURN_IF_ERROR(iterator_resource->Save(&writer));
TF_RETURN_IF_ERROR(writer.Flush());
return Status::OK();
}
string TypeName() const { return kIteratorVariantTypeName; }
void Encode(VariantTensorData* data) const { *data = *data_; }
bool Decode(const VariantTensorData& data) {
if (data.type_name() != TypeName()) {
return false;
}
std::unique_ptr<VariantTensorData> tensor_data(new VariantTensorData);
*tensor_data = data;
std::unique_ptr<VariantTensorDataReader> reader(
new VariantTensorDataReader(tensor_data.get()));
status_ = reader->status();
if (!status_.ok()) {
return false;
}
data_ = std::move(tensor_data);
reader_ = std::move(reader);
return true;
}
IteratorStateReader* get() { return reader_.get(); }
Status status() const { return status_; }
string DebugString() const {
if (data_) {
return strings::StrCat("IteratorStateVariant<",
"data: ", data_->DebugString(),
" status: ", status_.ToString(), ">");
} else {
return strings::StrCat("IteratorStateVariant<empty>");
}
}
private:
std::unique_ptr<IteratorStateReader> reader_;
Status status_;
std::unique_ptr<VariantTensorData> data_;
};
// Register the reader class in the global variant decode_fn registry
// so that a Variant containing a serialized representation of iterator state
// can be decoded using DecodeUnaryVariant. If we don't do this we will need
// to manually decode the returned Variant using MaybeDecodeAndCopy in
// DeserializeIteratorOp which is not recommended.
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
kIteratorVariantTypeName);
// TODO(mrry): Can we simply use the template kernel here?
class IteratorHandleOp : public ResourceOpKernel<IteratorResource> {
public:
@ -294,37 +434,6 @@ class ToSingleElementOp : public OpKernel {
}
};
class SaveIteratorOp : public OpKernel {
public:
explicit SaveIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
IteratorResource* iterator_resource;
OP_REQUIRES_OK(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->input(1).shape()),
errors::InvalidArgument("SaveIteratorOp: path must be scalar"));
const string& path = ctx->input(1).scalar<string>()();
OP_REQUIRES_OK(ctx, iterator_resource->Save(ctx, path));
}
};
class RestoreIteratorOp : public OpKernel {
public:
explicit RestoreIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
IteratorResource* iterator_resource;
OP_REQUIRES_OK(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
OP_REQUIRES(
ctx, TensorShapeUtils::IsScalar(ctx->input(1).shape()),
errors::InvalidArgument("RestoreIteratorOp: path must be scalar"));
const string& path = ctx->input(1).scalar<string>()();
OP_REQUIRES_OK(ctx, iterator_resource->Restore(ctx, path));
}
};
class OneShotIteratorOp : public AsyncOpKernel {
public:
explicit OneShotIteratorOp(OpKernelConstruction* ctx)
@ -644,15 +753,55 @@ class IteratorFromStringHandleOp : public OpKernel {
std::vector<PartialTensorShape> output_shapes_;
};
class SerializeIteratorOp : public OpKernel {
public:
explicit SerializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
const Tensor& resource_handle_t = ctx->input(0);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
errors::InvalidArgument("resource_handle must be a scalar"));
// Validate that the handle corresponds to a real resource, and
// that it is an IteratorResource.
IteratorResource* iterator_resource;
OP_REQUIRES_OK(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
iterator_resource->Unref();
Tensor* variant_t;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &variant_t));
IteratorStateVariant v;
OP_REQUIRES_OK(ctx, v.InitializeFromIterator(iterator_resource));
variant_t->scalar<Variant>()() = v;
}
};
class DeserializeIteratorOp : public OpKernel {
public:
explicit DeserializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// Validate that the handle corresponds to a real resource, and
// that it is an IteratorResource.
IteratorResource* iterator_resource;
OP_REQUIRES_OK(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
Variant variant = ctx->input(1).scalar<Variant>()();
auto* wrapper = variant.get<IteratorStateVariant>();
OP_REQUIRES(ctx, wrapper != nullptr,
errors::InvalidArgument(
"DeserializeIteratorOp: Unable to parse variant tensor."));
OP_REQUIRES_OK(ctx, wrapper->status());
OP_REQUIRES_OK(ctx, iterator_resource->Restore(ctx, wrapper->get()));
}
};
REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU),
MakeIteratorOp);
REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
ToSingleElementOp);
REGISTER_KERNEL_BUILDER(Name("SaveIterator").Device(DEVICE_CPU),
SaveIteratorOp);
REGISTER_KERNEL_BUILDER(Name("RestoreIterator").Device(DEVICE_CPU),
RestoreIteratorOp);
REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
OneShotIteratorOp);
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU),
@ -661,6 +810,10 @@ REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle").Device(DEVICE_CPU),
IteratorToStringHandleOp);
REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU),
IteratorFromStringHandleOp);
REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
SerializeIteratorOp);
REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
DeserializeIteratorOp);
} // namespace

View File

@ -92,6 +92,7 @@ class SerializeTensorOp : public OpKernel {
Name("SerializeTensor").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
SerializeTensorOp<T>);
TF_CALL_ALL_TYPES(REGISTER)
TF_CALL_variant(REGISTER)
#undef REGISTER
} // namespace tensorflow

View File

@ -112,19 +112,16 @@ class RangeDatasetOp : public DatasetOpKernel {
}
protected:
Status SaveInternal(OpKernelContext* ctx,
IteratorBundleWriter* writer) override {
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(
writer->WriteScalar<int64>(full_name("next"), next_));
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("next"), next_));
return Status::OK();
}
Status RestoreInternal(OpKernelContext* ctx,
IteratorBundleReader* reader) override {
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(
reader->ReadScalar<int64>(full_name("next"), &next_));
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("next"), &next_));
return Status::OK();
}

View File

@ -356,31 +356,30 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel {
}
protected:
Status SaveInternal(OpKernelContext* ctx,
IteratorBundleWriter* writer) override {
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar<int64>(
full_name("current_file_index"), current_file_index_));
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"),
current_file_index_));
// `input_buffer_` is empty if
// 1. GetNext has not been called even once.
// 2. All files have been read and iterator has been exhausted.
int64 current_pos = input_buffer_ ? input_buffer_->Tell() : -1;
TF_RETURN_IF_ERROR(
writer->WriteScalar<int64>(full_name("current_pos"), current_pos));
writer->WriteScalar(full_name("current_pos"), current_pos));
return Status::OK();
}
Status RestoreInternal(OpKernelContext* ctx,
IteratorBundleReader* reader) override {
IteratorStateReader* reader) override {
mutex_lock l(mu_);
int64 current_file_index;
TF_RETURN_IF_ERROR(reader->ReadScalar<int64>(
full_name("current_file_index"), &current_file_index));
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"),
&current_file_index));
current_file_index_ = size_t(current_file_index);
int64 current_pos;
TF_RETURN_IF_ERROR(
reader->ReadScalar<int64>(full_name("current_pos"), &current_pos));
reader->ReadScalar(full_name("current_pos"), &current_pos));
// Seek to current_pos.
input_buffer_.reset();

View File

@ -124,19 +124,18 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
}
protected:
Status SaveInternal(OpKernelContext* ctx,
IteratorBundleWriter* writer) override {
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar<int64>(full_name("i"), i_));
TF_RETURN_IF_ERROR(writer->SaveParent(ctx, input_impl_));
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
return Status::OK();
}
Status RestoreInternal(OpKernelContext* ctx,
IteratorBundleReader* reader) override {
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(reader->ReadScalar<int64>(full_name("i"), &i_));
TF_RETURN_IF_ERROR(reader->RestoreParent(ctx, input_impl_));
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
return Status::OK();
}

View File

@ -28753,18 +28753,6 @@ op {
}
is_stateful: true
}
op {
name: "RestoreIterator"
input_arg {
name: "iterator"
type: DT_RESOURCE
}
input_arg {
name: "path"
type: DT_STRING
}
is_stateful: true
}
op {
name: "RestoreSlice"
input_arg {
@ -29548,18 +29536,6 @@ op {
}
is_stateful: true
}
op {
name: "SaveIterator"
input_arg {
name: "iterator"
type: DT_RESOURCE
}
input_arg {
name: "path"
type: DT_STRING
}
is_stateful: true
}
op {
name: "SaveSlices"
input_arg {

View File

@ -598,24 +598,6 @@ This operation may be executed multiple times. Each execution will reset the
iterator in `iterator` to the first element of `dataset`.
)doc");
REGISTER_OP("SaveIterator")
.Input("iterator: resource")
.Input("path: string")
.SetShapeFn(shape_inference::NoOutputs)
.Doc(R"doc(
Saves the state of the `iterator` at `path`.
This state can be restored using "RestoreIterator".
)doc");
REGISTER_OP("RestoreIterator")
.Input("iterator: resource")
.Input("path: string")
.SetShapeFn(shape_inference::NoOutputs)
.Doc(R"doc(
Restores the state of the `iterator` from the checkpoint saved at `path` using "SaveIterator".
)doc");
REGISTER_OP("OneShotIterator")
.Output("handle: resource")
.Attr("dataset_factory: func")
@ -737,4 +719,28 @@ output_shapes: If specified, defines the shape of each tuple component in an
element produced by the resulting iterator.
)doc");
REGISTER_OP("SerializeIterator")
.Input("resource_handle: resource")
.Output("serialized: variant")
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
Converts the given `resource_handle` representing an iterator to a variant tensor.
resource_handle: A handle to an iterator resource.
serialized: A variant tensor storing the state of the iterator contained in the
resource.
)doc");
REGISTER_OP("DeserializeIterator")
.Input("resource_handle: resource")
.Input("serialized: variant")
.SetShapeFn(shape_inference::NoOutputs)
.Doc(R"doc(
Converts the given variant tensor to an iterator and stores it in the given resource.
resource_handle: A handle to an iterator resource.
serialized: A variant tensor storing the state of the iterator contained in the
resource.
)doc");
} // namespace tensorflow

View File

@ -2886,7 +2886,9 @@ tf_py_test(
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:io_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:platform",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:variables",
@ -2907,7 +2909,9 @@ tf_py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:io_ops",
"//tensorflow/python:lib",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:iterator_ops",
@ -3022,6 +3026,7 @@ tf_py_test(
"//tensorflow/python:function",
"//tensorflow/python:functional_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:io_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:script_ops",

View File

@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import script_ops
@ -538,9 +539,23 @@ class IteratorTest(test.TestCase):
def testIncorrectIteratorRestore(self):
def _iterator_checkpoint_prefix():
def _path():
return os.path.join(self.get_temp_dir(), "iterator")
def _save_op(iterator_resource):
iterator_state_variant = gen_dataset_ops.serialize_iterator(
iterator_resource)
save_op = io_ops.write_file(
_path(), parsing_ops.serialize_tensor(iterator_state_variant))
return save_op
def _restore_op(iterator_resource):
iterator_state_variant = parsing_ops.parse_tensor(
io_ops.read_file(_path()), dtypes.variant)
restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
iterator_state_variant)
return restore_op
def _build_range_dataset_graph():
start = 1
stop = 10
@ -548,22 +563,18 @@ class IteratorTest(test.TestCase):
stop).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
path = _iterator_checkpoint_prefix()
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
path)
save_op = _save_op(iterator._iterator_resource)
restore_op = _restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
def _build_reader_dataset_graph():
filenames = ["test"] # Does not exist but we don't care in this test.
path = _iterator_checkpoint_prefix()
iterator = readers.FixedLengthRecordDataset(
filenames, 1, 0, 0).make_initializable_iterator()
init_op = iterator.initializer
get_next_op = iterator.get_next()
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
path)
save_op = _save_op(iterator._iterator_resource)
restore_op = _restore_op(iterator._iterator_resource)
return init_op, get_next_op, save_op, restore_op
# Saving iterator for RangeDataset graph.

View File

@ -27,6 +27,8 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
@ -169,6 +171,21 @@ class RangeDatasetTest(test.TestCase):
def _iterator_checkpoint_prefix(self):
return os.path.join(self.get_temp_dir(), "iterator")
def _save_op(self, iterator_resource):
iterator_state_variant = gen_dataset_ops.serialize_iterator(
iterator_resource)
save_op = io_ops.write_file(
self._iterator_checkpoint_prefix(),
parsing_ops.serialize_tensor(iterator_state_variant))
return save_op
def _restore_op(self, iterator_resource):
iterator_state_variant = parsing_ops.parse_tensor(
io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant)
restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
iterator_state_variant)
return restore_op
def testSaveRestore(self):
def _build_graph(start, stop):
@ -176,10 +193,8 @@ class RangeDatasetTest(test.TestCase):
stop).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
path = self._iterator_checkpoint_prefix()
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
path)
save_op = self._save_op(iterator._iterator_resource)
restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
# Saving and restoring in different sessions.
@ -222,14 +237,13 @@ class RangeDatasetTest(test.TestCase):
def testRestoreWithoutBuildingDatasetGraph(self):
def _build_graph(start, stop, num_epochs, path):
def _build_graph(start, stop, num_epochs):
dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs)
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
path)
save_op = self._save_op(iterator._iterator_resource)
restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
# Saving and restoring in different sessions.
@ -238,10 +252,8 @@ class RangeDatasetTest(test.TestCase):
num_epochs = 5
break_point = 5
break_epoch = 3
path = self._iterator_checkpoint_prefix()
with ops.Graph().as_default() as g:
init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs,
path)
init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs)
with self.test_session(graph=g) as sess:
sess.run(variables.global_variables_initializer())
sess.run(init_op)
@ -258,8 +270,7 @@ class RangeDatasetTest(test.TestCase):
output_shapes = tensor_shape.scalar()
iterator = iterator_ops.Iterator.from_structure(output_types,
output_shapes)
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
path)
restore_op = self._restore_op(iterator._iterator_resource)
get_next = iterator.get_next()
with self.test_session(graph=g) as sess:
sess.run(restore_op)
@ -278,10 +289,8 @@ class RangeDatasetTest(test.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
path = self._iterator_checkpoint_prefix()
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
path)
save_op = self._save_op(iterator._iterator_resource)
restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
# Saving and restoring in different sessions.
@ -319,10 +328,8 @@ class RangeDatasetTest(test.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
path = self._iterator_checkpoint_prefix()
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
path)
save_op = self._save_op(iterator._iterator_resource)
restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
# Saving and restoring in different sessions.
@ -355,10 +362,8 @@ class RangeDatasetTest(test.TestCase):
stop).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
path = self._iterator_checkpoint_prefix()
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
path)
save_op = self._save_op(iterator._iterator_resource)
restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
start = 2
@ -400,10 +405,8 @@ class RangeDatasetTest(test.TestCase):
start, stop).repeat(num_epochs).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
path = self._iterator_checkpoint_prefix()
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
path)
save_op = self._save_op(iterator._iterator_resource)
restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
start = 2
@ -447,10 +450,8 @@ class RangeDatasetTest(test.TestCase):
start, stop).repeat(num_epochs).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
path = self._iterator_checkpoint_prefix()
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
path)
save_op = self._save_op(iterator._iterator_resource)
restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
start = 2

View File

@ -31,6 +31,8 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
from tensorflow.python.util import compat
@ -273,18 +275,31 @@ class FixedLengthRecordReaderTest(test.TestCase):
def _iterator_checkpoint_path(self):
return os.path.join(self.get_temp_dir(), "iterator")
def _save_op(self, iterator_resource):
iterator_state_variant = gen_dataset_ops.serialize_iterator(
iterator_resource)
save_op = io_ops.write_file(
self._iterator_checkpoint_path(),
parsing_ops.serialize_tensor(iterator_state_variant))
return save_op
def _restore_op(self, iterator_resource):
iterator_state_variant = parsing_ops.parse_tensor(
io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant)
restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
iterator_state_variant)
return restore_op
def _build_iterator_graph(self, num_epochs):
filenames = self._createFiles()
path = self._iterator_checkpoint_path()
dataset = (readers.FixedLengthRecordDataset(
filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
.repeat(num_epochs))
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next_op = iterator.get_next()
save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
path)
save_op = self._save_op(iterator._iterator_resource)
restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next_op, save_op, restore_op
def _restore_iterator(self):
@ -292,8 +307,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
output_shapes = tensor_shape.scalar()
iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
get_next = iterator.get_next()
restore_op = gen_dataset_ops.restore_iterator(
iterator._iterator_resource, self._iterator_checkpoint_path())
restore_op = self._restore_op(iterator._iterator_resource)
return restore_op, get_next
def testSaveRestore(self):