[tf.data] Internal cleanup.
PiperOrigin-RevId: 295768875 Change-Id: I77da989a9eb2c74706e64bdc5e863d13fa76832a
This commit is contained in:
parent
b04371bc95
commit
dff4559ac3
@ -182,8 +182,8 @@ TEST_P(ParameterizedGetNextTest, GetNext) {
|
||||
|
||||
// Test the read mode.
|
||||
TF_ASSERT_OK(dataset_->MakeIterator(
|
||||
iterator_ctx_.get(), test_case.dataset_params.iterator_prefix(),
|
||||
&iterator_));
|
||||
iterator_ctx_.get(), /*parent=*/nullptr,
|
||||
test_case.dataset_params.iterator_prefix(), &iterator_));
|
||||
end_of_sequence = false;
|
||||
out_tensors.clear();
|
||||
while (!end_of_sequence) {
|
||||
@ -322,8 +322,8 @@ TEST_P(ParameterizedIteratorSaveAndRestoreTest, SaveAndRestore) {
|
||||
end_of_sequence = false;
|
||||
out_tensors.clear();
|
||||
TF_ASSERT_OK(dataset_->MakeIterator(
|
||||
iterator_ctx_.get(), test_case.dataset_params.iterator_prefix(),
|
||||
&iterator_));
|
||||
iterator_ctx_.get(), /*parent=*/nullptr,
|
||||
test_case.dataset_params.iterator_prefix(), &iterator_));
|
||||
}
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||
|
||||
@ -654,8 +654,8 @@ Status DatasetOpsTestBase::CheckIteratorSaveAndRestore(
|
||||
const string& iterator_prefix, const std::vector<Tensor>& expected_outputs,
|
||||
const std::vector<int>& breakpoints, bool compare_order) {
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset_->MakeIterator(iterator_ctx_.get(), iterator_prefix, &iterator));
|
||||
TF_RETURN_IF_ERROR(dataset_->MakeIterator(
|
||||
iterator_ctx_.get(), /*parent=*/nullptr, iterator_prefix, &iterator));
|
||||
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||
TF_RETURN_IF_ERROR(CreateSerializationContext(&serialization_ctx));
|
||||
bool end_of_sequence = false;
|
||||
@ -704,8 +704,9 @@ Status DatasetOpsTestBase::Initialize(const DatasetParams& dataset_params) {
|
||||
TF_RETURN_IF_ERROR(MakeDataset(dataset_params, &dataset_kernel_, ¶ms_,
|
||||
&dataset_ctx_, &tensors_, &dataset_));
|
||||
TF_RETURN_IF_ERROR(CreateIteratorContext(dataset_ctx_.get(), &iterator_ctx_));
|
||||
TF_RETURN_IF_ERROR(dataset_->MakeIterator(
|
||||
iterator_ctx_.get(), dataset_params.iterator_prefix(), &iterator_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset_->MakeIterator(iterator_ctx_.get(), /*parent=*/nullptr,
|
||||
dataset_params.iterator_prefix(), &iterator_));
|
||||
initialized_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -791,7 +792,8 @@ Status DatasetOpsTestBase::MakeIterator(
|
||||
CreateIteratorContext(dataset.op_kernel_context(), &iterator_ctx));
|
||||
std::unique_ptr<IteratorBase> iterator_base;
|
||||
TF_RETURN_IF_ERROR(dataset.dataset()->MakeIterator(
|
||||
iterator_ctx.get(), dataset_params.iterator_prefix(), &iterator_base));
|
||||
iterator_ctx.get(), /*parent=*/nullptr, dataset_params.iterator_prefix(),
|
||||
&iterator_base));
|
||||
*iterator = std::make_unique<TestIterator>(std::move(iterator_ctx),
|
||||
std::move(iterator_base));
|
||||
return Status::OK();
|
||||
|
||||
@ -84,8 +84,8 @@ class ToTFRecordOp : public AsyncOpKernel {
|
||||
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset->MakeIterator(&iter_ctx, "ToTFRecordOpIterator", &iterator));
|
||||
TF_RETURN_IF_ERROR(dataset->MakeIterator(
|
||||
&iter_ctx, /*parent=*/nullptr, "ToTFRecordOpIterator", &iterator));
|
||||
|
||||
std::vector<Tensor> components;
|
||||
components.reserve(dataset->output_dtypes().size());
|
||||
|
||||
@ -191,7 +191,8 @@ Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
|
||||
{
|
||||
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
|
||||
TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
|
||||
"Iterator", &iterator));
|
||||
/*parent=*/nullptr, "Iterator",
|
||||
&iterator));
|
||||
TF_RETURN_IF_ERROR(
|
||||
VerifyTypesMatch(output_dtypes_, iterator->output_dtypes()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -565,8 +566,8 @@ class ToSingleElementOp : public HybridAsyncOpKernel {
|
||||
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset->MakeIterator(&iter_ctx, "SingleElementIterator", &iterator));
|
||||
TF_RETURN_IF_ERROR(dataset->MakeIterator(
|
||||
&iter_ctx, /*parent=*/nullptr, "SingleElementIterator", &iterator));
|
||||
|
||||
std::vector<Tensor> components;
|
||||
components.reserve(dataset->output_dtypes().size());
|
||||
@ -636,8 +637,8 @@ class ReduceDatasetOp : public HybridAsyncOpKernel {
|
||||
captured_func->Instantiate(&iter_ctx, &instantiated_captured_func));
|
||||
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset->MakeIterator(&iter_ctx, "ReduceIterator", &iterator));
|
||||
TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, /*parent=*/nullptr,
|
||||
"ReduceIterator", &iterator));
|
||||
|
||||
// Iterate through the input dataset.
|
||||
while (true) {
|
||||
|
||||
@ -344,8 +344,8 @@ TEST_P(ParameterizedGetNextTest, GetNext) {
|
||||
// Reshuffle the dataset.
|
||||
end_of_sequence = false;
|
||||
TF_ASSERT_OK(dataset_->MakeIterator(
|
||||
iterator_ctx_.get(), test_case.dataset_params.iterator_prefix(),
|
||||
&iterator_));
|
||||
iterator_ctx_.get(), /*parent=*/nullptr,
|
||||
test_case.dataset_params.iterator_prefix(), &iterator_));
|
||||
std::vector<Tensor> reshuffled_out_tensors;
|
||||
while (!end_of_sequence) {
|
||||
std::vector<Tensor> next;
|
||||
|
||||
@ -302,7 +302,8 @@ TEST_P(ParameterizedGetNextTest, GetNext) {
|
||||
&window_dataset));
|
||||
std::unique_ptr<IteratorBase> window_dataset_iterator;
|
||||
TF_ASSERT_OK(window_dataset->MakeIterator(
|
||||
iterator_ctx_.get(), test_case.dataset_params.iterator_prefix(),
|
||||
iterator_ctx_.get(), /*parent=*/nullptr,
|
||||
test_case.dataset_params.iterator_prefix(),
|
||||
&window_dataset_iterator));
|
||||
bool end_of_window_dataset = false;
|
||||
std::vector<Tensor> window_elements;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user