STT-tensorflow/tensorflow/js/ops/ts_op_gen_test.cc
Gaurav Jain 4213d5c1bd Use absl instead of deprecated str_util
PiperOrigin-RevId: 249248716
2019-05-21 08:10:10 -07:00

240 lines
5.6 KiB
C++

/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/js/ops/ts_op_gen.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
void ExpectContainsStr(StringPiece s, StringPiece expected) {
EXPECT_TRUE(absl::StrContains(s, expected))
<< "'" << s << "' does not contain '" << expected << "'";
}
void ExpectDoesNotContainStr(StringPiece s, StringPiece expected) {
EXPECT_FALSE(absl::StrContains(s, expected))
<< "'" << s << "' does not contain '" << expected << "'";
}
constexpr char kBaseOpDef[] = R"(
op {
name: "Foo"
input_arg {
name: "images"
type_attr: "T"
number_attr: "N"
description: "Images to process."
}
input_arg {
name: "dim"
description: "Description for dim."
type: DT_FLOAT
}
output_arg {
name: "output"
description: "Description for output."
type: DT_FLOAT
}
attr {
name: "T"
type: "type"
description: "Type for images"
allowed_values {
list {
type: DT_UINT8
type: DT_INT8
}
}
default_value {
i: 1
}
}
attr {
name: "N"
type: "int"
has_minimum: true
minimum: 1
}
summary: "Summary for op Foo."
description: "Description for op Foo."
}
)";
// Generate TypeScript code
void GenerateTsOpFileText(const string& op_def_str, const string& api_def_str,
string* ts_file_text) {
Env* env = Env::Default();
OpList op_defs;
protobuf::TextFormat::ParseFromString(
op_def_str.empty() ? kBaseOpDef : op_def_str, &op_defs);
ApiDefMap api_def_map(op_defs);
if (!api_def_str.empty()) {
TF_ASSERT_OK(api_def_map.LoadApiDef(api_def_str));
}
const string& tmpdir = testing::TmpDir();
const auto ts_file_path = io::JoinPath(tmpdir, "test.ts");
WriteTSOps(op_defs, api_def_map, ts_file_path);
TF_ASSERT_OK(ReadFileToString(env, ts_file_path, ts_file_text));
}
TEST(TsOpGenTest, TestImports) {
string ts_file_text;
GenerateTsOpFileText("", "", &ts_file_text);
const string expected = R"(
import * as tfc from '@tensorflow/tfjs-core';
import {createTensorsTypeOpAttr, nodeBackend} from './op_utils';
)";
ExpectContainsStr(ts_file_text, expected);
}
TEST(TsOpGenTest, InputSingleAndList) {
const string api_def = R"pb(
op { graph_op_name: "Foo" arg_order: "dim" arg_order: "images" }
)pb";
string ts_file_text;
GenerateTsOpFileText("", api_def, &ts_file_text);
const string expected = R"(
export function Foo(dim: tfc.Tensor, images: tfc.Tensor[]): tfc.Tensor {
)";
ExpectContainsStr(ts_file_text, expected);
}
TEST(TsOpGenTest, TestVisibility) {
const string api_def = R"(
op {
graph_op_name: "Foo"
visibility: HIDDEN
}
)";
string ts_file_text;
GenerateTsOpFileText("", api_def, &ts_file_text);
const string expected = R"(
export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
)";
ExpectDoesNotContainStr(ts_file_text, expected);
}
TEST(TsOpGenTest, SkipDeprecated) {
const string op_def = R"(
op {
name: "DeprecatedFoo"
input_arg {
name: "input"
type_attr: "T"
description: "Description for input."
}
output_arg {
name: "output"
description: "Description for output."
type: DT_FLOAT
}
attr {
name: "T"
type: "type"
description: "Type for input"
allowed_values {
list {
type: DT_FLOAT
}
}
}
deprecation {
explanation: "Deprecated."
}
}
)";
string ts_file_text;
GenerateTsOpFileText(op_def, "", &ts_file_text);
ExpectDoesNotContainStr(ts_file_text, "DeprecatedFoo");
}
TEST(TsOpGenTest, MultiOutput) {
const string op_def = R"(
op {
name: "MultiOutputFoo"
input_arg {
name: "input"
description: "Description for input."
type_attr: "T"
}
output_arg {
name: "output1"
description: "Description for output 1."
type: DT_FLOAT
}
output_arg {
name: "output2"
description: "Description for output 2."
type: DT_FLOAT
}
attr {
name: "T"
type: "type"
description: "Type for input"
allowed_values {
list {
type: DT_FLOAT
}
}
}
summary: "Summary for op MultiOutputFoo."
description: "Description for op MultiOutputFoo."
}
)";
string ts_file_text;
GenerateTsOpFileText(op_def, "", &ts_file_text);
const string expected = R"(
export function MultiOutputFoo(input: tfc.Tensor): tfc.Tensor[] {
)";
ExpectContainsStr(ts_file_text, expected);
}
TEST(TsOpGenTest, OpAttrs) {
string ts_file_text;
GenerateTsOpFileText("", "", &ts_file_text);
const string expectedFooAttrs = R"(
const opAttrs = [
createTensorsTypeOpAttr('T', images),
{name: 'N', type: nodeBackend().binding.TF_ATTR_INT, value: images.length}
];
)";
ExpectContainsStr(ts_file_text, expectedFooAttrs);
}
} // namespace
} // namespace tensorflow