Modifications per recommended changes
* moved classes to org.tensorflow.op package * BUILD changes to keep op package in a separate target * Updated ci_build to also build new target * Replaced static factory method for Scope with public raw constructor * Replaced Builders for NameScope and Scope with raw constructors * Addressed JavaDoc nits * Added comments pointing to reasons for operator name restrictions * Merged NameScopeTest into ScopeTest * renamed scope.withOpName -> scope.withName
This commit is contained in:
parent
1ec17bd1c6
commit
3072c2f239
@ -15,6 +15,16 @@ java_library(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Operator convenience wrappers for the core tensorflow library.
|
||||
java_library(
|
||||
name = "tensorflow_op",
|
||||
srcs = [
|
||||
":java_op_sources",
|
||||
],
|
||||
javacopts = JAVACOPTS,
|
||||
deps = [":tensorflow"],
|
||||
)
|
||||
|
||||
# NOTE(ashankar): Rule to include the Java API in the Android Inference Library
|
||||
# .aar. At some point, might make sense for a .aar rule here instead.
|
||||
filegroup(
|
||||
@ -26,6 +36,14 @@ filegroup(
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "java_op_sources",
|
||||
srcs = glob(["src/main/java/org/tensorflow/op/*.java"]),
|
||||
visibility = [
|
||||
"//tensorflow/java:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
java_library(
|
||||
name = "testutil",
|
||||
testonly = 1,
|
||||
@ -47,32 +65,6 @@ java_test(
|
||||
],
|
||||
)
|
||||
|
||||
java_test(
|
||||
name = "NameScopeTest",
|
||||
size = "small",
|
||||
srcs = ["src/test/java/org/tensorflow/NameScopeTest.java"],
|
||||
javacopts = JAVACOPTS,
|
||||
test_class = "org.tensorflow.NameScopeTest",
|
||||
deps = [
|
||||
":tensorflow",
|
||||
":testutil",
|
||||
"@junit",
|
||||
],
|
||||
)
|
||||
|
||||
java_test(
|
||||
name = "ScopeTest",
|
||||
size = "small",
|
||||
srcs = ["src/test/java/org/tensorflow/ScopeTest.java"],
|
||||
javacopts = JAVACOPTS,
|
||||
test_class = "org.tensorflow.ScopeTest",
|
||||
deps = [
|
||||
":tensorflow",
|
||||
":testutil",
|
||||
"@junit",
|
||||
],
|
||||
)
|
||||
|
||||
java_test(
|
||||
name = "OperationBuilderTest",
|
||||
size = "small",
|
||||
@ -164,6 +156,20 @@ java_test(
|
||||
],
|
||||
)
|
||||
|
||||
java_test(
|
||||
name = "ScopeTest",
|
||||
size = "small",
|
||||
srcs = ["src/test/java/org/tensorflow/op/ScopeTest.java"],
|
||||
javacopts = JAVACOPTS,
|
||||
test_class = "org.tensorflow.op.ScopeTest",
|
||||
deps = [
|
||||
":tensorflow",
|
||||
":tensorflow_op",
|
||||
":testutil",
|
||||
"@junit",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "libtensorflow_jni",
|
||||
srcs = select({
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow;
|
||||
package org.tensorflow.op;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
@ -40,7 +40,7 @@ final class NameScope {
|
||||
if (baseName == null) {
|
||||
checkPattern(OP_NAME_REGEX, scopeName);
|
||||
} else {
|
||||
checkPattern(ROOT_SCOPE_NAME_REGEX, scopeName);
|
||||
checkPattern(SUBSCOPE_NAME_REGEX, scopeName);
|
||||
}
|
||||
|
||||
if (baseOpName != null) {
|
||||
@ -49,14 +49,14 @@ final class NameScope {
|
||||
}
|
||||
|
||||
String newBaseName = fullyQualify(makeUnique(scopeName));
|
||||
return NameScope.builder().baseName(newBaseName).build();
|
||||
return new NameScope(newBaseName, null, null);
|
||||
}
|
||||
|
||||
NameScope withOpName(String name) {
|
||||
NameScope withName(String name) {
|
||||
checkPattern(OP_NAME_REGEX, name);
|
||||
|
||||
// All context except for the baseOpName is shared with the new scope.
|
||||
return NameScope.builder().ids(ids).baseName(baseName).baseOpName(name).build();
|
||||
return new NameScope(baseName, name, ids);
|
||||
}
|
||||
|
||||
String makeOpName(String opName) {
|
||||
@ -74,20 +74,18 @@ final class NameScope {
|
||||
*
|
||||
* <p>A root-level namescope generates operator names with no components, like {@code Const_72}
|
||||
* and {@code result}.
|
||||
*
|
||||
* @return a NameScope that generates top-level names.
|
||||
*/
|
||||
static NameScope create() {
|
||||
return NameScope.builder().build();
|
||||
NameScope() {
|
||||
this(null, null, null);
|
||||
}
|
||||
|
||||
private NameScope(Builder builder) {
|
||||
baseName = builder.baseName;
|
||||
baseOpName = builder.baseOpName;
|
||||
if (builder.ids != null) {
|
||||
ids = builder.ids;
|
||||
private NameScope(String baseName, String baseOpName, Map<String, Integer> ids) {
|
||||
this.baseName = baseName;
|
||||
this.baseOpName = baseOpName;
|
||||
if (ids != null) {
|
||||
this.ids = ids;
|
||||
} else {
|
||||
ids = new HashMap<String, Integer>();
|
||||
this.ids = new HashMap<String, Integer>();
|
||||
}
|
||||
}
|
||||
|
||||
@ -145,37 +143,16 @@ final class NameScope {
|
||||
}
|
||||
}
|
||||
|
||||
// The constraints for operator and scope names originate from restrictions on node names
|
||||
// noted in the proto definition core/framework/node_def.proto for NodeDef and actually
|
||||
// implemented in core/framework/node_def_util.cc [Note that the proto comment does not include
|
||||
// dash (-) in names, while the actual implementation permits it. These regexs follow the actual
|
||||
// implementation.]
|
||||
//
|
||||
// These two patterns are used to ensure fully qualified names always start with a
|
||||
// LETTER_DIGIT_DOT, followed by zero or more LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE. SLASH is not
|
||||
// permitted in actual user-supplied names to NameScope - it is used as a reserved character to
|
||||
// separate subcomponents within fully qualified names.
|
||||
private static final Pattern OP_NAME_REGEX = Pattern.compile("[A-Za-z0-9.][A-Za-z0-9_.\\-]*");
|
||||
private static final Pattern ROOT_SCOPE_NAME_REGEX = Pattern.compile("[A-Za-z0-9_.\\-]+");
|
||||
|
||||
private static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
private static final class Builder {
|
||||
private Builder() {}
|
||||
|
||||
private Builder baseName(String name) {
|
||||
baseName = name;
|
||||
return this;
|
||||
}
|
||||
|
||||
private Builder baseOpName(String name) {
|
||||
baseOpName = name;
|
||||
return this;
|
||||
}
|
||||
|
||||
private Builder ids(Map<String, Integer> map) {
|
||||
ids = map;
|
||||
return this;
|
||||
}
|
||||
|
||||
private NameScope build() {
|
||||
return new NameScope(this);
|
||||
}
|
||||
|
||||
private String baseName = null;
|
||||
private String baseOpName = null;
|
||||
private Map<String, Integer> ids = null;
|
||||
}
|
||||
private static final Pattern SUBSCOPE_NAME_REGEX = Pattern.compile("[A-Za-z0-9_.\\-]+");
|
||||
}
|
@ -13,17 +13,19 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow;
|
||||
package org.tensorflow.op;
|
||||
|
||||
import org.tensorflow.Graph;
|
||||
|
||||
/**
|
||||
* A {@code Scope} represents a set of related properties when creating Tensorflow Operations, such
|
||||
* as a common name prefix.
|
||||
* Manages groups of related properties when creating Tensorflow Operations, such as a common name
|
||||
* prefix.
|
||||
*
|
||||
* <p>A {@code Scope} is a container for common properties applied to TensorFlow Ops. Normal user
|
||||
* code initializes a {@code Scope} and provides it to Operation building classes. For example:
|
||||
*
|
||||
* <pre>{@code
|
||||
* Scope scope = Scope.create(graph);
|
||||
* Scope scope = new Scope(graph);
|
||||
* Constant c = Constant.create(scope, 42);
|
||||
* }</pre>
|
||||
*
|
||||
@ -51,13 +53,13 @@ package org.tensorflow;
|
||||
* <p>An example using {@code Constant} implemented as before:
|
||||
*
|
||||
* <pre>{@code
|
||||
* Scope root = Scope.create(graph);
|
||||
* Scope root = new Scope(graph);
|
||||
*
|
||||
* // The linear subscope will generate names like linear/...
|
||||
* Scope linear = Scope.withSubScope("linear");
|
||||
*
|
||||
* // This op name will be "linear/W"
|
||||
* Constant.create(linear.withOpName("W"), ...);
|
||||
* Constant.create(linear.withName("W"), ...);
|
||||
*
|
||||
* // This op will be "linear/Constant", using the default
|
||||
* // name provided by Constant
|
||||
@ -72,23 +74,23 @@ package org.tensorflow;
|
||||
* <p>Scope objects are thread-safe.
|
||||
*/
|
||||
public final class Scope {
|
||||
|
||||
/**
|
||||
* Create a new top-level scope.
|
||||
*
|
||||
* @param graph The graph instance to be managed by the scope.
|
||||
* @return a top-level Scope.
|
||||
*/
|
||||
public static Scope create(Graph graph) {
|
||||
return builder(graph).build();
|
||||
public Scope(Graph graph) {
|
||||
this(graph, new NameScope());
|
||||
}
|
||||
|
||||
/** @return the graph managed by this scope. */
|
||||
/** Returns the graph managed by this scope. */
|
||||
public Graph graph() {
|
||||
return graph;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a new subscope with the provided name.
|
||||
* Returns a new scope where added operations will have the provided name prefix.
|
||||
*
|
||||
* <p>Ops created with this scope will have {@code name/childScopeName/} as the prefix. The actual
|
||||
* name will be unique in the returned scope. All other properties are inherited from the current
|
||||
@ -106,7 +108,7 @@ public final class Scope {
|
||||
* @throws IllegalArgumentException if the name is invalid
|
||||
*/
|
||||
public Scope withSubScope(String childScopeName) {
|
||||
return toBuilder().nameScope(nameScope.withSubScope(childScopeName)).build();
|
||||
return new Scope(graph, nameScope.withSubScope(childScopeName));
|
||||
}
|
||||
|
||||
/**
|
||||
@ -121,8 +123,8 @@ public final class Scope {
|
||||
* @return a new Scope that uses opName for operations.
|
||||
* @throws IllegalArgumentException if the name is invalid
|
||||
*/
|
||||
public Scope withOpName(String opName) {
|
||||
return toBuilder().nameScope(nameScope.withOpName(opName)).build();
|
||||
public Scope withName(String opName) {
|
||||
return new Scope(graph, nameScope.withName(opName));
|
||||
}
|
||||
|
||||
/**
|
||||
@ -158,41 +160,11 @@ public final class Scope {
|
||||
return nameScope.makeOpName(defaultName);
|
||||
}
|
||||
|
||||
private Scope(Builder builder) {
|
||||
graph = builder.graph;
|
||||
if (builder.nameScope != null) {
|
||||
nameScope = builder.nameScope;
|
||||
} else {
|
||||
nameScope = NameScope.create();
|
||||
}
|
||||
private Scope(Graph graph, NameScope nameScope) {
|
||||
this.graph = graph;
|
||||
this.nameScope = nameScope;
|
||||
}
|
||||
|
||||
private final Graph graph;
|
||||
private final NameScope nameScope;
|
||||
|
||||
private Builder toBuilder() {
|
||||
return builder(graph).nameScope(nameScope);
|
||||
}
|
||||
|
||||
private static Builder builder(Graph graph) {
|
||||
return new Builder(graph);
|
||||
}
|
||||
|
||||
private static final class Builder {
|
||||
private Builder(Graph g) {
|
||||
graph = g;
|
||||
}
|
||||
|
||||
private Builder nameScope(NameScope ns) {
|
||||
nameScope = ns;
|
||||
return this;
|
||||
}
|
||||
|
||||
private Scope build() {
|
||||
return new Scope(this);
|
||||
}
|
||||
|
||||
private final Graph graph;
|
||||
private NameScope nameScope;
|
||||
}
|
||||
}
|
@ -1,158 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
/** Unit tests for {@link org.tensorflow.NameScope}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public class NameScopeTest {
|
||||
|
||||
@Test
|
||||
public void basicNames() {
|
||||
NameScope root = NameScope.create();
|
||||
assertEquals("add", root.makeOpName("add"));
|
||||
assertEquals("add_1", root.makeOpName("add"));
|
||||
assertEquals("add_2", root.makeOpName("add"));
|
||||
assertEquals("mul", root.makeOpName("mul"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void hierarchicalNames() {
|
||||
NameScope root = NameScope.create();
|
||||
NameScope child = root.withSubScope("child");
|
||||
assertEquals("child/add", child.makeOpName("add"));
|
||||
assertEquals("child/add_1", child.makeOpName("add"));
|
||||
assertEquals("child/mul", child.makeOpName("mul"));
|
||||
|
||||
NameScope child_1 = root.withSubScope("child");
|
||||
assertEquals("child_1/add", child_1.makeOpName("add"));
|
||||
assertEquals("child_1/add_1", child_1.makeOpName("add"));
|
||||
assertEquals("child_1/mul", child_1.makeOpName("mul"));
|
||||
|
||||
NameScope c_c = root.withSubScope("c").withSubScope("c");
|
||||
assertEquals("c/c/add", c_c.makeOpName("add"));
|
||||
|
||||
NameScope c_1 = root.withSubScope("c");
|
||||
NameScope c_1_c = c_1.withSubScope("c");
|
||||
assertEquals("c_1/c/add", c_1_c.makeOpName("add"));
|
||||
|
||||
NameScope c_1_c_1 = c_1.withSubScope("c");
|
||||
assertEquals("c_1/c_1/add", c_1_c_1.makeOpName("add"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void scopeAndOpNames() {
|
||||
NameScope root = NameScope.create();
|
||||
NameScope child = root.withSubScope("child");
|
||||
|
||||
assertEquals("child/add", child.makeOpName("add"));
|
||||
assertEquals("child_1", root.makeOpName("child"));
|
||||
assertEquals("child_2/p", root.withSubScope("child").makeOpName("p"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void names() {
|
||||
NameScope root = NameScope.create();
|
||||
|
||||
final String[] invalid_names = {
|
||||
"_", // Names are constrained to start with [A-Za-z0-9.]
|
||||
null, "", "a$", // Invalid characters
|
||||
"a/b", // slashes not allowed
|
||||
};
|
||||
|
||||
for (String name : invalid_names) {
|
||||
try {
|
||||
root.withOpName(name);
|
||||
fail("failed to catch invalid op name.");
|
||||
} catch (IllegalArgumentException ex) {
|
||||
// expected
|
||||
}
|
||||
// Root scopes follow the same rules as opnames
|
||||
try {
|
||||
root.withSubScope(name);
|
||||
fail("failed to catch invalid scope name: " + name);
|
||||
} catch (IllegalArgumentException ex) {
|
||||
// expected
|
||||
}
|
||||
}
|
||||
|
||||
// Non-root scopes have a less restrictive constraint.
|
||||
assertEquals("a/_/hello", root.withSubScope("a").withSubScope("_").makeOpName("hello"));
|
||||
}
|
||||
|
||||
// A dummy composite op - it should create a tree of op names.
|
||||
private static void topCompositeOp(NameScope scope, List<String> opnames) {
|
||||
NameScope compScope = scope.withSubScope("top");
|
||||
opnames.add(compScope.makeOpName("mul"));
|
||||
opnames.add(compScope.makeOpName("bias_add"));
|
||||
intermediateOp(compScope, opnames);
|
||||
}
|
||||
|
||||
private static void intermediateOp(NameScope scope, List<String> opnames) {
|
||||
NameScope compScope = scope.withSubScope("intermediate");
|
||||
opnames.add(compScope.makeOpName("c1"));
|
||||
opnames.add(compScope.makeOpName("mul"));
|
||||
leafOp(compScope.withOpName("c2"), opnames);
|
||||
}
|
||||
|
||||
private static void leafOp(NameScope scope, List<String> opnames) {
|
||||
opnames.add(scope.makeOpName("const"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void compositeOp() {
|
||||
NameScope root = NameScope.create();
|
||||
List<String> names = new ArrayList<String>();
|
||||
topCompositeOp(root, names);
|
||||
assertEquals(
|
||||
Arrays.asList(
|
||||
"top/mul",
|
||||
"top/bias_add",
|
||||
"top/intermediate/c1",
|
||||
"top/intermediate/mul",
|
||||
"top/intermediate/c2"),
|
||||
names);
|
||||
|
||||
assertEquals("top_1", root.makeOpName("top"));
|
||||
|
||||
names.clear();
|
||||
topCompositeOp(root, names);
|
||||
assertEquals(
|
||||
Arrays.asList(
|
||||
"top_2/mul",
|
||||
"top_2/bias_add",
|
||||
"top_2/intermediate/c1",
|
||||
"top_2/intermediate/mul",
|
||||
"top_2/intermediate/c2"),
|
||||
names);
|
||||
|
||||
names.clear();
|
||||
topCompositeOp(root.withOpName("c"), names);
|
||||
assertEquals(
|
||||
Arrays.asList(
|
||||
"c/mul", "c/bias_add", "c/intermediate/c1", "c/intermediate/mul", "c/intermediate/c2"),
|
||||
names);
|
||||
}
|
||||
}
|
@ -1,184 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
/** Unit tests for {@link org.tensorflow.Scope}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public class ScopeTest {
|
||||
|
||||
@Test
|
||||
public void basic() {
|
||||
try (Graph g = new Graph()) {
|
||||
Scope s = Scope.create(g);
|
||||
Const c1 = Const.create(s, 42);
|
||||
assertEquals("Const", c1.input().op().name());
|
||||
Const c2 = Const.create(s, 7);
|
||||
assertEquals("Const_1", c2.input().op().name());
|
||||
Const c3 = Const.create(s.withOpName("four"), 4);
|
||||
assertEquals("four", c3.input().op().name());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void hierarchy() {
|
||||
try (Graph g = new Graph()) {
|
||||
Scope root = Scope.create(g);
|
||||
Scope child = root.withSubScope("child");
|
||||
assertEquals("child/Const", Const.create(child, 42).input().op().name());
|
||||
assertEquals("child/four", Const.create(child.withOpName("four"), 4).input().op().name());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void composite() {
|
||||
try (Graph g = new Graph();
|
||||
Session sess = new Session(g)) {
|
||||
Scope s = Scope.create(g);
|
||||
Const data = Const.create(s.withOpName("data"), new int[] {600, 470, 170, 430, 300});
|
||||
|
||||
// Create a composite op with a customized name
|
||||
Variance var1 = Variance.create(s.withOpName("example"), data);
|
||||
assertEquals("example/variance", var1.input().op().name());
|
||||
|
||||
// Confirm internally added ops have the right names.
|
||||
assertNotNull(g.operation("example/squared_deviation"));
|
||||
assertNotNull(g.operation("example/Mean"));
|
||||
assertNotNull(g.operation("example/zero"));
|
||||
|
||||
// Same composite op with a default name
|
||||
Variance var2 = Variance.create(s, data);
|
||||
assertEquals("variance/variance", var2.input().op().name());
|
||||
|
||||
// Confirm internally added ops have the right names.
|
||||
assertNotNull(g.operation("variance/squared_deviation"));
|
||||
assertNotNull(g.operation("variance/Mean"));
|
||||
assertNotNull(g.operation("variance/zero"));
|
||||
|
||||
// Verify correct results as well.
|
||||
Tensor result = sess.runner().fetch(var1.input()).run().get(0);
|
||||
assertEquals(21704, result.intValue());
|
||||
result = sess.runner().fetch(var2.input()).run().get(0);
|
||||
assertEquals(21704, result.intValue());
|
||||
}
|
||||
}
|
||||
|
||||
// Convenience interface
|
||||
// TODO: replace when standardized
|
||||
interface Input {
|
||||
Output input();
|
||||
}
|
||||
|
||||
// "handwritten" sample operator classes
|
||||
private static final class Const implements Input {
|
||||
private final Output output;
|
||||
|
||||
private static Const create(Scope s, Object v) {
|
||||
try (Tensor value = Tensor.create(v)) {
|
||||
return new Const(
|
||||
s.graph()
|
||||
.opBuilder("Const", s.makeOpName("Const"))
|
||||
.setAttr("dtype", value.dataType())
|
||||
.setAttr("value", value)
|
||||
.build()
|
||||
.output(0));
|
||||
}
|
||||
}
|
||||
|
||||
private Const(Output o) {
|
||||
output = o;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Output input() {
|
||||
return output;
|
||||
}
|
||||
}
|
||||
|
||||
private static final class Mean implements Input {
|
||||
private final Output output;
|
||||
|
||||
private static Mean create(Scope s, Input input, Input reductionIndices) {
|
||||
return new Mean(
|
||||
s.graph()
|
||||
.opBuilder("Mean", s.makeOpName("Mean"))
|
||||
.addInput(input.input())
|
||||
.addInput(reductionIndices.input())
|
||||
.build()
|
||||
.output(0));
|
||||
}
|
||||
|
||||
private Mean(Output o) {
|
||||
output = o;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Output input() {
|
||||
return output;
|
||||
}
|
||||
}
|
||||
|
||||
private static final class SquaredDifference implements Input {
|
||||
private final Output output;
|
||||
|
||||
private static SquaredDifference create(Scope s, Input x, Input y) {
|
||||
return new SquaredDifference(
|
||||
s.graph()
|
||||
.opBuilder("SquaredDifference", s.makeOpName("SquaredDifference"))
|
||||
.addInput(x.input())
|
||||
.addInput(y.input())
|
||||
.build()
|
||||
.output(0));
|
||||
}
|
||||
|
||||
private SquaredDifference(Output o) {
|
||||
output = o;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Output input() {
|
||||
return output;
|
||||
}
|
||||
}
|
||||
|
||||
private static final class Variance implements Input {
|
||||
private final Output output;
|
||||
|
||||
private static Variance create(Scope base, Input x) {
|
||||
Scope s = base.withSubScope("variance");
|
||||
Const zero = Const.create(s.withOpName("zero"), new int[] {0});
|
||||
SquaredDifference sqdiff =
|
||||
SquaredDifference.create(s.withOpName("squared_deviation"), x, Mean.create(s, x, zero));
|
||||
|
||||
return new Variance(Mean.create(s.withOpName("variance"), sqdiff, zero).input());
|
||||
}
|
||||
|
||||
private Variance(Output o) {
|
||||
output = o;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Output input() {
|
||||
return output;
|
||||
}
|
||||
}
|
||||
}
|
265
tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java
Normal file
265
tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java
Normal file
@ -0,0 +1,265 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.op;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
import org.tensorflow.Graph;
|
||||
import org.tensorflow.Output;
|
||||
import org.tensorflow.Session;
|
||||
import org.tensorflow.Tensor;
|
||||
|
||||
/** Unit tests for {@link org.tensorflow.Scope}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public class ScopeTest {
|
||||
|
||||
@Test
|
||||
public void basicNames() {
|
||||
try (Graph g = new Graph()) {
|
||||
Scope root = new Scope(g);
|
||||
assertEquals("add", root.makeOpName("add"));
|
||||
assertEquals("add_1", root.makeOpName("add"));
|
||||
assertEquals("add_2", root.makeOpName("add"));
|
||||
assertEquals("mul", root.makeOpName("mul"));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void hierarchicalNames() {
|
||||
try (Graph g = new Graph()) {
|
||||
Scope root = new Scope(g);
|
||||
Scope child = root.withSubScope("child");
|
||||
assertEquals("child/add", child.makeOpName("add"));
|
||||
assertEquals("child/add_1", child.makeOpName("add"));
|
||||
assertEquals("child/mul", child.makeOpName("mul"));
|
||||
|
||||
Scope child_1 = root.withSubScope("child");
|
||||
assertEquals("child_1/add", child_1.makeOpName("add"));
|
||||
assertEquals("child_1/add_1", child_1.makeOpName("add"));
|
||||
assertEquals("child_1/mul", child_1.makeOpName("mul"));
|
||||
|
||||
Scope c_c = root.withSubScope("c").withSubScope("c");
|
||||
assertEquals("c/c/add", c_c.makeOpName("add"));
|
||||
|
||||
Scope c_1 = root.withSubScope("c");
|
||||
Scope c_1_c = c_1.withSubScope("c");
|
||||
assertEquals("c_1/c/add", c_1_c.makeOpName("add"));
|
||||
|
||||
Scope c_1_c_1 = c_1.withSubScope("c");
|
||||
assertEquals("c_1/c_1/add", c_1_c_1.makeOpName("add"));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void scopeAndOpNames() {
|
||||
try (Graph g = new Graph()) {
|
||||
Scope root = new Scope(g);
|
||||
|
||||
Scope child = root.withSubScope("child");
|
||||
|
||||
assertEquals("child/add", child.makeOpName("add"));
|
||||
assertEquals("child_1", root.makeOpName("child"));
|
||||
assertEquals("child_2/p", root.withSubScope("child").makeOpName("p"));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void validateNames() {
|
||||
try (Graph g = new Graph()) {
|
||||
Scope root = new Scope(g);
|
||||
|
||||
final String[] invalid_names = {
|
||||
"_", // Names are constrained to start with [A-Za-z0-9.]
|
||||
null, "", "a$", // Invalid characters
|
||||
"a/b", // slashes not allowed
|
||||
};
|
||||
|
||||
for (String name : invalid_names) {
|
||||
try {
|
||||
root.withName(name);
|
||||
fail("failed to catch invalid op name.");
|
||||
} catch (IllegalArgumentException ex) {
|
||||
// expected
|
||||
}
|
||||
// Root scopes follow the same rules as opnames
|
||||
try {
|
||||
root.withSubScope(name);
|
||||
fail("failed to catch invalid scope name: " + name);
|
||||
} catch (IllegalArgumentException ex) {
|
||||
// expected
|
||||
}
|
||||
}
|
||||
|
||||
// Non-root scopes have a less restrictive constraint.
|
||||
assertEquals("a/_/hello", root.withSubScope("a").withSubScope("_").makeOpName("hello"));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void basic() {
|
||||
try (Graph g = new Graph()) {
|
||||
Scope s = new Scope(g);
|
||||
Const c1 = Const.create(s, 42);
|
||||
assertEquals("Const", c1.output().op().name());
|
||||
Const c2 = Const.create(s, 7);
|
||||
assertEquals("Const_1", c2.output().op().name());
|
||||
Const c3 = Const.create(s.withName("four"), 4);
|
||||
assertEquals("four", c3.output().op().name());
|
||||
Const c4 = Const.create(s.withName("four"), 4);
|
||||
assertEquals("four_1", c4.output().op().name());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void hierarchy() {
|
||||
try (Graph g = new Graph()) {
|
||||
Scope root = new Scope(g);
|
||||
Scope child = root.withSubScope("child");
|
||||
assertEquals("child/Const", Const.create(child, 42).output().op().name());
|
||||
assertEquals("child/four", Const.create(child.withName("four"), 4).output().op().name());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void composite() {
|
||||
try (Graph g = new Graph();
|
||||
Session sess = new Session(g)) {
|
||||
Scope s = new Scope(g);
|
||||
Output data = Const.create(s.withName("data"), new int[] {600, 470, 170, 430, 300}).output();
|
||||
|
||||
// Create a composite op with a customized name
|
||||
Variance var1 = Variance.create(s.withName("example"), data);
|
||||
assertEquals("example/variance", var1.output().op().name());
|
||||
|
||||
// Confirm internally added ops have the right names.
|
||||
assertNotNull(g.operation("example/squared_deviation"));
|
||||
assertNotNull(g.operation("example/Mean"));
|
||||
assertNotNull(g.operation("example/zero"));
|
||||
|
||||
// Same composite op with a default name
|
||||
Variance var2 = Variance.create(s, data);
|
||||
assertEquals("variance/variance", var2.output().op().name());
|
||||
|
||||
// Confirm internally added ops have the right names.
|
||||
assertNotNull(g.operation("variance/squared_deviation"));
|
||||
assertNotNull(g.operation("variance/Mean"));
|
||||
assertNotNull(g.operation("variance/zero"));
|
||||
|
||||
// Verify correct results as well.
|
||||
Tensor result = sess.runner().fetch(var1.output()).run().get(0);
|
||||
assertEquals(21704, result.intValue());
|
||||
result = sess.runner().fetch(var2.output()).run().get(0);
|
||||
assertEquals(21704, result.intValue());
|
||||
}
|
||||
}
|
||||
|
||||
// "handwritten" sample operator classes
|
||||
private static final class Const {
|
||||
private final Output output;
|
||||
|
||||
private static Const create(Scope s, Object v) {
|
||||
try (Tensor value = Tensor.create(v)) {
|
||||
return new Const(
|
||||
s.graph()
|
||||
.opBuilder("Const", s.makeOpName("Const"))
|
||||
.setAttr("dtype", value.dataType())
|
||||
.setAttr("value", value)
|
||||
.build()
|
||||
.output(0));
|
||||
}
|
||||
}
|
||||
|
||||
private Const(Output o) {
|
||||
output = o;
|
||||
}
|
||||
|
||||
public Output output() {
|
||||
return output;
|
||||
}
|
||||
}
|
||||
|
||||
private static final class Mean {
|
||||
private final Output output;
|
||||
|
||||
private static Mean create(Scope s, Output input, Output reductionIndices) {
|
||||
return new Mean(
|
||||
s.graph()
|
||||
.opBuilder("Mean", s.makeOpName("Mean"))
|
||||
.addInput(input)
|
||||
.addInput(reductionIndices)
|
||||
.build()
|
||||
.output(0));
|
||||
}
|
||||
|
||||
private Mean(Output o) {
|
||||
output = o;
|
||||
}
|
||||
|
||||
public Output output() {
|
||||
return output;
|
||||
}
|
||||
}
|
||||
|
||||
private static final class SquaredDifference {
|
||||
private final Output output;
|
||||
|
||||
private static SquaredDifference create(Scope s, Output x, Output y) {
|
||||
return new SquaredDifference(
|
||||
s.graph()
|
||||
.opBuilder("SquaredDifference", s.makeOpName("SquaredDifference"))
|
||||
.addInput(x)
|
||||
.addInput(y)
|
||||
.build()
|
||||
.output(0));
|
||||
}
|
||||
|
||||
private SquaredDifference(Output o) {
|
||||
output = o;
|
||||
}
|
||||
|
||||
public Output output() {
|
||||
return output;
|
||||
}
|
||||
}
|
||||
|
||||
private static final class Variance {
|
||||
private final Output output;
|
||||
|
||||
private static Variance create(Scope base, Output x) {
|
||||
Scope s = base.withSubScope("variance");
|
||||
Output zero = Const.create(s.withName("zero"), new int[] {0}).output();
|
||||
Output sqdiff =
|
||||
SquaredDifference.create(
|
||||
s.withName("squared_deviation"), x, Mean.create(s, x, zero).output())
|
||||
.output();
|
||||
|
||||
return new Variance(Mean.create(s.withName("variance"), sqdiff, zero).output());
|
||||
}
|
||||
|
||||
private Variance(Output o) {
|
||||
output = o;
|
||||
}
|
||||
|
||||
public Output output() {
|
||||
return output;
|
||||
}
|
||||
}
|
||||
}
|
@ -20,6 +20,8 @@
|
||||
# And jars:
|
||||
# (3) Java API .jar
|
||||
# (4) Java API sources .jar
|
||||
# (5) Java Op wrappers .jar
|
||||
# (6) Java Op wrapper sources .jar
|
||||
#
|
||||
# These binary distributions will allow use of TensorFlow in various languages
|
||||
# without having to compile the TensorFlow framework from sources, which takes
|
||||
@ -34,6 +36,8 @@
|
||||
# - lib_package/libtensorflow_jni${SUFFIX}.tar.gz
|
||||
# - lib_package/libtensorflow.jar
|
||||
# - lib_package/libtensorflow-src.jar
|
||||
# - lib_package/libtensorflow_op.jar
|
||||
# - lib_package/libtensorflow_op-src.jar
|
||||
# - lib_package/libtensorflow_proto.zip
|
||||
#
|
||||
# ASSUMPTIONS:
|
||||
@ -75,12 +79,15 @@ function build_libtensorflow_tarball() {
|
||||
//tensorflow/tools/lib_package:libtensorflow_jni.tar.gz \
|
||||
//tensorflow/java:libtensorflow.jar \
|
||||
//tensorflow/java:libtensorflow-src.jar \
|
||||
//tensorflow/java:libtensorflow_op.jar \
|
||||
//tensorflow/java:libtensorflow_op-src.jar \
|
||||
//tensorflow/tools/lib_package:libtensorflow_proto.zip
|
||||
|
||||
mkdir -p ${DIR}
|
||||
cp bazel-bin/tensorflow/tools/lib_package/libtensorflow.tar.gz ${DIR}/libtensorflow${TARBALL_SUFFIX}.tar.gz
|
||||
cp bazel-bin/tensorflow/tools/lib_package/libtensorflow_jni.tar.gz ${DIR}/libtensorflow_jni${TARBALL_SUFFIX}.tar.gz
|
||||
cp bazel-bin/tensorflow/java/libtensorflow.jar bazel-bin/tensorflow/java/libtensorflow-src.jar ${DIR}
|
||||
cp bazel-bin/tensorflow/java/libtensorflow_op.jar bazel-bin/tensorflow/java/libtensorflow_op-src.jar ${DIR}
|
||||
cp bazel-genfiles/tensorflow/tools/lib_package/libtensorflow_proto.zip ${DIR}
|
||||
chmod -x ${DIR}/*
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user