Fix invalid check of ShapeForPlaceholder
This commit is contained in:
parent
ee9a16b203
commit
2cae609660
@ -80,7 +80,7 @@ Status ShapeForPlaceholder(const TransformFuncContext& context,
|
||||
*result = {};
|
||||
|
||||
// Check to see if we have been given a default for all placeholders.
|
||||
if (context.params.count("type")) {
|
||||
if (context.params.count("shape")) {
|
||||
if (context.params.at("shape").size() != 1) {
|
||||
return errors::InvalidArgument(
|
||||
"You must pass no more than one default 'shape' to "
|
||||
@ -91,10 +91,10 @@ Status ShapeForPlaceholder(const TransformFuncContext& context,
|
||||
}
|
||||
|
||||
// See if there's a particular type specified for this placeholder.
|
||||
if (context.params.count("name") || context.params.count("type_for_name")) {
|
||||
if (context.params.count("name") || context.params.count("shape_for_name")) {
|
||||
if (!context.params.count("name") ||
|
||||
!context.params.count("type_for_name") ||
|
||||
(context.params.at("type_for_name").size() !=
|
||||
!context.params.count("shape_for_name") ||
|
||||
(context.params.at("shape_for_name").size() !=
|
||||
context.params.at("name").size())) {
|
||||
return errors::InvalidArgument(
|
||||
"You must pass a 'shape_for_name' arg for every 'name', e.g. "
|
||||
|
Loading…
Reference in New Issue
Block a user