Internal test infra change
PiperOrigin-RevId: 310261779 Change-Id: Iac550ffeb52a444c6ae58dbc85bb67bf80f50dd8
This commit is contained in:
parent
a967cad22b
commit
154044d0f2
|
@ -18,6 +18,8 @@ package org.tensorflow.lite.gpu;
|
|||
import static com.google.common.truth.Truth.assertThat;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
@ -28,7 +30,7 @@ import org.tensorflow.lite.TestUtils;
|
|||
@RunWith(JUnit4.class)
|
||||
public final class GpuDelegateTest {
|
||||
|
||||
private static final String MODEL_PATH = "tensorflow/lite/java/src/testdata/add.bin";
|
||||
private static final String MODEL_PATH = "tensorflow/lite/testdata/multi_add.bin";
|
||||
private static final ByteBuffer MODEL_BUFFER = TestUtils.getTestFileAsBuffer(MODEL_PATH);
|
||||
|
||||
@Test
|
||||
|
@ -43,15 +45,19 @@ public final class GpuDelegateTest {
|
|||
Interpreter.Options options = new Interpreter.Options();
|
||||
try (GpuDelegate delegate = new GpuDelegate();
|
||||
Interpreter interpreter = new Interpreter(MODEL_BUFFER, options.addDelegate(delegate))) {
|
||||
float[] oneD = {1.23f, 6.54f, 7.81f};
|
||||
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
|
||||
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
|
||||
float[][][][] fourD = {threeD, threeD};
|
||||
float[][][][] parsedOutputs = new float[2][8][8][3];
|
||||
interpreter.run(fourD, parsedOutputs);
|
||||
float[] outputOneD = parsedOutputs[0][0][0];
|
||||
float[] expected = {3.69f, 19.62f, 23.43f};
|
||||
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
|
||||
float[] input0 = {1.23f};
|
||||
float[] input1 = {2.43f};
|
||||
Object[] inputs = {input0, input1, input0, input1};
|
||||
float[] parsedOutput0 = new float[1];
|
||||
float[] parsedOutput1 = new float[1];
|
||||
Map<Integer, Object> outputs = new HashMap<>();
|
||||
outputs.put(0, parsedOutput0);
|
||||
outputs.put(1, parsedOutput1);
|
||||
interpreter.runForMultipleInputsOutputs(inputs, outputs);
|
||||
float[] expected0 = {4.89f};
|
||||
float[] expected1 = {6.09f};
|
||||
assertThat(parsedOutput0).usingTolerance(0.1f).containsExactly(expected0).inOrder();
|
||||
assertThat(parsedOutput1).usingTolerance(0.1f).containsExactly(expected1).inOrder();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue