diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index 11c5a2d3022..ee07fc48132 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -6,7 +6,7 @@ package(default_visibility = ["//visibility:private"]) licenses(["notice"]) # Apache 2.0 load(":build_defs.bzl", "JAVACOPTS") -load(":src/gen/gen_ops.bzl", "java_op_gen_srcjar") +load(":src/gen/gen_ops.bzl", "tf_java_op_gen_srcjar") load("//tensorflow:tensorflow.bzl", "tf_copts") java_library( @@ -44,8 +44,9 @@ filegroup( ], ) -java_op_gen_srcjar( +tf_java_op_gen_srcjar( name = "java_op_gen_sources", + gen_base_package = "org.tensorflow.op", gen_tool = "java_op_gen_tool", ops_libs = [ "array_ops", @@ -70,7 +71,7 @@ java_op_gen_srcjar( ) # Build the gen tool as a library, as it will be linked to a core/ops binary -# file before making it an executable. See java_op_gen_srcjar(). +# file before making it an executable. See tf_java_op_gen_srcjar(). cc_library( name = "java_op_gen_tool", srcs = glob([ diff --git a/tensorflow/java/src/gen/cc/op_gen_main.cc b/tensorflow/java/src/gen/cc/op_gen_main.cc index 16cd3e8c71a..c396d78f15c 100644 --- a/tensorflow/java/src/gen/cc/op_gen_main.cc +++ b/tensorflow/java/src/gen/cc/op_gen_main.cc @@ -11,7 +11,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - ==============================================================================*/ +==============================================================================*/ #include #include @@ -19,6 +19,7 @@ #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/java/src/gen/cc/op_generator.h" @@ -28,33 +29,53 @@ namespace op_gen { const char kUsageHeader[] = "\n\nGenerator of operation wrappers in Java.\n\n" - "This executable generates wrappers for all operations registered in the\n" - "ops file it has been linked to (i.e. one of the /core/ops/*.o binaries).\n" - "Generated files are output to the path provided as an argument, under\n" - "their appropriate package and using a maven-style directory layout.\n\n"; + "This executable generates wrappers for all registered operations it has " + "been compiled with. A wrapper exposes an intuitive and strongly-typed\n" + "interface for building its underlying operation and linking it into a " + "graph.\n\n" + "Operation wrappers are generated under the path specified by the " + "'--output_dir' argument. This path can be absolute or relative to the\n" + "current working directory and will be created if it does not exists.\n\n" + "The '--lib_name' argument is used to classify the set of operations. If " + "the chosen name contains more than one word, it must be provided in \n" + "snake_case. This value is declined into other meaningful names, such as " + "the group and package of the generated operations. For example,\n" + "'--lib_name=my_lib' generates the operations under the " + "'org.tensorflow.op.mylib' package and add them to the 'myLib()' operator\n" + "group.\n\n" + "Note that the operator group assigned to the generated wrappers is just " + "a annotation tag at this stage. Operations will not be available through\n" + "the Ops API as a group until the generated classes are compiled using an " + "appropriate annotation processor.\n\n" + "Finally, the '--base_package' overrides the default parent package " + "under which the generated subpackage and classes are to be located.\n\n"; } // namespace op_gen } // namespace tensorflow int main(int argc, char* argv[]) { - tensorflow::string ops_file; + tensorflow::string lib_name; tensorflow::string output_dir; + tensorflow::string base_package = "org.tensorflow.op"; std::vector flag_list = { - tensorflow::Flag("file", &ops_file, - "name of the ops file linked to this executable"), - tensorflow::Flag("output", &output_dir, - "base directory where to output generated files") + tensorflow::Flag("output_dir", &output_dir, + "Root directory into which output files are generated"), + tensorflow::Flag("lib_name", &lib_name, + "A name, in snake_case, used to classify this set of operations"), + tensorflow::Flag("base_package", &base_package, + "Package parent to the generated subpackage and classes") }; tensorflow::string usage = tensorflow::op_gen::kUsageHeader; usage += tensorflow::Flags::Usage(argv[0], flag_list); bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); - QCHECK(parsed_flags_ok && !ops_file.empty() && !output_dir.empty()) << usage; tensorflow::port::InitMain(usage.c_str(), &argc, &argv); + QCHECK(parsed_flags_ok && !lib_name.empty() && !output_dir.empty()) << usage; - tensorflow::OpGenerator generator(tensorflow::Env::Default(), output_dir); + tensorflow::OpGenerator generator; tensorflow::OpList ops; tensorflow::OpRegistry::Global()->Export(true, &ops); - tensorflow::Status status = generator.Run(ops_file, ops); + tensorflow::Status status = + generator.Run(ops, lib_name, base_package, output_dir); TF_QCHECK_OK(status); return 0; diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index bda1f68abc2..f755f982e43 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -20,25 +20,46 @@ limitations under the License. #include "tensorflow/java/src/gen/cc/op_generator.h" namespace tensorflow { +namespace op_gen { -OpGenerator::OpGenerator(Env* env, const string& output_dir) - : env(env), output_path(output_dir + "/src/main/java/") { +string CamelCase(const string& str, char delimiter, bool upper) { + string result; + bool cap = upper; + for (string::const_iterator it = str.begin(); it != str.end(); ++it) { + const char c = *it; + if (c == delimiter) { + cap = true; + } else if (cap) { + result += toupper(c); + cap = false; + } else { + result += c; + } + } + return result; +} + +} // namespace op_gen + +OpGenerator::OpGenerator() + : env(Env::Default()) { } OpGenerator::~OpGenerator() {} -Status OpGenerator::Run(const string& ops_file, const OpList& ops) { - const string& lib_name = ops_file.substr(0, ops_file.find_last_of('_')); - const string package_name = - str_util::StringReplace("org.tensorflow.op." + lib_name, "_", "", true); +Status OpGenerator::Run(const OpList& ops, const string& lib_name, + const string& base_package, const string& output_dir) { + const string package = + base_package + '.' + str_util::StringReplace(lib_name, "_", "", true); const string package_path = - output_path + str_util::StringReplace(package_name, ".", "/", true); + output_dir + '/' + str_util::StringReplace(package, ".", "/", true); + const string group = op_gen::CamelCase(lib_name, '_', false); if (!env->FileExists(package_path).ok()) { TF_CHECK_OK(env->RecursivelyCreateDir(package_path)); } - LOG(INFO) << "Generating Java wrappers for \"" << lib_name << "\" operations"; + LOG(INFO) << "Generating Java wrappers for '" << lib_name << "' operations"; // TODO(karllessard) generate wrappers from list of ops return Status::OK(); diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h index b0d8cb05af5..98a1f8d5346 100644 --- a/tensorflow/java/src/gen/cc/op_generator.h +++ b/tensorflow/java/src/gen/cc/op_generator.h @@ -27,31 +27,23 @@ namespace tensorflow { /// \brief A generator of Java operation wrappers. /// /// Such generator is normally ran only once per executable, outputting -/// wrappers for the ops library it has been linked with. Nonetheless, -/// it is designed to support multiple runs, giving a different list of -/// operations on each cycle. +/// wrappers for the all registered operations it has been compiled with. +/// Nonetheless, it is designed to support multiple runs, giving a different +/// list of operations on each cycle. class OpGenerator { public: - /// \brief Create a new generator, giving an environment and an - /// output directory path. - explicit OpGenerator(Env* env, const string& output_dir); + OpGenerator(); virtual ~OpGenerator(); /// \brief Generates wrappers for the given list of 'ops'. /// - /// The list of operations should be issued from the library whose - /// file name starts with 'ops_file' (see /core/ops/*.cc). - /// - /// Generated files are output under this directory: - /// /src/main/java/org/tensorflow/java/op/ - /// where - /// 'output_dir' is the directory passed in the constructor and - /// 'group' is extracted from the 'ops_file' name - Status Run(const string& ops_file, const OpList& ops); + /// Output files are generated in //, + /// where 'lib_package' is derived from 'lib_name'. + Status Run(const OpList& ops, const string& lib_name, + const string& base_package, const string& output_dir); private: Env* env; - const string output_path; }; } // namespace tensorflow diff --git a/tensorflow/java/src/gen/gen_ops.bzl b/tensorflow/java/src/gen/gen_ops.bzl index e881c35d4df..5ef01d68dff 100644 --- a/tensorflow/java/src/gen/gen_ops.bzl +++ b/tensorflow/java/src/gen/gen_ops.bzl @@ -8,7 +8,7 @@ load("//tensorflow:tensorflow.bzl", "tf_copts") # Then, combine all those source files into a single archive (.srcjar). # # For example: -# java_op_gen_srcjar("gen_sources", "gen_tool", [ "array_ops", "math_ops" ]) +# tf_java_op_gen_srcjar("gen_sources", "gen_tool", [ "array_ops", "math_ops" ]) # # will create a genrule named "gen_sources" that first generate source files: # ops/src/main/java/org/tensorflow/op/array/*.java @@ -17,18 +17,21 @@ load("//tensorflow:tensorflow.bzl", "tf_copts") # and then archive those source files in: # ops/gen_sources.srcjar # -def java_op_gen_srcjar(name, - gen_tool, - ops_libs=[], - ops_libs_pkg="//tensorflow/core", - out_dir="ops/", - visibility=["//tensorflow/java:__pkg__"]): +def tf_java_op_gen_srcjar(name, + gen_tool, + gen_base_package, + ops_libs=[], + ops_libs_pkg="//tensorflow/core", + out_dir="ops/", + out_src_dir="src/main/java/", + visibility=["//tensorflow/java:__pkg__"]): gen_tools = [] gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files # Construct an op generator binary for each ops library. for ops_lib in ops_libs: + gen_lib = ops_lib[:ops_lib.rfind('_')] out_gen_tool = out_dir + ops_lib + "_gen_tool" native.cc_binary( @@ -39,7 +42,10 @@ def java_op_gen_srcjar(name, deps=[gen_tool, ops_libs_pkg + ":" + ops_lib + "_op_lib"]) gen_tools += [":" + out_gen_tool] - gen_cmds += ["$(location :" + out_gen_tool + ") --output=$(@D) --file=" + ops_lib] + gen_cmds += ["$(location :" + out_gen_tool + ")" + + " --output_dir=$(@D)/" + out_src_dir + + " --lib_name=" + gen_lib + + " --base_package=" + gen_base_package] # Generate a source archive containing generated code for these ops. gen_srcjar = out_dir + name + ".srcjar"