From c9cc388dc2beb1325fa624adb458620625683e4d Mon Sep 17 00:00:00 2001 From: Asim Shankar Date: Wed, 31 May 2017 12:39:14 -0700 Subject: [PATCH] Avoid CHECKs in BundleReader, propagate errors instead. Motivation: We'd like to evolve the checkpoint format over time (e.g., enable different types of compression). Without this change, a TensorFlow version that encounters a format that it doesn't understand would CHECK fail with an unhelpful error message. With this, it propagates a clearer error message up, giving the user some hints about what could be wrong. I don't have a unittest for this - I thought about writing a bundle and then strategically corrupting the bytes on disk before reading it back, but that seems a bit much. The intention of this change is to enable graceful reporting of forward compatibility breakages. Ideas for an appropriate unittest are appreciated. PiperOrigin-RevId: 157620358 --- .../core/util/tensor_bundle/tensor_bundle.cc | 40 +++++++++++++++---- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc index 5c2bda4770f..dd04cea40d1 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc @@ -238,6 +238,23 @@ bool IsFullSlice(const TensorSlice& slice_spec, } } +Status CorruptFileError(const Status& in_status, const string& filename, + const string& detail) { + if (in_status.ok()) { + return errors::Internal("Unable to read file (", filename, + "). Perhaps the file is corrupt or was produced by " + "a newer version of TensorFlow with format changes " + "(", + detail, ")"); + } + return Status( + in_status.code(), + strings::StrCat("Unable to read file (", filename, + "). Perhaps the file is corrupt or was produced by a " + "newer version of TensorFlow with format changes (", + detail, "): ", in_status.error_message())); +} + } // namespace BundleWriter::BundleWriter(Env* env, StringPiece prefix) @@ -433,11 +450,13 @@ static Status MergeOneBundle(Env* env, StringPiece prefix, // Process header. { iter->Seek(kHeaderEntryKey); - CHECK(iter->Valid()) << "File: " << filename - << ", iterator status: " << iter->status(); + if (!iter->Valid()) { + return CorruptFileError(iter->status(), filename, + "failed to seek to header entry"); + } BundleHeaderProto header; - TF_CHECK_OK(ParseEntryProto(iter->key(), iter->value(), &header)); - CHECK_GE(header.num_shards(), 0); + Status s = ParseEntryProto(iter->key(), iter->value(), &header); + if (!s.ok()) return CorruptFileError(s, filename, "unable to parse header"); merge_state->num_shards += header.num_shards(); if (!merge_state->seen_first_bundle) { @@ -584,10 +603,17 @@ BundleReader::BundleReader(Env* env, StringPiece prefix) // Reads "num_shards_" from the first entry. iter_->Seek(kHeaderEntryKey); - CHECK(iter_->Valid()) << "File: " << filename - << ", iterator status: " << iter_->status(); + if (!iter_->Valid()) { + status_ = CorruptFileError(iter_->status(), filename, + "failed to seek to header entry"); + return; + } BundleHeaderProto header; - TF_CHECK_OK(ParseEntryProto(iter_->key(), iter_->value(), &header)); + status_ = ParseEntryProto(iter_->key(), iter_->value(), &header); + if (!status_.ok()) { + status_ = CorruptFileError(status_, filename, "unable to parse header"); + return; + } num_shards_ = header.num_shards(); if ((header.endianness() == BundleHeaderProto::BIG && port::kLittleEndian) || (header.endianness() == BundleHeaderProto::LITTLE &&