diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index 8d670d96748..0ca4364f9cd 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -545,13 +545,44 @@ LogicalResult Verify(SwitchNOp switchn) { << "expect `num_outs` (" << num_outs.getInt() << ") results but got " << (switchn.getNumResults() - 1); + // Check that operand can be broadcasted to each output type. auto operand0_type = switchn.getOperand(0).getType(); - for (Value result : switchn.outputs()) - if (operand0_type != result.getType()) - return switchn.emitOpError() - << "type mismatch between data operand and result: " - << operand0_type << " vs " << result.getType(); + TensorType operand0_tensor_type = operand0_type.dyn_cast(); + if (!operand0_tensor_type) { + return switchn.emitOpError() + << "expects data operand to have tensor type but got " + << operand0_type; + } + for (Type output_type : switchn.getResultTypes()) { + if (output_type.isa()) break; + TensorType output_tensor_type = output_type.dyn_cast(); + if (!output_tensor_type) { + return switchn.emitOpError() + << "expects outputs to have tensor type but got " << output_type; + } + + // If the output type is a ref type, then the operand type should also be of + // the same ref type. However, if the output type is a non-ref type T, then + // the operand can be tensor of type T or T_REF. + bool is_output_ref = + output_tensor_type.getElementType().isa(); + if (is_output_ref && + !operand0_tensor_type.getElementType().isa()) { + return switchn.emitOpError() + << "expects same operand and output element type but got " + << operand0_tensor_type << " vs " << output_tensor_type; + } + Type broadcasted_type = OpTrait::util::getBroadcastedType( + DropRefType(DropTypeSubTypes(operand0_tensor_type)), + DropRefType(DropTypeSubTypes(output_tensor_type))); + if (!broadcasted_type) { + return switchn.emitOpError() + << "expects data operand to be broadcastable with all output types" + << " but got " << operand0_tensor_type << " vs " + << output_tensor_type; + } + } return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir index 10cafb354d7..db9db1518d7 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir @@ -405,12 +405,49 @@ func @invalid_switchN(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> // ----- -// Check that switchN result type matches the input type. -func @invalid_switchN(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { +// Check that data operands of SwitchN have tensor type +func @invalid_switchN(%arg0: i32, %arg1: tensor) -> tensor<*xi32> { + %result = tf_executor.graph { + %1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (i32, tensor) -> (tensor<*xi32>, tensor, !tf_executor.control) +// expected-error@-1 {{'tf_executor.SwitchN' op expects data operand to have tensor type but got 'i32'}} + tf_executor.fetch %1#0 : tensor<*xi32> + } + return %result : tensor<*xi32> +} + +// ----- + +// Check that result of SwitchN has tensor type +func @invalid_switchN(%arg0: tensor<*xi32>, %arg1: tensor) -> i32 { + %result = tf_executor.graph { + %1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<*xi32>, tensor) -> (i32, tensor, !tf_executor.control) +// expected-error@-1 {{'tf_executor.SwitchN' op expects outputs to have tensor type but got 'i32'}} + tf_executor.fetch %1#0 : i32 + } + return %result : i32 +} + +// ----- + +// Check that if any result is a ref type, then data operand needs to be ref too. +func @invalid_switchN(%arg0: tensor<4xf32>, %arg1: tensor) -> tensor<4x!tf.f32ref> { %fetches = tf_executor.graph { - %1:3 = "tf_executor.SwitchN"(%arg1, %arg0) {num_outs = 2} : (tensor<*xf32>, tensor) -> (tensor<*xf32>, i32, !tf_executor.control) -// expected-error@-1 {{'tf_executor.SwitchN' op type mismatch between data operand and result: 'tensor<*xf32>' vs 'i32'}} + %1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<4xf32>, tensor) -> (tensor<4x!tf.f32ref>, tensor<4xf32>, !tf_executor.control) +// expected-error@-1 {{'tf_executor.SwitchN' op expects same operand and output element type but got 'tensor<4xf32>' vs 'tensor<4x!tf.f32ref>'}} + tf_executor.fetch %1#0 : tensor<4x!tf.f32ref> + } + return %fetches : tensor<4x!tf.f32ref> +} + +// ----- + +// Check that switchN data operand is broadcastable with all output types +func @invalid_switchN(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf32> { + %fetches = tf_executor.graph { + + %1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor, !tf_executor.control) +// expected-error@-1 {{'tf_executor.SwitchN' op expects data operand to be broadcastable with all output types but got 'tensor<*xf32>' vs 'tensor'}} tf_executor.fetch %1#0 : tensor<*xf32> } @@ -472,6 +509,30 @@ func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf32> { // ----- +// Check that data operands of merge have tensor type +func @invalid_merge(%arg0: tensor<*xi32>, %arg1: i32) -> tensor<*xi32> { + %result = tf_executor.graph { + %value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<*xi32>, i32) -> (tensor<*xi32>, tensor, !tf_executor.control) +// expected-error@-1 {{'tf_executor.Merge' op expects data operands to have tensor type but got 'i32'}} + tf_executor.fetch %value : tensor<*xi32> + } + return %result : tensor<*xi32> +} + +// ----- + +// Check that result of merge has tensor type +func @invalid_merge(%arg0: tensor<*xi32>, %arg1: tensor) -> i32 { + %result = tf_executor.graph { + %value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<*xi32>, tensor) -> (i32, tensor, !tf_executor.control) +// expected-error@-1 {{'tf_executor.Merge' op result #0 must be tensor of any type values, but got 'i32'}} + tf_executor.fetch %value : i32 + } + return %result : i32 +} + +// ----- + // Check that merge data inputs are all the same type func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf32> { %result = tf_executor.graph {