Internal test change for Android
PiperOrigin-RevId: 309078882 Change-Id: Ia433a9b69ce75954b2400a4e7d88daf3a06536f8
This commit is contained in:
parent
0782e4933a
commit
ef9563cc4a
@ -348,6 +348,14 @@ filegroup(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "portable_gpu_tests",
|
||||
srcs = [
|
||||
"src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "libtensorflowlite_jni",
|
||||
srcs = select({
|
||||
|
@ -0,0 +1,57 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.gpu;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
import org.tensorflow.lite.Interpreter;
|
||||
import org.tensorflow.lite.TestUtils;
|
||||
|
||||
/** Unit tests for {@link org.tensorflow.lite.gpu.GpuDelegate}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public final class GpuDelegateTest {
|
||||
|
||||
private static final String MODEL_PATH = "tensorflow/lite/java/src/testdata/add.bin";
|
||||
private static final ByteBuffer MODEL_BUFFER = TestUtils.getTestFileAsBuffer(MODEL_PATH);
|
||||
|
||||
@Test
|
||||
public void testBasic() throws Exception {
|
||||
try (GpuDelegate delegate = new GpuDelegate()) {
|
||||
assertThat(delegate.getNativeHandle()).isNotEqualTo(0);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInterpreterWithGpu() throws Exception {
|
||||
Interpreter.Options options = new Interpreter.Options();
|
||||
try (GpuDelegate delegate = new GpuDelegate();
|
||||
Interpreter interpreter = new Interpreter(MODEL_BUFFER, options.addDelegate(delegate))) {
|
||||
float[] oneD = {1.23f, 6.54f, 7.81f};
|
||||
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
|
||||
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
|
||||
float[][][][] fourD = {threeD, threeD};
|
||||
float[][][][] parsedOutputs = new float[2][8][8][3];
|
||||
interpreter.run(fourD, parsedOutputs);
|
||||
float[] outputOneD = parsedOutputs[0][0][0];
|
||||
float[] expected = {3.69f, 19.62f, 23.43f};
|
||||
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user