[tf.data] Internal cleanup.

PiperOrigin-RevId: 295768875
Change-Id: I77da989a9eb2c74706e64bdc5e863d13fa76832a
This commit is contained in:
Jiri Simsa 2020-02-18 10:42:39 -08:00 committed by TensorFlower Gardener
parent b04371bc95
commit dff4559ac3
6 changed files with 23 additions and 19 deletions

View File

@ -182,8 +182,8 @@ TEST_P(ParameterizedGetNextTest, GetNext) {
// Test the read mode.
TF_ASSERT_OK(dataset_->MakeIterator(
iterator_ctx_.get(), test_case.dataset_params.iterator_prefix(),
&iterator_));
iterator_ctx_.get(), /*parent=*/nullptr,
test_case.dataset_params.iterator_prefix(), &iterator_));
end_of_sequence = false;
out_tensors.clear();
while (!end_of_sequence) {
@ -322,8 +322,8 @@ TEST_P(ParameterizedIteratorSaveAndRestoreTest, SaveAndRestore) {
end_of_sequence = false;
out_tensors.clear();
TF_ASSERT_OK(dataset_->MakeIterator(
iterator_ctx_.get(), test_case.dataset_params.iterator_prefix(),
&iterator_));
iterator_ctx_.get(), /*parent=*/nullptr,
test_case.dataset_params.iterator_prefix(), &iterator_));
}
std::unique_ptr<SerializationContext> serialization_ctx;

View File

@ -654,8 +654,8 @@ Status DatasetOpsTestBase::CheckIteratorSaveAndRestore(
const string& iterator_prefix, const std::vector<Tensor>& expected_outputs,
const std::vector<int>& breakpoints, bool compare_order) {
std::unique_ptr<IteratorBase> iterator;
TF_RETURN_IF_ERROR(
dataset_->MakeIterator(iterator_ctx_.get(), iterator_prefix, &iterator));
TF_RETURN_IF_ERROR(dataset_->MakeIterator(
iterator_ctx_.get(), /*parent=*/nullptr, iterator_prefix, &iterator));
std::unique_ptr<SerializationContext> serialization_ctx;
TF_RETURN_IF_ERROR(CreateSerializationContext(&serialization_ctx));
bool end_of_sequence = false;
@ -704,8 +704,9 @@ Status DatasetOpsTestBase::Initialize(const DatasetParams& dataset_params) {
TF_RETURN_IF_ERROR(MakeDataset(dataset_params, &dataset_kernel_, &params_,
&dataset_ctx_, &tensors_, &dataset_));
TF_RETURN_IF_ERROR(CreateIteratorContext(dataset_ctx_.get(), &iterator_ctx_));
TF_RETURN_IF_ERROR(dataset_->MakeIterator(
iterator_ctx_.get(), dataset_params.iterator_prefix(), &iterator_));
TF_RETURN_IF_ERROR(
dataset_->MakeIterator(iterator_ctx_.get(), /*parent=*/nullptr,
dataset_params.iterator_prefix(), &iterator_));
initialized_ = true;
return Status::OK();
}
@ -791,7 +792,8 @@ Status DatasetOpsTestBase::MakeIterator(
CreateIteratorContext(dataset.op_kernel_context(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator_base;
TF_RETURN_IF_ERROR(dataset.dataset()->MakeIterator(
iterator_ctx.get(), dataset_params.iterator_prefix(), &iterator_base));
iterator_ctx.get(), /*parent=*/nullptr, dataset_params.iterator_prefix(),
&iterator_base));
*iterator = std::make_unique<TestIterator>(std::move(iterator_ctx),
std::move(iterator_base));
return Status::OK();

View File

@ -84,8 +84,8 @@ class ToTFRecordOp : public AsyncOpKernel {
IteratorContext iter_ctx(std::move(params));
std::unique_ptr<IteratorBase> iterator;
TF_RETURN_IF_ERROR(
dataset->MakeIterator(&iter_ctx, "ToTFRecordOpIterator", &iterator));
TF_RETURN_IF_ERROR(dataset->MakeIterator(
&iter_ctx, /*parent=*/nullptr, "ToTFRecordOpIterator", &iterator));
std::vector<Tensor> components;
components.reserve(dataset->output_dtypes().size());

View File

@ -191,7 +191,8 @@ Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
{
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
"Iterator", &iterator));
/*parent=*/nullptr, "Iterator",
&iterator));
TF_RETURN_IF_ERROR(
VerifyTypesMatch(output_dtypes_, iterator->output_dtypes()));
TF_RETURN_IF_ERROR(
@ -565,8 +566,8 @@ class ToSingleElementOp : public HybridAsyncOpKernel {
IteratorContext iter_ctx(std::move(params));
std::unique_ptr<IteratorBase> iterator;
TF_RETURN_IF_ERROR(
dataset->MakeIterator(&iter_ctx, "SingleElementIterator", &iterator));
TF_RETURN_IF_ERROR(dataset->MakeIterator(
&iter_ctx, /*parent=*/nullptr, "SingleElementIterator", &iterator));
std::vector<Tensor> components;
components.reserve(dataset->output_dtypes().size());
@ -636,8 +637,8 @@ class ReduceDatasetOp : public HybridAsyncOpKernel {
captured_func->Instantiate(&iter_ctx, &instantiated_captured_func));
std::unique_ptr<IteratorBase> iterator;
TF_RETURN_IF_ERROR(
dataset->MakeIterator(&iter_ctx, "ReduceIterator", &iterator));
TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, /*parent=*/nullptr,
"ReduceIterator", &iterator));
// Iterate through the input dataset.
while (true) {

View File

@ -344,8 +344,8 @@ TEST_P(ParameterizedGetNextTest, GetNext) {
// Reshuffle the dataset.
end_of_sequence = false;
TF_ASSERT_OK(dataset_->MakeIterator(
iterator_ctx_.get(), test_case.dataset_params.iterator_prefix(),
&iterator_));
iterator_ctx_.get(), /*parent=*/nullptr,
test_case.dataset_params.iterator_prefix(), &iterator_));
std::vector<Tensor> reshuffled_out_tensors;
while (!end_of_sequence) {
std::vector<Tensor> next;

View File

@ -302,7 +302,8 @@ TEST_P(ParameterizedGetNextTest, GetNext) {
&window_dataset));
std::unique_ptr<IteratorBase> window_dataset_iterator;
TF_ASSERT_OK(window_dataset->MakeIterator(
iterator_ctx_.get(), test_case.dataset_params.iterator_prefix(),
iterator_ctx_.get(), /*parent=*/nullptr,
test_case.dataset_params.iterator_prefix(),
&window_dataset_iterator));
bool end_of_window_dataset = false;
std::vector<Tensor> window_elements;