Fix NNAPIDelegate.StatefulDelegateWithBufferHandles test by marking input data is stale after altering the content of the buffer.

PiperOrigin-RevId: 348839211
Change-Id: I3dcab7df6d772989ac191c1fae32e141f69815fc
This commit is contained in:
A. Unique TensorFlower 2020-12-23 13:22:07 -08:00 committed by TensorFlower Gardener
parent 208bf5695b
commit 170ccd937c

View File

@ -72,6 +72,10 @@ class SingleOpModelWithNNAPI : public SingleOpModel {
interpreter_->SetBufferHandle(index, handle, stateful_delegate_.get());
}
void MarkInputTensorDataStale(int index) {
interpreter_->tensor(index)->data_is_stale = true;
}
TfLiteStatus AllocateTensors() { return interpreter_->AllocateTensors(); }
protected:
@ -391,12 +395,10 @@ TEST(NNAPIDelegate, StatefulDelegateWithBufferHandles) {
!NnApiImplementation()->ANeuralNetworksMemory_createFromFd) {
GTEST_SKIP();
}
// TODO(b/176241505): Fix incorrect outputs on API 29.
if (NnApiImplementation()->android_sdk_version == 29) {
GTEST_SKIP();
}
StatefulNnApiDelegate::Options options;
// Allow NNAPI CPU fallback path.
options.disallow_nnapi_cpu = false;
FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}},
{TensorType_FLOAT32, {1, 2, 2, 1}},
{TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
@ -443,6 +445,7 @@ TEST(NNAPIDelegate, StatefulDelegateWithBufferHandles) {
auto input1_handle = delegate->RegisterNnapiMemory(
input1_memory, memory_callback, &memory_context);
m.SetBufferHandle(m.input1(), input1_handle);
m.MarkInputTensorDataStale(m.input1());
m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3}));
@ -454,6 +457,7 @@ TEST(NNAPIDelegate, StatefulDelegateWithBufferHandles) {
auto input1_handle = delegate->RegisterNnapiMemory(
input1_memory, memory_callback, &memory_context);
m.SetBufferHandle(m.input1(), input1_handle);
m.MarkInputTensorDataStale(m.input1());
m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9 + i, 0.4, 1.0, 1.3}));