Build rules and skeleton for Java operation wrappers generator

This commit is contained in:
karl@kubx.ca 2017-07-28 18:33:03 -04:00 committed by Martin Wicke
parent fad50ea1ca
commit 621c2dcf27
6 changed files with 264 additions and 1 deletions

View File

@ -6,6 +6,7 @@ package(default_visibility = ["//visibility:private"])
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
load("build_defs", "JAVACOPTS") load("build_defs", "JAVACOPTS")
load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_gen_op_wrappers_java")
java_library( java_library(
name = "tensorflow", name = "tensorflow",
@ -34,12 +35,38 @@ filegroup(
filegroup( filegroup(
name = "java_op_sources", name = "java_op_sources",
srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]), srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]) + [
":java_op_gensources",
],
visibility = [ visibility = [
"//tensorflow/java:__pkg__", "//tensorflow/java:__pkg__",
], ],
) )
tf_gen_op_wrappers_java(
name = "java_op_gensources",
ops_libs = [
"array_ops",
"candidate_sampling_ops",
"control_flow_ops",
"data_flow_ops",
"image_ops",
"io_ops",
"linalg_ops",
"logging_ops",
"math_ops",
"nn_ops",
"no_op",
"parsing_ops",
"random_ops",
"sparse_ops",
"state_ops",
"string_ops",
"training_ops",
"user_ops",
],
)
java_library( java_library(
name = "testutil", name = "testutil",
testonly = 1, testonly = 1,
@ -282,3 +309,20 @@ filegroup(
), ),
visibility = ["//tensorflow:__subpackages__"], visibility = ["//tensorflow:__subpackages__"],
) )
cc_library(
name = "java_ops_gentool",
srcs = glob([
"src/gen/cc/gen_util.h",
"src/gen/cc/ops/*.h",
"src/gen/cc/ops/*.cc",
]),
copts = tf_copts(),
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:proto_text",
"//tensorflow/core:protos_all_cc",
],
)

View File

@ -0,0 +1,46 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_GEN_UTIL_H_
#define TENSORFLOW_JAVA_SRC_GEN_CC_GEN_UTIL_H_
#include <string>
namespace tensorflow {
namespace gen_util {
inline std::string StripChar(const std::string& str, const char old_char) {
std::string ret;
for (const char& c : str) {
if (c != old_char) {
ret.push_back(c);
}
}
return ret;
}
inline std::string ReplaceChar(const std::string& str, const char old_char,
const char new_char) {
std::string ret;
for (const char& c : str) {
ret.push_back(c == old_char ? new_char : c);
}
return ret;
}
} // namespace gen_util
} // namespace tensorflow
#endif // TENSORFLOW_JAVA_SRC_GEN_CC_GEN_UTIL_H_

View File

@ -0,0 +1,33 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/java/src/gen/cc/ops/op_generator.h"
int main(int argc, char* argv[]) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
const std::string& output_dir = argv[1];
const std::string& ops_lib = argv[2];
LOG(INFO) << "Generating Java operation wrappers for \""
<< ops_lib << "\" library";
tensorflow::OpGenerator generator(ops_lib);
return generator.Run(output_dir + "/src/main/java/");
}

View File

@ -0,0 +1,52 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/java/src/gen/cc/gen_util.h"
#include "tensorflow/java/src/gen/cc/ops/op_generator.h"
namespace tensorflow {
OpGenerator::OpGenerator(const std::string& ops_lib)
: ops_lib(ops_lib) {
const std::string& lib_name = ops_lib.substr(0, ops_lib.find_last_of('_'));
package_name = gen_util::StripChar("org.tensorflow.op." + lib_name, '_');
}
OpGenerator::~OpGenerator() {}
int OpGenerator::Run(const std::string& output_path) {
tensorflow::Env* env = tensorflow::Env::Default();
std::string package_path(output_path);
package_path += gen_util::ReplaceChar(package_name, '.', '/');
if (!env->FileExists(package_path).ok()) {
TF_CHECK_OK(env->RecursivelyCreateDir(package_path));
}
OpList ops;
OpRegistry::Global()->Export(true, &ops);
// TODO(karllessard) generate wrappers from collected ops
return 0;
}
} // namespace tensorflow

View File

@ -0,0 +1,37 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_OPS_OP_GENERATOR_H_
#define TENSORFLOW_JAVA_SRC_GEN_CC_OPS_OP_GENERATOR_H_
#include <string>
namespace tensorflow {
class OpGenerator {
public:
explicit OpGenerator(const std::string& ops_lib);
virtual ~OpGenerator();
int Run(const std::string& output_path);
private:
const std::string ops_lib;
std::string package_name;
};
} // namespace tensorflow
#endif // TENSORFLOW_JAVA_SRC_GEN_CC_OPS_OP_GENERATOR_H_

View File

@ -430,6 +430,57 @@ def tf_gen_op_wrapper_py(name,
],) ],)
# Given a list of "ops_libs" (a list of files in the core/ops directory
# without their .cc extensions), generate Java wrapper code for all operations
# found in the ops files.
# Then, combine all those source files into a single archive (.srcjar).
#
# For example:
# tf_gen_op_wrappers_java("java_op_gensources", [ "array_ops", "math_ops" ])
#
# will create a genrule named "java_op_gensources" that first generate source files:
# ops/src/main/java/org/tensorflow/op/array/*.java
# ops/src/main/java/org/tensorflow/op/math/*.java
#
# and then archive those source files in:
# ops/java_op_gensources.srcjar
#
def tf_gen_op_wrappers_java(name,
ops_libs=[],
ops_libs_pkg="//tensorflow/core",
out_dir="ops/",
gen_main=clean_dep("//tensorflow/java:java_ops_gentool"),
visibility=["//tensorflow/java:__pkg__"]):
gen_tools = []
gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files
# Construct an op generator binary for each ops library.
for ops_lib in ops_libs:
gen_tool = out_dir + ops_lib + "_gentool"
native.cc_binary(
name=gen_tool,
copts=tf_copts(),
linkopts=["-lm"],
linkstatic=1, # Faster to link this one-time-use binary dynamically
deps=[gen_main, ops_libs_pkg + ":" + ops_lib + "_op_lib"])
gen_tools += [":" + gen_tool]
gen_cmds += ["$(location :" + gen_tool + ") $(@D) " + ops_lib]
# Generate a source archive containing generated code for these ops.
gen_srcjar = out_dir + name + ".srcjar"
gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) ."]
native.genrule(
name=name,
srcs=["@local_jdk//:jar"],
outs=[gen_srcjar],
tools=gen_tools,
cmd='&&'.join(gen_cmds))
# Define a bazel macro that creates cc_test for tensorflow. # Define a bazel macro that creates cc_test for tensorflow.
# TODO(opensource): we need to enable this to work around the hidden symbol # TODO(opensource): we need to enable this to work around the hidden symbol
# __cudaRegisterFatBinary error. Need more investigations. # __cudaRegisterFatBinary error. Need more investigations.