Updated ReaderInterface and subclasses to use tstring.

This is a part of a larger migration effort for tensorflow::tstring.
See: https://github.com/tensorflow/community/pull/91
PiperOrigin-RevId: 265822025
This commit is contained in:
Dero Gharibian 2019-08-27 19:55:45 -07:00 committed by TensorFlower Gardener
parent 8df6f08527
commit 7ba3600c94
48 changed files with 221 additions and 186 deletions

View File

@ -186,7 +186,7 @@ Status AutoClusteringTest::RunAutoClusteringTestWithGzippedPbtxt(
/*input_buffer_bytes=*/k_buffer_size,
/*output_buffer_bytes=*/k_buffer_size,
io::ZlibCompressionOptions::GZIP());
string decompressed_pbtxt_string;
tstring decompressed_pbtxt_string;
Status s = in.ReadNBytes(INT_MAX, &decompressed_pbtxt_string);
if (!s.ok() && !errors::IsOutOfRange(s)) {
// OutOfRange is fine since we set the number of read bytes to INT_MAX.

View File

@ -121,7 +121,7 @@ class InitializeTRTResource : public OpKernel {
uint64 offset = 0;
int num_loaded_engine = 0;
do {
string record;
tstring record;
Status status = reader->ReadRecord(&offset, &record);
if (errors::IsOutOfRange(status)) break;

View File

@ -66,7 +66,7 @@ class BigQueryReader : public ReaderBase {
return Status::OK();
}
Status ReadLocked(string* key, string* value, bool* produced,
Status ReadLocked(tstring* key, tstring* value, bool* produced,
bool* at_end) override {
*at_end = false;
*produced = false;

View File

@ -31,12 +31,13 @@ class SequenceFileReader {
new io::BufferedInputStream(file, kSequenceFileBufferSize)) {}
Status ReadHeader() {
string version;
tstring version;
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(4, &version));
if (version.substr(0, 3) != "SEQ" || version[3] != 6) {
StringPiece version_view(version);
if (version_view.substr(0, 3) != "SEQ" || version[3] != 6) {
return errors::InvalidArgument(
"sequence file header must starts with `SEQ6`, received \"",
version.substr(0, 3), static_cast<int>(version[3]), "\"");
version_view.substr(0, 3), static_cast<int>(version[3]), "\"");
}
TF_RETURN_IF_ERROR(ReadString(&key_class_name_));
TF_RETURN_IF_ERROR(ReadString(&value_class_name_));
@ -50,7 +51,7 @@ class SequenceFileReader {
"' is currently not supported");
}
string buffer;
tstring buffer;
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(2, &buffer));
compression_ = buffer[0];
block_compression_ = buffer[1];
@ -84,12 +85,12 @@ class SequenceFileReader {
return Status::OK();
}
Status ReadRecord(string* key, string* value) {
Status ReadRecord(tstring* key, tstring* value) {
uint32 length = 0;
TF_RETURN_IF_ERROR(ReadUInt32(&length));
if (length == static_cast<uint32>(-1)) {
// Sync marker.
string sync_marker;
tstring sync_marker;
TF_RETURN_IF_ERROR(
input_stream_->ReadNBytes(kSyncMarkerSize, &sync_marker));
if (sync_marker != sync_marker_) {
@ -114,7 +115,7 @@ class SequenceFileReader {
return Status::OK();
}
Status ReadString(string* value) {
Status ReadString(tstring* value) {
int64 length = 0;
TF_RETURN_IF_ERROR(ReadVInt(&length));
if (value == nullptr) {
@ -124,7 +125,7 @@ class SequenceFileReader {
}
Status ReadUInt32(uint32* value) {
string buffer;
tstring buffer;
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(4, &buffer));
*value = ((static_cast<uint32>(buffer[0]) << 24) |
static_cast<uint32>(buffer[1]) << 16) |
@ -134,7 +135,7 @@ class SequenceFileReader {
}
Status ReadVInt(int64* value) {
string buffer;
tstring buffer;
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(1, &buffer));
if (buffer[0] >= -112) {
*value = static_cast<int64>(buffer[0]);
@ -167,12 +168,12 @@ class SequenceFileReader {
private:
std::unique_ptr<io::InputStreamInterface> input_stream_;
string key_class_name_;
string value_class_name_;
string sync_marker_;
tstring key_class_name_;
tstring value_class_name_;
tstring sync_marker_;
bool compression_;
bool block_compression_;
string compression_codec_class_name_;
tstring compression_codec_class_name_;
TF_DISALLOW_COPY_AND_ASSIGN(SequenceFileReader);
};
class SequenceFileDatasetOp : public DatasetOpKernel {
@ -258,7 +259,7 @@ class SequenceFileDatasetOp : public DatasetOpKernel {
do {
// We are currently processing a file, so try to read the next record.
if (reader_) {
string key, value;
tstring key, value;
Status status = reader_->ReadRecord(&key, &value);
if (!errors::IsOutOfRange(status)) {
TF_RETURN_IF_ERROR(status);

View File

@ -53,16 +53,16 @@ Status ReaderBase::ResetLocked() {
return Status::OK();
}
Status ReaderBase::SerializeState(string* state) {
Status ReaderBase::SerializeState(tstring* state) {
mutex_lock lock(mu_);
return SerializeStateLocked(state);
}
Status ReaderBase::SerializeStateLocked(string* state) {
Status ReaderBase::SerializeStateLocked(tstring* state) {
return errors::Unimplemented("Reader SerializeState");
}
Status ReaderBase::RestoreState(const string& state) {
Status ReaderBase::RestoreState(const tstring& state) {
mutex_lock lock(mu_);
Status status = RestoreStateLocked(state);
if (!status.ok()) {
@ -71,13 +71,13 @@ Status ReaderBase::RestoreState(const string& state) {
return status;
}
Status ReaderBase::RestoreStateLocked(const string& state) {
Status ReaderBase::RestoreStateLocked(const tstring& state) {
return errors::Unimplemented("Reader RestoreState");
}
int64 ReaderBase::ReadUpTo(const int64 num_records, QueueInterface* queue,
std::vector<string>* keys,
std::vector<string>* values,
std::vector<tstring>* keys,
std::vector<tstring>* values,
OpKernelContext* context) {
mutex_lock lock(mu_);
int64 records_produced_this_call = 0;
@ -133,16 +133,16 @@ int64 ReaderBase::ReadUpTo(const int64 num_records, QueueInterface* queue,
}
// Default implementation just reads one record at a time.
Status ReaderBase::ReadUpToLocked(int64 num_records, std::vector<string>* keys,
std::vector<string>* values, int64* num_read,
Status ReaderBase::ReadUpToLocked(int64 num_records, std::vector<tstring>* keys,
std::vector<tstring>* values, int64* num_read,
bool* at_end) {
bool produced = false;
string key;
string value;
tstring key;
tstring value;
Status status = ReadLocked(&key, &value, &produced, at_end);
if (produced) {
keys->emplace_back(key);
values->emplace_back(value);
keys->push_back(std::move(key));
values->push_back(std::move(value));
*num_read = 1;
} else {
*num_read = 0;
@ -150,7 +150,7 @@ Status ReaderBase::ReadUpToLocked(int64 num_records, std::vector<string>* keys,
return status;
}
void ReaderBase::Read(QueueInterface* queue, string* key, string* value,
void ReaderBase::Read(QueueInterface* queue, tstring* key, tstring* value,
OpKernelContext* context) {
mutex_lock lock(mu_);
while (true) {
@ -228,10 +228,19 @@ void ReaderBase::SaveBaseState(ReaderBaseState* state) const {
state->set_work_started(work_started_);
state->set_work_finished(work_finished_);
state->set_num_records_produced(num_records_produced_);
state->set_current_work(work_);
// Unfortunately, external proto does not accept string_view.
#if defined(PLATFORM_GOOGLE)
// TODO(dero): Remove NOLINT after USE_TSTRING is enabled. The external proto
// compiler does not create an overloaded set method that accepts
// absl::string_view, and string_view to std::string is an explicit
// conversion.
state->set_current_work(StringPiece(work_)); // NOLINT
#else
state->set_current_work(string(work_));
#endif
}
string ReaderBase::KeyName(const string& key) const {
tstring ReaderBase::KeyName(const tstring& key) const {
return strings::StrCat(current_work(), ":", key);
}

View File

@ -52,15 +52,15 @@ class ReaderBase : public ReaderInterface {
// d) If there was an error producing (e.g. an error reading the file,
// data corruption), return a non-OK() status. ReadLocked may be
// called again if the user reruns this part of the graph.
virtual Status ReadLocked(string* key, string* value, bool* produced,
virtual Status ReadLocked(tstring* key, tstring* value, bool* produced,
bool* at_end) = 0;
// Descendants may optionally implement these -------------------------------
// Produce up to num_records next key/value pairs from the current
// work item, in the same manner of ReadLocked.
virtual Status ReadUpToLocked(int64 num_records, std::vector<string>* keys,
std::vector<string>* values, int64* num_read,
virtual Status ReadUpToLocked(int64 num_records, std::vector<tstring>* keys,
std::vector<tstring>* values, int64* num_read,
bool* at_end);
// Called when work starts / finishes.
@ -72,8 +72,8 @@ class ReaderBase : public ReaderInterface {
// Default implementation generates an Unimplemented error.
// See the protected helper methods below.
virtual Status SerializeStateLocked(string* state);
virtual Status RestoreStateLocked(const string& state);
virtual Status SerializeStateLocked(tstring* state);
virtual Status RestoreStateLocked(const tstring& state);
// Accessors ----------------------------------------------------------------
@ -83,13 +83,13 @@ class ReaderBase : public ReaderInterface {
// Returns the name of the current work item (valid if
// work_in_progress() returns true). May change between calls to
// ReadLocked().
const string& current_work() const { return work_; }
const tstring& current_work() const { return work_; }
// What was passed to the constructor.
const string& name() const { return name_; }
// Produce the key name (from current_work and the actual key).
string KeyName(const string& key) const;
tstring KeyName(const tstring& key) const;
protected:
// For descendants wishing to implement serialize & restore state.
@ -110,27 +110,27 @@ class ReaderBase : public ReaderInterface {
// Implementations of ReaderInterface methods. These ensure thread-safety
// and call the methods above to do the work.
void Read(QueueInterface* queue, string* key, string* value,
void Read(QueueInterface* queue, tstring* key, tstring* value,
OpKernelContext* context) override;
// Produces up to num_records.
// In this implementation all the records come from the same work unit.
int64 ReadUpTo(const int64 num_records, QueueInterface* queue,
std::vector<string>* keys, std::vector<string>* value,
std::vector<tstring>* keys, std::vector<tstring>* value,
OpKernelContext* context) override;
Status Reset() override;
int64 NumRecordsProduced() override;
int64 NumWorkUnitsCompleted() override;
Status SerializeState(string* state) override;
Status RestoreState(const string& state) override;
Status SerializeState(tstring* state) override;
Status RestoreState(const tstring& state) override;
mutable mutex mu_;
const string name_;
int64 work_started_ = 0;
int64 work_finished_ = 0;
int64 num_records_produced_ = 0;
string work_;
tstring work_;
};
} // namespace tensorflow

View File

@ -48,7 +48,7 @@ class ReaderInterface : public ResourceBase {
// *context with an OutOfRange Status if the current work is
// complete and the queue is done (closed and empty).
// This method may block.
virtual void Read(QueueInterface* queue, string* key, string* value,
virtual void Read(QueueInterface* queue, tstring* key, tstring* value,
OpKernelContext* context) = 0;
// Read up to num_records records into keys / values. May get more work from
@ -60,7 +60,8 @@ class ReaderInterface : public ResourceBase {
// structures (that have most likely been reserve(num_records)).
// Returns how many records were actually read.
virtual int64 ReadUpTo(const int64 num_records, QueueInterface* queue,
std::vector<string>* keys, std::vector<string>* value,
std::vector<tstring>* keys,
std::vector<tstring>* value,
OpKernelContext* context) = 0;
// Restore this reader to its newly-constructed state.
@ -72,9 +73,9 @@ class ReaderInterface : public ResourceBase {
// -- Serialization/Restoration support --
// Not all readers will support saving and restoring state.
virtual Status SerializeState(string* state) = 0;
virtual Status SerializeState(tstring* state) = 0;
// Note: Must Reset on error.
virtual Status RestoreState(const string& state) = 0;
virtual Status RestoreState(const tstring& state) = 0;
string DebugString() const override { return "a reader"; }

View File

@ -419,7 +419,7 @@ class CSVDatasetOp : public DatasetOpKernel {
Status SaveAndFillBuffer(std::vector<Piece>* earlier_pieces,
size_t* start, bool include)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
string temp_buffer;
tstring temp_buffer;
buffer_.swap(temp_buffer);
if (include && pos_ > *start) {
@ -622,7 +622,7 @@ class CSVDatasetOp : public DatasetOpKernel {
}
}
Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
Status FillBuffer(tstring* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
result->clear();
++num_buffer_reads_;
Status s = input_stream_->ReadNBytes(
@ -827,7 +827,7 @@ class CSVDatasetOp : public DatasetOpKernel {
}
mutex mu_;
string buffer_ GUARDED_BY(mu_); // Maintain our own buffer
tstring buffer_ GUARDED_BY(mu_); // Maintain our own buffer
size_t pos_ GUARDED_BY(
mu_); // Index into the buffer must be maintained between iters
size_t num_buffer_reads_ GUARDED_BY(mu_);

View File

@ -162,11 +162,11 @@ class SnapshotReader {
}
}
Status ReadRecord(string* record) {
Status ReadRecord(tstring* record) {
profiler::TraceMe activity(
absl::StrCat(kClassName, kSeparator, kReadString),
profiler::TraceMeLevel::kInfo);
string header;
tstring header;
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
uint64 length = core::DecodeFixed64(header.data());
return input_stream_->ReadNBytes(length, record);
@ -176,14 +176,14 @@ class SnapshotReader {
Status ReadRecord(absl::Cord* record) {
profiler::TraceMe activity(absl::StrCat(kClassName, kSeparator, kReadCord),
profiler::TraceMeLevel::kInfo);
string header;
tstring header;
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
uint64 length = core::DecodeFixed64(header.data());
if (compression_type_ == io::compression::kNone) {
return input_stream_->ReadNBytes(length, record);
} else {
string tmp_str;
tstring tmp_str;
Status s = input_stream_->ReadNBytes(length, &tmp_str);
record->Append(tmp_str);
return s;
@ -224,7 +224,7 @@ Status ReadMetadataFile(const string& hash_dir,
std::unique_ptr<RandomAccessFile> file;
TF_CHECK_OK(Env::Default()->NewRandomAccessFile(metadata_filename, &file));
string record_bytes;
tstring record_bytes;
auto reader = absl::make_unique<SnapshotReader>(file.get());
TF_CHECK_OK(reader->ReadRecord(&record_bytes));

View File

@ -258,7 +258,7 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase {
if (dataset()->compression_type_.empty()) {
DCHECK_GE(file_pos_limit_, 0);
if (current_pos < file_pos_limit_) {
string record;
tstring record;
TF_RETURN_IF_ERROR(buffered_input_stream_->ReadNBytes(
dataset()->record_bytes_, &record));
metrics::RecordTFDataBytesRead(kDatasetType,
@ -272,16 +272,18 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase {
return Status::OK();
}
} else {
string record;
tstring record;
Status s = buffered_input_stream_->ReadNBytes(
dataset()->record_bytes_, &record);
if (s.ok()) {
metrics::RecordTFDataBytesRead(kDatasetType,
dataset()->record_bytes_);
lookahead_cache_.append(record);
record = lookahead_cache_.substr(0, dataset()->record_bytes_);
lookahead_cache_ =
lookahead_cache_.substr(dataset()->record_bytes_);
StringPiece lookahead_cache_view(lookahead_cache_);
record = tstring(
lookahead_cache_view.substr(0, dataset()->record_bytes_));
lookahead_cache_ = tstring(
lookahead_cache_view.substr(dataset()->record_bytes_));
// Produce the record as output.
Tensor record_tensor(ctx->allocator({}), DT_STRING, {});
record_tensor.scalar<tstring>()() = std::move(record);
@ -433,7 +435,7 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase {
std::unique_ptr<io::InputStreamInterface> buffered_input_stream_
GUARDED_BY(mu_);
int64 file_pos_limit_ GUARDED_BY(mu_) = -1;
string lookahead_cache_ GUARDED_BY(mu_);
tstring lookahead_cache_ GUARDED_BY(mu_);
};
const std::vector<string> filenames_;

View File

@ -107,7 +107,7 @@ class TFRecordDatasetOp::Dataset : public DatasetBase {
out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
TensorShape({}));
Status s =
reader_->ReadRecord(&out_tensors->back().scalar<string>()());
reader_->ReadRecord(&out_tensors->back().scalar<tstring>()());
if (s.ok()) {
metrics::RecordTFDataBytesRead(
kDatasetType, out_tensors->back().scalar<tstring>()().size());
@ -208,7 +208,7 @@ class TFRecordDatasetOp::Dataset : public DatasetBase {
};
const std::vector<string> filenames_;
const string compression_type_;
const tstring compression_type_;
io::RecordReaderOptions options_;
};
@ -230,9 +230,9 @@ void TFRecordDatasetOp::MakeDataset(OpKernelContext* ctx,
filenames.push_back(filenames_tensor->flat<tstring>()(i));
}
string compression_type;
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, kCompressionType,
&compression_type));
tstring compression_type;
OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, kCompressionType,
&compression_type));
int64 buffer_size = -1;
OP_REQUIRES_OK(ctx,

View File

@ -46,7 +46,7 @@ class TFRecordDatasetOpTest : public DatasetOpsTestBase {
};
struct TestCase {
std::vector<string> filenames;
std::vector<tstring> filenames;
std::vector<std::vector<string>> contents;
CompressionType compression_type;
int64 buffer_size;
@ -84,12 +84,12 @@ TestCase TestCase1() {
/*compression_type*/ CompressionType::ZLIB,
/*buffer_size*/ 10,
/*expected_outputs*/
{CreateTensor<string>(TensorShape({}), {"1"}),
CreateTensor<string>(TensorShape({}), {"22"}),
CreateTensor<string>(TensorShape({}), {"333"}),
CreateTensor<string>(TensorShape({}), {"a"}),
CreateTensor<string>(TensorShape({}), {"bb"}),
CreateTensor<string>(TensorShape({}), {"ccc"})},
{CreateTensor<tstring>(TensorShape({}), {"1"}),
CreateTensor<tstring>(TensorShape({}), {"22"}),
CreateTensor<tstring>(TensorShape({}), {"333"}),
CreateTensor<tstring>(TensorShape({}), {"a"}),
CreateTensor<tstring>(TensorShape({}), {"bb"}),
CreateTensor<tstring>(TensorShape({}), {"ccc"})},
/*expected_output_dtypes*/ {DT_STRING},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ kUnknownCardinality,
@ -105,12 +105,12 @@ TestCase TestCase2() {
/*compression_type*/ CompressionType::GZIP,
/*buffer_size*/ 10,
/*expected_outputs*/
{CreateTensor<string>(TensorShape({}), {"1"}),
CreateTensor<string>(TensorShape({}), {"22"}),
CreateTensor<string>(TensorShape({}), {"333"}),
CreateTensor<string>(TensorShape({}), {"a"}),
CreateTensor<string>(TensorShape({}), {"bb"}),
CreateTensor<string>(TensorShape({}), {"ccc"})},
{CreateTensor<tstring>(TensorShape({}), {"1"}),
CreateTensor<tstring>(TensorShape({}), {"22"}),
CreateTensor<tstring>(TensorShape({}), {"333"}),
CreateTensor<tstring>(TensorShape({}), {"a"}),
CreateTensor<tstring>(TensorShape({}), {"bb"}),
CreateTensor<tstring>(TensorShape({}), {"ccc"})},
/*expected_output_dtypes*/ {DT_STRING},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ kUnknownCardinality,
@ -127,12 +127,12 @@ TestCase TestCase3() {
/*compression_type*/ CompressionType::UNCOMPRESSED,
/*buffer_size*/ 10,
/*expected_outputs*/
{CreateTensor<string>(TensorShape({}), {"1"}),
CreateTensor<string>(TensorShape({}), {"22"}),
CreateTensor<string>(TensorShape({}), {"333"}),
CreateTensor<string>(TensorShape({}), {"a"}),
CreateTensor<string>(TensorShape({}), {"bb"}),
CreateTensor<string>(TensorShape({}), {"ccc"})},
{CreateTensor<tstring>(TensorShape({}), {"1"}),
CreateTensor<tstring>(TensorShape({}), {"22"}),
CreateTensor<tstring>(TensorShape({}), {"333"}),
CreateTensor<tstring>(TensorShape({}), {"a"}),
CreateTensor<tstring>(TensorShape({}), {"bb"}),
CreateTensor<tstring>(TensorShape({}), {"ccc"})},
/*expected_output_dtypes*/ {DT_STRING},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ kUnknownCardinality,
@ -156,8 +156,8 @@ TEST_P(ParameterizedTFRecordDatasetOpTest, GetNext) {
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
CreateTensor<tstring>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<tstring>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
@ -206,8 +206,8 @@ TEST_F(TFRecordDatasetOpTest, DatasetNodeName) {
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
CreateTensor<tstring>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<tstring>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
@ -239,8 +239,8 @@ TEST_F(TFRecordDatasetOpTest, DatasetTypeString) {
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
CreateTensor<tstring>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<tstring>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
@ -273,8 +273,8 @@ TEST_P(ParameterizedTFRecordDatasetOpTest, DatasetOutputDtypes) {
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
CreateTensor<tstring>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<tstring>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
@ -307,8 +307,8 @@ TEST_P(ParameterizedTFRecordDatasetOpTest, DatasetOutputShapes) {
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
CreateTensor<tstring>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<tstring>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
@ -341,8 +341,8 @@ TEST_P(ParameterizedTFRecordDatasetOpTest, Cardinality) {
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
CreateTensor<tstring>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<tstring>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
@ -374,8 +374,8 @@ TEST_P(ParameterizedTFRecordDatasetOpTest, IteratorOutputDtypes) {
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
CreateTensor<tstring>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<tstring>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
@ -416,8 +416,8 @@ TEST_P(ParameterizedTFRecordDatasetOpTest, IteratorOutputShapes) {
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
CreateTensor<tstring>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<tstring>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
@ -458,8 +458,8 @@ TEST_P(ParameterizedTFRecordDatasetOpTest, IteratorOutputPrefix) {
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
CreateTensor<tstring>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<tstring>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
@ -501,8 +501,8 @@ TEST_P(ParameterizedTFRecordDatasetOpTest, Roundtrip) {
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
CreateTensor<tstring>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<tstring>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});

View File

@ -34,7 +34,7 @@ class MemoryInputStream : public io::InputStreamInterface {
~MemoryInputStream() override {}
Status ReadNBytes(int64 bytes_to_read, string* result) override {
Status ReadNBytes(int64 bytes_to_read, tstring* result) override {
result->clear();
if (bytes_to_read < 0) {
return errors::InvalidArgument("Can't read a negative number of bytes: ",
@ -106,7 +106,7 @@ class DecodeCompressedOp : public OpKernel {
new io::ZlibInputStream(
input_stream.get(), static_cast<size_t>(kBufferSize),
static_cast<size_t>(kBufferSize), zlib_options));
string output_string;
tstring output_string;
Status s = zlib_stream->ReadNBytes(INT_MAX, &output_string);
OP_REQUIRES(context, (s.ok() || errors::IsOutOfRange(s)), s);
output_flat(i) = std::move(output_string);

View File

@ -77,7 +77,7 @@ class FixedLengthRecordReader : public ReaderBase {
return Status::OK();
}
Status ReadLocked(string* key, string* value, bool* produced,
Status ReadLocked(tstring* key, tstring* value, bool* produced,
bool* at_end) override {
// We will always "hop" the hop_bytes_ except the first record
// where record_number_ == 0

View File

@ -33,7 +33,7 @@ class IdentityReader : public ReaderBase {
explicit IdentityReader(const string& node_name)
: ReaderBase(strings::StrCat("IdentityReader '", node_name, "'")) {}
Status ReadLocked(string* key, string* value, bool* produced,
Status ReadLocked(tstring* key, tstring* value, bool* produced,
bool* at_end) override {
*key = current_work();
*value = current_work();
@ -44,14 +44,14 @@ class IdentityReader : public ReaderBase {
// Stores state in a ReaderBaseState proto, since IdentityReader has
// no additional state beyond ReaderBase.
Status SerializeStateLocked(string* state) override {
Status SerializeStateLocked(tstring* state) override {
ReaderBaseState base_state;
SaveBaseState(&base_state);
base_state.SerializeToString(state);
SerializeToTString(base_state, state);
return Status::OK();
}
Status RestoreStateLocked(const string& state) override {
Status RestoreStateLocked(const tstring& state) override {
ReaderBaseState base_state;
if (!ParseProtoUnlimited(&base_state, state)) {
return errors::InvalidArgument("Could not parse state for ", name(), ": ",

9
tensorflow/core/kernels/lmdb_reader_op.cc Executable file → Normal file
View File

@ -68,7 +68,7 @@ class LMDBReader : public ReaderBase {
return Status::OK();
}
Status ReadLocked(string* key, string* value, bool* produced,
Status ReadLocked(tstring* key, tstring* value, bool* produced,
bool* at_end) override {
if (mdb_cursor_ == nullptr) {
MDB_CHECK(mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_));
@ -82,9 +82,10 @@ class LMDBReader : public ReaderBase {
return Status::OK();
}
}
*key = string(static_cast<const char*>(mdb_key_.mv_data), mdb_key_.mv_size);
*value = string(static_cast<const char*>(mdb_value_.mv_data),
mdb_value_.mv_size);
*key =
tstring(static_cast<const char*>(mdb_key_.mv_data), mdb_key_.mv_size);
*value = tstring(static_cast<const char*>(mdb_value_.mv_data),
mdb_value_.mv_size);
*produced = true;
return Status::OK();
}

View File

@ -90,9 +90,12 @@ class ReaderReadOp : public ReaderVerbAsyncOpKernel {
OP_REQUIRES_OK(context,
context->allocate_output("value", TensorShape({}), &value));
auto key_scalar = key->scalar<string>();
auto value_scalar = value->scalar<string>();
reader->Read(queue, &key_scalar(), &value_scalar(), context);
auto key_scalar = key->scalar<tstring>();
auto value_scalar = value->scalar<tstring>();
tstring key_out, val_out;
reader->Read(queue, &key_out, &val_out, context);
key_scalar() = key_out;
value_scalar() = val_out;
}
};
@ -115,9 +118,9 @@ class ReaderReadUpToOp : public ReaderVerbAsyncOpKernel {
GetResourceFromContext(context, "queue_handle", &queue));
core::ScopedUnref unref_me(queue);
std::vector<string> keys_vec;
std::vector<tstring> keys_vec;
keys_vec.reserve(num_records);
std::vector<string> values_vec;
std::vector<tstring> values_vec;
values_vec.reserve(num_records);
int64 num_actually_read =
@ -200,7 +203,7 @@ class ReaderSerializeStateOp : public ReaderVerbSyncOpKernel {
OP_REQUIRES_OK(context,
context->allocate_output("state", TensorShape({}), &output));
OP_REQUIRES_OK(context,
reader->SerializeState(&output->scalar<string>()()));
reader->SerializeState(&output->scalar<tstring>()()));
}
};

View File

@ -70,7 +70,7 @@ Status RecordYielder::YieldOne(tstring* value) {
struct RecordYielder::Shard {
int index; // Shard index.
std::vector<string> filenames; // File names given to this shard.
std::vector<tstring> filenames; // File names given to this shard.
Notification done; // Notified when this shard is done.
Status status; // Shard status.
};
@ -211,7 +211,7 @@ void RecordYielder::ShardLoop(Shard* shard) {
opts_.compression_type);
io::RecordReader rdr(file.get(), options);
uint64 offset = 0;
string record;
tstring record;
while (true) {
Status s = rdr.ReadRecord(&offset, &record);
if (s.ok()) {

View File

@ -56,7 +56,7 @@ class TextLineReader : public ReaderBase {
return Status::OK();
}
Status ReadLocked(string* key, string* value, bool* produced,
Status ReadLocked(tstring* key, tstring* value, bool* produced,
bool* at_end) override {
Status status = input_buffer_->ReadLine(value);
++line_number_;

View File

@ -50,7 +50,7 @@ class TFRecordReader : public ReaderBase {
return Status::OK();
}
Status ReadLocked(string* key, string* value, bool* produced,
Status ReadLocked(tstring* key, tstring* value, bool* produced,
bool* at_end) override {
*key = strings::StrCat(current_work(), ":", offset_);
Status status = reader_->ReadRecord(&offset_, value);

View File

@ -50,7 +50,7 @@ class WholeFileReader : public ReaderBase {
: ReaderBase(strings::StrCat("WholeFileReader '", node_name, "'")),
env_(env) {}
Status ReadLocked(string* key, string* value, bool* produced,
Status ReadLocked(tstring* key, tstring* value, bool* produced,
bool* at_end) override {
*key = current_work();
TF_RETURN_IF_ERROR(ReadEntireFile(env_, *key, value));
@ -61,14 +61,14 @@ class WholeFileReader : public ReaderBase {
// Stores state in a ReaderBaseState proto, since WholeFileReader has
// no additional state beyond ReaderBase.
Status SerializeStateLocked(string* state) override {
Status SerializeStateLocked(tstring* state) override {
ReaderBaseState base_state;
SaveBaseState(&base_state);
base_state.SerializeToString(state);
SerializeToTString(base_state, state);
return Status::OK();
}
Status RestoreStateLocked(const string& state) override {
Status RestoreStateLocked(const tstring& state) override {
ReaderBaseState base_state;
if (!ParseProtoUnlimited(&base_state, state)) {
return errors::InvalidArgument("Could not parse state for ", name(), ": ",

View File

@ -85,7 +85,7 @@ Status BufferedInputStream::ReadLineHelper(string* result, bool include_eol) {
return s;
}
Status BufferedInputStream::ReadNBytes(int64 bytes_to_read, string* result) {
Status BufferedInputStream::ReadNBytes(int64 bytes_to_read, tstring* result) {
if (bytes_to_read < 0) {
return errors::InvalidArgument("Can't read a negative number of bytes: ",
bytes_to_read);

View File

@ -41,7 +41,7 @@ class BufferedInputStream : public InputStreamInterface {
~BufferedInputStream() override;
Status ReadNBytes(int64 bytes_to_read, string* result) override;
Status ReadNBytes(int64 bytes_to_read, tstring* result) override;
Status SkipNBytes(int64 bytes_to_skip) override;
@ -90,7 +90,7 @@ class BufferedInputStream : public InputStreamInterface {
InputStreamInterface* input_stream_; // not owned.
size_t size_; // buffer size.
string buf_; // the buffer itself.
tstring buf_; // the buffer itself.
// buf_[pos_, limit_) holds the valid "read ahead" data in the file.
size_t pos_ = 0; // current position in buf_.
size_t limit_ = 0; // just past the end of valid data in buf_.

View File

@ -163,7 +163,7 @@ TEST(BufferedInputStream, ReadNBytes) {
for (auto buf_size : BufferSizes()) {
std::unique_ptr<RandomAccessInputStream> input_stream(
new RandomAccessInputStream(file.get()));
string read;
tstring read;
BufferedInputStream in(input_stream.get(), buf_size);
EXPECT_EQ(0, in.Tell());
TF_ASSERT_OK(in.ReadNBytes(3, &read));
@ -200,7 +200,7 @@ TEST(BufferedInputStream, SkipNBytes) {
for (auto buf_size : BufferSizes()) {
std::unique_ptr<RandomAccessInputStream> input_stream(
new RandomAccessInputStream(file.get()));
string read;
tstring read;
BufferedInputStream in(input_stream.get(), buf_size);
EXPECT_EQ(0, in.Tell());
TF_ASSERT_OK(in.SkipNBytes(3));
@ -235,7 +235,7 @@ TEST(BufferedInputStream, ReadNBytesRandomAccessFile) {
TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file));
for (auto buf_size : BufferSizes()) {
string read;
tstring read;
BufferedInputStream in(file.get(), buf_size);
EXPECT_EQ(0, in.Tell());
TF_ASSERT_OK(in.ReadNBytes(3, &read));
@ -270,7 +270,7 @@ TEST(BufferedInputStream, SkipNBytesRandomAccessFile) {
TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file));
for (auto buf_size : BufferSizes()) {
string read;
tstring read;
BufferedInputStream in(file.get(), buf_size);
EXPECT_EQ(0, in.Tell());
TF_ASSERT_OK(in.SkipNBytes(3));
@ -307,7 +307,7 @@ TEST(BufferedInputStream, Seek) {
for (auto buf_size : BufferSizes()) {
std::unique_ptr<RandomAccessInputStream> input_stream(
new RandomAccessInputStream(file.get()));
string read;
tstring read;
BufferedInputStream in(input_stream.get(), buf_size);
// Seek forward
@ -378,7 +378,7 @@ void BM_BufferedReaderSmallReads(const int iters, const int buff_size,
std::unique_ptr<RandomAccessFile> file;
TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file));
string result;
tstring result;
testing::StartTiming();
for (int itr = 0; itr < iters; ++itr) {

View File

@ -42,7 +42,8 @@ Status InputBuffer::FillBuffer() {
return s;
}
Status InputBuffer::ReadLine(string* result) {
template <typename T>
Status InputBuffer::ReadLine(T* result) {
result->clear();
Status s;
do {
@ -71,6 +72,11 @@ Status InputBuffer::ReadLine(string* result) {
return s;
}
template Status InputBuffer::ReadLine<string>(string* result);
#ifdef USE_TSTRING
template Status InputBuffer::ReadLine<tstring>(tstring* result);
#endif // USE_TSTRING
Status InputBuffer::ReadNBytes(int64 bytes_to_read, string* result) {
result->clear();
if (bytes_to_read < 0) {

View File

@ -43,7 +43,8 @@ class InputBuffer {
// If successful, returns OK. If we are already at the end of the
// file, we return an OUT_OF_RANGE error. Otherwise, we return
// some other non-OK status.
Status ReadLine(string* result);
template <typename T>
Status ReadLine(T* result);
// Reads bytes_to_read bytes into *result, overwriting *result.
//

View File

@ -28,7 +28,7 @@ Status InputStreamInterface::SkipNBytes(int64 bytes_to_skip) {
if (bytes_to_skip < 0) {
return errors::InvalidArgument("Can't skip a negative number of bytes");
}
string unused;
tstring unused;
// Read kDefaultSkipSize at a time till bytes_to_skip.
while (bytes_to_skip > 0) {
int64 bytes_to_read = std::min<int64>(kMaxSkipSize, bytes_to_skip);

View File

@ -35,7 +35,7 @@ class InputStreamInterface {
// Reads the next bytes_to_read from the file. Typical return codes:
// * OK - in case of success.
// * OUT_OF_RANGE - not enough bytes remaining before end of file.
virtual Status ReadNBytes(int64 bytes_to_read, string* result) = 0;
virtual Status ReadNBytes(int64 bytes_to_read, tstring* result) = 0;
#if defined(PLATFORM_GOOGLE)
// Reads the next bytes_to_read from the file. Typical return codes:

View File

@ -27,7 +27,7 @@ class TestStringStream : public InputStreamInterface {
public:
explicit TestStringStream(const string& content) : content_(content) {}
Status ReadNBytes(int64 bytes_to_read, string* result) override {
Status ReadNBytes(int64 bytes_to_read, tstring* result) override {
result->clear();
if (pos_ + bytes_to_read > content_.size()) {
return errors::OutOfRange("limit reached");
@ -51,7 +51,7 @@ class TestStringStream : public InputStreamInterface {
TEST(InputStreamInterface, Basic) {
TestStringStream ss("This is a test string");
string res;
tstring res;
TF_ASSERT_OK(ss.ReadNBytes(4, &res));
EXPECT_EQ("This", res);
TF_ASSERT_OK(ss.SkipNBytes(6));

View File

@ -30,7 +30,7 @@ RandomAccessInputStream::~RandomAccessInputStream() {
}
Status RandomAccessInputStream::ReadNBytes(int64 bytes_to_read,
string* result) {
tstring* result) {
if (bytes_to_read < 0) {
return errors::InvalidArgument("Cannot read negative number of bytes");
}

View File

@ -33,7 +33,7 @@ class RandomAccessInputStream : public InputStreamInterface {
~RandomAccessInputStream();
Status ReadNBytes(int64 bytes_to_read, string* result) override;
Status ReadNBytes(int64 bytes_to_read, tstring* result) override;
#if defined(PLATFORM_GOOGLE)
Status ReadNBytes(int64 bytes_to_read, absl::Cord* result) override;

View File

@ -30,7 +30,7 @@ TEST(RandomInputStream, ReadNBytes) {
std::unique_ptr<RandomAccessFile> file;
TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file));
string read;
tstring read;
RandomAccessInputStream in(file.get());
TF_ASSERT_OK(in.ReadNBytes(3, &read));
EXPECT_EQ(read, "012");
@ -59,7 +59,7 @@ TEST(RandomInputStream, SkipNBytes) {
std::unique_ptr<RandomAccessFile> file;
TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file));
string read;
tstring read;
RandomAccessInputStream in(file.get());
TF_ASSERT_OK(in.SkipNBytes(3));
EXPECT_EQ(3, in.Tell());
@ -90,7 +90,7 @@ TEST(RandomInputStream, Seek) {
std::unique_ptr<RandomAccessFile> file;
TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file));
string read;
tstring read;
RandomAccessInputStream in(file.get());
// Seek forward

View File

@ -84,7 +84,7 @@ RecordReader::RecordReader(RandomAccessFile* file,
//
// offset corresponds to the user-provided value to ReadRecord()
// and is used only in error messages.
Status RecordReader::ReadChecksummed(uint64 offset, size_t n, string* result) {
Status RecordReader::ReadChecksummed(uint64 offset, size_t n, tstring* result) {
if (n >= SIZE_MAX - sizeof(uint32)) {
return errors::DataLoss("record size too large");
}
@ -125,7 +125,7 @@ Status RecordReader::GetMetadata(Metadata* md) {
// loop should be guaranteed to either return after reaching EOF
// or encountering an error.
uint64 offset = 0;
string record;
tstring record;
while (true) {
// Read header, containing size of data.
Status s = ReadChecksummed(offset, sizeof(uint64), &record);
@ -161,7 +161,7 @@ Status RecordReader::GetMetadata(Metadata* md) {
return Status::OK();
}
Status RecordReader::ReadRecord(uint64* offset, string* record) {
Status RecordReader::ReadRecord(uint64* offset, tstring* record) {
// Position the input stream.
int64 curr_pos = input_stream_->Tell();
int64 desired_pos = static_cast<int64>(*offset);

View File

@ -89,7 +89,7 @@ class RecordReader {
// Read the record at "*offset" into *record and update *offset to
// point to the offset of the next record. Returns OK on success,
// OUT_OF_RANGE for end of file, or something else for an error.
Status ReadRecord(uint64* offset, string* record);
Status ReadRecord(uint64* offset, tstring* record);
// Return the metadata of the Record file.
//
@ -103,7 +103,7 @@ class RecordReader {
Status GetMetadata(Metadata* md);
private:
Status ReadChecksummed(uint64 offset, size_t n, string* result);
Status ReadChecksummed(uint64 offset, size_t n, tstring* result);
RecordReaderOptions options_;
std::unique_ptr<InputStreamInterface> input_stream_;
@ -129,7 +129,7 @@ class SequentialRecordReader {
// Reads the next record in the file into *record. Returns OK on success,
// OUT_OF_RANGE for end of file, or something else for an error.
Status ReadRecord(string* record) {
Status ReadRecord(tstring* record) {
return underlying_.ReadRecord(&offset_, record);
}

View File

@ -86,7 +86,7 @@ void VerifyFlush(const io::RecordWriterOptions& options) {
// Verify that file has all records written so far and no more.
uint64 offset = 0;
string record;
tstring record;
for (size_t j = 0; j <= i; j++) {
// Check that j'th record is written correctly.
TF_CHECK_OK(reader.ReadRecord(&offset, &record));
@ -142,7 +142,7 @@ TEST(RecordReaderWriterTest, TestBasics) {
options.zlib_options.input_buffer_size = buf_size;
io::RecordReader reader(read_file.get(), options);
uint64 offset = 0;
string record;
tstring record;
TF_CHECK_OK(reader.ReadRecord(&offset, &record));
EXPECT_EQ("abc", record);
TF_CHECK_OK(reader.ReadRecord(&offset, &record));
@ -187,7 +187,7 @@ TEST(RecordReaderWriterTest, TestZlib) {
options.zlib_options.input_buffer_size = buf_size;
io::RecordReader reader(read_file.get(), options);
uint64 offset = 0;
string record;
tstring record;
TF_CHECK_OK(reader.ReadRecord(&offset, &record));
EXPECT_EQ("abc", record);
TF_CHECK_OK(reader.ReadRecord(&offset, &record));

View File

@ -149,7 +149,7 @@ class RecordioTest : public ::testing::Test {
if (!reading_) {
reading_ = true;
}
string record;
tstring record;
Status s = reader_->ReadRecord(&readpos_, &record);
if (s.ok()) {
return record;
@ -183,7 +183,7 @@ class RecordioTest : public ::testing::Test {
Write(BigString("x", 10000));
reading_ = true;
uint64 offset = WrittenBytes() + offset_past_end;
string record;
tstring record;
Status s = reader_->ReadRecord(&offset, &record);
ASSERT_TRUE(errors::IsOutOfRange(s)) << s;
}
@ -261,7 +261,7 @@ void TestNonSequentialReads(const RecordWriterOptions& writer_options,
StringSource file(&contents);
RecordReader reader(&file, reader_options);
string record;
tstring record;
// First read sequentially to fill in the offsets table.
uint64 offsets[10] = {0};
uint64 offset = 0;
@ -315,7 +315,7 @@ void TestReadError(const RecordWriterOptions& writer_options,
RecordReader reader(&file, reader_options);
uint64 offset = 0;
string read;
tstring read;
file.force_error();
Status status = reader.ReadRecord(&offset, &read);
ASSERT_TRUE(errors::IsDataLoss(status));

View File

@ -121,7 +121,7 @@ Status TestMultipleWrites(size_t compress_input_buf_size,
for (int attempt = 0; attempt < 2; ++attempt) {
string actual_result;
for (int i = 0; i < num_writes; i++) {
string decompressed_output;
tstring decompressed_output;
TF_RETURN_IF_ERROR(in.ReadNBytes(data.size(), &decompressed_output));
strings::StrAppend(&actual_result, decompressed_output);
}

View File

@ -29,7 +29,7 @@ SnappyInputBuffer::SnappyInputBuffer(
output_buffer_(new char[output_buffer_capacity_]),
next_in_(input_buffer_.get()) {}
Status SnappyInputBuffer::ReadNBytes(int64 bytes_to_read, string* result) {
Status SnappyInputBuffer::ReadNBytes(int64 bytes_to_read, tstring* result) {
result->clear();
// Read as many bytes as possible from cache.
bytes_to_read -= ReadBytesFromCache(bytes_to_read, result);
@ -62,7 +62,7 @@ Status SnappyInputBuffer::Reset() {
}
size_t SnappyInputBuffer::ReadBytesFromCache(size_t bytes_to_read,
string* result) {
tstring* result) {
size_t can_read_bytes = std::min(bytes_to_read, avail_out_);
if (can_read_bytes > 0) {
result->append(next_out_, can_read_bytes);

View File

@ -54,7 +54,7 @@ class SnappyInputBuffer : public InputStreamInterface {
// If input_buffer_ is smaller in size than a compressed block.
// others:
// If reading from file failed.
Status ReadNBytes(int64 bytes_to_read, string* result) override;
Status ReadNBytes(int64 bytes_to_read, tstring* result) override;
int64 Tell() const override;
@ -86,7 +86,7 @@ class SnappyInputBuffer : public InputStreamInterface {
// bytes have been read or `next_out_` is reached.
// Returns the number of bytes read and advances the `next_out_`
// pointer to the next location to read from.
size_t ReadBytesFromCache(size_t bytes_to_read, string* result);
size_t ReadBytesFromCache(size_t bytes_to_read, tstring* result);
// Reads the length of the next *compressed* block and stores in `length`.
// The length is stored in 4 bytes in little endian notation.

View File

@ -69,7 +69,7 @@ void TestAllCombinations(CompressionOptions input_options,
for (auto output_buf_size : OutputBufferSizes()) {
std::unique_ptr<WritableFile> file_writer;
TF_ASSERT_OK(env->NewWritableFile(fname, &file_writer));
string result;
tstring result;
ZlibOutputBuffer out(file_writer.get(), input_buf_size, output_buf_size,
output_options);
@ -142,7 +142,7 @@ void TestMultipleWrites(uint8 input_buf_size, uint8 output_buf_size,
input_options);
for (int i = 0; i < num_writes; i++) {
string decompressed_output;
tstring decompressed_output;
TF_ASSERT_OK(in.ReadNBytes(data.size(), &decompressed_output));
strings::StrAppend(&actual_result, decompressed_output);
}
@ -171,7 +171,7 @@ TEST(ZlibInputStream, FailsToReadIfWindowBitsAreIncompatible) {
string data = GenTestString(10);
std::unique_ptr<WritableFile> file_writer;
TF_ASSERT_OK(env->NewWritableFile(fname, &file_writer));
string result;
tstring result;
ZlibOutputBuffer out(file_writer.get(), input_buf_size, output_buf_size,
output_options);
TF_ASSERT_OK(out.Init());
@ -229,8 +229,8 @@ void TestTell(CompressionOptions input_options,
ZlibInputStream in(input_stream.get(), input_buf_size, output_buf_size,
input_options);
string first_half(data, 0, data.size() / 2);
string bytes_read;
tstring first_half(string(data, 0, data.size() / 2));
tstring bytes_read;
// Read the first half of the uncompressed file and expect that Tell()
// returns half the uncompressed length of the file.
@ -240,7 +240,7 @@ void TestTell(CompressionOptions input_options,
// Read the remaining half of the uncompressed file and expect that
// Tell() points past the end of file.
string second_half;
tstring second_half;
TF_ASSERT_OK(
in.ReadNBytes(data.size() - first_half.size(), &second_half));
EXPECT_EQ(in.Tell(), data.size());
@ -283,7 +283,7 @@ void TestSkipNBytes(CompressionOptions input_options,
// Expect that second half is read correctly and Tell() returns past
// end of file after reading complete file.
string bytes_read;
tstring bytes_read;
TF_ASSERT_OK(in.ReadNBytes(second_half.size(), &bytes_read));
EXPECT_EQ(bytes_read, second_half);
EXPECT_EQ(in.Tell(), data.size());

View File

@ -132,7 +132,7 @@ Status ZlibInputStream::ReadFromStream() {
bytes_to_read -= z_stream_def_->stream->avail_in;
read_location += z_stream_def_->stream->avail_in;
}
string data;
tstring data;
// Try to read enough data to fill up z_stream_def_->input.
// TODO(rohanj): Add a char* version of ReadNBytes to InputStreamInterface
// and use that instead to make this more efficient.
@ -166,7 +166,7 @@ Status ZlibInputStream::ReadFromStream() {
}
size_t ZlibInputStream::ReadBytesFromCache(size_t bytes_to_read,
string* result) {
tstring* result) {
size_t unread_bytes =
reinterpret_cast<char*>(z_stream_def_->stream->next_out) -
next_unread_byte_;
@ -186,7 +186,7 @@ size_t ZlibInputStream::NumUnreadBytes() const {
read_bytes;
}
Status ZlibInputStream::ReadNBytes(int64 bytes_to_read, string* result) {
Status ZlibInputStream::ReadNBytes(int64 bytes_to_read, tstring* result) {
result->clear();
// Read as many bytes as possible from cache.
bytes_to_read -= ReadBytesFromCache(bytes_to_read, result);

View File

@ -66,7 +66,7 @@ class ZlibInputStream : public InputStreamInterface {
// ABORTED: If inflate() fails, we return the error code with the
// error message in `z_stream_->msg`.
// others: If reading from stream failed.
Status ReadNBytes(int64 bytes_to_read, string* result) override;
Status ReadNBytes(int64 bytes_to_read, tstring* result) override;
int64 Tell() const override;
@ -107,7 +107,7 @@ class ZlibInputStream : public InputStreamInterface {
// bytes have been read or `z_stream_->next_out` is reached.
// Returns the number of bytes read and advances the `next_unread_byte_`
// pointer to the next location to read from.
size_t ReadBytesFromCache(size_t bytes_to_read, string* result);
size_t ReadBytesFromCache(size_t bytes_to_read, tstring* result);
// The number of unread bytes in z_stream_output_.
//

View File

@ -163,6 +163,8 @@ class tstring {
const char* data() const { return str_.data(); }
char back() const { return str_.back(); }
const char& operator[](size_t i) const { return str_[i]; }
char* data() { return &str_[0]; }
@ -209,6 +211,15 @@ class tstring {
return *this;
}
void swap(tstring& str) { str_.swap(str.str_); }
tstring& insert(size_t pos, const tstring& str, size_t subpos,
size_t sublen) {
str_.insert(pos, str.str_, subpos, sublen);
return *this;
}
void push_back(char ch) { str_.push_back(ch); }
friend const tstring operator+(const tstring& a, const tstring& b);

View File

@ -96,7 +96,7 @@ int main(int argc, char* argv[]) {
uint64 start = env->NowMicros();
uint64 records = 0;
uint64 offset = 0;
string record;
tstring record;
while (true) {
std::unique_ptr<Event> event = std::unique_ptr<Event>(new Event);
Status s = reader.ReadRecord(&offset, &record);

View File

@ -69,7 +69,7 @@ class SummaryFileWriterTest : public ::testing::Test {
TF_CHECK_OK(env_.NewRandomAccessFile(io::JoinPath(testing::TmpDir(), f),
&read_file));
io::RecordReader reader(read_file.get(), io::RecordReaderOptions());
string record;
tstring record;
uint64 offset = 0;
TF_CHECK_OK(
reader.ReadRecord(&offset,

View File

@ -53,7 +53,7 @@ void WriteFile(EventsWriter* writer) {
static bool ReadEventProto(io::RecordReader* reader, uint64* offset,
Event* proto) {
string record;
tstring record;
Status s = reader->ReadRecord(offset, &record);
if (!s.ok()) {
return false;

View File

@ -228,7 +228,7 @@ int64 TellFile(tensorflow::WritableFile* file, TF_Status* status) {
string ReadFromStream(tensorflow::io::BufferedInputStream* stream,
size_t bytes,
TF_Status* status) {
string result;
tensorflow::tstring result;
tensorflow::Status s = stream->ReadNBytes(bytes, &result);
if (!s.ok() && s.code() != tensorflow::error::OUT_OF_RANGE) {
Set_TF_Status_from_Status(status, s);

View File

@ -63,7 +63,7 @@ class PyRecordReader {
uint64 offset_;
RandomAccessFile* file_; // Owned
io::RecordReader* reader_; // Owned
string record_;
tstring record_;
TF_DISALLOW_COPY_AND_ASSIGN(PyRecordReader);
};