Add method for validating GPUCompatibilityList flatbuffer
PiperOrigin-RevId: 332847117 Change-Id: I8c123d26a9ed0f57510b3a5c00c9748866bfae6f
This commit is contained in:
parent
b253e82d53
commit
06f1b2e533
@ -102,5 +102,13 @@ bool GPUCompatibilityList::IsDatabaseLoaded() const {
|
||||
return database_ != nullptr;
|
||||
}
|
||||
|
||||
// static
|
||||
bool GPUCompatibilityList::IsValidFlatbuffer(const unsigned char* data,
|
||||
int len) {
|
||||
// Verify opensource db.
|
||||
flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(data), len);
|
||||
return tflite::acceleration::VerifyDeviceDatabaseBuffer(verifier);
|
||||
}
|
||||
|
||||
} // namespace acceleration
|
||||
} // namespace tflite
|
||||
|
@ -78,6 +78,9 @@ class GPUCompatibilityList {
|
||||
GPUCompatibilityList& operator=(const GPUCompatibilityList&) = delete;
|
||||
bool IsDatabaseLoaded() const;
|
||||
|
||||
// Checks if the provided byte array represents a valid compatibility list
|
||||
static bool IsValidFlatbuffer(const unsigned char* data, int len);
|
||||
|
||||
protected:
|
||||
const DeviceDatabase* database_;
|
||||
};
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
@ -84,4 +85,17 @@ TEST_F(GPUCompatibilityTest, ReturnsDefaultOptions) {
|
||||
default_options.max_delegated_partitions);
|
||||
}
|
||||
|
||||
TEST(GPUCompatibility, RecogniseValidCompatibilityListFlatbuffer) {
|
||||
EXPECT_TRUE(tflite::acceleration::GPUCompatibilityList::IsValidFlatbuffer(
|
||||
g_tflite_acceleration_devicedb_sample_binary,
|
||||
g_tflite_acceleration_devicedb_sample_binary_len));
|
||||
}
|
||||
|
||||
TEST(GPUCompatibility, RecogniseInvalidCompatibilityListFlatbuffer) {
|
||||
unsigned char invalid_buffer[100];
|
||||
std::fill(invalid_buffer, invalid_buffer + 100, ' ');
|
||||
EXPECT_FALSE(tflite::acceleration::GPUCompatibilityList::IsValidFlatbuffer(
|
||||
invalid_buffer, 100));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user