Teach TFLite model verifier about all supported types
PiperOrigin-RevId: 259073172
This commit is contained in:
parent
fdc106e412
commit
8f55026cc8
@ -91,7 +91,8 @@ cc_test(
|
|||||||
"//tensorflow/core:framework_lite",
|
"//tensorflow/core:framework_lite",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite:schema_fbs_version",
|
"//tensorflow/lite:schema_fbs_version",
|
||||||
"//tensorflow/lite/c:c_api_internal",
|
"//tensorflow/lite:util",
|
||||||
|
"//tensorflow/lite/core/api",
|
||||||
"//tensorflow/lite/schema:schema_fbs",
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
"//tensorflow/lite/testing:util",
|
"//tensorflow/lite/testing:util",
|
||||||
"@com_google_googletest//:gtest",
|
"@com_google_googletest//:gtest",
|
||||||
|
|||||||
@ -130,20 +130,30 @@ bool VerifyNumericTensorBuffer(const Tensor& tensor, const Buffer& buffer,
|
|||||||
case TensorType_FLOAT32:
|
case TensorType_FLOAT32:
|
||||||
bytes_required *= sizeof(float);
|
bytes_required *= sizeof(float);
|
||||||
break;
|
break;
|
||||||
case TensorType_INT8:
|
case TensorType_FLOAT16:
|
||||||
bytes_required *= sizeof(int8_t);
|
bytes_required *= sizeof(uint16_t);
|
||||||
break;
|
|
||||||
case TensorType_UINT8:
|
|
||||||
bytes_required *= sizeof(uint8_t);
|
|
||||||
break;
|
break;
|
||||||
case TensorType_INT32:
|
case TensorType_INT32:
|
||||||
bytes_required *= sizeof(int32_t);
|
bytes_required *= sizeof(int32_t);
|
||||||
break;
|
break;
|
||||||
|
case TensorType_UINT8:
|
||||||
|
bytes_required *= sizeof(uint8_t);
|
||||||
|
break;
|
||||||
|
case TensorType_INT8:
|
||||||
|
bytes_required *= sizeof(int8_t);
|
||||||
|
break;
|
||||||
case TensorType_INT64:
|
case TensorType_INT64:
|
||||||
bytes_required *= sizeof(int64_t);
|
bytes_required *= sizeof(int64_t);
|
||||||
break;
|
break;
|
||||||
case TensorType_FLOAT16:
|
case TensorType_BOOL:
|
||||||
// FALLTHROUGH_INTENDED;
|
bytes_required *= sizeof(bool);
|
||||||
|
break;
|
||||||
|
case TensorType_INT16:
|
||||||
|
bytes_required *= sizeof(uint16_t);
|
||||||
|
break;
|
||||||
|
case TensorType_COMPLEX64:
|
||||||
|
bytes_required *= sizeof(std::complex<float>);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
ReportError(error_reporter, "Tensor %s invalid type: %d",
|
ReportError(error_reporter, "Tensor %s invalid type: %d",
|
||||||
tensor.name()->c_str(), tensor.type());
|
tensor.name()->c_str(), tensor.type());
|
||||||
|
|||||||
@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/tools/verifier.h"
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -21,11 +23,12 @@ limitations under the License.
|
|||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include "tensorflow/core/framework/numeric_types.h"
|
#include "tensorflow/core/framework/numeric_types.h"
|
||||||
#include "tensorflow/lite/allocation.h"
|
#include "tensorflow/lite/allocation.h"
|
||||||
|
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
|
||||||
#include "tensorflow/lite/error_reporter.h"
|
#include "tensorflow/lite/error_reporter.h"
|
||||||
#include "tensorflow/lite/op_resolver.h"
|
#include "tensorflow/lite/op_resolver.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
#include "tensorflow/lite/testing/util.h"
|
#include "tensorflow/lite/testing/util.h"
|
||||||
#include "tensorflow/lite/tools/verifier.h"
|
#include "tensorflow/lite/util.h"
|
||||||
#include "tensorflow/lite/version.h"
|
#include "tensorflow/lite/version.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
@ -516,6 +519,42 @@ TEST(VerifyModel, OpWithOptionalTensor) {
|
|||||||
EXPECT_EQ("", builder.GetErrorString());
|
EXPECT_EQ("", builder.GetErrorString());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(VerifyModel, TypedTensorShapeMismatchWithTensorBufferSize) {
|
||||||
|
TfLiteFlatbufferModelBuilder builder;
|
||||||
|
for (int tensor_type = TensorType_MIN; tensor_type <= TensorType_MAX;
|
||||||
|
++tensor_type) {
|
||||||
|
if (tensor_type == TensorType_STRING) continue;
|
||||||
|
builder.AddTensor({2, 3}, static_cast<TensorType>(tensor_type),
|
||||||
|
{1, 2, 3, 4}, "input");
|
||||||
|
builder.FinishModel({}, {});
|
||||||
|
ASSERT_FALSE(builder.Verify());
|
||||||
|
EXPECT_THAT(
|
||||||
|
builder.GetErrorString(),
|
||||||
|
::testing::ContainsRegex("Tensor input requires .* bytes, but is "
|
||||||
|
"allocated with 4 bytes buffer"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(VerifyModel, TypedTensorShapeMatchesTensorBufferSize) {
|
||||||
|
TfLiteFlatbufferModelBuilder builder;
|
||||||
|
for (int tensor_type = TensorType_MIN; tensor_type <= TensorType_MAX;
|
||||||
|
++tensor_type) {
|
||||||
|
if (tensor_type == TensorType_STRING) continue;
|
||||||
|
TfLiteType lite_type = kTfLiteNoType;
|
||||||
|
ASSERT_EQ(ConvertTensorType(static_cast<TensorType>(tensor_type),
|
||||||
|
&lite_type, /*error_reporter=*/nullptr),
|
||||||
|
kTfLiteOk);
|
||||||
|
size_t size_bytes = 0;
|
||||||
|
ASSERT_EQ(GetSizeOfType(/*context=*/nullptr, lite_type, &size_bytes),
|
||||||
|
kTfLiteOk);
|
||||||
|
std::vector<uint8_t> buffer(size_bytes);
|
||||||
|
builder.AddTensor({1}, static_cast<TensorType>(tensor_type), buffer,
|
||||||
|
"input");
|
||||||
|
builder.FinishModel({}, {});
|
||||||
|
ASSERT_TRUE(builder.Verify());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(yichengfan): make up malicious files to test with.
|
// TODO(yichengfan): make up malicious files to test with.
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user