[tf.data] Fixing a bug in tf.data.Dataset.window handling of drop_remainder=True.

Fixes: https://github.com/tensorflow/tensorflow/issues/43703
PiperOrigin-RevId: 335119334
Change-Id: Ie91ef58abac04a756e37724891d70b63ef4765d6
This commit is contained in:
Jiri Simsa 2020-10-02 16:07:12 -07:00 committed by TensorFlower Gardener
parent 36314ba572
commit 9e1ee89c7f
2 changed files with 16 additions and 2 deletions

View File

@ -155,14 +155,17 @@ class WindowDatasetOp::Dataset : public DatasetBase {
std::vector<std::vector<Tensor>> window_elements;
Status status = Status::OK();
{
const size_t target_size = TargetBufferSize(window_size, window_stride);
mutex_lock l(mu_);
if (!input_impl_ && buffer_.empty()) {
if (!input_impl_ &&
(buffer_.empty() ||
(dataset()->drop_remainder_ && buffer_.size() < target_size))) {
*end_of_sequence = true;
return Status::OK();
}
// Add elements to the buffer.
size_t target_size = TargetBufferSize(window_size, window_stride);
if (input_impl_) {
*end_of_sequence = false;
for (size_t i = buffer_.size(); i < target_size && !*end_of_sequence;

View File

@ -239,6 +239,17 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces(x, range(i*10, (i+1)*10))
self.assertDatasetProduces(y, range(i*10, (i+1)*10))
@combinations.generate(test_base.default_test_combinations())
def testDropRemainderOutput(self):
dataset = dataset_ops.Dataset.range(100)
dataset = dataset.window(30, drop_remainder=True)
dataset = dataset.flat_map(lambda x: x.batch(30))
dataset = dataset.batch(4)
self.assertDatasetProduces(
dataset,
expected_output=[[[y + 30 * x for y in range(30)] for x in range(3)]])
if __name__ == "__main__":
test.main()