Add dataset API functionality to snapshot_util::Reader
PiperOrigin-RevId: 307561078 Change-Id: I7b78c70dcd0ad66ee5f2c05f564d0f29ca8b34fc
This commit is contained in:
		
							parent
							
								
									2d79e9922f
								
							
						
					
					
						commit
						46646156db
					
				| @ -526,6 +526,7 @@ cc_library( | ||||
|         "//tensorflow/core:lib", | ||||
|         "//tensorflow/core:lib_internal", | ||||
|         "//tensorflow/core:protos_all_cc", | ||||
|         "//tensorflow/core/kernels/data:name_utils", | ||||
|         "//tensorflow/core/platform:coding", | ||||
|         "//tensorflow/core/platform:random", | ||||
|         "//tensorflow/core/profiler/lib:traceme", | ||||
|  | ||||
| @ -15,10 +15,14 @@ limitations under the License. | ||||
| 
 | ||||
| #include "tensorflow/core/kernels/data/experimental/snapshot_util.h" | ||||
| 
 | ||||
| #include <queue> | ||||
| 
 | ||||
| #include "absl/memory/memory.h" | ||||
| #include "tensorflow/core/common_runtime/dma_helper.h" | ||||
| #include "tensorflow/core/framework/dataset.h" | ||||
| #include "tensorflow/core/framework/graph.pb.h" | ||||
| #include "tensorflow/core/framework/tensor.pb.h" | ||||
| #include "tensorflow/core/kernels/data/name_utils.h" | ||||
| #include "tensorflow/core/lib/io/buffered_inputstream.h" | ||||
| #include "tensorflow/core/lib/io/random_inputstream.h" | ||||
| #include "tensorflow/core/lib/io/snappy/snappy_inputbuffer.h" | ||||
| @ -219,6 +223,198 @@ Status Reader::Create(Env* env, const std::string& filename, | ||||
|   return (*out_reader)->Initialize(env); | ||||
| } | ||||
| 
 | ||||
| class Reader::Dataset : public DatasetBase { | ||||
|  public: | ||||
|   explicit Dataset(const std::string& filename, const std::string& compression, | ||||
|                    const int64 version, const DataTypeVector& dtypes, | ||||
|                    const std::vector<PartialTensorShape>& shapes, | ||||
|                    DatasetContext::Params params) | ||||
|       : DatasetBase(DatasetContext(std::move(params))), | ||||
|         filename_(filename), | ||||
|         compression_(compression), | ||||
|         version_(version), | ||||
|         dtypes_(dtypes), | ||||
|         shapes_(shapes) {} | ||||
| 
 | ||||
|   const DataTypeVector& output_dtypes() const override { return dtypes_; } | ||||
| 
 | ||||
|   const std::vector<PartialTensorShape>& output_shapes() const override { | ||||
|     return shapes_; | ||||
|   } | ||||
| 
 | ||||
|   std::string DebugString() const override { | ||||
|     return "snapshot_util::Reader::Dataset"; | ||||
|   } | ||||
| 
 | ||||
|   Status CheckExternalState() const override { return Status::OK(); } | ||||
| 
 | ||||
|  protected: | ||||
|   Status AsGraphDefInternal(SerializationContext* ctx, | ||||
|                             DatasetGraphDefBuilder* b, | ||||
|                             Node** node) const override { | ||||
|     // TODO(frankchn): Implement for serialization and checkpointing.
 | ||||
|     return Status::OK(); | ||||
|   } | ||||
| 
 | ||||
|   std::unique_ptr<IteratorBase> MakeIteratorInternal( | ||||
|       const string& prefix) const override { | ||||
|     return absl::make_unique<Iterator>(Iterator::Params{ | ||||
|         this, name_utils::IteratorPrefix(node_name(), prefix)}); | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   std::string filename_; | ||||
|   std::string compression_; | ||||
|   int64 version_; | ||||
|   DataTypeVector dtypes_; | ||||
|   std::vector<PartialTensorShape> shapes_; | ||||
| 
 | ||||
|   class Iterator : public DatasetIterator<Dataset> { | ||||
|    public: | ||||
|     explicit Iterator(const Params& params) | ||||
|         : DatasetIterator<Dataset>(params) {} | ||||
| 
 | ||||
|     Status Initialize(IteratorContext* ctx) override { | ||||
|       return Reader::Create(ctx->env(), dataset()->filename_, | ||||
|                             dataset()->compression_, dataset()->version_, | ||||
|                             dataset()->dtypes_, &reader_); | ||||
|     } | ||||
| 
 | ||||
|    protected: | ||||
|     Status GetNextInternal(IteratorContext* ctx, | ||||
|                            std::vector<Tensor>* out_tensors, | ||||
|                            bool* end_of_sequence) override { | ||||
|       *end_of_sequence = false; | ||||
|       Status s = reader_->ReadTensors(out_tensors); | ||||
|       if (errors::IsOutOfRange(s)) { | ||||
|         *end_of_sequence = true; | ||||
|         return Status::OK(); | ||||
|       } | ||||
|       return s; | ||||
|     } | ||||
| 
 | ||||
|     Status SaveInternal(SerializationContext* ctx, | ||||
|                         IteratorStateWriter* writer) override { | ||||
|       // TODO(frankchn): Implement for serialization and checkpointing.
 | ||||
|       return Status::OK(); | ||||
|     } | ||||
| 
 | ||||
|     Status RestoreInternal(IteratorContext* ctx, | ||||
|                            IteratorStateReader* reader) override { | ||||
|       // TODO(frankchn): Implement for serialization and checkpointing.
 | ||||
|       return Status::OK(); | ||||
|     } | ||||
| 
 | ||||
|    private: | ||||
|     std::unique_ptr<Reader> reader_; | ||||
|   }; | ||||
| }; | ||||
| 
 | ||||
| class Reader::NestedDataset : public DatasetBase { | ||||
|  public: | ||||
|   explicit NestedDataset(std::vector<DatasetBase*> datasets, | ||||
|                          DatasetContext::Params params) | ||||
|       : DatasetBase(DatasetContext(std::move(params))), datasets_(datasets) { | ||||
|     dtypes_.push_back(DT_VARIANT); | ||||
|     gtl::InlinedVector<int64, 1> element_dim_sizes; | ||||
|     element_dim_sizes.push_back(1); | ||||
|     partial_shapes_.emplace_back(element_dim_sizes); | ||||
|   } | ||||
| 
 | ||||
|   const DataTypeVector& output_dtypes() const override { return dtypes_; } | ||||
| 
 | ||||
|   const std::vector<PartialTensorShape>& output_shapes() const override { | ||||
|     return partial_shapes_; | ||||
|   } | ||||
| 
 | ||||
|   std::string DebugString() const override { | ||||
|     return "snapshot_util::Reader::NestedDataset"; | ||||
|   } | ||||
| 
 | ||||
|   Status CheckExternalState() const override { return Status::OK(); } | ||||
| 
 | ||||
|  protected: | ||||
|   Status AsGraphDefInternal(SerializationContext* ctx, | ||||
|                             DatasetGraphDefBuilder* b, | ||||
|                             Node** node) const override { | ||||
|     // TODO(frankchn): Implement for serialization and checkpointing.
 | ||||
|     return Status::OK(); | ||||
|   } | ||||
| 
 | ||||
|   std::unique_ptr<IteratorBase> MakeIteratorInternal( | ||||
|       const string& prefix) const override { | ||||
|     return absl::make_unique<Iterator>(Iterator::Params{ | ||||
|         this, name_utils::IteratorPrefix(node_name(), prefix)}); | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   std::vector<DatasetBase*> datasets_; | ||||
|   DataTypeVector dtypes_; | ||||
|   std::vector<PartialTensorShape> partial_shapes_; | ||||
| 
 | ||||
|   class Iterator : public DatasetIterator<NestedDataset> { | ||||
|    public: | ||||
|     explicit Iterator(const Params& params) | ||||
|         : DatasetIterator<NestedDataset>(params), index_(0) {} | ||||
| 
 | ||||
|    protected: | ||||
|     Status GetNextInternal(IteratorContext* ctx, | ||||
|                            std::vector<Tensor>* out_tensors, | ||||
|                            bool* end_of_sequence) override { | ||||
|       *end_of_sequence = dataset()->datasets_.size() == index_; | ||||
|       if (!*end_of_sequence) { | ||||
|         Tensor tensor(DT_VARIANT, TensorShape({})); | ||||
| 
 | ||||
|         TF_RETURN_IF_ERROR( | ||||
|             StoreDatasetInVariantTensor(dataset()->datasets_[index_], &tensor)); | ||||
|         out_tensors->clear(); | ||||
|         out_tensors->push_back(std::move(tensor)); | ||||
| 
 | ||||
|         index_++; | ||||
|       } | ||||
|       return Status::OK(); | ||||
|     } | ||||
| 
 | ||||
|     Status SaveInternal(SerializationContext* ctx, | ||||
|                         IteratorStateWriter* writer) override { | ||||
|       // TODO(frankchn): Implement for serialization and checkpointing.
 | ||||
|       return Status::OK(); | ||||
|     } | ||||
| 
 | ||||
|     Status RestoreInternal(IteratorContext* ctx, | ||||
|                            IteratorStateReader* reader) override { | ||||
|       // TODO(frankchn): Implement for serialization and checkpointing.
 | ||||
|       return Status::OK(); | ||||
|     } | ||||
| 
 | ||||
|    private: | ||||
|     int64 index_; | ||||
|   }; | ||||
| }; | ||||
| 
 | ||||
| Status Reader::MakeNestedDataset(Env* env, | ||||
|                                  const std::vector<std::string>& filenames, | ||||
|                                  const string& compression_type, int version, | ||||
|                                  const DataTypeVector& dtypes, | ||||
|                                  const std::vector<PartialTensorShape>& shapes, | ||||
|                                  DatasetBase** output) { | ||||
|   std::vector<DatasetBase*> datasets; | ||||
| 
 | ||||
|   datasets.reserve(filenames.size()); | ||||
|   for (const auto& filename : filenames) { | ||||
|     datasets.push_back(new Dataset( | ||||
|         filename, compression_type, version, dtypes, shapes, | ||||
|         DatasetContext::Params{.type_string = "snapshot_util::Reader::Dataset", | ||||
|                                .node_name = "snapshot_util_reader_Dataset"})); | ||||
|   } | ||||
| 
 | ||||
|   *output = new NestedDataset( | ||||
|       datasets, DatasetContext::Params{ | ||||
|                     .type_string = "snapshot_util::Reader::NestedDataset", | ||||
|                     .node_name = "snapshot_util_reader_NestedDataset"}); | ||||
|   return Status::OK(); | ||||
| } | ||||
| 
 | ||||
| Reader::Reader(const std::string& filename, const string& compression_type, | ||||
|                int version, const DataTypeVector& dtypes) | ||||
|     : filename_(filename), | ||||
|  | ||||
| @ -16,6 +16,7 @@ limitations under the License. | ||||
| #ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_UTIL_H_ | ||||
| #define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_UTIL_H_ | ||||
| 
 | ||||
| #include "tensorflow/core/framework/dataset.h" | ||||
| #include "tensorflow/core/framework/tensor.h" | ||||
| #include "tensorflow/core/framework/types.h" | ||||
| #include "tensorflow/core/lib/io/compression.h" | ||||
| @ -119,6 +120,18 @@ class Reader { | ||||
|                        const DataTypeVector& dtypes, | ||||
|                        std::unique_ptr<Reader>* out_reader); | ||||
| 
 | ||||
|   // Returns a nested dataset for a set of given snapshot file names.
 | ||||
|   //
 | ||||
|   // This function takes a vector of snapshot files, and returns a nested
 | ||||
|   // dataset. Each element within the nested dataset is itself a dataset, and
 | ||||
|   // contains all the elements written out to each individual snapshot file.
 | ||||
|   static Status MakeNestedDataset(Env* env, | ||||
|                                   const std::vector<std::string>& filenames, | ||||
|                                   const string& compression_type, int version, | ||||
|                                   const DataTypeVector& dtypes, | ||||
|                                   const std::vector<PartialTensorShape>& shapes, | ||||
|                                   DatasetBase** output); | ||||
| 
 | ||||
|   Status ReadTensors(std::vector<Tensor>* read_tensors); | ||||
| 
 | ||||
|  private: | ||||
| @ -150,6 +163,9 @@ class Reader { | ||||
|   int num_simple_ = 0; | ||||
|   int num_complex_ = 0; | ||||
|   std::vector<bool> simple_tensor_mask_;  // true for simple, false for complex.
 | ||||
| 
 | ||||
|   class Dataset; | ||||
|   class NestedDataset; | ||||
| }; | ||||
| 
 | ||||
| Status WriteMetadataFile(const string& hash_dir, | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user