diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java index b9cbc27052f..1fe4a531624 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java @@ -18,6 +18,8 @@ package org.tensorflow.lite.gpu; import static com.google.common.truth.Truth.assertThat; import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -28,7 +30,7 @@ import org.tensorflow.lite.TestUtils; @RunWith(JUnit4.class) public final class GpuDelegateTest { - private static final String MODEL_PATH = "tensorflow/lite/java/src/testdata/add.bin"; + private static final String MODEL_PATH = "tensorflow/lite/testdata/multi_add.bin"; private static final ByteBuffer MODEL_BUFFER = TestUtils.getTestFileAsBuffer(MODEL_PATH); @Test @@ -43,15 +45,19 @@ public final class GpuDelegateTest { Interpreter.Options options = new Interpreter.Options(); try (GpuDelegate delegate = new GpuDelegate(); Interpreter interpreter = new Interpreter(MODEL_BUFFER, options.addDelegate(delegate))) { - float[] oneD = {1.23f, 6.54f, 7.81f}; - float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; - float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; - float[][][][] fourD = {threeD, threeD}; - float[][][][] parsedOutputs = new float[2][8][8][3]; - interpreter.run(fourD, parsedOutputs); - float[] outputOneD = parsedOutputs[0][0][0]; - float[] expected = {3.69f, 19.62f, 23.43f}; - assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + float[] input0 = {1.23f}; + float[] input1 = {2.43f}; + Object[] inputs = {input0, input1, input0, input1}; + float[] parsedOutput0 = new float[1]; + float[] parsedOutput1 = new float[1]; + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutput0); + outputs.put(1, parsedOutput1); + interpreter.runForMultipleInputsOutputs(inputs, outputs); + float[] expected0 = {4.89f}; + float[] expected1 = {6.09f}; + assertThat(parsedOutput0).usingTolerance(0.1f).containsExactly(expected0).inOrder(); + assertThat(parsedOutput1).usingTolerance(0.1f).containsExactly(expected1).inOrder(); } } }