Limit reserve size by 2**16 in dataset batch op when drop_remainder is false.

Dataset batch is sometimes used to stack all the elements in the dataset. A
common pattern is to pass a very large batch size with drop_remainder == false.
The batch size should be larger than the dataset cardinality, but at the same
time, it should not be too large otherwise vector reserve in the batch op ends
up OOM.

This change limits the reserve size by 2**16 when drop_remainder is false. Then
the users may pass a large enough number like INT32_MAX or INT64_MAX to stack
all elements.

PiperOrigin-RevId: 316740746
Change-Id: I445ff66eff088363c802ec31c359e81f188d8047
This commit is contained in:
Sung Jin Hwang 2020-06-16 12:59:30 -07:00 committed by TensorFlower Gardener
parent ed1d7d09ae
commit 5f06c79985

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/batch_dataset_op.h"
#include <algorithm>
#include <utility>
#include "tensorflow/core/framework/op_kernel.h"
@ -49,6 +50,12 @@ class BatchDatasetOp::Dataset : public DatasetBase {
bool parallel_copy, const DatasetBase* input, int op_version)
: DatasetBase(DatasetContext(ctx)),
batch_size_(batch_size),
// Dataset batch is sometimes used to stack all elements in the
// dataset. In such cases, a very large batch size (e.g., INT32_MAX)
// is passed with drop_remainder set to false. Avoid OOM in such case
// by limiting `reserve()` size by 2**16.
reserve_size_(drop_remainder ? batch_size
: std::min<int64>(batch_size, 1 << 16)),
drop_remainder_(drop_remainder),
parallel_copy_(parallel_copy),
input_(input),
@ -153,7 +160,7 @@ class BatchDatasetOp::Dataset : public DatasetBase {
*end_of_sequence = true;
return Status::OK();
}
batch_elements.reserve(dataset()->batch_size_);
batch_elements.reserve(dataset()->reserve_size_);
*end_of_sequence = false;
for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence; ++i) {
std::vector<Tensor> batch_element_tuple;
@ -289,6 +296,7 @@ class BatchDatasetOp::Dataset : public DatasetBase {
};
const int64 batch_size_;
const int64 reserve_size_;
const bool drop_remainder_;
const bool parallel_copy_;
const DatasetBase* const input_;