Return matchFailure instead of opError in decomposition of RngReadAndSkip.

If we encounter unexpected conditions during decomposition of
tf.RngReadAndSkip, we should label them as match failures to prevent the
decomposition, not as op errors.

PiperOrigin-RevId: 357965593
Change-Id: Ibb576f424b1e86cd284a72a00014fa7e46b363f5
This commit is contained in:
Richard Uhler 2021-02-17 09:30:29 -08:00 committed by TensorFlower Gardener
parent 4bbf4dd025
commit 72d0c2ad78

View File

@ -107,13 +107,12 @@ class DecomposeRngReadAndSkipOp : public RewritePattern {
DenseIntElementsAttr alg_constant;
if (!matchPattern(rng_op.alg(), m_Constant(&alg_constant))) {
op->emitOpError() << "unable to determine algorithm statically";
return failure();
return rewriter.notifyMatchFailure(
op, "unable to determine algorithm statically");
}
if (alg_constant.getNumElements() != 1) {
op->emitOpError() << "expected alg to be a scalar";
return failure();
return rewriter.notifyMatchFailure(op, "expected alg to be a scalar");
}
uint64_t alg_value = ((*alg_constant.int_value_begin()).getZExtValue());
@ -123,8 +122,7 @@ class DecomposeRngReadAndSkipOp : public RewritePattern {
} else if (tensorflow::RNG_ALG_THREEFRY == alg_value) {
alg = tensorflow::RNG_ALG_THREEFRY;
} else {
op->emitOpError() << "unsupported alg";
return failure();
return rewriter.notifyMatchFailure(op, "unsupported alg");
}
Type state_element_type = rewriter.getI64Type();
@ -132,13 +130,11 @@ class DecomposeRngReadAndSkipOp : public RewritePattern {
{tensorflow::RNG_MAX_COUNTER_SIZE + tensorflow::RNG_KEY_SIZE},
state_element_type);
if (op_type != rng_op.getType()) {
op->emitOpError() << "unexpected op type";
return failure();
return rewriter.notifyMatchFailure(op, "unexpected op type");
}
if (!HasResourceSubtype(rng_op.resource())) {
op->emitOpError() << "missing resource subtype";
return failure();
return rewriter.notifyMatchFailure(op, "missing resource subtype");
}
int counter_size = tensorflow::GetCounterSize(alg);
@ -146,8 +142,7 @@ class DecomposeRngReadAndSkipOp : public RewritePattern {
RankedTensorType res_type =
RankedTensorType::get({state_size}, state_element_type);
if (res_type != GetResourceSubtype(rng_op.resource())) {
op->emitOpError() << "unexpected resource subtype";
return failure();
return rewriter.notifyMatchFailure(op, "unexpected resource subtype");
}
Location loc = op->getLoc();