Modifications per recommended changes

* moved classes to org.tensorflow.op package
* BUILD changes to keep op package in a separate target
* Updated ci_build to also build new target
* Replaced static factory method for Scope with public raw constructor
* Replaced Builders for NameScope and Scope with raw constructors
* Addressed JavaDoc nits
* Added comments pointing to reasons for operator name restrictions
* Merged NameScopeTest into ScopeTest
* renamed scope.withOpName -> scope.withName
This commit is contained in:
KB Sriram 2017-06-06 17:10:16 -07:00 committed by Martin Wicke
parent 1ec17bd1c6
commit 3072c2f239
7 changed files with 347 additions and 462 deletions

View File

@ -15,6 +15,16 @@ java_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
# Operator convenience wrappers for the core tensorflow library.
java_library(
name = "tensorflow_op",
srcs = [
":java_op_sources",
],
javacopts = JAVACOPTS,
deps = [":tensorflow"],
)
# NOTE(ashankar): Rule to include the Java API in the Android Inference Library # NOTE(ashankar): Rule to include the Java API in the Android Inference Library
# .aar. At some point, might make sense for a .aar rule here instead. # .aar. At some point, might make sense for a .aar rule here instead.
filegroup( filegroup(
@ -26,6 +36,14 @@ filegroup(
], ],
) )
filegroup(
name = "java_op_sources",
srcs = glob(["src/main/java/org/tensorflow/op/*.java"]),
visibility = [
"//tensorflow/java:__pkg__",
],
)
java_library( java_library(
name = "testutil", name = "testutil",
testonly = 1, testonly = 1,
@ -47,32 +65,6 @@ java_test(
], ],
) )
java_test(
name = "NameScopeTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/NameScopeTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.NameScopeTest",
deps = [
":tensorflow",
":testutil",
"@junit",
],
)
java_test(
name = "ScopeTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/ScopeTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.ScopeTest",
deps = [
":tensorflow",
":testutil",
"@junit",
],
)
java_test( java_test(
name = "OperationBuilderTest", name = "OperationBuilderTest",
size = "small", size = "small",
@ -164,6 +156,20 @@ java_test(
], ],
) )
java_test(
name = "ScopeTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/op/ScopeTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.op.ScopeTest",
deps = [
":tensorflow",
":tensorflow_op",
":testutil",
"@junit",
],
)
filegroup( filegroup(
name = "libtensorflow_jni", name = "libtensorflow_jni",
srcs = select({ srcs = select({

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
package org.tensorflow; package org.tensorflow.op;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -40,7 +40,7 @@ final class NameScope {
if (baseName == null) { if (baseName == null) {
checkPattern(OP_NAME_REGEX, scopeName); checkPattern(OP_NAME_REGEX, scopeName);
} else { } else {
checkPattern(ROOT_SCOPE_NAME_REGEX, scopeName); checkPattern(SUBSCOPE_NAME_REGEX, scopeName);
} }
if (baseOpName != null) { if (baseOpName != null) {
@ -49,14 +49,14 @@ final class NameScope {
} }
String newBaseName = fullyQualify(makeUnique(scopeName)); String newBaseName = fullyQualify(makeUnique(scopeName));
return NameScope.builder().baseName(newBaseName).build(); return new NameScope(newBaseName, null, null);
} }
NameScope withOpName(String name) { NameScope withName(String name) {
checkPattern(OP_NAME_REGEX, name); checkPattern(OP_NAME_REGEX, name);
// All context except for the baseOpName is shared with the new scope. // All context except for the baseOpName is shared with the new scope.
return NameScope.builder().ids(ids).baseName(baseName).baseOpName(name).build(); return new NameScope(baseName, name, ids);
} }
String makeOpName(String opName) { String makeOpName(String opName) {
@ -74,20 +74,18 @@ final class NameScope {
* *
* <p>A root-level namescope generates operator names with no components, like {@code Const_72} * <p>A root-level namescope generates operator names with no components, like {@code Const_72}
* and {@code result}. * and {@code result}.
*
* @return a NameScope that generates top-level names.
*/ */
static NameScope create() { NameScope() {
return NameScope.builder().build(); this(null, null, null);
} }
private NameScope(Builder builder) { private NameScope(String baseName, String baseOpName, Map<String, Integer> ids) {
baseName = builder.baseName; this.baseName = baseName;
baseOpName = builder.baseOpName; this.baseOpName = baseOpName;
if (builder.ids != null) { if (ids != null) {
ids = builder.ids; this.ids = ids;
} else { } else {
ids = new HashMap<String, Integer>(); this.ids = new HashMap<String, Integer>();
} }
} }
@ -145,37 +143,16 @@ final class NameScope {
} }
} }
// The constraints for operator and scope names originate from restrictions on node names
// noted in the proto definition core/framework/node_def.proto for NodeDef and actually
// implemented in core/framework/node_def_util.cc [Note that the proto comment does not include
// dash (-) in names, while the actual implementation permits it. These regexs follow the actual
// implementation.]
//
// These two patterns are used to ensure fully qualified names always start with a
// LETTER_DIGIT_DOT, followed by zero or more LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE. SLASH is not
// permitted in actual user-supplied names to NameScope - it is used as a reserved character to
// separate subcomponents within fully qualified names.
private static final Pattern OP_NAME_REGEX = Pattern.compile("[A-Za-z0-9.][A-Za-z0-9_.\\-]*"); private static final Pattern OP_NAME_REGEX = Pattern.compile("[A-Za-z0-9.][A-Za-z0-9_.\\-]*");
private static final Pattern ROOT_SCOPE_NAME_REGEX = Pattern.compile("[A-Za-z0-9_.\\-]+"); private static final Pattern SUBSCOPE_NAME_REGEX = Pattern.compile("[A-Za-z0-9_.\\-]+");
private static Builder builder() {
return new Builder();
}
private static final class Builder {
private Builder() {}
private Builder baseName(String name) {
baseName = name;
return this;
}
private Builder baseOpName(String name) {
baseOpName = name;
return this;
}
private Builder ids(Map<String, Integer> map) {
ids = map;
return this;
}
private NameScope build() {
return new NameScope(this);
}
private String baseName = null;
private String baseOpName = null;
private Map<String, Integer> ids = null;
}
} }

View File

@ -13,17 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
package org.tensorflow; package org.tensorflow.op;
import org.tensorflow.Graph;
/** /**
* A {@code Scope} represents a set of related properties when creating Tensorflow Operations, such * Manages groups of related properties when creating Tensorflow Operations, such as a common name
* as a common name prefix. * prefix.
* *
* <p>A {@code Scope} is a container for common properties applied to TensorFlow Ops. Normal user * <p>A {@code Scope} is a container for common properties applied to TensorFlow Ops. Normal user
* code initializes a {@code Scope} and provides it to Operation building classes. For example: * code initializes a {@code Scope} and provides it to Operation building classes. For example:
* *
* <pre>{@code * <pre>{@code
* Scope scope = Scope.create(graph); * Scope scope = new Scope(graph);
* Constant c = Constant.create(scope, 42); * Constant c = Constant.create(scope, 42);
* }</pre> * }</pre>
* *
@ -51,13 +53,13 @@ package org.tensorflow;
* <p>An example using {@code Constant} implemented as before: * <p>An example using {@code Constant} implemented as before:
* *
* <pre>{@code * <pre>{@code
* Scope root = Scope.create(graph); * Scope root = new Scope(graph);
* *
* // The linear subscope will generate names like linear/... * // The linear subscope will generate names like linear/...
* Scope linear = Scope.withSubScope("linear"); * Scope linear = Scope.withSubScope("linear");
* *
* // This op name will be "linear/W" * // This op name will be "linear/W"
* Constant.create(linear.withOpName("W"), ...); * Constant.create(linear.withName("W"), ...);
* *
* // This op will be "linear/Constant", using the default * // This op will be "linear/Constant", using the default
* // name provided by Constant * // name provided by Constant
@ -72,23 +74,23 @@ package org.tensorflow;
* <p>Scope objects are thread-safe. * <p>Scope objects are thread-safe.
*/ */
public final class Scope { public final class Scope {
/** /**
* Create a new top-level scope. * Create a new top-level scope.
* *
* @param graph The graph instance to be managed by the scope. * @param graph The graph instance to be managed by the scope.
* @return a top-level Scope.
*/ */
public static Scope create(Graph graph) { public Scope(Graph graph) {
return builder(graph).build(); this(graph, new NameScope());
} }
/** @return the graph managed by this scope. */ /** Returns the graph managed by this scope. */
public Graph graph() { public Graph graph() {
return graph; return graph;
} }
/** /**
* Return a new subscope with the provided name. * Returns a new scope where added operations will have the provided name prefix.
* *
* <p>Ops created with this scope will have {@code name/childScopeName/} as the prefix. The actual * <p>Ops created with this scope will have {@code name/childScopeName/} as the prefix. The actual
* name will be unique in the returned scope. All other properties are inherited from the current * name will be unique in the returned scope. All other properties are inherited from the current
@ -106,7 +108,7 @@ public final class Scope {
* @throws IllegalArgumentException if the name is invalid * @throws IllegalArgumentException if the name is invalid
*/ */
public Scope withSubScope(String childScopeName) { public Scope withSubScope(String childScopeName) {
return toBuilder().nameScope(nameScope.withSubScope(childScopeName)).build(); return new Scope(graph, nameScope.withSubScope(childScopeName));
} }
/** /**
@ -121,8 +123,8 @@ public final class Scope {
* @return a new Scope that uses opName for operations. * @return a new Scope that uses opName for operations.
* @throws IllegalArgumentException if the name is invalid * @throws IllegalArgumentException if the name is invalid
*/ */
public Scope withOpName(String opName) { public Scope withName(String opName) {
return toBuilder().nameScope(nameScope.withOpName(opName)).build(); return new Scope(graph, nameScope.withName(opName));
} }
/** /**
@ -158,41 +160,11 @@ public final class Scope {
return nameScope.makeOpName(defaultName); return nameScope.makeOpName(defaultName);
} }
private Scope(Builder builder) { private Scope(Graph graph, NameScope nameScope) {
graph = builder.graph; this.graph = graph;
if (builder.nameScope != null) { this.nameScope = nameScope;
nameScope = builder.nameScope;
} else {
nameScope = NameScope.create();
}
} }
private final Graph graph; private final Graph graph;
private final NameScope nameScope; private final NameScope nameScope;
private Builder toBuilder() {
return builder(graph).nameScope(nameScope);
}
private static Builder builder(Graph graph) {
return new Builder(graph);
}
private static final class Builder {
private Builder(Graph g) {
graph = g;
}
private Builder nameScope(NameScope ns) {
nameScope = ns;
return this;
}
private Scope build() {
return new Scope(this);
}
private final Graph graph;
private NameScope nameScope;
}
} }

View File

@ -1,158 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link org.tensorflow.NameScope}. */
@RunWith(JUnit4.class)
public class NameScopeTest {
@Test
public void basicNames() {
NameScope root = NameScope.create();
assertEquals("add", root.makeOpName("add"));
assertEquals("add_1", root.makeOpName("add"));
assertEquals("add_2", root.makeOpName("add"));
assertEquals("mul", root.makeOpName("mul"));
}
@Test
public void hierarchicalNames() {
NameScope root = NameScope.create();
NameScope child = root.withSubScope("child");
assertEquals("child/add", child.makeOpName("add"));
assertEquals("child/add_1", child.makeOpName("add"));
assertEquals("child/mul", child.makeOpName("mul"));
NameScope child_1 = root.withSubScope("child");
assertEquals("child_1/add", child_1.makeOpName("add"));
assertEquals("child_1/add_1", child_1.makeOpName("add"));
assertEquals("child_1/mul", child_1.makeOpName("mul"));
NameScope c_c = root.withSubScope("c").withSubScope("c");
assertEquals("c/c/add", c_c.makeOpName("add"));
NameScope c_1 = root.withSubScope("c");
NameScope c_1_c = c_1.withSubScope("c");
assertEquals("c_1/c/add", c_1_c.makeOpName("add"));
NameScope c_1_c_1 = c_1.withSubScope("c");
assertEquals("c_1/c_1/add", c_1_c_1.makeOpName("add"));
}
@Test
public void scopeAndOpNames() {
NameScope root = NameScope.create();
NameScope child = root.withSubScope("child");
assertEquals("child/add", child.makeOpName("add"));
assertEquals("child_1", root.makeOpName("child"));
assertEquals("child_2/p", root.withSubScope("child").makeOpName("p"));
}
@Test
public void names() {
NameScope root = NameScope.create();
final String[] invalid_names = {
"_", // Names are constrained to start with [A-Za-z0-9.]
null, "", "a$", // Invalid characters
"a/b", // slashes not allowed
};
for (String name : invalid_names) {
try {
root.withOpName(name);
fail("failed to catch invalid op name.");
} catch (IllegalArgumentException ex) {
// expected
}
// Root scopes follow the same rules as opnames
try {
root.withSubScope(name);
fail("failed to catch invalid scope name: " + name);
} catch (IllegalArgumentException ex) {
// expected
}
}
// Non-root scopes have a less restrictive constraint.
assertEquals("a/_/hello", root.withSubScope("a").withSubScope("_").makeOpName("hello"));
}
// A dummy composite op - it should create a tree of op names.
private static void topCompositeOp(NameScope scope, List<String> opnames) {
NameScope compScope = scope.withSubScope("top");
opnames.add(compScope.makeOpName("mul"));
opnames.add(compScope.makeOpName("bias_add"));
intermediateOp(compScope, opnames);
}
private static void intermediateOp(NameScope scope, List<String> opnames) {
NameScope compScope = scope.withSubScope("intermediate");
opnames.add(compScope.makeOpName("c1"));
opnames.add(compScope.makeOpName("mul"));
leafOp(compScope.withOpName("c2"), opnames);
}
private static void leafOp(NameScope scope, List<String> opnames) {
opnames.add(scope.makeOpName("const"));
}
@Test
public void compositeOp() {
NameScope root = NameScope.create();
List<String> names = new ArrayList<String>();
topCompositeOp(root, names);
assertEquals(
Arrays.asList(
"top/mul",
"top/bias_add",
"top/intermediate/c1",
"top/intermediate/mul",
"top/intermediate/c2"),
names);
assertEquals("top_1", root.makeOpName("top"));
names.clear();
topCompositeOp(root, names);
assertEquals(
Arrays.asList(
"top_2/mul",
"top_2/bias_add",
"top_2/intermediate/c1",
"top_2/intermediate/mul",
"top_2/intermediate/c2"),
names);
names.clear();
topCompositeOp(root.withOpName("c"), names);
assertEquals(
Arrays.asList(
"c/mul", "c/bias_add", "c/intermediate/c1", "c/intermediate/mul", "c/intermediate/c2"),
names);
}
}

View File

@ -1,184 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link org.tensorflow.Scope}. */
@RunWith(JUnit4.class)
public class ScopeTest {
@Test
public void basic() {
try (Graph g = new Graph()) {
Scope s = Scope.create(g);
Const c1 = Const.create(s, 42);
assertEquals("Const", c1.input().op().name());
Const c2 = Const.create(s, 7);
assertEquals("Const_1", c2.input().op().name());
Const c3 = Const.create(s.withOpName("four"), 4);
assertEquals("four", c3.input().op().name());
}
}
@Test
public void hierarchy() {
try (Graph g = new Graph()) {
Scope root = Scope.create(g);
Scope child = root.withSubScope("child");
assertEquals("child/Const", Const.create(child, 42).input().op().name());
assertEquals("child/four", Const.create(child.withOpName("four"), 4).input().op().name());
}
}
@Test
public void composite() {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope s = Scope.create(g);
Const data = Const.create(s.withOpName("data"), new int[] {600, 470, 170, 430, 300});
// Create a composite op with a customized name
Variance var1 = Variance.create(s.withOpName("example"), data);
assertEquals("example/variance", var1.input().op().name());
// Confirm internally added ops have the right names.
assertNotNull(g.operation("example/squared_deviation"));
assertNotNull(g.operation("example/Mean"));
assertNotNull(g.operation("example/zero"));
// Same composite op with a default name
Variance var2 = Variance.create(s, data);
assertEquals("variance/variance", var2.input().op().name());
// Confirm internally added ops have the right names.
assertNotNull(g.operation("variance/squared_deviation"));
assertNotNull(g.operation("variance/Mean"));
assertNotNull(g.operation("variance/zero"));
// Verify correct results as well.
Tensor result = sess.runner().fetch(var1.input()).run().get(0);
assertEquals(21704, result.intValue());
result = sess.runner().fetch(var2.input()).run().get(0);
assertEquals(21704, result.intValue());
}
}
// Convenience interface
// TODO: replace when standardized
interface Input {
Output input();
}
// "handwritten" sample operator classes
private static final class Const implements Input {
private final Output output;
private static Const create(Scope s, Object v) {
try (Tensor value = Tensor.create(v)) {
return new Const(
s.graph()
.opBuilder("Const", s.makeOpName("Const"))
.setAttr("dtype", value.dataType())
.setAttr("value", value)
.build()
.output(0));
}
}
private Const(Output o) {
output = o;
}
@Override
public Output input() {
return output;
}
}
private static final class Mean implements Input {
private final Output output;
private static Mean create(Scope s, Input input, Input reductionIndices) {
return new Mean(
s.graph()
.opBuilder("Mean", s.makeOpName("Mean"))
.addInput(input.input())
.addInput(reductionIndices.input())
.build()
.output(0));
}
private Mean(Output o) {
output = o;
}
@Override
public Output input() {
return output;
}
}
private static final class SquaredDifference implements Input {
private final Output output;
private static SquaredDifference create(Scope s, Input x, Input y) {
return new SquaredDifference(
s.graph()
.opBuilder("SquaredDifference", s.makeOpName("SquaredDifference"))
.addInput(x.input())
.addInput(y.input())
.build()
.output(0));
}
private SquaredDifference(Output o) {
output = o;
}
@Override
public Output input() {
return output;
}
}
private static final class Variance implements Input {
private final Output output;
private static Variance create(Scope base, Input x) {
Scope s = base.withSubScope("variance");
Const zero = Const.create(s.withOpName("zero"), new int[] {0});
SquaredDifference sqdiff =
SquaredDifference.create(s.withOpName("squared_deviation"), x, Mean.create(s, x, zero));
return new Variance(Mean.create(s.withOpName("variance"), sqdiff, zero).input());
}
private Variance(Output o) {
output = o;
}
@Override
public Output input() {
return output;
}
}
}

View File

@ -0,0 +1,265 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.op;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
/** Unit tests for {@link org.tensorflow.Scope}. */
@RunWith(JUnit4.class)
public class ScopeTest {
@Test
public void basicNames() {
try (Graph g = new Graph()) {
Scope root = new Scope(g);
assertEquals("add", root.makeOpName("add"));
assertEquals("add_1", root.makeOpName("add"));
assertEquals("add_2", root.makeOpName("add"));
assertEquals("mul", root.makeOpName("mul"));
}
}
@Test
public void hierarchicalNames() {
try (Graph g = new Graph()) {
Scope root = new Scope(g);
Scope child = root.withSubScope("child");
assertEquals("child/add", child.makeOpName("add"));
assertEquals("child/add_1", child.makeOpName("add"));
assertEquals("child/mul", child.makeOpName("mul"));
Scope child_1 = root.withSubScope("child");
assertEquals("child_1/add", child_1.makeOpName("add"));
assertEquals("child_1/add_1", child_1.makeOpName("add"));
assertEquals("child_1/mul", child_1.makeOpName("mul"));
Scope c_c = root.withSubScope("c").withSubScope("c");
assertEquals("c/c/add", c_c.makeOpName("add"));
Scope c_1 = root.withSubScope("c");
Scope c_1_c = c_1.withSubScope("c");
assertEquals("c_1/c/add", c_1_c.makeOpName("add"));
Scope c_1_c_1 = c_1.withSubScope("c");
assertEquals("c_1/c_1/add", c_1_c_1.makeOpName("add"));
}
}
@Test
public void scopeAndOpNames() {
try (Graph g = new Graph()) {
Scope root = new Scope(g);
Scope child = root.withSubScope("child");
assertEquals("child/add", child.makeOpName("add"));
assertEquals("child_1", root.makeOpName("child"));
assertEquals("child_2/p", root.withSubScope("child").makeOpName("p"));
}
}
@Test
public void validateNames() {
try (Graph g = new Graph()) {
Scope root = new Scope(g);
final String[] invalid_names = {
"_", // Names are constrained to start with [A-Za-z0-9.]
null, "", "a$", // Invalid characters
"a/b", // slashes not allowed
};
for (String name : invalid_names) {
try {
root.withName(name);
fail("failed to catch invalid op name.");
} catch (IllegalArgumentException ex) {
// expected
}
// Root scopes follow the same rules as opnames
try {
root.withSubScope(name);
fail("failed to catch invalid scope name: " + name);
} catch (IllegalArgumentException ex) {
// expected
}
}
// Non-root scopes have a less restrictive constraint.
assertEquals("a/_/hello", root.withSubScope("a").withSubScope("_").makeOpName("hello"));
}
}
@Test
public void basic() {
try (Graph g = new Graph()) {
Scope s = new Scope(g);
Const c1 = Const.create(s, 42);
assertEquals("Const", c1.output().op().name());
Const c2 = Const.create(s, 7);
assertEquals("Const_1", c2.output().op().name());
Const c3 = Const.create(s.withName("four"), 4);
assertEquals("four", c3.output().op().name());
Const c4 = Const.create(s.withName("four"), 4);
assertEquals("four_1", c4.output().op().name());
}
}
@Test
public void hierarchy() {
try (Graph g = new Graph()) {
Scope root = new Scope(g);
Scope child = root.withSubScope("child");
assertEquals("child/Const", Const.create(child, 42).output().op().name());
assertEquals("child/four", Const.create(child.withName("four"), 4).output().op().name());
}
}
@Test
public void composite() {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope s = new Scope(g);
Output data = Const.create(s.withName("data"), new int[] {600, 470, 170, 430, 300}).output();
// Create a composite op with a customized name
Variance var1 = Variance.create(s.withName("example"), data);
assertEquals("example/variance", var1.output().op().name());
// Confirm internally added ops have the right names.
assertNotNull(g.operation("example/squared_deviation"));
assertNotNull(g.operation("example/Mean"));
assertNotNull(g.operation("example/zero"));
// Same composite op with a default name
Variance var2 = Variance.create(s, data);
assertEquals("variance/variance", var2.output().op().name());
// Confirm internally added ops have the right names.
assertNotNull(g.operation("variance/squared_deviation"));
assertNotNull(g.operation("variance/Mean"));
assertNotNull(g.operation("variance/zero"));
// Verify correct results as well.
Tensor result = sess.runner().fetch(var1.output()).run().get(0);
assertEquals(21704, result.intValue());
result = sess.runner().fetch(var2.output()).run().get(0);
assertEquals(21704, result.intValue());
}
}
// "handwritten" sample operator classes
private static final class Const {
private final Output output;
private static Const create(Scope s, Object v) {
try (Tensor value = Tensor.create(v)) {
return new Const(
s.graph()
.opBuilder("Const", s.makeOpName("Const"))
.setAttr("dtype", value.dataType())
.setAttr("value", value)
.build()
.output(0));
}
}
private Const(Output o) {
output = o;
}
public Output output() {
return output;
}
}
private static final class Mean {
private final Output output;
private static Mean create(Scope s, Output input, Output reductionIndices) {
return new Mean(
s.graph()
.opBuilder("Mean", s.makeOpName("Mean"))
.addInput(input)
.addInput(reductionIndices)
.build()
.output(0));
}
private Mean(Output o) {
output = o;
}
public Output output() {
return output;
}
}
private static final class SquaredDifference {
private final Output output;
private static SquaredDifference create(Scope s, Output x, Output y) {
return new SquaredDifference(
s.graph()
.opBuilder("SquaredDifference", s.makeOpName("SquaredDifference"))
.addInput(x)
.addInput(y)
.build()
.output(0));
}
private SquaredDifference(Output o) {
output = o;
}
public Output output() {
return output;
}
}
private static final class Variance {
private final Output output;
private static Variance create(Scope base, Output x) {
Scope s = base.withSubScope("variance");
Output zero = Const.create(s.withName("zero"), new int[] {0}).output();
Output sqdiff =
SquaredDifference.create(
s.withName("squared_deviation"), x, Mean.create(s, x, zero).output())
.output();
return new Variance(Mean.create(s.withName("variance"), sqdiff, zero).output());
}
private Variance(Output o) {
output = o;
}
public Output output() {
return output;
}
}
}

View File

@ -20,6 +20,8 @@
# And jars: # And jars:
# (3) Java API .jar # (3) Java API .jar
# (4) Java API sources .jar # (4) Java API sources .jar
# (5) Java Op wrappers .jar
# (6) Java Op wrapper sources .jar
# #
# These binary distributions will allow use of TensorFlow in various languages # These binary distributions will allow use of TensorFlow in various languages
# without having to compile the TensorFlow framework from sources, which takes # without having to compile the TensorFlow framework from sources, which takes
@ -34,6 +36,8 @@
# - lib_package/libtensorflow_jni${SUFFIX}.tar.gz # - lib_package/libtensorflow_jni${SUFFIX}.tar.gz
# - lib_package/libtensorflow.jar # - lib_package/libtensorflow.jar
# - lib_package/libtensorflow-src.jar # - lib_package/libtensorflow-src.jar
# - lib_package/libtensorflow_op.jar
# - lib_package/libtensorflow_op-src.jar
# - lib_package/libtensorflow_proto.zip # - lib_package/libtensorflow_proto.zip
# #
# ASSUMPTIONS: # ASSUMPTIONS:
@ -75,12 +79,15 @@ function build_libtensorflow_tarball() {
//tensorflow/tools/lib_package:libtensorflow_jni.tar.gz \ //tensorflow/tools/lib_package:libtensorflow_jni.tar.gz \
//tensorflow/java:libtensorflow.jar \ //tensorflow/java:libtensorflow.jar \
//tensorflow/java:libtensorflow-src.jar \ //tensorflow/java:libtensorflow-src.jar \
//tensorflow/java:libtensorflow_op.jar \
//tensorflow/java:libtensorflow_op-src.jar \
//tensorflow/tools/lib_package:libtensorflow_proto.zip //tensorflow/tools/lib_package:libtensorflow_proto.zip
mkdir -p ${DIR} mkdir -p ${DIR}
cp bazel-bin/tensorflow/tools/lib_package/libtensorflow.tar.gz ${DIR}/libtensorflow${TARBALL_SUFFIX}.tar.gz cp bazel-bin/tensorflow/tools/lib_package/libtensorflow.tar.gz ${DIR}/libtensorflow${TARBALL_SUFFIX}.tar.gz
cp bazel-bin/tensorflow/tools/lib_package/libtensorflow_jni.tar.gz ${DIR}/libtensorflow_jni${TARBALL_SUFFIX}.tar.gz cp bazel-bin/tensorflow/tools/lib_package/libtensorflow_jni.tar.gz ${DIR}/libtensorflow_jni${TARBALL_SUFFIX}.tar.gz
cp bazel-bin/tensorflow/java/libtensorflow.jar bazel-bin/tensorflow/java/libtensorflow-src.jar ${DIR} cp bazel-bin/tensorflow/java/libtensorflow.jar bazel-bin/tensorflow/java/libtensorflow-src.jar ${DIR}
cp bazel-bin/tensorflow/java/libtensorflow_op.jar bazel-bin/tensorflow/java/libtensorflow_op-src.jar ${DIR}
cp bazel-genfiles/tensorflow/tools/lib_package/libtensorflow_proto.zip ${DIR} cp bazel-genfiles/tensorflow/tools/lib_package/libtensorflow_proto.zip ${DIR}
chmod -x ${DIR}/* chmod -x ${DIR}/*
} }