PR review changes

This commit is contained in:
rahul-kamat 2020-06-23 23:24:20 +00:00
parent 01fdbb866b
commit 7f0e00817f
6 changed files with 67 additions and 70 deletions

View File

@ -45,7 +45,8 @@ const int kRightMargin = 78;
constexpr char kEagerFallbackSuffix[] = "_eager_fallback";
std::unordered_map<string, string> dtype_type {
// Dtype enums mapped to dtype classes which is the type of each dtype
const std::unordered_map<string, string> dtype_type {
{"_dtypes.float16", "_dtypes.Float16"},
{"_dtypes.half", "_dtypes.Half"},
{"_dtypes.float32", "_dtypes.Float32"},
@ -133,8 +134,8 @@ string TensorPBString(const TensorProto& pb) {
class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
public:
GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
const string& function_name, const bool type_annotate_op)
: python_op_gen_internal::GenPythonOp(op_def, api_def, function_name, type_annotate_op) {
const string& function_name, bool add_type_annotations)
: python_op_gen_internal::GenPythonOp(op_def, api_def, function_name, add_type_annotations) {
op_name_ = function_name_;
absl::ConsumePrefix(&op_name_, "_");
}
@ -160,12 +161,12 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
bool AddEagerFastPathAndGraphCode(const string& parameters,
const std::vector<string>& output_sizes,
const string& eager_not_allowed_error,
std::unordered_map<string, string>& type_annotations);
const std::unordered_map<string, string>& type_annotations);
bool AddEagerFallbackCode(const string& parameters,
const std::vector<string>& output_sizes,
const string& num_outputs_expr,
const string& eager_not_allowed_error,
std::unordered_map<string, string>& type_annotations);
const std::unordered_map<string, string>& type_annotations);
void AddEagerFastPathExecute();
void AddEagerInferredAttrs(const string& indentation);
@ -177,11 +178,11 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
void AddRawOpExport(const string& parameters);
std::unordered_map<string, string> GetTypeAnnotationMap();
std::unordered_map<string, string> GetTypeAnnotations();
void GenerateTypeVars(std::unordered_map<string, string>& type_annotations);
void GenerateTypeVars(const std::unordered_map<string, string>& type_annotations);
void AddReturnTypeAnnotation(std::unordered_map<string, string>& type_annotations);
void AddReturnTypeAnnotation(const std::unordered_map<string, string>& type_annotations);
void AddAttrForArg(const string& attr, int arg_index) {
gtl::InsertIfNotPresent(&inferred_attrs_, attr,
@ -214,8 +215,8 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
};
string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
const string& function_name, const bool type_annotate_op) {
return GenEagerPythonOp(op_def, api_def, function_name, type_annotate_op).Code();
const string& function_name, bool add_type_annotations) {
return GenEagerPythonOp(op_def, api_def, function_name, add_type_annotations).Code();
}
string GenEagerPythonOp::FlattenInputs(
@ -347,8 +348,8 @@ string GenEagerPythonOp::Code() {
std::unordered_map<string, string> type_annotations;
// Only populate map for whitelisted ops
if (type_annotate_op_) {
type_annotations = GetTypeAnnotationMap();
if (add_type_annotations_) {
type_annotations = GetTypeAnnotations();
}
string parameters;
@ -357,33 +358,28 @@ string GenEagerPythonOp::Code() {
if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
strings::StrAppend(&parameters, param.GetRenameTo());
// Add type annotations to param
if (type_annotations.find(param.GetName()) != type_annotations.end()) {
strings::StrAppend(&parameters, ": ", type_annotations[param.GetName()]);
strings::StrAppend(&parameters, ": ", type_annotations.at(param.GetName()));
}
}
// Append to parameters and parameters_with_defaults because multiple functions
// are generated (op and fallback op)
string parameters_with_defaults = parameters;
for (const auto& param_and_default : params_with_default_) {
if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
if (!parameters_with_defaults.empty())
strings::StrAppend(&parameters_with_defaults, ", ");
// Add type annotations to param_and_default
strings::StrAppend(&parameters, param_and_default.first.GetRenameTo());
strings::StrAppend(&parameters_with_defaults, param_and_default.first.GetRenameTo());
if (type_annotations.find(param_and_default.first.GetName()) != type_annotations.end()) {
const string param_type = type_annotations[param_and_default.first.GetName()];
strings::StrAppend(&parameters, param_and_default.first.GetRenameTo(), ": ", param_type);
strings::StrAppend(&parameters_with_defaults,
param_and_default.first.GetRenameTo(), ": ",
param_type, " = ", param_and_default.second);
continue;
const string param_type = type_annotations.at(param_and_default.first.GetName());
// Append to parameters and parameters_with_defaults because multiple functions
// are generated by AddEagerFastPathAndGraphCode() and AddEagerFallbackCode()
strings::StrAppend(&parameters, ": ", param_type);
strings::StrAppend(&parameters_with_defaults, ":", param_type);
}
strings::StrAppend(&parameters, param_and_default.first.GetRenameTo());
strings::StrAppend(&parameters_with_defaults,
param_and_default.first.GetRenameTo(), "=",
strings::StrAppend(&parameters_with_defaults, "=",
param_and_default.second);
}
@ -428,9 +424,9 @@ string GenEagerPythonOp::Code() {
return prelude_ + result_;
}
std::unordered_map<string, string> GenEagerPythonOp::GetTypeAnnotationMap() {
std::unordered_map<string, string> GenEagerPythonOp::GetTypeAnnotations() {
std::unordered_map<string, string> type_annotations;
// Mapping attrs to TypeVars
// Map attrs to TypeVars
for (const auto& attr : op_def_.attr()) {
if (attr.type() == "type") {
const string type_var_name = "TV_" + op_def_.name() + "_" + attr.name();
@ -441,24 +437,26 @@ std::unordered_map<string, string> GenEagerPythonOp::GetTypeAnnotationMap() {
}
}
// Mapping input Tensors to their types
// Map input Tensors to their types
for (const auto& arg : op_def_.input_arg()) {
// Do not add type annotations to args that accept a sequence of Tensors
if (!arg.number_attr().empty()) continue;
// TODO(rahulkamat): Add type annotations to args that accept a sequence of Tensors
if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) continue;
type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
}
// Mapping output Tensor to its type
// TODO(rahulkamat): Add type annotations to handle return types of a sequence of Tensors.
// Map output Tensor to its type
if (op_def_.output_arg_size() == 1) {
const auto& arg = op_def_.output_arg(0);
type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
if (arg.number_attr().empty() && arg.type_list_attr().empty())
type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
}
return type_annotations;
}
// Generate TypeVars using attrs
void GenEagerPythonOp::GenerateTypeVars(std::unordered_map<string, string>& type_annotations) {
void GenEagerPythonOp::GenerateTypeVars(const std::unordered_map<string, string>& type_annotations) {
bool added_typevar = false;
for (const auto& attr : op_def_.attr()) {
if (attr.type() == "type") {
@ -466,12 +464,10 @@ void GenEagerPythonOp::GenerateTypeVars(std::unordered_map<string, string>& type
for (int t : attr.allowed_values().list().type()) {
DataType dtype = static_cast<DataType>(t);
const string py_dtype = python_op_gen_internal::DataTypeToPython(dtype, "_dtypes.");
if (dtype_type.find(py_dtype) != dtype_type.end()) {
allowed_types.emplace_back(dtype_type[py_dtype]);
}
allowed_types.emplace_back(dtype_type.at(py_dtype));
}
// If all dtypes are allowed, add them all
// When a Tensor does not have any dtypes specified, all dtypes are allowed
if (allowed_types.empty()) {
for (std::pair<string, string> map_dtype : dtype_type) {
allowed_types.emplace_back(map_dtype.second);
@ -486,7 +482,7 @@ void GenEagerPythonOp::GenerateTypeVars(std::unordered_map<string, string>& type
strings::StrAppend(&typevar_dtypes, *it);
}
const string type_var_name = type_annotations[attr.name()];
const string type_var_name = type_annotations.at(attr.name());
strings::StrAppend(&result_, type_var_name, " = TypeVar(\"", type_var_name, "\", ", typevar_dtypes,")\n");
added_typevar = true;
}
@ -495,14 +491,15 @@ void GenEagerPythonOp::GenerateTypeVars(std::unordered_map<string, string>& type
if (added_typevar) strings::StrAppend(&result_, "\n");
}
// TODO(rahulkamat): Modify AddDefLine() to add return type annotation
void GenEagerPythonOp::AddReturnTypeAnnotation(std::unordered_map<string, string>& type_annotations) {
void GenEagerPythonOp::AddReturnTypeAnnotation(const std::unordered_map<string, string>& type_annotations) {
if (op_def_.output_arg_size() == 1) {
const auto& arg = op_def_.output_arg(0);
// Add type annotations to param
if (type_annotations.find(arg.name()) != type_annotations.end()) {
if (arg.number_attr().empty() && arg.type_list_attr().empty()) {
const string return_type = type_annotations.at(arg.name());
// TODO(rahulkamat): Modify AddDefLine() to add return type annotation to avoid
// erasing ":\n" from the end of the def line
result_.erase(result_.length() - 2);
strings::StrAppend(&result_, " -> ", type_annotations[arg.name()], ":\n");
strings::StrAppend(&result_, " -> ", return_type, ":\n");
}
}
}
@ -829,8 +826,9 @@ void GenEagerPythonOp::AddEagerFunctionTeardown(
bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
const string& parameters, const std::vector<string>& output_sizes,
const string& eager_not_allowed_error, std::unordered_map<string, string>& type_annotations) {
if (type_annotate_op_) {
const string& eager_not_allowed_error,
const std::unordered_map<string, string>& type_annotations) {
if (add_type_annotations_) {
GenerateTypeVars(type_annotations);
}
if (api_def_.visibility() == ApiDef::VISIBLE) {
@ -839,7 +837,7 @@ bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
AddExport();
AddDefLine(function_name_, parameters);
if (type_annotate_op_) {
if (add_type_annotations_) {
AddReturnTypeAnnotation(type_annotations);
}
AddDocStringDescription();
@ -877,11 +875,11 @@ bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
bool GenEagerPythonOp::AddEagerFallbackCode(
const string& parameters, const std::vector<string>& output_sizes,
const string& num_outputs_expr, const string& eager_not_allowed_error,
std::unordered_map<string, string>& type_annotations) {
const std::unordered_map<string, string>& type_annotations) {
AddDefLine(
strings::StrCat(function_name_, kEagerFallbackSuffix),
strings::StrCat(parameters, parameters.empty() ? "" : ", ", "ctx"));
if (type_annotate_op_) {
if (add_type_annotations_) {
AddReturnTypeAnnotation(type_annotations);
}
if (!eager_not_allowed_error.empty()) {
@ -1133,7 +1131,7 @@ void GenEagerPythonOp::AddRawOpExport(const string& parameters) {
string GetPythonOpsImpl(const OpList& ops, const ApiDefMap& api_defs,
const std::vector<string>& hidden_ops,
const string& source_file_name = "",
std::unordered_set<string> type_annotate_ops = {}) {
const std::unordered_set<string> type_annotate_ops = {}) {
string result;
// Header
// TODO(josh11b): Mention the library for which wrappers are being generated.
@ -1211,10 +1209,11 @@ from typing import TypeVar
continue;
}
const bool type_annotate_op = type_annotate_ops.find(op_def.name()) != type_annotate_ops.end();
auto iter = type_annotate_ops.find(op_def.name());
bool add_type_annotations = iter != type_annotate_ops.end();
strings::StrAppend(&result,
GetEagerPythonOp(op_def, *api_def, function_name, type_annotate_op));
GetEagerPythonOp(op_def, *api_def, function_name, add_type_annotations));
}
return result;
@ -1225,14 +1224,14 @@ from typing import TypeVar
string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
const std::vector<string>& hidden_ops,
const string& source_file_name,
std::unordered_set<string> type_annotate_ops) {
const std::unordered_set<string> type_annotate_ops) {
return GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name, type_annotate_ops);
}
void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
const std::vector<string>& hidden_ops,
const string& source_file_name,
std::unordered_set<string> type_annotate_ops) {
const std::unordered_set<string> type_annotate_ops) {
printf("%s",
GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name, type_annotate_ops).c_str());
}
@ -1245,16 +1244,14 @@ string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
return GetPythonOpsImpl(ops, api_def_map, {});
}
string GetArgAnnotation(const auto& arg, std::unordered_map<string, string>& type_annotations) {
if (type_annotations.find(arg.type_attr()) != type_annotations.end()) {
string GetArgAnnotation(const auto& arg, const std::unordered_map<string, string>& type_annotations) {
if (!arg.type_attr().empty()) {
// Get the correct TypeVar if arg maps to an attr
return "_ops.Tensor[" + type_annotations[arg.type_attr()] + "]";
return "_ops.Tensor[" + type_annotations.at(arg.type_attr()) + "]";
} else {
// Get the dtype of the Tensor
const string py_dtype = python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
if (dtype_type.find(py_dtype) != dtype_type.end()) {
return "_ops.Tensor[" + dtype_type[py_dtype] + "]";
}
return "_ops.Tensor[" + dtype_type.at(py_dtype) + "]";
}
return "Any";

View File

@ -33,7 +33,7 @@ namespace tensorflow {
string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
const std::vector<string>& hidden_ops,
const string& source_file_name,
std::unordered_set<string> type_annotate_ops);
const std::unordered_set<string> type_annotate_ops);
// Prints the output of GetPrintOps to stdout.
// hidden_ops should be a list of Op names that should get a leading _
@ -43,7 +43,7 @@ string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
const std::vector<string>& hidden_ops,
const string& source_file_name,
std::unordered_set<string> type_annotate_ops);
const std::unordered_set<string> type_annotate_ops);
// Get the python wrappers for a list of ops in a OpList.
// `op_list_buf` should be a pointer to a buffer containing
@ -55,7 +55,7 @@ string GetPythonWrappers(const char* op_list_buf, size_t op_list_len);
// `arg` should be an input or output of an op
// `type_annotations` should contain attr names mapped to TypeVar names
string GetArgAnnotation(const auto& arg,
std::unordered_map<string, string>& type_annotations);
const std::unordered_map<string, string>& type_annotations);
} // namespace tensorflow

View File

@ -513,11 +513,11 @@ const ApiDef::Attr* FindAttr(StringPiece name, const ApiDef& api_def) {
}
GenPythonOp::GenPythonOp(const OpDef& op_def, const ApiDef& api_def,
const string& function_name, const bool type_annotate_op)
const string& function_name, bool add_type_annotations)
: op_def_(op_def),
api_def_(api_def),
function_name_(function_name),
type_annotate_op_(type_annotate_op),
add_type_annotations_(add_type_annotations),
num_outs_(op_def.output_arg_size()) {}
GenPythonOp::~GenPythonOp() {}

View File

@ -71,7 +71,7 @@ class ParamNames {
class GenPythonOp {
public:
GenPythonOp(const OpDef& op_def, const ApiDef& api_def,
const string& function_name, const bool type_annotate_op_);
const string& function_name, bool add_type_annotations_);
virtual ~GenPythonOp();
virtual string Code();
@ -98,7 +98,7 @@ class GenPythonOp {
const OpDef& op_def_;
const ApiDef& api_def_;
const string function_name_;
const bool type_annotate_op_;
bool add_type_annotations_;
const int num_outs_;
// Return value from Code() is prelude_ + result_.

View File

@ -109,7 +109,7 @@ void PrintAllPythonOps(const std::vector<string>& op_list,
const std::vector<string>& api_def_dirs,
const string& source_file_name,
bool op_list_is_whitelist,
std::unordered_set<string> type_annotate_ops) {
const std::unordered_set<string> type_annotate_ops) {
OpList ops;
OpRegistry::Global()->Export(false, &ops);
@ -159,7 +159,7 @@ int main(int argc, char* argv[]) {
argv[1], ",", tensorflow::str_util::SkipEmpty());
// Add op name to this set to add type annotations
std::unordered_set<tensorflow::string> type_annotate_ops {
const std::unordered_set<tensorflow::string> type_annotate_ops {
};
if (argc == 2) {

View File

@ -261,7 +261,7 @@ TEST(PythonOpGen, TypeAnnotateDefaultParams) {
string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
const string params = "def foo_bar(x: _ops.Tensor[_dtypes.Float32], t: TV_FooBar_t, var1: bool = False, var2: int = 0, name=None)";
const string params = "def foo_bar(x: _ops.Tensor[_dtypes.Float32], t: TV_FooBar_t, var1:bool=False, var2:int=0, name=None)";
const string params_fallback = "def foo_bar_eager_fallback(x: _ops.Tensor[_dtypes.Float32], t: TV_FooBar_t, var1: bool, var2: int, name, ctx)";
ExpectHasSubstr(code, params);
ExpectHasSubstr(code, params_fallback);