Fix TF Bitcast and Cast ops to HLO legalization

- Allow unsigned types in Bitcast legalization. It got restricted to unsigned when MLIR added unsigned types.
- Restrict Cast op native legalization to non complex types. There is a fallback lowering for this op to cover that.

Enabled relevant tests for these.

PiperOrigin-RevId: 351540425
Change-Id: I74098ae6c0f12c69fe6768b08e961251cc668b24
This commit is contained in:
Smit Hinsu 2021-01-13 01:49:38 -08:00 committed by TensorFlower Gardener
parent 3365a661a8
commit f2a0826b7d
3 changed files with 4 additions and 8 deletions

View File

@ -2122,7 +2122,7 @@ func @cast_i2f(%arg0: tensor<2xi32>) -> tensor<2xf32> {
// CHECK-LABEL: func @cast_c2f
func @cast_c2f(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> {
//CHECK: "mhlo.convert"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// CHECK: tf.Cast
%0 = "tf.Cast"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}

View File

@ -618,7 +618,7 @@ foreach Mapping = [
// TODO(bixia): Lower Cast with a Complex type source operand or with
// Truncate=True for floating point value conversions.
def : Pat<(TF_CastOp HLO_Tensor:$arg, ConstBoolAttrFalse),
def : Pat<(TF_CastOp HLO_PredIntOrFpTensor:$arg, ConstBoolAttrFalse),
(HLO_ConvertOp $arg)>;
def : Pat<(TF_TransposeOp:$res $arg, (ConstantLikeMatcher ElementsAttr:$permutation)),
@ -637,8 +637,8 @@ foreach TfOp = [TF_ExpandDimsOp, TF_ReshapeOp, TF_SqueezeOp, ] in {
def : Pat<(TF_SignOp $x), (HLO_SignOp $x)>;
def BothElementTypesSameWidthIntOrFloat : Constraint<CPred<
"getElementTypeOrSelf($0.getType()).isSignlessIntOrFloat() && "
"getElementTypeOrSelf($1.getType()).isSignlessIntOrFloat() && "
"getElementTypeOrSelf($0.getType()).isIntOrFloat() && "
"getElementTypeOrSelf($1.getType()).isIntOrFloat() && "
"getElementTypeOrSelf($0.getType()).getIntOrFloatBitWidth() == "
"getElementTypeOrSelf($1.getType()).getIntOrFloatBitWidth()">,
"element types must be integers or floats of same width">;

View File

@ -887,8 +887,6 @@ class UnaryOpsTest(xla_test.XLATestCase):
[[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], dtype=np.float32),
expected=np.array([14., 22.], dtype=np.float32))
@test_util.disable_mlir_bridge("TODO(b/153812660): Handle tf.Cast compilation"
)
def testCast(self):
shapes = [[], [4], [2, 3], [2, 0, 4]]
types = {
@ -936,8 +934,6 @@ class UnaryOpsTest(xla_test.XLATestCase):
src,
expected=dst)
@test_util.disable_mlir_bridge(
"TODO(b/153812660): Handle tf.Bitcast compilation")
def testBitcast(self):
self._assertOpOutputMatchesExpected(
lambda x: array_ops.bitcast(x, dtypes.int32),