diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index 64b37677357..50dca9270c7 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -6,6 +6,7 @@ package(default_visibility = ["//visibility:private"]) licenses(["notice"]) # Apache 2.0 load("build_defs", "JAVACOPTS") +load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_gen_op_wrappers_java") java_library( name = "tensorflow", @@ -34,12 +35,38 @@ filegroup( filegroup( name = "java_op_sources", - srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]), + srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]) + [ + ":java_op_gensources", + ], visibility = [ "//tensorflow/java:__pkg__", ], ) +tf_gen_op_wrappers_java( + name = "java_op_gensources", + ops_libs = [ + "array_ops", + "candidate_sampling_ops", + "control_flow_ops", + "data_flow_ops", + "image_ops", + "io_ops", + "linalg_ops", + "logging_ops", + "math_ops", + "nn_ops", + "no_op", + "parsing_ops", + "random_ops", + "sparse_ops", + "state_ops", + "string_ops", + "training_ops", + "user_ops", + ], +) + java_library( name = "testutil", testonly = 1, @@ -282,3 +309,20 @@ filegroup( ), visibility = ["//tensorflow:__subpackages__"], ) + +cc_library( + name = "java_ops_gentool", + srcs = glob([ + "src/gen/cc/gen_util.h", + "src/gen/cc/ops/*.h", + "src/gen/cc/ops/*.cc", + ]), + copts = tf_copts(), + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:proto_text", + "//tensorflow/core:protos_all_cc", + ], +) diff --git a/tensorflow/java/src/gen/cc/gen_util.h b/tensorflow/java/src/gen/cc/gen_util.h new file mode 100644 index 00000000000..b6c419dcf82 --- /dev/null +++ b/tensorflow/java/src/gen/cc/gen_util.h @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_GEN_UTIL_H_ +#define TENSORFLOW_JAVA_SRC_GEN_CC_GEN_UTIL_H_ + +#include + +namespace tensorflow { +namespace gen_util { + +inline std::string StripChar(const std::string& str, const char old_char) { + std::string ret; + for (const char& c : str) { + if (c != old_char) { + ret.push_back(c); + } + } + return ret; +} + +inline std::string ReplaceChar(const std::string& str, const char old_char, + const char new_char) { + std::string ret; + for (const char& c : str) { + ret.push_back(c == old_char ? new_char : c); + } + return ret; +} + +} // namespace gen_util +} // namespace tensorflow + +#endif // TENSORFLOW_JAVA_SRC_GEN_CC_GEN_UTIL_H_ diff --git a/tensorflow/java/src/gen/cc/ops/main.cc b/tensorflow/java/src/gen/cc/ops/main.cc new file mode 100644 index 00000000000..eff93e9ee5b --- /dev/null +++ b/tensorflow/java/src/gen/cc/ops/main.cc @@ -0,0 +1,33 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +#include + +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/java/src/gen/cc/ops/op_generator.h" + +int main(int argc, char* argv[]) { + tensorflow::port::InitMain(argv[0], &argc, &argv); + const std::string& output_dir = argv[1]; + const std::string& ops_lib = argv[2]; + + LOG(INFO) << "Generating Java operation wrappers for \"" + << ops_lib << "\" library"; + + tensorflow::OpGenerator generator(ops_lib); + + return generator.Run(output_dir + "/src/main/java/"); +} diff --git a/tensorflow/java/src/gen/cc/ops/op_generator.cc b/tensorflow/java/src/gen/cc/ops/op_generator.cc new file mode 100644 index 00000000000..34b97948afe --- /dev/null +++ b/tensorflow/java/src/gen/cc/ops/op_generator.cc @@ -0,0 +1,52 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/java/src/gen/cc/gen_util.h" +#include "tensorflow/java/src/gen/cc/ops/op_generator.h" + +namespace tensorflow { + +OpGenerator::OpGenerator(const std::string& ops_lib) + : ops_lib(ops_lib) { + + const std::string& lib_name = ops_lib.substr(0, ops_lib.find_last_of('_')); + package_name = gen_util::StripChar("org.tensorflow.op." + lib_name, '_'); +} + +OpGenerator::~OpGenerator() {} + +int OpGenerator::Run(const std::string& output_path) { + tensorflow::Env* env = tensorflow::Env::Default(); + std::string package_path(output_path); + package_path += gen_util::ReplaceChar(package_name, '.', '/'); + + if (!env->FileExists(package_path).ok()) { + TF_CHECK_OK(env->RecursivelyCreateDir(package_path)); + } + + OpList ops; + OpRegistry::Global()->Export(true, &ops); + + // TODO(karllessard) generate wrappers from collected ops + + return 0; +} + +} // namespace tensorflow diff --git a/tensorflow/java/src/gen/cc/ops/op_generator.h b/tensorflow/java/src/gen/cc/ops/op_generator.h new file mode 100644 index 00000000000..fc52d1ac18d --- /dev/null +++ b/tensorflow/java/src/gen/cc/ops/op_generator.h @@ -0,0 +1,37 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_OPS_OP_GENERATOR_H_ +#define TENSORFLOW_JAVA_SRC_GEN_CC_OPS_OP_GENERATOR_H_ + +#include + +namespace tensorflow { + +class OpGenerator { + public: + explicit OpGenerator(const std::string& ops_lib); + virtual ~OpGenerator(); + + int Run(const std::string& output_path); + + private: + const std::string ops_lib; + std::string package_name; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_JAVA_SRC_GEN_CC_OPS_OP_GENERATOR_H_ diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index f0301937fba..efb2d9a4949 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -430,6 +430,57 @@ def tf_gen_op_wrapper_py(name, ],) +# Given a list of "ops_libs" (a list of files in the core/ops directory +# without their .cc extensions), generate Java wrapper code for all operations +# found in the ops files. +# Then, combine all those source files into a single archive (.srcjar). +# +# For example: +# tf_gen_op_wrappers_java("java_op_gensources", [ "array_ops", "math_ops" ]) +# +# will create a genrule named "java_op_gensources" that first generate source files: +# ops/src/main/java/org/tensorflow/op/array/*.java +# ops/src/main/java/org/tensorflow/op/math/*.java +# +# and then archive those source files in: +# ops/java_op_gensources.srcjar +# +def tf_gen_op_wrappers_java(name, + ops_libs=[], + ops_libs_pkg="//tensorflow/core", + out_dir="ops/", + gen_main=clean_dep("//tensorflow/java:java_ops_gentool"), + visibility=["//tensorflow/java:__pkg__"]): + + gen_tools = [] + gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files + + # Construct an op generator binary for each ops library. + for ops_lib in ops_libs: + gen_tool = out_dir + ops_lib + "_gentool" + + native.cc_binary( + name=gen_tool, + copts=tf_copts(), + linkopts=["-lm"], + linkstatic=1, # Faster to link this one-time-use binary dynamically + deps=[gen_main, ops_libs_pkg + ":" + ops_lib + "_op_lib"]) + + gen_tools += [":" + gen_tool] + gen_cmds += ["$(location :" + gen_tool + ") $(@D) " + ops_lib] + + # Generate a source archive containing generated code for these ops. + gen_srcjar = out_dir + name + ".srcjar" + gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) ."] + + native.genrule( + name=name, + srcs=["@local_jdk//:jar"], + outs=[gen_srcjar], + tools=gen_tools, + cmd='&&'.join(gen_cmds)) + + # Define a bazel macro that creates cc_test for tensorflow. # TODO(opensource): we need to enable this to work around the hidden symbol # __cudaRegisterFatBinary error. Need more investigations.