Internal test infra change

PiperOrigin-RevId: 310261779
Change-Id: Iac550ffeb52a444c6ae58dbc85bb67bf80f50dd8
This commit is contained in:
Jared Duke 2020-05-06 17:32:36 -07:00 committed by TensorFlower Gardener
parent a967cad22b
commit 154044d0f2
1 changed files with 16 additions and 10 deletions

View File

@ -18,6 +18,8 @@ package org.tensorflow.lite.gpu;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
@ -28,7 +30,7 @@ import org.tensorflow.lite.TestUtils;
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
public final class GpuDelegateTest { public final class GpuDelegateTest {
private static final String MODEL_PATH = "tensorflow/lite/java/src/testdata/add.bin"; private static final String MODEL_PATH = "tensorflow/lite/testdata/multi_add.bin";
private static final ByteBuffer MODEL_BUFFER = TestUtils.getTestFileAsBuffer(MODEL_PATH); private static final ByteBuffer MODEL_BUFFER = TestUtils.getTestFileAsBuffer(MODEL_PATH);
@Test @Test
@ -43,15 +45,19 @@ public final class GpuDelegateTest {
Interpreter.Options options = new Interpreter.Options(); Interpreter.Options options = new Interpreter.Options();
try (GpuDelegate delegate = new GpuDelegate(); try (GpuDelegate delegate = new GpuDelegate();
Interpreter interpreter = new Interpreter(MODEL_BUFFER, options.addDelegate(delegate))) { Interpreter interpreter = new Interpreter(MODEL_BUFFER, options.addDelegate(delegate))) {
float[] oneD = {1.23f, 6.54f, 7.81f}; float[] input0 = {1.23f};
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; float[] input1 = {2.43f};
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; Object[] inputs = {input0, input1, input0, input1};
float[][][][] fourD = {threeD, threeD}; float[] parsedOutput0 = new float[1];
float[][][][] parsedOutputs = new float[2][8][8][3]; float[] parsedOutput1 = new float[1];
interpreter.run(fourD, parsedOutputs); Map<Integer, Object> outputs = new HashMap<>();
float[] outputOneD = parsedOutputs[0][0][0]; outputs.put(0, parsedOutput0);
float[] expected = {3.69f, 19.62f, 23.43f}; outputs.put(1, parsedOutput1);
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); interpreter.runForMultipleInputsOutputs(inputs, outputs);
float[] expected0 = {4.89f};
float[] expected1 = {6.09f};
assertThat(parsedOutput0).usingTolerance(0.1f).containsExactly(expected0).inOrder();
assertThat(parsedOutput1).usingTolerance(0.1f).containsExactly(expected1).inOrder();
} }
} }
} }