[tf.data] Fixing a bug in tf.data.Dataset.window
handling of drop_remainder=True
.
Fixes: https://github.com/tensorflow/tensorflow/issues/43703 PiperOrigin-RevId: 335119334 Change-Id: Ie91ef58abac04a756e37724891d70b63ef4765d6
This commit is contained in:
parent
36314ba572
commit
9e1ee89c7f
@ -155,14 +155,17 @@ class WindowDatasetOp::Dataset : public DatasetBase {
|
|||||||
std::vector<std::vector<Tensor>> window_elements;
|
std::vector<std::vector<Tensor>> window_elements;
|
||||||
Status status = Status::OK();
|
Status status = Status::OK();
|
||||||
{
|
{
|
||||||
|
const size_t target_size = TargetBufferSize(window_size, window_stride);
|
||||||
|
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
if (!input_impl_ && buffer_.empty()) {
|
if (!input_impl_ &&
|
||||||
|
(buffer_.empty() ||
|
||||||
|
(dataset()->drop_remainder_ && buffer_.size() < target_size))) {
|
||||||
*end_of_sequence = true;
|
*end_of_sequence = true;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add elements to the buffer.
|
// Add elements to the buffer.
|
||||||
size_t target_size = TargetBufferSize(window_size, window_stride);
|
|
||||||
if (input_impl_) {
|
if (input_impl_) {
|
||||||
*end_of_sequence = false;
|
*end_of_sequence = false;
|
||||||
for (size_t i = buffer_.size(); i < target_size && !*end_of_sequence;
|
for (size_t i = buffer_.size(); i < target_size && !*end_of_sequence;
|
||||||
|
@ -239,6 +239,17 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
self.assertDatasetProduces(x, range(i*10, (i+1)*10))
|
self.assertDatasetProduces(x, range(i*10, (i+1)*10))
|
||||||
self.assertDatasetProduces(y, range(i*10, (i+1)*10))
|
self.assertDatasetProduces(y, range(i*10, (i+1)*10))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testDropRemainderOutput(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(100)
|
||||||
|
dataset = dataset.window(30, drop_remainder=True)
|
||||||
|
dataset = dataset.flat_map(lambda x: x.batch(30))
|
||||||
|
dataset = dataset.batch(4)
|
||||||
|
|
||||||
|
self.assertDatasetProduces(
|
||||||
|
dataset,
|
||||||
|
expected_output=[[[y + 30 * x for y in range(30)] for x in range(3)]])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user