Fix bug in xla-legalize-tf-with-tf2xla pass by handling non-tensor operands
Currently, it only expects tensor operands but that is not applicable for non tensorflow dialect ops. PiperOrigin-RevId: 317198672 Change-Id: I1387e664de740d044ef535f6903e07d63fa02f6d
This commit is contained in:
parent
e08382691b
commit
39504c25d9
tensorflow/compiler/mlir/xla
@ -35,7 +35,7 @@ func @not_whitelisted_op(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor
|
||||
// CHECK-LABEL: unranked_operand
|
||||
func @unranked_operand(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: tf.Abs
|
||||
// expected-remark@+1 {{lowering requires static shaped operands}}
|
||||
// expected-remark@+1 {{lowering requires static shaped tensor operands}}
|
||||
%0 = "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
|
||||
return %0 : tensor<*xf32>
|
||||
@ -44,12 +44,20 @@ func @unranked_operand(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-LABEL: dynamic_operand
|
||||
func @dynamic_operand(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: tf.Abs
|
||||
// expected-remark@+1 {{lowering requires static shaped operands}}
|
||||
// expected-remark@+1 {{lowering requires static shaped tensor operands}}
|
||||
%0 = "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: tuple_type
|
||||
func @tuple_type(%arg0: tuple<tensor<f32>, tensor<i32>>) -> tensor<f32> {
|
||||
// Verifies that the pass can handle operands of non-tensor type like tuple
|
||||
// from non TensorFlow ops.
|
||||
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>, tensor<i32>>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: unsupported_dtype
|
||||
func @unsupported_dtype(%arg0: tensor<2x!tf.variant>) -> tensor<2x!tf.variant> {
|
||||
// CHECK: tf.AddN
|
||||
|
@ -337,9 +337,9 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) {
|
||||
|
||||
// Only static shaped operands are supported in XLA builders for now.
|
||||
for (Type ty : op->getOperandTypes()) {
|
||||
auto ranked_ty = ty.cast<ShapedType>();
|
||||
if (!ranked_ty.hasStaticShape()) {
|
||||
op->emitRemark() << "lowering requires static shaped operands";
|
||||
auto ranked_ty = ty.dyn_cast<ShapedType>();
|
||||
if (!ranked_ty || !ranked_ty.hasStaticShape()) {
|
||||
op->emitRemark() << "lowering requires static shaped tensor operands";
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user