From 49115abfd39d30506679d9fdc572ccd2f7c22dbe Mon Sep 17 00:00:00 2001 From: Nick Kreeger Date: Mon, 20 Aug 2018 19:39:44 -0700 Subject: [PATCH] Introduce basic CC library for generating TypeScript files for TensorFlow.js from registered Ops. This initial change provides the very basics to start generating TypeScript. Non-deprecated and visible Ops are exported as a typescript function using internal functionality that is used the @tensorflow/tfjs-node repo (https://github.com/tensorflow/tfjs-node). Future changes will introduce more code generation + tests. This initial change will help set the foundation for those upcoming changes. PiperOrigin-RevId: 209528126 --- tensorflow/js/BUILD | 52 +++++++ tensorflow/js/ops/ts_op_gen.cc | 199 ++++++++++++++++++++++++++ tensorflow/js/ops/ts_op_gen.h | 31 ++++ tensorflow/js/ops/ts_op_gen_test.cc | 212 ++++++++++++++++++++++++++++ 4 files changed, 494 insertions(+) create mode 100644 tensorflow/js/BUILD create mode 100644 tensorflow/js/ops/ts_op_gen.cc create mode 100644 tensorflow/js/ops/ts_op_gen.h create mode 100644 tensorflow/js/ops/ts_op_gen_test.cc diff --git a/tensorflow/js/BUILD b/tensorflow/js/BUILD new file mode 100644 index 00000000000..ad0dc44f549 --- /dev/null +++ b/tensorflow/js/BUILD @@ -0,0 +1,52 @@ +# Description: +# JavaScript/TypeScript code generation for TensorFlow.js + +visibility = [ + "//tensorflow:internal", +] + +package(default_visibility = visibility) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +cc_library( + name = "ts_op_gen", + srcs = [ + "ops/ts_op_gen.cc", + ], + hdrs = [ + "ops/ts_op_gen.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:op_gen_lib", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_cc_test( + name = "ts_op_gen_test", + srcs = [ + "ops/ts_op_gen.cc", + "ops/ts_op_gen.h", + "ops/ts_op_gen_test.cc", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:op_gen_lib", + "//tensorflow/core:proto_text", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/js/ops/ts_op_gen.cc b/tensorflow/js/ops/ts_op_gen.cc new file mode 100644 index 00000000000..babf55cd5f2 --- /dev/null +++ b/tensorflow/js/ops/ts_op_gen.cc @@ -0,0 +1,199 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/js/ops/ts_op_gen.h" +#include + +#include "tensorflow/core/framework/api_def.pb.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { +namespace { + +static bool IsListAttr(const OpDef_ArgDef& arg) { + return !arg.type_list_attr().empty() || !arg.number_attr().empty(); +} + +// Struct to hold a combo OpDef and ArgDef for a given Op argument: +struct ArgDefs { + ArgDefs(const OpDef::ArgDef& op_def_arg, const ApiDef::Arg& api_def_arg) + : op_def_arg(op_def_arg), api_def_arg(api_def_arg) {} + + const OpDef::ArgDef& op_def_arg; + const ApiDef::Arg& api_def_arg; +}; + +// Helper class to generate TypeScript code for a given OpDef: +class GenTypeScriptOp { + public: + GenTypeScriptOp(const OpDef& op_def, const ApiDef& api_def); + ~GenTypeScriptOp(); + + // Returns the generated code as a string: + string Code(); + + private: + void ProcessArgs(); + + void AddMethodSignature(); + void AddMethodReturnAndClose(); + + const OpDef& op_def_; + const ApiDef& api_def_; + + // Placeholder string for all generated code: + string result_; + + // Holds in-order vector of Op inputs: + std::vector input_op_args_; + + // Holds number of outputs: + int num_outputs_; +}; + +GenTypeScriptOp::GenTypeScriptOp(const OpDef& op_def, const ApiDef& api_def) + : op_def_(op_def), api_def_(api_def), num_outputs_(0) {} + +GenTypeScriptOp::~GenTypeScriptOp() {} + +string GenTypeScriptOp::Code() { + ProcessArgs(); + + // Generate exported function for Op: + AddMethodSignature(); + AddMethodReturnAndClose(); + + strings::StrAppend(&result_, "\n"); + return result_; +} + +void GenTypeScriptOp::ProcessArgs() { + for (int i = 0; i < api_def_.arg_order_size(); i++) { + auto op_def_arg = FindInputArg(api_def_.arg_order(i), op_def_); + if (op_def_arg == nullptr) { + LOG(WARNING) << "Could not find OpDef::ArgDef for " + << api_def_.arg_order(i); + continue; + } + auto api_def_arg = FindInputArg(api_def_.arg_order(i), api_def_); + if (api_def_arg == nullptr) { + LOG(WARNING) << "Could not find ApiDef::Arg for " + << api_def_.arg_order(i); + continue; + } + input_op_args_.push_back(ArgDefs(*op_def_arg, *api_def_arg)); + } + + num_outputs_ = api_def_.out_arg_size(); +} + +void GenTypeScriptOp::AddMethodSignature() { + strings::StrAppend(&result_, "export function ", api_def_.endpoint(0).name(), + "("); + + bool is_first = true; + for (auto& in_arg : input_op_args_) { + if (is_first) { + is_first = false; + } else { + strings::StrAppend(&result_, ", "); + } + + auto op_def_arg = in_arg.op_def_arg; + + strings::StrAppend(&result_, op_def_arg.name(), ": "); + if (IsListAttr(op_def_arg)) { + strings::StrAppend(&result_, "tfc.Tensor[]"); + } else { + strings::StrAppend(&result_, "tfc.Tensor"); + } + } + + if (num_outputs_ == 1) { + strings::StrAppend(&result_, "): tfc.Tensor {\n"); + } else { + strings::StrAppend(&result_, "): tfc.Tensor[] {\n"); + } +} + +void GenTypeScriptOp::AddMethodReturnAndClose() { + strings::StrAppend(&result_, " return null;\n}\n"); +} + +void WriteTSOp(const OpDef& op_def, const ApiDef& api_def, WritableFile* ts) { + GenTypeScriptOp ts_op(op_def, api_def); + TF_CHECK_OK(ts->Append(GenTypeScriptOp(op_def, api_def).Code())); +} + +void StartFile(WritableFile* ts_file) { + const string header = + R"header(/** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// This file is MACHINE GENERATED! Do not edit + +import * as tfc from '@tensorflow/tfjs-core'; +import {createTypeOpAttr, getTFDTypeForInputs, nodeBackend} from './op_utils'; + +)header"; + + TF_CHECK_OK(ts_file->Append(header)); +} + +} // namespace + +void WriteTSOps(const OpList& ops, const ApiDefMap& api_def_map, + const string& ts_filename) { + Env* env = Env::Default(); + + std::unique_ptr ts_file = nullptr; + TF_CHECK_OK(env->NewWritableFile(ts_filename, &ts_file)); + + StartFile(ts_file.get()); + + for (const auto& op_def : ops.op()) { + // Skip deprecated ops + if (op_def.has_deprecation() && + op_def.deprecation().version() <= TF_GRAPH_DEF_VERSION) { + continue; + } + + const auto* api_def = api_def_map.GetApiDef(op_def.name()); + if (api_def->visibility() == ApiDef::VISIBLE) { + WriteTSOp(op_def, *api_def, ts_file.get()); + } + } + + TF_CHECK_OK(ts_file->Close()); +} + +} // namespace tensorflow diff --git a/tensorflow/js/ops/ts_op_gen.h b/tensorflow/js/ops/ts_op_gen.h new file mode 100644 index 00000000000..fcd46a17a77 --- /dev/null +++ b/tensorflow/js/ops/ts_op_gen.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_JS_OPS_TS_OP_GEN_H_ +#define TENSORFLOW_JS_OPS_TS_OP_GEN_H_ + +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Generated code is written to the file ts_filename: +void WriteTSOps(const OpList& ops, const ApiDefMap& api_def_map, + const string& ts_filename); + +} // namespace tensorflow + +#endif // TENSORFLOW_JS_OPS_TS_OP_GEN_H_ diff --git a/tensorflow/js/ops/ts_op_gen_test.cc b/tensorflow/js/ops/ts_op_gen_test.cc new file mode 100644 index 00000000000..9a85c021b09 --- /dev/null +++ b/tensorflow/js/ops/ts_op_gen_test.cc @@ -0,0 +1,212 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/js/ops/ts_op_gen.h" + +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +void ExpectContainsStr(StringPiece s, StringPiece expected) { + EXPECT_TRUE(str_util::StrContains(s, expected)) + << "'" << s << "' does not contain '" << expected << "'"; +} + +void ExpectDoesNotContainStr(StringPiece s, StringPiece expected) { + EXPECT_FALSE(str_util::StrContains(s, expected)) + << "'" << s << "' does not contain '" << expected << "'"; +} + +// TODO(kreeger): Add multiple outputs here? +constexpr char kBaseOpDef[] = R"( +op { + name: "Foo" + input_arg { + name: "images" + type_attr: "T" + number_attr: "N" + description: "Images to process." + } + input_arg { + name: "dim" + description: "Description for dim." + type: DT_FLOAT + } + output_arg { + name: "output" + description: "Description for output." + type: DT_FLOAT + } + attr { + name: "T" + type: "type" + description: "Type for images" + allowed_values { + list { + type: DT_UINT8 + type: DT_INT8 + } + } + default_value { + i: 1 + } + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + summary: "Summary for op Foo." + description: "Description for op Foo." +} +op { + name: "DeprecatedFoo" + input_arg { + name: "input" + description: "Description for input." + type: DT_FLOAT + } + output_arg { + name: "output" + description: "Description for output." + type: DT_FLOAT + } + deprecation { + explanation: "Deprecated." + } +} +op { + name: "MultiOutputFoo" + input_arg { + name: "input" + description: "Description for input." + type: DT_FLOAT + } + output_arg { + name: "output1" + description: "Description for output 1." + type: DT_FLOAT + } + output_arg { + name: "output2" + description: "Description for output 2." + type: DT_FLOAT + } + summary: "Summary for op MultiOutputFoo." + description: "Description for op MultiOutputFoo." +} +)"; + +// Generate TypeScript code +// @param api_def_str TODO doc me. +void GenerateTsOpFileText(const string& api_def_str, string* ts_file_text) { + Env* env = Env::Default(); + OpList op_defs; + protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); + ApiDefMap api_def_map(op_defs); + + if (!api_def_str.empty()) { + TF_ASSERT_OK(api_def_map.LoadApiDef(api_def_str)); + } + + const string& tmpdir = testing::TmpDir(); + const auto ts_file_path = io::JoinPath(tmpdir, "test.ts"); + + WriteTSOps(op_defs, api_def_map, ts_file_path); + TF_ASSERT_OK(ReadFileToString(env, ts_file_path, ts_file_text)); +} + +TEST(TsOpGenTest, TestImports) { + string ts_file_text; + GenerateTsOpFileText("", &ts_file_text); + + const string expected = R"( +import * as tfc from '@tensorflow/tfjs-core'; +import {createTypeOpAttr, getTFDTypeForInputs, nodeBackend} from './op_utils'; +)"; + ExpectContainsStr(ts_file_text, expected); +} + +TEST(TsOpGenTest, InputSingleAndList) { + const string api_def = R"( +op { + name: "Foo" + input_arg { + name: "images" + type_attr: "T" + number_attr: "N" + } +} +)"; + + string ts_file_text; + GenerateTsOpFileText(api_def, &ts_file_text); + + const string expected = R"( +export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor { + return null; +} +)"; + ExpectContainsStr(ts_file_text, expected); +} + +TEST(TsOpGenTest, TestVisibility) { + const string api_def = R"( +op { + graph_op_name: "Foo" + visibility: HIDDEN +} +)"; + + string ts_file_text; + GenerateTsOpFileText(api_def, &ts_file_text); + + const string expected = R"( +export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor { + return null; +} +)"; + ExpectDoesNotContainStr(ts_file_text, expected); +} + +TEST(TsOpGenTest, SkipDeprecated) { + string ts_file_text; + GenerateTsOpFileText("", &ts_file_text); + + ExpectDoesNotContainStr(ts_file_text, "DeprecatedFoo"); +} + +TEST(TsOpGenTest, MultiOutput) { + string ts_file_text; + GenerateTsOpFileText("", &ts_file_text); + + const string expected = R"( +export function MultiOutputFoo(input: tfc.Tensor): tfc.Tensor[] { + return null; +} +)"; + ExpectContainsStr(ts_file_text, expected); +} + +} // namespace +} // namespace tensorflow