Add the Constant operator class (#11559)
Create a custom operator class to create constants in the Graph, and introduce the Operator marker annotation to identify operator classes. Please see #7149 for the master tracking issue.
This commit is contained in:
parent
d09304fca4
commit
599165861e
tensorflow/java
@ -34,7 +34,7 @@ filegroup(
|
||||
|
||||
filegroup(
|
||||
name = "java_op_sources",
|
||||
srcs = glob(["src/main/java/org/tensorflow/op/*.java"]),
|
||||
srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]),
|
||||
visibility = [
|
||||
"//tensorflow/java:__pkg__",
|
||||
],
|
||||
@ -191,6 +191,19 @@ java_test(
|
||||
],
|
||||
)
|
||||
|
||||
java_test(
|
||||
name = "ConstantTest",
|
||||
size = "small",
|
||||
srcs = ["src/test/java/org/tensorflow/op/core/ConstantTest.java"],
|
||||
javacopts = JAVACOPTS,
|
||||
test_class = "org.tensorflow.op.core.ConstantTest",
|
||||
deps = [
|
||||
":tensorflow",
|
||||
":testutil",
|
||||
"@junit",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "libtensorflow_jni",
|
||||
srcs = select({
|
||||
|
@ -0,0 +1,112 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.op.annotation;
|
||||
|
||||
import java.lang.annotation.Documented;
|
||||
import java.lang.annotation.ElementType;
|
||||
import java.lang.annotation.Retention;
|
||||
import java.lang.annotation.RetentionPolicy;
|
||||
import java.lang.annotation.Target;
|
||||
|
||||
/**
|
||||
* Annotation used by classes to make TensorFlow operations conveniently accessible via {@code
|
||||
* org.tensorflow.op.Ops}.
|
||||
*
|
||||
* <p>An annotation processor (TODO: not yet implemented) builds the {@code Ops} class by
|
||||
* aggregating all classes annotated as {@code @Operator}s. Each annotated class <b>must</b> have at
|
||||
* least one public static factory method named {@code create} that accepts a {@link
|
||||
* org.tensorflow.op.Scope} as its first argument. The processor then adds a convenience method in
|
||||
* the {@code Ops} class. For example:
|
||||
*
|
||||
* <pre>{@code
|
||||
* @Operator
|
||||
* public final class MyOp implements Op {
|
||||
* public static MyOp create(Scope scope, Operand operand) {
|
||||
* ...
|
||||
* }
|
||||
* }
|
||||
* }</pre>
|
||||
*
|
||||
* <p>results in a method in the {@code Ops} class
|
||||
*
|
||||
* <pre>{@code
|
||||
* import org.tensorflow.op.Ops;
|
||||
* ...
|
||||
* Ops ops = new Ops(graph);
|
||||
* ...
|
||||
* ops.myOp(operand);
|
||||
* // and has exactly the same effect as calling
|
||||
* // MyOp.create(ops.getScope(), operand);
|
||||
* }</pre>
|
||||
*/
|
||||
@Documented
|
||||
@Target(ElementType.TYPE)
|
||||
@Retention(RetentionPolicy.CLASS)
|
||||
public @interface Operator {
|
||||
/**
|
||||
* Specify an optional group within the {@code Ops} class.
|
||||
*
|
||||
* <p>By default, an annotation processor will create convenience methods directly in the {@code
|
||||
* Ops} class. An annotated operator may optionally choose to place the method within a group. For
|
||||
* example:
|
||||
*
|
||||
* <pre>{@code
|
||||
* @Operator(group="math")
|
||||
* public final class Add extends PrimitiveOp implements Operand {
|
||||
* ...
|
||||
* }
|
||||
* }</pre>
|
||||
*
|
||||
* <p>results in the {@code add} method placed within a {@code math} group within the {@code Ops}
|
||||
* class.
|
||||
*
|
||||
* <pre>{@code
|
||||
* ops.math().add(...);
|
||||
* }</pre>
|
||||
*
|
||||
* <p>The group name must be a <a
|
||||
* href="https://docs.oracle.com/javase/specs/jls/se7/html/jls-3.html#jls-3.8">valid Java
|
||||
* identifier</a>.
|
||||
*/
|
||||
String group() default "";
|
||||
|
||||
/**
|
||||
* Name for the wrapper method used in the {@code Ops} class.
|
||||
*
|
||||
* <p>By default, a processor derives the method name in the {@code Ops} class from the class name
|
||||
* of the operator. This attribute allow you to provide a different name instead. For example:
|
||||
*
|
||||
* <pre>{@code
|
||||
* @Operator(name="myOperation")
|
||||
* public final class MyRealOperation implements Operand {
|
||||
* public static MyRealOperation create(...)
|
||||
* }
|
||||
* }</pre>
|
||||
*
|
||||
* <p>results in this method added to the {@code Ops} class
|
||||
*
|
||||
* <pre>{@code
|
||||
* ops.myOperation(...);
|
||||
* // and is the same as calling
|
||||
* // MyRealOperation.create(...)
|
||||
* }</pre>
|
||||
*
|
||||
* <p>The name must be a <a
|
||||
* href="https://docs.oracle.com/javase/specs/jls/se7/html/jls-3.html#jls-3.8">valid Java
|
||||
* identifier</a>.
|
||||
*/
|
||||
String name() default "";
|
||||
}
|
@ -0,0 +1,173 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.op.core;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.DoubleBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.nio.IntBuffer;
|
||||
import java.nio.LongBuffer;
|
||||
import org.tensorflow.DataType;
|
||||
import org.tensorflow.Operand;
|
||||
import org.tensorflow.Operation;
|
||||
import org.tensorflow.Output;
|
||||
import org.tensorflow.Tensor;
|
||||
import org.tensorflow.op.PrimitiveOp;
|
||||
import org.tensorflow.op.Scope;
|
||||
import org.tensorflow.op.annotation.Operator;
|
||||
|
||||
/** An operator producing a constant value. */
|
||||
@Operator
|
||||
public final class Constant extends PrimitiveOp implements Operand {
|
||||
/**
|
||||
* Create a constant from a Java object.
|
||||
*
|
||||
* <p>The argument {@code object} is first converted into a Tensor using {@link
|
||||
* org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be
|
||||
* provided. For example:
|
||||
*
|
||||
* <pre>{@code
|
||||
* Constant.create(scope, 7); // returns a constant scalar tensor 7
|
||||
* }</pre>
|
||||
*
|
||||
* @param scope is a scope used to add the underlying operation.
|
||||
* @param object a Java object representing the constant.
|
||||
* @see org.tensorflow.Tensor#create(Object) Tensor.create
|
||||
*/
|
||||
public static Constant create(Scope scope, Object object) {
|
||||
try (Tensor value = Tensor.create(object)) {
|
||||
return createWithTensor(scope, value);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a {@link DataType#INT32} constant with data from the given buffer.
|
||||
*
|
||||
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
|
||||
* its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
|
||||
* a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
|
||||
* method.
|
||||
*
|
||||
* @param scope is a scope used to add the underlying operation.
|
||||
* @param shape the tensor shape.
|
||||
* @param data a buffer containing the tensor data.
|
||||
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
|
||||
*/
|
||||
public static Constant create(Scope scope, long[] shape, IntBuffer data) {
|
||||
try (Tensor value = Tensor.create(shape, data)) {
|
||||
return createWithTensor(scope, value);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a {@link DataType#FLOAT} constant with data from the given buffer.
|
||||
*
|
||||
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
|
||||
* its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
|
||||
* a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
|
||||
* method.
|
||||
*
|
||||
* @param scope is a scope used to add the underlying operation.
|
||||
* @param shape the tensor shape.
|
||||
* @param data a buffer containing the tensor data.
|
||||
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
|
||||
*/
|
||||
public static Constant create(Scope scope, long[] shape, FloatBuffer data) {
|
||||
try (Tensor value = Tensor.create(shape, data)) {
|
||||
return createWithTensor(scope, value);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a {@link DataType#DOUBLE} constant with data from the given buffer.
|
||||
*
|
||||
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
|
||||
* its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
|
||||
* a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
|
||||
* method.
|
||||
*
|
||||
* @param scope is a scope used to add the underlying operation.
|
||||
* @param shape the tensor shape.
|
||||
* @param data a buffer containing the tensor data.
|
||||
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
|
||||
*/
|
||||
public static Constant create(Scope scope, long[] shape, DoubleBuffer data) {
|
||||
try (Tensor value = Tensor.create(shape, data)) {
|
||||
return createWithTensor(scope, value);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a {@link DataType#INT64} constant with data from the given buffer.
|
||||
*
|
||||
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
|
||||
* its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
|
||||
* a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
|
||||
* method.
|
||||
*
|
||||
* @param scope is a scope used to add the underlying operation.
|
||||
* @param shape the tensor shape.
|
||||
* @param data a buffer containing the tensor data.
|
||||
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
|
||||
*/
|
||||
public static Constant create(Scope scope, long[] shape, LongBuffer data) {
|
||||
try (Tensor value = Tensor.create(shape, data)) {
|
||||
return createWithTensor(scope, value);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a constant with data from the given buffer.
|
||||
*
|
||||
* <p>Creates a Constant with the provided shape of any type where the constant data has been
|
||||
* encoded into {@code data} as per the specification of the TensorFlow <a
|
||||
* href="https://www.tensorflow.org/code/tensorflow/c/c_api.h">C API</a>.
|
||||
*
|
||||
* @param scope is a scope used to add the underlying operation.
|
||||
* @param dataType the tensor datatype.
|
||||
* @param shape the tensor shape.
|
||||
* @param data a buffer containing the tensor data.
|
||||
* @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the
|
||||
* buffer
|
||||
*/
|
||||
public static Constant create(Scope scope, DataType dataType, long[] shape, ByteBuffer data) {
|
||||
try (Tensor value = Tensor.create(dataType, shape, data)) {
|
||||
return createWithTensor(scope, value);
|
||||
}
|
||||
}
|
||||
|
||||
private static Constant createWithTensor(Scope scope, Tensor value) {
|
||||
return new Constant(
|
||||
scope
|
||||
.graph()
|
||||
.opBuilder("Const", scope.makeOpName("Const"))
|
||||
.setAttr("value", value)
|
||||
.setAttr("dtype", value.dataType())
|
||||
.build());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Output asOutput() {
|
||||
return output;
|
||||
}
|
||||
|
||||
private Constant(Operation operation) {
|
||||
super(operation);
|
||||
output = operation.output(0);
|
||||
}
|
||||
|
||||
private final Output output;
|
||||
}
|
@ -0,0 +1,131 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.op.core;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.DataOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.DoubleBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.nio.IntBuffer;
|
||||
import java.nio.LongBuffer;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
import org.tensorflow.DataType;
|
||||
import org.tensorflow.Graph;
|
||||
import org.tensorflow.Session;
|
||||
import org.tensorflow.Tensor;
|
||||
import org.tensorflow.op.Scope;
|
||||
|
||||
@RunWith(JUnit4.class)
|
||||
public class ConstantTest {
|
||||
private static final float EPSILON = 1e-7f;
|
||||
|
||||
@Test
|
||||
public void createIntBuffer() {
|
||||
int[] ints = {1, 2, 3, 4};
|
||||
long[] shape = {4};
|
||||
|
||||
try (Graph g = new Graph();
|
||||
Session sess = new Session(g)) {
|
||||
Scope scope = new Scope(g);
|
||||
Constant op = Constant.create(scope, shape, IntBuffer.wrap(ints));
|
||||
Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
|
||||
int[] actual = new int[ints.length];
|
||||
assertArrayEquals(ints, result.copyTo(actual));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createFloatBuffer() {
|
||||
float[] floats = {1, 2, 3, 4};
|
||||
long[] shape = {4};
|
||||
|
||||
try (Graph g = new Graph();
|
||||
Session sess = new Session(g)) {
|
||||
Scope scope = new Scope(g);
|
||||
Constant op = Constant.create(scope, shape, FloatBuffer.wrap(floats));
|
||||
Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
|
||||
float[] actual = new float[floats.length];
|
||||
assertArrayEquals(floats, result.copyTo(actual), EPSILON);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createDoubleBuffer() {
|
||||
double[] doubles = {1, 2, 3, 4};
|
||||
long[] shape = {4};
|
||||
|
||||
try (Graph g = new Graph();
|
||||
Session sess = new Session(g)) {
|
||||
Scope scope = new Scope(g);
|
||||
Constant op = Constant.create(scope, shape, DoubleBuffer.wrap(doubles));
|
||||
Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
|
||||
double[] actual = new double[doubles.length];
|
||||
assertArrayEquals(doubles, result.copyTo(actual), EPSILON);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createLongBuffer() {
|
||||
long[] longs = {1, 2, 3, 4};
|
||||
long[] shape = {4};
|
||||
|
||||
try (Graph g = new Graph();
|
||||
Session sess = new Session(g)) {
|
||||
Scope scope = new Scope(g);
|
||||
Constant op = Constant.create(scope, shape, LongBuffer.wrap(longs));
|
||||
Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
|
||||
long[] actual = new long[longs.length];
|
||||
assertArrayEquals(longs, result.copyTo(actual));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createStringBuffer() throws IOException {
|
||||
|
||||
byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4};
|
||||
long[] shape = {};
|
||||
|
||||
// byte arrays (DataType.STRING in Tensorflow) are encoded as an offset in the data buffer,
|
||||
// followed by a varint encoded size, followed by the data.
|
||||
ByteArrayOutputStream baout = new ByteArrayOutputStream();
|
||||
DataOutputStream out = new DataOutputStream(baout);
|
||||
// Offset in array.
|
||||
out.writeLong(0L);
|
||||
// Varint encoded length of buffer.
|
||||
// For any number < 0x80, the varint encoding is simply the number itself.
|
||||
// https://developers.google.com/protocol-buffers/docs/encoding#varints
|
||||
assertTrue(data.length < 0x80);
|
||||
out.write(data.length);
|
||||
out.write(data);
|
||||
out.close();
|
||||
byte[] content = baout.toByteArray();
|
||||
|
||||
try (Graph g = new Graph();
|
||||
Session sess = new Session(g)) {
|
||||
Scope scope = new Scope(g);
|
||||
Constant op = Constant.create(scope, DataType.STRING, shape, ByteBuffer.wrap(content));
|
||||
Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
|
||||
assertArrayEquals(data, result.bytesValue());
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user