Return matchFailure instead of opError in decomposition of RngReadAndSkip.
If we encounter unexpected conditions during decomposition of tf.RngReadAndSkip, we should label them as match failures to prevent the decomposition, not as op errors. PiperOrigin-RevId: 357965593 Change-Id: Ibb576f424b1e86cd284a72a00014fa7e46b363f5
This commit is contained in:
parent
4bbf4dd025
commit
72d0c2ad78
@ -107,13 +107,12 @@ class DecomposeRngReadAndSkipOp : public RewritePattern {
|
|||||||
|
|
||||||
DenseIntElementsAttr alg_constant;
|
DenseIntElementsAttr alg_constant;
|
||||||
if (!matchPattern(rng_op.alg(), m_Constant(&alg_constant))) {
|
if (!matchPattern(rng_op.alg(), m_Constant(&alg_constant))) {
|
||||||
op->emitOpError() << "unable to determine algorithm statically";
|
return rewriter.notifyMatchFailure(
|
||||||
return failure();
|
op, "unable to determine algorithm statically");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (alg_constant.getNumElements() != 1) {
|
if (alg_constant.getNumElements() != 1) {
|
||||||
op->emitOpError() << "expected alg to be a scalar";
|
return rewriter.notifyMatchFailure(op, "expected alg to be a scalar");
|
||||||
return failure();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t alg_value = ((*alg_constant.int_value_begin()).getZExtValue());
|
uint64_t alg_value = ((*alg_constant.int_value_begin()).getZExtValue());
|
||||||
@ -123,8 +122,7 @@ class DecomposeRngReadAndSkipOp : public RewritePattern {
|
|||||||
} else if (tensorflow::RNG_ALG_THREEFRY == alg_value) {
|
} else if (tensorflow::RNG_ALG_THREEFRY == alg_value) {
|
||||||
alg = tensorflow::RNG_ALG_THREEFRY;
|
alg = tensorflow::RNG_ALG_THREEFRY;
|
||||||
} else {
|
} else {
|
||||||
op->emitOpError() << "unsupported alg";
|
return rewriter.notifyMatchFailure(op, "unsupported alg");
|
||||||
return failure();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Type state_element_type = rewriter.getI64Type();
|
Type state_element_type = rewriter.getI64Type();
|
||||||
@ -132,13 +130,11 @@ class DecomposeRngReadAndSkipOp : public RewritePattern {
|
|||||||
{tensorflow::RNG_MAX_COUNTER_SIZE + tensorflow::RNG_KEY_SIZE},
|
{tensorflow::RNG_MAX_COUNTER_SIZE + tensorflow::RNG_KEY_SIZE},
|
||||||
state_element_type);
|
state_element_type);
|
||||||
if (op_type != rng_op.getType()) {
|
if (op_type != rng_op.getType()) {
|
||||||
op->emitOpError() << "unexpected op type";
|
return rewriter.notifyMatchFailure(op, "unexpected op type");
|
||||||
return failure();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!HasResourceSubtype(rng_op.resource())) {
|
if (!HasResourceSubtype(rng_op.resource())) {
|
||||||
op->emitOpError() << "missing resource subtype";
|
return rewriter.notifyMatchFailure(op, "missing resource subtype");
|
||||||
return failure();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int counter_size = tensorflow::GetCounterSize(alg);
|
int counter_size = tensorflow::GetCounterSize(alg);
|
||||||
@ -146,8 +142,7 @@ class DecomposeRngReadAndSkipOp : public RewritePattern {
|
|||||||
RankedTensorType res_type =
|
RankedTensorType res_type =
|
||||||
RankedTensorType::get({state_size}, state_element_type);
|
RankedTensorType::get({state_size}, state_element_type);
|
||||||
if (res_type != GetResourceSubtype(rng_op.resource())) {
|
if (res_type != GetResourceSubtype(rng_op.resource())) {
|
||||||
op->emitOpError() << "unexpected resource subtype";
|
return rewriter.notifyMatchFailure(op, "unexpected resource subtype");
|
||||||
return failure();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
|
Loading…
Reference in New Issue
Block a user