Support casting to int16.
PiperOrigin-RevId: 350114031 Change-Id: Id7c3f6d172ae4430ee7eb968d8e2d1b9fd1956a7
This commit is contained in:
parent
467e16b8a5
commit
62e391cc6c
@ -3453,10 +3453,10 @@ def TFL_CastOp : TFL_Op<"cast", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$input
|
||||
TFL_TensorOf<[F32, I1, I16, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$input
|
||||
);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, I1, I16, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$output);
|
||||
|
||||
// TFLite's cast op does not utilize CastOptions, instead derives types
|
||||
// from the TfLiteTensors.
|
||||
|
@ -1320,6 +1320,22 @@ func @cast(%arg0: tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32> {
|
||||
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32>
|
||||
}
|
||||
|
||||
func @castFloat32ToI16(%arg0: tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xi16> {
|
||||
%0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xi16>
|
||||
return %0 : tensor<1x2x2x5xi16>
|
||||
|
||||
// CHECK-LABEL: castFloat32ToI16
|
||||
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xi16>
|
||||
}
|
||||
|
||||
func @castI16ToFloat32(%arg0: tensor<1x2x2x5xi16>) -> tensor<1x2x2x5xf32> {
|
||||
%0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xi16>) -> tensor<1x2x2x5xf32>
|
||||
return %0 : tensor<1x2x2x5xf32>
|
||||
|
||||
// CHECK-LABEL: castI16ToFloat32
|
||||
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xi16>) -> tensor<1x2x2x5xf32>
|
||||
}
|
||||
|
||||
func @castComplex(%arg0: tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xcomplex<f32>> {
|
||||
%0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xcomplex<f32>>
|
||||
return %0 : tensor<1x2x2x5xcomplex<f32>>
|
||||
|
@ -80,6 +80,9 @@ TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
|
||||
case kTfLiteInt32:
|
||||
copyCast(in, out->data.i32, num_elements);
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
copyCast(in, out->data.i16, num_elements);
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
copyCast(in, out->data.uint8, num_elements);
|
||||
break;
|
||||
@ -113,6 +116,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return copyToTensor(context, input->data.i64, output, num_elements);
|
||||
case kTfLiteInt32:
|
||||
return copyToTensor(context, input->data.i32, output, num_elements);
|
||||
case kTfLiteInt16:
|
||||
return copyToTensor(context, input->data.i16, output, num_elements);
|
||||
case kTfLiteUInt8:
|
||||
return copyToTensor(context, input->data.uint8, output, num_elements);
|
||||
case kTfLiteFloat32:
|
||||
|
@ -46,6 +46,22 @@ class CastOpModel : public SingleOpModel {
|
||||
int output_;
|
||||
};
|
||||
|
||||
TEST(CastOpModel, CastInt16ToFloat) {
|
||||
CastOpModel m({TensorType_INT16, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
|
||||
m.PopulateTensor<int16_t>(m.input(), {100, 200, 300, 400, 500, 600});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.ExtractVector<float>(m.output()),
|
||||
ElementsAreArray({100.f, 200.f, 300.f, 400.f, 500.f, 600.f}));
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastInt16ToInt32) {
|
||||
CastOpModel m({TensorType_INT16, {2, 3}}, {TensorType_INT32, {2, 3}});
|
||||
m.PopulateTensor<int16_t>(m.input(), {100, 200, 300, 400, 500, 600});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
|
||||
ElementsAreArray({100, 200, 300, 400, 500, 600}));
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastInt32ToFloat) {
|
||||
CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
|
||||
m.PopulateTensor<int32_t>(m.input(), {100, 200, 300, 400, 500, 600});
|
||||
@ -62,6 +78,14 @@ TEST(CastOpModel, CastFloatToInt32) {
|
||||
ElementsAreArray({100, 20, 3, 0, 0, 1}));
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastFloatToInt16) {
|
||||
CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT16, {3, 2}});
|
||||
m.PopulateTensor<float>(m.input(), {100.f, 20.f, 3.f, 0.4f, 0.999f, 1.1f});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.ExtractVector<int16_t>(m.output()),
|
||||
ElementsAreArray({100, 20, 3, 0, 0, 1}));
|
||||
}
|
||||
|
||||
TEST(CastOpModel, CastInt64ToFloat) {
|
||||
CastOpModel m({TensorType_INT64, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
|
||||
m.PopulateTensor<int64_t>(m.input(), {100, 200, 300, 400, 500, 600});
|
||||
|
@ -26,11 +26,30 @@ from tensorflow.lite.testing.zip_test_utils import register_make_test_function
|
||||
@register_make_test_function()
|
||||
def make_cast_tests(options):
|
||||
"""Generate examples for cast."""
|
||||
test_parameters = [{
|
||||
"input_dtype": [tf.int32],
|
||||
"output_dtype": [tf.float32],
|
||||
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
|
||||
}]
|
||||
if options.use_experimental_converter:
|
||||
test_parameters = [
|
||||
{
|
||||
"input_dtype": [tf.float32],
|
||||
"output_dtype": [tf.int16],
|
||||
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
|
||||
},
|
||||
{
|
||||
"input_dtype": [tf.int16],
|
||||
"output_dtype": [tf.float32],
|
||||
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
|
||||
},
|
||||
{
|
||||
"input_dtype": [tf.int32],
|
||||
"output_dtype": [tf.float32],
|
||||
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
|
||||
}]
|
||||
else:
|
||||
test_parameters = [
|
||||
{
|
||||
"input_dtype": [tf.int32],
|
||||
"output_dtype": [tf.float32],
|
||||
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
|
||||
}]
|
||||
|
||||
def build_graph(parameters):
|
||||
"""Build the cast testing graph."""
|
||||
|
Loading…
Reference in New Issue
Block a user