Fix TF Bitcast and Cast ops to HLO legalization
- Allow unsigned types in Bitcast legalization. It got restricted to unsigned when MLIR added unsigned types. - Restrict Cast op native legalization to non complex types. There is a fallback lowering for this op to cover that. Enabled relevant tests for these. PiperOrigin-RevId: 351540425 Change-Id: I74098ae6c0f12c69fe6768b08e961251cc668b24
This commit is contained in:
parent
3365a661a8
commit
f2a0826b7d
@ -2122,7 +2122,7 @@ func @cast_i2f(%arg0: tensor<2xi32>) -> tensor<2xf32> {
|
||||
|
||||
// CHECK-LABEL: func @cast_c2f
|
||||
func @cast_c2f(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> {
|
||||
//CHECK: "mhlo.convert"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
|
||||
// CHECK: tf.Cast
|
||||
%0 = "tf.Cast"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
@ -618,7 +618,7 @@ foreach Mapping = [
|
||||
|
||||
// TODO(bixia): Lower Cast with a Complex type source operand or with
|
||||
// Truncate=True for floating point value conversions.
|
||||
def : Pat<(TF_CastOp HLO_Tensor:$arg, ConstBoolAttrFalse),
|
||||
def : Pat<(TF_CastOp HLO_PredIntOrFpTensor:$arg, ConstBoolAttrFalse),
|
||||
(HLO_ConvertOp $arg)>;
|
||||
|
||||
def : Pat<(TF_TransposeOp:$res $arg, (ConstantLikeMatcher ElementsAttr:$permutation)),
|
||||
@ -637,8 +637,8 @@ foreach TfOp = [TF_ExpandDimsOp, TF_ReshapeOp, TF_SqueezeOp, ] in {
|
||||
def : Pat<(TF_SignOp $x), (HLO_SignOp $x)>;
|
||||
|
||||
def BothElementTypesSameWidthIntOrFloat : Constraint<CPred<
|
||||
"getElementTypeOrSelf($0.getType()).isSignlessIntOrFloat() && "
|
||||
"getElementTypeOrSelf($1.getType()).isSignlessIntOrFloat() && "
|
||||
"getElementTypeOrSelf($0.getType()).isIntOrFloat() && "
|
||||
"getElementTypeOrSelf($1.getType()).isIntOrFloat() && "
|
||||
"getElementTypeOrSelf($0.getType()).getIntOrFloatBitWidth() == "
|
||||
"getElementTypeOrSelf($1.getType()).getIntOrFloatBitWidth()">,
|
||||
"element types must be integers or floats of same width">;
|
||||
|
@ -887,8 +887,6 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
[[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], dtype=np.float32),
|
||||
expected=np.array([14., 22.], dtype=np.float32))
|
||||
|
||||
@test_util.disable_mlir_bridge("TODO(b/153812660): Handle tf.Cast compilation"
|
||||
)
|
||||
def testCast(self):
|
||||
shapes = [[], [4], [2, 3], [2, 0, 4]]
|
||||
types = {
|
||||
@ -936,8 +934,6 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
src,
|
||||
expected=dst)
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"TODO(b/153812660): Handle tf.Bitcast compilation")
|
||||
def testBitcast(self):
|
||||
self._assertOpOutputMatchesExpected(
|
||||
lambda x: array_ops.bitcast(x, dtypes.int32),
|
||||
|
Loading…
x
Reference in New Issue
Block a user