From 4d4cfa046fe2f287f5bdffc7b360a974fb7d8268 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Wed, 29 Apr 2020 12:39:30 -0700 Subject: [PATCH] Don't crash for unranked tensor types. ShapedType::getShape will abort for unranked types. Bail out for unranked early, and just use RankedTensorType directly. No test. I don't want to set a precedent of every elementwise op / other op needing individual test cases for every operand that might be ranked. PiperOrigin-RevId: 309076134 Change-Id: Ibb15856bb29a1feac0a7054d2629f29b320ee8f8 --- tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 37eb31d0ade..a6a6829b109 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -1749,10 +1749,12 @@ class ConvertSigmoidOp : public OpRewritePattern { op.getLoc(), rewriter.getFloatAttr(getElementTypeOrSelf(operand.getType()), 0.5)); - auto shaped_type = operand.getType().cast(); + auto type = operand.getType().dyn_cast(); + if (!type) + return rewriter.notifyMatchFailure(op, "requires ranked tensor type"); auto constant_ones = rewriter.create( - op.getLoc(), shaped_type, scalar_one, - GetI64ElementsAttr(shaped_type.getShape(), &rewriter)); + op.getLoc(), type, scalar_one, + GetI64ElementsAttr(type.getShape(), &rewriter)); auto scaled_input = rewriter.create( op.getLoc(), operand, constant_ones, DenseIntElementsAttr());