Generate TypeScript Op attribute values for "type" and "int" OpDef attribute types.

This is an incremental change to first introduce updates to the TypeScript internal library and references to building OpDef attribute structs that the TensorFlow.js Node runtime uses. For now, this change introduces basic "type" and "int" attr types. I'll continue to roll more types and complicated examples in upcoming changes.

PiperOrigin-RevId: 210121141
This commit is contained in:
Nick Kreeger 2018-08-24 10:45:37 -07:00 committed by TensorFlower Gardener
parent 37b2b0eb61
commit 90030cc1ef
2 changed files with 178 additions and 53 deletions

View File

@ -38,6 +38,15 @@ struct ArgDefs {
const ApiDef::Arg& api_def_arg; const ApiDef::Arg& api_def_arg;
}; };
// Struct to hold a combo OpDef::AttrDef and ApiDef::Attr for an Op.
struct OpAttrs {
OpAttrs(const OpDef::AttrDef& op_def_attr, const ApiDef::Attr& api_def_attr)
: op_def_attr(op_def_attr), api_def_attr(api_def_attr) {}
const OpDef::AttrDef& op_def_attr;
const ApiDef::Attr& api_def_attr;
};
// Helper class to generate TypeScript code for a given OpDef: // Helper class to generate TypeScript code for a given OpDef:
class GenTypeScriptOp { class GenTypeScriptOp {
public: public:
@ -49,8 +58,12 @@ class GenTypeScriptOp {
private: private:
void ProcessArgs(); void ProcessArgs();
void ProcessAttrs();
void AddAttrForArg(const string& attr, int arg_index);
string InputForAttr(const OpDef::AttrDef& op_def_attr);
void AddMethodSignature(); void AddMethodSignature();
void AddOpAttrs();
void AddMethodReturnAndClose(); void AddMethodReturnAndClose();
const OpDef& op_def_; const OpDef& op_def_;
@ -62,6 +75,13 @@ class GenTypeScriptOp {
// Holds in-order vector of Op inputs: // Holds in-order vector of Op inputs:
std::vector<ArgDefs> input_op_args_; std::vector<ArgDefs> input_op_args_;
// Holds in-order vector of Op attributes:
std::vector<OpAttrs> op_attrs_;
// Stores attributes-to-arguments by name:
typedef std::unordered_map<string, std::vector<int>> AttrArgIdxMap;
AttrArgIdxMap attr_arg_idx_map_;
// Holds number of outputs: // Holds number of outputs:
int num_outputs_; int num_outputs_;
}; };
@ -73,9 +93,11 @@ GenTypeScriptOp::~GenTypeScriptOp() {}
string GenTypeScriptOp::Code() { string GenTypeScriptOp::Code() {
ProcessArgs(); ProcessArgs();
ProcessAttrs();
// Generate exported function for Op: // Generate exported function for Op:
AddMethodSignature(); AddMethodSignature();
AddOpAttrs();
AddMethodReturnAndClose(); AddMethodReturnAndClose();
strings::StrAppend(&result_, "\n"); strings::StrAppend(&result_, "\n");
@ -96,12 +118,52 @@ void GenTypeScriptOp::ProcessArgs() {
<< api_def_.arg_order(i); << api_def_.arg_order(i);
continue; continue;
} }
// Map attr names to arg indexes:
if (!op_def_arg->type_attr().empty()) {
AddAttrForArg(op_def_arg->type_attr(), i);
} else if (!op_def_arg->type_list_attr().empty()) {
AddAttrForArg(op_def_arg->type_list_attr(), i);
}
if (!op_def_arg->number_attr().empty()) {
AddAttrForArg(op_def_arg->number_attr(), i);
}
input_op_args_.push_back(ArgDefs(*op_def_arg, *api_def_arg)); input_op_args_.push_back(ArgDefs(*op_def_arg, *api_def_arg));
} }
num_outputs_ = api_def_.out_arg_size(); num_outputs_ = api_def_.out_arg_size();
} }
void GenTypeScriptOp::ProcessAttrs() {
for (int i = 0; i < op_def_.attr_size(); i++) {
op_attrs_.push_back(OpAttrs(op_def_.attr(i), api_def_.attr(i)));
}
}
void GenTypeScriptOp::AddAttrForArg(const string& attr, int arg_index) {
// Keep track of attributes-to-arguments by name. These will be used for
// construction Op attributes that require information about the inputs.
auto iter = attr_arg_idx_map_.find(attr);
if (iter == attr_arg_idx_map_.end()) {
attr_arg_idx_map_.insert(AttrArgIdxMap::value_type(attr, {arg_index}));
} else {
iter->second.push_back(arg_index);
}
}
string GenTypeScriptOp::InputForAttr(const OpDef::AttrDef& op_def_attr) {
string inputs;
auto arg_list = attr_arg_idx_map_.find(op_def_attr.name());
if (arg_list != attr_arg_idx_map_.end()) {
for (auto iter = arg_list->second.begin(); iter != arg_list->second.end();
++iter) {
strings::StrAppend(&inputs, input_op_args_[*iter].op_def_arg.name());
}
}
return inputs;
}
void GenTypeScriptOp::AddMethodSignature() { void GenTypeScriptOp::AddMethodSignature() {
strings::StrAppend(&result_, "export function ", api_def_.endpoint(0).name(), strings::StrAppend(&result_, "export function ", api_def_.endpoint(0).name(),
"("); "(");
@ -131,6 +193,35 @@ void GenTypeScriptOp::AddMethodSignature() {
} }
} }
void GenTypeScriptOp::AddOpAttrs() {
strings::StrAppend(&result_, " const opAttrs = [\n");
bool is_first = true;
for (auto& attr : op_attrs_) {
if (is_first) {
is_first = false;
} else {
strings::StrAppend(&result_, ",\n");
}
// Append 4 spaces to start:
strings::StrAppend(&result_, " ");
if (attr.op_def_attr.type() == "type") {
// Type OpAttributes can be generated from a helper function:
strings::StrAppend(&result_, "createTensorsTypeOpAttr('",
attr.op_def_attr.name(), "', ",
InputForAttr(attr.op_def_attr), ")");
} else if (attr.op_def_attr.type() == "int") {
strings::StrAppend(&result_, "{name: '", attr.op_def_attr.name(), "', ");
strings::StrAppend(&result_, "type: nodeBackend().binding.TF_ATTR_INT, ");
strings::StrAppend(&result_, "value: ", InputForAttr(attr.op_def_attr),
".length}");
}
}
strings::StrAppend(&result_, "\n ];\n");
}
void GenTypeScriptOp::AddMethodReturnAndClose() { void GenTypeScriptOp::AddMethodReturnAndClose() {
strings::StrAppend(&result_, " return null;\n}\n"); strings::StrAppend(&result_, " return null;\n}\n");
} }
@ -162,7 +253,7 @@ void StartFile(WritableFile* ts_file) {
// This file is MACHINE GENERATED! Do not edit // This file is MACHINE GENERATED! Do not edit
import * as tfc from '@tensorflow/tfjs-core'; import * as tfc from '@tensorflow/tfjs-core';
import {createTypeOpAttr, getTFDTypeForInputs, nodeBackend} from './op_utils'; import {createTensorsTypeOpAttr, nodeBackend} from './op_utils';
)header"; )header";

View File

@ -36,7 +36,6 @@ void ExpectDoesNotContainStr(StringPiece s, StringPiece expected) {
<< "'" << s << "' does not contain '" << expected << "'"; << "'" << s << "' does not contain '" << expected << "'";
} }
// TODO(kreeger): Add multiple outputs here?
constexpr char kBaseOpDef[] = R"( constexpr char kBaseOpDef[] = R"(
op { op {
name: "Foo" name: "Foo"
@ -79,50 +78,15 @@ op {
summary: "Summary for op Foo." summary: "Summary for op Foo."
description: "Description for op Foo." description: "Description for op Foo."
} }
op {
name: "DeprecatedFoo"
input_arg {
name: "input"
description: "Description for input."
type: DT_FLOAT
}
output_arg {
name: "output"
description: "Description for output."
type: DT_FLOAT
}
deprecation {
explanation: "Deprecated."
}
}
op {
name: "MultiOutputFoo"
input_arg {
name: "input"
description: "Description for input."
type: DT_FLOAT
}
output_arg {
name: "output1"
description: "Description for output 1."
type: DT_FLOAT
}
output_arg {
name: "output2"
description: "Description for output 2."
type: DT_FLOAT
}
summary: "Summary for op MultiOutputFoo."
description: "Description for op MultiOutputFoo."
}
)"; )";
// Generate TypeScript code // Generate TypeScript code
// @param api_def_str TODO doc me. void GenerateTsOpFileText(const string& op_def_str, const string& api_def_str,
void GenerateTsOpFileText(const string& api_def_str, string* ts_file_text) { string* ts_file_text) {
Env* env = Env::Default(); Env* env = Env::Default();
OpList op_defs; OpList op_defs;
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); protobuf::TextFormat::ParseFromString(
op_def_str.empty() ? kBaseOpDef : op_def_str, &op_defs);
ApiDefMap api_def_map(op_defs); ApiDefMap api_def_map(op_defs);
if (!api_def_str.empty()) { if (!api_def_str.empty()) {
@ -138,11 +102,11 @@ void GenerateTsOpFileText(const string& api_def_str, string* ts_file_text) {
TEST(TsOpGenTest, TestImports) { TEST(TsOpGenTest, TestImports) {
string ts_file_text; string ts_file_text;
GenerateTsOpFileText("", &ts_file_text); GenerateTsOpFileText("", "", &ts_file_text);
const string expected = R"( const string expected = R"(
import * as tfc from '@tensorflow/tfjs-core'; import * as tfc from '@tensorflow/tfjs-core';
import {createTypeOpAttr, getTFDTypeForInputs, nodeBackend} from './op_utils'; import {createTensorsTypeOpAttr, nodeBackend} from './op_utils';
)"; )";
ExpectContainsStr(ts_file_text, expected); ExpectContainsStr(ts_file_text, expected);
} }
@ -160,12 +124,10 @@ op {
)"; )";
string ts_file_text; string ts_file_text;
GenerateTsOpFileText(api_def, &ts_file_text); GenerateTsOpFileText("", api_def, &ts_file_text);
const string expected = R"( const string expected = R"(
export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor { export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
return null;
}
)"; )";
ExpectContainsStr(ts_file_text, expected); ExpectContainsStr(ts_file_text, expected);
} }
@ -179,34 +141,106 @@ op {
)"; )";
string ts_file_text; string ts_file_text;
GenerateTsOpFileText(api_def, &ts_file_text); GenerateTsOpFileText("", api_def, &ts_file_text);
const string expected = R"( const string expected = R"(
export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor { export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
return null;
}
)"; )";
ExpectDoesNotContainStr(ts_file_text, expected); ExpectDoesNotContainStr(ts_file_text, expected);
} }
TEST(TsOpGenTest, SkipDeprecated) { TEST(TsOpGenTest, SkipDeprecated) {
const string op_def = R"(
op {
name: "DeprecatedFoo"
input_arg {
name: "input"
type_attr: "T"
description: "Description for input."
}
output_arg {
name: "output"
description: "Description for output."
type: DT_FLOAT
}
attr {
name: "T"
type: "type"
description: "Type for input"
allowed_values {
list {
type: DT_FLOAT
}
}
}
deprecation {
explanation: "Deprecated."
}
}
)";
string ts_file_text; string ts_file_text;
GenerateTsOpFileText("", &ts_file_text); GenerateTsOpFileText(op_def, "", &ts_file_text);
ExpectDoesNotContainStr(ts_file_text, "DeprecatedFoo"); ExpectDoesNotContainStr(ts_file_text, "DeprecatedFoo");
} }
TEST(TsOpGenTest, MultiOutput) { TEST(TsOpGenTest, MultiOutput) {
const string op_def = R"(
op {
name: "MultiOutputFoo"
input_arg {
name: "input"
description: "Description for input."
type_attr: "T"
}
output_arg {
name: "output1"
description: "Description for output 1."
type: DT_FLOAT
}
output_arg {
name: "output2"
description: "Description for output 2."
type: DT_FLOAT
}
attr {
name: "T"
type: "type"
description: "Type for input"
allowed_values {
list {
type: DT_FLOAT
}
}
}
summary: "Summary for op MultiOutputFoo."
description: "Description for op MultiOutputFoo."
}
)";
string ts_file_text; string ts_file_text;
GenerateTsOpFileText("", &ts_file_text); GenerateTsOpFileText(op_def, "", &ts_file_text);
const string expected = R"( const string expected = R"(
export function MultiOutputFoo(input: tfc.Tensor): tfc.Tensor[] { export function MultiOutputFoo(input: tfc.Tensor): tfc.Tensor[] {
return null;
}
)"; )";
ExpectContainsStr(ts_file_text, expected); ExpectContainsStr(ts_file_text, expected);
} }
TEST(TsOpGenTest, OpAttrs) {
string ts_file_text;
GenerateTsOpFileText("", "", &ts_file_text);
const string expectedFooAttrs = R"(
const opAttrs = [
createTensorsTypeOpAttr('T', images),
{name: 'N', type: nodeBackend().binding.TF_ATTR_INT, value: images.length}
];
)";
ExpectContainsStr(ts_file_text, expectedFooAttrs);
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow