Fix parser bug for tf_executor.Switch
Fixes the following error for tf_executor.Switch with one or more valid control inputs: `custom op 'tf_executor.Switch' expected 2 operands`. PiperOrigin-RevId: 306755642 Change-Id: I960bf399bfa4638090d623b8ca2b5143c707ed42
This commit is contained in:
parent
856f084c17
commit
dba59540c7
|
@ -474,7 +474,7 @@ namespace {
|
||||||
ParseResult ParseSwitchOp(OpAsmParser &parser, OperationState &result) {
|
ParseResult ParseSwitchOp(OpAsmParser &parser, OperationState &result) {
|
||||||
SmallVector<OpAsmParser::OperandType, 2> op_infos;
|
SmallVector<OpAsmParser::OperandType, 2> op_infos;
|
||||||
SmallVector<Type, 1> types;
|
SmallVector<Type, 1> types;
|
||||||
if (parser.parseOperandList(op_infos, 2) || parser.parseColonTypeList(types))
|
if (parser.parseOperandList(op_infos) || parser.parseColonTypeList(types))
|
||||||
return failure();
|
return failure();
|
||||||
if (types.size() != 1)
|
if (types.size() != 1)
|
||||||
return parser.emitError(parser.getNameLoc())
|
return parser.emitError(parser.getNameLoc())
|
||||||
|
@ -486,12 +486,15 @@ ParseResult ParseSwitchOp(OpAsmParser &parser, OperationState &result) {
|
||||||
// type).
|
// type).
|
||||||
if (types.front().isa<FunctionType>()) {
|
if (types.front().isa<FunctionType>()) {
|
||||||
FunctionType type = types.front().cast<FunctionType>();
|
FunctionType type = types.front().cast<FunctionType>();
|
||||||
if (type.getNumInputs() != 2)
|
if (type.getNumInputs() < 2)
|
||||||
return parser.emitError(parser.getNameLoc())
|
return parser.emitError(parser.getNameLoc())
|
||||||
<< " expects a single data type and a predicate";
|
<< " expects a single data type and a predicate";
|
||||||
result.types.assign(type.getResults().begin(), type.getResults().end());
|
result.types.assign(type.getResults().begin(), type.getResults().end());
|
||||||
types.assign(type.getInputs().begin(), type.getInputs().end());
|
types.assign(type.getInputs().begin(), type.getInputs().end());
|
||||||
} else {
|
} else {
|
||||||
|
if (op_infos.size() < 2)
|
||||||
|
return parser.emitError(parser.getNameLoc())
|
||||||
|
<< " expects a single data type and a predicate";
|
||||||
Type control_type = ControlType::get(parser.getBuilder().getContext());
|
Type control_type = ControlType::get(parser.getBuilder().getContext());
|
||||||
result.types.append(2, types[0]);
|
result.types.append(2, types[0]);
|
||||||
result.types.push_back(control_type);
|
result.types.push_back(control_type);
|
||||||
|
|
|
@ -187,6 +187,26 @@ func @switch_with_unranked_pred(%arg0: tensor<*xf32>, %arg1: tensor<*xi1>) -> te
|
||||||
return %result : tensor<*xf32>
|
return %result : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @switch_with_control_inputs(
|
||||||
|
func @switch_with_control_inputs(%arg0: tensor<i1>, %arg1: !tf_executor.control, %arg2: !tf_executor.control) -> tensor<i1> {
|
||||||
|
%result = tf_executor.graph {
|
||||||
|
// CHECK: tf_executor.Switch %{{[^%]*}}, %{{[^%]*}}, %{{[^%]*}}, %{{[^%]*}} : tensor<i1>
|
||||||
|
%1:3 = tf_executor.Switch %arg0, %arg0, %arg1, %arg2 : tensor<i1>
|
||||||
|
tf_executor.fetch %1#0 : tensor<i1>
|
||||||
|
}
|
||||||
|
return %result : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @switch_with_control_inputs_functional(
|
||||||
|
func @switch_with_control_inputs_functional(%arg0: tensor<i1>, %arg1: !tf_executor.control, %arg2: !tf_executor.control) -> tensor<i1> {
|
||||||
|
%result = tf_executor.graph {
|
||||||
|
// CHECK: tf_executor.Switch %{{[^%]*}}, %{{[^%]*}}, %{{[^%]*}}, %{{[^%]*}} : tensor<i1>
|
||||||
|
%1:3 = tf_executor.Switch %arg0, %arg0, %arg1, %arg2 : (tensor<i1>, tensor<i1>, !tf_executor.control, !tf_executor.control) -> (tensor<i1>, tensor<i1>, !tf_executor.control)
|
||||||
|
tf_executor.fetch %1#0 : tensor<i1>
|
||||||
|
}
|
||||||
|
return %result : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @switchN(
|
// CHECK-LABEL: func @switchN(
|
||||||
func @switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
func @switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
%fetches = tf_executor.graph {
|
%fetches = tf_executor.graph {
|
||||||
|
|
|
@ -333,7 +333,7 @@ func @parent_is_graph(%arg0: tensor<*xf32>, %arg1: tensor<i1>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// Check that a switch always takes two arguments.
|
// Check that a switch always needs at least two arguments.
|
||||||
func @invalid_switch(%arg0: tensor<*xf32>) {
|
func @invalid_switch(%arg0: tensor<*xf32>) {
|
||||||
tf_executor.graph {
|
tf_executor.graph {
|
||||||
%true, %false, %ctlSwitch = "tf_executor.Switch"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>, !tf_executor.control)
|
%true, %false, %ctlSwitch = "tf_executor.Switch"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>, !tf_executor.control)
|
||||||
|
@ -344,6 +344,17 @@ func @invalid_switch(%arg0: tensor<*xf32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// Check that a switch always needs at least two arguments.
|
||||||
|
func @invalid_switch(%arg0: tensor<*xf32>) {
|
||||||
|
tf_executor.graph {
|
||||||
|
%true, %false, %ctlSwitch = tf_executor.Switch %arg0 : tensor<*xf32>
|
||||||
|
// expected-error@-1 {{custom op 'tf_executor.Switch' expects a single data type and a predicate}}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// Check that a switch second argument must be a valid predicate (i1).
|
// Check that a switch second argument must be a valid predicate (i1).
|
||||||
func @invalid_switch(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> {
|
func @invalid_switch(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> {
|
||||||
%result = tf_executor.graph {
|
%result = tf_executor.graph {
|
||||||
|
|
Loading…
Reference in New Issue