Relax type checking for data operand and results of tf.SwitchN

Now we only require that data operand must be broadcastable to results, before
we required them to be equal which is problematic for shape inference. The new
checking is more consistent with tf.Switch and tf.Merge.
Also added more tests for tf.SwitchN and tf.Merge.

PiperOrigin-RevId: 304935489
Change-Id: Ife8d1ea097cced6ad1eddd577fb95c0c19648281
This commit is contained in:
Michael Gester 2020-04-05 17:54:55 -07:00 committed by TensorFlower Gardener
parent be449950ea
commit e068a4c413
2 changed files with 101 additions and 9 deletions

View File

@ -545,13 +545,44 @@ LogicalResult Verify(SwitchNOp switchn) {
<< "expect `num_outs` (" << num_outs.getInt() << ") results but got "
<< (switchn.getNumResults() - 1);
// Check that operand can be broadcasted to each output type.
auto operand0_type = switchn.getOperand(0).getType();
for (Value result : switchn.outputs())
if (operand0_type != result.getType())
return switchn.emitOpError()
<< "type mismatch between data operand and result: "
<< operand0_type << " vs " << result.getType();
TensorType operand0_tensor_type = operand0_type.dyn_cast<TensorType>();
if (!operand0_tensor_type) {
return switchn.emitOpError()
<< "expects data operand to have tensor type but got "
<< operand0_type;
}
for (Type output_type : switchn.getResultTypes()) {
if (output_type.isa<ControlType>()) break;
TensorType output_tensor_type = output_type.dyn_cast<TensorType>();
if (!output_tensor_type) {
return switchn.emitOpError()
<< "expects outputs to have tensor type but got " << output_type;
}
// If the output type is a ref type, then the operand type should also be of
// the same ref type. However, if the output type is a non-ref type T, then
// the operand can be tensor of type T or T_REF.
bool is_output_ref =
output_tensor_type.getElementType().isa<TF::TensorFlowRefType>();
if (is_output_ref &&
!operand0_tensor_type.getElementType().isa<TF::TensorFlowRefType>()) {
return switchn.emitOpError()
<< "expects same operand and output element type but got "
<< operand0_tensor_type << " vs " << output_tensor_type;
}
Type broadcasted_type = OpTrait::util::getBroadcastedType(
DropRefType(DropTypeSubTypes(operand0_tensor_type)),
DropRefType(DropTypeSubTypes(output_tensor_type)));
if (!broadcasted_type) {
return switchn.emitOpError()
<< "expects data operand to be broadcastable with all output types"
<< " but got " << operand0_tensor_type << " vs "
<< output_tensor_type;
}
}
return success();
}

View File

@ -405,12 +405,49 @@ func @invalid_switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32>
// -----
// Check that switchN result type matches the input type.
func @invalid_switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
// Check that data operands of SwitchN have tensor type
func @invalid_switchN(%arg0: i32, %arg1: tensor<i32>) -> tensor<*xi32> {
%result = tf_executor.graph {
%1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (i32, tensor<i32>) -> (tensor<*xi32>, tensor<i32>, !tf_executor.control)
// expected-error@-1 {{'tf_executor.SwitchN' op expects data operand to have tensor type but got 'i32'}}
tf_executor.fetch %1#0 : tensor<*xi32>
}
return %result : tensor<*xi32>
}
// -----
// Check that result of SwitchN has tensor type
func @invalid_switchN(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> i32 {
%result = tf_executor.graph {
%1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<*xi32>, tensor<i32>) -> (i32, tensor<i32>, !tf_executor.control)
// expected-error@-1 {{'tf_executor.SwitchN' op expects outputs to have tensor type but got 'i32'}}
tf_executor.fetch %1#0 : i32
}
return %result : i32
}
// -----
// Check that if any result is a ref type, then data operand needs to be ref too.
func @invalid_switchN(%arg0: tensor<4xf32>, %arg1: tensor<i32>) -> tensor<4x!tf.f32ref> {
%fetches = tf_executor.graph {
%1:3 = "tf_executor.SwitchN"(%arg1, %arg0) {num_outs = 2} : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, i32, !tf_executor.control)
// expected-error@-1 {{'tf_executor.SwitchN' op type mismatch between data operand and result: 'tensor<*xf32>' vs 'i32'}}
%1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<4xf32>, tensor<i32>) -> (tensor<4x!tf.f32ref>, tensor<4xf32>, !tf_executor.control)
// expected-error@-1 {{'tf_executor.SwitchN' op expects same operand and output element type but got 'tensor<4xf32>' vs 'tensor<4x!tf.f32ref>'}}
tf_executor.fetch %1#0 : tensor<4x!tf.f32ref>
}
return %fetches : tensor<4x!tf.f32ref>
}
// -----
// Check that switchN data operand is broadcastable with all output types
func @invalid_switchN(%arg0: tensor<*xf32>, %arg1: tensor<i32>) -> tensor<*xf32> {
%fetches = tf_executor.graph {
%1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<i32>, !tf_executor.control)
// expected-error@-1 {{'tf_executor.SwitchN' op expects data operand to be broadcastable with all output types but got 'tensor<*xf32>' vs 'tensor<i32>'}}
tf_executor.fetch %1#0 : tensor<*xf32>
}
@ -472,6 +509,30 @@ func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
// -----
// Check that data operands of merge have tensor type
func @invalid_merge(%arg0: tensor<*xi32>, %arg1: i32) -> tensor<*xi32> {
%result = tf_executor.graph {
%value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<*xi32>, i32) -> (tensor<*xi32>, tensor<i32>, !tf_executor.control)
// expected-error@-1 {{'tf_executor.Merge' op expects data operands to have tensor type but got 'i32'}}
tf_executor.fetch %value : tensor<*xi32>
}
return %result : tensor<*xi32>
}
// -----
// Check that result of merge has tensor type
func @invalid_merge(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> i32 {
%result = tf_executor.graph {
%value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<*xi32>, tensor<i32>) -> (i32, tensor<i32>, !tf_executor.control)
// expected-error@-1 {{'tf_executor.Merge' op result #0 must be tensor of any type values, but got 'i32'}}
tf_executor.fetch %value : i32
}
return %result : i32
}
// -----
// Check that merge data inputs are all the same type
func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
%result = tf_executor.graph {