diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 37eb31d0ade..a6a6829b109 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -1749,10 +1749,12 @@ class ConvertSigmoidOp : public OpRewritePattern { op.getLoc(), rewriter.getFloatAttr(getElementTypeOrSelf(operand.getType()), 0.5)); - auto shaped_type = operand.getType().cast(); + auto type = operand.getType().dyn_cast(); + if (!type) + return rewriter.notifyMatchFailure(op, "requires ranked tensor type"); auto constant_ones = rewriter.create( - op.getLoc(), shaped_type, scalar_one, - GetI64ElementsAttr(shaped_type.getShape(), &rewriter)); + op.getLoc(), type, scalar_one, + GetI64ElementsAttr(type.getShape(), &rewriter)); auto scaled_input = rewriter.create( op.getLoc(), operand, constant_ones, DenseIntElementsAttr());