Add absl::Cord support for builds with TF_CORD_SUPPORT enabled
Also fixes various bugs within TF's absl::Cord handling. PiperOrigin-RevId: 346884244 Change-Id: I04cec023bedb5d772833614e19c766a7557bef5e
This commit is contained in:
parent
b938722556
commit
816bb157e9
@ -117,11 +117,19 @@ Status TFRecordWriter::WriteTensors(const std::vector<Tensor>& tensors) {
|
||||
for (const auto& tensor : tensors) {
|
||||
TensorProto proto;
|
||||
tensor.AsProtoTensorContent(&proto);
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto.SerializeAsCord()));
|
||||
#else // PLATFORM_GOOGLE
|
||||
#if defined(TF_CORD_SUPPORT)
|
||||
// Creating raw pointer here because std::move() in a releases in OSS TF
|
||||
// will result in a smart pointer being moved upon function creation, which
|
||||
// will result in proto_buffer == nullptr when WriteRecord happens.
|
||||
auto proto_buffer = new std::string();
|
||||
proto.SerializeToString(proto_buffer);
|
||||
absl::Cord proto_serialized = absl::MakeCordFromExternal(
|
||||
*proto_buffer,
|
||||
[proto_buffer](absl::string_view) { delete proto_buffer; });
|
||||
TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto_serialized));
|
||||
#else // TF_CORD_SUPPORT
|
||||
TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto.SerializeAsString()));
|
||||
#endif // PLATFORM_GOOGLE
|
||||
#endif // TF_CORD_SUPPORT
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -197,16 +205,16 @@ Status CustomWriter::WriteTensors(const std::vector<Tensor>& tensors) {
|
||||
TensorProto* t = record.add_tensor();
|
||||
tensor.AsProtoTensorContent(t);
|
||||
}
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
return WriteRecord(record.SerializeAsCord());
|
||||
#else // PLATFORM_GOOGLE
|
||||
#if defined(TF_CORD_SUPPORT)
|
||||
auto record_buffer = new std::string();
|
||||
record.SerializeToString(record_buffer);
|
||||
absl::Cord record_serialized = absl::MakeCordFromExternal(
|
||||
*record_buffer,
|
||||
[record_buffer](absl::string_view) { delete record_buffer; });
|
||||
return WriteRecord(record_serialized);
|
||||
#else // TF_CORD_SUPPORT
|
||||
return WriteRecord(record.SerializeAsString());
|
||||
#endif // PLATFORM_GOOGLE
|
||||
}
|
||||
|
||||
if (compression_type_ != io::compression::kSnappy) {
|
||||
return errors::InvalidArgument("Compression ", compression_type_,
|
||||
" is not supported.");
|
||||
#endif // TF_CORD_SUPPORT
|
||||
}
|
||||
|
||||
std::vector<const TensorBuffer*> tensor_buffers;
|
||||
@ -258,11 +266,16 @@ Status CustomWriter::WriteTensors(const std::vector<Tensor>& tensors) {
|
||||
if (!port::Snappy_Compress(uncompressed.data(), total_size, &output)) {
|
||||
return errors::Internal("Failed to compress using snappy.");
|
||||
}
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
absl::Cord metadata_serialized = metadata.SerializeAsCord();
|
||||
#else // PLATFORM_GOOGLE
|
||||
|
||||
#if defined(TF_CORD_SUPPORT)
|
||||
auto metadata_buffer = new std::string();
|
||||
metadata.SerializeToString(metadata_buffer);
|
||||
absl::Cord metadata_serialized = absl::MakeCordFromExternal(
|
||||
*metadata_buffer,
|
||||
[metadata_buffer](absl::string_view) { delete metadata_buffer; });
|
||||
#else
|
||||
std::string metadata_serialized = metadata.SerializeAsString();
|
||||
#endif // PLATFORM_GOOGLE
|
||||
#endif // TF_CORD_SUPPORT
|
||||
TF_RETURN_IF_ERROR(WriteRecord(metadata_serialized));
|
||||
TF_RETURN_IF_ERROR(WriteRecord(output));
|
||||
return Status::OK();
|
||||
@ -296,14 +309,14 @@ Status CustomWriter::WriteRecord(const StringPiece& data) {
|
||||
return dest_->Append(data);
|
||||
}
|
||||
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
#if defined(TF_CORD_SUPPORT)
|
||||
Status CustomWriter::WriteRecord(const absl::Cord& data) {
|
||||
char header[kHeaderSize];
|
||||
core::EncodeFixed64(header, data.size());
|
||||
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
|
||||
return dest_->Append(data);
|
||||
}
|
||||
#endif // PLATFORM_GOOGLE
|
||||
#endif // TF_CORD_SUPPORT
|
||||
|
||||
Status Reader::Create(Env* env, const std::string& filename,
|
||||
const string& compression_type, int version,
|
||||
@ -722,19 +735,9 @@ Status CustomReader::ReadTensors(std::vector<Tensor>* read_tensors) {
|
||||
auto tensor_proto_str = std::move(tensor_proto_strs[complex_index].first);
|
||||
size_t tensor_proto_size = tensor_proto_strs[complex_index].second;
|
||||
TensorProto tp;
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
absl::string_view tensor_proto_view(tensor_proto_str.get(),
|
||||
tensor_proto_size);
|
||||
absl::Cord c = absl::MakeCordFromExternal(
|
||||
tensor_proto_view, [s = std::move(tensor_proto_str)] {});
|
||||
if (!tp.ParseFromCord(c)) {
|
||||
return errors::Internal("Could not parse TensorProto");
|
||||
}
|
||||
#else // PLATFORM_GOOGLE
|
||||
if (!tp.ParseFromArray(tensor_proto_str.get(), tensor_proto_size)) {
|
||||
return errors::Internal("Could not parse TensorProto");
|
||||
}
|
||||
#endif // PLATFORM_GOOGLE
|
||||
Tensor t;
|
||||
if (!t.FromProto(tp)) {
|
||||
return errors::Internal("Could not parse Tensor");
|
||||
@ -824,7 +827,7 @@ Status CustomReader::ReadRecord(tstring* record) {
|
||||
return input_stream_->ReadNBytes(length, record);
|
||||
}
|
||||
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
#if defined(TF_CORD_SUPPORT)
|
||||
Status CustomReader::ReadRecord(absl::Cord* record) {
|
||||
tstring header;
|
||||
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
|
||||
@ -832,15 +835,15 @@ Status CustomReader::ReadRecord(absl::Cord* record) {
|
||||
if (compression_type_ == io::compression::kNone) {
|
||||
return input_stream_->ReadNBytes(length, record);
|
||||
} else {
|
||||
auto tmp_str = absl::make_unique<tstring>();
|
||||
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(length, tmp_str.get()));
|
||||
auto tmp_str = new tstring();
|
||||
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(length, tmp_str));
|
||||
absl::string_view tmp_str_view(*tmp_str);
|
||||
record->Append(
|
||||
absl::MakeCordFromExternal(tmp_str_view, [s = std::move(tmp_str)] {}));
|
||||
record->Append(absl::MakeCordFromExternal(
|
||||
tmp_str_view, [tmp_str](absl::string_view) { delete tmp_str; }));
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // TF_CORD_SUPPORT
|
||||
|
||||
Status WriteMetadataFile(Env* env, const string& dir,
|
||||
const experimental::SnapshotMetadataRecord* metadata) {
|
||||
|
@ -146,9 +146,9 @@ class CustomWriter : public Writer {
|
||||
private:
|
||||
Status WriteRecord(const StringPiece& data);
|
||||
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
#if defined(TF_CORD_SUPPORT)
|
||||
Status WriteRecord(const absl::Cord& data);
|
||||
#endif // PLATFORM_GOOGLE
|
||||
#endif // TF_CORD_SUPPORT
|
||||
|
||||
std::unique_ptr<WritableFile> dest_;
|
||||
const std::string filename_;
|
||||
@ -265,7 +265,7 @@ class CustomReader : public Reader {
|
||||
|
||||
Status ReadRecord(tstring* record);
|
||||
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
#if defined(TF_CORD_SUPPORT)
|
||||
Status ReadRecord(absl::Cord* record);
|
||||
#endif
|
||||
|
||||
|
@ -55,9 +55,10 @@ Status RandomAccessInputStream::ReadNBytes(int64 bytes_to_read,
|
||||
if (bytes_to_read < 0) {
|
||||
return errors::InvalidArgument("Cannot read negative number of bytes");
|
||||
}
|
||||
int64 current_size = result->size();
|
||||
Status s = file_->Read(pos_, bytes_to_read, result);
|
||||
if (s.ok() || errors::IsOutOfRange(s)) {
|
||||
pos_ += result->size();
|
||||
pos_ += result->size() - current_size;
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
@ -52,6 +52,39 @@ TEST(RandomInputStream, ReadNBytes) {
|
||||
EXPECT_EQ(10, in.Tell());
|
||||
}
|
||||
|
||||
#if defined(TF_CORD_SUPPORT)
|
||||
TEST(RandomInputStream, ReadNBytesWithCords) {
|
||||
Env* env = Env::Default();
|
||||
string fname = testing::TmpDir() + "/random_inputbuffer_test";
|
||||
TF_ASSERT_OK(WriteStringToFile(env, fname, "0123456789"));
|
||||
|
||||
std::unique_ptr<RandomAccessFile> file;
|
||||
TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file));
|
||||
absl::Cord read;
|
||||
RandomAccessInputStream in(file.get());
|
||||
|
||||
// Reading into `absl::Cord`s does not clear existing data from the cord.
|
||||
TF_ASSERT_OK(in.ReadNBytes(3, &read));
|
||||
EXPECT_EQ(read, "012");
|
||||
EXPECT_EQ(3, in.Tell());
|
||||
TF_ASSERT_OK(in.ReadNBytes(0, &read));
|
||||
EXPECT_EQ(read, "012");
|
||||
EXPECT_EQ(3, in.Tell());
|
||||
TF_ASSERT_OK(in.ReadNBytes(5, &read));
|
||||
EXPECT_EQ(read, "01234567");
|
||||
EXPECT_EQ(8, in.Tell());
|
||||
TF_ASSERT_OK(in.ReadNBytes(0, &read));
|
||||
EXPECT_EQ(read, "01234567");
|
||||
EXPECT_EQ(8, in.Tell());
|
||||
EXPECT_TRUE(errors::IsOutOfRange(in.ReadNBytes(20, &read)));
|
||||
EXPECT_EQ(read, "0123456789");
|
||||
EXPECT_EQ(10, in.Tell());
|
||||
TF_ASSERT_OK(in.ReadNBytes(0, &read));
|
||||
EXPECT_EQ(read, "0123456789");
|
||||
EXPECT_EQ(10, in.Tell());
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(RandomInputStream, SkipNBytes) {
|
||||
Env* env = Env::Default();
|
||||
string fname = testing::TmpDir() + "/random_inputbuffer_test";
|
||||
|
@ -228,6 +228,17 @@ Status ZlibInputStream::ReadNBytes(int64 bytes_to_read, tstring* result) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#if defined(TF_CORD_SUPPORT)
|
||||
Status ZlibInputStream::ReadNBytes(int64 bytes_to_read, absl::Cord* result) {
|
||||
// TODO(frankchn): Optimize this instead of bouncing through the buffer.
|
||||
tstring buf;
|
||||
TF_RETURN_IF_ERROR(ReadNBytes(bytes_to_read, &buf));
|
||||
result->Clear();
|
||||
result->Append(buf.data());
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
int64 ZlibInputStream::Tell() const { return bytes_read_; }
|
||||
|
||||
Status ZlibInputStream::Inflate() {
|
||||
|
@ -68,6 +68,10 @@ class ZlibInputStream : public InputStreamInterface {
|
||||
// others: If reading from stream failed.
|
||||
Status ReadNBytes(int64 bytes_to_read, tstring* result) override;
|
||||
|
||||
#if defined(TF_CORD_SUPPORT)
|
||||
Status ReadNBytes(int64 bytes_to_read, absl::Cord* result) override;
|
||||
#endif
|
||||
|
||||
int64 Tell() const override;
|
||||
|
||||
Status Reset() override;
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <sys/mman.h>
|
||||
|
||||
#if defined(__linux__)
|
||||
#include <sys/sendfile.h>
|
||||
#endif
|
||||
@ -31,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/default/posix_file_system.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/error.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/file_system_helper.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
@ -92,6 +94,34 @@ class PosixRandomAccessFile : public RandomAccessFile {
|
||||
*result = StringPiece(scratch, dst - scratch);
|
||||
return s;
|
||||
}
|
||||
|
||||
#if defined(TF_CORD_SUPPORT)
|
||||
Status Read(uint64 offset, size_t n, absl::Cord* cord) const override {
|
||||
if (n == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (n < 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Attempting to read ", n,
|
||||
" bytes. You cannot read a negative number of bytes.");
|
||||
}
|
||||
|
||||
char* scratch = new char[n];
|
||||
if (scratch == nullptr) {
|
||||
return errors::ResourceExhausted("Unable to allocate ", n,
|
||||
" bytes for file reading.");
|
||||
}
|
||||
|
||||
StringPiece tmp;
|
||||
Status s = Read(offset, n, &tmp, scratch);
|
||||
|
||||
absl::Cord tmp_cord = absl::MakeCordFromExternal(
|
||||
absl::string_view(static_cast<char*>(scratch), tmp.size()),
|
||||
[scratch](absl::string_view) { delete[] scratch; });
|
||||
cord->Append(tmp_cord);
|
||||
return s;
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
class PosixWritableFile : public WritableFile {
|
||||
@ -118,6 +148,19 @@ class PosixWritableFile : public WritableFile {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#if defined(TF_CORD_SUPPORT)
|
||||
// \brief Append 'cord' to the file.
|
||||
Status Append(const absl::Cord& cord) override {
|
||||
for (const auto& chunk : cord.Chunks()) {
|
||||
size_t r = fwrite(chunk.data(), 1, chunk.size(), file_);
|
||||
if (r != chunk.size()) {
|
||||
return IOError(filename_, errno);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
Status Close() override {
|
||||
if (file_ == nullptr) {
|
||||
return IOError(filename_, EBADF);
|
||||
|
@ -147,6 +147,34 @@ class WindowsRandomAccessFile : public RandomAccessFile {
|
||||
*result = StringPiece(scratch, dst - scratch);
|
||||
return s;
|
||||
}
|
||||
|
||||
#if defined(TF_CORD_SUPPORT)
|
||||
Status Read(uint64 offset, size_t n, absl::Cord* cord) const override {
|
||||
if (n == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (n < 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Attempting to read ", n,
|
||||
" bytes. You cannot read a negative number of bytes.");
|
||||
}
|
||||
|
||||
char* scratch = new char[n];
|
||||
if (scratch == nullptr) {
|
||||
return errors::ResourceExhausted("Unable to allocate ", n,
|
||||
" bytes for file reading.");
|
||||
}
|
||||
|
||||
StringPiece tmp;
|
||||
Status s = Read(offset, n, &tmp, scratch);
|
||||
|
||||
absl::Cord tmp_cord = absl::MakeCordFromExternal(
|
||||
absl::string_view(static_cast<char*>(scratch), tmp.size()),
|
||||
[scratch](absl::string_view) { delete[] scratch; });
|
||||
cord->Append(tmp_cord);
|
||||
return s;
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
class WindowsWritableFile : public WritableFile {
|
||||
@ -177,6 +205,24 @@ class WindowsWritableFile : public WritableFile {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#if defined(TF_CORD_SUPPORT)
|
||||
// \brief Append 'data' to the file.
|
||||
Status Append(const absl::Cord& cord) override {
|
||||
for (const auto& chunk : cord.Chunks()) {
|
||||
DWORD bytes_written = 0;
|
||||
DWORD data_size = static_cast<DWORD>(chunk.size());
|
||||
BOOL write_result =
|
||||
::WriteFile(hfile_, chunk.data(), data_size, &bytes_written, NULL);
|
||||
if (FALSE == write_result) {
|
||||
return IOErrorFromWindowsError("Failed to WriteFile: " + filename_);
|
||||
}
|
||||
|
||||
assert(size_t(bytes_written) == chunk.size());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
Status Tell(int64* position) override {
|
||||
Status result = Flush();
|
||||
if (!result.ok()) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user