Introduce basic CC library for generating TypeScript files for TensorFlow.js from registered Ops.
This initial change provides the very basics to start generating TypeScript. Non-deprecated and visible Ops are exported as a typescript function using internal functionality that is used the @tensorflow/tfjs-node repo (https://github.com/tensorflow/tfjs-node). Future changes will introduce more code generation + tests. This initial change will help set the foundation for those upcoming changes. PiperOrigin-RevId: 209528126
This commit is contained in:
parent
debd8b6b4e
commit
49115abfd3
52
tensorflow/js/BUILD
Normal file
52
tensorflow/js/BUILD
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
# Description:
|
||||||
|
# JavaScript/TypeScript code generation for TensorFlow.js
|
||||||
|
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow:internal",
|
||||||
|
]
|
||||||
|
|
||||||
|
package(default_visibility = visibility)
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
load(
|
||||||
|
"//tensorflow:tensorflow.bzl",
|
||||||
|
"tf_cc_test",
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "ts_op_gen",
|
||||||
|
srcs = [
|
||||||
|
"ops/ts_op_gen.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"ops/ts_op_gen.h",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:op_gen_lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "ts_op_gen_test",
|
||||||
|
srcs = [
|
||||||
|
"ops/ts_op_gen.cc",
|
||||||
|
"ops/ts_op_gen.h",
|
||||||
|
"ops/ts_op_gen_test.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:op_gen_lib",
|
||||||
|
"//tensorflow/core:proto_text",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
],
|
||||||
|
)
|
199
tensorflow/js/ops/ts_op_gen.cc
Normal file
199
tensorflow/js/ops/ts_op_gen.cc
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/js/ops/ts_op_gen.h"
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/api_def.pb.h"
|
||||||
|
#include "tensorflow/core/framework/op_def_util.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
#include "tensorflow/core/public/version.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static bool IsListAttr(const OpDef_ArgDef& arg) {
|
||||||
|
return !arg.type_list_attr().empty() || !arg.number_attr().empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Struct to hold a combo OpDef and ArgDef for a given Op argument:
|
||||||
|
struct ArgDefs {
|
||||||
|
ArgDefs(const OpDef::ArgDef& op_def_arg, const ApiDef::Arg& api_def_arg)
|
||||||
|
: op_def_arg(op_def_arg), api_def_arg(api_def_arg) {}
|
||||||
|
|
||||||
|
const OpDef::ArgDef& op_def_arg;
|
||||||
|
const ApiDef::Arg& api_def_arg;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Helper class to generate TypeScript code for a given OpDef:
|
||||||
|
class GenTypeScriptOp {
|
||||||
|
public:
|
||||||
|
GenTypeScriptOp(const OpDef& op_def, const ApiDef& api_def);
|
||||||
|
~GenTypeScriptOp();
|
||||||
|
|
||||||
|
// Returns the generated code as a string:
|
||||||
|
string Code();
|
||||||
|
|
||||||
|
private:
|
||||||
|
void ProcessArgs();
|
||||||
|
|
||||||
|
void AddMethodSignature();
|
||||||
|
void AddMethodReturnAndClose();
|
||||||
|
|
||||||
|
const OpDef& op_def_;
|
||||||
|
const ApiDef& api_def_;
|
||||||
|
|
||||||
|
// Placeholder string for all generated code:
|
||||||
|
string result_;
|
||||||
|
|
||||||
|
// Holds in-order vector of Op inputs:
|
||||||
|
std::vector<ArgDefs> input_op_args_;
|
||||||
|
|
||||||
|
// Holds number of outputs:
|
||||||
|
int num_outputs_;
|
||||||
|
};
|
||||||
|
|
||||||
|
GenTypeScriptOp::GenTypeScriptOp(const OpDef& op_def, const ApiDef& api_def)
|
||||||
|
: op_def_(op_def), api_def_(api_def), num_outputs_(0) {}
|
||||||
|
|
||||||
|
GenTypeScriptOp::~GenTypeScriptOp() {}
|
||||||
|
|
||||||
|
string GenTypeScriptOp::Code() {
|
||||||
|
ProcessArgs();
|
||||||
|
|
||||||
|
// Generate exported function for Op:
|
||||||
|
AddMethodSignature();
|
||||||
|
AddMethodReturnAndClose();
|
||||||
|
|
||||||
|
strings::StrAppend(&result_, "\n");
|
||||||
|
return result_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GenTypeScriptOp::ProcessArgs() {
|
||||||
|
for (int i = 0; i < api_def_.arg_order_size(); i++) {
|
||||||
|
auto op_def_arg = FindInputArg(api_def_.arg_order(i), op_def_);
|
||||||
|
if (op_def_arg == nullptr) {
|
||||||
|
LOG(WARNING) << "Could not find OpDef::ArgDef for "
|
||||||
|
<< api_def_.arg_order(i);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto api_def_arg = FindInputArg(api_def_.arg_order(i), api_def_);
|
||||||
|
if (api_def_arg == nullptr) {
|
||||||
|
LOG(WARNING) << "Could not find ApiDef::Arg for "
|
||||||
|
<< api_def_.arg_order(i);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
input_op_args_.push_back(ArgDefs(*op_def_arg, *api_def_arg));
|
||||||
|
}
|
||||||
|
|
||||||
|
num_outputs_ = api_def_.out_arg_size();
|
||||||
|
}
|
||||||
|
|
||||||
|
void GenTypeScriptOp::AddMethodSignature() {
|
||||||
|
strings::StrAppend(&result_, "export function ", api_def_.endpoint(0).name(),
|
||||||
|
"(");
|
||||||
|
|
||||||
|
bool is_first = true;
|
||||||
|
for (auto& in_arg : input_op_args_) {
|
||||||
|
if (is_first) {
|
||||||
|
is_first = false;
|
||||||
|
} else {
|
||||||
|
strings::StrAppend(&result_, ", ");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto op_def_arg = in_arg.op_def_arg;
|
||||||
|
|
||||||
|
strings::StrAppend(&result_, op_def_arg.name(), ": ");
|
||||||
|
if (IsListAttr(op_def_arg)) {
|
||||||
|
strings::StrAppend(&result_, "tfc.Tensor[]");
|
||||||
|
} else {
|
||||||
|
strings::StrAppend(&result_, "tfc.Tensor");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (num_outputs_ == 1) {
|
||||||
|
strings::StrAppend(&result_, "): tfc.Tensor {\n");
|
||||||
|
} else {
|
||||||
|
strings::StrAppend(&result_, "): tfc.Tensor[] {\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void GenTypeScriptOp::AddMethodReturnAndClose() {
|
||||||
|
strings::StrAppend(&result_, " return null;\n}\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void WriteTSOp(const OpDef& op_def, const ApiDef& api_def, WritableFile* ts) {
|
||||||
|
GenTypeScriptOp ts_op(op_def, api_def);
|
||||||
|
TF_CHECK_OK(ts->Append(GenTypeScriptOp(op_def, api_def).Code()));
|
||||||
|
}
|
||||||
|
|
||||||
|
void StartFile(WritableFile* ts_file) {
|
||||||
|
const string header =
|
||||||
|
R"header(/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2018 Google Inc. All Rights Reserved.
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
* =============================================================================
|
||||||
|
*/
|
||||||
|
|
||||||
|
// This file is MACHINE GENERATED! Do not edit
|
||||||
|
|
||||||
|
import * as tfc from '@tensorflow/tfjs-core';
|
||||||
|
import {createTypeOpAttr, getTFDTypeForInputs, nodeBackend} from './op_utils';
|
||||||
|
|
||||||
|
)header";
|
||||||
|
|
||||||
|
TF_CHECK_OK(ts_file->Append(header));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void WriteTSOps(const OpList& ops, const ApiDefMap& api_def_map,
|
||||||
|
const string& ts_filename) {
|
||||||
|
Env* env = Env::Default();
|
||||||
|
|
||||||
|
std::unique_ptr<WritableFile> ts_file = nullptr;
|
||||||
|
TF_CHECK_OK(env->NewWritableFile(ts_filename, &ts_file));
|
||||||
|
|
||||||
|
StartFile(ts_file.get());
|
||||||
|
|
||||||
|
for (const auto& op_def : ops.op()) {
|
||||||
|
// Skip deprecated ops
|
||||||
|
if (op_def.has_deprecation() &&
|
||||||
|
op_def.deprecation().version() <= TF_GRAPH_DEF_VERSION) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto* api_def = api_def_map.GetApiDef(op_def.name());
|
||||||
|
if (api_def->visibility() == ApiDef::VISIBLE) {
|
||||||
|
WriteTSOp(op_def, *api_def, ts_file.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_CHECK_OK(ts_file->Close());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
31
tensorflow/js/ops/ts_op_gen.h
Normal file
31
tensorflow/js/ops/ts_op_gen.h
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_JS_OPS_TS_OP_GEN_H_
|
||||||
|
#define TENSORFLOW_JS_OPS_TS_OP_GEN_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/op_def.pb.h"
|
||||||
|
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Generated code is written to the file ts_filename:
|
||||||
|
void WriteTSOps(const OpList& ops, const ApiDefMap& api_def_map,
|
||||||
|
const string& ts_filename);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_JS_OPS_TS_OP_GEN_H_
|
212
tensorflow/js/ops/ts_op_gen_test.cc
Normal file
212
tensorflow/js/ops/ts_op_gen_test.cc
Normal file
@ -0,0 +1,212 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/js/ops/ts_op_gen.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/op_def.pb.h"
|
||||||
|
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void ExpectContainsStr(StringPiece s, StringPiece expected) {
|
||||||
|
EXPECT_TRUE(str_util::StrContains(s, expected))
|
||||||
|
<< "'" << s << "' does not contain '" << expected << "'";
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExpectDoesNotContainStr(StringPiece s, StringPiece expected) {
|
||||||
|
EXPECT_FALSE(str_util::StrContains(s, expected))
|
||||||
|
<< "'" << s << "' does not contain '" << expected << "'";
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(kreeger): Add multiple outputs here?
|
||||||
|
constexpr char kBaseOpDef[] = R"(
|
||||||
|
op {
|
||||||
|
name: "Foo"
|
||||||
|
input_arg {
|
||||||
|
name: "images"
|
||||||
|
type_attr: "T"
|
||||||
|
number_attr: "N"
|
||||||
|
description: "Images to process."
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "dim"
|
||||||
|
description: "Description for dim."
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "output"
|
||||||
|
description: "Description for output."
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
description: "Type for images"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_UINT8
|
||||||
|
type: DT_INT8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default_value {
|
||||||
|
i: 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "N"
|
||||||
|
type: "int"
|
||||||
|
has_minimum: true
|
||||||
|
minimum: 1
|
||||||
|
}
|
||||||
|
summary: "Summary for op Foo."
|
||||||
|
description: "Description for op Foo."
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
name: "DeprecatedFoo"
|
||||||
|
input_arg {
|
||||||
|
name: "input"
|
||||||
|
description: "Description for input."
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "output"
|
||||||
|
description: "Description for output."
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
deprecation {
|
||||||
|
explanation: "Deprecated."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
name: "MultiOutputFoo"
|
||||||
|
input_arg {
|
||||||
|
name: "input"
|
||||||
|
description: "Description for input."
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "output1"
|
||||||
|
description: "Description for output 1."
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "output2"
|
||||||
|
description: "Description for output 2."
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
summary: "Summary for op MultiOutputFoo."
|
||||||
|
description: "Description for op MultiOutputFoo."
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
// Generate TypeScript code
|
||||||
|
// @param api_def_str TODO doc me.
|
||||||
|
void GenerateTsOpFileText(const string& api_def_str, string* ts_file_text) {
|
||||||
|
Env* env = Env::Default();
|
||||||
|
OpList op_defs;
|
||||||
|
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
|
||||||
|
ApiDefMap api_def_map(op_defs);
|
||||||
|
|
||||||
|
if (!api_def_str.empty()) {
|
||||||
|
TF_ASSERT_OK(api_def_map.LoadApiDef(api_def_str));
|
||||||
|
}
|
||||||
|
|
||||||
|
const string& tmpdir = testing::TmpDir();
|
||||||
|
const auto ts_file_path = io::JoinPath(tmpdir, "test.ts");
|
||||||
|
|
||||||
|
WriteTSOps(op_defs, api_def_map, ts_file_path);
|
||||||
|
TF_ASSERT_OK(ReadFileToString(env, ts_file_path, ts_file_text));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TsOpGenTest, TestImports) {
|
||||||
|
string ts_file_text;
|
||||||
|
GenerateTsOpFileText("", &ts_file_text);
|
||||||
|
|
||||||
|
const string expected = R"(
|
||||||
|
import * as tfc from '@tensorflow/tfjs-core';
|
||||||
|
import {createTypeOpAttr, getTFDTypeForInputs, nodeBackend} from './op_utils';
|
||||||
|
)";
|
||||||
|
ExpectContainsStr(ts_file_text, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TsOpGenTest, InputSingleAndList) {
|
||||||
|
const string api_def = R"(
|
||||||
|
op {
|
||||||
|
name: "Foo"
|
||||||
|
input_arg {
|
||||||
|
name: "images"
|
||||||
|
type_attr: "T"
|
||||||
|
number_attr: "N"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
string ts_file_text;
|
||||||
|
GenerateTsOpFileText(api_def, &ts_file_text);
|
||||||
|
|
||||||
|
const string expected = R"(
|
||||||
|
export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
ExpectContainsStr(ts_file_text, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TsOpGenTest, TestVisibility) {
|
||||||
|
const string api_def = R"(
|
||||||
|
op {
|
||||||
|
graph_op_name: "Foo"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
string ts_file_text;
|
||||||
|
GenerateTsOpFileText(api_def, &ts_file_text);
|
||||||
|
|
||||||
|
const string expected = R"(
|
||||||
|
export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
ExpectDoesNotContainStr(ts_file_text, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TsOpGenTest, SkipDeprecated) {
|
||||||
|
string ts_file_text;
|
||||||
|
GenerateTsOpFileText("", &ts_file_text);
|
||||||
|
|
||||||
|
ExpectDoesNotContainStr(ts_file_text, "DeprecatedFoo");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TsOpGenTest, MultiOutput) {
|
||||||
|
string ts_file_text;
|
||||||
|
GenerateTsOpFileText("", &ts_file_text);
|
||||||
|
|
||||||
|
const string expected = R"(
|
||||||
|
export function MultiOutputFoo(input: tfc.Tensor): tfc.Tensor[] {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
ExpectContainsStr(ts_file_text, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user