Fix bug in xla-legalize-tf-with-tf2xla pass by handling non-tensor operands

Currently, it only expects tensor operands but that is not applicable for non tensorflow dialect ops.

PiperOrigin-RevId: 317198672
Change-Id: I1387e664de740d044ef535f6903e07d63fa02f6d
This commit is contained in:
Smit Hinsu 2020-06-18 15:55:24 -07:00 committed by TensorFlower Gardener
parent e08382691b
commit 39504c25d9
2 changed files with 13 additions and 5 deletions
tensorflow/compiler/mlir/xla

View File

@ -35,7 +35,7 @@ func @not_whitelisted_op(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor
// CHECK-LABEL: unranked_operand
func @unranked_operand(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: tf.Abs
// expected-remark@+1 {{lowering requires static shaped operands}}
// expected-remark@+1 {{lowering requires static shaped tensor operands}}
%0 = "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
@ -44,12 +44,20 @@ func @unranked_operand(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-LABEL: dynamic_operand
func @dynamic_operand(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: tf.Abs
// expected-remark@+1 {{lowering requires static shaped operands}}
// expected-remark@+1 {{lowering requires static shaped tensor operands}}
%0 = "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// CHECK-LABEL: tuple_type
func @tuple_type(%arg0: tuple<tensor<f32>, tensor<i32>>) -> tensor<f32> {
// Verifies that the pass can handle operands of non-tensor type like tuple
// from non TensorFlow ops.
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>, tensor<i32>>) -> tensor<f32>
return %0 : tensor<f32>
}
// CHECK-LABEL: unsupported_dtype
func @unsupported_dtype(%arg0: tensor<2x!tf.variant>) -> tensor<2x!tf.variant> {
// CHECK: tf.AddN

View File

@ -337,9 +337,9 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) {
// Only static shaped operands are supported in XLA builders for now.
for (Type ty : op->getOperandTypes()) {
auto ranked_ty = ty.cast<ShapedType>();
if (!ranked_ty.hasStaticShape()) {
op->emitRemark() << "lowering requires static shaped operands";
auto ranked_ty = ty.dyn_cast<ShapedType>();
if (!ranked_ty || !ranked_ty.hasStaticShape()) {
op->emitRemark() << "lowering requires static shaped tensor operands";
return success();
}
}