Relax type checking for data operand and results of tf.SwitchN
Now we only require that data operand must be broadcastable to results, before we required them to be equal which is problematic for shape inference. The new checking is more consistent with tf.Switch and tf.Merge. Also added more tests for tf.SwitchN and tf.Merge. PiperOrigin-RevId: 304935489 Change-Id: Ife8d1ea097cced6ad1eddd577fb95c0c19648281
This commit is contained in:
parent
be449950ea
commit
e068a4c413
@ -545,13 +545,44 @@ LogicalResult Verify(SwitchNOp switchn) {
|
||||
<< "expect `num_outs` (" << num_outs.getInt() << ") results but got "
|
||||
<< (switchn.getNumResults() - 1);
|
||||
|
||||
// Check that operand can be broadcasted to each output type.
|
||||
auto operand0_type = switchn.getOperand(0).getType();
|
||||
for (Value result : switchn.outputs())
|
||||
if (operand0_type != result.getType())
|
||||
return switchn.emitOpError()
|
||||
<< "type mismatch between data operand and result: "
|
||||
<< operand0_type << " vs " << result.getType();
|
||||
TensorType operand0_tensor_type = operand0_type.dyn_cast<TensorType>();
|
||||
if (!operand0_tensor_type) {
|
||||
return switchn.emitOpError()
|
||||
<< "expects data operand to have tensor type but got "
|
||||
<< operand0_type;
|
||||
}
|
||||
for (Type output_type : switchn.getResultTypes()) {
|
||||
if (output_type.isa<ControlType>()) break;
|
||||
|
||||
TensorType output_tensor_type = output_type.dyn_cast<TensorType>();
|
||||
if (!output_tensor_type) {
|
||||
return switchn.emitOpError()
|
||||
<< "expects outputs to have tensor type but got " << output_type;
|
||||
}
|
||||
|
||||
// If the output type is a ref type, then the operand type should also be of
|
||||
// the same ref type. However, if the output type is a non-ref type T, then
|
||||
// the operand can be tensor of type T or T_REF.
|
||||
bool is_output_ref =
|
||||
output_tensor_type.getElementType().isa<TF::TensorFlowRefType>();
|
||||
if (is_output_ref &&
|
||||
!operand0_tensor_type.getElementType().isa<TF::TensorFlowRefType>()) {
|
||||
return switchn.emitOpError()
|
||||
<< "expects same operand and output element type but got "
|
||||
<< operand0_tensor_type << " vs " << output_tensor_type;
|
||||
}
|
||||
Type broadcasted_type = OpTrait::util::getBroadcastedType(
|
||||
DropRefType(DropTypeSubTypes(operand0_tensor_type)),
|
||||
DropRefType(DropTypeSubTypes(output_tensor_type)));
|
||||
if (!broadcasted_type) {
|
||||
return switchn.emitOpError()
|
||||
<< "expects data operand to be broadcastable with all output types"
|
||||
<< " but got " << operand0_tensor_type << " vs "
|
||||
<< output_tensor_type;
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -405,12 +405,49 @@ func @invalid_switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// Check that switchN result type matches the input type.
|
||||
func @invalid_switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// Check that data operands of SwitchN have tensor type
|
||||
func @invalid_switchN(%arg0: i32, %arg1: tensor<i32>) -> tensor<*xi32> {
|
||||
%result = tf_executor.graph {
|
||||
%1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (i32, tensor<i32>) -> (tensor<*xi32>, tensor<i32>, !tf_executor.control)
|
||||
// expected-error@-1 {{'tf_executor.SwitchN' op expects data operand to have tensor type but got 'i32'}}
|
||||
tf_executor.fetch %1#0 : tensor<*xi32>
|
||||
}
|
||||
return %result : tensor<*xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that result of SwitchN has tensor type
|
||||
func @invalid_switchN(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> i32 {
|
||||
%result = tf_executor.graph {
|
||||
%1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<*xi32>, tensor<i32>) -> (i32, tensor<i32>, !tf_executor.control)
|
||||
// expected-error@-1 {{'tf_executor.SwitchN' op expects outputs to have tensor type but got 'i32'}}
|
||||
tf_executor.fetch %1#0 : i32
|
||||
}
|
||||
return %result : i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that if any result is a ref type, then data operand needs to be ref too.
|
||||
func @invalid_switchN(%arg0: tensor<4xf32>, %arg1: tensor<i32>) -> tensor<4x!tf.f32ref> {
|
||||
%fetches = tf_executor.graph {
|
||||
|
||||
%1:3 = "tf_executor.SwitchN"(%arg1, %arg0) {num_outs = 2} : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, i32, !tf_executor.control)
|
||||
// expected-error@-1 {{'tf_executor.SwitchN' op type mismatch between data operand and result: 'tensor<*xf32>' vs 'i32'}}
|
||||
%1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<4xf32>, tensor<i32>) -> (tensor<4x!tf.f32ref>, tensor<4xf32>, !tf_executor.control)
|
||||
// expected-error@-1 {{'tf_executor.SwitchN' op expects same operand and output element type but got 'tensor<4xf32>' vs 'tensor<4x!tf.f32ref>'}}
|
||||
tf_executor.fetch %1#0 : tensor<4x!tf.f32ref>
|
||||
}
|
||||
return %fetches : tensor<4x!tf.f32ref>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that switchN data operand is broadcastable with all output types
|
||||
func @invalid_switchN(%arg0: tensor<*xf32>, %arg1: tensor<i32>) -> tensor<*xf32> {
|
||||
%fetches = tf_executor.graph {
|
||||
|
||||
%1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<i32>, !tf_executor.control)
|
||||
// expected-error@-1 {{'tf_executor.SwitchN' op expects data operand to be broadcastable with all output types but got 'tensor<*xf32>' vs 'tensor<i32>'}}
|
||||
|
||||
tf_executor.fetch %1#0 : tensor<*xf32>
|
||||
}
|
||||
@ -472,6 +509,30 @@ func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// Check that data operands of merge have tensor type
|
||||
func @invalid_merge(%arg0: tensor<*xi32>, %arg1: i32) -> tensor<*xi32> {
|
||||
%result = tf_executor.graph {
|
||||
%value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<*xi32>, i32) -> (tensor<*xi32>, tensor<i32>, !tf_executor.control)
|
||||
// expected-error@-1 {{'tf_executor.Merge' op expects data operands to have tensor type but got 'i32'}}
|
||||
tf_executor.fetch %value : tensor<*xi32>
|
||||
}
|
||||
return %result : tensor<*xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that result of merge has tensor type
|
||||
func @invalid_merge(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> i32 {
|
||||
%result = tf_executor.graph {
|
||||
%value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<*xi32>, tensor<i32>) -> (i32, tensor<i32>, !tf_executor.control)
|
||||
// expected-error@-1 {{'tf_executor.Merge' op result #0 must be tensor of any type values, but got 'i32'}}
|
||||
tf_executor.fetch %value : i32
|
||||
}
|
||||
return %result : i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that merge data inputs are all the same type
|
||||
func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*xf32> {
|
||||
%result = tf_executor.graph {
|
||||
|
Loading…
Reference in New Issue
Block a user