Introduce basic CC library for generating TypeScript files for TensorFlow.js from registered Ops.
This initial change provides the very basics to start generating TypeScript. Non-deprecated and visible Ops are exported as a typescript function using internal functionality that is used the @tensorflow/tfjs-node repo (https://github.com/tensorflow/tfjs-node). Future changes will introduce more code generation + tests. This initial change will help set the foundation for those upcoming changes. PiperOrigin-RevId: 209528126
This commit is contained in:
parent
debd8b6b4e
commit
49115abfd3
52
tensorflow/js/BUILD
Normal file
52
tensorflow/js/BUILD
Normal file
@ -0,0 +1,52 @@
|
||||
# Description:
|
||||
# JavaScript/TypeScript code generation for TensorFlow.js
|
||||
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
]
|
||||
|
||||
package(default_visibility = visibility)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ts_op_gen",
|
||||
srcs = [
|
||||
"ops/ts_op_gen.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"ops/ts_op_gen.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:op_gen_lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "ts_op_gen_test",
|
||||
srcs = [
|
||||
"ops/ts_op_gen.cc",
|
||||
"ops/ts_op_gen.h",
|
||||
"ops/ts_op_gen_test.cc",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:op_gen_lib",
|
||||
"//tensorflow/core:proto_text",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
199
tensorflow/js/ops/ts_op_gen.cc
Normal file
199
tensorflow/js/ops/ts_op_gen.cc
Normal file
@ -0,0 +1,199 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/js/ops/ts_op_gen.h"
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/framework/api_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_def_util.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
static bool IsListAttr(const OpDef_ArgDef& arg) {
|
||||
return !arg.type_list_attr().empty() || !arg.number_attr().empty();
|
||||
}
|
||||
|
||||
// Struct to hold a combo OpDef and ArgDef for a given Op argument:
|
||||
struct ArgDefs {
|
||||
ArgDefs(const OpDef::ArgDef& op_def_arg, const ApiDef::Arg& api_def_arg)
|
||||
: op_def_arg(op_def_arg), api_def_arg(api_def_arg) {}
|
||||
|
||||
const OpDef::ArgDef& op_def_arg;
|
||||
const ApiDef::Arg& api_def_arg;
|
||||
};
|
||||
|
||||
// Helper class to generate TypeScript code for a given OpDef:
|
||||
class GenTypeScriptOp {
|
||||
public:
|
||||
GenTypeScriptOp(const OpDef& op_def, const ApiDef& api_def);
|
||||
~GenTypeScriptOp();
|
||||
|
||||
// Returns the generated code as a string:
|
||||
string Code();
|
||||
|
||||
private:
|
||||
void ProcessArgs();
|
||||
|
||||
void AddMethodSignature();
|
||||
void AddMethodReturnAndClose();
|
||||
|
||||
const OpDef& op_def_;
|
||||
const ApiDef& api_def_;
|
||||
|
||||
// Placeholder string for all generated code:
|
||||
string result_;
|
||||
|
||||
// Holds in-order vector of Op inputs:
|
||||
std::vector<ArgDefs> input_op_args_;
|
||||
|
||||
// Holds number of outputs:
|
||||
int num_outputs_;
|
||||
};
|
||||
|
||||
GenTypeScriptOp::GenTypeScriptOp(const OpDef& op_def, const ApiDef& api_def)
|
||||
: op_def_(op_def), api_def_(api_def), num_outputs_(0) {}
|
||||
|
||||
GenTypeScriptOp::~GenTypeScriptOp() {}
|
||||
|
||||
string GenTypeScriptOp::Code() {
|
||||
ProcessArgs();
|
||||
|
||||
// Generate exported function for Op:
|
||||
AddMethodSignature();
|
||||
AddMethodReturnAndClose();
|
||||
|
||||
strings::StrAppend(&result_, "\n");
|
||||
return result_;
|
||||
}
|
||||
|
||||
void GenTypeScriptOp::ProcessArgs() {
|
||||
for (int i = 0; i < api_def_.arg_order_size(); i++) {
|
||||
auto op_def_arg = FindInputArg(api_def_.arg_order(i), op_def_);
|
||||
if (op_def_arg == nullptr) {
|
||||
LOG(WARNING) << "Could not find OpDef::ArgDef for "
|
||||
<< api_def_.arg_order(i);
|
||||
continue;
|
||||
}
|
||||
auto api_def_arg = FindInputArg(api_def_.arg_order(i), api_def_);
|
||||
if (api_def_arg == nullptr) {
|
||||
LOG(WARNING) << "Could not find ApiDef::Arg for "
|
||||
<< api_def_.arg_order(i);
|
||||
continue;
|
||||
}
|
||||
input_op_args_.push_back(ArgDefs(*op_def_arg, *api_def_arg));
|
||||
}
|
||||
|
||||
num_outputs_ = api_def_.out_arg_size();
|
||||
}
|
||||
|
||||
void GenTypeScriptOp::AddMethodSignature() {
|
||||
strings::StrAppend(&result_, "export function ", api_def_.endpoint(0).name(),
|
||||
"(");
|
||||
|
||||
bool is_first = true;
|
||||
for (auto& in_arg : input_op_args_) {
|
||||
if (is_first) {
|
||||
is_first = false;
|
||||
} else {
|
||||
strings::StrAppend(&result_, ", ");
|
||||
}
|
||||
|
||||
auto op_def_arg = in_arg.op_def_arg;
|
||||
|
||||
strings::StrAppend(&result_, op_def_arg.name(), ": ");
|
||||
if (IsListAttr(op_def_arg)) {
|
||||
strings::StrAppend(&result_, "tfc.Tensor[]");
|
||||
} else {
|
||||
strings::StrAppend(&result_, "tfc.Tensor");
|
||||
}
|
||||
}
|
||||
|
||||
if (num_outputs_ == 1) {
|
||||
strings::StrAppend(&result_, "): tfc.Tensor {\n");
|
||||
} else {
|
||||
strings::StrAppend(&result_, "): tfc.Tensor[] {\n");
|
||||
}
|
||||
}
|
||||
|
||||
void GenTypeScriptOp::AddMethodReturnAndClose() {
|
||||
strings::StrAppend(&result_, " return null;\n}\n");
|
||||
}
|
||||
|
||||
void WriteTSOp(const OpDef& op_def, const ApiDef& api_def, WritableFile* ts) {
|
||||
GenTypeScriptOp ts_op(op_def, api_def);
|
||||
TF_CHECK_OK(ts->Append(GenTypeScriptOp(op_def, api_def).Code()));
|
||||
}
|
||||
|
||||
void StartFile(WritableFile* ts_file) {
|
||||
const string header =
|
||||
R"header(/**
|
||||
* @license
|
||||
* Copyright 2018 Google Inc. All Rights Reserved.
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
* =============================================================================
|
||||
*/
|
||||
|
||||
// This file is MACHINE GENERATED! Do not edit
|
||||
|
||||
import * as tfc from '@tensorflow/tfjs-core';
|
||||
import {createTypeOpAttr, getTFDTypeForInputs, nodeBackend} from './op_utils';
|
||||
|
||||
)header";
|
||||
|
||||
TF_CHECK_OK(ts_file->Append(header));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void WriteTSOps(const OpList& ops, const ApiDefMap& api_def_map,
|
||||
const string& ts_filename) {
|
||||
Env* env = Env::Default();
|
||||
|
||||
std::unique_ptr<WritableFile> ts_file = nullptr;
|
||||
TF_CHECK_OK(env->NewWritableFile(ts_filename, &ts_file));
|
||||
|
||||
StartFile(ts_file.get());
|
||||
|
||||
for (const auto& op_def : ops.op()) {
|
||||
// Skip deprecated ops
|
||||
if (op_def.has_deprecation() &&
|
||||
op_def.deprecation().version() <= TF_GRAPH_DEF_VERSION) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto* api_def = api_def_map.GetApiDef(op_def.name());
|
||||
if (api_def->visibility() == ApiDef::VISIBLE) {
|
||||
WriteTSOp(op_def, *api_def, ts_file.get());
|
||||
}
|
||||
}
|
||||
|
||||
TF_CHECK_OK(ts_file->Close());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
31
tensorflow/js/ops/ts_op_gen.h
Normal file
31
tensorflow/js/ops/ts_op_gen.h
Normal file
@ -0,0 +1,31 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_JS_OPS_TS_OP_GEN_H_
|
||||
#define TENSORFLOW_JS_OPS_TS_OP_GEN_H_
|
||||
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Generated code is written to the file ts_filename:
|
||||
void WriteTSOps(const OpList& ops, const ApiDefMap& api_def_map,
|
||||
const string& ts_filename);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_JS_OPS_TS_OP_GEN_H_
|
212
tensorflow/js/ops/ts_op_gen_test.cc
Normal file
212
tensorflow/js/ops/ts_op_gen_test.cc
Normal file
@ -0,0 +1,212 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/js/ops/ts_op_gen.h"
|
||||
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
void ExpectContainsStr(StringPiece s, StringPiece expected) {
|
||||
EXPECT_TRUE(str_util::StrContains(s, expected))
|
||||
<< "'" << s << "' does not contain '" << expected << "'";
|
||||
}
|
||||
|
||||
void ExpectDoesNotContainStr(StringPiece s, StringPiece expected) {
|
||||
EXPECT_FALSE(str_util::StrContains(s, expected))
|
||||
<< "'" << s << "' does not contain '" << expected << "'";
|
||||
}
|
||||
|
||||
// TODO(kreeger): Add multiple outputs here?
|
||||
constexpr char kBaseOpDef[] = R"(
|
||||
op {
|
||||
name: "Foo"
|
||||
input_arg {
|
||||
name: "images"
|
||||
type_attr: "T"
|
||||
number_attr: "N"
|
||||
description: "Images to process."
|
||||
}
|
||||
input_arg {
|
||||
name: "dim"
|
||||
description: "Description for dim."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "Description for output."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
description: "Type for images"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_UINT8
|
||||
type: DT_INT8
|
||||
}
|
||||
}
|
||||
default_value {
|
||||
i: 1
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "N"
|
||||
type: "int"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
summary: "Summary for op Foo."
|
||||
description: "Description for op Foo."
|
||||
}
|
||||
op {
|
||||
name: "DeprecatedFoo"
|
||||
input_arg {
|
||||
name: "input"
|
||||
description: "Description for input."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "Description for output."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
deprecation {
|
||||
explanation: "Deprecated."
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "MultiOutputFoo"
|
||||
input_arg {
|
||||
name: "input"
|
||||
description: "Description for input."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "output1"
|
||||
description: "Description for output 1."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "output2"
|
||||
description: "Description for output 2."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
summary: "Summary for op MultiOutputFoo."
|
||||
description: "Description for op MultiOutputFoo."
|
||||
}
|
||||
)";
|
||||
|
||||
// Generate TypeScript code
|
||||
// @param api_def_str TODO doc me.
|
||||
void GenerateTsOpFileText(const string& api_def_str, string* ts_file_text) {
|
||||
Env* env = Env::Default();
|
||||
OpList op_defs;
|
||||
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
|
||||
ApiDefMap api_def_map(op_defs);
|
||||
|
||||
if (!api_def_str.empty()) {
|
||||
TF_ASSERT_OK(api_def_map.LoadApiDef(api_def_str));
|
||||
}
|
||||
|
||||
const string& tmpdir = testing::TmpDir();
|
||||
const auto ts_file_path = io::JoinPath(tmpdir, "test.ts");
|
||||
|
||||
WriteTSOps(op_defs, api_def_map, ts_file_path);
|
||||
TF_ASSERT_OK(ReadFileToString(env, ts_file_path, ts_file_text));
|
||||
}
|
||||
|
||||
TEST(TsOpGenTest, TestImports) {
|
||||
string ts_file_text;
|
||||
GenerateTsOpFileText("", &ts_file_text);
|
||||
|
||||
const string expected = R"(
|
||||
import * as tfc from '@tensorflow/tfjs-core';
|
||||
import {createTypeOpAttr, getTFDTypeForInputs, nodeBackend} from './op_utils';
|
||||
)";
|
||||
ExpectContainsStr(ts_file_text, expected);
|
||||
}
|
||||
|
||||
TEST(TsOpGenTest, InputSingleAndList) {
|
||||
const string api_def = R"(
|
||||
op {
|
||||
name: "Foo"
|
||||
input_arg {
|
||||
name: "images"
|
||||
type_attr: "T"
|
||||
number_attr: "N"
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
string ts_file_text;
|
||||
GenerateTsOpFileText(api_def, &ts_file_text);
|
||||
|
||||
const string expected = R"(
|
||||
export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
|
||||
return null;
|
||||
}
|
||||
)";
|
||||
ExpectContainsStr(ts_file_text, expected);
|
||||
}
|
||||
|
||||
TEST(TsOpGenTest, TestVisibility) {
|
||||
const string api_def = R"(
|
||||
op {
|
||||
graph_op_name: "Foo"
|
||||
visibility: HIDDEN
|
||||
}
|
||||
)";
|
||||
|
||||
string ts_file_text;
|
||||
GenerateTsOpFileText(api_def, &ts_file_text);
|
||||
|
||||
const string expected = R"(
|
||||
export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
|
||||
return null;
|
||||
}
|
||||
)";
|
||||
ExpectDoesNotContainStr(ts_file_text, expected);
|
||||
}
|
||||
|
||||
TEST(TsOpGenTest, SkipDeprecated) {
|
||||
string ts_file_text;
|
||||
GenerateTsOpFileText("", &ts_file_text);
|
||||
|
||||
ExpectDoesNotContainStr(ts_file_text, "DeprecatedFoo");
|
||||
}
|
||||
|
||||
TEST(TsOpGenTest, MultiOutput) {
|
||||
string ts_file_text;
|
||||
GenerateTsOpFileText("", &ts_file_text);
|
||||
|
||||
const string expected = R"(
|
||||
export function MultiOutputFoo(input: tfc.Tensor): tfc.Tensor[] {
|
||||
return null;
|
||||
}
|
||||
)";
|
||||
ExpectContainsStr(ts_file_text, expected);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user