Fix parser bug for tf_executor.Switch
Fixes the following error for tf_executor.Switch with one or more valid control inputs: `custom op 'tf_executor.Switch' expected 2 operands`. PiperOrigin-RevId: 306755642 Change-Id: I960bf399bfa4638090d623b8ca2b5143c707ed42
This commit is contained in:
parent
856f084c17
commit
dba59540c7
|
@ -474,7 +474,7 @@ namespace {
|
|||
ParseResult ParseSwitchOp(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 2> op_infos;
|
||||
SmallVector<Type, 1> types;
|
||||
if (parser.parseOperandList(op_infos, 2) || parser.parseColonTypeList(types))
|
||||
if (parser.parseOperandList(op_infos) || parser.parseColonTypeList(types))
|
||||
return failure();
|
||||
if (types.size() != 1)
|
||||
return parser.emitError(parser.getNameLoc())
|
||||
|
@ -486,12 +486,15 @@ ParseResult ParseSwitchOp(OpAsmParser &parser, OperationState &result) {
|
|||
// type).
|
||||
if (types.front().isa<FunctionType>()) {
|
||||
FunctionType type = types.front().cast<FunctionType>();
|
||||
if (type.getNumInputs() != 2)
|
||||
if (type.getNumInputs() < 2)
|
||||
return parser.emitError(parser.getNameLoc())
|
||||
<< " expects a single data type and a predicate";
|
||||
result.types.assign(type.getResults().begin(), type.getResults().end());
|
||||
types.assign(type.getInputs().begin(), type.getInputs().end());
|
||||
} else {
|
||||
if (op_infos.size() < 2)
|
||||
return parser.emitError(parser.getNameLoc())
|
||||
<< " expects a single data type and a predicate";
|
||||
Type control_type = ControlType::get(parser.getBuilder().getContext());
|
||||
result.types.append(2, types[0]);
|
||||
result.types.push_back(control_type);
|
||||
|
|
|
@ -187,6 +187,26 @@ func @switch_with_unranked_pred(%arg0: tensor<*xf32>, %arg1: tensor<*xi1>) -> te
|
|||
return %result : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @switch_with_control_inputs(
|
||||
func @switch_with_control_inputs(%arg0: tensor<i1>, %arg1: !tf_executor.control, %arg2: !tf_executor.control) -> tensor<i1> {
|
||||
%result = tf_executor.graph {
|
||||
// CHECK: tf_executor.Switch %{{[^%]*}}, %{{[^%]*}}, %{{[^%]*}}, %{{[^%]*}} : tensor<i1>
|
||||
%1:3 = tf_executor.Switch %arg0, %arg0, %arg1, %arg2 : tensor<i1>
|
||||
tf_executor.fetch %1#0 : tensor<i1>
|
||||
}
|
||||
return %result : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @switch_with_control_inputs_functional(
|
||||
func @switch_with_control_inputs_functional(%arg0: tensor<i1>, %arg1: !tf_executor.control, %arg2: !tf_executor.control) -> tensor<i1> {
|
||||
%result = tf_executor.graph {
|
||||
// CHECK: tf_executor.Switch %{{[^%]*}}, %{{[^%]*}}, %{{[^%]*}}, %{{[^%]*}} : tensor<i1>
|
||||
%1:3 = tf_executor.Switch %arg0, %arg0, %arg1, %arg2 : (tensor<i1>, tensor<i1>, !tf_executor.control, !tf_executor.control) -> (tensor<i1>, tensor<i1>, !tf_executor.control)
|
||||
tf_executor.fetch %1#0 : tensor<i1>
|
||||
}
|
||||
return %result : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @switchN(
|
||||
func @switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%fetches = tf_executor.graph {
|
||||
|
|
|
@ -333,7 +333,7 @@ func @parent_is_graph(%arg0: tensor<*xf32>, %arg1: tensor<i1>) {
|
|||
|
||||
// -----
|
||||
|
||||
// Check that a switch always takes two arguments.
|
||||
// Check that a switch always needs at least two arguments.
|
||||
func @invalid_switch(%arg0: tensor<*xf32>) {
|
||||
tf_executor.graph {
|
||||
%true, %false, %ctlSwitch = "tf_executor.Switch"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>, !tf_executor.control)
|
||||
|
@ -344,6 +344,17 @@ func @invalid_switch(%arg0: tensor<*xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
// Check that a switch always needs at least two arguments.
|
||||
func @invalid_switch(%arg0: tensor<*xf32>) {
|
||||
tf_executor.graph {
|
||||
%true, %false, %ctlSwitch = tf_executor.Switch %arg0 : tensor<*xf32>
|
||||
// expected-error@-1 {{custom op 'tf_executor.Switch' expects a single data type and a predicate}}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that a switch second argument must be a valid predicate (i1).
|
||||
func @invalid_switch(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> {
|
||||
%result = tf_executor.graph {
|
||||
|
|
Loading…
Reference in New Issue