diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc index c915f80c2c6..cfeb63a4242 100644 --- a/tensorflow/core/kernels/data/batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/batch_dataset_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/data/batch_dataset_op.h" +#include #include #include "tensorflow/core/framework/op_kernel.h" @@ -49,6 +50,12 @@ class BatchDatasetOp::Dataset : public DatasetBase { bool parallel_copy, const DatasetBase* input, int op_version) : DatasetBase(DatasetContext(ctx)), batch_size_(batch_size), + // Dataset batch is sometimes used to stack all elements in the + // dataset. In such cases, a very large batch size (e.g., INT32_MAX) + // is passed with drop_remainder set to false. Avoid OOM in such case + // by limiting `reserve()` size by 2**16. + reserve_size_(drop_remainder ? batch_size + : std::min(batch_size, 1 << 16)), drop_remainder_(drop_remainder), parallel_copy_(parallel_copy), input_(input), @@ -153,7 +160,7 @@ class BatchDatasetOp::Dataset : public DatasetBase { *end_of_sequence = true; return Status::OK(); } - batch_elements.reserve(dataset()->batch_size_); + batch_elements.reserve(dataset()->reserve_size_); *end_of_sequence = false; for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence; ++i) { std::vector batch_element_tuple; @@ -289,6 +296,7 @@ class BatchDatasetOp::Dataset : public DatasetBase { }; const int64 batch_size_; + const int64 reserve_size_; const bool drop_remainder_; const bool parallel_copy_; const DatasetBase* const input_;