Merge pull request from rahul-kamat:gen-ops-with-type-annotations

PiperOrigin-RevId: 318862240
Change-Id: Idbc6c43c0ab9418930298d4dc71c31a8abf5df65
This commit is contained in:
TensorFlower Gardener 2020-06-29 12:04:02 -07:00
commit a8eb45b4af
6 changed files with 692 additions and 41 deletions

View File

@ -45,6 +45,33 @@ const int kRightMargin = 78;
constexpr char kEagerFallbackSuffix[] = "_eager_fallback";
// Maps C++ dtype enum values to Python DType classes
const std::unordered_map<string, string> dtype_type{
{"_dtypes.float16", "_dtypes.Float16"},
{"_dtypes.half", "_dtypes.Half"},
{"_dtypes.float32", "_dtypes.Float32"},
{"_dtypes.float64", "_dtypes.Float64"},
{"_dtypes.bfloat16", "_dtypes.BFloat16"},
{"_dtypes.complex64", "_dtypes.Complex64"},
{"_dtypes.complex128", "_dtypes.Complex128"},
{"_dtypes.int8", "_dtypes.Int8"},
{"_dtypes.uint8", "_dtypes.UInt8"},
{"_dtypes.uint16", "_dtypes.UInt16"},
{"_dtypes.uint32", "_dtypes.UInt32"},
{"_dtypes.uint64", "_dtypes.UInt64"},
{"_dtypes.int16", "_dtypes.Int16"},
{"_dtypes.int32", "_dtypes.Int32"},
{"_dtypes.int64", "_dtypes.Int64"},
{"_dtypes.bool", "_dtypes.Bool"},
{"_dtypes.string", "_dtypes.String"},
{"_dtypes.qint8", "_dtypes.QInt8"},
{"_dtypes.quint8", "_dtypes.QUInt8"},
{"_dtypes.qint16", "_dtypes.QInt16"},
{"_dtypes.quint16", "_dtypes.QUInt16"},
{"_dtypes.qint32", "_dtypes.QInt32"},
{"_dtypes.resource", "_dtypes.Resource"},
{"_dtypes.variant", "_dtypes.Variant"}};
string AttrVarName(const string& attr_name,
std::unordered_map<string, string>* attr_expressions) {
const string var = strings::StrCat("_attr_", attr_name);
@ -106,8 +133,9 @@ string TensorPBString(const TensorProto& pb) {
class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
public:
GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
const string& function_name)
: python_op_gen_internal::GenPythonOp(op_def, api_def, function_name) {
const string& function_name, bool add_type_annotations)
: python_op_gen_internal::GenPythonOp(op_def, api_def, function_name,
add_type_annotations) {
op_name_ = function_name_;
absl::ConsumePrefix(&op_name_, "_");
}
@ -130,13 +158,14 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
const std::vector<string>& output_sizes,
bool execute_record_gradient);
bool AddEagerFastPathAndGraphCode(const string& parameters,
const std::vector<string>& output_sizes,
const string& eager_not_allowed_error);
bool AddEagerFallbackCode(const string& parameters,
const std::vector<string>& output_sizes,
const string& num_outputs_expr,
const string& eager_not_allowed_error);
bool AddEagerFastPathAndGraphCode(
const string& parameters, const std::vector<string>& output_sizes,
const string& eager_not_allowed_error,
const std::unordered_map<string, string>& type_annotations);
bool AddEagerFallbackCode(
const string& parameters, const std::vector<string>& output_sizes,
const string& num_outputs_expr, const string& eager_not_allowed_error,
const std::unordered_map<string, string>& type_annotations);
void AddEagerFastPathExecute();
void AddEagerInferredAttrs(const string& indentation);
@ -148,6 +177,14 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
void AddRawOpExport(const string& parameters);
std::unordered_map<string, string> GetTypeAnnotations();
void GenerateTypeVars(
const std::unordered_map<string, string>& type_annotations);
void AddReturnTypeAnnotation(
const std::unordered_map<string, string>& type_annotations);
void AddAttrForArg(const string& attr, int arg_index) {
gtl::InsertIfNotPresent(&inferred_attrs_, attr,
op_def_.input_arg(arg_index).name());
@ -179,8 +216,10 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
};
string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
const string& function_name) {
return GenEagerPythonOp(op_def, api_def, function_name).Code();
const string& function_name,
bool add_type_annotations) {
return GenEagerPythonOp(op_def, api_def, function_name, add_type_annotations)
.Code();
}
string GenEagerPythonOp::FlattenInputs(
@ -311,19 +350,45 @@ string GenEagerPythonOp::Code() {
param_names_.push_back(param_and_default.first);
}
std::unordered_map<string, string> type_annotations;
// Only populate map for whitelisted ops
if (add_type_annotations_) {
type_annotations = GetTypeAnnotations();
}
string parameters;
// Param can be an input or an attr
for (const auto& param : params_no_default_) {
if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
strings::StrAppend(&parameters, param.GetRenameTo());
if (type_annotations.find(param.GetName()) != type_annotations.end()) {
strings::StrAppend(&parameters, ": ",
type_annotations.at(param.GetName()));
}
}
string parameters_with_defaults = parameters;
for (const auto& param_and_default : params_with_default_) {
if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
if (!parameters_with_defaults.empty())
strings::StrAppend(&parameters_with_defaults, ", ");
strings::StrAppend(&parameters, param_and_default.first.GetRenameTo());
strings::StrAppend(&parameters_with_defaults,
param_and_default.first.GetRenameTo(), "=",
param_and_default.first.GetRenameTo());
if (type_annotations.find(param_and_default.first.GetName()) !=
type_annotations.end()) {
const string param_type =
type_annotations.at(param_and_default.first.GetName());
// Append to parameters and parameters_with_defaults because multiple
// functions are generated by AddEagerFastPathAndGraphCode() and
// AddEagerFallbackCode()
strings::StrAppend(&parameters, ": ", param_type);
strings::StrAppend(&parameters_with_defaults, ":", param_type);
}
strings::StrAppend(&parameters_with_defaults, "=",
param_and_default.second);
}
@ -356,18 +421,108 @@ string GenEagerPythonOp::Code() {
string eager_not_allowed_error = GetEagerNotAllowedError();
if (!AddEagerFastPathAndGraphCode(parameters_with_defaults, output_sizes,
eager_not_allowed_error)) {
eager_not_allowed_error,
type_annotations)) {
return result_;
}
if (!AddEagerFallbackCode(parameters, output_sizes, num_outputs_expr,
eager_not_allowed_error)) {
eager_not_allowed_error, type_annotations)) {
return result_;
}
return prelude_ + result_;
}
std::unordered_map<string, string> GenEagerPythonOp::GetTypeAnnotations() {
std::unordered_map<string, string> type_annotations;
// Map attrs to TypeVars
for (const auto& attr : op_def_.attr()) {
if (attr.type() == "type") {
const string type_var_name = "TV_" + op_def_.name() + "_" + attr.name();
type_annotations[attr.name()] = type_var_name;
} else if (attr.type() == "bool" || attr.type() == "float" ||
attr.type() == "int" || attr.type() == "bytes") {
type_annotations[attr.name()] = attr.type();
} else if (attr.type() == "string") {
type_annotations[attr.name()] = "str";
}
}
// Map input Tensors to their types
for (const auto& arg : op_def_.input_arg()) {
// TODO(rahulkamat): Add type annotations to args that accept a sequence of
// Tensors
if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) continue;
type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
}
// TODO(rahulkamat): Add type annotations to handle return types of a sequence
// of Tensors. Map output Tensor to its type
if (op_def_.output_arg_size() == 1) {
const auto& arg = op_def_.output_arg(0);
if (arg.number_attr().empty() && arg.type_list_attr().empty())
type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
}
return type_annotations;
}
// Generate TypeVars using attrs
void GenEagerPythonOp::GenerateTypeVars(
const std::unordered_map<string, string>& type_annotations) {
bool added_typevar = false;
for (const auto& attr : op_def_.attr()) {
if (attr.type() == "type") {
std::vector<string> allowed_types;
for (int t : attr.allowed_values().list().type()) {
DataType dtype = static_cast<DataType>(t);
const string py_dtype =
python_op_gen_internal::DataTypeToPython(dtype, "_dtypes.");
allowed_types.emplace_back(dtype_type.at(py_dtype));
}
// When a Tensor does not have any dtypes specified, all dtypes are
// allowed
if (allowed_types.empty()) {
for (std::pair<string, string> map_dtype : dtype_type) {
allowed_types.emplace_back(map_dtype.second);
}
}
std::sort(allowed_types.begin(), allowed_types.end());
string typevar_dtypes;
for (std::vector<string>::iterator it = allowed_types.begin();
it != allowed_types.end(); ++it) {
if (!typevar_dtypes.empty()) strings::StrAppend(&typevar_dtypes, ", ");
strings::StrAppend(&typevar_dtypes, *it);
}
const string type_var_name = type_annotations.at(attr.name());
strings::StrAppend(&result_, type_var_name, " = TypeVar(\"",
type_var_name, "\", ", typevar_dtypes, ")\n");
added_typevar = true;
}
}
if (added_typevar) strings::StrAppend(&result_, "\n");
}
void GenEagerPythonOp::AddReturnTypeAnnotation(
const std::unordered_map<string, string>& type_annotations) {
if (op_def_.output_arg_size() == 1) {
const auto& arg = op_def_.output_arg(0);
if (arg.number_attr().empty() && arg.type_list_attr().empty()) {
const string return_type = type_annotations.at(arg.name());
// TODO(rahulkamat): Modify AddDefLine() to add return type annotation to
// avoid erasing ":\n" from the end of the def line
result_.erase(result_.length() - 2);
strings::StrAppend(&result_, " -> ", return_type, ":\n");
}
}
}
void GenEagerPythonOp::HandleGraphMode(
const string& function_setup, const std::vector<string>& output_sizes) {
strings::StrAppend(&result_, " # Add nodes to the TensorFlow graph.\n");
@ -690,13 +845,20 @@ void GenEagerPythonOp::AddEagerFunctionTeardown(
bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
const string& parameters, const std::vector<string>& output_sizes,
const string& eager_not_allowed_error) {
const string& eager_not_allowed_error,
const std::unordered_map<string, string>& type_annotations) {
if (add_type_annotations_) {
GenerateTypeVars(type_annotations);
}
if (api_def_.visibility() == ApiDef::VISIBLE) {
strings::StrAppend(&result_, "@_dispatch.add_dispatch_list\n");
}
AddExport();
AddDefLine(function_name_, parameters);
if (add_type_annotations_) {
AddReturnTypeAnnotation(type_annotations);
}
AddDocStringDescription();
AddDocStringArgs();
AddDocStringInputs();
@ -731,11 +893,14 @@ bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
bool GenEagerPythonOp::AddEagerFallbackCode(
const string& parameters, const std::vector<string>& output_sizes,
const string& num_outputs_expr, const string& eager_not_allowed_error) {
const string& num_outputs_expr, const string& eager_not_allowed_error,
const std::unordered_map<string, string>& type_annotations) {
AddDefLine(
strings::StrCat(function_name_, kEagerFallbackSuffix),
strings::StrCat(parameters, parameters.empty() ? "" : ", ", "ctx"));
if (add_type_annotations_) {
AddReturnTypeAnnotation(type_annotations);
}
if (!eager_not_allowed_error.empty()) {
strings::StrAppend(&result_, " ", eager_not_allowed_error);
return true;
@ -982,9 +1147,10 @@ void GenEagerPythonOp::AddRawOpExport(const string& parameters) {
function_name_, "))\n");
}
string GetPythonOpsImpl(const OpList& ops, const ApiDefMap& api_defs,
const std::vector<string>& hidden_ops,
const string& source_file_name = "") {
string GetPythonOpsImpl(
const OpList& ops, const ApiDefMap& api_defs,
const std::vector<string>& hidden_ops, const string& source_file_name = "",
const std::unordered_set<string> type_annotate_ops = {}) {
string result;
// Header
// TODO(josh11b): Mention the library for which wrappers are being generated.
@ -1018,6 +1184,7 @@ from tensorflow.python.util.deprecation import deprecated_endpoints
from tensorflow.python.util import dispatch as _dispatch
from tensorflow.python.util.tf_export import tf_export
from typing import TypeVar
)");
for (const auto& op_def : ops.op()) {
@ -1061,8 +1228,12 @@ from tensorflow.python.util.tf_export import tf_export
continue;
}
auto iter = type_annotate_ops.find(op_def.name());
bool add_type_annotations = iter != type_annotate_ops.end();
strings::StrAppend(&result,
GetEagerPythonOp(op_def, *api_def, function_name));
GetEagerPythonOp(op_def, *api_def, function_name,
add_type_annotations));
}
return result;
@ -1072,15 +1243,19 @@ from tensorflow.python.util.tf_export import tf_export
string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
const std::vector<string>& hidden_ops,
const string& source_file_name) {
return GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name);
const string& source_file_name,
const std::unordered_set<string> type_annotate_ops) {
return GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name,
type_annotate_ops);
}
void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
const std::vector<string>& hidden_ops,
const string& source_file_name) {
printf("%s",
GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name).c_str());
const string& source_file_name,
const std::unordered_set<string> type_annotate_ops) {
printf("%s", GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name,
type_annotate_ops)
.c_str());
}
string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
@ -1091,4 +1266,20 @@ string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
return GetPythonOpsImpl(ops, api_def_map, {});
}
string GetArgAnnotation(
const OpDef::ArgDef& arg,
const std::unordered_map<string, string>& type_annotations) {
if (!arg.type_attr().empty()) {
// Get the correct TypeVar if arg maps to an attr
return "_ops.Tensor[" + type_annotations.at(arg.type_attr()) + "]";
} else {
// Get the dtype of the Tensor
const string py_dtype =
python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
return "_ops.Tensor[" + dtype_type.at(py_dtype) + "]";
}
return "Any";
}
} // namespace tensorflow

View File

@ -16,7 +16,9 @@ limitations under the License.
#define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_H_
#include <string>
#include <unordered_set>
#include <vector>
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/platform/types.h"
@ -32,7 +34,8 @@ namespace tensorflow {
// file where the ops' REGISTER_OP() calls reside.
string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
const std::vector<string>& hidden_ops,
const string& source_file_name);
const string& source_file_name,
const std::unordered_set<string> type_annotate_ops);
// Prints the output of GetPrintOps to stdout.
// hidden_ops should be a list of Op names that should get a leading _
@ -41,7 +44,8 @@ string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
// where the ops' REGISTER_OP() calls reside.
void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
const std::vector<string>& hidden_ops,
const string& source_file_name);
const string& source_file_name,
const std::unordered_set<string> type_annotate_ops);
// Get the python wrappers for a list of ops in a OpList.
// `op_list_buf` should be a pointer to a buffer containing
@ -49,6 +53,13 @@ void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
// length of that buffer.
string GetPythonWrappers(const char* op_list_buf, size_t op_list_len);
// Get the type annotation for an arg
// `arg` should be an input or output of an op
// `type_annotations` should contain attr names mapped to TypeVar names
string GetArgAnnotation(
const OpDef::ArgDef& arg,
const std::unordered_map<string, string>& type_annotations);
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_H_

View File

@ -513,10 +513,11 @@ const ApiDef::Attr* FindAttr(StringPiece name, const ApiDef& api_def) {
}
GenPythonOp::GenPythonOp(const OpDef& op_def, const ApiDef& api_def,
const string& function_name)
const string& function_name, bool add_type_annotations)
: op_def_(op_def),
api_def_(api_def),
function_name_(function_name),
add_type_annotations_(add_type_annotations),
num_outs_(op_def.output_arg_size()) {}
GenPythonOp::~GenPythonOp() {}

View File

@ -71,7 +71,7 @@ class ParamNames {
class GenPythonOp {
public:
GenPythonOp(const OpDef& op_def, const ApiDef& api_def,
const string& function_name);
const string& function_name, bool add_type_annotations_);
virtual ~GenPythonOp();
virtual string Code();
@ -98,6 +98,7 @@ class GenPythonOp {
const OpDef& op_def_;
const ApiDef& api_def_;
const string function_name_;
bool add_type_annotations_;
const int num_outs_;
// Return value from Code() is prelude_ + result_.

View File

@ -108,7 +108,8 @@ string InferSourceFileName(const char* argv_zero) {
void PrintAllPythonOps(const std::vector<string>& op_list,
const std::vector<string>& api_def_dirs,
const string& source_file_name,
bool op_list_is_whitelist) {
bool op_list_is_whitelist,
const std::unordered_set<string> type_annotate_ops) {
OpList ops;
OpRegistry::Global()->Export(false, &ops);
@ -133,9 +134,11 @@ void PrintAllPythonOps(const std::vector<string>& op_list,
*pruned_ops.mutable_op()->Add() = op_def;
}
}
PrintPythonOps(pruned_ops, api_def_map, {}, source_file_name);
PrintPythonOps(pruned_ops, api_def_map, {}, source_file_name,
type_annotate_ops);
} else {
PrintPythonOps(ops, api_def_map, op_list, source_file_name);
PrintPythonOps(ops, api_def_map, op_list, source_file_name,
type_annotate_ops);
}
}
@ -157,19 +160,25 @@ int main(int argc, char* argv[]) {
std::vector<tensorflow::string> api_def_dirs = tensorflow::str_util::Split(
argv[1], ",", tensorflow::str_util::SkipEmpty());
// Add op name here to generate type annotations for it
const std::unordered_set<tensorflow::string> type_annotate_ops{};
if (argc == 2) {
tensorflow::PrintAllPythonOps({}, api_def_dirs, source_file_name,
false /* op_list_is_whitelist */);
false /* op_list_is_whitelist */,
type_annotate_ops);
} else if (argc == 3) {
std::vector<tensorflow::string> hidden_ops;
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &hidden_ops));
tensorflow::PrintAllPythonOps(hidden_ops, api_def_dirs, source_file_name,
false /* op_list_is_whitelist */);
false /* op_list_is_whitelist */,
type_annotate_ops);
} else if (argc == 4) {
std::vector<tensorflow::string> op_list;
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &op_list));
tensorflow::PrintAllPythonOps(op_list, api_def_dirs, source_file_name,
tensorflow::string(argv[3]) == "1");
tensorflow::string(argv[3]) == "1",
type_annotate_ops);
} else {
return -1;
}

View File

@ -23,20 +23,458 @@ limitations under the License.
namespace tensorflow {
namespace {
TEST(PythonOpGen, Basic) {
void ExpectHasSubstr(const string& s, const string& expected) {
EXPECT_TRUE(absl::StrContains(s, expected))
<< "'Generated ops does not contain '" << expected << "'";
}
void ExpectDoesNotHaveSubstr(const string& s, const string& expected) {
EXPECT_FALSE(absl::StrContains(s, expected))
<< "'Generated ops contains '" << expected << "'";
}
void ExpectSubstrOrder(const string& s, const string& before,
const string& after) {
int before_pos = s.find(before);
int after_pos = s.find(after);
ASSERT_NE(std::string::npos, before_pos);
ASSERT_NE(std::string::npos, after_pos);
EXPECT_LT(before_pos, after_pos) << before << "' is not before '" << after;
}
TEST(PythonOpGen, TypeAnnotateAllOps) {
OpList ops;
OpRegistry::Global()->Export(false, &ops);
ApiDefMap api_def_map(ops);
string code = GetPythonOps(ops, api_def_map, {}, "");
std::unordered_set<string> type_annotate_ops;
for (const auto& op : ops.op()) {
type_annotate_ops.insert(op.name());
}
EXPECT_TRUE(absl::StrContains(code, "def case"));
string code = GetPythonOps(ops, api_def_map, {}, "", type_annotate_ops);
// TODO(mdan): Add tests to verify type annotations are correctly added.
const string all_types =
", _dtypes.BFloat16, _dtypes.Bool, _dtypes.Complex128, "
"_dtypes.Complex64, "
"_dtypes.Float16, _dtypes.Float32, _dtypes.Float64, _dtypes.Half, "
"_dtypes.Int16, "
"_dtypes.Int32, _dtypes.Int64, _dtypes.Int8, _dtypes.QInt16, "
"_dtypes.QInt32, "
"_dtypes.QInt8, _dtypes.QUInt16, _dtypes.QUInt8, _dtypes.Resource, "
"_dtypes.String, "
"_dtypes.UInt16, _dtypes.UInt32, _dtypes.UInt64, _dtypes.UInt8, "
"_dtypes.Variant)";
const string fake_param_typevar =
"TV_FakeParam_dtype = TypeVar(\"TV_FakeParam_dtype\"" + all_types;
const string fake_param =
"def fake_param_eager_fallback(dtype: TV_FakeParam_dtype, shape, name, "
"ctx) -> _ops.Tensor[TV_FakeParam_dtype]:";
const string fake_param_fallback =
"def fake_param_eager_fallback(dtype: TV_FakeParam_dtype, shape, name, "
"ctx) -> _ops.Tensor[TV_FakeParam_dtype]:";
ExpectHasSubstr(code, fake_param_typevar);
ExpectHasSubstr(code, fake_param);
ExpectHasSubstr(code, fake_param_fallback);
const string to_bool_typevar =
"TV_ToBool_T = TypeVar(\"TV_ToBool_T\"" + all_types;
const string to_bool_ =
"def to_bool(input: _ops.Tensor[TV_ToBool_T], name=None) -> "
"_ops.Tensor[_dtypes.Bool]:";
const string to_bool_fallback =
"def to_bool_eager_fallback(input: _ops.Tensor[TV_ToBool_T], name, ctx) "
"-> _ops.Tensor[_dtypes.Bool]:";
ExpectHasSubstr(code, to_bool_typevar);
ExpectHasSubstr(code, to_bool_);
ExpectHasSubstr(code, to_bool_fallback);
}
// TODO(mdan): Include more tests with synhtetic ops and api defs.
TEST(PythonOpGen, TypeAnnotateSingleTypeTensor) {
constexpr char kBaseOpDef[] = R"(
op {
name: "Bar"
input_arg {
name: "x"
type: DT_STRING
}
input_arg {
name: "y"
type: DT_QINT8
}
output_arg {
name: "output"
type: DT_BOOL
}
summary: "Summary for op Bar."
description: "Description for op Bar."
}
)";
std::unordered_set<string> type_annotate_ops{"Bar"};
OpList op_defs;
OpRegistry::Global()->Export(false, &op_defs);
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
ApiDefMap api_def_map(op_defs);
string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
const string typed_bar =
"def bar(x: _ops.Tensor[_dtypes.String], y: _ops.Tensor[_dtypes.QInt8], "
"name=None) -> _ops.Tensor[_dtypes.Bool]:";
ExpectHasSubstr(code, typed_bar);
const string untyped_bar = "def bar(x, y, name=None):";
ExpectDoesNotHaveSubstr(code, untyped_bar);
}
TEST(PythonOpGen, TypeAnnotateMultiTypeTensor) {
constexpr char kBaseOpDef[] = R"(
op {
name: "Foo"
input_arg {
name: "x"
type_attr: "T"
}
input_arg {
name: "y"
type_attr: "T2"
}
output_arg {
name: "output"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_UINT8
type: DT_INT8
}
}
}
attr {
name: "T2"
type: "type"
allowed_values {
list {
type: DT_STRING
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
summary: "Summary for op Foo."
description: "Description for op Foo."
}
)";
std::unordered_set<string> type_annotate_ops{
"Foo",
};
OpList op_defs;
OpRegistry::Global()->Export(false, &op_defs);
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
ApiDefMap api_def_map(op_defs);
string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
const string typed_foo =
"def foo(x: _ops.Tensor[TV_Foo_T], y: _ops.Tensor[TV_Foo_T2], name=None) "
"-> _ops.Tensor[TV_Foo_T]:";
ExpectHasSubstr(code, typed_foo);
}
TEST(PythonOpGen, GenerateCorrectTypeVars) {
constexpr char kBaseOpDef[] = R"(
op {
name: "Foo"
input_arg {
name: "x"
type_attr: "T"
}
input_arg {
name: "y"
type_attr: "T2"
}
output_arg {
name: "output"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_UINT8
type: DT_INT8
}
}
}
attr {
name: "T2"
type: "type"
allowed_values {
list {
type: DT_STRING
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
summary: "Summary for op Foo."
description: "Description for op Foo."
}
)";
std::unordered_set<string> type_annotate_ops{
"Foo",
};
OpList op_defs;
OpRegistry::Global()->Export(false, &op_defs);
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
ApiDefMap api_def_map(op_defs);
string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
const string typevars_foo = R"(
TV_Foo_T = TypeVar("TV_Foo_T", _dtypes.Int8, _dtypes.UInt8)
TV_Foo_T2 = TypeVar("TV_Foo_T2", _dtypes.Float32, _dtypes.Float64, _dtypes.String)
)";
ExpectHasSubstr(code, typevars_foo);
}
TEST(PythonOpGen, TypeAnnotateFallback) {
constexpr char kBaseOpDef[] = R"(
op {
name: "Foo"
input_arg {
name: "x"
type_attr: "T"
}
input_arg {
name: "y"
type_attr: "T2"
}
output_arg {
name: "output"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_UINT8
type: DT_INT8
}
}
}
attr {
name: "T2"
type: "type"
allowed_values {
list {
type: DT_STRING
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
summary: "Summary for op Foo."
description: "Description for op Foo."
}
)";
std::unordered_set<string> type_annotate_ops{
"Foo",
};
OpList op_defs;
OpRegistry::Global()->Export(false, &op_defs);
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
ApiDefMap api_def_map(op_defs);
string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
const string typed_foo_fallback =
"def foo_eager_fallback(x: _ops.Tensor[TV_Foo_T], y: "
"_ops.Tensor[TV_Foo_T2], name, ctx) -> _ops.Tensor[TV_Foo_T]:";
ExpectHasSubstr(code, typed_foo_fallback);
}
TEST(PythonOpGen, GenerateTypeVarAboveOp) {
constexpr char kBaseOpDef[] = R"(
op {
name: "Foo"
input_arg {
name: "x"
type_attr: "T"
}
input_arg {
name: "y"
type_attr: "T2"
}
output_arg {
name: "output"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_UINT8
type: DT_INT8
}
}
}
attr {
name: "T2"
type: "type"
allowed_values {
list {
type: DT_STRING
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
summary: "Summary for op Foo."
description: "Description for op Foo."
}
)";
std::unordered_set<string> type_annotate_ops{
"Foo",
};
OpList op_defs;
OpRegistry::Global()->Export(false, &op_defs);
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
ApiDefMap api_def_map(op_defs);
string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
const string typevar_foo = "TV_Foo_";
const string def_foo = "def foo";
ExpectSubstrOrder(code, typevar_foo, def_foo);
}
TEST(PythonOpGen, TypeAnnotateDefaultParams) {
constexpr char kBaseOpDef[] = R"(
op {
name: "FooBar"
input_arg {
name: "x"
type: DT_FLOAT
}
output_arg {
name: "output"
type: DT_BOOL
}
attr {
name: "t"
type: "type"
allowed_values {
list {
type: DT_HALF
type: DT_INT8
}
}
}
attr {
name: "var1"
type: "bool"
default_value {
b: false
}
}
attr {
name: "var2"
type: "int"
default_value {
i: 0
}
}
summary: "Summary for op FooBar."
description: "Description for op FooBar."
}
)";
std::unordered_set<string> type_annotate_ops{
"FooBar",
};
OpList op_defs;
OpRegistry::Global()->Export(false, &op_defs);
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
ApiDefMap api_def_map(op_defs);
string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
const string params =
"def foo_bar(x: _ops.Tensor[_dtypes.Float32], t: TV_FooBar_t, "
"var1:bool=False, var2:int=0, name=None)";
const string params_fallback =
"def foo_bar_eager_fallback(x: _ops.Tensor[_dtypes.Float32], t: "
"TV_FooBar_t, var1: bool, var2: int, name, ctx)";
ExpectHasSubstr(code, params);
ExpectHasSubstr(code, params_fallback);
}
TEST(PythonOpGen, NoTypingSequenceTensors) {
constexpr char kBaseOpDef[] = R"(
op {
name: "Baz"
input_arg {
name: "inputs"
number_attr: "N"
type_list_attr: "T"
}
output_arg {
name: "output1"
type: DT_BOOL
}
output_arg {
name: "output2"
type: DT_BOOL
}
attr {
name: "T"
type: "bool"
}
attr {
name: "N"
type: "int"
}
summary: "Summary for op Baz."
description: "Description for op Baz."
}
)";
std::unordered_set<string> type_annotate_ops{"Baz"};
OpList op_defs;
OpRegistry::Global()->Export(false, &op_defs);
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
ApiDefMap api_def_map(op_defs);
string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
const string baz_def_line = "def baz(inputs, name=None):";
ExpectHasSubstr(code, baz_def_line);
}
} // namespace
} // namespace tensorflow