diff --git a/tensorflow/js/ops/ts_op_gen.cc b/tensorflow/js/ops/ts_op_gen.cc index babf55cd5f2..fb93bb6d8e8 100644 --- a/tensorflow/js/ops/ts_op_gen.cc +++ b/tensorflow/js/ops/ts_op_gen.cc @@ -38,6 +38,15 @@ struct ArgDefs { const ApiDef::Arg& api_def_arg; }; +// Struct to hold a combo OpDef::AttrDef and ApiDef::Attr for an Op. +struct OpAttrs { + OpAttrs(const OpDef::AttrDef& op_def_attr, const ApiDef::Attr& api_def_attr) + : op_def_attr(op_def_attr), api_def_attr(api_def_attr) {} + + const OpDef::AttrDef& op_def_attr; + const ApiDef::Attr& api_def_attr; +}; + // Helper class to generate TypeScript code for a given OpDef: class GenTypeScriptOp { public: @@ -49,8 +58,12 @@ class GenTypeScriptOp { private: void ProcessArgs(); + void ProcessAttrs(); + void AddAttrForArg(const string& attr, int arg_index); + string InputForAttr(const OpDef::AttrDef& op_def_attr); void AddMethodSignature(); + void AddOpAttrs(); void AddMethodReturnAndClose(); const OpDef& op_def_; @@ -62,6 +75,13 @@ class GenTypeScriptOp { // Holds in-order vector of Op inputs: std::vector<ArgDefs> input_op_args_; + // Holds in-order vector of Op attributes: + std::vector<OpAttrs> op_attrs_; + + // Stores attributes-to-arguments by name: + typedef std::unordered_map<string, std::vector<int>> AttrArgIdxMap; + AttrArgIdxMap attr_arg_idx_map_; + // Holds number of outputs: int num_outputs_; }; @@ -73,9 +93,11 @@ GenTypeScriptOp::~GenTypeScriptOp() {} string GenTypeScriptOp::Code() { ProcessArgs(); + ProcessAttrs(); // Generate exported function for Op: AddMethodSignature(); + AddOpAttrs(); AddMethodReturnAndClose(); strings::StrAppend(&result_, "\n"); @@ -96,12 +118,52 @@ void GenTypeScriptOp::ProcessArgs() { << api_def_.arg_order(i); continue; } + + // Map attr names to arg indexes: + if (!op_def_arg->type_attr().empty()) { + AddAttrForArg(op_def_arg->type_attr(), i); + } else if (!op_def_arg->type_list_attr().empty()) { + AddAttrForArg(op_def_arg->type_list_attr(), i); + } + if (!op_def_arg->number_attr().empty()) { + AddAttrForArg(op_def_arg->number_attr(), i); + } + input_op_args_.push_back(ArgDefs(*op_def_arg, *api_def_arg)); } num_outputs_ = api_def_.out_arg_size(); } +void GenTypeScriptOp::ProcessAttrs() { + for (int i = 0; i < op_def_.attr_size(); i++) { + op_attrs_.push_back(OpAttrs(op_def_.attr(i), api_def_.attr(i))); + } +} + +void GenTypeScriptOp::AddAttrForArg(const string& attr, int arg_index) { + // Keep track of attributes-to-arguments by name. These will be used for + // construction Op attributes that require information about the inputs. + auto iter = attr_arg_idx_map_.find(attr); + if (iter == attr_arg_idx_map_.end()) { + attr_arg_idx_map_.insert(AttrArgIdxMap::value_type(attr, {arg_index})); + } else { + iter->second.push_back(arg_index); + } +} + +string GenTypeScriptOp::InputForAttr(const OpDef::AttrDef& op_def_attr) { + string inputs; + auto arg_list = attr_arg_idx_map_.find(op_def_attr.name()); + if (arg_list != attr_arg_idx_map_.end()) { + for (auto iter = arg_list->second.begin(); iter != arg_list->second.end(); + ++iter) { + strings::StrAppend(&inputs, input_op_args_[*iter].op_def_arg.name()); + } + } + return inputs; +} + void GenTypeScriptOp::AddMethodSignature() { strings::StrAppend(&result_, "export function ", api_def_.endpoint(0).name(), "("); @@ -131,6 +193,35 @@ void GenTypeScriptOp::AddMethodSignature() { } } +void GenTypeScriptOp::AddOpAttrs() { + strings::StrAppend(&result_, " const opAttrs = [\n"); + + bool is_first = true; + for (auto& attr : op_attrs_) { + if (is_first) { + is_first = false; + } else { + strings::StrAppend(&result_, ",\n"); + } + + // Append 4 spaces to start: + strings::StrAppend(&result_, " "); + + if (attr.op_def_attr.type() == "type") { + // Type OpAttributes can be generated from a helper function: + strings::StrAppend(&result_, "createTensorsTypeOpAttr('", + attr.op_def_attr.name(), "', ", + InputForAttr(attr.op_def_attr), ")"); + } else if (attr.op_def_attr.type() == "int") { + strings::StrAppend(&result_, "{name: '", attr.op_def_attr.name(), "', "); + strings::StrAppend(&result_, "type: nodeBackend().binding.TF_ATTR_INT, "); + strings::StrAppend(&result_, "value: ", InputForAttr(attr.op_def_attr), + ".length}"); + } + } + strings::StrAppend(&result_, "\n ];\n"); +} + void GenTypeScriptOp::AddMethodReturnAndClose() { strings::StrAppend(&result_, " return null;\n}\n"); } @@ -162,7 +253,7 @@ void StartFile(WritableFile* ts_file) { // This file is MACHINE GENERATED! Do not edit import * as tfc from '@tensorflow/tfjs-core'; -import {createTypeOpAttr, getTFDTypeForInputs, nodeBackend} from './op_utils'; +import {createTensorsTypeOpAttr, nodeBackend} from './op_utils'; )header"; diff --git a/tensorflow/js/ops/ts_op_gen_test.cc b/tensorflow/js/ops/ts_op_gen_test.cc index 9a85c021b09..03241689b5f 100644 --- a/tensorflow/js/ops/ts_op_gen_test.cc +++ b/tensorflow/js/ops/ts_op_gen_test.cc @@ -36,7 +36,6 @@ void ExpectDoesNotContainStr(StringPiece s, StringPiece expected) { << "'" << s << "' does not contain '" << expected << "'"; } -// TODO(kreeger): Add multiple outputs here? constexpr char kBaseOpDef[] = R"( op { name: "Foo" @@ -79,50 +78,15 @@ op { summary: "Summary for op Foo." description: "Description for op Foo." } -op { - name: "DeprecatedFoo" - input_arg { - name: "input" - description: "Description for input." - type: DT_FLOAT - } - output_arg { - name: "output" - description: "Description for output." - type: DT_FLOAT - } - deprecation { - explanation: "Deprecated." - } -} -op { - name: "MultiOutputFoo" - input_arg { - name: "input" - description: "Description for input." - type: DT_FLOAT - } - output_arg { - name: "output1" - description: "Description for output 1." - type: DT_FLOAT - } - output_arg { - name: "output2" - description: "Description for output 2." - type: DT_FLOAT - } - summary: "Summary for op MultiOutputFoo." - description: "Description for op MultiOutputFoo." -} )"; // Generate TypeScript code -// @param api_def_str TODO doc me. -void GenerateTsOpFileText(const string& api_def_str, string* ts_file_text) { +void GenerateTsOpFileText(const string& op_def_str, const string& api_def_str, + string* ts_file_text) { Env* env = Env::Default(); OpList op_defs; - protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); + protobuf::TextFormat::ParseFromString( + op_def_str.empty() ? kBaseOpDef : op_def_str, &op_defs); ApiDefMap api_def_map(op_defs); if (!api_def_str.empty()) { @@ -138,11 +102,11 @@ void GenerateTsOpFileText(const string& api_def_str, string* ts_file_text) { TEST(TsOpGenTest, TestImports) { string ts_file_text; - GenerateTsOpFileText("", &ts_file_text); + GenerateTsOpFileText("", "", &ts_file_text); const string expected = R"( import * as tfc from '@tensorflow/tfjs-core'; -import {createTypeOpAttr, getTFDTypeForInputs, nodeBackend} from './op_utils'; +import {createTensorsTypeOpAttr, nodeBackend} from './op_utils'; )"; ExpectContainsStr(ts_file_text, expected); } @@ -160,12 +124,10 @@ op { )"; string ts_file_text; - GenerateTsOpFileText(api_def, &ts_file_text); + GenerateTsOpFileText("", api_def, &ts_file_text); const string expected = R"( export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor { - return null; -} )"; ExpectContainsStr(ts_file_text, expected); } @@ -179,34 +141,106 @@ op { )"; string ts_file_text; - GenerateTsOpFileText(api_def, &ts_file_text); + GenerateTsOpFileText("", api_def, &ts_file_text); const string expected = R"( export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor { - return null; -} )"; ExpectDoesNotContainStr(ts_file_text, expected); } TEST(TsOpGenTest, SkipDeprecated) { + const string op_def = R"( +op { + name: "DeprecatedFoo" + input_arg { + name: "input" + type_attr: "T" + description: "Description for input." + } + output_arg { + name: "output" + description: "Description for output." + type: DT_FLOAT + } + attr { + name: "T" + type: "type" + description: "Type for input" + allowed_values { + list { + type: DT_FLOAT + } + } + } + deprecation { + explanation: "Deprecated." + } +} +)"; + string ts_file_text; - GenerateTsOpFileText("", &ts_file_text); + GenerateTsOpFileText(op_def, "", &ts_file_text); ExpectDoesNotContainStr(ts_file_text, "DeprecatedFoo"); } TEST(TsOpGenTest, MultiOutput) { + const string op_def = R"( +op { + name: "MultiOutputFoo" + input_arg { + name: "input" + description: "Description for input." + type_attr: "T" + } + output_arg { + name: "output1" + description: "Description for output 1." + type: DT_FLOAT + } + output_arg { + name: "output2" + description: "Description for output 2." + type: DT_FLOAT + } + attr { + name: "T" + type: "type" + description: "Type for input" + allowed_values { + list { + type: DT_FLOAT + } + } + } + summary: "Summary for op MultiOutputFoo." + description: "Description for op MultiOutputFoo." +} +)"; + string ts_file_text; - GenerateTsOpFileText("", &ts_file_text); + GenerateTsOpFileText(op_def, "", &ts_file_text); const string expected = R"( export function MultiOutputFoo(input: tfc.Tensor): tfc.Tensor[] { - return null; -} )"; ExpectContainsStr(ts_file_text, expected); } +TEST(TsOpGenTest, OpAttrs) { + string ts_file_text; + GenerateTsOpFileText("", "", &ts_file_text); + + const string expectedFooAttrs = R"( + const opAttrs = [ + createTensorsTypeOpAttr('T', images), + {name: 'N', type: nodeBackend().binding.TF_ATTR_INT, value: images.length} + ]; +)"; + + ExpectContainsStr(ts_file_text, expectedFooAttrs); +} + } // namespace } // namespace tensorflow