Unconditionally re-insert default attributes in GraphConstructor

GraphConstructor used to re-insert default attributes only when it
was emulating python's import_graph_def functionality. It should be
safe to re-insert always instead. This is needed so that whenever
we convert from GraphDef to Graph (a necessary step for execution)
default attributes are added.

The change also adds an option to perform NodeDef validation to
GraphConstructor. This option is needed when we want to convert
a GraphDef without default attributes to Graph with default attributes
and validate (e.g. in GraphMgr::InitItem). Validating as a separate step is:
 - more expensive
 - Existing validation routines work on GraphDefs, not Graphs. We can't
   use them because the GraphDef without default attrs is considered invalid.

This change is the first part of stripping and re-inserting default
attributes across WorkerService, which is needed to support forward
compatibility across RPCs, i.e. an old server trying to run a new graph.

PiperOrigin-RevId: 271168101
This commit is contained in:
Igor Ganichev 2019-09-25 11:30:53 -07:00 committed by TensorFlower Gardener
parent b581d28aca
commit 18c6ed23e7
3 changed files with 27 additions and 1 deletions

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/graph.h"
@ -50,6 +51,7 @@ void RunPass(const GraphDef& original, GraphDef* rewritten,
IsolatePlacerInspectionRequiredOpsPass pass;
TF_ASSERT_OK(pass.Run(options));
graph->ToGraphDef(rewritten);
StripDefaultAttributes(*OpRegistry::Global(), rewritten->mutable_node());
}
void RunPassAndCompare(const GraphDef& original, const GraphDef& expected) {

View File

@ -92,6 +92,7 @@ class GraphConstructor {
: allow_internal_ops(in.allow_internal_ops),
expect_device_spec(in.expect_device_spec),
importing(false),
validate_nodes(in.validate_nodes),
validate_colocation_constraints(false) {}
Options(const ImportGraphDefOptions& in) // NOLINT(runtime/explicit)
: allow_internal_ops(false),
@ -107,6 +108,7 @@ class GraphConstructor {
return_tensors(in.return_tensors.begin(), in.return_tensors.end()),
return_nodes(in.return_nodes),
importing(true),
validate_nodes(true),
validate_colocation_constraints(in.validate_colocation_constraints),
validate_shape(in.validate_shape),
default_device(in.default_device) {}
@ -132,6 +134,10 @@ class GraphConstructor {
// applicable to ConvertGraphDefToGraph as well, so make an attempt to
// remove this.
bool importing;
// If true, validates that nodes being converted have all expected attrs
// set and no unknonw attrs set by calling ValidateNodeDef().
// `validate_nodes` is always true when `importing` is set.
bool validate_nodes;
bool validate_colocation_constraints;
bool validate_shape = true;
@ -1225,8 +1231,22 @@ Status GraphConstructor::Convert() {
if (opts_.uniquify_names && (prefix_.empty() || !opts_.uniquify_prefix)) {
UniquifyNames(input_already_exists, &node_def);
}
TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&node_def));
}
if (opts_.importing) {
TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&node_def));
} else {
const OpDef* op_def;
TF_RETURN_IF_ERROR(
g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
if (opts_.validate_nodes) {
AddDefaultsToNodeDef(*op_def, &node_def);
TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, *op_def));
} else {
AddDefaultsToNodeDef(*op_def, &node_def);
}
}
TF_RETURN_IF_ERROR(MakeNode(std::move(node_def), &node));
if (opts_.importing) {

View File

@ -42,6 +42,10 @@ struct GraphConstructorOptions {
//
// TODO(zhifengc): if possible, consider removing this option.
bool expect_device_spec = false;
// If true, validates that nodes being converted have all expected attrs
// set and no unknonw attrs set by calling ValidateNodeDef().
bool validate_nodes = false;
};
extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
const GraphDef& gdef, Graph* g);