Teach TFLite model verifier about all supported types
PiperOrigin-RevId: 259073172
This commit is contained in:
parent
fdc106e412
commit
8f55026cc8
@ -91,7 +91,8 @@ cc_test(
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:schema_fbs_version",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite:util",
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/testing:util",
|
||||
"@com_google_googletest//:gtest",
|
||||
|
@ -130,20 +130,30 @@ bool VerifyNumericTensorBuffer(const Tensor& tensor, const Buffer& buffer,
|
||||
case TensorType_FLOAT32:
|
||||
bytes_required *= sizeof(float);
|
||||
break;
|
||||
case TensorType_INT8:
|
||||
bytes_required *= sizeof(int8_t);
|
||||
break;
|
||||
case TensorType_UINT8:
|
||||
bytes_required *= sizeof(uint8_t);
|
||||
case TensorType_FLOAT16:
|
||||
bytes_required *= sizeof(uint16_t);
|
||||
break;
|
||||
case TensorType_INT32:
|
||||
bytes_required *= sizeof(int32_t);
|
||||
break;
|
||||
case TensorType_UINT8:
|
||||
bytes_required *= sizeof(uint8_t);
|
||||
break;
|
||||
case TensorType_INT8:
|
||||
bytes_required *= sizeof(int8_t);
|
||||
break;
|
||||
case TensorType_INT64:
|
||||
bytes_required *= sizeof(int64_t);
|
||||
break;
|
||||
case TensorType_FLOAT16:
|
||||
// FALLTHROUGH_INTENDED;
|
||||
case TensorType_BOOL:
|
||||
bytes_required *= sizeof(bool);
|
||||
break;
|
||||
case TensorType_INT16:
|
||||
bytes_required *= sizeof(uint16_t);
|
||||
break;
|
||||
case TensorType_COMPLEX64:
|
||||
bytes_required *= sizeof(std::complex<float>);
|
||||
break;
|
||||
default:
|
||||
ReportError(error_reporter, "Tensor %s invalid type: %d",
|
||||
tensor.name()->c_str(), tensor.type());
|
||||
|
@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/tools/verifier.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@ -21,11 +23,12 @@ limitations under the License.
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/lite/allocation.h"
|
||||
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
|
||||
#include "tensorflow/lite/error_reporter.h"
|
||||
#include "tensorflow/lite/op_resolver.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/testing/util.h"
|
||||
#include "tensorflow/lite/tools/verifier.h"
|
||||
#include "tensorflow/lite/util.h"
|
||||
#include "tensorflow/lite/version.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -516,6 +519,42 @@ TEST(VerifyModel, OpWithOptionalTensor) {
|
||||
EXPECT_EQ("", builder.GetErrorString());
|
||||
}
|
||||
|
||||
TEST(VerifyModel, TypedTensorShapeMismatchWithTensorBufferSize) {
|
||||
TfLiteFlatbufferModelBuilder builder;
|
||||
for (int tensor_type = TensorType_MIN; tensor_type <= TensorType_MAX;
|
||||
++tensor_type) {
|
||||
if (tensor_type == TensorType_STRING) continue;
|
||||
builder.AddTensor({2, 3}, static_cast<TensorType>(tensor_type),
|
||||
{1, 2, 3, 4}, "input");
|
||||
builder.FinishModel({}, {});
|
||||
ASSERT_FALSE(builder.Verify());
|
||||
EXPECT_THAT(
|
||||
builder.GetErrorString(),
|
||||
::testing::ContainsRegex("Tensor input requires .* bytes, but is "
|
||||
"allocated with 4 bytes buffer"));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(VerifyModel, TypedTensorShapeMatchesTensorBufferSize) {
|
||||
TfLiteFlatbufferModelBuilder builder;
|
||||
for (int tensor_type = TensorType_MIN; tensor_type <= TensorType_MAX;
|
||||
++tensor_type) {
|
||||
if (tensor_type == TensorType_STRING) continue;
|
||||
TfLiteType lite_type = kTfLiteNoType;
|
||||
ASSERT_EQ(ConvertTensorType(static_cast<TensorType>(tensor_type),
|
||||
&lite_type, /*error_reporter=*/nullptr),
|
||||
kTfLiteOk);
|
||||
size_t size_bytes = 0;
|
||||
ASSERT_EQ(GetSizeOfType(/*context=*/nullptr, lite_type, &size_bytes),
|
||||
kTfLiteOk);
|
||||
std::vector<uint8_t> buffer(size_bytes);
|
||||
builder.AddTensor({1}, static_cast<TensorType>(tensor_type), buffer,
|
||||
"input");
|
||||
builder.FinishModel({}, {});
|
||||
ASSERT_TRUE(builder.Verify());
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(yichengfan): make up malicious files to test with.
|
||||
|
||||
} // namespace tflite
|
||||
|
Loading…
Reference in New Issue
Block a user