Simplify and improve generics handling in generator
This commit is contained in:
parent
4bfedb4f2e
commit
eac1479f04
@ -15,6 +15,7 @@ JAVA_VERSION_OPTS = [
|
|||||||
XLINT_OPTS = [
|
XLINT_OPTS = [
|
||||||
"-Werror",
|
"-Werror",
|
||||||
"-Xlint:all",
|
"-Xlint:all",
|
||||||
|
"-Xlint:-processing",
|
||||||
"-Xlint:-serial",
|
"-Xlint:-serial",
|
||||||
"-Xlint:-try",
|
"-Xlint:-try",
|
||||||
"-Xlint:-classfile", # see b/32750402, go/javac-warnings#classfile
|
"-Xlint:-classfile", # see b/32750402, go/javac-warnings#classfile
|
||||||
|
@ -67,10 +67,10 @@ int main(int argc, char* argv[]) {
|
|||||||
QCHECK(parsed_flags_ok && !output_dir.empty()) << usage;
|
QCHECK(parsed_flags_ok && !output_dir.empty()) << usage;
|
||||||
std::vector<tensorflow::string> api_dirs = tensorflow::str_util::Split(
|
std::vector<tensorflow::string> api_dirs = tensorflow::str_util::Split(
|
||||||
api_dirs_str, ",", tensorflow::str_util::SkipEmpty());
|
api_dirs_str, ",", tensorflow::str_util::SkipEmpty());
|
||||||
tensorflow::java::OpGenerator generator(base_package, output_dir, api_dirs);
|
tensorflow::java::OpGenerator generator(api_dirs);
|
||||||
tensorflow::OpList ops;
|
tensorflow::OpList ops;
|
||||||
tensorflow::OpRegistry::Global()->Export(false, &ops);
|
tensorflow::OpRegistry::Global()->Export(false, &ops);
|
||||||
TF_CHECK_OK(generator.Run(ops));
|
TF_CHECK_OK(generator.Run(ops, base_package, output_dir));
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include <list>
|
#include <list>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <set>
|
#include <set>
|
||||||
|
#include <ctime>
|
||||||
|
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
@ -38,23 +39,18 @@ namespace {
|
|||||||
const char* kLicenseSnippet =
|
const char* kLicenseSnippet =
|
||||||
"tensorflow/java/src/gen/resources/license.java.snippet";
|
"tensorflow/java/src/gen/resources/license.java.snippet";
|
||||||
|
|
||||||
const std::map<string, Type> kPrimitiveAttrTypes = {
|
|
||||||
{ "Boolean", Type::Boolean() },
|
|
||||||
{ "Byte", Type::Byte() },
|
|
||||||
{ "Character", Type::Byte() },
|
|
||||||
{ "Float", Type::Float() },
|
|
||||||
{ "Integer", Type::Long() },
|
|
||||||
{ "Long", Type::Long() },
|
|
||||||
{ "Short", Type::Long() },
|
|
||||||
{ "Double", Type::Float() },
|
|
||||||
};
|
|
||||||
|
|
||||||
enum RenderMode {
|
enum RenderMode {
|
||||||
DEFAULT,
|
DEFAULT,
|
||||||
SINGLE_OUTPUT,
|
SINGLE_OUTPUT,
|
||||||
SINGLE_LIST_OUTPUT
|
SINGLE_LIST_OUTPUT
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline void AddArgument(const Variable& var, const string& description,
|
||||||
|
Method* method_out, Javadoc* javadoc_out) {
|
||||||
|
method_out->add_argument(var);
|
||||||
|
javadoc_out->add_param_tag(var.name(), description);
|
||||||
|
}
|
||||||
|
|
||||||
void CollectOpDependencies(const OpSpec& op, RenderMode mode,
|
void CollectOpDependencies(const OpSpec& op, RenderMode mode,
|
||||||
std::list<Type>* out) {
|
std::list<Type>* out) {
|
||||||
out->push_back(Type::Class("Operation", "org.tensorflow"));
|
out->push_back(Type::Class("Operation", "org.tensorflow"));
|
||||||
@ -81,9 +77,7 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode,
|
|||||||
}
|
}
|
||||||
for (const AttributeSpec& attribute : op.attributes()) {
|
for (const AttributeSpec& attribute : op.attributes()) {
|
||||||
out->push_back(attribute.var().type());
|
out->push_back(attribute.var().type());
|
||||||
if (attribute.var().type().name() == "Class") {
|
out->push_back(attribute.jni_type());
|
||||||
out->push_back(Type::Enum("DataType", "org.tensorflow"));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
|
for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
|
||||||
out->push_back(optional_attribute.var().type());
|
out->push_back(optional_attribute.var().type());
|
||||||
@ -92,45 +86,38 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode,
|
|||||||
|
|
||||||
void WriteSetAttrDirective(const AttributeSpec& attr, bool optional,
|
void WriteSetAttrDirective(const AttributeSpec& attr, bool optional,
|
||||||
SourceWriter* writer) {
|
SourceWriter* writer) {
|
||||||
string var = optional ? "opts." + attr.var().name() : attr.var().name();
|
string var_name = optional ? "opts." + attr.var().name() : attr.var().name();
|
||||||
if (attr.iterable()) {
|
if (attr.iterable()) {
|
||||||
const Type& type = attr.type();
|
string array_name = attr.var().name() + "Array";
|
||||||
std::map<string, Type>::const_iterator it =
|
writer->AppendType(attr.jni_type())
|
||||||
kPrimitiveAttrTypes.find(type.name());
|
.Append("[] " + array_name + " = new ")
|
||||||
if (it != kPrimitiveAttrTypes.end()) {
|
.AppendType(attr.jni_type())
|
||||||
string array = attr.var().name() + "Array";
|
.Append("[" + var_name + ".size()];")
|
||||||
writer->AppendType(it->second)
|
|
||||||
.Append("[] " + array + " = new ")
|
|
||||||
.AppendType(it->second)
|
|
||||||
.Append("[" + var + ".size()];")
|
|
||||||
.EndLine();
|
|
||||||
writer->BeginBlock("for (int i = 0; i < " + array + ".length; ++i)")
|
|
||||||
.Append(array + "[i] = " + var + ".get(i);")
|
|
||||||
.EndLine()
|
.EndLine()
|
||||||
.EndBlock()
|
.BeginBlock("for (int i = 0; i < " + array_name + ".length; ++i)")
|
||||||
.Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", " + array)
|
.Append(array_name + "[i] = ");
|
||||||
.Append(");")
|
if (attr.type().kind() == Type::GENERIC) {
|
||||||
.EndLine();
|
writer->Append("DataType.fromClass(" + var_name + ".get(i));");
|
||||||
} else {
|
} else {
|
||||||
writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", " + var)
|
writer->Append(var_name + ".get(i);");
|
||||||
.Append(".toArray(new ")
|
|
||||||
.AppendType(type)
|
|
||||||
.Append("[" + var + ".size()]));")
|
|
||||||
.EndLine();
|
|
||||||
}
|
}
|
||||||
|
writer->EndLine()
|
||||||
|
.EndBlock()
|
||||||
|
.Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
|
||||||
|
.Append(array_name + ");")
|
||||||
|
.EndLine();
|
||||||
} else {
|
} else {
|
||||||
Type type = attr.var().type();
|
|
||||||
writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ");
|
writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ");
|
||||||
if (type.name() == "Class") {
|
if (attr.var().type().name() == "Class") {
|
||||||
writer->Append("DataType.fromClass(" + attr.var().name() + "));");
|
writer->Append("DataType.fromClass(" + var_name + "));");
|
||||||
} else {
|
} else {
|
||||||
writer->Append(var + ");");
|
writer->Append(var_name + ");");
|
||||||
}
|
}
|
||||||
writer->EndLine();
|
writer->EndLine();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void RenderFactoryMethod(const OpSpec& op, const Type& op_class,
|
void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
|
||||||
SourceWriter* writer) {
|
SourceWriter* writer) {
|
||||||
Method factory = Method::Create("create", op_class);
|
Method factory = Method::Create("create", op_class);
|
||||||
Javadoc factory_doc = Javadoc::Create(
|
Javadoc factory_doc = Javadoc::Create(
|
||||||
@ -138,27 +125,24 @@ void RenderFactoryMethod(const OpSpec& op, const Type& op_class,
|
|||||||
+ " operation to the graph.");
|
+ " operation to the graph.");
|
||||||
Variable scope =
|
Variable scope =
|
||||||
Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
|
Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
|
||||||
factory.add_argument(scope);
|
AddArgument(scope, "current graph scope", &factory, &factory_doc);
|
||||||
factory_doc.add_param_tag(scope.name(), "Current graph scope");
|
|
||||||
for (const ArgumentSpec& input : op.inputs()) {
|
for (const ArgumentSpec& input : op.inputs()) {
|
||||||
factory.add_argument(input.var());
|
AddArgument(input.var(), input.description(), &factory, &factory_doc);
|
||||||
factory_doc.add_param_tag(input.var().name(), input.description());
|
|
||||||
}
|
}
|
||||||
for (const AttributeSpec& attribute : op.attributes()) {
|
for (const AttributeSpec& attr : op.attributes()) {
|
||||||
factory.add_argument(attribute.var());
|
AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
|
||||||
factory_doc.add_param_tag(attribute.var().name(), attribute.description());
|
|
||||||
}
|
}
|
||||||
if (!op.optional_attributes().empty()) {
|
if (!op.optional_attributes().empty()) {
|
||||||
factory.add_argument(Variable::Varargs("options", Type::Class("Options")));
|
AddArgument(Variable::Varargs("options", Type::Class("Options")),
|
||||||
factory_doc.add_param_tag("options", "carries optional attributes values");
|
"carries optional attributes values", &factory, &factory_doc);
|
||||||
}
|
}
|
||||||
factory_doc.add_tag("return", "a new instance of " + op_class.name());
|
factory_doc.add_tag("return", "a new instance of " + op_class.name());
|
||||||
|
|
||||||
writer->BeginMethod(factory, PUBLIC|STATIC, &factory_doc);
|
writer->BeginMethod(factory, PUBLIC|STATIC, &factory_doc);
|
||||||
writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\""
|
writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\""
|
||||||
+ op.graph_op_name() + "\", scope.makeOpName(\""
|
+ op.graph_op_name() + "\", scope.makeOpName(\""
|
||||||
+ op_class.name() + "\"));");
|
+ op_class.name() + "\"));");
|
||||||
writer->EndLine();
|
writer->EndLine();
|
||||||
|
|
||||||
for (const ArgumentSpec& input : op.inputs()) {
|
for (const ArgumentSpec& input : op.inputs()) {
|
||||||
if (input.iterable()) {
|
if (input.iterable()) {
|
||||||
writer->Append("opBuilder.addInputList(Operands.asOutputs("
|
writer->Append("opBuilder.addInputList(Operands.asOutputs("
|
||||||
@ -192,10 +176,9 @@ void RenderFactoryMethod(const OpSpec& op, const Type& op_class,
|
|||||||
|
|
||||||
void RenderConstructor(const OpSpec& op, const Type& op_class,
|
void RenderConstructor(const OpSpec& op, const Type& op_class,
|
||||||
SourceWriter* writer) {
|
SourceWriter* writer) {
|
||||||
Method constructor = Method::ConstructorFor(op_class)
|
Variable operation =
|
||||||
.add_argument(
|
Variable::Create("operation", Type::Class("Operation", "org.tensorflow"));
|
||||||
Variable::Create("operation",
|
Method constructor = Method::ConstructorFor(op_class).add_argument(operation);
|
||||||
Type::Class("Operation", "org.tensorflow")));
|
|
||||||
for (const ArgumentSpec& output : op.outputs()) {
|
for (const ArgumentSpec& output : op.outputs()) {
|
||||||
if (output.iterable() && !output.type().unknown()) {
|
if (output.iterable() && !output.type().unknown()) {
|
||||||
constructor.add_annotation(
|
constructor.add_annotation(
|
||||||
@ -237,15 +220,14 @@ void RenderConstructor(const OpSpec& op, const Type& op_class,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) {
|
void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) {
|
||||||
for (const AttributeSpec& attribute : op.optional_attributes()) {
|
for (const AttributeSpec& attr : op.optional_attributes()) {
|
||||||
Method setter =
|
Method setter =
|
||||||
Method::Create(attribute.var().name(), Type::Class("Options"))
|
Method::Create(attr.var().name(), Type::Class("Options"));
|
||||||
.add_argument(attribute.var());
|
Javadoc setter_doc = Javadoc::Create();
|
||||||
Javadoc setter_doc = Javadoc::Create()
|
AddArgument(attr.var(), attr.description(), &setter, &setter_doc);
|
||||||
.add_param_tag(attribute.var().name(), attribute.description());
|
|
||||||
writer->BeginMethod(setter, PUBLIC|STATIC, &setter_doc)
|
writer->BeginMethod(setter, PUBLIC|STATIC, &setter_doc)
|
||||||
.Append("return new Options()." + attribute.var().name() + "("
|
.Append("return new Options()." + attr.var().name() + "("
|
||||||
+ attribute.var().name() + ");")
|
+ attr.var().name() + ");")
|
||||||
.EndLine()
|
.EndLine()
|
||||||
.EndMethod();
|
.EndMethod();
|
||||||
}
|
}
|
||||||
@ -311,14 +293,12 @@ void RenderOptionsClass(const OpSpec& op, const Type& op_class,
|
|||||||
Javadoc options_doc = Javadoc::Create(
|
Javadoc options_doc = Javadoc::Create(
|
||||||
"Optional attributes for {@link " + op_class.full_name() + "}");
|
"Optional attributes for {@link " + op_class.full_name() + "}");
|
||||||
writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc);
|
writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc);
|
||||||
for (const AttributeSpec& attribute : op.optional_attributes()) {
|
for (const AttributeSpec& attr : op.optional_attributes()) {
|
||||||
Method setter = Method::Create(attribute.var().name(), options_class)
|
Method setter = Method::Create(attr.var().name(), options_class);
|
||||||
.add_argument(attribute.var());
|
Javadoc setter_doc = Javadoc::Create();
|
||||||
Javadoc setter_doc = Javadoc::Create()
|
AddArgument(attr.var(), attr.description(), &setter, &setter_doc);
|
||||||
.add_param_tag(attribute.var().name(), attribute.description());
|
|
||||||
writer->BeginMethod(setter, PUBLIC, &setter_doc)
|
writer->BeginMethod(setter, PUBLIC, &setter_doc)
|
||||||
.Append("this." + attribute.var().name() + " = "
|
.Append("this." + attr.var().name() + " = " + attr.var().name() + ";")
|
||||||
+ attribute.var().name() + ";")
|
|
||||||
.EndLine()
|
.EndLine()
|
||||||
.Append("return this;")
|
.Append("return this;")
|
||||||
.EndLine()
|
.EndLine()
|
||||||
@ -339,12 +319,13 @@ inline Type ClassOf(const EndpointSpec& endpoint, const string& base_package) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
|
void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
|
||||||
const string& base_package, const string& output_dir, Env* env) {
|
const string& base_package, const string& output_dir, Env* env,
|
||||||
|
const std::tm* timestamp) {
|
||||||
Type op_class(ClassOf(endpoint, base_package)
|
Type op_class(ClassOf(endpoint, base_package)
|
||||||
.add_supertype(Type::Class("PrimitiveOp", "org.tensorflow.op")));
|
.add_supertype(Type::Class("PrimitiveOp", "org.tensorflow.op")));
|
||||||
Javadoc op_javadoc(endpoint.javadoc());
|
Javadoc op_javadoc(endpoint.javadoc());
|
||||||
|
|
||||||
// implement Operand (or Iterable<Operand>) if the op has only one output
|
// op interfaces
|
||||||
RenderMode mode = DEFAULT;
|
RenderMode mode = DEFAULT;
|
||||||
if (op.outputs().size() == 1) {
|
if (op.outputs().size() == 1) {
|
||||||
const ArgumentSpec& output = op.outputs().front();
|
const ArgumentSpec& output = op.outputs().front();
|
||||||
@ -360,18 +341,22 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
|
|||||||
op_class.add_supertype(operand_inf);
|
op_class.add_supertype(operand_inf);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// declare all outputs generics at the op class level
|
// op generic parameters
|
||||||
std::set<string> generics;
|
std::set<string> generics;
|
||||||
for (const ArgumentSpec& output : op.outputs()) {
|
for (const ArgumentSpec& output : op.outputs()) {
|
||||||
if (output.type().kind() == Type::GENERIC && !output.type().unknown()
|
if (output.type().kind() == Type::GENERIC && !output.type().unknown()
|
||||||
&& generics.find(output.type().name()) == generics.end()) {
|
&& generics.find(output.type().name()) == generics.end()) {
|
||||||
op_class.add_parameter(output.type());
|
op_class.add_parameter(output.type());
|
||||||
op_javadoc.add_param_tag("<" + output.type().name() + ">",
|
op_javadoc.add_param_tag("<" + output.type().name() + ">",
|
||||||
"data type of output {@code " + output.var().name() + "}");
|
"data type for {@code " + output.var().name() + "()} output");
|
||||||
generics.insert(output.type().name());
|
generics.insert(output.type().name());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// handle endpoint deprecation
|
// op annotations
|
||||||
|
char date[20];
|
||||||
|
strftime(date, sizeof date, "%FT%TZ", timestamp);
|
||||||
|
op_class.add_annotation(Annotation::Create("Generated", "javax.annotation")
|
||||||
|
.attributes(string("value = \"op_generator\", date = \"") + date + "\""));
|
||||||
if (endpoint.deprecated()) {
|
if (endpoint.deprecated()) {
|
||||||
op_class.add_annotation(Annotation::Create("Deprecated"));
|
op_class.add_annotation(Annotation::Create("Deprecated"));
|
||||||
string explanation;
|
string explanation;
|
||||||
@ -384,8 +369,8 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
|
|||||||
}
|
}
|
||||||
op_javadoc.add_tag("deprecated", explanation);
|
op_javadoc.add_tag("deprecated", explanation);
|
||||||
}
|
}
|
||||||
// expose the op in the Ops Graph API only if it is visible
|
|
||||||
if (!op.hidden()) {
|
if (!op.hidden()) {
|
||||||
|
// expose the op in the Ops Graph API only if it is visible
|
||||||
op_class.add_annotation(
|
op_class.add_annotation(
|
||||||
Annotation::Create("Operator", "org.tensorflow.op.annotation")
|
Annotation::Create("Operator", "org.tensorflow.op.annotation")
|
||||||
.attributes("group = \"" + endpoint.package() + "\""));
|
.attributes("group = \"" + endpoint.package() + "\""));
|
||||||
@ -405,15 +390,12 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
|
|||||||
std::list<Type> dependencies;
|
std::list<Type> dependencies;
|
||||||
CollectOpDependencies(op, mode, &dependencies);
|
CollectOpDependencies(op, mode, &dependencies);
|
||||||
writer.WriteFromFile(kLicenseSnippet)
|
writer.WriteFromFile(kLicenseSnippet)
|
||||||
.EndLine()
|
|
||||||
.Append("// This file is machine generated, DO NOT EDIT!")
|
|
||||||
.EndLine()
|
|
||||||
.EndLine()
|
.EndLine()
|
||||||
.BeginType(op_class, PUBLIC|FINAL, &dependencies, &op_javadoc);
|
.BeginType(op_class, PUBLIC|FINAL, &dependencies, &op_javadoc);
|
||||||
if (!op.optional_attributes().empty()) {
|
if (!op.optional_attributes().empty()) {
|
||||||
RenderOptionsClass(op, op_class, &writer);
|
RenderOptionsClass(op, op_class, &writer);
|
||||||
}
|
}
|
||||||
RenderFactoryMethod(op, op_class, &writer);
|
RenderFactoryMethods(op, op_class, &writer);
|
||||||
RenderGettersAndSetters(op, &writer);
|
RenderGettersAndSetters(op, &writer);
|
||||||
if (mode != DEFAULT) {
|
if (mode != DEFAULT) {
|
||||||
RenderInterfaceImpl(op, mode, &writer);
|
RenderInterfaceImpl(op, mode, &writer);
|
||||||
@ -428,13 +410,8 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
OpGenerator::OpGenerator(const string& base_package, const string& output_dir,
|
Status OpGenerator::Run(const OpList& op_list, const string& base_package,
|
||||||
const std::vector<string>& api_dirs, Env* env)
|
const string& output_dir) {
|
||||||
: base_package_(base_package), output_dir_(output_dir), api_dirs_(api_dirs),
|
|
||||||
env_(env) {
|
|
||||||
}
|
|
||||||
|
|
||||||
Status OpGenerator::Run(const OpList& op_list) {
|
|
||||||
ApiDefMap api_map(op_list);
|
ApiDefMap api_map(op_list);
|
||||||
if (!api_dirs_.empty()) {
|
if (!api_dirs_.empty()) {
|
||||||
// Only load api files that correspond to the requested "op_list"
|
// Only load api files that correspond to the requested "op_list"
|
||||||
@ -449,12 +426,14 @@ Status OpGenerator::Run(const OpList& op_list) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
api_map.UpdateDocs();
|
api_map.UpdateDocs();
|
||||||
|
time_t now;
|
||||||
|
time(&now);
|
||||||
for (const auto& op_def : op_list.op()) {
|
for (const auto& op_def : op_list.op()) {
|
||||||
const ApiDef* api_def = api_map.GetApiDef(op_def.name());
|
const ApiDef* api_def = api_map.GetApiDef(op_def.name());
|
||||||
if (api_def->visibility() != ApiDef::SKIP) {
|
if (api_def->visibility() != ApiDef::SKIP) {
|
||||||
OpSpec op(OpSpec::Create(op_def, *api_def));
|
OpSpec op(OpSpec::Create(op_def, *api_def));
|
||||||
for (const EndpointSpec& endpoint : op.endpoints()) {
|
for (const EndpointSpec& endpoint : op.endpoints()) {
|
||||||
GenerateOp(op, endpoint, base_package_, output_dir_, env_);
|
GenerateOp(op, endpoint, base_package, output_dir, env_, gmtime(&now));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -36,18 +36,17 @@ namespace java {
|
|||||||
// ops definitions.
|
// ops definitions.
|
||||||
class OpGenerator {
|
class OpGenerator {
|
||||||
public:
|
public:
|
||||||
OpGenerator(const string& base_package, const string& output_dir,
|
explicit OpGenerator(const std::vector<string>& api_dirs,
|
||||||
const std::vector<string>& api_dirs, Env* env = Env::Default());
|
Env* env = Env::Default()) : api_dirs_(api_dirs), env_(env) {}
|
||||||
|
|
||||||
// Generates wrappers for the given list of 'ops'.
|
// Generates wrappers for the given list of 'ops'.
|
||||||
//
|
//
|
||||||
// Output files are generated in <output_dir>/<base_package>/<lib_package>,
|
// Output files are generated in <output_dir>/<base_package>/<op_package>,
|
||||||
// where 'lib_package' is derived from ops endpoints.
|
// where 'op_package' is derived from ops endpoints.
|
||||||
Status Run(const OpList& op_list);
|
Status Run(const OpList& op_list, const string& base_package,
|
||||||
|
const string& output_dir);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const string base_package_;
|
|
||||||
const string output_dir_;
|
|
||||||
const std::vector<string> api_dirs_;
|
const std::vector<string> api_dirs_;
|
||||||
Env* env_;
|
Env* env_;
|
||||||
};
|
};
|
||||||
|
@ -46,14 +46,30 @@ class TypeResolver {
|
|||||||
explicit TypeResolver(const OpDef& op_def) : op_def_(op_def) {}
|
explicit TypeResolver(const OpDef& op_def) : op_def_(op_def) {}
|
||||||
|
|
||||||
Type TypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out);
|
Type TypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out);
|
||||||
Type TypeOf(const OpDef_AttrDef& attr_def, bool *iterable_out);
|
std::pair<Type, Type> TypeOf(const OpDef_AttrDef& attr_def,
|
||||||
|
bool *iterable_out);
|
||||||
bool IsAttributeVisited(const string& attr_name) {
|
bool IsAttributeVisited(const string& attr_name) {
|
||||||
return visited_attrs_.find(attr_name) != visited_attrs_.cend();
|
return visited_attrs_.find(attr_name) != visited_attrs_.cend();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const OpDef op_def_;
|
const OpDef op_def_;
|
||||||
std::map<std::string, Type> visited_attrs_;
|
std::map<std::string, Type> visited_attrs_;
|
||||||
char next_generic_ = 'T';
|
char next_generic_letter_ = 'T';
|
||||||
|
|
||||||
|
std::pair<Type, Type> MakeTypePair(const Type& type, const Type& jni_type) {
|
||||||
|
return std::make_pair(type, jni_type);
|
||||||
|
}
|
||||||
|
std::pair<Type, Type> MakeTypePair(const Type& type) {
|
||||||
|
return std::make_pair(type, type);
|
||||||
|
}
|
||||||
|
Type NextGeneric() {
|
||||||
|
char generic_letter = next_generic_letter_++;
|
||||||
|
if (next_generic_letter_ > 'Z') {
|
||||||
|
next_generic_letter_ = 'A';
|
||||||
|
}
|
||||||
|
return Type::Generic(string(1, generic_letter));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def,
|
Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def,
|
||||||
@ -107,7 +123,7 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def,
|
|||||||
} else {
|
} else {
|
||||||
for (const auto& attr_def : op_def_.attr()) {
|
for (const auto& attr_def : op_def_.attr()) {
|
||||||
if (attr_def.name() == arg_def.type_attr()) {
|
if (attr_def.name() == arg_def.type_attr()) {
|
||||||
type = TypeOf(attr_def, iterable_out);
|
type = TypeOf(attr_def, iterable_out).first;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -125,51 +141,47 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def,
|
|||||||
return type;
|
return type;
|
||||||
}
|
}
|
||||||
|
|
||||||
Type TypeResolver::TypeOf(const OpDef_AttrDef& attr_def,
|
std::pair<Type, Type> TypeResolver::TypeOf(const OpDef_AttrDef& attr_def,
|
||||||
bool* iterable_out) {
|
bool* iterable_out) {
|
||||||
|
std::pair<Type, Type> types = MakeTypePair(Type::Wildcard());
|
||||||
*iterable_out = false;
|
*iterable_out = false;
|
||||||
StringPiece attr_type = attr_def.type();
|
StringPiece attr_type = attr_def.type();
|
||||||
if (str_util::ConsumePrefix(&attr_type, "list(")) {
|
if (str_util::ConsumePrefix(&attr_type, "list(")) {
|
||||||
attr_type.remove_suffix(1); // remove closing brace
|
attr_type.remove_suffix(1); // remove closing brace
|
||||||
*iterable_out = true;
|
*iterable_out = true;
|
||||||
}
|
}
|
||||||
Type type = *iterable_out ? Type::Wildcard() : Type::Class("Object");
|
if (attr_type == "string") {
|
||||||
if (attr_type == "type") {
|
types = MakeTypePair(Type::Class("String"));
|
||||||
if (*iterable_out) {
|
|
||||||
type = Type::Enum("DataType", "org.tensorflow");
|
|
||||||
} else {
|
|
||||||
type = Type::Generic(string(1, next_generic_));
|
|
||||||
next_generic_ = (next_generic_ == 'Z') ? 'A' : next_generic_ + 1;
|
|
||||||
if (IsRealNumbers(attr_def.allowed_values())) {
|
|
||||||
// enforce real numbers datasets by extending java.lang.Number
|
|
||||||
type.add_supertype(Type::Class("Number"));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (attr_type == "string") {
|
|
||||||
type = Type::Class("String");
|
|
||||||
|
|
||||||
} else if (attr_type == "int") {
|
} else if (attr_type == "int") {
|
||||||
type = Type::Class("Integer");
|
types = MakeTypePair(Type::Class("Long"), Type::Long());
|
||||||
|
|
||||||
} else if (attr_type == "float") {
|
} else if (attr_type == "float") {
|
||||||
type = Type::Class("Float");
|
types = MakeTypePair(Type::Class("Float"), Type::Float());
|
||||||
|
|
||||||
} else if (attr_type == "bool") {
|
} else if (attr_type == "bool") {
|
||||||
type = Type::Class("Boolean");
|
types = MakeTypePair(Type::Class("Boolean"), Type::Boolean());
|
||||||
|
|
||||||
} else if (attr_type == "shape") {
|
} else if (attr_type == "shape") {
|
||||||
type = Type::Class("Shape", "org.tensorflow");
|
types = MakeTypePair(Type::Class("Shape", "org.tensorflow"));
|
||||||
|
|
||||||
} else if (attr_type == "tensor") {
|
} else if (attr_type == "tensor") {
|
||||||
type = Type::Class("Tensor", "org.tensorflow")
|
types = MakeTypePair(Type::Class("Tensor", "org.tensorflow")
|
||||||
.add_parameter(Type::Wildcard());
|
.add_parameter(Type::Wildcard()));
|
||||||
|
|
||||||
|
} else if (attr_type == "type") {
|
||||||
|
Type type = *iterable_out ? Type::Wildcard() : NextGeneric();
|
||||||
|
if (IsRealNumbers(attr_def.allowed_values())) {
|
||||||
|
type.add_supertype(Type::Class("Number"));
|
||||||
|
}
|
||||||
|
types = MakeTypePair(type, Type::Enum("DataType", "org.tensorflow"));
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type
|
LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type
|
||||||
<< "\" in operation \"" << op_def_.name() << "\"";
|
<< "\" in operation \"" << op_def_.name() << "\"";
|
||||||
}
|
}
|
||||||
visited_attrs_.insert(std::make_pair(attr_def.name(), type));
|
visited_attrs_.insert(std::make_pair(attr_def.name(), types.first));
|
||||||
return type;
|
return types;
|
||||||
}
|
}
|
||||||
|
|
||||||
string SnakeToCamelCase(const string& str, bool upper = false) {
|
string SnakeToCamelCase(const string& str, bool upper = false) {
|
||||||
@ -307,19 +319,19 @@ ArgumentSpec CreateInput(const OpDef_ArgDef& input_def,
|
|||||||
AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
|
AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
|
||||||
const ApiDef::Attr& attr_api_def, TypeResolver* type_resolver) {
|
const ApiDef::Attr& attr_api_def, TypeResolver* type_resolver) {
|
||||||
bool iterable = false;
|
bool iterable = false;
|
||||||
Type type = type_resolver->TypeOf(attr_def, &iterable);
|
std::pair<Type, Type> types = type_resolver->TypeOf(attr_def, &iterable);
|
||||||
// type attributes must be passed explicitly in methods as a Class<> parameter
|
Type var_type = types.first.kind() == Type::GENERIC ?
|
||||||
bool is_explicit = type.kind() == Type::GENERIC && !iterable;
|
Type::Class("Class").add_parameter(types.first) : types.first;
|
||||||
Type var_type = is_explicit ? Type::Class("Class").add_parameter(type) : type;
|
|
||||||
if (iterable) {
|
if (iterable) {
|
||||||
var_type = Type::ListOf(type);
|
var_type = Type::ListOf(var_type);
|
||||||
}
|
}
|
||||||
return AttributeSpec(attr_api_def.name(),
|
return AttributeSpec(attr_api_def.name(),
|
||||||
Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type),
|
Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type),
|
||||||
type,
|
types.first,
|
||||||
|
types.second,
|
||||||
ParseDocumentation(attr_api_def.description()),
|
ParseDocumentation(attr_api_def.description()),
|
||||||
iterable,
|
iterable,
|
||||||
attr_api_def.has_default_value() && !is_explicit);
|
attr_api_def.has_default_value());
|
||||||
}
|
}
|
||||||
|
|
||||||
ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def,
|
ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def,
|
||||||
@ -340,7 +352,6 @@ ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def,
|
|||||||
|
|
||||||
EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def,
|
EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def,
|
||||||
const ApiDef_Endpoint& endpoint_def) {
|
const ApiDef_Endpoint& endpoint_def) {
|
||||||
|
|
||||||
std::vector<string> name_tokens = str_util::Split(endpoint_def.name(), ".");
|
std::vector<string> name_tokens = str_util::Split(endpoint_def.name(), ".");
|
||||||
string package;
|
string package;
|
||||||
string name;
|
string name;
|
||||||
@ -381,7 +392,7 @@ OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) {
|
|||||||
AttributeSpec attr = CreateAttribute(op_def.attr(i), api_def.attr(i),
|
AttributeSpec attr = CreateAttribute(op_def.attr(i), api_def.attr(i),
|
||||||
&type_resolver);
|
&type_resolver);
|
||||||
// attributes with a default value are optional
|
// attributes with a default value are optional
|
||||||
if (attr.optional()) {
|
if (attr.has_default_value() && attr.type().kind() != Type::GENERIC) {
|
||||||
op.optional_attributes_.push_back(attr);
|
op.optional_attributes_.push_back(attr);
|
||||||
} else {
|
} else {
|
||||||
op.attributes_.push_back(attr);
|
op.attributes_.push_back(attr);
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/op_def.pb.h"
|
#include "tensorflow/core/framework/op_def.pb.h"
|
||||||
#include "tensorflow/core/framework/api_def.pb.h"
|
#include "tensorflow/core/framework/api_def.pb.h"
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/java/src/gen/cc/java_defs.h"
|
#include "tensorflow/java/src/gen/cc/java_defs.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -87,20 +88,23 @@ class AttributeSpec : public ArgumentSpec {
|
|||||||
// op_def_name: attribute name, as known by TensorFlow core
|
// op_def_name: attribute name, as known by TensorFlow core
|
||||||
// var: a variable to represent this attribute in Java
|
// var: a variable to represent this attribute in Java
|
||||||
// type: the type of this attribute
|
// type: the type of this attribute
|
||||||
|
// jni_type: the type of this attribute in JNI layer (see OperationBuilder)
|
||||||
// description: a description of this attribute, in javadoc
|
// description: a description of this attribute, in javadoc
|
||||||
// iterable: true if this attribute is a list
|
// iterable: true if this attribute is a list
|
||||||
// optional: true if this attribute does not require to be set explicitly
|
// has_default_value: true if this attribute has a default value if not set
|
||||||
AttributeSpec(const string& op_def_name, const Variable& var,
|
AttributeSpec(const string& op_def_name, const Variable& var,
|
||||||
const Type& type, const string& description, bool iterable,
|
const Type& type, const Type& jni_type, const string& description,
|
||||||
bool optional)
|
bool iterable, bool has_default_value)
|
||||||
: ArgumentSpec(op_def_name, var, type, description, iterable),
|
: ArgumentSpec(op_def_name, var, type, description, iterable),
|
||||||
optional_(optional) {}
|
jni_type_(jni_type), has_default_value_(has_default_value) {}
|
||||||
virtual ~AttributeSpec() = default;
|
virtual ~AttributeSpec() = default;
|
||||||
|
|
||||||
bool optional() const { return optional_; }
|
const Type& jni_type() const { return jni_type_; }
|
||||||
|
bool has_default_value() const { return has_default_value_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const bool optional_;
|
const Type jni_type_;
|
||||||
|
const bool has_default_value_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class OpSpec {
|
class OpSpec {
|
||||||
|
Loading…
Reference in New Issue
Block a user