Merge pull request #41819 from wwwind:16x8_slice_transpose_fixes
PiperOrigin-RevId: 331289591
This commit is contained in:
		
						commit
						510ced6135
					
				@ -364,6 +364,7 @@ SelectOpTest/.+,29
 | 
			
		||||
-SliceOpTest/SliceOpTest/SliceString/.+
 | 
			
		||||
-SliceOpTest/SliceOpTest/SliceInt64/.+
 | 
			
		||||
-SliceOpTest/SliceOpTest/SliceBool/.+
 | 
			
		||||
-SliceOpTest/SliceOpTest/SliceInt16/.+
 | 
			
		||||
# Only constant tensors
 | 
			
		||||
SliceOpTest/SliceOpTest/.+/0,29
 | 
			
		||||
 | 
			
		||||
@ -413,6 +414,7 @@ TopKV2OpTest/TopKV2OpTest/.+/0,29
 | 
			
		||||
-TransposeTest/5DDividedIntoTwo2Ds.*
 | 
			
		||||
-TransposeTest/Complex5DTest.*
 | 
			
		||||
-TransposeTest/.+DynamicTensor
 | 
			
		||||
-TransposeTest/TestRefOps4DInt16
 | 
			
		||||
TransposeTest/.+
 | 
			
		||||
 | 
			
		||||
# transpose_conv_test
 | 
			
		||||
 | 
			
		||||
@ -136,7 +136,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
 | 
			
		||||
             /* max_version = */ 4);
 | 
			
		||||
  AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE(),
 | 
			
		||||
             /* min_version = */ 1,
 | 
			
		||||
             /* max_version = */ 4);
 | 
			
		||||
             /* max_version = */ 5);
 | 
			
		||||
  AddBuiltin(BuiltinOperator_MEAN, Register_MEAN(),
 | 
			
		||||
             /* min_version = */ 1,
 | 
			
		||||
             /* max_version = */ 3);
 | 
			
		||||
@ -203,7 +203,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
 | 
			
		||||
  AddBuiltin(BuiltinOperator_SELECT_V2, Register_SELECT_V2());
 | 
			
		||||
  AddBuiltin(BuiltinOperator_SLICE, Register_SLICE(),
 | 
			
		||||
             /* min_version = */ 1,
 | 
			
		||||
             /* max_version = */ 3);
 | 
			
		||||
             /* max_version = */ 4);
 | 
			
		||||
  AddBuiltin(BuiltinOperator_SIN, Register_SIN());
 | 
			
		||||
  AddBuiltin(BuiltinOperator_COS, Register_COS());
 | 
			
		||||
  AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV(),
 | 
			
		||||
 | 
			
		||||
@ -214,6 +214,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 | 
			
		||||
    case kTfLiteInt8:
 | 
			
		||||
      TF_LITE_SLICE(int8_t, kernel_type);
 | 
			
		||||
      break;
 | 
			
		||||
    case kTfLiteInt16:
 | 
			
		||||
      TF_LITE_SLICE(int16_t, kernel_type);
 | 
			
		||||
      break;
 | 
			
		||||
    case kTfLiteUInt8:
 | 
			
		||||
      TF_LITE_SLICE(uint8_t, kernel_type);
 | 
			
		||||
      break;
 | 
			
		||||
 | 
			
		||||
@ -226,6 +226,16 @@ TEST_P(SliceOpTest, SliceInt8) {
 | 
			
		||||
  EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_P(SliceOpTest, SliceInt16) {
 | 
			
		||||
  SliceOpModel<int16_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
 | 
			
		||||
                                   {2, 1, -1, 1}, TensorType_INT32,
 | 
			
		||||
                                   TensorType_INT16, GetParam());
 | 
			
		||||
  m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
 | 
			
		||||
  m.Invoke();
 | 
			
		||||
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
 | 
			
		||||
  EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_P(SliceOpTest, SliceString) {
 | 
			
		||||
  SliceOpModel<string, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
 | 
			
		||||
                                  {2, 1, -1, 1}, TensorType_INT32,
 | 
			
		||||
 | 
			
		||||
@ -130,6 +130,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 | 
			
		||||
        TF_LITE_TRANSPOSE(reference_ops, int8_t);
 | 
			
		||||
      }
 | 
			
		||||
      break;
 | 
			
		||||
    case kTfLiteInt16:
 | 
			
		||||
      TF_LITE_TRANSPOSE(reference_ops, int16_t);
 | 
			
		||||
      break;
 | 
			
		||||
    case kTfLiteInt64:
 | 
			
		||||
      TF_LITE_TRANSPOSE(reference_ops, int64_t);
 | 
			
		||||
      break;
 | 
			
		||||
 | 
			
		||||
@ -180,13 +180,14 @@ TEST(TransposeTest, TestRefOps4D) {
 | 
			
		||||
  ASSERT_EQ(out, ref);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(TransposeTest, TestRefOps4DInt8) {
 | 
			
		||||
  std::vector<int8_t> out;
 | 
			
		||||
template <typename T>
 | 
			
		||||
void TransposeTestTestRefOps4D() {
 | 
			
		||||
  std::vector<T> out;
 | 
			
		||||
  // Basic 4d.
 | 
			
		||||
  RunTestPermutation({2, 3, 4, 5}, {2, 0, 1, 3}, &out);
 | 
			
		||||
  ASSERT_EQ(
 | 
			
		||||
      out,
 | 
			
		||||
      std::vector<int8_t>(
 | 
			
		||||
      std::vector<T>(
 | 
			
		||||
          {0,  1,  2,  3,  4,  20, 21, 22, 23, 24, 40,  41,  42,  43,  44,
 | 
			
		||||
           60, 61, 62, 63, 64, 80, 81, 82, 83, 84, 100, 101, 102, 103, 104,
 | 
			
		||||
           5,  6,  7,  8,  9,  25, 26, 27, 28, 29, 45,  46,  47,  48,  49,
 | 
			
		||||
@ -197,11 +198,15 @@ TEST(TransposeTest, TestRefOps4DInt8) {
 | 
			
		||||
           75, 76, 77, 78, 79, 95, 96, 97, 98, 99, 115, 116, 117, 118, 119}));
 | 
			
		||||
  RunTestPermutation({2, 3, 4, 5}, {0, 1, 2, 3}, &out);
 | 
			
		||||
  // Basic identity.
 | 
			
		||||
  std::vector<int8_t> ref(out.size());
 | 
			
		||||
  std::vector<T> ref(out.size());
 | 
			
		||||
  for (int k = 0; k < ref.size(); k++) ref[k] = k;
 | 
			
		||||
  ASSERT_EQ(out, ref);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(TransposeTest, TestRefOps4DInt8) { TransposeTestTestRefOps4D<int8_t>(); }
 | 
			
		||||
 | 
			
		||||
TEST(TransposeTest, TestRefOps4DInt16) { TransposeTestTestRefOps4D<int16_t>(); }
 | 
			
		||||
 | 
			
		||||
class TransposeOpModel : public SingleOpModel {
 | 
			
		||||
 public:
 | 
			
		||||
  void SetInput(std::initializer_list<float> data) {
 | 
			
		||||
 | 
			
		||||
@ -124,6 +124,7 @@ std::string GetMinimumRuntimeVersionForModel(const Model& model) {
 | 
			
		||||
          {{OperatorType::kTranspose, 1}, "1.6.0"},
 | 
			
		||||
          {{OperatorType::kTranspose, 2}, "1.14.0"},
 | 
			
		||||
          {{OperatorType::kTranspose, 3}, "1.15.0"},
 | 
			
		||||
          {{OperatorType::kTranspose, 5}, kPendingReleaseOpVersion},
 | 
			
		||||
          {{OperatorType::kLstmCell, 1}, "1.7.0"},
 | 
			
		||||
          {{OperatorType::kLstmCell, 2}, "1.10.0"},
 | 
			
		||||
          {{OperatorType::kLstmCell, 3}, "1.14.0"},
 | 
			
		||||
@ -180,6 +181,7 @@ std::string GetMinimumRuntimeVersionForModel(const Model& model) {
 | 
			
		||||
          {{OperatorType::kSlice, 1}, "1.14.0"},
 | 
			
		||||
          {{OperatorType::kSlice, 2}, "1.14.0"},
 | 
			
		||||
          {{OperatorType::kSlice, 3}, "1.14.0"},
 | 
			
		||||
          {{OperatorType::kSlice, 4}, kPendingReleaseOpVersion},
 | 
			
		||||
          {{OperatorType::kTanh, 1}, "1.14.0"},
 | 
			
		||||
          {{OperatorType::kTanh, 2}, "1.14.0"},
 | 
			
		||||
          {{OperatorType::kTanh, 3}, kPendingReleaseOpVersion},
 | 
			
		||||
 | 
			
		||||
@ -237,6 +237,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
 | 
			
		||||
      return 1;
 | 
			
		||||
 | 
			
		||||
    case BuiltinOperator_TRANSPOSE:
 | 
			
		||||
      if (op_sig.input_types.at(0) == TensorType_INT16) {
 | 
			
		||||
        return 5;
 | 
			
		||||
      }
 | 
			
		||||
      if (op_sig.options.single_input_op.num_dims > 4) {
 | 
			
		||||
        return 4;
 | 
			
		||||
      }
 | 
			
		||||
@ -320,6 +323,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
 | 
			
		||||
      return 1;
 | 
			
		||||
 | 
			
		||||
    case BuiltinOperator_SLICE:
 | 
			
		||||
      if (op_sig.input_types.at(0) == TensorType_INT16) {
 | 
			
		||||
        return 4;
 | 
			
		||||
      }
 | 
			
		||||
      // Version 3 supports string input types.
 | 
			
		||||
      if (op_sig.input_types.at(0) == TensorType_STRING) {
 | 
			
		||||
        return 3;
 | 
			
		||||
 | 
			
		||||
@ -216,6 +216,12 @@ TEST(OpVersionTest, VersioningSpaceToDepthTest) {
 | 
			
		||||
 | 
			
		||||
TEST(OpVersionTest, VersioningSliceTest) {
 | 
			
		||||
  OpSignature fake_op_sig = {
 | 
			
		||||
      .op = BuiltinOperator_SLICE,
 | 
			
		||||
      .input_types = std::vector<TensorType>{TensorType_INT16},
 | 
			
		||||
  };
 | 
			
		||||
  EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4);
 | 
			
		||||
 | 
			
		||||
  fake_op_sig = {
 | 
			
		||||
      .op = BuiltinOperator_SLICE,
 | 
			
		||||
      .input_types = std::vector<TensorType>{TensorType_STRING},
 | 
			
		||||
  };
 | 
			
		||||
@ -587,6 +593,12 @@ TEST(OpVersionTest, VersioningTileOperatorTest) {
 | 
			
		||||
}
 | 
			
		||||
TEST(OpVersionTest, VersioningTransposeTest) {
 | 
			
		||||
  OpSignature fake_op_sig = {
 | 
			
		||||
      .op = BuiltinOperator_TRANSPOSE,
 | 
			
		||||
      .input_types = std::vector<TensorType>{TensorType_INT16},
 | 
			
		||||
  };
 | 
			
		||||
  EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5);
 | 
			
		||||
 | 
			
		||||
  fake_op_sig = {
 | 
			
		||||
      .op = BuiltinOperator_TRANSPOSE,
 | 
			
		||||
      .input_types = std::vector<TensorType>{TensorType_BOOL},
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
@ -159,6 +159,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
 | 
			
		||||
              {{BuiltinOperator_TRANSPOSE, 2}, "1.14.0"},
 | 
			
		||||
              {{BuiltinOperator_TRANSPOSE, 3}, "1.15.0"},
 | 
			
		||||
              {{BuiltinOperator_TRANSPOSE, 4}, "2.3.0"},
 | 
			
		||||
              {{BuiltinOperator_TRANSPOSE, 5}, kPendingReleaseVersion},
 | 
			
		||||
              {{BuiltinOperator_LSTM, 1}, "1.7.0"},
 | 
			
		||||
              {{BuiltinOperator_LSTM, 2}, "1.10.0"},
 | 
			
		||||
              {{BuiltinOperator_LSTM, 3}, "1.14.0"},
 | 
			
		||||
@ -228,6 +229,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
 | 
			
		||||
              {{BuiltinOperator_SLICE, 1}, "1.14.0"},
 | 
			
		||||
              {{BuiltinOperator_SLICE, 2}, "1.14.0"},
 | 
			
		||||
              {{BuiltinOperator_SLICE, 3}, "1.14.0"},
 | 
			
		||||
              {{BuiltinOperator_SLICE, 4}, kPendingReleaseVersion},
 | 
			
		||||
              {{BuiltinOperator_TANH, 1}, "1.14.0"},
 | 
			
		||||
              {{BuiltinOperator_TANH, 2}, "1.14.0"},
 | 
			
		||||
              {{BuiltinOperator_TANH, 3}, "2.3.0"},
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user