Fix invalid check of ShapeForPlaceholder

This commit is contained in:
tianyapiaozi 2018-07-14 20:08:05 +08:00
parent ee9a16b203
commit 2cae609660

View File

@ -80,7 +80,7 @@ Status ShapeForPlaceholder(const TransformFuncContext& context,
*result = {}; *result = {};
// Check to see if we have been given a default for all placeholders. // Check to see if we have been given a default for all placeholders.
if (context.params.count("type")) { if (context.params.count("shape")) {
if (context.params.at("shape").size() != 1) { if (context.params.at("shape").size() != 1) {
return errors::InvalidArgument( return errors::InvalidArgument(
"You must pass no more than one default 'shape' to " "You must pass no more than one default 'shape' to "
@ -91,10 +91,10 @@ Status ShapeForPlaceholder(const TransformFuncContext& context,
} }
// See if there's a particular type specified for this placeholder. // See if there's a particular type specified for this placeholder.
if (context.params.count("name") || context.params.count("type_for_name")) { if (context.params.count("name") || context.params.count("shape_for_name")) {
if (!context.params.count("name") || if (!context.params.count("name") ||
!context.params.count("type_for_name") || !context.params.count("shape_for_name") ||
(context.params.at("type_for_name").size() != (context.params.at("shape_for_name").size() !=
context.params.at("name").size())) { context.params.at("name").size())) {
return errors::InvalidArgument( return errors::InvalidArgument(
"You must pass a 'shape_for_name' arg for every 'name', e.g. " "You must pass a 'shape_for_name' arg for every 'name', e.g. "