Add method for validating GPUCompatibilityList flatbuffer

PiperOrigin-RevId: 332847117
Change-Id: I8c123d26a9ed0f57510b3a5c00c9748866bfae6f
This commit is contained in:
Stefano Galarraga 2020-09-21 07:36:58 -07:00 committed by TensorFlower Gardener
parent b253e82d53
commit 06f1b2e533
3 changed files with 25 additions and 0 deletions

View File

@ -102,5 +102,13 @@ bool GPUCompatibilityList::IsDatabaseLoaded() const {
return database_ != nullptr;
}
// static
bool GPUCompatibilityList::IsValidFlatbuffer(const unsigned char* data,
int len) {
// Verify opensource db.
flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(data), len);
return tflite::acceleration::VerifyDeviceDatabaseBuffer(verifier);
}
} // namespace acceleration
} // namespace tflite

View File

@ -78,6 +78,9 @@ class GPUCompatibilityList {
GPUCompatibilityList& operator=(const GPUCompatibilityList&) = delete;
bool IsDatabaseLoaded() const;
// Checks if the provided byte array represents a valid compatibility list
static bool IsValidFlatbuffer(const unsigned char* data, int len);
protected:
const DeviceDatabase* database_;
};

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/acceleration/compatibility/gpu_compatibility.h"
#include <algorithm>
#include <memory>
#include <gmock/gmock.h>
@ -84,4 +85,17 @@ TEST_F(GPUCompatibilityTest, ReturnsDefaultOptions) {
default_options.max_delegated_partitions);
}
TEST(GPUCompatibility, RecogniseValidCompatibilityListFlatbuffer) {
EXPECT_TRUE(tflite::acceleration::GPUCompatibilityList::IsValidFlatbuffer(
g_tflite_acceleration_devicedb_sample_binary,
g_tflite_acceleration_devicedb_sample_binary_len));
}
TEST(GPUCompatibility, RecogniseInvalidCompatibilityListFlatbuffer) {
unsigned char invalid_buffer[100];
std::fill(invalid_buffer, invalid_buffer + 100, ' ');
EXPECT_FALSE(tflite::acceleration::GPUCompatibilityList::IsValidFlatbuffer(
invalid_buffer, 100));
}
} // namespace