Updated ReaderInterface and subclasses to use tstring.
This is a part of a larger migration effort for tensorflow::tstring. See: https://github.com/tensorflow/community/pull/91 PiperOrigin-RevId: 265822025
This commit is contained in:
parent
8df6f08527
commit
7ba3600c94
@ -186,7 +186,7 @@ Status AutoClusteringTest::RunAutoClusteringTestWithGzippedPbtxt(
|
||||
/*input_buffer_bytes=*/k_buffer_size,
|
||||
/*output_buffer_bytes=*/k_buffer_size,
|
||||
io::ZlibCompressionOptions::GZIP());
|
||||
string decompressed_pbtxt_string;
|
||||
tstring decompressed_pbtxt_string;
|
||||
Status s = in.ReadNBytes(INT_MAX, &decompressed_pbtxt_string);
|
||||
if (!s.ok() && !errors::IsOutOfRange(s)) {
|
||||
// OutOfRange is fine since we set the number of read bytes to INT_MAX.
|
||||
|
@ -121,7 +121,7 @@ class InitializeTRTResource : public OpKernel {
|
||||
uint64 offset = 0;
|
||||
int num_loaded_engine = 0;
|
||||
do {
|
||||
string record;
|
||||
tstring record;
|
||||
Status status = reader->ReadRecord(&offset, &record);
|
||||
if (errors::IsOutOfRange(status)) break;
|
||||
|
||||
|
@ -66,7 +66,7 @@ class BigQueryReader : public ReaderBase {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadLocked(string* key, string* value, bool* produced,
|
||||
Status ReadLocked(tstring* key, tstring* value, bool* produced,
|
||||
bool* at_end) override {
|
||||
*at_end = false;
|
||||
*produced = false;
|
||||
|
@ -31,12 +31,13 @@ class SequenceFileReader {
|
||||
new io::BufferedInputStream(file, kSequenceFileBufferSize)) {}
|
||||
|
||||
Status ReadHeader() {
|
||||
string version;
|
||||
tstring version;
|
||||
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(4, &version));
|
||||
if (version.substr(0, 3) != "SEQ" || version[3] != 6) {
|
||||
StringPiece version_view(version);
|
||||
if (version_view.substr(0, 3) != "SEQ" || version[3] != 6) {
|
||||
return errors::InvalidArgument(
|
||||
"sequence file header must starts with `SEQ6`, received \"",
|
||||
version.substr(0, 3), static_cast<int>(version[3]), "\"");
|
||||
version_view.substr(0, 3), static_cast<int>(version[3]), "\"");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ReadString(&key_class_name_));
|
||||
TF_RETURN_IF_ERROR(ReadString(&value_class_name_));
|
||||
@ -50,7 +51,7 @@ class SequenceFileReader {
|
||||
"' is currently not supported");
|
||||
}
|
||||
|
||||
string buffer;
|
||||
tstring buffer;
|
||||
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(2, &buffer));
|
||||
compression_ = buffer[0];
|
||||
block_compression_ = buffer[1];
|
||||
@ -84,12 +85,12 @@ class SequenceFileReader {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadRecord(string* key, string* value) {
|
||||
Status ReadRecord(tstring* key, tstring* value) {
|
||||
uint32 length = 0;
|
||||
TF_RETURN_IF_ERROR(ReadUInt32(&length));
|
||||
if (length == static_cast<uint32>(-1)) {
|
||||
// Sync marker.
|
||||
string sync_marker;
|
||||
tstring sync_marker;
|
||||
TF_RETURN_IF_ERROR(
|
||||
input_stream_->ReadNBytes(kSyncMarkerSize, &sync_marker));
|
||||
if (sync_marker != sync_marker_) {
|
||||
@ -114,7 +115,7 @@ class SequenceFileReader {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadString(string* value) {
|
||||
Status ReadString(tstring* value) {
|
||||
int64 length = 0;
|
||||
TF_RETURN_IF_ERROR(ReadVInt(&length));
|
||||
if (value == nullptr) {
|
||||
@ -124,7 +125,7 @@ class SequenceFileReader {
|
||||
}
|
||||
|
||||
Status ReadUInt32(uint32* value) {
|
||||
string buffer;
|
||||
tstring buffer;
|
||||
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(4, &buffer));
|
||||
*value = ((static_cast<uint32>(buffer[0]) << 24) |
|
||||
static_cast<uint32>(buffer[1]) << 16) |
|
||||
@ -134,7 +135,7 @@ class SequenceFileReader {
|
||||
}
|
||||
|
||||
Status ReadVInt(int64* value) {
|
||||
string buffer;
|
||||
tstring buffer;
|
||||
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(1, &buffer));
|
||||
if (buffer[0] >= -112) {
|
||||
*value = static_cast<int64>(buffer[0]);
|
||||
@ -167,12 +168,12 @@ class SequenceFileReader {
|
||||
|
||||
private:
|
||||
std::unique_ptr<io::InputStreamInterface> input_stream_;
|
||||
string key_class_name_;
|
||||
string value_class_name_;
|
||||
string sync_marker_;
|
||||
tstring key_class_name_;
|
||||
tstring value_class_name_;
|
||||
tstring sync_marker_;
|
||||
bool compression_;
|
||||
bool block_compression_;
|
||||
string compression_codec_class_name_;
|
||||
tstring compression_codec_class_name_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SequenceFileReader);
|
||||
};
|
||||
class SequenceFileDatasetOp : public DatasetOpKernel {
|
||||
@ -258,7 +259,7 @@ class SequenceFileDatasetOp : public DatasetOpKernel {
|
||||
do {
|
||||
// We are currently processing a file, so try to read the next record.
|
||||
if (reader_) {
|
||||
string key, value;
|
||||
tstring key, value;
|
||||
Status status = reader_->ReadRecord(&key, &value);
|
||||
if (!errors::IsOutOfRange(status)) {
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
|
@ -53,16 +53,16 @@ Status ReaderBase::ResetLocked() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReaderBase::SerializeState(string* state) {
|
||||
Status ReaderBase::SerializeState(tstring* state) {
|
||||
mutex_lock lock(mu_);
|
||||
return SerializeStateLocked(state);
|
||||
}
|
||||
|
||||
Status ReaderBase::SerializeStateLocked(string* state) {
|
||||
Status ReaderBase::SerializeStateLocked(tstring* state) {
|
||||
return errors::Unimplemented("Reader SerializeState");
|
||||
}
|
||||
|
||||
Status ReaderBase::RestoreState(const string& state) {
|
||||
Status ReaderBase::RestoreState(const tstring& state) {
|
||||
mutex_lock lock(mu_);
|
||||
Status status = RestoreStateLocked(state);
|
||||
if (!status.ok()) {
|
||||
@ -71,13 +71,13 @@ Status ReaderBase::RestoreState(const string& state) {
|
||||
return status;
|
||||
}
|
||||
|
||||
Status ReaderBase::RestoreStateLocked(const string& state) {
|
||||
Status ReaderBase::RestoreStateLocked(const tstring& state) {
|
||||
return errors::Unimplemented("Reader RestoreState");
|
||||
}
|
||||
|
||||
int64 ReaderBase::ReadUpTo(const int64 num_records, QueueInterface* queue,
|
||||
std::vector<string>* keys,
|
||||
std::vector<string>* values,
|
||||
std::vector<tstring>* keys,
|
||||
std::vector<tstring>* values,
|
||||
OpKernelContext* context) {
|
||||
mutex_lock lock(mu_);
|
||||
int64 records_produced_this_call = 0;
|
||||
@ -133,16 +133,16 @@ int64 ReaderBase::ReadUpTo(const int64 num_records, QueueInterface* queue,
|
||||
}
|
||||
|
||||
// Default implementation just reads one record at a time.
|
||||
Status ReaderBase::ReadUpToLocked(int64 num_records, std::vector<string>* keys,
|
||||
std::vector<string>* values, int64* num_read,
|
||||
Status ReaderBase::ReadUpToLocked(int64 num_records, std::vector<tstring>* keys,
|
||||
std::vector<tstring>* values, int64* num_read,
|
||||
bool* at_end) {
|
||||
bool produced = false;
|
||||
string key;
|
||||
string value;
|
||||
tstring key;
|
||||
tstring value;
|
||||
Status status = ReadLocked(&key, &value, &produced, at_end);
|
||||
if (produced) {
|
||||
keys->emplace_back(key);
|
||||
values->emplace_back(value);
|
||||
keys->push_back(std::move(key));
|
||||
values->push_back(std::move(value));
|
||||
*num_read = 1;
|
||||
} else {
|
||||
*num_read = 0;
|
||||
@ -150,7 +150,7 @@ Status ReaderBase::ReadUpToLocked(int64 num_records, std::vector<string>* keys,
|
||||
return status;
|
||||
}
|
||||
|
||||
void ReaderBase::Read(QueueInterface* queue, string* key, string* value,
|
||||
void ReaderBase::Read(QueueInterface* queue, tstring* key, tstring* value,
|
||||
OpKernelContext* context) {
|
||||
mutex_lock lock(mu_);
|
||||
while (true) {
|
||||
@ -228,10 +228,19 @@ void ReaderBase::SaveBaseState(ReaderBaseState* state) const {
|
||||
state->set_work_started(work_started_);
|
||||
state->set_work_finished(work_finished_);
|
||||
state->set_num_records_produced(num_records_produced_);
|
||||
state->set_current_work(work_);
|
||||
// Unfortunately, external proto does not accept string_view.
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
// TODO(dero): Remove NOLINT after USE_TSTRING is enabled. The external proto
|
||||
// compiler does not create an overloaded set method that accepts
|
||||
// absl::string_view, and string_view to std::string is an explicit
|
||||
// conversion.
|
||||
state->set_current_work(StringPiece(work_)); // NOLINT
|
||||
#else
|
||||
state->set_current_work(string(work_));
|
||||
#endif
|
||||
}
|
||||
|
||||
string ReaderBase::KeyName(const string& key) const {
|
||||
tstring ReaderBase::KeyName(const tstring& key) const {
|
||||
return strings::StrCat(current_work(), ":", key);
|
||||
}
|
||||
|
||||
|
@ -52,15 +52,15 @@ class ReaderBase : public ReaderInterface {
|
||||
// d) If there was an error producing (e.g. an error reading the file,
|
||||
// data corruption), return a non-OK() status. ReadLocked may be
|
||||
// called again if the user reruns this part of the graph.
|
||||
virtual Status ReadLocked(string* key, string* value, bool* produced,
|
||||
virtual Status ReadLocked(tstring* key, tstring* value, bool* produced,
|
||||
bool* at_end) = 0;
|
||||
|
||||
// Descendants may optionally implement these -------------------------------
|
||||
|
||||
// Produce up to num_records next key/value pairs from the current
|
||||
// work item, in the same manner of ReadLocked.
|
||||
virtual Status ReadUpToLocked(int64 num_records, std::vector<string>* keys,
|
||||
std::vector<string>* values, int64* num_read,
|
||||
virtual Status ReadUpToLocked(int64 num_records, std::vector<tstring>* keys,
|
||||
std::vector<tstring>* values, int64* num_read,
|
||||
bool* at_end);
|
||||
|
||||
// Called when work starts / finishes.
|
||||
@ -72,8 +72,8 @@ class ReaderBase : public ReaderInterface {
|
||||
|
||||
// Default implementation generates an Unimplemented error.
|
||||
// See the protected helper methods below.
|
||||
virtual Status SerializeStateLocked(string* state);
|
||||
virtual Status RestoreStateLocked(const string& state);
|
||||
virtual Status SerializeStateLocked(tstring* state);
|
||||
virtual Status RestoreStateLocked(const tstring& state);
|
||||
|
||||
// Accessors ----------------------------------------------------------------
|
||||
|
||||
@ -83,13 +83,13 @@ class ReaderBase : public ReaderInterface {
|
||||
// Returns the name of the current work item (valid if
|
||||
// work_in_progress() returns true). May change between calls to
|
||||
// ReadLocked().
|
||||
const string& current_work() const { return work_; }
|
||||
const tstring& current_work() const { return work_; }
|
||||
|
||||
// What was passed to the constructor.
|
||||
const string& name() const { return name_; }
|
||||
|
||||
// Produce the key name (from current_work and the actual key).
|
||||
string KeyName(const string& key) const;
|
||||
tstring KeyName(const tstring& key) const;
|
||||
|
||||
protected:
|
||||
// For descendants wishing to implement serialize & restore state.
|
||||
@ -110,27 +110,27 @@ class ReaderBase : public ReaderInterface {
|
||||
|
||||
// Implementations of ReaderInterface methods. These ensure thread-safety
|
||||
// and call the methods above to do the work.
|
||||
void Read(QueueInterface* queue, string* key, string* value,
|
||||
void Read(QueueInterface* queue, tstring* key, tstring* value,
|
||||
OpKernelContext* context) override;
|
||||
|
||||
// Produces up to num_records.
|
||||
// In this implementation all the records come from the same work unit.
|
||||
int64 ReadUpTo(const int64 num_records, QueueInterface* queue,
|
||||
std::vector<string>* keys, std::vector<string>* value,
|
||||
std::vector<tstring>* keys, std::vector<tstring>* value,
|
||||
OpKernelContext* context) override;
|
||||
|
||||
Status Reset() override;
|
||||
int64 NumRecordsProduced() override;
|
||||
int64 NumWorkUnitsCompleted() override;
|
||||
Status SerializeState(string* state) override;
|
||||
Status RestoreState(const string& state) override;
|
||||
Status SerializeState(tstring* state) override;
|
||||
Status RestoreState(const tstring& state) override;
|
||||
|
||||
mutable mutex mu_;
|
||||
const string name_;
|
||||
int64 work_started_ = 0;
|
||||
int64 work_finished_ = 0;
|
||||
int64 num_records_produced_ = 0;
|
||||
string work_;
|
||||
tstring work_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -48,7 +48,7 @@ class ReaderInterface : public ResourceBase {
|
||||
// *context with an OutOfRange Status if the current work is
|
||||
// complete and the queue is done (closed and empty).
|
||||
// This method may block.
|
||||
virtual void Read(QueueInterface* queue, string* key, string* value,
|
||||
virtual void Read(QueueInterface* queue, tstring* key, tstring* value,
|
||||
OpKernelContext* context) = 0;
|
||||
|
||||
// Read up to num_records records into keys / values. May get more work from
|
||||
@ -60,7 +60,8 @@ class ReaderInterface : public ResourceBase {
|
||||
// structures (that have most likely been reserve(num_records)).
|
||||
// Returns how many records were actually read.
|
||||
virtual int64 ReadUpTo(const int64 num_records, QueueInterface* queue,
|
||||
std::vector<string>* keys, std::vector<string>* value,
|
||||
std::vector<tstring>* keys,
|
||||
std::vector<tstring>* value,
|
||||
OpKernelContext* context) = 0;
|
||||
|
||||
// Restore this reader to its newly-constructed state.
|
||||
@ -72,9 +73,9 @@ class ReaderInterface : public ResourceBase {
|
||||
|
||||
// -- Serialization/Restoration support --
|
||||
// Not all readers will support saving and restoring state.
|
||||
virtual Status SerializeState(string* state) = 0;
|
||||
virtual Status SerializeState(tstring* state) = 0;
|
||||
// Note: Must Reset on error.
|
||||
virtual Status RestoreState(const string& state) = 0;
|
||||
virtual Status RestoreState(const tstring& state) = 0;
|
||||
|
||||
string DebugString() const override { return "a reader"; }
|
||||
|
||||
|
@ -419,7 +419,7 @@ class CSVDatasetOp : public DatasetOpKernel {
|
||||
Status SaveAndFillBuffer(std::vector<Piece>* earlier_pieces,
|
||||
size_t* start, bool include)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
string temp_buffer;
|
||||
tstring temp_buffer;
|
||||
|
||||
buffer_.swap(temp_buffer);
|
||||
if (include && pos_ > *start) {
|
||||
@ -622,7 +622,7 @@ class CSVDatasetOp : public DatasetOpKernel {
|
||||
}
|
||||
}
|
||||
|
||||
Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
Status FillBuffer(tstring* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
result->clear();
|
||||
++num_buffer_reads_;
|
||||
Status s = input_stream_->ReadNBytes(
|
||||
@ -827,7 +827,7 @@ class CSVDatasetOp : public DatasetOpKernel {
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
string buffer_ GUARDED_BY(mu_); // Maintain our own buffer
|
||||
tstring buffer_ GUARDED_BY(mu_); // Maintain our own buffer
|
||||
size_t pos_ GUARDED_BY(
|
||||
mu_); // Index into the buffer must be maintained between iters
|
||||
size_t num_buffer_reads_ GUARDED_BY(mu_);
|
||||
|
@ -162,11 +162,11 @@ class SnapshotReader {
|
||||
}
|
||||
}
|
||||
|
||||
Status ReadRecord(string* record) {
|
||||
Status ReadRecord(tstring* record) {
|
||||
profiler::TraceMe activity(
|
||||
absl::StrCat(kClassName, kSeparator, kReadString),
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
string header;
|
||||
tstring header;
|
||||
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
|
||||
uint64 length = core::DecodeFixed64(header.data());
|
||||
return input_stream_->ReadNBytes(length, record);
|
||||
@ -176,14 +176,14 @@ class SnapshotReader {
|
||||
Status ReadRecord(absl::Cord* record) {
|
||||
profiler::TraceMe activity(absl::StrCat(kClassName, kSeparator, kReadCord),
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
string header;
|
||||
tstring header;
|
||||
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
|
||||
uint64 length = core::DecodeFixed64(header.data());
|
||||
|
||||
if (compression_type_ == io::compression::kNone) {
|
||||
return input_stream_->ReadNBytes(length, record);
|
||||
} else {
|
||||
string tmp_str;
|
||||
tstring tmp_str;
|
||||
Status s = input_stream_->ReadNBytes(length, &tmp_str);
|
||||
record->Append(tmp_str);
|
||||
return s;
|
||||
@ -224,7 +224,7 @@ Status ReadMetadataFile(const string& hash_dir,
|
||||
std::unique_ptr<RandomAccessFile> file;
|
||||
TF_CHECK_OK(Env::Default()->NewRandomAccessFile(metadata_filename, &file));
|
||||
|
||||
string record_bytes;
|
||||
tstring record_bytes;
|
||||
auto reader = absl::make_unique<SnapshotReader>(file.get());
|
||||
TF_CHECK_OK(reader->ReadRecord(&record_bytes));
|
||||
|
||||
|
@ -258,7 +258,7 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase {
|
||||
if (dataset()->compression_type_.empty()) {
|
||||
DCHECK_GE(file_pos_limit_, 0);
|
||||
if (current_pos < file_pos_limit_) {
|
||||
string record;
|
||||
tstring record;
|
||||
TF_RETURN_IF_ERROR(buffered_input_stream_->ReadNBytes(
|
||||
dataset()->record_bytes_, &record));
|
||||
metrics::RecordTFDataBytesRead(kDatasetType,
|
||||
@ -272,16 +272,18 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase {
|
||||
return Status::OK();
|
||||
}
|
||||
} else {
|
||||
string record;
|
||||
tstring record;
|
||||
Status s = buffered_input_stream_->ReadNBytes(
|
||||
dataset()->record_bytes_, &record);
|
||||
if (s.ok()) {
|
||||
metrics::RecordTFDataBytesRead(kDatasetType,
|
||||
dataset()->record_bytes_);
|
||||
lookahead_cache_.append(record);
|
||||
record = lookahead_cache_.substr(0, dataset()->record_bytes_);
|
||||
lookahead_cache_ =
|
||||
lookahead_cache_.substr(dataset()->record_bytes_);
|
||||
StringPiece lookahead_cache_view(lookahead_cache_);
|
||||
record = tstring(
|
||||
lookahead_cache_view.substr(0, dataset()->record_bytes_));
|
||||
lookahead_cache_ = tstring(
|
||||
lookahead_cache_view.substr(dataset()->record_bytes_));
|
||||
// Produce the record as output.
|
||||
Tensor record_tensor(ctx->allocator({}), DT_STRING, {});
|
||||
record_tensor.scalar<tstring>()() = std::move(record);
|
||||
@ -433,7 +435,7 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase {
|
||||
std::unique_ptr<io::InputStreamInterface> buffered_input_stream_
|
||||
GUARDED_BY(mu_);
|
||||
int64 file_pos_limit_ GUARDED_BY(mu_) = -1;
|
||||
string lookahead_cache_ GUARDED_BY(mu_);
|
||||
tstring lookahead_cache_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
const std::vector<string> filenames_;
|
||||
|
@ -107,7 +107,7 @@ class TFRecordDatasetOp::Dataset : public DatasetBase {
|
||||
out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
|
||||
TensorShape({}));
|
||||
Status s =
|
||||
reader_->ReadRecord(&out_tensors->back().scalar<string>()());
|
||||
reader_->ReadRecord(&out_tensors->back().scalar<tstring>()());
|
||||
if (s.ok()) {
|
||||
metrics::RecordTFDataBytesRead(
|
||||
kDatasetType, out_tensors->back().scalar<tstring>()().size());
|
||||
@ -208,7 +208,7 @@ class TFRecordDatasetOp::Dataset : public DatasetBase {
|
||||
};
|
||||
|
||||
const std::vector<string> filenames_;
|
||||
const string compression_type_;
|
||||
const tstring compression_type_;
|
||||
io::RecordReaderOptions options_;
|
||||
};
|
||||
|
||||
@ -230,9 +230,9 @@ void TFRecordDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
filenames.push_back(filenames_tensor->flat<tstring>()(i));
|
||||
}
|
||||
|
||||
string compression_type;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, kCompressionType,
|
||||
&compression_type));
|
||||
tstring compression_type;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, kCompressionType,
|
||||
&compression_type));
|
||||
|
||||
int64 buffer_size = -1;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
|
@ -46,7 +46,7 @@ class TFRecordDatasetOpTest : public DatasetOpsTestBase {
|
||||
};
|
||||
|
||||
struct TestCase {
|
||||
std::vector<string> filenames;
|
||||
std::vector<tstring> filenames;
|
||||
std::vector<std::vector<string>> contents;
|
||||
CompressionType compression_type;
|
||||
int64 buffer_size;
|
||||
@ -84,12 +84,12 @@ TestCase TestCase1() {
|
||||
/*compression_type*/ CompressionType::ZLIB,
|
||||
/*buffer_size*/ 10,
|
||||
/*expected_outputs*/
|
||||
{CreateTensor<string>(TensorShape({}), {"1"}),
|
||||
CreateTensor<string>(TensorShape({}), {"22"}),
|
||||
CreateTensor<string>(TensorShape({}), {"333"}),
|
||||
CreateTensor<string>(TensorShape({}), {"a"}),
|
||||
CreateTensor<string>(TensorShape({}), {"bb"}),
|
||||
CreateTensor<string>(TensorShape({}), {"ccc"})},
|
||||
{CreateTensor<tstring>(TensorShape({}), {"1"}),
|
||||
CreateTensor<tstring>(TensorShape({}), {"22"}),
|
||||
CreateTensor<tstring>(TensorShape({}), {"333"}),
|
||||
CreateTensor<tstring>(TensorShape({}), {"a"}),
|
||||
CreateTensor<tstring>(TensorShape({}), {"bb"}),
|
||||
CreateTensor<tstring>(TensorShape({}), {"ccc"})},
|
||||
/*expected_output_dtypes*/ {DT_STRING},
|
||||
/*expected_output_shapes*/ {PartialTensorShape({})},
|
||||
/*expected_cardinality*/ kUnknownCardinality,
|
||||
@ -105,12 +105,12 @@ TestCase TestCase2() {
|
||||
/*compression_type*/ CompressionType::GZIP,
|
||||
/*buffer_size*/ 10,
|
||||
/*expected_outputs*/
|
||||
{CreateTensor<string>(TensorShape({}), {"1"}),
|
||||
CreateTensor<string>(TensorShape({}), {"22"}),
|
||||
CreateTensor<string>(TensorShape({}), {"333"}),
|
||||
CreateTensor<string>(TensorShape({}), {"a"}),
|
||||
CreateTensor<string>(TensorShape({}), {"bb"}),
|
||||
CreateTensor<string>(TensorShape({}), {"ccc"})},
|
||||
{CreateTensor<tstring>(TensorShape({}), {"1"}),
|
||||
CreateTensor<tstring>(TensorShape({}), {"22"}),
|
||||
CreateTensor<tstring>(TensorShape({}), {"333"}),
|
||||
CreateTensor<tstring>(TensorShape({}), {"a"}),
|
||||
CreateTensor<tstring>(TensorShape({}), {"bb"}),
|
||||
CreateTensor<tstring>(TensorShape({}), {"ccc"})},
|
||||
/*expected_output_dtypes*/ {DT_STRING},
|
||||
/*expected_output_shapes*/ {PartialTensorShape({})},
|
||||
/*expected_cardinality*/ kUnknownCardinality,
|
||||
@ -127,12 +127,12 @@ TestCase TestCase3() {
|
||||
/*compression_type*/ CompressionType::UNCOMPRESSED,
|
||||
/*buffer_size*/ 10,
|
||||
/*expected_outputs*/
|
||||
{CreateTensor<string>(TensorShape({}), {"1"}),
|
||||
CreateTensor<string>(TensorShape({}), {"22"}),
|
||||
CreateTensor<string>(TensorShape({}), {"333"}),
|
||||
CreateTensor<string>(TensorShape({}), {"a"}),
|
||||
CreateTensor<string>(TensorShape({}), {"bb"}),
|
||||
CreateTensor<string>(TensorShape({}), {"ccc"})},
|
||||
{CreateTensor<tstring>(TensorShape({}), {"1"}),
|
||||
CreateTensor<tstring>(TensorShape({}), {"22"}),
|
||||
CreateTensor<tstring>(TensorShape({}), {"333"}),
|
||||
CreateTensor<tstring>(TensorShape({}), {"a"}),
|
||||
CreateTensor<tstring>(TensorShape({}), {"bb"}),
|
||||
CreateTensor<tstring>(TensorShape({}), {"ccc"})},
|
||||
/*expected_output_dtypes*/ {DT_STRING},
|
||||
/*expected_output_shapes*/ {PartialTensorShape({})},
|
||||
/*expected_cardinality*/ kUnknownCardinality,
|
||||
@ -156,8 +156,8 @@ TEST_P(ParameterizedTFRecordDatasetOpTest, GetNext) {
|
||||
|
||||
int64 num_files = test_case.filenames.size();
|
||||
Tensor filenames =
|
||||
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<string>(
|
||||
CreateTensor<tstring>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<tstring>(
|
||||
TensorShape({}), {ToString(test_case.compression_type)});
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
|
||||
@ -206,8 +206,8 @@ TEST_F(TFRecordDatasetOpTest, DatasetNodeName) {
|
||||
|
||||
int64 num_files = test_case.filenames.size();
|
||||
Tensor filenames =
|
||||
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<string>(
|
||||
CreateTensor<tstring>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<tstring>(
|
||||
TensorShape({}), {ToString(test_case.compression_type)});
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
|
||||
@ -239,8 +239,8 @@ TEST_F(TFRecordDatasetOpTest, DatasetTypeString) {
|
||||
|
||||
int64 num_files = test_case.filenames.size();
|
||||
Tensor filenames =
|
||||
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<string>(
|
||||
CreateTensor<tstring>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<tstring>(
|
||||
TensorShape({}), {ToString(test_case.compression_type)});
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
|
||||
@ -273,8 +273,8 @@ TEST_P(ParameterizedTFRecordDatasetOpTest, DatasetOutputDtypes) {
|
||||
|
||||
int64 num_files = test_case.filenames.size();
|
||||
Tensor filenames =
|
||||
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<string>(
|
||||
CreateTensor<tstring>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<tstring>(
|
||||
TensorShape({}), {ToString(test_case.compression_type)});
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
|
||||
@ -307,8 +307,8 @@ TEST_P(ParameterizedTFRecordDatasetOpTest, DatasetOutputShapes) {
|
||||
|
||||
int64 num_files = test_case.filenames.size();
|
||||
Tensor filenames =
|
||||
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<string>(
|
||||
CreateTensor<tstring>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<tstring>(
|
||||
TensorShape({}), {ToString(test_case.compression_type)});
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
|
||||
@ -341,8 +341,8 @@ TEST_P(ParameterizedTFRecordDatasetOpTest, Cardinality) {
|
||||
|
||||
int64 num_files = test_case.filenames.size();
|
||||
Tensor filenames =
|
||||
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<string>(
|
||||
CreateTensor<tstring>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<tstring>(
|
||||
TensorShape({}), {ToString(test_case.compression_type)});
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
|
||||
@ -374,8 +374,8 @@ TEST_P(ParameterizedTFRecordDatasetOpTest, IteratorOutputDtypes) {
|
||||
|
||||
int64 num_files = test_case.filenames.size();
|
||||
Tensor filenames =
|
||||
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<string>(
|
||||
CreateTensor<tstring>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<tstring>(
|
||||
TensorShape({}), {ToString(test_case.compression_type)});
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
|
||||
@ -416,8 +416,8 @@ TEST_P(ParameterizedTFRecordDatasetOpTest, IteratorOutputShapes) {
|
||||
|
||||
int64 num_files = test_case.filenames.size();
|
||||
Tensor filenames =
|
||||
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<string>(
|
||||
CreateTensor<tstring>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<tstring>(
|
||||
TensorShape({}), {ToString(test_case.compression_type)});
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
|
||||
@ -458,8 +458,8 @@ TEST_P(ParameterizedTFRecordDatasetOpTest, IteratorOutputPrefix) {
|
||||
|
||||
int64 num_files = test_case.filenames.size();
|
||||
Tensor filenames =
|
||||
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<string>(
|
||||
CreateTensor<tstring>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<tstring>(
|
||||
TensorShape({}), {ToString(test_case.compression_type)});
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
|
||||
@ -501,8 +501,8 @@ TEST_P(ParameterizedTFRecordDatasetOpTest, Roundtrip) {
|
||||
|
||||
int64 num_files = test_case.filenames.size();
|
||||
Tensor filenames =
|
||||
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<string>(
|
||||
CreateTensor<tstring>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<tstring>(
|
||||
TensorShape({}), {ToString(test_case.compression_type)});
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
|
||||
|
@ -34,7 +34,7 @@ class MemoryInputStream : public io::InputStreamInterface {
|
||||
|
||||
~MemoryInputStream() override {}
|
||||
|
||||
Status ReadNBytes(int64 bytes_to_read, string* result) override {
|
||||
Status ReadNBytes(int64 bytes_to_read, tstring* result) override {
|
||||
result->clear();
|
||||
if (bytes_to_read < 0) {
|
||||
return errors::InvalidArgument("Can't read a negative number of bytes: ",
|
||||
@ -106,7 +106,7 @@ class DecodeCompressedOp : public OpKernel {
|
||||
new io::ZlibInputStream(
|
||||
input_stream.get(), static_cast<size_t>(kBufferSize),
|
||||
static_cast<size_t>(kBufferSize), zlib_options));
|
||||
string output_string;
|
||||
tstring output_string;
|
||||
Status s = zlib_stream->ReadNBytes(INT_MAX, &output_string);
|
||||
OP_REQUIRES(context, (s.ok() || errors::IsOutOfRange(s)), s);
|
||||
output_flat(i) = std::move(output_string);
|
||||
|
@ -77,7 +77,7 @@ class FixedLengthRecordReader : public ReaderBase {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadLocked(string* key, string* value, bool* produced,
|
||||
Status ReadLocked(tstring* key, tstring* value, bool* produced,
|
||||
bool* at_end) override {
|
||||
// We will always "hop" the hop_bytes_ except the first record
|
||||
// where record_number_ == 0
|
||||
|
@ -33,7 +33,7 @@ class IdentityReader : public ReaderBase {
|
||||
explicit IdentityReader(const string& node_name)
|
||||
: ReaderBase(strings::StrCat("IdentityReader '", node_name, "'")) {}
|
||||
|
||||
Status ReadLocked(string* key, string* value, bool* produced,
|
||||
Status ReadLocked(tstring* key, tstring* value, bool* produced,
|
||||
bool* at_end) override {
|
||||
*key = current_work();
|
||||
*value = current_work();
|
||||
@ -44,14 +44,14 @@ class IdentityReader : public ReaderBase {
|
||||
|
||||
// Stores state in a ReaderBaseState proto, since IdentityReader has
|
||||
// no additional state beyond ReaderBase.
|
||||
Status SerializeStateLocked(string* state) override {
|
||||
Status SerializeStateLocked(tstring* state) override {
|
||||
ReaderBaseState base_state;
|
||||
SaveBaseState(&base_state);
|
||||
base_state.SerializeToString(state);
|
||||
SerializeToTString(base_state, state);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreStateLocked(const string& state) override {
|
||||
Status RestoreStateLocked(const tstring& state) override {
|
||||
ReaderBaseState base_state;
|
||||
if (!ParseProtoUnlimited(&base_state, state)) {
|
||||
return errors::InvalidArgument("Could not parse state for ", name(), ": ",
|
||||
|
9
tensorflow/core/kernels/lmdb_reader_op.cc
Executable file → Normal file
9
tensorflow/core/kernels/lmdb_reader_op.cc
Executable file → Normal file
@ -68,7 +68,7 @@ class LMDBReader : public ReaderBase {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadLocked(string* key, string* value, bool* produced,
|
||||
Status ReadLocked(tstring* key, tstring* value, bool* produced,
|
||||
bool* at_end) override {
|
||||
if (mdb_cursor_ == nullptr) {
|
||||
MDB_CHECK(mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_));
|
||||
@ -82,9 +82,10 @@ class LMDBReader : public ReaderBase {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
*key = string(static_cast<const char*>(mdb_key_.mv_data), mdb_key_.mv_size);
|
||||
*value = string(static_cast<const char*>(mdb_value_.mv_data),
|
||||
mdb_value_.mv_size);
|
||||
*key =
|
||||
tstring(static_cast<const char*>(mdb_key_.mv_data), mdb_key_.mv_size);
|
||||
*value = tstring(static_cast<const char*>(mdb_value_.mv_data),
|
||||
mdb_value_.mv_size);
|
||||
*produced = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -90,9 +90,12 @@ class ReaderReadOp : public ReaderVerbAsyncOpKernel {
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("value", TensorShape({}), &value));
|
||||
|
||||
auto key_scalar = key->scalar<string>();
|
||||
auto value_scalar = value->scalar<string>();
|
||||
reader->Read(queue, &key_scalar(), &value_scalar(), context);
|
||||
auto key_scalar = key->scalar<tstring>();
|
||||
auto value_scalar = value->scalar<tstring>();
|
||||
tstring key_out, val_out;
|
||||
reader->Read(queue, &key_out, &val_out, context);
|
||||
key_scalar() = key_out;
|
||||
value_scalar() = val_out;
|
||||
}
|
||||
};
|
||||
|
||||
@ -115,9 +118,9 @@ class ReaderReadUpToOp : public ReaderVerbAsyncOpKernel {
|
||||
GetResourceFromContext(context, "queue_handle", &queue));
|
||||
core::ScopedUnref unref_me(queue);
|
||||
|
||||
std::vector<string> keys_vec;
|
||||
std::vector<tstring> keys_vec;
|
||||
keys_vec.reserve(num_records);
|
||||
std::vector<string> values_vec;
|
||||
std::vector<tstring> values_vec;
|
||||
values_vec.reserve(num_records);
|
||||
|
||||
int64 num_actually_read =
|
||||
@ -200,7 +203,7 @@ class ReaderSerializeStateOp : public ReaderVerbSyncOpKernel {
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("state", TensorShape({}), &output));
|
||||
OP_REQUIRES_OK(context,
|
||||
reader->SerializeState(&output->scalar<string>()()));
|
||||
reader->SerializeState(&output->scalar<tstring>()()));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -70,7 +70,7 @@ Status RecordYielder::YieldOne(tstring* value) {
|
||||
|
||||
struct RecordYielder::Shard {
|
||||
int index; // Shard index.
|
||||
std::vector<string> filenames; // File names given to this shard.
|
||||
std::vector<tstring> filenames; // File names given to this shard.
|
||||
Notification done; // Notified when this shard is done.
|
||||
Status status; // Shard status.
|
||||
};
|
||||
@ -211,7 +211,7 @@ void RecordYielder::ShardLoop(Shard* shard) {
|
||||
opts_.compression_type);
|
||||
io::RecordReader rdr(file.get(), options);
|
||||
uint64 offset = 0;
|
||||
string record;
|
||||
tstring record;
|
||||
while (true) {
|
||||
Status s = rdr.ReadRecord(&offset, &record);
|
||||
if (s.ok()) {
|
||||
|
@ -56,7 +56,7 @@ class TextLineReader : public ReaderBase {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadLocked(string* key, string* value, bool* produced,
|
||||
Status ReadLocked(tstring* key, tstring* value, bool* produced,
|
||||
bool* at_end) override {
|
||||
Status status = input_buffer_->ReadLine(value);
|
||||
++line_number_;
|
||||
|
@ -50,7 +50,7 @@ class TFRecordReader : public ReaderBase {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadLocked(string* key, string* value, bool* produced,
|
||||
Status ReadLocked(tstring* key, tstring* value, bool* produced,
|
||||
bool* at_end) override {
|
||||
*key = strings::StrCat(current_work(), ":", offset_);
|
||||
Status status = reader_->ReadRecord(&offset_, value);
|
||||
|
@ -50,7 +50,7 @@ class WholeFileReader : public ReaderBase {
|
||||
: ReaderBase(strings::StrCat("WholeFileReader '", node_name, "'")),
|
||||
env_(env) {}
|
||||
|
||||
Status ReadLocked(string* key, string* value, bool* produced,
|
||||
Status ReadLocked(tstring* key, tstring* value, bool* produced,
|
||||
bool* at_end) override {
|
||||
*key = current_work();
|
||||
TF_RETURN_IF_ERROR(ReadEntireFile(env_, *key, value));
|
||||
@ -61,14 +61,14 @@ class WholeFileReader : public ReaderBase {
|
||||
|
||||
// Stores state in a ReaderBaseState proto, since WholeFileReader has
|
||||
// no additional state beyond ReaderBase.
|
||||
Status SerializeStateLocked(string* state) override {
|
||||
Status SerializeStateLocked(tstring* state) override {
|
||||
ReaderBaseState base_state;
|
||||
SaveBaseState(&base_state);
|
||||
base_state.SerializeToString(state);
|
||||
SerializeToTString(base_state, state);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreStateLocked(const string& state) override {
|
||||
Status RestoreStateLocked(const tstring& state) override {
|
||||
ReaderBaseState base_state;
|
||||
if (!ParseProtoUnlimited(&base_state, state)) {
|
||||
return errors::InvalidArgument("Could not parse state for ", name(), ": ",
|
||||
|
@ -85,7 +85,7 @@ Status BufferedInputStream::ReadLineHelper(string* result, bool include_eol) {
|
||||
return s;
|
||||
}
|
||||
|
||||
Status BufferedInputStream::ReadNBytes(int64 bytes_to_read, string* result) {
|
||||
Status BufferedInputStream::ReadNBytes(int64 bytes_to_read, tstring* result) {
|
||||
if (bytes_to_read < 0) {
|
||||
return errors::InvalidArgument("Can't read a negative number of bytes: ",
|
||||
bytes_to_read);
|
||||
|
@ -41,7 +41,7 @@ class BufferedInputStream : public InputStreamInterface {
|
||||
|
||||
~BufferedInputStream() override;
|
||||
|
||||
Status ReadNBytes(int64 bytes_to_read, string* result) override;
|
||||
Status ReadNBytes(int64 bytes_to_read, tstring* result) override;
|
||||
|
||||
Status SkipNBytes(int64 bytes_to_skip) override;
|
||||
|
||||
@ -90,7 +90,7 @@ class BufferedInputStream : public InputStreamInterface {
|
||||
|
||||
InputStreamInterface* input_stream_; // not owned.
|
||||
size_t size_; // buffer size.
|
||||
string buf_; // the buffer itself.
|
||||
tstring buf_; // the buffer itself.
|
||||
// buf_[pos_, limit_) holds the valid "read ahead" data in the file.
|
||||
size_t pos_ = 0; // current position in buf_.
|
||||
size_t limit_ = 0; // just past the end of valid data in buf_.
|
||||
|
@ -163,7 +163,7 @@ TEST(BufferedInputStream, ReadNBytes) {
|
||||
for (auto buf_size : BufferSizes()) {
|
||||
std::unique_ptr<RandomAccessInputStream> input_stream(
|
||||
new RandomAccessInputStream(file.get()));
|
||||
string read;
|
||||
tstring read;
|
||||
BufferedInputStream in(input_stream.get(), buf_size);
|
||||
EXPECT_EQ(0, in.Tell());
|
||||
TF_ASSERT_OK(in.ReadNBytes(3, &read));
|
||||
@ -200,7 +200,7 @@ TEST(BufferedInputStream, SkipNBytes) {
|
||||
for (auto buf_size : BufferSizes()) {
|
||||
std::unique_ptr<RandomAccessInputStream> input_stream(
|
||||
new RandomAccessInputStream(file.get()));
|
||||
string read;
|
||||
tstring read;
|
||||
BufferedInputStream in(input_stream.get(), buf_size);
|
||||
EXPECT_EQ(0, in.Tell());
|
||||
TF_ASSERT_OK(in.SkipNBytes(3));
|
||||
@ -235,7 +235,7 @@ TEST(BufferedInputStream, ReadNBytesRandomAccessFile) {
|
||||
TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file));
|
||||
|
||||
for (auto buf_size : BufferSizes()) {
|
||||
string read;
|
||||
tstring read;
|
||||
BufferedInputStream in(file.get(), buf_size);
|
||||
EXPECT_EQ(0, in.Tell());
|
||||
TF_ASSERT_OK(in.ReadNBytes(3, &read));
|
||||
@ -270,7 +270,7 @@ TEST(BufferedInputStream, SkipNBytesRandomAccessFile) {
|
||||
TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file));
|
||||
|
||||
for (auto buf_size : BufferSizes()) {
|
||||
string read;
|
||||
tstring read;
|
||||
BufferedInputStream in(file.get(), buf_size);
|
||||
EXPECT_EQ(0, in.Tell());
|
||||
TF_ASSERT_OK(in.SkipNBytes(3));
|
||||
@ -307,7 +307,7 @@ TEST(BufferedInputStream, Seek) {
|
||||
for (auto buf_size : BufferSizes()) {
|
||||
std::unique_ptr<RandomAccessInputStream> input_stream(
|
||||
new RandomAccessInputStream(file.get()));
|
||||
string read;
|
||||
tstring read;
|
||||
BufferedInputStream in(input_stream.get(), buf_size);
|
||||
|
||||
// Seek forward
|
||||
@ -378,7 +378,7 @@ void BM_BufferedReaderSmallReads(const int iters, const int buff_size,
|
||||
std::unique_ptr<RandomAccessFile> file;
|
||||
TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file));
|
||||
|
||||
string result;
|
||||
tstring result;
|
||||
testing::StartTiming();
|
||||
|
||||
for (int itr = 0; itr < iters; ++itr) {
|
||||
|
@ -42,7 +42,8 @@ Status InputBuffer::FillBuffer() {
|
||||
return s;
|
||||
}
|
||||
|
||||
Status InputBuffer::ReadLine(string* result) {
|
||||
template <typename T>
|
||||
Status InputBuffer::ReadLine(T* result) {
|
||||
result->clear();
|
||||
Status s;
|
||||
do {
|
||||
@ -71,6 +72,11 @@ Status InputBuffer::ReadLine(string* result) {
|
||||
return s;
|
||||
}
|
||||
|
||||
template Status InputBuffer::ReadLine<string>(string* result);
|
||||
#ifdef USE_TSTRING
|
||||
template Status InputBuffer::ReadLine<tstring>(tstring* result);
|
||||
#endif // USE_TSTRING
|
||||
|
||||
Status InputBuffer::ReadNBytes(int64 bytes_to_read, string* result) {
|
||||
result->clear();
|
||||
if (bytes_to_read < 0) {
|
||||
|
@ -43,7 +43,8 @@ class InputBuffer {
|
||||
// If successful, returns OK. If we are already at the end of the
|
||||
// file, we return an OUT_OF_RANGE error. Otherwise, we return
|
||||
// some other non-OK status.
|
||||
Status ReadLine(string* result);
|
||||
template <typename T>
|
||||
Status ReadLine(T* result);
|
||||
|
||||
// Reads bytes_to_read bytes into *result, overwriting *result.
|
||||
//
|
||||
|
@ -28,7 +28,7 @@ Status InputStreamInterface::SkipNBytes(int64 bytes_to_skip) {
|
||||
if (bytes_to_skip < 0) {
|
||||
return errors::InvalidArgument("Can't skip a negative number of bytes");
|
||||
}
|
||||
string unused;
|
||||
tstring unused;
|
||||
// Read kDefaultSkipSize at a time till bytes_to_skip.
|
||||
while (bytes_to_skip > 0) {
|
||||
int64 bytes_to_read = std::min<int64>(kMaxSkipSize, bytes_to_skip);
|
||||
|
@ -35,7 +35,7 @@ class InputStreamInterface {
|
||||
// Reads the next bytes_to_read from the file. Typical return codes:
|
||||
// * OK - in case of success.
|
||||
// * OUT_OF_RANGE - not enough bytes remaining before end of file.
|
||||
virtual Status ReadNBytes(int64 bytes_to_read, string* result) = 0;
|
||||
virtual Status ReadNBytes(int64 bytes_to_read, tstring* result) = 0;
|
||||
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
// Reads the next bytes_to_read from the file. Typical return codes:
|
||||
|
@ -27,7 +27,7 @@ class TestStringStream : public InputStreamInterface {
|
||||
public:
|
||||
explicit TestStringStream(const string& content) : content_(content) {}
|
||||
|
||||
Status ReadNBytes(int64 bytes_to_read, string* result) override {
|
||||
Status ReadNBytes(int64 bytes_to_read, tstring* result) override {
|
||||
result->clear();
|
||||
if (pos_ + bytes_to_read > content_.size()) {
|
||||
return errors::OutOfRange("limit reached");
|
||||
@ -51,7 +51,7 @@ class TestStringStream : public InputStreamInterface {
|
||||
|
||||
TEST(InputStreamInterface, Basic) {
|
||||
TestStringStream ss("This is a test string");
|
||||
string res;
|
||||
tstring res;
|
||||
TF_ASSERT_OK(ss.ReadNBytes(4, &res));
|
||||
EXPECT_EQ("This", res);
|
||||
TF_ASSERT_OK(ss.SkipNBytes(6));
|
||||
|
@ -30,7 +30,7 @@ RandomAccessInputStream::~RandomAccessInputStream() {
|
||||
}
|
||||
|
||||
Status RandomAccessInputStream::ReadNBytes(int64 bytes_to_read,
|
||||
string* result) {
|
||||
tstring* result) {
|
||||
if (bytes_to_read < 0) {
|
||||
return errors::InvalidArgument("Cannot read negative number of bytes");
|
||||
}
|
||||
|
@ -33,7 +33,7 @@ class RandomAccessInputStream : public InputStreamInterface {
|
||||
|
||||
~RandomAccessInputStream();
|
||||
|
||||
Status ReadNBytes(int64 bytes_to_read, string* result) override;
|
||||
Status ReadNBytes(int64 bytes_to_read, tstring* result) override;
|
||||
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
Status ReadNBytes(int64 bytes_to_read, absl::Cord* result) override;
|
||||
|
@ -30,7 +30,7 @@ TEST(RandomInputStream, ReadNBytes) {
|
||||
|
||||
std::unique_ptr<RandomAccessFile> file;
|
||||
TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file));
|
||||
string read;
|
||||
tstring read;
|
||||
RandomAccessInputStream in(file.get());
|
||||
TF_ASSERT_OK(in.ReadNBytes(3, &read));
|
||||
EXPECT_EQ(read, "012");
|
||||
@ -59,7 +59,7 @@ TEST(RandomInputStream, SkipNBytes) {
|
||||
|
||||
std::unique_ptr<RandomAccessFile> file;
|
||||
TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file));
|
||||
string read;
|
||||
tstring read;
|
||||
RandomAccessInputStream in(file.get());
|
||||
TF_ASSERT_OK(in.SkipNBytes(3));
|
||||
EXPECT_EQ(3, in.Tell());
|
||||
@ -90,7 +90,7 @@ TEST(RandomInputStream, Seek) {
|
||||
|
||||
std::unique_ptr<RandomAccessFile> file;
|
||||
TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file));
|
||||
string read;
|
||||
tstring read;
|
||||
RandomAccessInputStream in(file.get());
|
||||
|
||||
// Seek forward
|
||||
|
@ -84,7 +84,7 @@ RecordReader::RecordReader(RandomAccessFile* file,
|
||||
//
|
||||
// offset corresponds to the user-provided value to ReadRecord()
|
||||
// and is used only in error messages.
|
||||
Status RecordReader::ReadChecksummed(uint64 offset, size_t n, string* result) {
|
||||
Status RecordReader::ReadChecksummed(uint64 offset, size_t n, tstring* result) {
|
||||
if (n >= SIZE_MAX - sizeof(uint32)) {
|
||||
return errors::DataLoss("record size too large");
|
||||
}
|
||||
@ -125,7 +125,7 @@ Status RecordReader::GetMetadata(Metadata* md) {
|
||||
// loop should be guaranteed to either return after reaching EOF
|
||||
// or encountering an error.
|
||||
uint64 offset = 0;
|
||||
string record;
|
||||
tstring record;
|
||||
while (true) {
|
||||
// Read header, containing size of data.
|
||||
Status s = ReadChecksummed(offset, sizeof(uint64), &record);
|
||||
@ -161,7 +161,7 @@ Status RecordReader::GetMetadata(Metadata* md) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RecordReader::ReadRecord(uint64* offset, string* record) {
|
||||
Status RecordReader::ReadRecord(uint64* offset, tstring* record) {
|
||||
// Position the input stream.
|
||||
int64 curr_pos = input_stream_->Tell();
|
||||
int64 desired_pos = static_cast<int64>(*offset);
|
||||
|
@ -89,7 +89,7 @@ class RecordReader {
|
||||
// Read the record at "*offset" into *record and update *offset to
|
||||
// point to the offset of the next record. Returns OK on success,
|
||||
// OUT_OF_RANGE for end of file, or something else for an error.
|
||||
Status ReadRecord(uint64* offset, string* record);
|
||||
Status ReadRecord(uint64* offset, tstring* record);
|
||||
|
||||
// Return the metadata of the Record file.
|
||||
//
|
||||
@ -103,7 +103,7 @@ class RecordReader {
|
||||
Status GetMetadata(Metadata* md);
|
||||
|
||||
private:
|
||||
Status ReadChecksummed(uint64 offset, size_t n, string* result);
|
||||
Status ReadChecksummed(uint64 offset, size_t n, tstring* result);
|
||||
|
||||
RecordReaderOptions options_;
|
||||
std::unique_ptr<InputStreamInterface> input_stream_;
|
||||
@ -129,7 +129,7 @@ class SequentialRecordReader {
|
||||
|
||||
// Reads the next record in the file into *record. Returns OK on success,
|
||||
// OUT_OF_RANGE for end of file, or something else for an error.
|
||||
Status ReadRecord(string* record) {
|
||||
Status ReadRecord(tstring* record) {
|
||||
return underlying_.ReadRecord(&offset_, record);
|
||||
}
|
||||
|
||||
|
@ -86,7 +86,7 @@ void VerifyFlush(const io::RecordWriterOptions& options) {
|
||||
|
||||
// Verify that file has all records written so far and no more.
|
||||
uint64 offset = 0;
|
||||
string record;
|
||||
tstring record;
|
||||
for (size_t j = 0; j <= i; j++) {
|
||||
// Check that j'th record is written correctly.
|
||||
TF_CHECK_OK(reader.ReadRecord(&offset, &record));
|
||||
@ -142,7 +142,7 @@ TEST(RecordReaderWriterTest, TestBasics) {
|
||||
options.zlib_options.input_buffer_size = buf_size;
|
||||
io::RecordReader reader(read_file.get(), options);
|
||||
uint64 offset = 0;
|
||||
string record;
|
||||
tstring record;
|
||||
TF_CHECK_OK(reader.ReadRecord(&offset, &record));
|
||||
EXPECT_EQ("abc", record);
|
||||
TF_CHECK_OK(reader.ReadRecord(&offset, &record));
|
||||
@ -187,7 +187,7 @@ TEST(RecordReaderWriterTest, TestZlib) {
|
||||
options.zlib_options.input_buffer_size = buf_size;
|
||||
io::RecordReader reader(read_file.get(), options);
|
||||
uint64 offset = 0;
|
||||
string record;
|
||||
tstring record;
|
||||
TF_CHECK_OK(reader.ReadRecord(&offset, &record));
|
||||
EXPECT_EQ("abc", record);
|
||||
TF_CHECK_OK(reader.ReadRecord(&offset, &record));
|
||||
|
@ -149,7 +149,7 @@ class RecordioTest : public ::testing::Test {
|
||||
if (!reading_) {
|
||||
reading_ = true;
|
||||
}
|
||||
string record;
|
||||
tstring record;
|
||||
Status s = reader_->ReadRecord(&readpos_, &record);
|
||||
if (s.ok()) {
|
||||
return record;
|
||||
@ -183,7 +183,7 @@ class RecordioTest : public ::testing::Test {
|
||||
Write(BigString("x", 10000));
|
||||
reading_ = true;
|
||||
uint64 offset = WrittenBytes() + offset_past_end;
|
||||
string record;
|
||||
tstring record;
|
||||
Status s = reader_->ReadRecord(&offset, &record);
|
||||
ASSERT_TRUE(errors::IsOutOfRange(s)) << s;
|
||||
}
|
||||
@ -261,7 +261,7 @@ void TestNonSequentialReads(const RecordWriterOptions& writer_options,
|
||||
StringSource file(&contents);
|
||||
RecordReader reader(&file, reader_options);
|
||||
|
||||
string record;
|
||||
tstring record;
|
||||
// First read sequentially to fill in the offsets table.
|
||||
uint64 offsets[10] = {0};
|
||||
uint64 offset = 0;
|
||||
@ -315,7 +315,7 @@ void TestReadError(const RecordWriterOptions& writer_options,
|
||||
RecordReader reader(&file, reader_options);
|
||||
|
||||
uint64 offset = 0;
|
||||
string read;
|
||||
tstring read;
|
||||
file.force_error();
|
||||
Status status = reader.ReadRecord(&offset, &read);
|
||||
ASSERT_TRUE(errors::IsDataLoss(status));
|
||||
|
@ -121,7 +121,7 @@ Status TestMultipleWrites(size_t compress_input_buf_size,
|
||||
for (int attempt = 0; attempt < 2; ++attempt) {
|
||||
string actual_result;
|
||||
for (int i = 0; i < num_writes; i++) {
|
||||
string decompressed_output;
|
||||
tstring decompressed_output;
|
||||
TF_RETURN_IF_ERROR(in.ReadNBytes(data.size(), &decompressed_output));
|
||||
strings::StrAppend(&actual_result, decompressed_output);
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ SnappyInputBuffer::SnappyInputBuffer(
|
||||
output_buffer_(new char[output_buffer_capacity_]),
|
||||
next_in_(input_buffer_.get()) {}
|
||||
|
||||
Status SnappyInputBuffer::ReadNBytes(int64 bytes_to_read, string* result) {
|
||||
Status SnappyInputBuffer::ReadNBytes(int64 bytes_to_read, tstring* result) {
|
||||
result->clear();
|
||||
// Read as many bytes as possible from cache.
|
||||
bytes_to_read -= ReadBytesFromCache(bytes_to_read, result);
|
||||
@ -62,7 +62,7 @@ Status SnappyInputBuffer::Reset() {
|
||||
}
|
||||
|
||||
size_t SnappyInputBuffer::ReadBytesFromCache(size_t bytes_to_read,
|
||||
string* result) {
|
||||
tstring* result) {
|
||||
size_t can_read_bytes = std::min(bytes_to_read, avail_out_);
|
||||
if (can_read_bytes > 0) {
|
||||
result->append(next_out_, can_read_bytes);
|
||||
|
@ -54,7 +54,7 @@ class SnappyInputBuffer : public InputStreamInterface {
|
||||
// If input_buffer_ is smaller in size than a compressed block.
|
||||
// others:
|
||||
// If reading from file failed.
|
||||
Status ReadNBytes(int64 bytes_to_read, string* result) override;
|
||||
Status ReadNBytes(int64 bytes_to_read, tstring* result) override;
|
||||
|
||||
int64 Tell() const override;
|
||||
|
||||
@ -86,7 +86,7 @@ class SnappyInputBuffer : public InputStreamInterface {
|
||||
// bytes have been read or `next_out_` is reached.
|
||||
// Returns the number of bytes read and advances the `next_out_`
|
||||
// pointer to the next location to read from.
|
||||
size_t ReadBytesFromCache(size_t bytes_to_read, string* result);
|
||||
size_t ReadBytesFromCache(size_t bytes_to_read, tstring* result);
|
||||
|
||||
// Reads the length of the next *compressed* block and stores in `length`.
|
||||
// The length is stored in 4 bytes in little endian notation.
|
||||
|
@ -69,7 +69,7 @@ void TestAllCombinations(CompressionOptions input_options,
|
||||
for (auto output_buf_size : OutputBufferSizes()) {
|
||||
std::unique_ptr<WritableFile> file_writer;
|
||||
TF_ASSERT_OK(env->NewWritableFile(fname, &file_writer));
|
||||
string result;
|
||||
tstring result;
|
||||
|
||||
ZlibOutputBuffer out(file_writer.get(), input_buf_size, output_buf_size,
|
||||
output_options);
|
||||
@ -142,7 +142,7 @@ void TestMultipleWrites(uint8 input_buf_size, uint8 output_buf_size,
|
||||
input_options);
|
||||
|
||||
for (int i = 0; i < num_writes; i++) {
|
||||
string decompressed_output;
|
||||
tstring decompressed_output;
|
||||
TF_ASSERT_OK(in.ReadNBytes(data.size(), &decompressed_output));
|
||||
strings::StrAppend(&actual_result, decompressed_output);
|
||||
}
|
||||
@ -171,7 +171,7 @@ TEST(ZlibInputStream, FailsToReadIfWindowBitsAreIncompatible) {
|
||||
string data = GenTestString(10);
|
||||
std::unique_ptr<WritableFile> file_writer;
|
||||
TF_ASSERT_OK(env->NewWritableFile(fname, &file_writer));
|
||||
string result;
|
||||
tstring result;
|
||||
ZlibOutputBuffer out(file_writer.get(), input_buf_size, output_buf_size,
|
||||
output_options);
|
||||
TF_ASSERT_OK(out.Init());
|
||||
@ -229,8 +229,8 @@ void TestTell(CompressionOptions input_options,
|
||||
ZlibInputStream in(input_stream.get(), input_buf_size, output_buf_size,
|
||||
input_options);
|
||||
|
||||
string first_half(data, 0, data.size() / 2);
|
||||
string bytes_read;
|
||||
tstring first_half(string(data, 0, data.size() / 2));
|
||||
tstring bytes_read;
|
||||
|
||||
// Read the first half of the uncompressed file and expect that Tell()
|
||||
// returns half the uncompressed length of the file.
|
||||
@ -240,7 +240,7 @@ void TestTell(CompressionOptions input_options,
|
||||
|
||||
// Read the remaining half of the uncompressed file and expect that
|
||||
// Tell() points past the end of file.
|
||||
string second_half;
|
||||
tstring second_half;
|
||||
TF_ASSERT_OK(
|
||||
in.ReadNBytes(data.size() - first_half.size(), &second_half));
|
||||
EXPECT_EQ(in.Tell(), data.size());
|
||||
@ -283,7 +283,7 @@ void TestSkipNBytes(CompressionOptions input_options,
|
||||
|
||||
// Expect that second half is read correctly and Tell() returns past
|
||||
// end of file after reading complete file.
|
||||
string bytes_read;
|
||||
tstring bytes_read;
|
||||
TF_ASSERT_OK(in.ReadNBytes(second_half.size(), &bytes_read));
|
||||
EXPECT_EQ(bytes_read, second_half);
|
||||
EXPECT_EQ(in.Tell(), data.size());
|
||||
|
@ -132,7 +132,7 @@ Status ZlibInputStream::ReadFromStream() {
|
||||
bytes_to_read -= z_stream_def_->stream->avail_in;
|
||||
read_location += z_stream_def_->stream->avail_in;
|
||||
}
|
||||
string data;
|
||||
tstring data;
|
||||
// Try to read enough data to fill up z_stream_def_->input.
|
||||
// TODO(rohanj): Add a char* version of ReadNBytes to InputStreamInterface
|
||||
// and use that instead to make this more efficient.
|
||||
@ -166,7 +166,7 @@ Status ZlibInputStream::ReadFromStream() {
|
||||
}
|
||||
|
||||
size_t ZlibInputStream::ReadBytesFromCache(size_t bytes_to_read,
|
||||
string* result) {
|
||||
tstring* result) {
|
||||
size_t unread_bytes =
|
||||
reinterpret_cast<char*>(z_stream_def_->stream->next_out) -
|
||||
next_unread_byte_;
|
||||
@ -186,7 +186,7 @@ size_t ZlibInputStream::NumUnreadBytes() const {
|
||||
read_bytes;
|
||||
}
|
||||
|
||||
Status ZlibInputStream::ReadNBytes(int64 bytes_to_read, string* result) {
|
||||
Status ZlibInputStream::ReadNBytes(int64 bytes_to_read, tstring* result) {
|
||||
result->clear();
|
||||
// Read as many bytes as possible from cache.
|
||||
bytes_to_read -= ReadBytesFromCache(bytes_to_read, result);
|
||||
|
@ -66,7 +66,7 @@ class ZlibInputStream : public InputStreamInterface {
|
||||
// ABORTED: If inflate() fails, we return the error code with the
|
||||
// error message in `z_stream_->msg`.
|
||||
// others: If reading from stream failed.
|
||||
Status ReadNBytes(int64 bytes_to_read, string* result) override;
|
||||
Status ReadNBytes(int64 bytes_to_read, tstring* result) override;
|
||||
|
||||
int64 Tell() const override;
|
||||
|
||||
@ -107,7 +107,7 @@ class ZlibInputStream : public InputStreamInterface {
|
||||
// bytes have been read or `z_stream_->next_out` is reached.
|
||||
// Returns the number of bytes read and advances the `next_unread_byte_`
|
||||
// pointer to the next location to read from.
|
||||
size_t ReadBytesFromCache(size_t bytes_to_read, string* result);
|
||||
size_t ReadBytesFromCache(size_t bytes_to_read, tstring* result);
|
||||
|
||||
// The number of unread bytes in z_stream_output_.
|
||||
//
|
||||
|
@ -163,6 +163,8 @@ class tstring {
|
||||
|
||||
const char* data() const { return str_.data(); }
|
||||
|
||||
char back() const { return str_.back(); }
|
||||
|
||||
const char& operator[](size_t i) const { return str_[i]; }
|
||||
|
||||
char* data() { return &str_[0]; }
|
||||
@ -209,6 +211,15 @@ class tstring {
|
||||
return *this;
|
||||
}
|
||||
|
||||
void swap(tstring& str) { str_.swap(str.str_); }
|
||||
|
||||
tstring& insert(size_t pos, const tstring& str, size_t subpos,
|
||||
size_t sublen) {
|
||||
str_.insert(pos, str.str_, subpos, sublen);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
void push_back(char ch) { str_.push_back(ch); }
|
||||
|
||||
friend const tstring operator+(const tstring& a, const tstring& b);
|
||||
|
@ -96,7 +96,7 @@ int main(int argc, char* argv[]) {
|
||||
uint64 start = env->NowMicros();
|
||||
uint64 records = 0;
|
||||
uint64 offset = 0;
|
||||
string record;
|
||||
tstring record;
|
||||
while (true) {
|
||||
std::unique_ptr<Event> event = std::unique_ptr<Event>(new Event);
|
||||
Status s = reader.ReadRecord(&offset, &record);
|
||||
|
@ -69,7 +69,7 @@ class SummaryFileWriterTest : public ::testing::Test {
|
||||
TF_CHECK_OK(env_.NewRandomAccessFile(io::JoinPath(testing::TmpDir(), f),
|
||||
&read_file));
|
||||
io::RecordReader reader(read_file.get(), io::RecordReaderOptions());
|
||||
string record;
|
||||
tstring record;
|
||||
uint64 offset = 0;
|
||||
TF_CHECK_OK(
|
||||
reader.ReadRecord(&offset,
|
||||
|
@ -53,7 +53,7 @@ void WriteFile(EventsWriter* writer) {
|
||||
|
||||
static bool ReadEventProto(io::RecordReader* reader, uint64* offset,
|
||||
Event* proto) {
|
||||
string record;
|
||||
tstring record;
|
||||
Status s = reader->ReadRecord(offset, &record);
|
||||
if (!s.ok()) {
|
||||
return false;
|
||||
|
@ -228,7 +228,7 @@ int64 TellFile(tensorflow::WritableFile* file, TF_Status* status) {
|
||||
string ReadFromStream(tensorflow::io::BufferedInputStream* stream,
|
||||
size_t bytes,
|
||||
TF_Status* status) {
|
||||
string result;
|
||||
tensorflow::tstring result;
|
||||
tensorflow::Status s = stream->ReadNBytes(bytes, &result);
|
||||
if (!s.ok() && s.code() != tensorflow::error::OUT_OF_RANGE) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
|
@ -63,7 +63,7 @@ class PyRecordReader {
|
||||
uint64 offset_;
|
||||
RandomAccessFile* file_; // Owned
|
||||
io::RecordReader* reader_; // Owned
|
||||
string record_;
|
||||
tstring record_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(PyRecordReader);
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user