[tf.data] Internal cleanup.
PiperOrigin-RevId: 295768875 Change-Id: I77da989a9eb2c74706e64bdc5e863d13fa76832a
This commit is contained in:
		
							parent
							
								
									b04371bc95
								
							
						
					
					
						commit
						dff4559ac3
					
				| @ -182,8 +182,8 @@ TEST_P(ParameterizedGetNextTest, GetNext) { | ||||
| 
 | ||||
|   // Test the read mode.
 | ||||
|   TF_ASSERT_OK(dataset_->MakeIterator( | ||||
|       iterator_ctx_.get(), test_case.dataset_params.iterator_prefix(), | ||||
|       &iterator_)); | ||||
|       iterator_ctx_.get(), /*parent=*/nullptr, | ||||
|       test_case.dataset_params.iterator_prefix(), &iterator_)); | ||||
|   end_of_sequence = false; | ||||
|   out_tensors.clear(); | ||||
|   while (!end_of_sequence) { | ||||
| @ -322,8 +322,8 @@ TEST_P(ParameterizedIteratorSaveAndRestoreTest, SaveAndRestore) { | ||||
|     end_of_sequence = false; | ||||
|     out_tensors.clear(); | ||||
|     TF_ASSERT_OK(dataset_->MakeIterator( | ||||
|         iterator_ctx_.get(), test_case.dataset_params.iterator_prefix(), | ||||
|         &iterator_)); | ||||
|         iterator_ctx_.get(), /*parent=*/nullptr, | ||||
|         test_case.dataset_params.iterator_prefix(), &iterator_)); | ||||
|   } | ||||
| 
 | ||||
|   std::unique_ptr<SerializationContext> serialization_ctx; | ||||
|  | ||||
| @ -654,8 +654,8 @@ Status DatasetOpsTestBase::CheckIteratorSaveAndRestore( | ||||
|     const string& iterator_prefix, const std::vector<Tensor>& expected_outputs, | ||||
|     const std::vector<int>& breakpoints, bool compare_order) { | ||||
|   std::unique_ptr<IteratorBase> iterator; | ||||
|   TF_RETURN_IF_ERROR( | ||||
|       dataset_->MakeIterator(iterator_ctx_.get(), iterator_prefix, &iterator)); | ||||
|   TF_RETURN_IF_ERROR(dataset_->MakeIterator( | ||||
|       iterator_ctx_.get(), /*parent=*/nullptr, iterator_prefix, &iterator)); | ||||
|   std::unique_ptr<SerializationContext> serialization_ctx; | ||||
|   TF_RETURN_IF_ERROR(CreateSerializationContext(&serialization_ctx)); | ||||
|   bool end_of_sequence = false; | ||||
| @ -704,8 +704,9 @@ Status DatasetOpsTestBase::Initialize(const DatasetParams& dataset_params) { | ||||
|   TF_RETURN_IF_ERROR(MakeDataset(dataset_params, &dataset_kernel_, ¶ms_, | ||||
|                                  &dataset_ctx_, &tensors_, &dataset_)); | ||||
|   TF_RETURN_IF_ERROR(CreateIteratorContext(dataset_ctx_.get(), &iterator_ctx_)); | ||||
|   TF_RETURN_IF_ERROR(dataset_->MakeIterator( | ||||
|       iterator_ctx_.get(), dataset_params.iterator_prefix(), &iterator_)); | ||||
|   TF_RETURN_IF_ERROR( | ||||
|       dataset_->MakeIterator(iterator_ctx_.get(), /*parent=*/nullptr, | ||||
|                              dataset_params.iterator_prefix(), &iterator_)); | ||||
|   initialized_ = true; | ||||
|   return Status::OK(); | ||||
| } | ||||
| @ -791,7 +792,8 @@ Status DatasetOpsTestBase::MakeIterator( | ||||
|       CreateIteratorContext(dataset.op_kernel_context(), &iterator_ctx)); | ||||
|   std::unique_ptr<IteratorBase> iterator_base; | ||||
|   TF_RETURN_IF_ERROR(dataset.dataset()->MakeIterator( | ||||
|       iterator_ctx.get(), dataset_params.iterator_prefix(), &iterator_base)); | ||||
|       iterator_ctx.get(), /*parent=*/nullptr, dataset_params.iterator_prefix(), | ||||
|       &iterator_base)); | ||||
|   *iterator = std::make_unique<TestIterator>(std::move(iterator_ctx), | ||||
|                                              std::move(iterator_base)); | ||||
|   return Status::OK(); | ||||
|  | ||||
| @ -84,8 +84,8 @@ class ToTFRecordOp : public AsyncOpKernel { | ||||
| 
 | ||||
|     IteratorContext iter_ctx(std::move(params)); | ||||
|     std::unique_ptr<IteratorBase> iterator; | ||||
|     TF_RETURN_IF_ERROR( | ||||
|         dataset->MakeIterator(&iter_ctx, "ToTFRecordOpIterator", &iterator)); | ||||
|     TF_RETURN_IF_ERROR(dataset->MakeIterator( | ||||
|         &iter_ctx, /*parent=*/nullptr, "ToTFRecordOpIterator", &iterator)); | ||||
| 
 | ||||
|     std::vector<Tensor> components; | ||||
|     components.reserve(dataset->output_dtypes().size()); | ||||
|  | ||||
| @ -191,7 +191,8 @@ Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx, | ||||
|   { | ||||
|     auto cleanup = gtl::MakeCleanup(std::move(deregister_fn)); | ||||
|     TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)), | ||||
|                                              "Iterator", &iterator)); | ||||
|                                              /*parent=*/nullptr, "Iterator", | ||||
|                                              &iterator)); | ||||
|     TF_RETURN_IF_ERROR( | ||||
|         VerifyTypesMatch(output_dtypes_, iterator->output_dtypes())); | ||||
|     TF_RETURN_IF_ERROR( | ||||
| @ -565,8 +566,8 @@ class ToSingleElementOp : public HybridAsyncOpKernel { | ||||
| 
 | ||||
|     IteratorContext iter_ctx(std::move(params)); | ||||
|     std::unique_ptr<IteratorBase> iterator; | ||||
|     TF_RETURN_IF_ERROR( | ||||
|         dataset->MakeIterator(&iter_ctx, "SingleElementIterator", &iterator)); | ||||
|     TF_RETURN_IF_ERROR(dataset->MakeIterator( | ||||
|         &iter_ctx, /*parent=*/nullptr, "SingleElementIterator", &iterator)); | ||||
| 
 | ||||
|     std::vector<Tensor> components; | ||||
|     components.reserve(dataset->output_dtypes().size()); | ||||
| @ -636,8 +637,8 @@ class ReduceDatasetOp : public HybridAsyncOpKernel { | ||||
|         captured_func->Instantiate(&iter_ctx, &instantiated_captured_func)); | ||||
| 
 | ||||
|     std::unique_ptr<IteratorBase> iterator; | ||||
|     TF_RETURN_IF_ERROR( | ||||
|         dataset->MakeIterator(&iter_ctx, "ReduceIterator", &iterator)); | ||||
|     TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, /*parent=*/nullptr, | ||||
|                                              "ReduceIterator", &iterator)); | ||||
| 
 | ||||
|     // Iterate through the input dataset.
 | ||||
|     while (true) { | ||||
|  | ||||
| @ -344,8 +344,8 @@ TEST_P(ParameterizedGetNextTest, GetNext) { | ||||
|   // Reshuffle the dataset.
 | ||||
|   end_of_sequence = false; | ||||
|   TF_ASSERT_OK(dataset_->MakeIterator( | ||||
|       iterator_ctx_.get(), test_case.dataset_params.iterator_prefix(), | ||||
|       &iterator_)); | ||||
|       iterator_ctx_.get(), /*parent=*/nullptr, | ||||
|       test_case.dataset_params.iterator_prefix(), &iterator_)); | ||||
|   std::vector<Tensor> reshuffled_out_tensors; | ||||
|   while (!end_of_sequence) { | ||||
|     std::vector<Tensor> next; | ||||
|  | ||||
| @ -302,7 +302,8 @@ TEST_P(ParameterizedGetNextTest, GetNext) { | ||||
|                                                  &window_dataset)); | ||||
|         std::unique_ptr<IteratorBase> window_dataset_iterator; | ||||
|         TF_ASSERT_OK(window_dataset->MakeIterator( | ||||
|             iterator_ctx_.get(), test_case.dataset_params.iterator_prefix(), | ||||
|             iterator_ctx_.get(), /*parent=*/nullptr, | ||||
|             test_case.dataset_params.iterator_prefix(), | ||||
|             &window_dataset_iterator)); | ||||
|         bool end_of_window_dataset = false; | ||||
|         std::vector<Tensor> window_elements; | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user