From 72060b46e9a7f46f4c9e1605ae7532307700a539 Mon Sep 17 00:00:00 2001 From: Cheng Ren <1428327+chengren311@users.noreply.github.com> Date: Thu, 21 Nov 2019 23:20:46 -0800 Subject: [PATCH] avoid doing reset when position is still in buffer. --- .../core/lib/io/buffered_inputstream.cc | 17 ++++++++++---- .../core/lib/io/buffered_inputstream_test.cc | 23 +++++++++++++++++++ 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/lib/io/buffered_inputstream.cc b/tensorflow/core/lib/io/buffered_inputstream.cc index 6f268de8cac..b471e50167c 100644 --- a/tensorflow/core/lib/io/buffered_inputstream.cc +++ b/tensorflow/core/lib/io/buffered_inputstream.cc @@ -157,15 +157,22 @@ Status BufferedInputStream::Seek(int64 position) { position); } - // Position of the buffer within file. - const int64 bufpos = Tell(); - if (position < bufpos) { - // Reset input stream and skip 'position' bytes. + // Position of the buffer's lower limit within file. + const int64 buf_lower_limit = input_stream_->Tell() - limit_ ; + if (position < buf_lower_limit) { + // Seek before buffer, reset input stream and skip 'position' bytes. TF_RETURN_IF_ERROR(Reset()); return SkipNBytes(position); } - return SkipNBytes(position - bufpos); + if (position < Tell()) { + // Seek within buffer before 'pos_' + pos_ -= Tell() - position; + return Status::OK(); + } + + // Seek after 'pos_' + return SkipNBytes(position - Tell()); } template <typename T> diff --git a/tensorflow/core/lib/io/buffered_inputstream_test.cc b/tensorflow/core/lib/io/buffered_inputstream_test.cc index d6c07344ba3..04156967897 100644 --- a/tensorflow/core/lib/io/buffered_inputstream_test.cc +++ b/tensorflow/core/lib/io/buffered_inputstream_test.cc @@ -394,6 +394,29 @@ TEST(BufferedInputStream, Seek) { } } +TEST(BufferedInputStream, Seek_NotReset) { + // This test verifies seek backwards within the buffer doesn't reset input_stream + Env* env = Env::Default(); + string fname; + ASSERT_TRUE(env->LocalTempFilename(&fname)); + TF_ASSERT_OK(WriteStringToFile(env, fname, "0123456789")); + std::unique_ptr<RandomAccessFile> file; + TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file)); + + std::unique_ptr<RandomAccessInputStream> input_stream( + new RandomAccessInputStream(file.get())); + tstring read; + BufferedInputStream in(input_stream.get(), 3); + + TF_ASSERT_OK(in.ReadNBytes(4, &read)); + int before_tell = input_stream.get()->Tell(); + EXPECT_EQ(before_tell, 6); + // Seek backwards + TF_ASSERT_OK(in.Seek(3)); + int after_tell = input_stream.get()->Tell(); + EXPECT_EQ(before_tell, after_tell); +} + TEST(BufferedInputStream, ReadAll_Empty) { Env* env = Env::Default(); string fname;