Fuse the back to back tfr.cast with unranked input to tf.EnsureShape
After this step, all the tfr ops should be raised to the tf ops PiperOrigin-RevId: 317756602 Change-Id: I15731f231caf6c47eac1719fc1c2ed3d01cff515
This commit is contained in:
parent
a26044ac2f
commit
557916fa3f
@ -2806,6 +2806,27 @@ the corresponding feature.
|
|||||||
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
|
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_EnsureShapeOp : TF_Op<"EnsureShape", [NoSideEffect]> {
|
||||||
|
let summary = "Ensures that the tensor's shape matches the expected shape.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Raises an error if the input tensor's shape does not match the specified shape.
|
||||||
|
Returns the input tensor otherwise.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_Tensor:$input,
|
||||||
|
|
||||||
|
TF_ShapeAttr:$shape
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_Tensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
}
|
||||||
|
|
||||||
def TF_EqualOp : TF_Op<"Equal", [Commutative, NoSideEffect]> {
|
def TF_EqualOp : TF_Op<"Equal", [Commutative, NoSideEffect]> {
|
||||||
let summary = "Returns the truth value of (x == y) element-wise.";
|
let summary = "Returns the truth value of (x == y) element-wise.";
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user