From 816bb157e9b77c3eac597f5991bac3d569564256 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Thu, 10 Dec 2020 15:58:00 -0800 Subject: [PATCH] Add absl::Cord support for builds with TF_CORD_SUPPORT enabled Also fixes various bugs within TF's absl::Cord handling. PiperOrigin-RevId: 346884244 Change-Id: I04cec023bedb5d772833614e19c766a7557bef5e --- .../data/experimental/snapshot_util.cc | 73 ++++++++++--------- .../kernels/data/experimental/snapshot_util.h | 6 +- tensorflow/core/lib/io/random_inputstream.cc | 3 +- .../core/lib/io/random_inputstream_test.cc | 33 +++++++++ tensorflow/core/lib/io/zlib_inputstream.cc | 11 +++ tensorflow/core/lib/io/zlib_inputstream.h | 4 + .../platform/default/posix_file_system.cc | 43 +++++++++++ .../platform/windows/windows_file_system.cc | 46 ++++++++++++ 8 files changed, 180 insertions(+), 39 deletions(-) diff --git a/tensorflow/core/kernels/data/experimental/snapshot_util.cc b/tensorflow/core/kernels/data/experimental/snapshot_util.cc index 12ffd55722b..d3c5e4bbcd5 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_util.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_util.cc @@ -117,11 +117,19 @@ Status TFRecordWriter::WriteTensors(const std::vector& tensors) { for (const auto& tensor : tensors) { TensorProto proto; tensor.AsProtoTensorContent(&proto); -#if defined(PLATFORM_GOOGLE) - TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto.SerializeAsCord())); -#else // PLATFORM_GOOGLE +#if defined(TF_CORD_SUPPORT) + // Creating raw pointer here because std::move() in a releases in OSS TF + // will result in a smart pointer being moved upon function creation, which + // will result in proto_buffer == nullptr when WriteRecord happens. + auto proto_buffer = new std::string(); + proto.SerializeToString(proto_buffer); + absl::Cord proto_serialized = absl::MakeCordFromExternal( + *proto_buffer, + [proto_buffer](absl::string_view) { delete proto_buffer; }); + TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto_serialized)); +#else // TF_CORD_SUPPORT TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto.SerializeAsString())); -#endif // PLATFORM_GOOGLE +#endif // TF_CORD_SUPPORT } return Status::OK(); } @@ -197,16 +205,16 @@ Status CustomWriter::WriteTensors(const std::vector& tensors) { TensorProto* t = record.add_tensor(); tensor.AsProtoTensorContent(t); } -#if defined(PLATFORM_GOOGLE) - return WriteRecord(record.SerializeAsCord()); -#else // PLATFORM_GOOGLE +#if defined(TF_CORD_SUPPORT) + auto record_buffer = new std::string(); + record.SerializeToString(record_buffer); + absl::Cord record_serialized = absl::MakeCordFromExternal( + *record_buffer, + [record_buffer](absl::string_view) { delete record_buffer; }); + return WriteRecord(record_serialized); +#else // TF_CORD_SUPPORT return WriteRecord(record.SerializeAsString()); -#endif // PLATFORM_GOOGLE - } - - if (compression_type_ != io::compression::kSnappy) { - return errors::InvalidArgument("Compression ", compression_type_, - " is not supported."); +#endif // TF_CORD_SUPPORT } std::vector tensor_buffers; @@ -258,11 +266,16 @@ Status CustomWriter::WriteTensors(const std::vector& tensors) { if (!port::Snappy_Compress(uncompressed.data(), total_size, &output)) { return errors::Internal("Failed to compress using snappy."); } -#if defined(PLATFORM_GOOGLE) - absl::Cord metadata_serialized = metadata.SerializeAsCord(); -#else // PLATFORM_GOOGLE + +#if defined(TF_CORD_SUPPORT) + auto metadata_buffer = new std::string(); + metadata.SerializeToString(metadata_buffer); + absl::Cord metadata_serialized = absl::MakeCordFromExternal( + *metadata_buffer, + [metadata_buffer](absl::string_view) { delete metadata_buffer; }); +#else std::string metadata_serialized = metadata.SerializeAsString(); -#endif // PLATFORM_GOOGLE +#endif // TF_CORD_SUPPORT TF_RETURN_IF_ERROR(WriteRecord(metadata_serialized)); TF_RETURN_IF_ERROR(WriteRecord(output)); return Status::OK(); @@ -296,14 +309,14 @@ Status CustomWriter::WriteRecord(const StringPiece& data) { return dest_->Append(data); } -#if defined(PLATFORM_GOOGLE) +#if defined(TF_CORD_SUPPORT) Status CustomWriter::WriteRecord(const absl::Cord& data) { char header[kHeaderSize]; core::EncodeFixed64(header, data.size()); TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); return dest_->Append(data); } -#endif // PLATFORM_GOOGLE +#endif // TF_CORD_SUPPORT Status Reader::Create(Env* env, const std::string& filename, const string& compression_type, int version, @@ -722,19 +735,9 @@ Status CustomReader::ReadTensors(std::vector* read_tensors) { auto tensor_proto_str = std::move(tensor_proto_strs[complex_index].first); size_t tensor_proto_size = tensor_proto_strs[complex_index].second; TensorProto tp; -#if defined(PLATFORM_GOOGLE) - absl::string_view tensor_proto_view(tensor_proto_str.get(), - tensor_proto_size); - absl::Cord c = absl::MakeCordFromExternal( - tensor_proto_view, [s = std::move(tensor_proto_str)] {}); - if (!tp.ParseFromCord(c)) { - return errors::Internal("Could not parse TensorProto"); - } -#else // PLATFORM_GOOGLE if (!tp.ParseFromArray(tensor_proto_str.get(), tensor_proto_size)) { return errors::Internal("Could not parse TensorProto"); } -#endif // PLATFORM_GOOGLE Tensor t; if (!t.FromProto(tp)) { return errors::Internal("Could not parse Tensor"); @@ -824,7 +827,7 @@ Status CustomReader::ReadRecord(tstring* record) { return input_stream_->ReadNBytes(length, record); } -#if defined(PLATFORM_GOOGLE) +#if defined(TF_CORD_SUPPORT) Status CustomReader::ReadRecord(absl::Cord* record) { tstring header; TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header)); @@ -832,15 +835,15 @@ Status CustomReader::ReadRecord(absl::Cord* record) { if (compression_type_ == io::compression::kNone) { return input_stream_->ReadNBytes(length, record); } else { - auto tmp_str = absl::make_unique(); - TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(length, tmp_str.get())); + auto tmp_str = new tstring(); + TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(length, tmp_str)); absl::string_view tmp_str_view(*tmp_str); - record->Append( - absl::MakeCordFromExternal(tmp_str_view, [s = std::move(tmp_str)] {})); + record->Append(absl::MakeCordFromExternal( + tmp_str_view, [tmp_str](absl::string_view) { delete tmp_str; })); return Status::OK(); } } -#endif +#endif // TF_CORD_SUPPORT Status WriteMetadataFile(Env* env, const string& dir, const experimental::SnapshotMetadataRecord* metadata) { diff --git a/tensorflow/core/kernels/data/experimental/snapshot_util.h b/tensorflow/core/kernels/data/experimental/snapshot_util.h index 5b228468861..35bd1f599eb 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_util.h +++ b/tensorflow/core/kernels/data/experimental/snapshot_util.h @@ -146,9 +146,9 @@ class CustomWriter : public Writer { private: Status WriteRecord(const StringPiece& data); -#if defined(PLATFORM_GOOGLE) +#if defined(TF_CORD_SUPPORT) Status WriteRecord(const absl::Cord& data); -#endif // PLATFORM_GOOGLE +#endif // TF_CORD_SUPPORT std::unique_ptr dest_; const std::string filename_; @@ -265,7 +265,7 @@ class CustomReader : public Reader { Status ReadRecord(tstring* record); -#if defined(PLATFORM_GOOGLE) +#if defined(TF_CORD_SUPPORT) Status ReadRecord(absl::Cord* record); #endif diff --git a/tensorflow/core/lib/io/random_inputstream.cc b/tensorflow/core/lib/io/random_inputstream.cc index 0f07b5f58c3..6f931a83bc4 100644 --- a/tensorflow/core/lib/io/random_inputstream.cc +++ b/tensorflow/core/lib/io/random_inputstream.cc @@ -55,9 +55,10 @@ Status RandomAccessInputStream::ReadNBytes(int64 bytes_to_read, if (bytes_to_read < 0) { return errors::InvalidArgument("Cannot read negative number of bytes"); } + int64 current_size = result->size(); Status s = file_->Read(pos_, bytes_to_read, result); if (s.ok() || errors::IsOutOfRange(s)) { - pos_ += result->size(); + pos_ += result->size() - current_size; } return s; } diff --git a/tensorflow/core/lib/io/random_inputstream_test.cc b/tensorflow/core/lib/io/random_inputstream_test.cc index 2fb325b6e76..58d4b9b3efa 100644 --- a/tensorflow/core/lib/io/random_inputstream_test.cc +++ b/tensorflow/core/lib/io/random_inputstream_test.cc @@ -52,6 +52,39 @@ TEST(RandomInputStream, ReadNBytes) { EXPECT_EQ(10, in.Tell()); } +#if defined(TF_CORD_SUPPORT) +TEST(RandomInputStream, ReadNBytesWithCords) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/random_inputbuffer_test"; + TF_ASSERT_OK(WriteStringToFile(env, fname, "0123456789")); + + std::unique_ptr file; + TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file)); + absl::Cord read; + RandomAccessInputStream in(file.get()); + + // Reading into `absl::Cord`s does not clear existing data from the cord. + TF_ASSERT_OK(in.ReadNBytes(3, &read)); + EXPECT_EQ(read, "012"); + EXPECT_EQ(3, in.Tell()); + TF_ASSERT_OK(in.ReadNBytes(0, &read)); + EXPECT_EQ(read, "012"); + EXPECT_EQ(3, in.Tell()); + TF_ASSERT_OK(in.ReadNBytes(5, &read)); + EXPECT_EQ(read, "01234567"); + EXPECT_EQ(8, in.Tell()); + TF_ASSERT_OK(in.ReadNBytes(0, &read)); + EXPECT_EQ(read, "01234567"); + EXPECT_EQ(8, in.Tell()); + EXPECT_TRUE(errors::IsOutOfRange(in.ReadNBytes(20, &read))); + EXPECT_EQ(read, "0123456789"); + EXPECT_EQ(10, in.Tell()); + TF_ASSERT_OK(in.ReadNBytes(0, &read)); + EXPECT_EQ(read, "0123456789"); + EXPECT_EQ(10, in.Tell()); +} +#endif + TEST(RandomInputStream, SkipNBytes) { Env* env = Env::Default(); string fname = testing::TmpDir() + "/random_inputbuffer_test"; diff --git a/tensorflow/core/lib/io/zlib_inputstream.cc b/tensorflow/core/lib/io/zlib_inputstream.cc index 7ea8508c569..293934605c7 100644 --- a/tensorflow/core/lib/io/zlib_inputstream.cc +++ b/tensorflow/core/lib/io/zlib_inputstream.cc @@ -228,6 +228,17 @@ Status ZlibInputStream::ReadNBytes(int64 bytes_to_read, tstring* result) { return Status::OK(); } +#if defined(TF_CORD_SUPPORT) +Status ZlibInputStream::ReadNBytes(int64 bytes_to_read, absl::Cord* result) { + // TODO(frankchn): Optimize this instead of bouncing through the buffer. + tstring buf; + TF_RETURN_IF_ERROR(ReadNBytes(bytes_to_read, &buf)); + result->Clear(); + result->Append(buf.data()); + return Status::OK(); +} +#endif + int64 ZlibInputStream::Tell() const { return bytes_read_; } Status ZlibInputStream::Inflate() { diff --git a/tensorflow/core/lib/io/zlib_inputstream.h b/tensorflow/core/lib/io/zlib_inputstream.h index 427daa74c8f..da9c3dee518 100644 --- a/tensorflow/core/lib/io/zlib_inputstream.h +++ b/tensorflow/core/lib/io/zlib_inputstream.h @@ -68,6 +68,10 @@ class ZlibInputStream : public InputStreamInterface { // others: If reading from stream failed. Status ReadNBytes(int64 bytes_to_read, tstring* result) override; +#if defined(TF_CORD_SUPPORT) + Status ReadNBytes(int64 bytes_to_read, absl::Cord* result) override; +#endif + int64 Tell() const override; Status Reset() override; diff --git a/tensorflow/core/platform/default/posix_file_system.cc b/tensorflow/core/platform/default/posix_file_system.cc index 18fea3fe15d..29f9bbab28f 100644 --- a/tensorflow/core/platform/default/posix_file_system.cc +++ b/tensorflow/core/platform/default/posix_file_system.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include + #if defined(__linux__) #include #endif @@ -31,6 +32,7 @@ limitations under the License. #include "tensorflow/core/platform/default/posix_file_system.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/error.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/file_system_helper.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/status.h" @@ -92,6 +94,34 @@ class PosixRandomAccessFile : public RandomAccessFile { *result = StringPiece(scratch, dst - scratch); return s; } + +#if defined(TF_CORD_SUPPORT) + Status Read(uint64 offset, size_t n, absl::Cord* cord) const override { + if (n == 0) { + return Status::OK(); + } + if (n < 0) { + return errors::InvalidArgument( + "Attempting to read ", n, + " bytes. You cannot read a negative number of bytes."); + } + + char* scratch = new char[n]; + if (scratch == nullptr) { + return errors::ResourceExhausted("Unable to allocate ", n, + " bytes for file reading."); + } + + StringPiece tmp; + Status s = Read(offset, n, &tmp, scratch); + + absl::Cord tmp_cord = absl::MakeCordFromExternal( + absl::string_view(static_cast(scratch), tmp.size()), + [scratch](absl::string_view) { delete[] scratch; }); + cord->Append(tmp_cord); + return s; + } +#endif }; class PosixWritableFile : public WritableFile { @@ -118,6 +148,19 @@ class PosixWritableFile : public WritableFile { return Status::OK(); } +#if defined(TF_CORD_SUPPORT) + // \brief Append 'cord' to the file. + Status Append(const absl::Cord& cord) override { + for (const auto& chunk : cord.Chunks()) { + size_t r = fwrite(chunk.data(), 1, chunk.size(), file_); + if (r != chunk.size()) { + return IOError(filename_, errno); + } + } + return Status::OK(); + } +#endif + Status Close() override { if (file_ == nullptr) { return IOError(filename_, EBADF); diff --git a/tensorflow/core/platform/windows/windows_file_system.cc b/tensorflow/core/platform/windows/windows_file_system.cc index 475f8791144..519ff727f86 100644 --- a/tensorflow/core/platform/windows/windows_file_system.cc +++ b/tensorflow/core/platform/windows/windows_file_system.cc @@ -147,6 +147,34 @@ class WindowsRandomAccessFile : public RandomAccessFile { *result = StringPiece(scratch, dst - scratch); return s; } + +#if defined(TF_CORD_SUPPORT) + Status Read(uint64 offset, size_t n, absl::Cord* cord) const override { + if (n == 0) { + return Status::OK(); + } + if (n < 0) { + return errors::InvalidArgument( + "Attempting to read ", n, + " bytes. You cannot read a negative number of bytes."); + } + + char* scratch = new char[n]; + if (scratch == nullptr) { + return errors::ResourceExhausted("Unable to allocate ", n, + " bytes for file reading."); + } + + StringPiece tmp; + Status s = Read(offset, n, &tmp, scratch); + + absl::Cord tmp_cord = absl::MakeCordFromExternal( + absl::string_view(static_cast(scratch), tmp.size()), + [scratch](absl::string_view) { delete[] scratch; }); + cord->Append(tmp_cord); + return s; + } +#endif }; class WindowsWritableFile : public WritableFile { @@ -177,6 +205,24 @@ class WindowsWritableFile : public WritableFile { return Status::OK(); } +#if defined(TF_CORD_SUPPORT) + // \brief Append 'data' to the file. + Status Append(const absl::Cord& cord) override { + for (const auto& chunk : cord.Chunks()) { + DWORD bytes_written = 0; + DWORD data_size = static_cast(chunk.size()); + BOOL write_result = + ::WriteFile(hfile_, chunk.data(), data_size, &bytes_written, NULL); + if (FALSE == write_result) { + return IOErrorFromWindowsError("Failed to WriteFile: " + filename_); + } + + assert(size_t(bytes_written) == chunk.size()); + } + return Status::OK(); + } +#endif + Status Tell(int64* position) override { Status result = Flush(); if (!result.ok()) {