Add the Constant operator class ()

Create a custom operator class to create constants in the Graph,
and introduce the Operator marker annotation to identify
operator classes.

Please see  for the master tracking issue.
This commit is contained in:
KB Sriram 2017-07-27 10:53:57 -07:00 committed by Vijay Vasudevan
parent d09304fca4
commit 599165861e
4 changed files with 430 additions and 1 deletions
tensorflow/java
BUILD
src
main/java/org/tensorflow/op
test/java/org/tensorflow/op/core

View File

@ -34,7 +34,7 @@ filegroup(
filegroup(
name = "java_op_sources",
srcs = glob(["src/main/java/org/tensorflow/op/*.java"]),
srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]),
visibility = [
"//tensorflow/java:__pkg__",
],
@ -191,6 +191,19 @@ java_test(
],
)
java_test(
name = "ConstantTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/op/core/ConstantTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.op.core.ConstantTest",
deps = [
":tensorflow",
":testutil",
"@junit",
],
)
filegroup(
name = "libtensorflow_jni",
srcs = select({

View File

@ -0,0 +1,112 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.op.annotation;
import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* Annotation used by classes to make TensorFlow operations conveniently accessible via {@code
* org.tensorflow.op.Ops}.
*
* <p>An annotation processor (TODO: not yet implemented) builds the {@code Ops} class by
* aggregating all classes annotated as {@code @Operator}s. Each annotated class <b>must</b> have at
* least one public static factory method named {@code create} that accepts a {@link
* org.tensorflow.op.Scope} as its first argument. The processor then adds a convenience method in
* the {@code Ops} class. For example:
*
* <pre>{@code
* @Operator
* public final class MyOp implements Op {
* public static MyOp create(Scope scope, Operand operand) {
* ...
* }
* }
* }</pre>
*
* <p>results in a method in the {@code Ops} class
*
* <pre>{@code
* import org.tensorflow.op.Ops;
* ...
* Ops ops = new Ops(graph);
* ...
* ops.myOp(operand);
* // and has exactly the same effect as calling
* // MyOp.create(ops.getScope(), operand);
* }</pre>
*/
@Documented
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.CLASS)
public @interface Operator {
/**
* Specify an optional group within the {@code Ops} class.
*
* <p>By default, an annotation processor will create convenience methods directly in the {@code
* Ops} class. An annotated operator may optionally choose to place the method within a group. For
* example:
*
* <pre>{@code
* @Operator(group="math")
* public final class Add extends PrimitiveOp implements Operand {
* ...
* }
* }</pre>
*
* <p>results in the {@code add} method placed within a {@code math} group within the {@code Ops}
* class.
*
* <pre>{@code
* ops.math().add(...);
* }</pre>
*
* <p>The group name must be a <a
* href="https://docs.oracle.com/javase/specs/jls/se7/html/jls-3.html#jls-3.8">valid Java
* identifier</a>.
*/
String group() default "";
/**
* Name for the wrapper method used in the {@code Ops} class.
*
* <p>By default, a processor derives the method name in the {@code Ops} class from the class name
* of the operator. This attribute allow you to provide a different name instead. For example:
*
* <pre>{@code
* @Operator(name="myOperation")
* public final class MyRealOperation implements Operand {
* public static MyRealOperation create(...)
* }
* }</pre>
*
* <p>results in this method added to the {@code Ops} class
*
* <pre>{@code
* ops.myOperation(...);
* // and is the same as calling
* // MyRealOperation.create(...)
* }</pre>
*
* <p>The name must be a <a
* href="https://docs.oracle.com/javase/specs/jls/se7/html/jls-3.html#jls-3.8">valid Java
* identifier</a>.
*/
String name() default "";
}

View File

@ -0,0 +1,173 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.op.core;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import org.tensorflow.DataType;
import org.tensorflow.Operand;
import org.tensorflow.Operation;
import org.tensorflow.Output;
import org.tensorflow.Tensor;
import org.tensorflow.op.PrimitiveOp;
import org.tensorflow.op.Scope;
import org.tensorflow.op.annotation.Operator;
/** An operator producing a constant value. */
@Operator
public final class Constant extends PrimitiveOp implements Operand {
/**
* Create a constant from a Java object.
*
* <p>The argument {@code object} is first converted into a Tensor using {@link
* org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be
* provided. For example:
*
* <pre>{@code
* Constant.create(scope, 7); // returns a constant scalar tensor 7
* }</pre>
*
* @param scope is a scope used to add the underlying operation.
* @param object a Java object representing the constant.
* @see org.tensorflow.Tensor#create(Object) Tensor.create
*/
public static Constant create(Scope scope, Object object) {
try (Tensor value = Tensor.create(object)) {
return createWithTensor(scope, value);
}
}
/**
* Create a {@link DataType#INT32} constant with data from the given buffer.
*
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
* its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
* a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
* method.
*
* @param scope is a scope used to add the underlying operation.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
public static Constant create(Scope scope, long[] shape, IntBuffer data) {
try (Tensor value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
/**
* Create a {@link DataType#FLOAT} constant with data from the given buffer.
*
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
* its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
* a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
* method.
*
* @param scope is a scope used to add the underlying operation.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
public static Constant create(Scope scope, long[] shape, FloatBuffer data) {
try (Tensor value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
/**
* Create a {@link DataType#DOUBLE} constant with data from the given buffer.
*
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
* its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
* a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
* method.
*
* @param scope is a scope used to add the underlying operation.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
public static Constant create(Scope scope, long[] shape, DoubleBuffer data) {
try (Tensor value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
/**
* Create a {@link DataType#INT64} constant with data from the given buffer.
*
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
* its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
* a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
* method.
*
* @param scope is a scope used to add the underlying operation.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
public static Constant create(Scope scope, long[] shape, LongBuffer data) {
try (Tensor value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
/**
* Create a constant with data from the given buffer.
*
* <p>Creates a Constant with the provided shape of any type where the constant data has been
* encoded into {@code data} as per the specification of the TensorFlow <a
* href="https://www.tensorflow.org/code/tensorflow/c/c_api.h">C API</a>.
*
* @param scope is a scope used to add the underlying operation.
* @param dataType the tensor datatype.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the
* buffer
*/
public static Constant create(Scope scope, DataType dataType, long[] shape, ByteBuffer data) {
try (Tensor value = Tensor.create(dataType, shape, data)) {
return createWithTensor(scope, value);
}
}
private static Constant createWithTensor(Scope scope, Tensor value) {
return new Constant(
scope
.graph()
.opBuilder("Const", scope.makeOpName("Const"))
.setAttr("value", value)
.setAttr("dtype", value.dataType())
.build());
}
@Override
public Output asOutput() {
return output;
}
private Constant(Operation operation) {
super(operation);
output = operation.output(0);
}
private final Output output;
}

View File

@ -0,0 +1,131 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.op.core;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertTrue;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.tensorflow.DataType;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.op.Scope;
@RunWith(JUnit4.class)
public class ConstantTest {
private static final float EPSILON = 1e-7f;
@Test
public void createIntBuffer() {
int[] ints = {1, 2, 3, 4};
long[] shape = {4};
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant op = Constant.create(scope, shape, IntBuffer.wrap(ints));
Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
int[] actual = new int[ints.length];
assertArrayEquals(ints, result.copyTo(actual));
}
}
@Test
public void createFloatBuffer() {
float[] floats = {1, 2, 3, 4};
long[] shape = {4};
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant op = Constant.create(scope, shape, FloatBuffer.wrap(floats));
Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
float[] actual = new float[floats.length];
assertArrayEquals(floats, result.copyTo(actual), EPSILON);
}
}
@Test
public void createDoubleBuffer() {
double[] doubles = {1, 2, 3, 4};
long[] shape = {4};
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant op = Constant.create(scope, shape, DoubleBuffer.wrap(doubles));
Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
double[] actual = new double[doubles.length];
assertArrayEquals(doubles, result.copyTo(actual), EPSILON);
}
}
@Test
public void createLongBuffer() {
long[] longs = {1, 2, 3, 4};
long[] shape = {4};
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant op = Constant.create(scope, shape, LongBuffer.wrap(longs));
Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
long[] actual = new long[longs.length];
assertArrayEquals(longs, result.copyTo(actual));
}
}
@Test
public void createStringBuffer() throws IOException {
byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4};
long[] shape = {};
// byte arrays (DataType.STRING in Tensorflow) are encoded as an offset in the data buffer,
// followed by a varint encoded size, followed by the data.
ByteArrayOutputStream baout = new ByteArrayOutputStream();
DataOutputStream out = new DataOutputStream(baout);
// Offset in array.
out.writeLong(0L);
// Varint encoded length of buffer.
// For any number < 0x80, the varint encoding is simply the number itself.
// https://developers.google.com/protocol-buffers/docs/encoding#varints
assertTrue(data.length < 0x80);
out.write(data.length);
out.write(data);
out.close();
byte[] content = baout.toByteArray();
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant op = Constant.create(scope, DataType.STRING, shape, ByteBuffer.wrap(content));
Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
assertArrayEquals(data, result.bytesValue());
}
}
}