Fix parser bug for tf_executor.Switch

Fixes the following error for tf_executor.Switch with one or more valid control
inputs: `custom op 'tf_executor.Switch' expected 2 operands`.
PiperOrigin-RevId: 306755642
Change-Id: I960bf399bfa4638090d623b8ca2b5143c707ed42
This commit is contained in:
Michael Gester 2020-04-15 17:46:50 -07:00 committed by TensorFlower Gardener
parent 856f084c17
commit dba59540c7
3 changed files with 37 additions and 3 deletions

View File

@ -474,7 +474,7 @@ namespace {
ParseResult ParseSwitchOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 2> op_infos;
SmallVector<Type, 1> types;
if (parser.parseOperandList(op_infos, 2) || parser.parseColonTypeList(types))
if (parser.parseOperandList(op_infos) || parser.parseColonTypeList(types))
return failure();
if (types.size() != 1)
return parser.emitError(parser.getNameLoc())
@ -486,12 +486,15 @@ ParseResult ParseSwitchOp(OpAsmParser &parser, OperationState &result) {
// type).
if (types.front().isa<FunctionType>()) {
FunctionType type = types.front().cast<FunctionType>();
if (type.getNumInputs() != 2)
if (type.getNumInputs() < 2)
return parser.emitError(parser.getNameLoc())
<< " expects a single data type and a predicate";
result.types.assign(type.getResults().begin(), type.getResults().end());
types.assign(type.getInputs().begin(), type.getInputs().end());
} else {
if (op_infos.size() < 2)
return parser.emitError(parser.getNameLoc())
<< " expects a single data type and a predicate";
Type control_type = ControlType::get(parser.getBuilder().getContext());
result.types.append(2, types[0]);
result.types.push_back(control_type);

View File

@ -187,6 +187,26 @@ func @switch_with_unranked_pred(%arg0: tensor<*xf32>, %arg1: tensor<*xi1>) -> te
return %result : tensor<*xf32>
}
// CHECK-LABEL: func @switch_with_control_inputs(
func @switch_with_control_inputs(%arg0: tensor<i1>, %arg1: !tf_executor.control, %arg2: !tf_executor.control) -> tensor<i1> {
%result = tf_executor.graph {
// CHECK: tf_executor.Switch %{{[^%]*}}, %{{[^%]*}}, %{{[^%]*}}, %{{[^%]*}} : tensor<i1>
%1:3 = tf_executor.Switch %arg0, %arg0, %arg1, %arg2 : tensor<i1>
tf_executor.fetch %1#0 : tensor<i1>
}
return %result : tensor<i1>
}
// CHECK-LABEL: func @switch_with_control_inputs_functional(
func @switch_with_control_inputs_functional(%arg0: tensor<i1>, %arg1: !tf_executor.control, %arg2: !tf_executor.control) -> tensor<i1> {
%result = tf_executor.graph {
// CHECK: tf_executor.Switch %{{[^%]*}}, %{{[^%]*}}, %{{[^%]*}}, %{{[^%]*}} : tensor<i1>
%1:3 = tf_executor.Switch %arg0, %arg0, %arg1, %arg2 : (tensor<i1>, tensor<i1>, !tf_executor.control, !tf_executor.control) -> (tensor<i1>, tensor<i1>, !tf_executor.control)
tf_executor.fetch %1#0 : tensor<i1>
}
return %result : tensor<i1>
}
// CHECK-LABEL: func @switchN(
func @switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%fetches = tf_executor.graph {

View File

@ -333,7 +333,7 @@ func @parent_is_graph(%arg0: tensor<*xf32>, %arg1: tensor<i1>) {
// -----
// Check that a switch always takes two arguments.
// Check that a switch always needs at least two arguments.
func @invalid_switch(%arg0: tensor<*xf32>) {
tf_executor.graph {
%true, %false, %ctlSwitch = "tf_executor.Switch"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>, !tf_executor.control)
@ -344,6 +344,17 @@ func @invalid_switch(%arg0: tensor<*xf32>) {
// -----
// Check that a switch always needs at least two arguments.
func @invalid_switch(%arg0: tensor<*xf32>) {
tf_executor.graph {
%true, %false, %ctlSwitch = tf_executor.Switch %arg0 : tensor<*xf32>
// expected-error@-1 {{custom op 'tf_executor.Switch' expects a single data type and a predicate}}
}
return
}
// -----
// Check that a switch second argument must be a valid predicate (i1).
func @invalid_switch(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> {
%result = tf_executor.graph {