Internal test infra change

PiperOrigin-RevId: 310261779
Change-Id: Iac550ffeb52a444c6ae58dbc85bb67bf80f50dd8
This commit is contained in:
Jared Duke 2020-05-06 17:32:36 -07:00 committed by TensorFlower Gardener
parent a967cad22b
commit 154044d0f2
1 changed files with 16 additions and 10 deletions

View File

@ -18,6 +18,8 @@ package org.tensorflow.lite.gpu;
import static com.google.common.truth.Truth.assertThat;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@ -28,7 +30,7 @@ import org.tensorflow.lite.TestUtils;
@RunWith(JUnit4.class)
public final class GpuDelegateTest {
private static final String MODEL_PATH = "tensorflow/lite/java/src/testdata/add.bin";
private static final String MODEL_PATH = "tensorflow/lite/testdata/multi_add.bin";
private static final ByteBuffer MODEL_BUFFER = TestUtils.getTestFileAsBuffer(MODEL_PATH);
@Test
@ -43,15 +45,19 @@ public final class GpuDelegateTest {
Interpreter.Options options = new Interpreter.Options();
try (GpuDelegate delegate = new GpuDelegate();
Interpreter interpreter = new Interpreter(MODEL_BUFFER, options.addDelegate(delegate))) {
float[] oneD = {1.23f, 6.54f, 7.81f};
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
float[][][][] parsedOutputs = new float[2][8][8][3];
interpreter.run(fourD, parsedOutputs);
float[] outputOneD = parsedOutputs[0][0][0];
float[] expected = {3.69f, 19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
float[] input0 = {1.23f};
float[] input1 = {2.43f};
Object[] inputs = {input0, input1, input0, input1};
float[] parsedOutput0 = new float[1];
float[] parsedOutput1 = new float[1];
Map<Integer, Object> outputs = new HashMap<>();
outputs.put(0, parsedOutput0);
outputs.put(1, parsedOutput1);
interpreter.runForMultipleInputsOutputs(inputs, outputs);
float[] expected0 = {4.89f};
float[] expected1 = {6.09f};
assertThat(parsedOutput0).usingTolerance(0.1f).containsExactly(expected0).inOrder();
assertThat(parsedOutput1).usingTolerance(0.1f).containsExactly(expected1).inOrder();
}
}
}