Support casting to int16.

PiperOrigin-RevId: 350114031
Change-Id: Id7c3f6d172ae4430ee7eb968d8e2d1b9fd1956a7
This commit is contained in:
A. Unique TensorFlower 2021-01-05 04:16:23 -08:00 committed by TensorFlower Gardener
parent 467e16b8a5
commit 62e391cc6c
5 changed files with 71 additions and 7 deletions

View File

@ -3453,10 +3453,10 @@ def TFL_CastOp : TFL_Op<"cast", [
}];
let arguments = (ins
TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$input
TFL_TensorOf<[F32, I1, I16, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$input
);
let results = (outs TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$output);
let results = (outs TFL_TensorOf<[F32, I1, I16, I32, I64, TFL_Quint8, UI8, Complex<F<32>>]>:$output);
// TFLite's cast op does not utilize CastOptions, instead derives types
// from the TfLiteTensors.

View File

@ -1320,6 +1320,22 @@ func @cast(%arg0: tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32> {
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32>
}
func @castFloat32ToI16(%arg0: tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xi16> {
%0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xi16>
return %0 : tensor<1x2x2x5xi16>
// CHECK-LABEL: castFloat32ToI16
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xi16>
}
func @castI16ToFloat32(%arg0: tensor<1x2x2x5xi16>) -> tensor<1x2x2x5xf32> {
%0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xi16>) -> tensor<1x2x2x5xf32>
return %0 : tensor<1x2x2x5xf32>
// CHECK-LABEL: castI16ToFloat32
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xi16>) -> tensor<1x2x2x5xf32>
}
func @castComplex(%arg0: tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xcomplex<f32>> {
%0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xcomplex<f32>>
return %0 : tensor<1x2x2x5xcomplex<f32>>

View File

@ -80,6 +80,9 @@ TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
case kTfLiteInt32:
copyCast(in, out->data.i32, num_elements);
break;
case kTfLiteInt16:
copyCast(in, out->data.i16, num_elements);
break;
case kTfLiteUInt8:
copyCast(in, out->data.uint8, num_elements);
break;
@ -113,6 +116,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return copyToTensor(context, input->data.i64, output, num_elements);
case kTfLiteInt32:
return copyToTensor(context, input->data.i32, output, num_elements);
case kTfLiteInt16:
return copyToTensor(context, input->data.i16, output, num_elements);
case kTfLiteUInt8:
return copyToTensor(context, input->data.uint8, output, num_elements);
case kTfLiteFloat32:

View File

@ -46,6 +46,22 @@ class CastOpModel : public SingleOpModel {
int output_;
};
TEST(CastOpModel, CastInt16ToFloat) {
CastOpModel m({TensorType_INT16, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
m.PopulateTensor<int16_t>(m.input(), {100, 200, 300, 400, 500, 600});
m.Invoke();
EXPECT_THAT(m.ExtractVector<float>(m.output()),
ElementsAreArray({100.f, 200.f, 300.f, 400.f, 500.f, 600.f}));
}
TEST(CastOpModel, CastInt16ToInt32) {
CastOpModel m({TensorType_INT16, {2, 3}}, {TensorType_INT32, {2, 3}});
m.PopulateTensor<int16_t>(m.input(), {100, 200, 300, 400, 500, 600});
m.Invoke();
EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
ElementsAreArray({100, 200, 300, 400, 500, 600}));
}
TEST(CastOpModel, CastInt32ToFloat) {
CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
m.PopulateTensor<int32_t>(m.input(), {100, 200, 300, 400, 500, 600});
@ -62,6 +78,14 @@ TEST(CastOpModel, CastFloatToInt32) {
ElementsAreArray({100, 20, 3, 0, 0, 1}));
}
TEST(CastOpModel, CastFloatToInt16) {
CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_INT16, {3, 2}});
m.PopulateTensor<float>(m.input(), {100.f, 20.f, 3.f, 0.4f, 0.999f, 1.1f});
m.Invoke();
EXPECT_THAT(m.ExtractVector<int16_t>(m.output()),
ElementsAreArray({100, 20, 3, 0, 0, 1}));
}
TEST(CastOpModel, CastInt64ToFloat) {
CastOpModel m({TensorType_INT64, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
m.PopulateTensor<int64_t>(m.input(), {100, 200, 300, 400, 500, 600});

View File

@ -26,11 +26,30 @@ from tensorflow.lite.testing.zip_test_utils import register_make_test_function
@register_make_test_function()
def make_cast_tests(options):
"""Generate examples for cast."""
test_parameters = [{
"input_dtype": [tf.int32],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}]
if options.use_experimental_converter:
test_parameters = [
{
"input_dtype": [tf.float32],
"output_dtype": [tf.int16],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.int16],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.int32],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}]
else:
test_parameters = [
{
"input_dtype": [tf.int32],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}]
def build_graph(parameters):
"""Build the cast testing graph."""