Avoid CHECKs in BundleReader, propagate errors instead.
Motivation: We'd like to evolve the checkpoint format over time (e.g., enable different types of compression). Without this change, a TensorFlow version that encounters a format that it doesn't understand would CHECK fail with an unhelpful error message. With this, it propagates a clearer error message up, giving the user some hints about what could be wrong. I don't have a unittest for this - I thought about writing a bundle and then strategically corrupting the bytes on disk before reading it back, but that seems a bit much. The intention of this change is to enable graceful reporting of forward compatibility breakages. Ideas for an appropriate unittest are appreciated. PiperOrigin-RevId: 157620358
This commit is contained in:
parent
ee05b8b690
commit
c9cc388dc2
@ -238,6 +238,23 @@ bool IsFullSlice(const TensorSlice& slice_spec,
|
||||
}
|
||||
}
|
||||
|
||||
Status CorruptFileError(const Status& in_status, const string& filename,
|
||||
const string& detail) {
|
||||
if (in_status.ok()) {
|
||||
return errors::Internal("Unable to read file (", filename,
|
||||
"). Perhaps the file is corrupt or was produced by "
|
||||
"a newer version of TensorFlow with format changes "
|
||||
"(",
|
||||
detail, ")");
|
||||
}
|
||||
return Status(
|
||||
in_status.code(),
|
||||
strings::StrCat("Unable to read file (", filename,
|
||||
"). Perhaps the file is corrupt or was produced by a "
|
||||
"newer version of TensorFlow with format changes (",
|
||||
detail, "): ", in_status.error_message()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
BundleWriter::BundleWriter(Env* env, StringPiece prefix)
|
||||
@ -433,11 +450,13 @@ static Status MergeOneBundle(Env* env, StringPiece prefix,
|
||||
// Process header.
|
||||
{
|
||||
iter->Seek(kHeaderEntryKey);
|
||||
CHECK(iter->Valid()) << "File: " << filename
|
||||
<< ", iterator status: " << iter->status();
|
||||
if (!iter->Valid()) {
|
||||
return CorruptFileError(iter->status(), filename,
|
||||
"failed to seek to header entry");
|
||||
}
|
||||
BundleHeaderProto header;
|
||||
TF_CHECK_OK(ParseEntryProto(iter->key(), iter->value(), &header));
|
||||
CHECK_GE(header.num_shards(), 0);
|
||||
Status s = ParseEntryProto(iter->key(), iter->value(), &header);
|
||||
if (!s.ok()) return CorruptFileError(s, filename, "unable to parse header");
|
||||
|
||||
merge_state->num_shards += header.num_shards();
|
||||
if (!merge_state->seen_first_bundle) {
|
||||
@ -584,10 +603,17 @@ BundleReader::BundleReader(Env* env, StringPiece prefix)
|
||||
|
||||
// Reads "num_shards_" from the first entry.
|
||||
iter_->Seek(kHeaderEntryKey);
|
||||
CHECK(iter_->Valid()) << "File: " << filename
|
||||
<< ", iterator status: " << iter_->status();
|
||||
if (!iter_->Valid()) {
|
||||
status_ = CorruptFileError(iter_->status(), filename,
|
||||
"failed to seek to header entry");
|
||||
return;
|
||||
}
|
||||
BundleHeaderProto header;
|
||||
TF_CHECK_OK(ParseEntryProto(iter_->key(), iter_->value(), &header));
|
||||
status_ = ParseEntryProto(iter_->key(), iter_->value(), &header);
|
||||
if (!status_.ok()) {
|
||||
status_ = CorruptFileError(status_, filename, "unable to parse header");
|
||||
return;
|
||||
}
|
||||
num_shards_ = header.num_shards();
|
||||
if ((header.endianness() == BundleHeaderProto::BIG && port::kLittleEndian) ||
|
||||
(header.endianness() == BundleHeaderProto::LITTLE &&
|
||||
|
Loading…
Reference in New Issue
Block a user