[Java] Add base classes and utilities for operation wrappers. (#11188)

* Add base classes and utilities for operation wrappers.

* Rename Input interface to Operand

* Introduce changes after code review
This commit is contained in:
Karl Lessard 2017-07-12 14:10:53 -04:00 committed by Frank Chen
parent a72fc31bca
commit 7c1fe9068b
11 changed files with 323 additions and 9 deletions

View File

@ -162,6 +162,32 @@ java_test(
],
)
java_test(
name = "PrimitiveOpTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/op/PrimitiveOpTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.op.PrimitiveOpTest",
deps = [
":tensorflow",
":testutil",
"@junit",
],
)
java_test(
name = "OperandsTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/op/OperandsTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.op.OperandsTest",
deps = [
":tensorflow",
":testutil",
"@junit",
],
)
filegroup(
name = "libtensorflow_jni",
srcs = select({

View File

@ -21,20 +21,20 @@ package org.tensorflow;
* <p>Example usage:
*
* <pre>{@code
* // The "decodeJpeg" operation can be used as input to the "cast" operation
* Input decodeJpeg = ops.image().decodeJpeg(...);
* // The "decodeJpeg" operation can be used as an operand to the "cast" operation
* Operand decodeJpeg = ops.image().decodeJpeg(...);
* ops.math().cast(decodeJpeg, DataType.FLOAT);
*
* // The output "y" of the "unique" operation can be used as input to the "cast" operation
* // The output "y" of the "unique" operation can be used as an operand to the "cast" operation
* Output y = ops.array().unique(...).y();
* ops.math().cast(y, DataType.FLOAT);
*
* // The "split" operation can be used as input list to the "concat" operation
* Iterable<? extends Input> split = ops.array().split(...);
* // The "split" operation can be used as operand list to the "concat" operation
* Iterable<? extends Operand> split = ops.array().split(...);
* ops.array().concat(0, split);
* }</pre>
*/
public interface Input {
public interface Operand {
/**
* Returns the symbolic handle of a tensor.

View File

@ -91,6 +91,21 @@ public final class Operation {
}
}
/**
* Returns symbolic handles to a list of tensors produced by this operation.
*
* @param idx index of the first tensor of the list
* @param length number of tensors in the list
* @return array of {@code Output}
*/
public Output[] outputList(int idx, int length) {
Output[] outputs = new Output[length];
for (int i = 0; i < length; ++i) {
outputs[i] = output(idx + i);
}
return outputs;
}
/** Returns a symbolic handle to one of the tensors produced by this operation. */
public Output output(int idx) {
return new Output(this, idx);

View File

@ -22,10 +22,10 @@ import java.util.Objects;
* <p>An Output is a symbolic handle to a tensor. The value of the Tensor is computed by executing
* the {@link Operation} in a {@link Session}.
*
* <p>By implementing the {@link Input} interface, instances of this class could also be passed
* directly in input to an operation.
* <p>By implementing the {@link Operand} interface, instances of this class also act as operands to
* {@link org.tensorflow.op.Op Op} instances.
*/
public final class Output implements Input {
public final class Output implements Operand {
/** Handle to the idx-th output of the Operation {@code op}. */
public Output(Operation op, int idx) {

View File

@ -0,0 +1,35 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.op;
/**
* A marker interface for all operation wrappers.
*
* <p>Operation wrappers provide strongly typed interfaces for building operations and linking them
* into a graph without the use of literals and indexes required by the core classes.
*
* <p>This interface allows keeping references to any operation wrapper using a common type.
*
* <pre>{@code
* // All values returned by an Ops call can be referred as a Op
* Op split = ops.array().split(...);
* Op shape = ops.array().shape(...);
*
* // All operations could be added to an Op collection
* Collection<Op> allOps = Arrays.asList(split, shape);
* }
*/
public interface Op {}

View File

@ -0,0 +1,47 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.op;
import java.util.ArrayList;
import java.util.List;
import org.tensorflow.Operand;
import org.tensorflow.OperationBuilder;
import org.tensorflow.Output;
/** Utilities for manipulating operand related types and lists. */
public final class Operands {
/**
* Converts a list of {@link Operand} into an array of {@link Output}.
*
* <p>Operation wrappers need to convert back a list of inputs into an array of outputs in order
* to build an operation, see {@link OperationBuilder#addInputList(Output[])}.
*
* @param inputs an iteration of input operands
* @return an array of outputs
*/
public static Output[] asOutputs(Iterable<? extends Operand> inputs) {
List<Output> outputList = new ArrayList<>();
for (Operand input : inputs) {
outputList.add(input.asOutput());
}
return outputList.toArray(new Output[outputList.size()]);
}
// Disabled constructor
private Operands() {}
}

View File

@ -0,0 +1,65 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.op;
import org.tensorflow.Operation;
/**
* A base class for {@link Op} implementations that are backed by a single {@link Operation}.
*
* <p>Each operation registered in the TensorFlow core is a primitive and is provided as a {@code
* PrimitiveOp}. Custom operations working with only one primitive may also derive from this class.
*/
public abstract class PrimitiveOp implements Op {
@Override
public final int hashCode() {
return operation.hashCode();
}
@Override
public final boolean equals(Object obj) {
if (this == obj) {
return true;
}
// Note: we consider that all objects wrapping the same operation are equal, no matter their
// implementation
if (!(obj instanceof PrimitiveOp)) {
return false;
}
return operation.equals(((PrimitiveOp) obj).operation);
}
@Override
public final String toString() {
return String.format("<%s '%s'>", operation.type(), operation.name());
}
/**
* Underlying operation. It is deliberately not exposed by a getter method to avoid any name
* conflict with generated methods of the subclasses.
*/
protected final Operation operation;
/**
* Constructor.
*
* @param operation the underlying operation
*/
protected PrimitiveOp(Operation operation) {
this.operation = operation;
}
}

View File

@ -24,6 +24,7 @@ import static org.junit.Assert.fail;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@ -153,6 +154,19 @@ public class OperationTest {
}
}
@Test
public void outputList() {
try (Graph g = new Graph()) {
Operation split = TestUtil.split(g, "split", new int[] {0, 1, 2}, 3);
Output[] outputs = split.outputList(1, 2);
assertNotNull(outputs);
assertEquals(2, outputs.length);
for (int i = 0; i < outputs.length; ++i) {
assertEquals(i + 1, outputs[i].index());
}
}
}
private static int split(int[] values, int num_split) {
try (Graph g = new Graph()) {
return g.opBuilder("Split", "Split")

View File

@ -48,6 +48,14 @@ public class TestUtil {
.output(0);
}
public static Operation split(Graph g, String name, int[] values, int num_split) {
return g.opBuilder("Split", name)
.addInput(constant(g, "split_dim", 0))
.addInput(constant(g, "values", values))
.setAttr("num_split", num_split)
.build();
}
public static void transpose_A_times_X(Graph g, int[][] a) {
matmul(g, "Y", constant(g, "A", a), placeholder(g, "X", DataType.INT32), true, false);
}

View File

@ -0,0 +1,47 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.op;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertSame;
import java.util.Arrays;
import java.util.List;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.Output;
import org.tensorflow.TestUtil;
/** Unit tests for {@link org.tensorflow.op.Operands}. */
@RunWith(JUnit4.class)
public class OperandsTest {
@Test
public void createOutputArrayFromOperandList() {
try (Graph g = new Graph()) {
Operation split = TestUtil.split(g, "split", new int[] {0, 1, 2}, 3);
List<Output> list = Arrays.asList(split.output(0), split.output(2));
Output[] array = Operands.asOutputs(list);
assertEquals(list.size(), array.length);
assertSame(array[0], list.get(0));
assertSame(array[1], list.get(1));
}
}
}

View File

@ -0,0 +1,57 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.op;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue;
import java.util.HashSet;
import java.util.Set;
import org.junit.Test;
import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.TestUtil;
public class PrimitiveOpTest {
@Test
public void equalsHashcode() {
try (Graph g = new Graph()) {
Output array = TestUtil.constant(g, "array", new int[2]);
PrimitiveOp test1 =
new PrimitiveOp(g.opBuilder("Shape", "shape1").addInput(array).build()) {};
PrimitiveOp test2 =
new PrimitiveOp(g.opBuilder("Shape", "shape2").addInput(array).build()) {};
PrimitiveOp test3 = new PrimitiveOp(test1.operation) {};
// equals() tests
assertNotEquals(test1, test2);
assertEquals(test1, test3);
assertEquals(test3, test1);
assertNotEquals(test2, test3);
// hashcode() tests
Set<PrimitiveOp> ops = new HashSet<>();
assertTrue(ops.add(test1));
assertTrue(ops.add(test2));
assertFalse(ops.add(test3));
}
}
}