Don't crash for unranked tensor types.

ShapedType::getShape will abort for unranked types. Bail out for unranked early, and just use RankedTensorType directly.

No test. I don't want to set a precedent of every elementwise op / other op needing individual test cases for every operand that might be ranked.

PiperOrigin-RevId: 309076134
Change-Id: Ibb15856bb29a1feac0a7054d2629f29b320ee8f8
This commit is contained in:
Sean Silva 2020-04-29 12:39:30 -07:00 committed by TensorFlower Gardener
parent e9bb36e11d
commit 4d4cfa046f

View File

@ -1749,10 +1749,12 @@ class ConvertSigmoidOp : public OpRewritePattern<TF::SigmoidOp> {
op.getLoc(),
rewriter.getFloatAttr(getElementTypeOrSelf(operand.getType()), 0.5));
auto shaped_type = operand.getType().cast<ShapedType>();
auto type = operand.getType().dyn_cast<RankedTensorType>();
if (!type)
return rewriter.notifyMatchFailure(op, "requires ranked tensor type");
auto constant_ones = rewriter.create<BroadcastOp>(
op.getLoc(), shaped_type, scalar_one,
GetI64ElementsAttr(shaped_type.getShape(), &rewriter));
op.getLoc(), type, scalar_one,
GetI64ElementsAttr(type.getShape(), &rewriter));
auto scaled_input = rewriter.create<MulOp>(
op.getLoc(), operand, constant_ones, DenseIntElementsAttr());