Extend error message when super new GraphDef import fails

PiperOrigin-RevId: 302064508
Change-Id: Ibc3174157ebf91811628c6d49fb225d991d1f6c9
This commit is contained in:
Igor Ganichev 2020-03-20 11:19:45 -07:00 committed by TensorFlower Gardener
parent 7d7a5c9b4f
commit 0ffd38260a
3 changed files with 71 additions and 5 deletions

View File

@ -2520,6 +2520,7 @@ tf_cuda_library(
"//third_party/eigen3",
] + if_static([
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
]),
)

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
@ -457,6 +458,33 @@ class NodeDefMovingGraphConstructor : public GraphConstructor {
std::vector<bool> is_consumed_;
};
bool ForwardCompatibilityWindowPassed(const VersionDef& versions) {
// TF_GRAPH_DEF_VERSION is incremented daily.
// TF has a 3 week forward compatibility guarantee.
return (versions.producer() - TF_GRAPH_DEF_VERSION) > 21;
}
Status MaybeAppendVersionWarning(const VersionDef* versions,
const Status& import_status) {
if (versions && ForwardCompatibilityWindowPassed(*versions)) {
return Status(
import_status.code(),
absl::StrCat(
"Converting GraphDef to Graph has failed. The binary trying to "
"import the GraphDef was built when GraphDef version was ",
TF_GRAPH_DEF_VERSION,
". The GraphDef was produced by a binary built when GraphDef "
"version was ",
versions->producer(),
". The difference between these versions is larger than "
"TensorFlow's forward compatibility guarantee. The following error "
"might be due to the binary trying to import the GraphDef being "
"too old: ",
import_status.error_message()));
}
return import_status;
}
/* static */ Status GraphConstructor::Construct(
const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
@ -471,8 +499,11 @@ class NodeDefMovingGraphConstructor : public GraphConstructor {
NodeDefCopyingGraphConstructor c(opts, node_defs, versions, library, g,
refiner, return_tensors, return_nodes,
missing_unused_input_map_keys);
const Status s = c.TryImport();
if (!s.ok()) c.Undo();
Status s = c.TryImport();
if (!s.ok()) {
c.Undo();
s = MaybeAppendVersionWarning(versions, s);
}
return s;
}
@ -484,11 +515,15 @@ class NodeDefMovingGraphConstructor : public GraphConstructor {
TF_RETURN_IF_ERROR(CheckVersions(graph_def.versions(), TF_GRAPH_DEF_VERSION,
TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
"GraphDef", "graph"));
VersionDef version_def = graph_def.versions();
NodeDefMovingGraphConstructor c(opts, std::move(graph_def), g, refiner,
return_tensors, return_nodes,
missing_unused_input_map_keys);
const Status s = c.TryImport();
if (!s.ok()) c.Undo();
Status s = c.TryImport();
if (!s.ok()) {
c.Undo();
s = MaybeAppendVersionWarning(&version_def, s);
}
return s;
}

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include <vector>
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/graph.pb.h"
@ -51,7 +52,8 @@ class GraphConstructorTest : public ::testing::Test {
}
void ExpectError(const string& gdef_ascii,
const std::vector<string>& expected_error_strs) {
const std::vector<string>& expected_error_strs,
string not_expected_error_str = "") {
// Used to verify that errors don't change graph
const string original_graph_description = GraphDebugString();
@ -65,6 +67,13 @@ class GraphConstructorTest : public ::testing::Test {
<< "Expected to find '" << error << "' in " << status;
}
if (!not_expected_error_str.empty()) {
EXPECT_TRUE(status.error_message().find(not_expected_error_str) ==
string::npos)
<< "Expected not to find '" << not_expected_error_str << "' in "
<< status;
}
EXPECT_EQ(original_graph_description, GraphDebugString());
}
@ -825,6 +834,27 @@ TEST_F(GraphConstructorTest, VersionGraph) {
ExpectVersions(TF_GRAPH_DEF_VERSION_MIN_CONSUMER, TF_GRAPH_DEF_VERSION);
}
TEST_F(GraphConstructorTest, ForwardCompatError) {
ExpectError(
strings::StrCat(
"node { name: 'a:b' op: 'ABC' }\n" // 'a:b' is an invalid name.
"versions { producer: ",
TF_GRAPH_DEF_VERSION + 22,
" min_consumer: ", TF_GRAPH_DEF_VERSION_MIN_CONSUMER, "}"),
{"forward compatibility guarantee"});
}
TEST_F(GraphConstructorTest, NoForwardCompatError) {
ExpectError(
strings::StrCat(
"node { name: 'a:b' op: 'ABC' }\n" // 'a:b' is an invalid name.
"versions { producer: ",
TF_GRAPH_DEF_VERSION + 21,
" min_consumer: ", TF_GRAPH_DEF_VERSION_MIN_CONSUMER, "}"),
{"Node name contains invalid characters"},
"forward compatibility guarantee");
}
TEST_F(GraphConstructorTest, LowVersion) {
ExpectError(strings::StrCat("versions { producer: ", -1, " }"),
{strings::StrCat("GraphDef producer version -1 below min "