From e068a4c413403628ca3ae1b9639fbe1fe8927709 Mon Sep 17 00:00:00 2001 From: Michael Gester Date: Sun, 5 Apr 2020 17:54:55 -0700 Subject: [PATCH] Relax type checking for data operand and results of tf.SwitchN Now we only require that data operand must be broadcastable to results, before we required them to be equal which is problematic for shape inference. The new checking is more consistent with tf.Switch and tf.Merge. Also added more tests for tf.SwitchN and tf.Merge. PiperOrigin-RevId: 304935489 Change-Id: Ife8d1ea097cced6ad1eddd577fb95c0c19648281 --- .../mlir/tensorflow/ir/tf_executor.cc | 41 +++++++++-- .../tests/tf_executor_ops_invalid.mlir | 69 +++++++++++++++++-- 2 files changed, 101 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index 8d670d96748..0ca4364f9cd 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -545,13 +545,44 @@ LogicalResult Verify(SwitchNOp switchn) { << "expect `num_outs` (" << num_outs.getInt() << ") results but got " << (switchn.getNumResults() - 1); + // Check that operand can be broadcasted to each output type. auto operand0_type = switchn.getOperand(0).getType(); - for (Value result : switchn.outputs()) - if (operand0_type != result.getType()) - return switchn.emitOpError() - << "type mismatch between data operand and result: " - << operand0_type << " vs " << result.getType(); + TensorType operand0_tensor_type = operand0_type.dyn_cast(); + if (!operand0_tensor_type) { + return switchn.emitOpError() + << "expects data operand to have tensor type but got " + << operand0_type; + } + for (Type output_type : switchn.getResultTypes()) { + if (output_type.isa()) break; + TensorType output_tensor_type = output_type.dyn_cast(); + if (!output_tensor_type) { + return switchn.emitOpError() + << "expects outputs to have tensor type but got " << output_type; + } + + // If the output type is a ref type, then the operand type should also be of + // the same ref type. However, if the output type is a non-ref type T, then + // the operand can be tensor of type T or T_REF. + bool is_output_ref = + output_tensor_type.getElementType().isa(); + if (is_output_ref && + !operand0_tensor_type.getElementType().isa()) { + return switchn.emitOpError() + << "expects same operand and output element type but got " + << operand0_tensor_type << " vs " << output_tensor_type; + } + Type broadcasted_type = OpTrait::util::getBroadcastedType( + DropRefType(DropTypeSubTypes(operand0_tensor_type)), + DropRefType(DropTypeSubTypes(output_tensor_type))); + if (!broadcasted_type) { + return switchn.emitOpError() + << "expects data operand to be broadcastable with all output types" + << " but got " << operand0_tensor_type << " vs " + << output_tensor_type; + } + } return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir index 10cafb354d7..db9db1518d7 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir @@ -405,12 +405,49 @@ func @invalid_switchN(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> // ----- -// Check that switchN result type matches the input type. -func @invalid_switchN(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { +// Check that data operands of SwitchN have tensor type +func @invalid_switchN(%arg0: i32, %arg1: tensor) -> tensor<*xi32> { + %result = tf_executor.graph { + %1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (i32, tensor) -> (tensor<*xi32>, tensor, !tf_executor.control) +// expected-error@-1 {{'tf_executor.SwitchN' op expects data operand to have tensor type but got 'i32'}} + tf_executor.fetch %1#0 : tensor<*xi32> + } + return %result : tensor<*xi32> +} + +// ----- + +// Check that result of SwitchN has tensor type +func @invalid_switchN(%arg0: tensor<*xi32>, %arg1: tensor) -> i32 { + %result = tf_executor.graph { + %1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<*xi32>, tensor) -> (i32, tensor, !tf_executor.control) +// expected-error@-1 {{'tf_executor.SwitchN' op expects outputs to have tensor type but got 'i32'}} + tf_executor.fetch %1#0 : i32 + } + return %result : i32 +} + +// ----- + +// Check that if any result is a ref type, then data operand needs to be ref too. +func @invalid_switchN(%arg0: tensor<4xf32>, %arg1: tensor) -> tensor<4x!tf.f32ref> { %fetches = tf_executor.graph { - %1:3 = "tf_executor.SwitchN"(%arg1, %arg0) {num_outs = 2} : (tensor<*xf32>, tensor) -> (tensor<*xf32>, i32, !tf_executor.control) -// expected-error@-1 {{'tf_executor.SwitchN' op type mismatch between data operand and result: 'tensor<*xf32>' vs 'i32'}} + %1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<4xf32>, tensor) -> (tensor<4x!tf.f32ref>, tensor<4xf32>, !tf_executor.control) +// expected-error@-1 {{'tf_executor.SwitchN' op expects same operand and output element type but got 'tensor<4xf32>' vs 'tensor<4x!tf.f32ref>'}} + tf_executor.fetch %1#0 : tensor<4x!tf.f32ref> + } + return %fetches : tensor<4x!tf.f32ref> +} + +// ----- + +// Check that switchN data operand is broadcastable with all output types +func @invalid_switchN(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf32> { + %fetches = tf_executor.graph { + + %1:3 = "tf_executor.SwitchN"(%arg0, %arg1) {num_outs = 2} : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor, !tf_executor.control) +// expected-error@-1 {{'tf_executor.SwitchN' op expects data operand to be broadcastable with all output types but got 'tensor<*xf32>' vs 'tensor'}} tf_executor.fetch %1#0 : tensor<*xf32> } @@ -472,6 +509,30 @@ func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf32> { // ----- +// Check that data operands of merge have tensor type +func @invalid_merge(%arg0: tensor<*xi32>, %arg1: i32) -> tensor<*xi32> { + %result = tf_executor.graph { + %value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<*xi32>, i32) -> (tensor<*xi32>, tensor, !tf_executor.control) +// expected-error@-1 {{'tf_executor.Merge' op expects data operands to have tensor type but got 'i32'}} + tf_executor.fetch %value : tensor<*xi32> + } + return %result : tensor<*xi32> +} + +// ----- + +// Check that result of merge has tensor type +func @invalid_merge(%arg0: tensor<*xi32>, %arg1: tensor) -> i32 { + %result = tf_executor.graph { + %value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<*xi32>, tensor) -> (i32, tensor, !tf_executor.control) +// expected-error@-1 {{'tf_executor.Merge' op result #0 must be tensor of any type values, but got 'i32'}} + tf_executor.fetch %value : i32 + } + return %result : i32 +} + +// ----- + // Check that merge data inputs are all the same type func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf32> { %result = tf_executor.graph {