Add absl::Cord support for builds with TF_CORD_SUPPORT enabled

Also fixes various bugs within TF's absl::Cord handling.

PiperOrigin-RevId: 346884244
Change-Id: I04cec023bedb5d772833614e19c766a7557bef5e
This commit is contained in:
Frank Chen 2020-12-10 15:58:00 -08:00 committed by TensorFlower Gardener
parent b938722556
commit 816bb157e9
8 changed files with 180 additions and 39 deletions

View File

@ -117,11 +117,19 @@ Status TFRecordWriter::WriteTensors(const std::vector<Tensor>& tensors) {
for (const auto& tensor : tensors) {
TensorProto proto;
tensor.AsProtoTensorContent(&proto);
#if defined(PLATFORM_GOOGLE)
TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto.SerializeAsCord()));
#else // PLATFORM_GOOGLE
#if defined(TF_CORD_SUPPORT)
// Creating raw pointer here because std::move() in a releases in OSS TF
// will result in a smart pointer being moved upon function creation, which
// will result in proto_buffer == nullptr when WriteRecord happens.
auto proto_buffer = new std::string();
proto.SerializeToString(proto_buffer);
absl::Cord proto_serialized = absl::MakeCordFromExternal(
*proto_buffer,
[proto_buffer](absl::string_view) { delete proto_buffer; });
TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto_serialized));
#else // TF_CORD_SUPPORT
TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto.SerializeAsString()));
#endif // PLATFORM_GOOGLE
#endif // TF_CORD_SUPPORT
}
return Status::OK();
}
@ -197,16 +205,16 @@ Status CustomWriter::WriteTensors(const std::vector<Tensor>& tensors) {
TensorProto* t = record.add_tensor();
tensor.AsProtoTensorContent(t);
}
#if defined(PLATFORM_GOOGLE)
return WriteRecord(record.SerializeAsCord());
#else // PLATFORM_GOOGLE
#if defined(TF_CORD_SUPPORT)
auto record_buffer = new std::string();
record.SerializeToString(record_buffer);
absl::Cord record_serialized = absl::MakeCordFromExternal(
*record_buffer,
[record_buffer](absl::string_view) { delete record_buffer; });
return WriteRecord(record_serialized);
#else // TF_CORD_SUPPORT
return WriteRecord(record.SerializeAsString());
#endif // PLATFORM_GOOGLE
}
if (compression_type_ != io::compression::kSnappy) {
return errors::InvalidArgument("Compression ", compression_type_,
" is not supported.");
#endif // TF_CORD_SUPPORT
}
std::vector<const TensorBuffer*> tensor_buffers;
@ -258,11 +266,16 @@ Status CustomWriter::WriteTensors(const std::vector<Tensor>& tensors) {
if (!port::Snappy_Compress(uncompressed.data(), total_size, &output)) {
return errors::Internal("Failed to compress using snappy.");
}
#if defined(PLATFORM_GOOGLE)
absl::Cord metadata_serialized = metadata.SerializeAsCord();
#else // PLATFORM_GOOGLE
#if defined(TF_CORD_SUPPORT)
auto metadata_buffer = new std::string();
metadata.SerializeToString(metadata_buffer);
absl::Cord metadata_serialized = absl::MakeCordFromExternal(
*metadata_buffer,
[metadata_buffer](absl::string_view) { delete metadata_buffer; });
#else
std::string metadata_serialized = metadata.SerializeAsString();
#endif // PLATFORM_GOOGLE
#endif // TF_CORD_SUPPORT
TF_RETURN_IF_ERROR(WriteRecord(metadata_serialized));
TF_RETURN_IF_ERROR(WriteRecord(output));
return Status::OK();
@ -296,14 +309,14 @@ Status CustomWriter::WriteRecord(const StringPiece& data) {
return dest_->Append(data);
}
#if defined(PLATFORM_GOOGLE)
#if defined(TF_CORD_SUPPORT)
Status CustomWriter::WriteRecord(const absl::Cord& data) {
char header[kHeaderSize];
core::EncodeFixed64(header, data.size());
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
return dest_->Append(data);
}
#endif // PLATFORM_GOOGLE
#endif // TF_CORD_SUPPORT
Status Reader::Create(Env* env, const std::string& filename,
const string& compression_type, int version,
@ -722,19 +735,9 @@ Status CustomReader::ReadTensors(std::vector<Tensor>* read_tensors) {
auto tensor_proto_str = std::move(tensor_proto_strs[complex_index].first);
size_t tensor_proto_size = tensor_proto_strs[complex_index].second;
TensorProto tp;
#if defined(PLATFORM_GOOGLE)
absl::string_view tensor_proto_view(tensor_proto_str.get(),
tensor_proto_size);
absl::Cord c = absl::MakeCordFromExternal(
tensor_proto_view, [s = std::move(tensor_proto_str)] {});
if (!tp.ParseFromCord(c)) {
return errors::Internal("Could not parse TensorProto");
}
#else // PLATFORM_GOOGLE
if (!tp.ParseFromArray(tensor_proto_str.get(), tensor_proto_size)) {
return errors::Internal("Could not parse TensorProto");
}
#endif // PLATFORM_GOOGLE
Tensor t;
if (!t.FromProto(tp)) {
return errors::Internal("Could not parse Tensor");
@ -824,7 +827,7 @@ Status CustomReader::ReadRecord(tstring* record) {
return input_stream_->ReadNBytes(length, record);
}
#if defined(PLATFORM_GOOGLE)
#if defined(TF_CORD_SUPPORT)
Status CustomReader::ReadRecord(absl::Cord* record) {
tstring header;
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
@ -832,15 +835,15 @@ Status CustomReader::ReadRecord(absl::Cord* record) {
if (compression_type_ == io::compression::kNone) {
return input_stream_->ReadNBytes(length, record);
} else {
auto tmp_str = absl::make_unique<tstring>();
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(length, tmp_str.get()));
auto tmp_str = new tstring();
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(length, tmp_str));
absl::string_view tmp_str_view(*tmp_str);
record->Append(
absl::MakeCordFromExternal(tmp_str_view, [s = std::move(tmp_str)] {}));
record->Append(absl::MakeCordFromExternal(
tmp_str_view, [tmp_str](absl::string_view) { delete tmp_str; }));
return Status::OK();
}
}
#endif
#endif // TF_CORD_SUPPORT
Status WriteMetadataFile(Env* env, const string& dir,
const experimental::SnapshotMetadataRecord* metadata) {

View File

@ -146,9 +146,9 @@ class CustomWriter : public Writer {
private:
Status WriteRecord(const StringPiece& data);
#if defined(PLATFORM_GOOGLE)
#if defined(TF_CORD_SUPPORT)
Status WriteRecord(const absl::Cord& data);
#endif // PLATFORM_GOOGLE
#endif // TF_CORD_SUPPORT
std::unique_ptr<WritableFile> dest_;
const std::string filename_;
@ -265,7 +265,7 @@ class CustomReader : public Reader {
Status ReadRecord(tstring* record);
#if defined(PLATFORM_GOOGLE)
#if defined(TF_CORD_SUPPORT)
Status ReadRecord(absl::Cord* record);
#endif

View File

@ -55,9 +55,10 @@ Status RandomAccessInputStream::ReadNBytes(int64 bytes_to_read,
if (bytes_to_read < 0) {
return errors::InvalidArgument("Cannot read negative number of bytes");
}
int64 current_size = result->size();
Status s = file_->Read(pos_, bytes_to_read, result);
if (s.ok() || errors::IsOutOfRange(s)) {
pos_ += result->size();
pos_ += result->size() - current_size;
}
return s;
}

View File

@ -52,6 +52,39 @@ TEST(RandomInputStream, ReadNBytes) {
EXPECT_EQ(10, in.Tell());
}
#if defined(TF_CORD_SUPPORT)
TEST(RandomInputStream, ReadNBytesWithCords) {
Env* env = Env::Default();
string fname = testing::TmpDir() + "/random_inputbuffer_test";
TF_ASSERT_OK(WriteStringToFile(env, fname, "0123456789"));
std::unique_ptr<RandomAccessFile> file;
TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file));
absl::Cord read;
RandomAccessInputStream in(file.get());
// Reading into `absl::Cord`s does not clear existing data from the cord.
TF_ASSERT_OK(in.ReadNBytes(3, &read));
EXPECT_EQ(read, "012");
EXPECT_EQ(3, in.Tell());
TF_ASSERT_OK(in.ReadNBytes(0, &read));
EXPECT_EQ(read, "012");
EXPECT_EQ(3, in.Tell());
TF_ASSERT_OK(in.ReadNBytes(5, &read));
EXPECT_EQ(read, "01234567");
EXPECT_EQ(8, in.Tell());
TF_ASSERT_OK(in.ReadNBytes(0, &read));
EXPECT_EQ(read, "01234567");
EXPECT_EQ(8, in.Tell());
EXPECT_TRUE(errors::IsOutOfRange(in.ReadNBytes(20, &read)));
EXPECT_EQ(read, "0123456789");
EXPECT_EQ(10, in.Tell());
TF_ASSERT_OK(in.ReadNBytes(0, &read));
EXPECT_EQ(read, "0123456789");
EXPECT_EQ(10, in.Tell());
}
#endif
TEST(RandomInputStream, SkipNBytes) {
Env* env = Env::Default();
string fname = testing::TmpDir() + "/random_inputbuffer_test";

View File

@ -228,6 +228,17 @@ Status ZlibInputStream::ReadNBytes(int64 bytes_to_read, tstring* result) {
return Status::OK();
}
#if defined(TF_CORD_SUPPORT)
Status ZlibInputStream::ReadNBytes(int64 bytes_to_read, absl::Cord* result) {
// TODO(frankchn): Optimize this instead of bouncing through the buffer.
tstring buf;
TF_RETURN_IF_ERROR(ReadNBytes(bytes_to_read, &buf));
result->Clear();
result->Append(buf.data());
return Status::OK();
}
#endif
int64 ZlibInputStream::Tell() const { return bytes_read_; }
Status ZlibInputStream::Inflate() {

View File

@ -68,6 +68,10 @@ class ZlibInputStream : public InputStreamInterface {
// others: If reading from stream failed.
Status ReadNBytes(int64 bytes_to_read, tstring* result) override;
#if defined(TF_CORD_SUPPORT)
Status ReadNBytes(int64 bytes_to_read, absl::Cord* result) override;
#endif
int64 Tell() const override;
Status Reset() override;

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <stdint.h>
#include <stdio.h>
#include <sys/mman.h>
#if defined(__linux__)
#include <sys/sendfile.h>
#endif
@ -31,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/platform/default/posix_file_system.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/error.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/file_system_helper.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/status.h"
@ -92,6 +94,34 @@ class PosixRandomAccessFile : public RandomAccessFile {
*result = StringPiece(scratch, dst - scratch);
return s;
}
#if defined(TF_CORD_SUPPORT)
Status Read(uint64 offset, size_t n, absl::Cord* cord) const override {
if (n == 0) {
return Status::OK();
}
if (n < 0) {
return errors::InvalidArgument(
"Attempting to read ", n,
" bytes. You cannot read a negative number of bytes.");
}
char* scratch = new char[n];
if (scratch == nullptr) {
return errors::ResourceExhausted("Unable to allocate ", n,
" bytes for file reading.");
}
StringPiece tmp;
Status s = Read(offset, n, &tmp, scratch);
absl::Cord tmp_cord = absl::MakeCordFromExternal(
absl::string_view(static_cast<char*>(scratch), tmp.size()),
[scratch](absl::string_view) { delete[] scratch; });
cord->Append(tmp_cord);
return s;
}
#endif
};
class PosixWritableFile : public WritableFile {
@ -118,6 +148,19 @@ class PosixWritableFile : public WritableFile {
return Status::OK();
}
#if defined(TF_CORD_SUPPORT)
// \brief Append 'cord' to the file.
Status Append(const absl::Cord& cord) override {
for (const auto& chunk : cord.Chunks()) {
size_t r = fwrite(chunk.data(), 1, chunk.size(), file_);
if (r != chunk.size()) {
return IOError(filename_, errno);
}
}
return Status::OK();
}
#endif
Status Close() override {
if (file_ == nullptr) {
return IOError(filename_, EBADF);

View File

@ -147,6 +147,34 @@ class WindowsRandomAccessFile : public RandomAccessFile {
*result = StringPiece(scratch, dst - scratch);
return s;
}
#if defined(TF_CORD_SUPPORT)
Status Read(uint64 offset, size_t n, absl::Cord* cord) const override {
if (n == 0) {
return Status::OK();
}
if (n < 0) {
return errors::InvalidArgument(
"Attempting to read ", n,
" bytes. You cannot read a negative number of bytes.");
}
char* scratch = new char[n];
if (scratch == nullptr) {
return errors::ResourceExhausted("Unable to allocate ", n,
" bytes for file reading.");
}
StringPiece tmp;
Status s = Read(offset, n, &tmp, scratch);
absl::Cord tmp_cord = absl::MakeCordFromExternal(
absl::string_view(static_cast<char*>(scratch), tmp.size()),
[scratch](absl::string_view) { delete[] scratch; });
cord->Append(tmp_cord);
return s;
}
#endif
};
class WindowsWritableFile : public WritableFile {
@ -177,6 +205,24 @@ class WindowsWritableFile : public WritableFile {
return Status::OK();
}
#if defined(TF_CORD_SUPPORT)
// \brief Append 'data' to the file.
Status Append(const absl::Cord& cord) override {
for (const auto& chunk : cord.Chunks()) {
DWORD bytes_written = 0;
DWORD data_size = static_cast<DWORD>(chunk.size());
BOOL write_result =
::WriteFile(hfile_, chunk.data(), data_size, &bytes_written, NULL);
if (FALSE == write_result) {
return IOErrorFromWindowsError("Failed to WriteFile: " + filename_);
}
assert(size_t(bytes_written) == chunk.size());
}
return Status::OK();
}
#endif
Status Tell(int64* position) override {
Status result = Flush();
if (!result.ok()) {