Don't crash for unranked tensor types.
ShapedType::getShape will abort for unranked types. Bail out for unranked early, and just use RankedTensorType directly. No test. I don't want to set a precedent of every elementwise op / other op needing individual test cases for every operand that might be ranked. PiperOrigin-RevId: 309076134 Change-Id: Ibb15856bb29a1feac0a7054d2629f29b320ee8f8
This commit is contained in:
parent
e9bb36e11d
commit
4d4cfa046f
@ -1749,10 +1749,12 @@ class ConvertSigmoidOp : public OpRewritePattern<TF::SigmoidOp> {
|
||||
op.getLoc(),
|
||||
rewriter.getFloatAttr(getElementTypeOrSelf(operand.getType()), 0.5));
|
||||
|
||||
auto shaped_type = operand.getType().cast<ShapedType>();
|
||||
auto type = operand.getType().dyn_cast<RankedTensorType>();
|
||||
if (!type)
|
||||
return rewriter.notifyMatchFailure(op, "requires ranked tensor type");
|
||||
auto constant_ones = rewriter.create<BroadcastOp>(
|
||||
op.getLoc(), shaped_type, scalar_one,
|
||||
GetI64ElementsAttr(shaped_type.getShape(), &rewriter));
|
||||
op.getLoc(), type, scalar_one,
|
||||
GetI64ElementsAttr(type.getShape(), &rewriter));
|
||||
|
||||
auto scaled_input = rewriter.create<MulOp>(
|
||||
op.getLoc(), operand, constant_ones, DenseIntElementsAttr());
|
||||
|
Loading…
Reference in New Issue
Block a user