Consider ops implementing InferTypeOpInterface as supported in shape inference.

IsSupportedNonTFOp is used to check whether case is needed or the ops type can be refined. Previously it only consider TF dialects ops, but ops implementing InferTypeOpInterface also get refined. Expand check to include such ops.

PiperOrigin-RevId: 359197188
Change-Id: I44c6cb0d080a6bcb7e6d173a5c0e11b03aecc691
This commit is contained in:
Renjie Liu 2021-02-23 20:29:14 -08:00 committed by TensorFlower Gardener
parent 5bb5bcc643
commit f083f1834c

View File

@ -117,7 +117,8 @@ bool IsSupportedNonTFOp(Operation* op) {
tf_executor::GraphOp, tf_executor::IslandOp,
tf_executor::LoopCondOp, tf_executor::MergeOp,
tf_executor::NextIterationSinkOp, tf_executor::SwitchNOp,
tf_executor::SwitchOp, tf_executor::YieldOp>(op);
tf_executor::SwitchOp, tf_executor::YieldOp>(op) ||
isa<InferTypeOpInterface>(op);
}
// Returns whether a cast back would need to be inserted, e.g., whether the