[tf.data] Addressing TODOs regarding deprecated APIs.
PiperOrigin-RevId: 299470038 Change-Id: If445f10866b9356e35abbd139a929d5a0f77c0bc
This commit is contained in:
		
							parent
							
								
									4831d4f42f
								
							
						
					
					
						commit
						3874b28288
					
				@ -619,12 +619,8 @@ class IteratorBase {
 | 
				
			|||||||
  // Saves the state of this iterator.
 | 
					  // Saves the state of this iterator.
 | 
				
			||||||
  //
 | 
					  //
 | 
				
			||||||
  // This method is used to store the state of the iterator in a checkpoint.
 | 
					  // This method is used to store the state of the iterator in a checkpoint.
 | 
				
			||||||
  //
 | 
					 | 
				
			||||||
  // TODO(jsimsa): Make this method pure virtual once all `IteratorBase`
 | 
					 | 
				
			||||||
  // implementations have an override.
 | 
					  // implementations have an override.
 | 
				
			||||||
  virtual Status SaveInternal(IteratorStateWriter* writer) {
 | 
					  virtual Status SaveInternal(IteratorStateWriter* writer) = 0;
 | 
				
			||||||
    return errors::Unimplemented("SaveInternal");
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Restores the state of this iterator.
 | 
					  // Restores the state of this iterator.
 | 
				
			||||||
  //
 | 
					  //
 | 
				
			||||||
@ -633,13 +629,9 @@ class IteratorBase {
 | 
				
			|||||||
  // Implementations may assume that the iterator is in a clean state. That is,
 | 
					  // Implementations may assume that the iterator is in a clean state. That is,
 | 
				
			||||||
  // its `Initialize` method has been called, but its `GetNext` method has
 | 
					  // its `Initialize` method has been called, but its `GetNext` method has
 | 
				
			||||||
  // never been called.
 | 
					  // never been called.
 | 
				
			||||||
  //
 | 
					 | 
				
			||||||
  // TODO(jsimsa): Make this method pure virtual once all `IteratorBase`
 | 
					 | 
				
			||||||
  // implementations have an override.
 | 
					  // implementations have an override.
 | 
				
			||||||
  virtual Status RestoreInternal(IteratorContext* ctx,
 | 
					  virtual Status RestoreInternal(IteratorContext* ctx,
 | 
				
			||||||
                                 IteratorStateReader* reader) {
 | 
					                                 IteratorStateReader* reader) = 0;
 | 
				
			||||||
    return errors::Unimplemented("RestoreInternal");
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Returns the number of elements produced by this iterator.
 | 
					  // Returns the number of elements produced by this iterator.
 | 
				
			||||||
  int64 num_elements() const {
 | 
					  int64 num_elements() const {
 | 
				
			||||||
@ -749,22 +741,6 @@ class DatasetBase : public core::RefCounted {
 | 
				
			|||||||
    return MakeIterator(&ctx, parent, output_prefix, iterator);
 | 
					    return MakeIterator(&ctx, parent, output_prefix, iterator);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // TODO(jsimsa): Remove this overlead once all callers are migrated to the API
 | 
					 | 
				
			||||||
  // that passes in the parent iterator pointer.
 | 
					 | 
				
			||||||
  ABSL_DEPRECATED("Use the overload that passes the parent iterator pointer.")
 | 
					 | 
				
			||||||
  Status MakeIterator(IteratorContext* ctx, const string& output_prefix,
 | 
					 | 
				
			||||||
                      std::unique_ptr<IteratorBase>* iterator) const {
 | 
					 | 
				
			||||||
    return MakeIterator(ctx, /*parent=*/nullptr, output_prefix, iterator);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // TODO(jsimsa): Remove this overlead once all callers are migrated to the API
 | 
					 | 
				
			||||||
  // that passes in the parent iterator pointer.
 | 
					 | 
				
			||||||
  ABSL_DEPRECATED("Use the overload that passes the parent iterator pointer.")
 | 
					 | 
				
			||||||
  Status MakeIterator(IteratorContext&& ctx, const string& output_prefix,
 | 
					 | 
				
			||||||
                      std::unique_ptr<IteratorBase>* iterator) const {
 | 
					 | 
				
			||||||
    return MakeIterator(&ctx, output_prefix, iterator);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Returns a new iterator restored from the checkpoint data in `reader`.
 | 
					  // Returns a new iterator restored from the checkpoint data in `reader`.
 | 
				
			||||||
  Status MakeIteratorFromCheckpoint(
 | 
					  Status MakeIteratorFromCheckpoint(
 | 
				
			||||||
      IteratorContext* ctx, const string& output_prefix,
 | 
					      IteratorContext* ctx, const string& output_prefix,
 | 
				
			||||||
@ -807,27 +783,11 @@ class DatasetBase : public core::RefCounted {
 | 
				
			|||||||
  // A human-readable debug string for this dataset.
 | 
					  // A human-readable debug string for this dataset.
 | 
				
			||||||
  virtual string DebugString() const = 0;
 | 
					  virtual string DebugString() const = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // If the dataset is stateful it will not be possible to save its graph or
 | 
					 | 
				
			||||||
  // checkpoint the state of its iterators.
 | 
					 | 
				
			||||||
  //
 | 
					 | 
				
			||||||
  // TODO(jsimsa): Remove this method once all `DatasetBase` implementations are
 | 
					 | 
				
			||||||
  // migrated over to `CheckExternalState`.
 | 
					 | 
				
			||||||
  ABSL_DEPRECATED("Use CheckExternalState instead.")
 | 
					 | 
				
			||||||
  virtual bool IsStateful() const { return false; }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Indicates whether the dataset depends on any external state which would
 | 
					  // Indicates whether the dataset depends on any external state which would
 | 
				
			||||||
  // prevent it from being serializable. If so, the method returns
 | 
					  // prevent it from being serializable. If so, the method returns
 | 
				
			||||||
  // `errors::FailedPrecondition` with a message that identifies the external
 | 
					  // `errors::FailedPrecondition` with a message that identifies the external
 | 
				
			||||||
  // state. Otherwise, the method returns `Status::OK()`.
 | 
					  // state. Otherwise, the method returns `Status::OK()`.
 | 
				
			||||||
  //
 | 
					  virtual Status CheckExternalState() const = 0;
 | 
				
			||||||
  // TODO(jsimsa): Make this method pure virtual once all `DatasetBase`
 | 
					 | 
				
			||||||
  // implementations have an override.
 | 
					 | 
				
			||||||
  virtual Status CheckExternalState() const {
 | 
					 | 
				
			||||||
    if (IsStateful()) {
 | 
					 | 
				
			||||||
      return errors::FailedPrecondition("Dataset cannot be serialized.");
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    return Status::OK();
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 protected:
 | 
					 protected:
 | 
				
			||||||
  friend Status AsGraphDef(
 | 
					  friend Status AsGraphDef(
 | 
				
			||||||
 | 
				
			|||||||
@ -432,15 +432,6 @@ Status MakeIteratorFromInputElement(
 | 
				
			|||||||
      out_iterator);
 | 
					      out_iterator);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Status MakeIteratorFromInputElement(
 | 
					 | 
				
			||||||
    IteratorContext* ctx, const std::vector<Tensor>& input_element,
 | 
					 | 
				
			||||||
    int64 thread_index, const InstantiatedCapturedFunction& inst_captured_func,
 | 
					 | 
				
			||||||
    StringPiece prefix, std::unique_ptr<IteratorBase>* out_iterator) {
 | 
					 | 
				
			||||||
  return MakeIteratorFromInputElement(ctx, /*parent=*/nullptr, input_element,
 | 
					 | 
				
			||||||
                                      thread_index, inst_captured_func, prefix,
 | 
					 | 
				
			||||||
                                      out_iterator);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
/* static */
 | 
					/* static */
 | 
				
			||||||
Status FunctionMetadata::Create(
 | 
					Status FunctionMetadata::Create(
 | 
				
			||||||
    OpKernelConstruction* ctx, const string& func_name, Params params,
 | 
					    OpKernelConstruction* ctx, const string& func_name, Params params,
 | 
				
			||||||
 | 
				
			|||||||
@ -47,17 +47,6 @@ Status MakeIteratorFromInputElement(
 | 
				
			|||||||
    const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
 | 
					    const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
 | 
				
			||||||
    std::unique_ptr<IteratorBase>* out_iterator);
 | 
					    std::unique_ptr<IteratorBase>* out_iterator);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Creates an iterator for a dataset which is created by applying the given
 | 
					 | 
				
			||||||
// function to the given input element.
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// TODO(jsimsa): Remove this overload once all callers are migrated to the API
 | 
					 | 
				
			||||||
// that passes in the parent iterator pointer.
 | 
					 | 
				
			||||||
ABSL_DEPRECATED("Use the overload that passes the parent iterator pointer.")
 | 
					 | 
				
			||||||
Status MakeIteratorFromInputElement(
 | 
					 | 
				
			||||||
    IteratorContext* ctx, const std::vector<Tensor>& input_element,
 | 
					 | 
				
			||||||
    int64 thread_index, const InstantiatedCapturedFunction& inst_captured_func,
 | 
					 | 
				
			||||||
    StringPiece prefix, std::unique_ptr<IteratorBase>* out_iterator);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Determines whether the given node is stateful.
 | 
					// Determines whether the given node is stateful.
 | 
				
			||||||
Status IsNodeStateful(const FunctionLibraryDefinition& library,
 | 
					Status IsNodeStateful(const FunctionLibraryDefinition& library,
 | 
				
			||||||
                      const NodeDef& node);
 | 
					                      const NodeDef& node);
 | 
				
			||||||
 | 
				
			|||||||
@ -55,6 +55,8 @@ class WrapperDataset : public DatasetBase {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  string DebugString() const override { return "WrapperDataset"; }
 | 
					  string DebugString() const override { return "WrapperDataset"; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Status CheckExternalState() const override { return Status::OK(); }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 protected:
 | 
					 protected:
 | 
				
			||||||
  Status AsGraphDefInternal(SerializationContext* ctx,
 | 
					  Status AsGraphDefInternal(SerializationContext* ctx,
 | 
				
			||||||
                            DatasetGraphDefBuilder* b,
 | 
					                            DatasetGraphDefBuilder* b,
 | 
				
			||||||
 | 
				
			|||||||
@ -64,6 +64,8 @@ class RandomDatasetOp::Dataset : public DatasetBase {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  int64 Cardinality() const override { return kInfiniteCardinality; }
 | 
					  int64 Cardinality() const override { return kInfiniteCardinality; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Status CheckExternalState() const override { return Status::OK(); }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 protected:
 | 
					 protected:
 | 
				
			||||||
  Status AsGraphDefInternal(SerializationContext* ctx,
 | 
					  Status AsGraphDefInternal(SerializationContext* ctx,
 | 
				
			||||||
                            DatasetGraphDefBuilder* b,
 | 
					                            DatasetGraphDefBuilder* b,
 | 
				
			||||||
 | 
				
			|||||||
@ -37,7 +37,7 @@ class RandomDatasetOp : public DatasetOpKernel {
 | 
				
			|||||||
  explicit RandomDatasetOp(OpKernelConstruction* ctx);
 | 
					  explicit RandomDatasetOp(OpKernelConstruction* ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 protected:
 | 
					 protected:
 | 
				
			||||||
  virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override;
 | 
					  void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
  class Dataset;
 | 
					  class Dataset;
 | 
				
			||||||
 | 
				
			|||||||
@ -20,6 +20,7 @@ limitations under the License.
 | 
				
			|||||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
 | 
					#include "tensorflow/core/kernels/data/dataset_utils.h"
 | 
				
			||||||
#include "tensorflow/core/lib/core/refcount.h"
 | 
					#include "tensorflow/core/lib/core/refcount.h"
 | 
				
			||||||
#include "tensorflow/core/lib/core/threadpool.h"
 | 
					#include "tensorflow/core/lib/core/threadpool.h"
 | 
				
			||||||
 | 
					#include "tensorflow/core/platform/thread_annotations.h"
 | 
				
			||||||
#include "tensorflow/core/util/work_sharder.h"
 | 
					#include "tensorflow/core/util/work_sharder.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace tensorflow {
 | 
					namespace tensorflow {
 | 
				
			||||||
@ -203,6 +204,7 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
 | 
				
			|||||||
      Status GetNextInternal(IteratorContext* ctx,
 | 
					      Status GetNextInternal(IteratorContext* ctx,
 | 
				
			||||||
                             std::vector<Tensor>* out_tensors,
 | 
					                             std::vector<Tensor>* out_tensors,
 | 
				
			||||||
                             bool* end_of_sequence) override {
 | 
					                             bool* end_of_sequence) override {
 | 
				
			||||||
 | 
					        mutex_lock l(mu_);
 | 
				
			||||||
        return input_impl_->GetNext(IteratorContext(CreateParams(ctx)),
 | 
					        return input_impl_->GetNext(IteratorContext(CreateParams(ctx)),
 | 
				
			||||||
                                    out_tensors, end_of_sequence);
 | 
					                                    out_tensors, end_of_sequence);
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
@ -214,6 +216,20 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
 | 
				
			|||||||
                                         /*ratio=*/1);
 | 
					                                         /*ratio=*/1);
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      Status SaveInternal(IteratorStateWriter* writer) override {
 | 
				
			||||||
 | 
					        mutex_lock l(mu_);
 | 
				
			||||||
 | 
					        DCHECK(input_impl_ != nullptr);
 | 
				
			||||||
 | 
					        TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
 | 
				
			||||||
 | 
					        return Status::OK();
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      Status RestoreInternal(IteratorContext* ctx,
 | 
				
			||||||
 | 
					                             IteratorStateReader* reader) override {
 | 
				
			||||||
 | 
					        mutex_lock l(mu_);
 | 
				
			||||||
 | 
					        TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
 | 
				
			||||||
 | 
					        return Status::OK();
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
     private:
 | 
					     private:
 | 
				
			||||||
      IteratorContext::Params CreateParams(IteratorContext* ctx) {
 | 
					      IteratorContext::Params CreateParams(IteratorContext* ctx) {
 | 
				
			||||||
        ThreadPoolResource* pool = dataset()->threadpool_;
 | 
					        ThreadPoolResource* pool = dataset()->threadpool_;
 | 
				
			||||||
@ -225,7 +241,8 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
 | 
				
			|||||||
        return params;
 | 
					        return params;
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      std::unique_ptr<IteratorBase> input_impl_;
 | 
					      mutex mu_;
 | 
				
			||||||
 | 
					      std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const DatasetBase* const input_;
 | 
					    const DatasetBase* const input_;
 | 
				
			||||||
@ -319,6 +336,7 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
 | 
				
			|||||||
        auto max_parallelism = dataset()->max_intra_op_parallelism_;
 | 
					        auto max_parallelism = dataset()->max_intra_op_parallelism_;
 | 
				
			||||||
        params.runner =
 | 
					        params.runner =
 | 
				
			||||||
            RunnerWithMaxParallelism(*ctx->runner(), max_parallelism);
 | 
					            RunnerWithMaxParallelism(*ctx->runner(), max_parallelism);
 | 
				
			||||||
 | 
					        mutex_lock l(mu_);
 | 
				
			||||||
        return input_impl_->GetNext(IteratorContext{std::move(params)},
 | 
					        return input_impl_->GetNext(IteratorContext{std::move(params)},
 | 
				
			||||||
                                    out_tensors, end_of_sequence);
 | 
					                                    out_tensors, end_of_sequence);
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
@ -330,8 +348,23 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
 | 
				
			|||||||
                                         /*ratio=*/1);
 | 
					                                         /*ratio=*/1);
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      Status SaveInternal(IteratorStateWriter* writer) override {
 | 
				
			||||||
 | 
					        mutex_lock l(mu_);
 | 
				
			||||||
 | 
					        DCHECK(input_impl_ != nullptr);
 | 
				
			||||||
 | 
					        TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
 | 
				
			||||||
 | 
					        return Status::OK();
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      Status RestoreInternal(IteratorContext* ctx,
 | 
				
			||||||
 | 
					                             IteratorStateReader* reader) override {
 | 
				
			||||||
 | 
					        mutex_lock l(mu_);
 | 
				
			||||||
 | 
					        TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
 | 
				
			||||||
 | 
					        return Status::OK();
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
     private:
 | 
					     private:
 | 
				
			||||||
      std::unique_ptr<IteratorBase> input_impl_;
 | 
					      mutex mu_;
 | 
				
			||||||
 | 
					      std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const DatasetBase* const input_;
 | 
					    const DatasetBase* const input_;
 | 
				
			||||||
@ -425,6 +458,7 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
 | 
				
			|||||||
          pool->Schedule(std::move(c));
 | 
					          pool->Schedule(std::move(c));
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
        params.runner_threadpool_size = dataset()->num_threads_;
 | 
					        params.runner_threadpool_size = dataset()->num_threads_;
 | 
				
			||||||
 | 
					        mutex_lock l(mu_);
 | 
				
			||||||
        return input_impl_->GetNext(IteratorContext{std::move(params)},
 | 
					        return input_impl_->GetNext(IteratorContext{std::move(params)},
 | 
				
			||||||
                                    out_tensors, end_of_sequence);
 | 
					                                    out_tensors, end_of_sequence);
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
@ -436,8 +470,23 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
 | 
				
			|||||||
                                         /*ratio=*/1);
 | 
					                                         /*ratio=*/1);
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      Status SaveInternal(IteratorStateWriter* writer) override {
 | 
				
			||||||
 | 
					        mutex_lock l(mu_);
 | 
				
			||||||
 | 
					        DCHECK(input_impl_ != nullptr);
 | 
				
			||||||
 | 
					        TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
 | 
				
			||||||
 | 
					        return Status::OK();
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      Status RestoreInternal(IteratorContext* ctx,
 | 
				
			||||||
 | 
					                             IteratorStateReader* reader) override {
 | 
				
			||||||
 | 
					        mutex_lock l(mu_);
 | 
				
			||||||
 | 
					        TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
 | 
				
			||||||
 | 
					        return Status::OK();
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
     private:
 | 
					     private:
 | 
				
			||||||
      std::unique_ptr<IteratorBase> input_impl_;
 | 
					      mutex mu_;
 | 
				
			||||||
 | 
					      std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const DatasetBase* const input_;
 | 
					    const DatasetBase* const input_;
 | 
				
			||||||
 | 
				
			|||||||
@ -158,6 +158,17 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
 | 
				
			|||||||
      return model::MakeSourceNode(std::move(args));
 | 
					      return model::MakeSourceNode(std::move(args));
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Status SaveInternal(IteratorStateWriter* writer) override {
 | 
				
			||||||
 | 
					      return errors::Unimplemented(
 | 
				
			||||||
 | 
					          "GeneratorDataset does not support checkpointing.");
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Status RestoreInternal(IteratorContext* ctx,
 | 
				
			||||||
 | 
					                           IteratorStateReader* reader) override {
 | 
				
			||||||
 | 
					      return errors::Unimplemented(
 | 
				
			||||||
 | 
					          "GeneratorDataset does not support checkpointing.");
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   private:
 | 
					   private:
 | 
				
			||||||
    mutex mu_;
 | 
					    mutex mu_;
 | 
				
			||||||
    bool initialized_ TF_GUARDED_BY(mu_) = false;
 | 
					    bool initialized_ TF_GUARDED_BY(mu_) = false;
 | 
				
			||||||
 | 
				
			|||||||
@ -573,7 +573,8 @@ class MultiDeviceIteratorInitOp : public OpKernel {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    IteratorContext iter_ctx(std::move(params));
 | 
					    IteratorContext iter_ctx(std::move(params));
 | 
				
			||||||
    OP_REQUIRES_OK(
 | 
					    OP_REQUIRES_OK(
 | 
				
			||||||
        ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator));
 | 
					        ctx, dataset->MakeIterator(std::move(iter_ctx), /*parent=*/nullptr,
 | 
				
			||||||
 | 
					                                   "Iterator", &iterator));
 | 
				
			||||||
    int64 incarnation_id;
 | 
					    int64 incarnation_id;
 | 
				
			||||||
    OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size,
 | 
					    OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size,
 | 
				
			||||||
                                       &incarnation_id));
 | 
					                                       &incarnation_id));
 | 
				
			||||||
 | 
				
			|||||||
@ -496,7 +496,8 @@ TEST_P(ParameterizedIteratorSaveAndRestoreTest, IteratorSaveAndRestore) {
 | 
				
			|||||||
                                                     &window_dataset));
 | 
					                                                     &window_dataset));
 | 
				
			||||||
            std::unique_ptr<IteratorBase> window_dataset_iterator;
 | 
					            std::unique_ptr<IteratorBase> window_dataset_iterator;
 | 
				
			||||||
            TF_ASSERT_OK(window_dataset->MakeIterator(
 | 
					            TF_ASSERT_OK(window_dataset->MakeIterator(
 | 
				
			||||||
                iterator_ctx_.get(), test_case.dataset_params.iterator_prefix(),
 | 
					                iterator_ctx_.get(), /*parent=*/nullptr,
 | 
				
			||||||
 | 
					                test_case.dataset_params.iterator_prefix(),
 | 
				
			||||||
                &window_dataset_iterator));
 | 
					                &window_dataset_iterator));
 | 
				
			||||||
            bool end_of_window_dataset = false;
 | 
					            bool end_of_window_dataset = false;
 | 
				
			||||||
            std::vector<Tensor> window_elements;
 | 
					            std::vector<Tensor> window_elements;
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user