Fix TF Bitcast and Cast ops to HLO legalization

- Allow unsigned types in Bitcast legalization. It got restricted to unsigned when MLIR added unsigned types.
- Restrict Cast op native legalization to non complex types. There is a fallback lowering for this op to cover that.

Enabled relevant tests for these.

PiperOrigin-RevId: 351540425
Change-Id: I74098ae6c0f12c69fe6768b08e961251cc668b24
This commit is contained in:
Smit Hinsu 2021-01-13 01:49:38 -08:00 committed by TensorFlower Gardener
parent 3365a661a8
commit f2a0826b7d
3 changed files with 4 additions and 8 deletions

View File

@ -2122,7 +2122,7 @@ func @cast_i2f(%arg0: tensor<2xi32>) -> tensor<2xf32> {
// CHECK-LABEL: func @cast_c2f // CHECK-LABEL: func @cast_c2f
func @cast_c2f(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> { func @cast_c2f(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> {
//CHECK: "mhlo.convert"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32> // CHECK: tf.Cast
%0 = "tf.Cast"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32> %0 = "tf.Cast"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
return %0 : tensor<2xf32> return %0 : tensor<2xf32>
} }

View File

@ -618,7 +618,7 @@ foreach Mapping = [
// TODO(bixia): Lower Cast with a Complex type source operand or with // TODO(bixia): Lower Cast with a Complex type source operand or with
// Truncate=True for floating point value conversions. // Truncate=True for floating point value conversions.
def : Pat<(TF_CastOp HLO_Tensor:$arg, ConstBoolAttrFalse), def : Pat<(TF_CastOp HLO_PredIntOrFpTensor:$arg, ConstBoolAttrFalse),
(HLO_ConvertOp $arg)>; (HLO_ConvertOp $arg)>;
def : Pat<(TF_TransposeOp:$res $arg, (ConstantLikeMatcher ElementsAttr:$permutation)), def : Pat<(TF_TransposeOp:$res $arg, (ConstantLikeMatcher ElementsAttr:$permutation)),
@ -637,8 +637,8 @@ foreach TfOp = [TF_ExpandDimsOp, TF_ReshapeOp, TF_SqueezeOp, ] in {
def : Pat<(TF_SignOp $x), (HLO_SignOp $x)>; def : Pat<(TF_SignOp $x), (HLO_SignOp $x)>;
def BothElementTypesSameWidthIntOrFloat : Constraint<CPred< def BothElementTypesSameWidthIntOrFloat : Constraint<CPred<
"getElementTypeOrSelf($0.getType()).isSignlessIntOrFloat() && " "getElementTypeOrSelf($0.getType()).isIntOrFloat() && "
"getElementTypeOrSelf($1.getType()).isSignlessIntOrFloat() && " "getElementTypeOrSelf($1.getType()).isIntOrFloat() && "
"getElementTypeOrSelf($0.getType()).getIntOrFloatBitWidth() == " "getElementTypeOrSelf($0.getType()).getIntOrFloatBitWidth() == "
"getElementTypeOrSelf($1.getType()).getIntOrFloatBitWidth()">, "getElementTypeOrSelf($1.getType()).getIntOrFloatBitWidth()">,
"element types must be integers or floats of same width">; "element types must be integers or floats of same width">;

View File

@ -887,8 +887,6 @@ class UnaryOpsTest(xla_test.XLATestCase):
[[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], dtype=np.float32), [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], dtype=np.float32),
expected=np.array([14., 22.], dtype=np.float32)) expected=np.array([14., 22.], dtype=np.float32))
@test_util.disable_mlir_bridge("TODO(b/153812660): Handle tf.Cast compilation"
)
def testCast(self): def testCast(self):
shapes = [[], [4], [2, 3], [2, 0, 4]] shapes = [[], [4], [2, 3], [2, 0, 4]]
types = { types = {
@ -936,8 +934,6 @@ class UnaryOpsTest(xla_test.XLATestCase):
src, src,
expected=dst) expected=dst)
@test_util.disable_mlir_bridge(
"TODO(b/153812660): Handle tf.Bitcast compilation")
def testBitcast(self): def testBitcast(self):
self._assertOpOutputMatchesExpected( self._assertOpOutputMatchesExpected(
lambda x: array_ops.bitcast(x, dtypes.int32), lambda x: array_ops.bitcast(x, dtypes.int32),