Avoid CHECKs in BundleReader, propagate errors instead.

Motivation:
We'd like to evolve the checkpoint format over time (e.g., enable
different types of compression). Without this change, a TensorFlow
version that encounters a format that it doesn't understand would CHECK fail
with an unhelpful error message.

With this, it propagates a clearer error message up, giving the user some
hints about what could be wrong.

I don't have a unittest for this - I thought about writing a bundle and
then strategically corrupting the bytes on disk before reading it back,
but that seems a bit much. The intention of this change is to enable
graceful reporting of forward compatibility breakages. Ideas for an
appropriate unittest are appreciated.

PiperOrigin-RevId: 157620358
This commit is contained in:
Asim Shankar 2017-05-31 12:39:14 -07:00 committed by TensorFlower Gardener
parent ee05b8b690
commit c9cc388dc2

View File

@ -238,6 +238,23 @@ bool IsFullSlice(const TensorSlice& slice_spec,
}
}
Status CorruptFileError(const Status& in_status, const string& filename,
const string& detail) {
if (in_status.ok()) {
return errors::Internal("Unable to read file (", filename,
"). Perhaps the file is corrupt or was produced by "
"a newer version of TensorFlow with format changes "
"(",
detail, ")");
}
return Status(
in_status.code(),
strings::StrCat("Unable to read file (", filename,
"). Perhaps the file is corrupt or was produced by a "
"newer version of TensorFlow with format changes (",
detail, "): ", in_status.error_message()));
}
} // namespace
BundleWriter::BundleWriter(Env* env, StringPiece prefix)
@ -433,11 +450,13 @@ static Status MergeOneBundle(Env* env, StringPiece prefix,
// Process header.
{
iter->Seek(kHeaderEntryKey);
CHECK(iter->Valid()) << "File: " << filename
<< ", iterator status: " << iter->status();
if (!iter->Valid()) {
return CorruptFileError(iter->status(), filename,
"failed to seek to header entry");
}
BundleHeaderProto header;
TF_CHECK_OK(ParseEntryProto(iter->key(), iter->value(), &header));
CHECK_GE(header.num_shards(), 0);
Status s = ParseEntryProto(iter->key(), iter->value(), &header);
if (!s.ok()) return CorruptFileError(s, filename, "unable to parse header");
merge_state->num_shards += header.num_shards();
if (!merge_state->seen_first_bundle) {
@ -584,10 +603,17 @@ BundleReader::BundleReader(Env* env, StringPiece prefix)
// Reads "num_shards_" from the first entry.
iter_->Seek(kHeaderEntryKey);
CHECK(iter_->Valid()) << "File: " << filename
<< ", iterator status: " << iter_->status();
if (!iter_->Valid()) {
status_ = CorruptFileError(iter_->status(), filename,
"failed to seek to header entry");
return;
}
BundleHeaderProto header;
TF_CHECK_OK(ParseEntryProto(iter_->key(), iter_->value(), &header));
status_ = ParseEntryProto(iter_->key(), iter_->value(), &header);
if (!status_.ok()) {
status_ = CorruptFileError(status_, filename, "unable to parse header");
return;
}
num_shards_ = header.num_shards();
if ((header.endianness() == BundleHeaderProto::BIG && port::kLittleEndian) ||
(header.endianness() == BundleHeaderProto::LITTLE &&