240 lines
5.6 KiB
C++
240 lines
5.6 KiB
C++
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/js/ops/ts_op_gen.h"
|
|
|
|
#include "tensorflow/core/framework/op_def.pb.h"
|
|
#include "tensorflow/core/framework/op_gen_lib.h"
|
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
|
#include "tensorflow/core/lib/io/path.h"
|
|
#include "tensorflow/core/lib/strings/str_util.h"
|
|
#include "tensorflow/core/platform/env.h"
|
|
#include "tensorflow/core/platform/test.h"
|
|
|
|
namespace tensorflow {
|
|
namespace {
|
|
|
|
void ExpectContainsStr(StringPiece s, StringPiece expected) {
|
|
EXPECT_TRUE(absl::StrContains(s, expected))
|
|
<< "'" << s << "' does not contain '" << expected << "'";
|
|
}
|
|
|
|
void ExpectDoesNotContainStr(StringPiece s, StringPiece expected) {
|
|
EXPECT_FALSE(absl::StrContains(s, expected))
|
|
<< "'" << s << "' does not contain '" << expected << "'";
|
|
}
|
|
|
|
constexpr char kBaseOpDef[] = R"(
|
|
op {
|
|
name: "Foo"
|
|
input_arg {
|
|
name: "images"
|
|
type_attr: "T"
|
|
number_attr: "N"
|
|
description: "Images to process."
|
|
}
|
|
input_arg {
|
|
name: "dim"
|
|
description: "Description for dim."
|
|
type: DT_FLOAT
|
|
}
|
|
output_arg {
|
|
name: "output"
|
|
description: "Description for output."
|
|
type: DT_FLOAT
|
|
}
|
|
attr {
|
|
name: "T"
|
|
type: "type"
|
|
description: "Type for images"
|
|
allowed_values {
|
|
list {
|
|
type: DT_UINT8
|
|
type: DT_INT8
|
|
}
|
|
}
|
|
default_value {
|
|
i: 1
|
|
}
|
|
}
|
|
attr {
|
|
name: "N"
|
|
type: "int"
|
|
has_minimum: true
|
|
minimum: 1
|
|
}
|
|
summary: "Summary for op Foo."
|
|
description: "Description for op Foo."
|
|
}
|
|
)";
|
|
|
|
// Generate TypeScript code
|
|
void GenerateTsOpFileText(const string& op_def_str, const string& api_def_str,
|
|
string* ts_file_text) {
|
|
Env* env = Env::Default();
|
|
OpList op_defs;
|
|
protobuf::TextFormat::ParseFromString(
|
|
op_def_str.empty() ? kBaseOpDef : op_def_str, &op_defs);
|
|
ApiDefMap api_def_map(op_defs);
|
|
|
|
if (!api_def_str.empty()) {
|
|
TF_ASSERT_OK(api_def_map.LoadApiDef(api_def_str));
|
|
}
|
|
|
|
const string& tmpdir = testing::TmpDir();
|
|
const auto ts_file_path = io::JoinPath(tmpdir, "test.ts");
|
|
|
|
WriteTSOps(op_defs, api_def_map, ts_file_path);
|
|
TF_ASSERT_OK(ReadFileToString(env, ts_file_path, ts_file_text));
|
|
}
|
|
|
|
TEST(TsOpGenTest, TestImports) {
|
|
string ts_file_text;
|
|
GenerateTsOpFileText("", "", &ts_file_text);
|
|
|
|
const string expected = R"(
|
|
import * as tfc from '@tensorflow/tfjs-core';
|
|
import {createTensorsTypeOpAttr, nodeBackend} from './op_utils';
|
|
)";
|
|
ExpectContainsStr(ts_file_text, expected);
|
|
}
|
|
|
|
TEST(TsOpGenTest, InputSingleAndList) {
|
|
const string api_def = R"pb(
|
|
op { graph_op_name: "Foo" arg_order: "dim" arg_order: "images" }
|
|
)pb";
|
|
|
|
string ts_file_text;
|
|
GenerateTsOpFileText("", api_def, &ts_file_text);
|
|
|
|
const string expected = R"(
|
|
export function Foo(dim: tfc.Tensor, images: tfc.Tensor[]): tfc.Tensor {
|
|
)";
|
|
ExpectContainsStr(ts_file_text, expected);
|
|
}
|
|
|
|
TEST(TsOpGenTest, TestVisibility) {
|
|
const string api_def = R"(
|
|
op {
|
|
graph_op_name: "Foo"
|
|
visibility: HIDDEN
|
|
}
|
|
)";
|
|
|
|
string ts_file_text;
|
|
GenerateTsOpFileText("", api_def, &ts_file_text);
|
|
|
|
const string expected = R"(
|
|
export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
|
|
)";
|
|
ExpectDoesNotContainStr(ts_file_text, expected);
|
|
}
|
|
|
|
TEST(TsOpGenTest, SkipDeprecated) {
|
|
const string op_def = R"(
|
|
op {
|
|
name: "DeprecatedFoo"
|
|
input_arg {
|
|
name: "input"
|
|
type_attr: "T"
|
|
description: "Description for input."
|
|
}
|
|
output_arg {
|
|
name: "output"
|
|
description: "Description for output."
|
|
type: DT_FLOAT
|
|
}
|
|
attr {
|
|
name: "T"
|
|
type: "type"
|
|
description: "Type for input"
|
|
allowed_values {
|
|
list {
|
|
type: DT_FLOAT
|
|
}
|
|
}
|
|
}
|
|
deprecation {
|
|
explanation: "Deprecated."
|
|
}
|
|
}
|
|
)";
|
|
|
|
string ts_file_text;
|
|
GenerateTsOpFileText(op_def, "", &ts_file_text);
|
|
|
|
ExpectDoesNotContainStr(ts_file_text, "DeprecatedFoo");
|
|
}
|
|
|
|
TEST(TsOpGenTest, MultiOutput) {
|
|
const string op_def = R"(
|
|
op {
|
|
name: "MultiOutputFoo"
|
|
input_arg {
|
|
name: "input"
|
|
description: "Description for input."
|
|
type_attr: "T"
|
|
}
|
|
output_arg {
|
|
name: "output1"
|
|
description: "Description for output 1."
|
|
type: DT_FLOAT
|
|
}
|
|
output_arg {
|
|
name: "output2"
|
|
description: "Description for output 2."
|
|
type: DT_FLOAT
|
|
}
|
|
attr {
|
|
name: "T"
|
|
type: "type"
|
|
description: "Type for input"
|
|
allowed_values {
|
|
list {
|
|
type: DT_FLOAT
|
|
}
|
|
}
|
|
}
|
|
summary: "Summary for op MultiOutputFoo."
|
|
description: "Description for op MultiOutputFoo."
|
|
}
|
|
)";
|
|
|
|
string ts_file_text;
|
|
GenerateTsOpFileText(op_def, "", &ts_file_text);
|
|
|
|
const string expected = R"(
|
|
export function MultiOutputFoo(input: tfc.Tensor): tfc.Tensor[] {
|
|
)";
|
|
ExpectContainsStr(ts_file_text, expected);
|
|
}
|
|
|
|
TEST(TsOpGenTest, OpAttrs) {
|
|
string ts_file_text;
|
|
GenerateTsOpFileText("", "", &ts_file_text);
|
|
|
|
const string expectedFooAttrs = R"(
|
|
const opAttrs = [
|
|
createTensorsTypeOpAttr('T', images),
|
|
{name: 'N', type: nodeBackend().binding.TF_ATTR_INT, value: images.length}
|
|
];
|
|
)";
|
|
|
|
ExpectContainsStr(ts_file_text, expectedFooAttrs);
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace tensorflow
|