[tf.data] Internal cleanup.
PiperOrigin-RevId: 295768875 Change-Id: I77da989a9eb2c74706e64bdc5e863d13fa76832a
This commit is contained in:
parent
b04371bc95
commit
dff4559ac3
@ -182,8 +182,8 @@ TEST_P(ParameterizedGetNextTest, GetNext) {
|
|||||||
|
|
||||||
// Test the read mode.
|
// Test the read mode.
|
||||||
TF_ASSERT_OK(dataset_->MakeIterator(
|
TF_ASSERT_OK(dataset_->MakeIterator(
|
||||||
iterator_ctx_.get(), test_case.dataset_params.iterator_prefix(),
|
iterator_ctx_.get(), /*parent=*/nullptr,
|
||||||
&iterator_));
|
test_case.dataset_params.iterator_prefix(), &iterator_));
|
||||||
end_of_sequence = false;
|
end_of_sequence = false;
|
||||||
out_tensors.clear();
|
out_tensors.clear();
|
||||||
while (!end_of_sequence) {
|
while (!end_of_sequence) {
|
||||||
@ -322,8 +322,8 @@ TEST_P(ParameterizedIteratorSaveAndRestoreTest, SaveAndRestore) {
|
|||||||
end_of_sequence = false;
|
end_of_sequence = false;
|
||||||
out_tensors.clear();
|
out_tensors.clear();
|
||||||
TF_ASSERT_OK(dataset_->MakeIterator(
|
TF_ASSERT_OK(dataset_->MakeIterator(
|
||||||
iterator_ctx_.get(), test_case.dataset_params.iterator_prefix(),
|
iterator_ctx_.get(), /*parent=*/nullptr,
|
||||||
&iterator_));
|
test_case.dataset_params.iterator_prefix(), &iterator_));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<SerializationContext> serialization_ctx;
|
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||||
|
|||||||
@ -654,8 +654,8 @@ Status DatasetOpsTestBase::CheckIteratorSaveAndRestore(
|
|||||||
const string& iterator_prefix, const std::vector<Tensor>& expected_outputs,
|
const string& iterator_prefix, const std::vector<Tensor>& expected_outputs,
|
||||||
const std::vector<int>& breakpoints, bool compare_order) {
|
const std::vector<int>& breakpoints, bool compare_order) {
|
||||||
std::unique_ptr<IteratorBase> iterator;
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(dataset_->MakeIterator(
|
||||||
dataset_->MakeIterator(iterator_ctx_.get(), iterator_prefix, &iterator));
|
iterator_ctx_.get(), /*parent=*/nullptr, iterator_prefix, &iterator));
|
||||||
std::unique_ptr<SerializationContext> serialization_ctx;
|
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||||
TF_RETURN_IF_ERROR(CreateSerializationContext(&serialization_ctx));
|
TF_RETURN_IF_ERROR(CreateSerializationContext(&serialization_ctx));
|
||||||
bool end_of_sequence = false;
|
bool end_of_sequence = false;
|
||||||
@ -704,8 +704,9 @@ Status DatasetOpsTestBase::Initialize(const DatasetParams& dataset_params) {
|
|||||||
TF_RETURN_IF_ERROR(MakeDataset(dataset_params, &dataset_kernel_, ¶ms_,
|
TF_RETURN_IF_ERROR(MakeDataset(dataset_params, &dataset_kernel_, ¶ms_,
|
||||||
&dataset_ctx_, &tensors_, &dataset_));
|
&dataset_ctx_, &tensors_, &dataset_));
|
||||||
TF_RETURN_IF_ERROR(CreateIteratorContext(dataset_ctx_.get(), &iterator_ctx_));
|
TF_RETURN_IF_ERROR(CreateIteratorContext(dataset_ctx_.get(), &iterator_ctx_));
|
||||||
TF_RETURN_IF_ERROR(dataset_->MakeIterator(
|
TF_RETURN_IF_ERROR(
|
||||||
iterator_ctx_.get(), dataset_params.iterator_prefix(), &iterator_));
|
dataset_->MakeIterator(iterator_ctx_.get(), /*parent=*/nullptr,
|
||||||
|
dataset_params.iterator_prefix(), &iterator_));
|
||||||
initialized_ = true;
|
initialized_ = true;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -791,7 +792,8 @@ Status DatasetOpsTestBase::MakeIterator(
|
|||||||
CreateIteratorContext(dataset.op_kernel_context(), &iterator_ctx));
|
CreateIteratorContext(dataset.op_kernel_context(), &iterator_ctx));
|
||||||
std::unique_ptr<IteratorBase> iterator_base;
|
std::unique_ptr<IteratorBase> iterator_base;
|
||||||
TF_RETURN_IF_ERROR(dataset.dataset()->MakeIterator(
|
TF_RETURN_IF_ERROR(dataset.dataset()->MakeIterator(
|
||||||
iterator_ctx.get(), dataset_params.iterator_prefix(), &iterator_base));
|
iterator_ctx.get(), /*parent=*/nullptr, dataset_params.iterator_prefix(),
|
||||||
|
&iterator_base));
|
||||||
*iterator = std::make_unique<TestIterator>(std::move(iterator_ctx),
|
*iterator = std::make_unique<TestIterator>(std::move(iterator_ctx),
|
||||||
std::move(iterator_base));
|
std::move(iterator_base));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|||||||
@ -84,8 +84,8 @@ class ToTFRecordOp : public AsyncOpKernel {
|
|||||||
|
|
||||||
IteratorContext iter_ctx(std::move(params));
|
IteratorContext iter_ctx(std::move(params));
|
||||||
std::unique_ptr<IteratorBase> iterator;
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(dataset->MakeIterator(
|
||||||
dataset->MakeIterator(&iter_ctx, "ToTFRecordOpIterator", &iterator));
|
&iter_ctx, /*parent=*/nullptr, "ToTFRecordOpIterator", &iterator));
|
||||||
|
|
||||||
std::vector<Tensor> components;
|
std::vector<Tensor> components;
|
||||||
components.reserve(dataset->output_dtypes().size());
|
components.reserve(dataset->output_dtypes().size());
|
||||||
|
|||||||
@ -191,7 +191,8 @@ Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
|
|||||||
{
|
{
|
||||||
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
|
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
|
||||||
TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
|
TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
|
||||||
"Iterator", &iterator));
|
/*parent=*/nullptr, "Iterator",
|
||||||
|
&iterator));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
VerifyTypesMatch(output_dtypes_, iterator->output_dtypes()));
|
VerifyTypesMatch(output_dtypes_, iterator->output_dtypes()));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
@ -565,8 +566,8 @@ class ToSingleElementOp : public HybridAsyncOpKernel {
|
|||||||
|
|
||||||
IteratorContext iter_ctx(std::move(params));
|
IteratorContext iter_ctx(std::move(params));
|
||||||
std::unique_ptr<IteratorBase> iterator;
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(dataset->MakeIterator(
|
||||||
dataset->MakeIterator(&iter_ctx, "SingleElementIterator", &iterator));
|
&iter_ctx, /*parent=*/nullptr, "SingleElementIterator", &iterator));
|
||||||
|
|
||||||
std::vector<Tensor> components;
|
std::vector<Tensor> components;
|
||||||
components.reserve(dataset->output_dtypes().size());
|
components.reserve(dataset->output_dtypes().size());
|
||||||
@ -636,8 +637,8 @@ class ReduceDatasetOp : public HybridAsyncOpKernel {
|
|||||||
captured_func->Instantiate(&iter_ctx, &instantiated_captured_func));
|
captured_func->Instantiate(&iter_ctx, &instantiated_captured_func));
|
||||||
|
|
||||||
std::unique_ptr<IteratorBase> iterator;
|
std::unique_ptr<IteratorBase> iterator;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, /*parent=*/nullptr,
|
||||||
dataset->MakeIterator(&iter_ctx, "ReduceIterator", &iterator));
|
"ReduceIterator", &iterator));
|
||||||
|
|
||||||
// Iterate through the input dataset.
|
// Iterate through the input dataset.
|
||||||
while (true) {
|
while (true) {
|
||||||
|
|||||||
@ -344,8 +344,8 @@ TEST_P(ParameterizedGetNextTest, GetNext) {
|
|||||||
// Reshuffle the dataset.
|
// Reshuffle the dataset.
|
||||||
end_of_sequence = false;
|
end_of_sequence = false;
|
||||||
TF_ASSERT_OK(dataset_->MakeIterator(
|
TF_ASSERT_OK(dataset_->MakeIterator(
|
||||||
iterator_ctx_.get(), test_case.dataset_params.iterator_prefix(),
|
iterator_ctx_.get(), /*parent=*/nullptr,
|
||||||
&iterator_));
|
test_case.dataset_params.iterator_prefix(), &iterator_));
|
||||||
std::vector<Tensor> reshuffled_out_tensors;
|
std::vector<Tensor> reshuffled_out_tensors;
|
||||||
while (!end_of_sequence) {
|
while (!end_of_sequence) {
|
||||||
std::vector<Tensor> next;
|
std::vector<Tensor> next;
|
||||||
|
|||||||
@ -302,7 +302,8 @@ TEST_P(ParameterizedGetNextTest, GetNext) {
|
|||||||
&window_dataset));
|
&window_dataset));
|
||||||
std::unique_ptr<IteratorBase> window_dataset_iterator;
|
std::unique_ptr<IteratorBase> window_dataset_iterator;
|
||||||
TF_ASSERT_OK(window_dataset->MakeIterator(
|
TF_ASSERT_OK(window_dataset->MakeIterator(
|
||||||
iterator_ctx_.get(), test_case.dataset_params.iterator_prefix(),
|
iterator_ctx_.get(), /*parent=*/nullptr,
|
||||||
|
test_case.dataset_params.iterator_prefix(),
|
||||||
&window_dataset_iterator));
|
&window_dataset_iterator));
|
||||||
bool end_of_window_dataset = false;
|
bool end_of_window_dataset = false;
|
||||||
std::vector<Tensor> window_elements;
|
std::vector<Tensor> window_elements;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user