Return matchFailure instead of opError in decomposition of RngReadAndSkip.
If we encounter unexpected conditions during decomposition of tf.RngReadAndSkip, we should label them as match failures to prevent the decomposition, not as op errors. PiperOrigin-RevId: 357965593 Change-Id: Ibb576f424b1e86cd284a72a00014fa7e46b363f5
This commit is contained in:
parent
4bbf4dd025
commit
72d0c2ad78
@ -107,13 +107,12 @@ class DecomposeRngReadAndSkipOp : public RewritePattern {
|
||||
|
||||
DenseIntElementsAttr alg_constant;
|
||||
if (!matchPattern(rng_op.alg(), m_Constant(&alg_constant))) {
|
||||
op->emitOpError() << "unable to determine algorithm statically";
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unable to determine algorithm statically");
|
||||
}
|
||||
|
||||
if (alg_constant.getNumElements() != 1) {
|
||||
op->emitOpError() << "expected alg to be a scalar";
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(op, "expected alg to be a scalar");
|
||||
}
|
||||
|
||||
uint64_t alg_value = ((*alg_constant.int_value_begin()).getZExtValue());
|
||||
@ -123,8 +122,7 @@ class DecomposeRngReadAndSkipOp : public RewritePattern {
|
||||
} else if (tensorflow::RNG_ALG_THREEFRY == alg_value) {
|
||||
alg = tensorflow::RNG_ALG_THREEFRY;
|
||||
} else {
|
||||
op->emitOpError() << "unsupported alg";
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(op, "unsupported alg");
|
||||
}
|
||||
|
||||
Type state_element_type = rewriter.getI64Type();
|
||||
@ -132,13 +130,11 @@ class DecomposeRngReadAndSkipOp : public RewritePattern {
|
||||
{tensorflow::RNG_MAX_COUNTER_SIZE + tensorflow::RNG_KEY_SIZE},
|
||||
state_element_type);
|
||||
if (op_type != rng_op.getType()) {
|
||||
op->emitOpError() << "unexpected op type";
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(op, "unexpected op type");
|
||||
}
|
||||
|
||||
if (!HasResourceSubtype(rng_op.resource())) {
|
||||
op->emitOpError() << "missing resource subtype";
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(op, "missing resource subtype");
|
||||
}
|
||||
|
||||
int counter_size = tensorflow::GetCounterSize(alg);
|
||||
@ -146,8 +142,7 @@ class DecomposeRngReadAndSkipOp : public RewritePattern {
|
||||
RankedTensorType res_type =
|
||||
RankedTensorType::get({state_size}, state_element_type);
|
||||
if (res_type != GetResourceSubtype(rng_op.resource())) {
|
||||
op->emitOpError() << "unexpected resource subtype";
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(op, "unexpected resource subtype");
|
||||
}
|
||||
|
||||
Location loc = op->getLoc();
|
||||
|
Loading…
Reference in New Issue
Block a user