[tf.data] Fixing a bug in tf.data.Dataset.window
handling of drop_remainder=True
.
Fixes: https://github.com/tensorflow/tensorflow/issues/43703 PiperOrigin-RevId: 335119334 Change-Id: Ie91ef58abac04a756e37724891d70b63ef4765d6
This commit is contained in:
parent
36314ba572
commit
9e1ee89c7f
@ -155,14 +155,17 @@ class WindowDatasetOp::Dataset : public DatasetBase {
|
||||
std::vector<std::vector<Tensor>> window_elements;
|
||||
Status status = Status::OK();
|
||||
{
|
||||
const size_t target_size = TargetBufferSize(window_size, window_stride);
|
||||
|
||||
mutex_lock l(mu_);
|
||||
if (!input_impl_ && buffer_.empty()) {
|
||||
if (!input_impl_ &&
|
||||
(buffer_.empty() ||
|
||||
(dataset()->drop_remainder_ && buffer_.size() < target_size))) {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Add elements to the buffer.
|
||||
size_t target_size = TargetBufferSize(window_size, window_stride);
|
||||
if (input_impl_) {
|
||||
*end_of_sequence = false;
|
||||
for (size_t i = buffer_.size(); i < target_size && !*end_of_sequence;
|
||||
|
@ -239,6 +239,17 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertDatasetProduces(x, range(i*10, (i+1)*10))
|
||||
self.assertDatasetProduces(y, range(i*10, (i+1)*10))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testDropRemainderOutput(self):
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
dataset = dataset.window(30, drop_remainder=True)
|
||||
dataset = dataset.flat_map(lambda x: x.batch(30))
|
||||
dataset = dataset.batch(4)
|
||||
|
||||
self.assertDatasetProduces(
|
||||
dataset,
|
||||
expected_output=[[[y + 30 * x for y in range(30)] for x in range(3)]])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user