Teach TFLite model verifier about all supported types

PiperOrigin-RevId: 259073172
This commit is contained in:
Jared Duke 2019-07-19 18:18:43 -07:00 committed by TensorFlower Gardener
parent fdc106e412
commit 8f55026cc8
3 changed files with 59 additions and 9 deletions

View File

@ -91,7 +91,8 @@ cc_test(
"//tensorflow/core:framework_lite",
"//tensorflow/lite:framework",
"//tensorflow/lite:schema_fbs_version",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite:util",
"//tensorflow/lite/core/api",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/testing:util",
"@com_google_googletest//:gtest",

View File

@ -130,20 +130,30 @@ bool VerifyNumericTensorBuffer(const Tensor& tensor, const Buffer& buffer,
case TensorType_FLOAT32:
bytes_required *= sizeof(float);
break;
case TensorType_INT8:
bytes_required *= sizeof(int8_t);
break;
case TensorType_UINT8:
bytes_required *= sizeof(uint8_t);
case TensorType_FLOAT16:
bytes_required *= sizeof(uint16_t);
break;
case TensorType_INT32:
bytes_required *= sizeof(int32_t);
break;
case TensorType_UINT8:
bytes_required *= sizeof(uint8_t);
break;
case TensorType_INT8:
bytes_required *= sizeof(int8_t);
break;
case TensorType_INT64:
bytes_required *= sizeof(int64_t);
break;
case TensorType_FLOAT16:
// FALLTHROUGH_INTENDED;
case TensorType_BOOL:
bytes_required *= sizeof(bool);
break;
case TensorType_INT16:
bytes_required *= sizeof(uint16_t);
break;
case TensorType_COMPLEX64:
bytes_required *= sizeof(std::complex<float>);
break;
default:
ReportError(error_reporter, "Tensor %s invalid type: %d",
tensor.name()->c_str(), tensor.type());

View File

@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/tools/verifier.h"
#include <string>
#include <vector>
@ -21,11 +23,12 @@ limitations under the License.
#include <gtest/gtest.h>
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/lite/allocation.h"
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
#include "tensorflow/lite/error_reporter.h"
#include "tensorflow/lite/op_resolver.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/testing/util.h"
#include "tensorflow/lite/tools/verifier.h"
#include "tensorflow/lite/util.h"
#include "tensorflow/lite/version.h"
namespace tflite {
@ -516,6 +519,42 @@ TEST(VerifyModel, OpWithOptionalTensor) {
EXPECT_EQ("", builder.GetErrorString());
}
TEST(VerifyModel, TypedTensorShapeMismatchWithTensorBufferSize) {
TfLiteFlatbufferModelBuilder builder;
for (int tensor_type = TensorType_MIN; tensor_type <= TensorType_MAX;
++tensor_type) {
if (tensor_type == TensorType_STRING) continue;
builder.AddTensor({2, 3}, static_cast<TensorType>(tensor_type),
{1, 2, 3, 4}, "input");
builder.FinishModel({}, {});
ASSERT_FALSE(builder.Verify());
EXPECT_THAT(
builder.GetErrorString(),
::testing::ContainsRegex("Tensor input requires .* bytes, but is "
"allocated with 4 bytes buffer"));
}
}
TEST(VerifyModel, TypedTensorShapeMatchesTensorBufferSize) {
TfLiteFlatbufferModelBuilder builder;
for (int tensor_type = TensorType_MIN; tensor_type <= TensorType_MAX;
++tensor_type) {
if (tensor_type == TensorType_STRING) continue;
TfLiteType lite_type = kTfLiteNoType;
ASSERT_EQ(ConvertTensorType(static_cast<TensorType>(tensor_type),
&lite_type, /*error_reporter=*/nullptr),
kTfLiteOk);
size_t size_bytes = 0;
ASSERT_EQ(GetSizeOfType(/*context=*/nullptr, lite_type, &size_bytes),
kTfLiteOk);
std::vector<uint8_t> buffer(size_bytes);
builder.AddTensor({1}, static_cast<TensorType>(tensor_type), buffer,
"input");
builder.FinishModel({}, {});
ASSERT_TRUE(builder.Verify());
}
}
// TODO(yichengfan): make up malicious files to test with.
} // namespace tflite