diff --git a/tensorflow/tools/graph_transforms/strip_unused_nodes.cc b/tensorflow/tools/graph_transforms/strip_unused_nodes.cc index ae9d0aa2099..cfecbf0e0cd 100644 --- a/tensorflow/tools/graph_transforms/strip_unused_nodes.cc +++ b/tensorflow/tools/graph_transforms/strip_unused_nodes.cc @@ -80,7 +80,7 @@ Status ShapeForPlaceholder(const TransformFuncContext& context, *result = {}; // Check to see if we have been given a default for all placeholders. - if (context.params.count("type")) { + if (context.params.count("shape")) { if (context.params.at("shape").size() != 1) { return errors::InvalidArgument( "You must pass no more than one default 'shape' to " @@ -91,10 +91,10 @@ Status ShapeForPlaceholder(const TransformFuncContext& context, } // See if there's a particular type specified for this placeholder. - if (context.params.count("name") || context.params.count("type_for_name")) { + if (context.params.count("name") || context.params.count("shape_for_name")) { if (!context.params.count("name") || - !context.params.count("type_for_name") || - (context.params.at("type_for_name").size() != + !context.params.count("shape_for_name") || + (context.params.at("shape_for_name").size() != context.params.at("name").size())) { return errors::InvalidArgument( "You must pass a 'shape_for_name' arg for every 'name', e.g. "