[tf.data] Make sure the batch_results_ size will not exceed num_parallel_calls for ParallelBatchDataset.

PiperOrigin-RevId: 356812046
Change-Id: I1f8179b32556792d64ee18276d92ed03f6ee09ab
This commit is contained in:
Jay Shi 2021-02-10 13:23:04 -08:00 committed by TensorFlower Gardener
parent 201e2a9f94
commit b57b89f7f5

View File

@ -406,7 +406,9 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
new_calls.reserve(num_parallel_calls_->value);
}
auto busy = [this]() TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
return num_calls_ >= num_parallel_calls_->value;
int64 num_parallel_calls = num_parallel_calls_->value;
return num_calls_ >= num_parallel_calls ||
batch_results_.size() >= num_parallel_calls;
};
while (true) {
{