From 18c6ed23e77deff0fa8ab5cd68dbf18971400b38 Mon Sep 17 00:00:00 2001 From: Igor Ganichev Date: Wed, 25 Sep 2019 11:30:53 -0700 Subject: [PATCH] Unconditionally re-insert default attributes in GraphConstructor GraphConstructor used to re-insert default attributes only when it was emulating python's import_graph_def functionality. It should be safe to re-insert always instead. This is needed so that whenever we convert from GraphDef to Graph (a necessary step for execution) default attributes are added. The change also adds an option to perform NodeDef validation to GraphConstructor. This option is needed when we want to convert a GraphDef without default attributes to Graph with default attributes and validate (e.g. in GraphMgr::InitItem). Validating as a separate step is: - more expensive - Existing validation routines work on GraphDefs, not Graphs. We can't use them because the GraphDef without default attrs is considered invalid. This change is the first part of stripping and re-inserting default attributes across WorkerService, which is needed to support forward compatibility across RPCs, i.e. an old server trying to run a new graph. PiperOrigin-RevId: 271168101 --- ...lacer_inspection_required_ops_pass_test.cc | 2 ++ tensorflow/core/graph/graph_constructor.cc | 22 ++++++++++++++++++- tensorflow/core/graph/graph_constructor.h | 4 ++++ 3 files changed, 27 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass_test.cc b/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass_test.cc index 6fb01c3b28a..904265fdc28 100644 --- a/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass_test.cc +++ b/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/graph.h" @@ -50,6 +51,7 @@ void RunPass(const GraphDef& original, GraphDef* rewritten, IsolatePlacerInspectionRequiredOpsPass pass; TF_ASSERT_OK(pass.Run(options)); graph->ToGraphDef(rewritten); + StripDefaultAttributes(*OpRegistry::Global(), rewritten->mutable_node()); } void RunPassAndCompare(const GraphDef& original, const GraphDef& expected) { diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 65d4ee9c561..a945fd26682 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -92,6 +92,7 @@ class GraphConstructor { : allow_internal_ops(in.allow_internal_ops), expect_device_spec(in.expect_device_spec), importing(false), + validate_nodes(in.validate_nodes), validate_colocation_constraints(false) {} Options(const ImportGraphDefOptions& in) // NOLINT(runtime/explicit) : allow_internal_ops(false), @@ -107,6 +108,7 @@ class GraphConstructor { return_tensors(in.return_tensors.begin(), in.return_tensors.end()), return_nodes(in.return_nodes), importing(true), + validate_nodes(true), validate_colocation_constraints(in.validate_colocation_constraints), validate_shape(in.validate_shape), default_device(in.default_device) {} @@ -132,6 +134,10 @@ class GraphConstructor { // applicable to ConvertGraphDefToGraph as well, so make an attempt to // remove this. bool importing; + // If true, validates that nodes being converted have all expected attrs + // set and no unknonw attrs set by calling ValidateNodeDef(). + // `validate_nodes` is always true when `importing` is set. + bool validate_nodes; bool validate_colocation_constraints; bool validate_shape = true; @@ -1225,8 +1231,22 @@ Status GraphConstructor::Convert() { if (opts_.uniquify_names && (prefix_.empty() || !opts_.uniquify_prefix)) { UniquifyNames(input_already_exists, &node_def); } - TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&node_def)); } + + if (opts_.importing) { + TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&node_def)); + } else { + const OpDef* op_def; + TF_RETURN_IF_ERROR( + g_->op_registry()->LookUpOpDef(node_def.op(), &op_def)); + if (opts_.validate_nodes) { + AddDefaultsToNodeDef(*op_def, &node_def); + TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *op_def)); + } else { + AddDefaultsToNodeDef(*op_def, &node_def); + } + } + TF_RETURN_IF_ERROR(MakeNode(std::move(node_def), &node)); if (opts_.importing) { diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h index 209266877c2..3e9491951c5 100644 --- a/tensorflow/core/graph/graph_constructor.h +++ b/tensorflow/core/graph/graph_constructor.h @@ -42,6 +42,10 @@ struct GraphConstructorOptions { // // TODO(zhifengc): if possible, consider removing this option. bool expect_device_spec = false; + + // If true, validates that nodes being converted have all expected attrs + // set and no unknonw attrs set by calling ValidateNodeDef(). + bool validate_nodes = false; }; extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, const GraphDef& gdef, Graph* g);