diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index e2366d8cd80..b9972cebdce 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -1965,6 +1965,25 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", let hasOptions = 1; } +def TFL_CastOp : TFL_Op<"cast", [NoSideEffect, SameOperandsAndResultShape]> { + let summary = "Cast operator"; + + let description = [{ + Casts input from input type to output type. + }]; + + // TODO(b/135538711): Add complex types here. + let arguments = (ins + TensorOf<[F32, I1, I32, I64]>:$input + ); + + let results = (outs TensorOf<[F32, I1, I32, I64]>:$output); + + // TFLite's cast op does not utilize CastOptions, instead derives types + // from the TfLiteTensors. + let hasOptions = 0; +} + def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [ NoSideEffect, TFL_OperandHasRank<1, 2>]> { diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index f771e0b7504..6b29869dde5 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -857,3 +857,11 @@ func @Tanh(%arg0: tensor<1xf32>) -> tensor<1xf32> { // CHECK-LABEL: Tanh // CHECK: %0 = "tfl.tanh"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> } + +func @cast(%arg0: tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32> { + %0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32> + return %0 : tensor<1x2x2x5xf32> + + // CHECK-LABEL: cast + // CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32> +} diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index 0fac8a09f86..d4866a4740e 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -235,6 +235,8 @@ def : Pat<(TF_MaxOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceMaxOp $arg0, $arg1 def : Pat<(TF_ProdOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceProdOp $arg0, $arg1, $arg2)>; +def : Pat<(TF_CastOp $arg0, BoolAttr:$arg1), (TFL_CastOp $arg0)>; + def : Pat<(TF_BatchToSpaceNDOp $input, $block_shape, $crops), (TFL_BatchToSpaceNdOp $input, $block_shape, $crops)>; def : Pat<(TF_SpaceToBatchNDOp $input, $block_shape, $paddings), (TFL_SpaceToBatchNdOp $input, $block_shape, $paddings)>; diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl index 9bdf0547d45..3c9337121f6 100644 --- a/tensorflow/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -235,6 +235,7 @@ def generated_test_models(): "arg_min_max", "avg_pool", "batch_to_space_nd", + "cast", "ceil", "concat", "constant", diff --git a/tensorflow/lite/testing/generate_examples_lib.py b/tensorflow/lite/testing/generate_examples_lib.py index 431a9aa541b..31c8f94a075 100644 --- a/tensorflow/lite/testing/generate_examples_lib.py +++ b/tensorflow/lite/testing/generate_examples_lib.py @@ -3856,6 +3856,33 @@ def make_zeros_like_tests(options): make_zip_of_tests(options, test_parameters, build_graph, build_inputs) +@register_make_test_function() +def make_cast_tests(options): + """Generate examples for cast.""" + test_parameters = [{ + "input_dtype": [tf.int32], + "output_dtype": [tf.float32], + "input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]], + }] + + def build_graph(parameters): + """Build the cast testing graph.""" + input_value = tf.placeholder( + dtype=parameters["input_dtype"], + name="input", + shape=parameters["input_shape"]) + out = tf.cast(input_value, parameters["output_dtype"]) + return [input_value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_dtype"], + parameters["input_shape"]) + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) + + make_zip_of_tests(options, test_parameters, build_graph, build_inputs) + + def _make_elementwise_tests(op): """Make a set of tests to do element-wise operations."""