Adds support for cast in MLIR converter. Also adds zip test for cast.
PiperOrigin-RevId: 257431619
This commit is contained in:
parent
f2b53e96cb
commit
84c176726f
@ -1965,6 +1965,25 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice",
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_CastOp : TFL_Op<"cast", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||
let summary = "Cast operator";
|
||||
|
||||
let description = [{
|
||||
Casts input from input type to output type.
|
||||
}];
|
||||
|
||||
// TODO(b/135538711): Add complex types here.
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I1, I32, I64]>:$input
|
||||
);
|
||||
|
||||
let results = (outs TensorOf<[F32, I1, I32, I64]>:$output);
|
||||
|
||||
// TFLite's cast op does not utilize CastOptions, instead derives types
|
||||
// from the TfLiteTensors.
|
||||
let hasOptions = 0;
|
||||
}
|
||||
|
||||
|
||||
def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
|
||||
NoSideEffect, TFL_OperandHasRank<1, 2>]> {
|
||||
|
@ -857,3 +857,11 @@ func @Tanh(%arg0: tensor<1xf32>) -> tensor<1xf32> {
|
||||
// CHECK-LABEL: Tanh
|
||||
// CHECK: %0 = "tfl.tanh"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
}
|
||||
|
||||
func @cast(%arg0: tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32> {
|
||||
%0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32>
|
||||
return %0 : tensor<1x2x2x5xf32>
|
||||
|
||||
// CHECK-LABEL: cast
|
||||
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32>
|
||||
}
|
||||
|
@ -235,6 +235,8 @@ def : Pat<(TF_MaxOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceMaxOp $arg0, $arg1
|
||||
|
||||
def : Pat<(TF_ProdOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceProdOp $arg0, $arg1, $arg2)>;
|
||||
|
||||
def : Pat<(TF_CastOp $arg0, BoolAttr:$arg1), (TFL_CastOp $arg0)>;
|
||||
|
||||
def : Pat<(TF_BatchToSpaceNDOp $input, $block_shape, $crops), (TFL_BatchToSpaceNdOp $input, $block_shape, $crops)>;
|
||||
|
||||
def : Pat<(TF_SpaceToBatchNDOp $input, $block_shape, $paddings), (TFL_SpaceToBatchNdOp $input, $block_shape, $paddings)>;
|
||||
|
@ -235,6 +235,7 @@ def generated_test_models():
|
||||
"arg_min_max",
|
||||
"avg_pool",
|
||||
"batch_to_space_nd",
|
||||
"cast",
|
||||
"ceil",
|
||||
"concat",
|
||||
"constant",
|
||||
|
@ -3856,6 +3856,33 @@ def make_zeros_like_tests(options):
|
||||
make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
@register_make_test_function()
|
||||
def make_cast_tests(options):
|
||||
"""Generate examples for cast."""
|
||||
test_parameters = [{
|
||||
"input_dtype": [tf.int32],
|
||||
"output_dtype": [tf.float32],
|
||||
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
|
||||
}]
|
||||
|
||||
def build_graph(parameters):
|
||||
"""Build the cast testing graph."""
|
||||
input_value = tf.placeholder(
|
||||
dtype=parameters["input_dtype"],
|
||||
name="input",
|
||||
shape=parameters["input_shape"])
|
||||
out = tf.cast(input_value, parameters["output_dtype"])
|
||||
return [input_value], [out]
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
input_value = create_tensor_data(parameters["input_dtype"],
|
||||
parameters["input_shape"])
|
||||
return [input_value], sess.run(
|
||||
outputs, feed_dict=dict(zip(inputs, [input_value])))
|
||||
|
||||
make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
def _make_elementwise_tests(op):
|
||||
"""Make a set of tests to do element-wise operations."""
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user