Integrating comments from the 1st code review round.

This commit is contained in:
karl@kubx.ca 2017-08-14 23:49:29 -04:00 committed by Martin Wicke
parent 621c2dcf27
commit 7dcc1ab72e
9 changed files with 212 additions and 208 deletions

View File

@ -5,8 +5,9 @@ package(default_visibility = ["//visibility:private"])
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
load("build_defs", "JAVACOPTS") load(":build_defs.bzl", "JAVACOPTS")
load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_gen_op_wrappers_java") load(":src/gen/gen_ops.bzl", "java_op_gen_srcjar")
load("//tensorflow:tensorflow.bzl", "tf_copts")
java_library( java_library(
name = "tensorflow", name = "tensorflow",
@ -36,15 +37,16 @@ filegroup(
filegroup( filegroup(
name = "java_op_sources", name = "java_op_sources",
srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]) + [ srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]) + [
":java_op_gensources", ":java_op_gen_sources",
], ],
visibility = [ visibility = [
"//tensorflow/java:__pkg__", "//tensorflow/java:__pkg__",
], ],
) )
tf_gen_op_wrappers_java( java_op_gen_srcjar(
name = "java_op_gensources", name = "java_op_gen_sources",
gen_tool = "java_op_gen_tool",
ops_libs = [ ops_libs = [
"array_ops", "array_ops",
"candidate_sampling_ops", "candidate_sampling_ops",
@ -67,6 +69,24 @@ tf_gen_op_wrappers_java(
], ],
) )
# Build the gen tool as a library, as it will be linked to a core/ops binary
# file before making it an executable. See java_op_gen_srcjar().
cc_library(
name = "java_op_gen_tool",
srcs = glob([
"src/gen/cc/*.h",
"src/gen/cc/*.cc",
]),
copts = tf_copts(),
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:proto_text",
"//tensorflow/core:protos_all_cc",
],
)
java_library( java_library(
name = "testutil", name = "testutil",
testonly = 1, testonly = 1,
@ -309,20 +329,3 @@ filegroup(
), ),
visibility = ["//tensorflow:__subpackages__"], visibility = ["//tensorflow:__subpackages__"],
) )
cc_library(
name = "java_ops_gentool",
srcs = glob([
"src/gen/cc/gen_util.h",
"src/gen/cc/ops/*.h",
"src/gen/cc/ops/*.cc",
]),
copts = tf_copts(),
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:proto_text",
"//tensorflow/core:protos_all_cc",
],
)

View File

@ -1,46 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_GEN_UTIL_H_
#define TENSORFLOW_JAVA_SRC_GEN_CC_GEN_UTIL_H_
#include <string>
namespace tensorflow {
namespace gen_util {
inline std::string StripChar(const std::string& str, const char old_char) {
std::string ret;
for (const char& c : str) {
if (c != old_char) {
ret.push_back(c);
}
}
return ret;
}
inline std::string ReplaceChar(const std::string& str, const char old_char,
const char new_char) {
std::string ret;
for (const char& c : str) {
ret.push_back(c == old_char ? new_char : c);
}
return ret;
}
} // namespace gen_util
} // namespace tensorflow
#endif // TENSORFLOW_JAVA_SRC_GEN_CC_GEN_UTIL_H_

View File

@ -0,0 +1,61 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include <vector>
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/java/src/gen/cc/op_generator.h"
namespace tensorflow {
namespace op_gen {
const char kUsageHeader[] =
"\n\nGenerator of operation wrappers in Java.\n\n"
"This executable generates wrappers for all operations registered in the\n"
"ops file it has been linked to (i.e. one of the /core/ops/*.o binaries).\n"
"Generated files are output to the path provided as an argument, under\n"
"their appropriate package and using a maven-style directory layout.\n\n";
} // namespace op_gen
} // namespace tensorflow
int main(int argc, char* argv[]) {
tensorflow::string ops_file;
tensorflow::string output_dir;
std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("file", &ops_file,
"name of the ops file linked to this executable"),
tensorflow::Flag("output", &output_dir,
"base directory where to output generated files")
};
tensorflow::string usage = tensorflow::op_gen::kUsageHeader;
usage += tensorflow::Flags::Usage(argv[0], flag_list);
bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
QCHECK(parsed_flags_ok && !ops_file.empty() && !output_dir.empty()) << usage;
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
tensorflow::OpGenerator generator(tensorflow::Env::Default(), output_dir);
tensorflow::OpList ops;
tensorflow::OpRegistry::Global()->Export(true, &ops);
tensorflow::Status status = generator.Run(ops_file, ops);
TF_QCHECK_OK(status);
return 0;
}

View File

@ -15,38 +15,33 @@ limitations under the License.
#include <string> #include <string>
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/java/src/gen/cc/op_generator.h"
#include "tensorflow/java/src/gen/cc/gen_util.h"
#include "tensorflow/java/src/gen/cc/ops/op_generator.h"
namespace tensorflow { namespace tensorflow {
OpGenerator::OpGenerator(const std::string& ops_lib) OpGenerator::OpGenerator(Env* env, const string& output_dir)
: ops_lib(ops_lib) { : env(env), output_path(output_dir + "/src/main/java/") {
const std::string& lib_name = ops_lib.substr(0, ops_lib.find_last_of('_'));
package_name = gen_util::StripChar("org.tensorflow.op." + lib_name, '_');
} }
OpGenerator::~OpGenerator() {} OpGenerator::~OpGenerator() {}
int OpGenerator::Run(const std::string& output_path) { Status OpGenerator::Run(const string& ops_file, const OpList& ops) {
tensorflow::Env* env = tensorflow::Env::Default(); const string& lib_name = ops_file.substr(0, ops_file.find_last_of('_'));
std::string package_path(output_path); const string package_name =
package_path += gen_util::ReplaceChar(package_name, '.', '/'); str_util::StringReplace("org.tensorflow.op." + lib_name, "_", "", true);
const string package_path =
output_path + str_util::StringReplace(package_name, ".", "/", true);
if (!env->FileExists(package_path).ok()) { if (!env->FileExists(package_path).ok()) {
TF_CHECK_OK(env->RecursivelyCreateDir(package_path)); TF_CHECK_OK(env->RecursivelyCreateDir(package_path));
} }
OpList ops; LOG(INFO) << "Generating Java wrappers for \"" << lib_name << "\" operations";
OpRegistry::Global()->Export(true, &ops); // TODO(karllessard) generate wrappers from list of ops
// TODO(karllessard) generate wrappers from collected ops return Status::OK();
return 0;
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -0,0 +1,59 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_OP_GENERATOR_H_
#define TENSORFLOW_JAVA_SRC_GEN_CC_OP_GENERATOR_H_
#include <string>
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
/// \brief A generator of Java operation wrappers.
///
/// Such generator is normally ran only once per executable, outputting
/// wrappers for the ops library it has been linked with. Nonetheless,
/// it is designed to support multiple runs, giving a different list of
/// operations on each cycle.
class OpGenerator {
public:
/// \brief Create a new generator, giving an environment and an
/// output directory path.
explicit OpGenerator(Env* env, const string& output_dir);
virtual ~OpGenerator();
/// \brief Generates wrappers for the given list of 'ops'.
///
/// The list of operations should be issued from the library whose
/// file name starts with 'ops_file' (see /core/ops/*.cc).
///
/// Generated files are output under this directory:
/// <output_dir>/src/main/java/org/tensorflow/java/op/<group>
/// where
/// 'output_dir' is the directory passed in the constructor and
/// 'group' is extracted from the 'ops_file' name
Status Run(const string& ops_file, const OpList& ops);
private:
Env* env;
const string output_path;
};
} // namespace tensorflow
#endif // TENSORFLOW_JAVA_SRC_GEN_CC_OP_GENERATOR_H_

View File

@ -1,33 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/java/src/gen/cc/ops/op_generator.h"
int main(int argc, char* argv[]) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
const std::string& output_dir = argv[1];
const std::string& ops_lib = argv[2];
LOG(INFO) << "Generating Java operation wrappers for \""
<< ops_lib << "\" library";
tensorflow::OpGenerator generator(ops_lib);
return generator.Run(output_dir + "/src/main/java/");
}

View File

@ -1,37 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_OPS_OP_GENERATOR_H_
#define TENSORFLOW_JAVA_SRC_GEN_CC_OPS_OP_GENERATOR_H_
#include <string>
namespace tensorflow {
class OpGenerator {
public:
explicit OpGenerator(const std::string& ops_lib);
virtual ~OpGenerator();
int Run(const std::string& output_path);
private:
const std::string ops_lib;
std::string package_name;
};
} // namespace tensorflow
#endif // TENSORFLOW_JAVA_SRC_GEN_CC_OPS_OP_GENERATOR_H_

View File

@ -0,0 +1,53 @@
# -*- Python -*-
load("//tensorflow:tensorflow.bzl", "tf_copts")
# Given a list of "ops_libs" (a list of files in the core/ops directory
# without their .cc extensions), generate Java wrapper code for all operations
# found in the ops files.
# Then, combine all those source files into a single archive (.srcjar).
#
# For example:
# java_op_gen_srcjar("gen_sources", "gen_tool", [ "array_ops", "math_ops" ])
#
# will create a genrule named "gen_sources" that first generate source files:
# ops/src/main/java/org/tensorflow/op/array/*.java
# ops/src/main/java/org/tensorflow/op/math/*.java
#
# and then archive those source files in:
# ops/gen_sources.srcjar
#
def java_op_gen_srcjar(name,
gen_tool,
ops_libs=[],
ops_libs_pkg="//tensorflow/core",
out_dir="ops/",
visibility=["//tensorflow/java:__pkg__"]):
gen_tools = []
gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files
# Construct an op generator binary for each ops library.
for ops_lib in ops_libs:
out_gen_tool = out_dir + ops_lib + "_gen_tool"
native.cc_binary(
name=out_gen_tool,
copts=tf_copts(),
linkopts=["-lm"],
linkstatic=1, # Faster to link this one-time-use binary dynamically
deps=[gen_tool, ops_libs_pkg + ":" + ops_lib + "_op_lib"])
gen_tools += [":" + out_gen_tool]
gen_cmds += ["$(location :" + out_gen_tool + ") --output=$(@D) --file=" + ops_lib]
# Generate a source archive containing generated code for these ops.
gen_srcjar = out_dir + name + ".srcjar"
gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) ."]
native.genrule(
name=name,
srcs=["@local_jdk//:jar"],
outs=[gen_srcjar],
tools=gen_tools,
cmd='&&'.join(gen_cmds))

View File

@ -430,57 +430,6 @@ def tf_gen_op_wrapper_py(name,
],) ],)
# Given a list of "ops_libs" (a list of files in the core/ops directory
# without their .cc extensions), generate Java wrapper code for all operations
# found in the ops files.
# Then, combine all those source files into a single archive (.srcjar).
#
# For example:
# tf_gen_op_wrappers_java("java_op_gensources", [ "array_ops", "math_ops" ])
#
# will create a genrule named "java_op_gensources" that first generate source files:
# ops/src/main/java/org/tensorflow/op/array/*.java
# ops/src/main/java/org/tensorflow/op/math/*.java
#
# and then archive those source files in:
# ops/java_op_gensources.srcjar
#
def tf_gen_op_wrappers_java(name,
ops_libs=[],
ops_libs_pkg="//tensorflow/core",
out_dir="ops/",
gen_main=clean_dep("//tensorflow/java:java_ops_gentool"),
visibility=["//tensorflow/java:__pkg__"]):
gen_tools = []
gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files
# Construct an op generator binary for each ops library.
for ops_lib in ops_libs:
gen_tool = out_dir + ops_lib + "_gentool"
native.cc_binary(
name=gen_tool,
copts=tf_copts(),
linkopts=["-lm"],
linkstatic=1, # Faster to link this one-time-use binary dynamically
deps=[gen_main, ops_libs_pkg + ":" + ops_lib + "_op_lib"])
gen_tools += [":" + gen_tool]
gen_cmds += ["$(location :" + gen_tool + ") $(@D) " + ops_lib]
# Generate a source archive containing generated code for these ops.
gen_srcjar = out_dir + name + ".srcjar"
gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) ."]
native.genrule(
name=name,
srcs=["@local_jdk//:jar"],
outs=[gen_srcjar],
tools=gen_tools,
cmd='&&'.join(gen_cmds))
# Define a bazel macro that creates cc_test for tensorflow. # Define a bazel macro that creates cc_test for tensorflow.
# TODO(opensource): we need to enable this to work around the hidden symbol # TODO(opensource): we need to enable this to work around the hidden symbol
# __cudaRegisterFatBinary error. Need more investigations. # __cudaRegisterFatBinary error. Need more investigations.