Integrating comments from the 2nd code review round

This commit is contained in:
karl@kubx.ca 2017-08-22 16:41:48 -04:00 committed by Martin Wicke
parent 7dcc1ab72e
commit dca9fe2484
5 changed files with 89 additions and 48 deletions

View File

@ -6,7 +6,7 @@ package(default_visibility = ["//visibility:private"])
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
load(":build_defs.bzl", "JAVACOPTS") load(":build_defs.bzl", "JAVACOPTS")
load(":src/gen/gen_ops.bzl", "java_op_gen_srcjar") load(":src/gen/gen_ops.bzl", "tf_java_op_gen_srcjar")
load("//tensorflow:tensorflow.bzl", "tf_copts") load("//tensorflow:tensorflow.bzl", "tf_copts")
java_library( java_library(
@ -44,8 +44,9 @@ filegroup(
], ],
) )
java_op_gen_srcjar( tf_java_op_gen_srcjar(
name = "java_op_gen_sources", name = "java_op_gen_sources",
gen_base_package = "org.tensorflow.op",
gen_tool = "java_op_gen_tool", gen_tool = "java_op_gen_tool",
ops_libs = [ ops_libs = [
"array_ops", "array_ops",
@ -70,7 +71,7 @@ java_op_gen_srcjar(
) )
# Build the gen tool as a library, as it will be linked to a core/ops binary # Build the gen tool as a library, as it will be linked to a core/ops binary
# file before making it an executable. See java_op_gen_srcjar(). # file before making it an executable. See tf_java_op_gen_srcjar().
cc_library( cc_library(
name = "java_op_gen_tool", name = "java_op_gen_tool",
srcs = glob([ srcs = glob([

View File

@ -11,7 +11,7 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <string> #include <string>
#include <vector> #include <vector>
@ -19,6 +19,7 @@
#include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/java/src/gen/cc/op_generator.h" #include "tensorflow/java/src/gen/cc/op_generator.h"
@ -28,33 +29,53 @@ namespace op_gen {
const char kUsageHeader[] = const char kUsageHeader[] =
"\n\nGenerator of operation wrappers in Java.\n\n" "\n\nGenerator of operation wrappers in Java.\n\n"
"This executable generates wrappers for all operations registered in the\n" "This executable generates wrappers for all registered operations it has "
"ops file it has been linked to (i.e. one of the /core/ops/*.o binaries).\n" "been compiled with. A wrapper exposes an intuitive and strongly-typed\n"
"Generated files are output to the path provided as an argument, under\n" "interface for building its underlying operation and linking it into a "
"their appropriate package and using a maven-style directory layout.\n\n"; "graph.\n\n"
"Operation wrappers are generated under the path specified by the "
"'--output_dir' argument. This path can be absolute or relative to the\n"
"current working directory and will be created if it does not exists.\n\n"
"The '--lib_name' argument is used to classify the set of operations. If "
"the chosen name contains more than one word, it must be provided in \n"
"snake_case. This value is declined into other meaningful names, such as "
"the group and package of the generated operations. For example,\n"
"'--lib_name=my_lib' generates the operations under the "
"'org.tensorflow.op.mylib' package and add them to the 'myLib()' operator\n"
"group.\n\n"
"Note that the operator group assigned to the generated wrappers is just "
"a annotation tag at this stage. Operations will not be available through\n"
"the Ops API as a group until the generated classes are compiled using an "
"appropriate annotation processor.\n\n"
"Finally, the '--base_package' overrides the default parent package "
"under which the generated subpackage and classes are to be located.\n\n";
} // namespace op_gen } // namespace op_gen
} // namespace tensorflow } // namespace tensorflow
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
tensorflow::string ops_file; tensorflow::string lib_name;
tensorflow::string output_dir; tensorflow::string output_dir;
tensorflow::string base_package = "org.tensorflow.op";
std::vector<tensorflow::Flag> flag_list = { std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("file", &ops_file, tensorflow::Flag("output_dir", &output_dir,
"name of the ops file linked to this executable"), "Root directory into which output files are generated"),
tensorflow::Flag("output", &output_dir, tensorflow::Flag("lib_name", &lib_name,
"base directory where to output generated files") "A name, in snake_case, used to classify this set of operations"),
tensorflow::Flag("base_package", &base_package,
"Package parent to the generated subpackage and classes")
}; };
tensorflow::string usage = tensorflow::op_gen::kUsageHeader; tensorflow::string usage = tensorflow::op_gen::kUsageHeader;
usage += tensorflow::Flags::Usage(argv[0], flag_list); usage += tensorflow::Flags::Usage(argv[0], flag_list);
bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
QCHECK(parsed_flags_ok && !ops_file.empty() && !output_dir.empty()) << usage;
tensorflow::port::InitMain(usage.c_str(), &argc, &argv); tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
QCHECK(parsed_flags_ok && !lib_name.empty() && !output_dir.empty()) << usage;
tensorflow::OpGenerator generator(tensorflow::Env::Default(), output_dir); tensorflow::OpGenerator generator;
tensorflow::OpList ops; tensorflow::OpList ops;
tensorflow::OpRegistry::Global()->Export(true, &ops); tensorflow::OpRegistry::Global()->Export(true, &ops);
tensorflow::Status status = generator.Run(ops_file, ops); tensorflow::Status status =
generator.Run(ops, lib_name, base_package, output_dir);
TF_QCHECK_OK(status); TF_QCHECK_OK(status);
return 0; return 0;

View File

@ -20,25 +20,46 @@ limitations under the License.
#include "tensorflow/java/src/gen/cc/op_generator.h" #include "tensorflow/java/src/gen/cc/op_generator.h"
namespace tensorflow { namespace tensorflow {
namespace op_gen {
OpGenerator::OpGenerator(Env* env, const string& output_dir) string CamelCase(const string& str, char delimiter, bool upper) {
: env(env), output_path(output_dir + "/src/main/java/") { string result;
bool cap = upper;
for (string::const_iterator it = str.begin(); it != str.end(); ++it) {
const char c = *it;
if (c == delimiter) {
cap = true;
} else if (cap) {
result += toupper(c);
cap = false;
} else {
result += c;
}
}
return result;
}
} // namespace op_gen
OpGenerator::OpGenerator()
: env(Env::Default()) {
} }
OpGenerator::~OpGenerator() {} OpGenerator::~OpGenerator() {}
Status OpGenerator::Run(const string& ops_file, const OpList& ops) { Status OpGenerator::Run(const OpList& ops, const string& lib_name,
const string& lib_name = ops_file.substr(0, ops_file.find_last_of('_')); const string& base_package, const string& output_dir) {
const string package_name = const string package =
str_util::StringReplace("org.tensorflow.op." + lib_name, "_", "", true); base_package + '.' + str_util::StringReplace(lib_name, "_", "", true);
const string package_path = const string package_path =
output_path + str_util::StringReplace(package_name, ".", "/", true); output_dir + '/' + str_util::StringReplace(package, ".", "/", true);
const string group = op_gen::CamelCase(lib_name, '_', false);
if (!env->FileExists(package_path).ok()) { if (!env->FileExists(package_path).ok()) {
TF_CHECK_OK(env->RecursivelyCreateDir(package_path)); TF_CHECK_OK(env->RecursivelyCreateDir(package_path));
} }
LOG(INFO) << "Generating Java wrappers for \"" << lib_name << "\" operations"; LOG(INFO) << "Generating Java wrappers for '" << lib_name << "' operations";
// TODO(karllessard) generate wrappers from list of ops // TODO(karllessard) generate wrappers from list of ops
return Status::OK(); return Status::OK();

View File

@ -27,31 +27,23 @@ namespace tensorflow {
/// \brief A generator of Java operation wrappers. /// \brief A generator of Java operation wrappers.
/// ///
/// Such generator is normally ran only once per executable, outputting /// Such generator is normally ran only once per executable, outputting
/// wrappers for the ops library it has been linked with. Nonetheless, /// wrappers for the all registered operations it has been compiled with.
/// it is designed to support multiple runs, giving a different list of /// Nonetheless, it is designed to support multiple runs, giving a different
/// operations on each cycle. /// list of operations on each cycle.
class OpGenerator { class OpGenerator {
public: public:
/// \brief Create a new generator, giving an environment and an OpGenerator();
/// output directory path.
explicit OpGenerator(Env* env, const string& output_dir);
virtual ~OpGenerator(); virtual ~OpGenerator();
/// \brief Generates wrappers for the given list of 'ops'. /// \brief Generates wrappers for the given list of 'ops'.
/// ///
/// The list of operations should be issued from the library whose /// Output files are generated in <output_dir>/<base_package>/<lib_package>,
/// file name starts with 'ops_file' (see /core/ops/*.cc). /// where 'lib_package' is derived from 'lib_name'.
/// Status Run(const OpList& ops, const string& lib_name,
/// Generated files are output under this directory: const string& base_package, const string& output_dir);
/// <output_dir>/src/main/java/org/tensorflow/java/op/<group>
/// where
/// 'output_dir' is the directory passed in the constructor and
/// 'group' is extracted from the 'ops_file' name
Status Run(const string& ops_file, const OpList& ops);
private: private:
Env* env; Env* env;
const string output_path;
}; };
} // namespace tensorflow } // namespace tensorflow

View File

@ -8,7 +8,7 @@ load("//tensorflow:tensorflow.bzl", "tf_copts")
# Then, combine all those source files into a single archive (.srcjar). # Then, combine all those source files into a single archive (.srcjar).
# #
# For example: # For example:
# java_op_gen_srcjar("gen_sources", "gen_tool", [ "array_ops", "math_ops" ]) # tf_java_op_gen_srcjar("gen_sources", "gen_tool", [ "array_ops", "math_ops" ])
# #
# will create a genrule named "gen_sources" that first generate source files: # will create a genrule named "gen_sources" that first generate source files:
# ops/src/main/java/org/tensorflow/op/array/*.java # ops/src/main/java/org/tensorflow/op/array/*.java
@ -17,18 +17,21 @@ load("//tensorflow:tensorflow.bzl", "tf_copts")
# and then archive those source files in: # and then archive those source files in:
# ops/gen_sources.srcjar # ops/gen_sources.srcjar
# #
def java_op_gen_srcjar(name, def tf_java_op_gen_srcjar(name,
gen_tool, gen_tool,
ops_libs=[], gen_base_package,
ops_libs_pkg="//tensorflow/core", ops_libs=[],
out_dir="ops/", ops_libs_pkg="//tensorflow/core",
visibility=["//tensorflow/java:__pkg__"]): out_dir="ops/",
out_src_dir="src/main/java/",
visibility=["//tensorflow/java:__pkg__"]):
gen_tools = [] gen_tools = []
gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files
# Construct an op generator binary for each ops library. # Construct an op generator binary for each ops library.
for ops_lib in ops_libs: for ops_lib in ops_libs:
gen_lib = ops_lib[:ops_lib.rfind('_')]
out_gen_tool = out_dir + ops_lib + "_gen_tool" out_gen_tool = out_dir + ops_lib + "_gen_tool"
native.cc_binary( native.cc_binary(
@ -39,7 +42,10 @@ def java_op_gen_srcjar(name,
deps=[gen_tool, ops_libs_pkg + ":" + ops_lib + "_op_lib"]) deps=[gen_tool, ops_libs_pkg + ":" + ops_lib + "_op_lib"])
gen_tools += [":" + out_gen_tool] gen_tools += [":" + out_gen_tool]
gen_cmds += ["$(location :" + out_gen_tool + ") --output=$(@D) --file=" + ops_lib] gen_cmds += ["$(location :" + out_gen_tool + ")" +
" --output_dir=$(@D)/" + out_src_dir +
" --lib_name=" + gen_lib +
" --base_package=" + gen_base_package]
# Generate a source archive containing generated code for these ops. # Generate a source archive containing generated code for these ops.
gen_srcjar = out_dir + name + ".srcjar" gen_srcjar = out_dir + name + ".srcjar"