Fix invalid check of ShapeForPlaceholder
This commit is contained in:
parent
ee9a16b203
commit
2cae609660
@ -80,7 +80,7 @@ Status ShapeForPlaceholder(const TransformFuncContext& context,
|
|||||||
*result = {};
|
*result = {};
|
||||||
|
|
||||||
// Check to see if we have been given a default for all placeholders.
|
// Check to see if we have been given a default for all placeholders.
|
||||||
if (context.params.count("type")) {
|
if (context.params.count("shape")) {
|
||||||
if (context.params.at("shape").size() != 1) {
|
if (context.params.at("shape").size() != 1) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"You must pass no more than one default 'shape' to "
|
"You must pass no more than one default 'shape' to "
|
||||||
@ -91,10 +91,10 @@ Status ShapeForPlaceholder(const TransformFuncContext& context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// See if there's a particular type specified for this placeholder.
|
// See if there's a particular type specified for this placeholder.
|
||||||
if (context.params.count("name") || context.params.count("type_for_name")) {
|
if (context.params.count("name") || context.params.count("shape_for_name")) {
|
||||||
if (!context.params.count("name") ||
|
if (!context.params.count("name") ||
|
||||||
!context.params.count("type_for_name") ||
|
!context.params.count("shape_for_name") ||
|
||||||
(context.params.at("type_for_name").size() !=
|
(context.params.at("shape_for_name").size() !=
|
||||||
context.params.at("name").size())) {
|
context.params.at("name").size())) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"You must pass a 'shape_for_name' arg for every 'name', e.g. "
|
"You must pass a 'shape_for_name' arg for every 'name', e.g. "
|
||||||
|
Loading…
Reference in New Issue
Block a user