[Java] Add base classes and utilities for operation wrappers. (#11188)
* Add base classes and utilities for operation wrappers. * Rename Input interface to Operand * Introduce changes after code review
This commit is contained in:
parent
a72fc31bca
commit
7c1fe9068b
@ -162,6 +162,32 @@ java_test(
|
||||
],
|
||||
)
|
||||
|
||||
java_test(
|
||||
name = "PrimitiveOpTest",
|
||||
size = "small",
|
||||
srcs = ["src/test/java/org/tensorflow/op/PrimitiveOpTest.java"],
|
||||
javacopts = JAVACOPTS,
|
||||
test_class = "org.tensorflow.op.PrimitiveOpTest",
|
||||
deps = [
|
||||
":tensorflow",
|
||||
":testutil",
|
||||
"@junit",
|
||||
],
|
||||
)
|
||||
|
||||
java_test(
|
||||
name = "OperandsTest",
|
||||
size = "small",
|
||||
srcs = ["src/test/java/org/tensorflow/op/OperandsTest.java"],
|
||||
javacopts = JAVACOPTS,
|
||||
test_class = "org.tensorflow.op.OperandsTest",
|
||||
deps = [
|
||||
":tensorflow",
|
||||
":testutil",
|
||||
"@junit",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "libtensorflow_jni",
|
||||
srcs = select({
|
||||
|
@ -21,20 +21,20 @@ package org.tensorflow;
|
||||
* <p>Example usage:
|
||||
*
|
||||
* <pre>{@code
|
||||
* // The "decodeJpeg" operation can be used as input to the "cast" operation
|
||||
* Input decodeJpeg = ops.image().decodeJpeg(...);
|
||||
* // The "decodeJpeg" operation can be used as an operand to the "cast" operation
|
||||
* Operand decodeJpeg = ops.image().decodeJpeg(...);
|
||||
* ops.math().cast(decodeJpeg, DataType.FLOAT);
|
||||
*
|
||||
* // The output "y" of the "unique" operation can be used as input to the "cast" operation
|
||||
* // The output "y" of the "unique" operation can be used as an operand to the "cast" operation
|
||||
* Output y = ops.array().unique(...).y();
|
||||
* ops.math().cast(y, DataType.FLOAT);
|
||||
*
|
||||
* // The "split" operation can be used as input list to the "concat" operation
|
||||
* Iterable<? extends Input> split = ops.array().split(...);
|
||||
* // The "split" operation can be used as operand list to the "concat" operation
|
||||
* Iterable<? extends Operand> split = ops.array().split(...);
|
||||
* ops.array().concat(0, split);
|
||||
* }</pre>
|
||||
*/
|
||||
public interface Input {
|
||||
public interface Operand {
|
||||
|
||||
/**
|
||||
* Returns the symbolic handle of a tensor.
|
@ -91,6 +91,21 @@ public final class Operation {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns symbolic handles to a list of tensors produced by this operation.
|
||||
*
|
||||
* @param idx index of the first tensor of the list
|
||||
* @param length number of tensors in the list
|
||||
* @return array of {@code Output}
|
||||
*/
|
||||
public Output[] outputList(int idx, int length) {
|
||||
Output[] outputs = new Output[length];
|
||||
for (int i = 0; i < length; ++i) {
|
||||
outputs[i] = output(idx + i);
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
/** Returns a symbolic handle to one of the tensors produced by this operation. */
|
||||
public Output output(int idx) {
|
||||
return new Output(this, idx);
|
||||
|
@ -22,10 +22,10 @@ import java.util.Objects;
|
||||
* <p>An Output is a symbolic handle to a tensor. The value of the Tensor is computed by executing
|
||||
* the {@link Operation} in a {@link Session}.
|
||||
*
|
||||
* <p>By implementing the {@link Input} interface, instances of this class could also be passed
|
||||
* directly in input to an operation.
|
||||
* <p>By implementing the {@link Operand} interface, instances of this class also act as operands to
|
||||
* {@link org.tensorflow.op.Op Op} instances.
|
||||
*/
|
||||
public final class Output implements Input {
|
||||
public final class Output implements Operand {
|
||||
|
||||
/** Handle to the idx-th output of the Operation {@code op}. */
|
||||
public Output(Operation op, int idx) {
|
||||
|
35
tensorflow/java/src/main/java/org/tensorflow/op/Op.java
Normal file
35
tensorflow/java/src/main/java/org/tensorflow/op/Op.java
Normal file
@ -0,0 +1,35 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.op;
|
||||
|
||||
/**
|
||||
* A marker interface for all operation wrappers.
|
||||
*
|
||||
* <p>Operation wrappers provide strongly typed interfaces for building operations and linking them
|
||||
* into a graph without the use of literals and indexes required by the core classes.
|
||||
*
|
||||
* <p>This interface allows keeping references to any operation wrapper using a common type.
|
||||
*
|
||||
* <pre>{@code
|
||||
* // All values returned by an Ops call can be referred as a Op
|
||||
* Op split = ops.array().split(...);
|
||||
* Op shape = ops.array().shape(...);
|
||||
*
|
||||
* // All operations could be added to an Op collection
|
||||
* Collection<Op> allOps = Arrays.asList(split, shape);
|
||||
* }
|
||||
*/
|
||||
public interface Op {}
|
@ -0,0 +1,47 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.op;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import org.tensorflow.Operand;
|
||||
import org.tensorflow.OperationBuilder;
|
||||
import org.tensorflow.Output;
|
||||
|
||||
/** Utilities for manipulating operand related types and lists. */
|
||||
public final class Operands {
|
||||
|
||||
/**
|
||||
* Converts a list of {@link Operand} into an array of {@link Output}.
|
||||
*
|
||||
* <p>Operation wrappers need to convert back a list of inputs into an array of outputs in order
|
||||
* to build an operation, see {@link OperationBuilder#addInputList(Output[])}.
|
||||
*
|
||||
* @param inputs an iteration of input operands
|
||||
* @return an array of outputs
|
||||
*/
|
||||
public static Output[] asOutputs(Iterable<? extends Operand> inputs) {
|
||||
List<Output> outputList = new ArrayList<>();
|
||||
for (Operand input : inputs) {
|
||||
outputList.add(input.asOutput());
|
||||
}
|
||||
return outputList.toArray(new Output[outputList.size()]);
|
||||
}
|
||||
|
||||
// Disabled constructor
|
||||
private Operands() {}
|
||||
}
|
@ -0,0 +1,65 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.op;
|
||||
|
||||
import org.tensorflow.Operation;
|
||||
|
||||
/**
|
||||
* A base class for {@link Op} implementations that are backed by a single {@link Operation}.
|
||||
*
|
||||
* <p>Each operation registered in the TensorFlow core is a primitive and is provided as a {@code
|
||||
* PrimitiveOp}. Custom operations working with only one primitive may also derive from this class.
|
||||
*/
|
||||
public abstract class PrimitiveOp implements Op {
|
||||
|
||||
@Override
|
||||
public final int hashCode() {
|
||||
return operation.hashCode();
|
||||
}
|
||||
|
||||
@Override
|
||||
public final boolean equals(Object obj) {
|
||||
if (this == obj) {
|
||||
return true;
|
||||
}
|
||||
// Note: we consider that all objects wrapping the same operation are equal, no matter their
|
||||
// implementation
|
||||
if (!(obj instanceof PrimitiveOp)) {
|
||||
return false;
|
||||
}
|
||||
return operation.equals(((PrimitiveOp) obj).operation);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final String toString() {
|
||||
return String.format("<%s '%s'>", operation.type(), operation.name());
|
||||
}
|
||||
|
||||
/**
|
||||
* Underlying operation. It is deliberately not exposed by a getter method to avoid any name
|
||||
* conflict with generated methods of the subclasses.
|
||||
*/
|
||||
protected final Operation operation;
|
||||
|
||||
/**
|
||||
* Constructor.
|
||||
*
|
||||
* @param operation the underlying operation
|
||||
*/
|
||||
protected PrimitiveOp(Operation operation) {
|
||||
this.operation = operation;
|
||||
}
|
||||
}
|
@ -24,6 +24,7 @@ import static org.junit.Assert.fail;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
@ -153,6 +154,19 @@ public class OperationTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void outputList() {
|
||||
try (Graph g = new Graph()) {
|
||||
Operation split = TestUtil.split(g, "split", new int[] {0, 1, 2}, 3);
|
||||
Output[] outputs = split.outputList(1, 2);
|
||||
assertNotNull(outputs);
|
||||
assertEquals(2, outputs.length);
|
||||
for (int i = 0; i < outputs.length; ++i) {
|
||||
assertEquals(i + 1, outputs[i].index());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static int split(int[] values, int num_split) {
|
||||
try (Graph g = new Graph()) {
|
||||
return g.opBuilder("Split", "Split")
|
||||
|
@ -48,6 +48,14 @@ public class TestUtil {
|
||||
.output(0);
|
||||
}
|
||||
|
||||
public static Operation split(Graph g, String name, int[] values, int num_split) {
|
||||
return g.opBuilder("Split", name)
|
||||
.addInput(constant(g, "split_dim", 0))
|
||||
.addInput(constant(g, "values", values))
|
||||
.setAttr("num_split", num_split)
|
||||
.build();
|
||||
}
|
||||
|
||||
public static void transpose_A_times_X(Graph g, int[][] a) {
|
||||
matmul(g, "Y", constant(g, "A", a), placeholder(g, "X", DataType.INT32), true, false);
|
||||
}
|
||||
|
@ -0,0 +1,47 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.op;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertSame;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
import org.tensorflow.Graph;
|
||||
import org.tensorflow.Operation;
|
||||
import org.tensorflow.Output;
|
||||
import org.tensorflow.TestUtil;
|
||||
|
||||
/** Unit tests for {@link org.tensorflow.op.Operands}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public class OperandsTest {
|
||||
|
||||
@Test
|
||||
public void createOutputArrayFromOperandList() {
|
||||
try (Graph g = new Graph()) {
|
||||
Operation split = TestUtil.split(g, "split", new int[] {0, 1, 2}, 3);
|
||||
List<Output> list = Arrays.asList(split.output(0), split.output(2));
|
||||
Output[] array = Operands.asOutputs(list);
|
||||
assertEquals(list.size(), array.length);
|
||||
assertSame(array[0], list.get(0));
|
||||
assertSame(array[1], list.get(1));
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,57 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.op;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertNotEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.tensorflow.Graph;
|
||||
import org.tensorflow.Output;
|
||||
import org.tensorflow.TestUtil;
|
||||
|
||||
public class PrimitiveOpTest {
|
||||
|
||||
@Test
|
||||
public void equalsHashcode() {
|
||||
try (Graph g = new Graph()) {
|
||||
Output array = TestUtil.constant(g, "array", new int[2]);
|
||||
|
||||
PrimitiveOp test1 =
|
||||
new PrimitiveOp(g.opBuilder("Shape", "shape1").addInput(array).build()) {};
|
||||
PrimitiveOp test2 =
|
||||
new PrimitiveOp(g.opBuilder("Shape", "shape2").addInput(array).build()) {};
|
||||
PrimitiveOp test3 = new PrimitiveOp(test1.operation) {};
|
||||
|
||||
// equals() tests
|
||||
assertNotEquals(test1, test2);
|
||||
assertEquals(test1, test3);
|
||||
assertEquals(test3, test1);
|
||||
assertNotEquals(test2, test3);
|
||||
|
||||
// hashcode() tests
|
||||
Set<PrimitiveOp> ops = new HashSet<>();
|
||||
assertTrue(ops.add(test1));
|
||||
assertTrue(ops.add(test2));
|
||||
assertFalse(ops.add(test3));
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user