diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 1b86ed39295..0781347fd6e 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -5614,6 +5614,40 @@ func IteratorGetNext(scope *Scope, iterator tf.Output, output_types []tf.DataTyp return components } +// Restores the state of the `iterator` from the checkpoint saved at `path` using "SaveIterator". +// +// Returns the created operation. +func RestoreIterator(scope *Scope, iterator tf.Output, path tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RestoreIterator", + Input: []tf.Input{ + iterator, path, + }, + } + return scope.AddOperation(opspec) +} + +// Saves the state of the `iterator` at `path`. +// +// This state can be restored using "RestoreIterator". +// +// Returns the created operation. +func SaveIterator(scope *Scope, iterator tf.Output, path tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SaveIterator", + Input: []tf.Input{ + iterator, path, + }, + } + return scope.AddOperation(opspec) +} + // Makes a new iterator from the given `dataset` and stores it in `iterator`. // // This operation may be executed multiple times. Each execution will reset the