diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc index 4e239d0895c..69ad7ea3bb5 100644 --- a/tensorflow/core/kernels/data/window_dataset_op.cc +++ b/tensorflow/core/kernels/data/window_dataset_op.cc @@ -155,14 +155,17 @@ class WindowDatasetOp::Dataset : public DatasetBase { std::vector> window_elements; Status status = Status::OK(); { + const size_t target_size = TargetBufferSize(window_size, window_stride); + mutex_lock l(mu_); - if (!input_impl_ && buffer_.empty()) { + if (!input_impl_ && + (buffer_.empty() || + (dataset()->drop_remainder_ && buffer_.size() < target_size))) { *end_of_sequence = true; return Status::OK(); } // Add elements to the buffer. - size_t target_size = TargetBufferSize(window_size, window_stride); if (input_impl_) { *end_of_sequence = false; for (size_t i = buffer_.size(); i < target_size && !*end_of_sequence; diff --git a/tensorflow/python/data/kernel_tests/window_test.py b/tensorflow/python/data/kernel_tests/window_test.py index 98b453a5900..2515bd52f60 100644 --- a/tensorflow/python/data/kernel_tests/window_test.py +++ b/tensorflow/python/data/kernel_tests/window_test.py @@ -239,6 +239,17 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces(x, range(i*10, (i+1)*10)) self.assertDatasetProduces(y, range(i*10, (i+1)*10)) + @combinations.generate(test_base.default_test_combinations()) + def testDropRemainderOutput(self): + dataset = dataset_ops.Dataset.range(100) + dataset = dataset.window(30, drop_remainder=True) + dataset = dataset.flat_map(lambda x: x.batch(30)) + dataset = dataset.batch(4) + + self.assertDatasetProduces( + dataset, + expected_output=[[[y + 30 * x for y in range(30)] for x in range(3)]]) + if __name__ == "__main__": test.main()