From dff4559ac3abca11bfad3400195b2f5a78420366 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Tue, 18 Feb 2020 10:42:39 -0800 Subject: [PATCH] [tf.data] Internal cleanup. PiperOrigin-RevId: 295768875 Change-Id: I77da989a9eb2c74706e64bdc5e863d13fa76832a --- .../core/kernels/data/cache_dataset_ops_test.cc | 8 ++++---- tensorflow/core/kernels/data/dataset_test_base.cc | 12 +++++++----- .../kernels/data/experimental/to_tf_record_op.cc | 4 ++-- tensorflow/core/kernels/data/iterator_ops.cc | 11 ++++++----- .../core/kernels/data/shuffle_dataset_op_test.cc | 4 ++-- .../core/kernels/data/window_dataset_op_test.cc | 3 ++- 6 files changed, 23 insertions(+), 19 deletions(-) diff --git a/tensorflow/core/kernels/data/cache_dataset_ops_test.cc b/tensorflow/core/kernels/data/cache_dataset_ops_test.cc index 9faf92b83da..c6bc70b4c94 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops_test.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops_test.cc @@ -182,8 +182,8 @@ TEST_P(ParameterizedGetNextTest, GetNext) { // Test the read mode. TF_ASSERT_OK(dataset_->MakeIterator( - iterator_ctx_.get(), test_case.dataset_params.iterator_prefix(), - &iterator_)); + iterator_ctx_.get(), /*parent=*/nullptr, + test_case.dataset_params.iterator_prefix(), &iterator_)); end_of_sequence = false; out_tensors.clear(); while (!end_of_sequence) { @@ -322,8 +322,8 @@ TEST_P(ParameterizedIteratorSaveAndRestoreTest, SaveAndRestore) { end_of_sequence = false; out_tensors.clear(); TF_ASSERT_OK(dataset_->MakeIterator( - iterator_ctx_.get(), test_case.dataset_params.iterator_prefix(), - &iterator_)); + iterator_ctx_.get(), /*parent=*/nullptr, + test_case.dataset_params.iterator_prefix(), &iterator_)); } std::unique_ptr serialization_ctx; diff --git a/tensorflow/core/kernels/data/dataset_test_base.cc b/tensorflow/core/kernels/data/dataset_test_base.cc index 38652753066..7c5d0c3f679 100644 --- a/tensorflow/core/kernels/data/dataset_test_base.cc +++ b/tensorflow/core/kernels/data/dataset_test_base.cc @@ -654,8 +654,8 @@ Status DatasetOpsTestBase::CheckIteratorSaveAndRestore( const string& iterator_prefix, const std::vector& expected_outputs, const std::vector& breakpoints, bool compare_order) { std::unique_ptr iterator; - TF_RETURN_IF_ERROR( - dataset_->MakeIterator(iterator_ctx_.get(), iterator_prefix, &iterator)); + TF_RETURN_IF_ERROR(dataset_->MakeIterator( + iterator_ctx_.get(), /*parent=*/nullptr, iterator_prefix, &iterator)); std::unique_ptr serialization_ctx; TF_RETURN_IF_ERROR(CreateSerializationContext(&serialization_ctx)); bool end_of_sequence = false; @@ -704,8 +704,9 @@ Status DatasetOpsTestBase::Initialize(const DatasetParams& dataset_params) { TF_RETURN_IF_ERROR(MakeDataset(dataset_params, &dataset_kernel_, ¶ms_, &dataset_ctx_, &tensors_, &dataset_)); TF_RETURN_IF_ERROR(CreateIteratorContext(dataset_ctx_.get(), &iterator_ctx_)); - TF_RETURN_IF_ERROR(dataset_->MakeIterator( - iterator_ctx_.get(), dataset_params.iterator_prefix(), &iterator_)); + TF_RETURN_IF_ERROR( + dataset_->MakeIterator(iterator_ctx_.get(), /*parent=*/nullptr, + dataset_params.iterator_prefix(), &iterator_)); initialized_ = true; return Status::OK(); } @@ -791,7 +792,8 @@ Status DatasetOpsTestBase::MakeIterator( CreateIteratorContext(dataset.op_kernel_context(), &iterator_ctx)); std::unique_ptr iterator_base; TF_RETURN_IF_ERROR(dataset.dataset()->MakeIterator( - iterator_ctx.get(), dataset_params.iterator_prefix(), &iterator_base)); + iterator_ctx.get(), /*parent=*/nullptr, dataset_params.iterator_prefix(), + &iterator_base)); *iterator = std::make_unique(std::move(iterator_ctx), std::move(iterator_base)); return Status::OK(); diff --git a/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc b/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc index 1f7576cbc75..6a910145b53 100644 --- a/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc +++ b/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc @@ -84,8 +84,8 @@ class ToTFRecordOp : public AsyncOpKernel { IteratorContext iter_ctx(std::move(params)); std::unique_ptr iterator; - TF_RETURN_IF_ERROR( - dataset->MakeIterator(&iter_ctx, "ToTFRecordOpIterator", &iterator)); + TF_RETURN_IF_ERROR(dataset->MakeIterator( + &iter_ctx, /*parent=*/nullptr, "ToTFRecordOpIterator", &iterator)); std::vector components; components.reserve(dataset->output_dtypes().size()); diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 7a1f12b044a..4adf7f64fba 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -191,7 +191,8 @@ Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx, { auto cleanup = gtl::MakeCleanup(std::move(deregister_fn)); TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)), - "Iterator", &iterator)); + /*parent=*/nullptr, "Iterator", + &iterator)); TF_RETURN_IF_ERROR( VerifyTypesMatch(output_dtypes_, iterator->output_dtypes())); TF_RETURN_IF_ERROR( @@ -565,8 +566,8 @@ class ToSingleElementOp : public HybridAsyncOpKernel { IteratorContext iter_ctx(std::move(params)); std::unique_ptr iterator; - TF_RETURN_IF_ERROR( - dataset->MakeIterator(&iter_ctx, "SingleElementIterator", &iterator)); + TF_RETURN_IF_ERROR(dataset->MakeIterator( + &iter_ctx, /*parent=*/nullptr, "SingleElementIterator", &iterator)); std::vector components; components.reserve(dataset->output_dtypes().size()); @@ -636,8 +637,8 @@ class ReduceDatasetOp : public HybridAsyncOpKernel { captured_func->Instantiate(&iter_ctx, &instantiated_captured_func)); std::unique_ptr iterator; - TF_RETURN_IF_ERROR( - dataset->MakeIterator(&iter_ctx, "ReduceIterator", &iterator)); + TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, /*parent=*/nullptr, + "ReduceIterator", &iterator)); // Iterate through the input dataset. while (true) { diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op_test.cc b/tensorflow/core/kernels/data/shuffle_dataset_op_test.cc index 20fb2912f5b..ca9afce7fc1 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op_test.cc @@ -344,8 +344,8 @@ TEST_P(ParameterizedGetNextTest, GetNext) { // Reshuffle the dataset. end_of_sequence = false; TF_ASSERT_OK(dataset_->MakeIterator( - iterator_ctx_.get(), test_case.dataset_params.iterator_prefix(), - &iterator_)); + iterator_ctx_.get(), /*parent=*/nullptr, + test_case.dataset_params.iterator_prefix(), &iterator_)); std::vector reshuffled_out_tensors; while (!end_of_sequence) { std::vector next; diff --git a/tensorflow/core/kernels/data/window_dataset_op_test.cc b/tensorflow/core/kernels/data/window_dataset_op_test.cc index bef42f761ac..31839e5d88d 100644 --- a/tensorflow/core/kernels/data/window_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/window_dataset_op_test.cc @@ -302,7 +302,8 @@ TEST_P(ParameterizedGetNextTest, GetNext) { &window_dataset)); std::unique_ptr window_dataset_iterator; TF_ASSERT_OK(window_dataset->MakeIterator( - iterator_ctx_.get(), test_case.dataset_params.iterator_prefix(), + iterator_ctx_.get(), /*parent=*/nullptr, + test_case.dataset_params.iterator_prefix(), &window_dataset_iterator)); bool end_of_window_dataset = false; std::vector window_elements;