PR review changes
This commit is contained in:
parent
01fdbb866b
commit
7f0e00817f
@ -45,7 +45,8 @@ const int kRightMargin = 78;
|
||||
|
||||
constexpr char kEagerFallbackSuffix[] = "_eager_fallback";
|
||||
|
||||
std::unordered_map<string, string> dtype_type {
|
||||
// Dtype enums mapped to dtype classes which is the type of each dtype
|
||||
const std::unordered_map<string, string> dtype_type {
|
||||
{"_dtypes.float16", "_dtypes.Float16"},
|
||||
{"_dtypes.half", "_dtypes.Half"},
|
||||
{"_dtypes.float32", "_dtypes.Float32"},
|
||||
@ -133,8 +134,8 @@ string TensorPBString(const TensorProto& pb) {
|
||||
class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
|
||||
public:
|
||||
GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
|
||||
const string& function_name, const bool type_annotate_op)
|
||||
: python_op_gen_internal::GenPythonOp(op_def, api_def, function_name, type_annotate_op) {
|
||||
const string& function_name, bool add_type_annotations)
|
||||
: python_op_gen_internal::GenPythonOp(op_def, api_def, function_name, add_type_annotations) {
|
||||
op_name_ = function_name_;
|
||||
absl::ConsumePrefix(&op_name_, "_");
|
||||
}
|
||||
@ -160,12 +161,12 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
|
||||
bool AddEagerFastPathAndGraphCode(const string& parameters,
|
||||
const std::vector<string>& output_sizes,
|
||||
const string& eager_not_allowed_error,
|
||||
std::unordered_map<string, string>& type_annotations);
|
||||
const std::unordered_map<string, string>& type_annotations);
|
||||
bool AddEagerFallbackCode(const string& parameters,
|
||||
const std::vector<string>& output_sizes,
|
||||
const string& num_outputs_expr,
|
||||
const string& eager_not_allowed_error,
|
||||
std::unordered_map<string, string>& type_annotations);
|
||||
const std::unordered_map<string, string>& type_annotations);
|
||||
void AddEagerFastPathExecute();
|
||||
|
||||
void AddEagerInferredAttrs(const string& indentation);
|
||||
@ -177,11 +178,11 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
|
||||
|
||||
void AddRawOpExport(const string& parameters);
|
||||
|
||||
std::unordered_map<string, string> GetTypeAnnotationMap();
|
||||
std::unordered_map<string, string> GetTypeAnnotations();
|
||||
|
||||
void GenerateTypeVars(std::unordered_map<string, string>& type_annotations);
|
||||
void GenerateTypeVars(const std::unordered_map<string, string>& type_annotations);
|
||||
|
||||
void AddReturnTypeAnnotation(std::unordered_map<string, string>& type_annotations);
|
||||
void AddReturnTypeAnnotation(const std::unordered_map<string, string>& type_annotations);
|
||||
|
||||
void AddAttrForArg(const string& attr, int arg_index) {
|
||||
gtl::InsertIfNotPresent(&inferred_attrs_, attr,
|
||||
@ -214,8 +215,8 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
|
||||
};
|
||||
|
||||
string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
|
||||
const string& function_name, const bool type_annotate_op) {
|
||||
return GenEagerPythonOp(op_def, api_def, function_name, type_annotate_op).Code();
|
||||
const string& function_name, bool add_type_annotations) {
|
||||
return GenEagerPythonOp(op_def, api_def, function_name, add_type_annotations).Code();
|
||||
}
|
||||
|
||||
string GenEagerPythonOp::FlattenInputs(
|
||||
@ -347,8 +348,8 @@ string GenEagerPythonOp::Code() {
|
||||
|
||||
std::unordered_map<string, string> type_annotations;
|
||||
// Only populate map for whitelisted ops
|
||||
if (type_annotate_op_) {
|
||||
type_annotations = GetTypeAnnotationMap();
|
||||
if (add_type_annotations_) {
|
||||
type_annotations = GetTypeAnnotations();
|
||||
}
|
||||
|
||||
string parameters;
|
||||
@ -357,33 +358,28 @@ string GenEagerPythonOp::Code() {
|
||||
if (!parameters.empty()) strings::StrAppend(¶meters, ", ");
|
||||
strings::StrAppend(¶meters, param.GetRenameTo());
|
||||
|
||||
// Add type annotations to param
|
||||
if (type_annotations.find(param.GetName()) != type_annotations.end()) {
|
||||
strings::StrAppend(¶meters, ": ", type_annotations[param.GetName()]);
|
||||
strings::StrAppend(¶meters, ": ", type_annotations.at(param.GetName()));
|
||||
}
|
||||
}
|
||||
|
||||
// Append to parameters and parameters_with_defaults because multiple functions
|
||||
// are generated (op and fallback op)
|
||||
string parameters_with_defaults = parameters;
|
||||
for (const auto& param_and_default : params_with_default_) {
|
||||
if (!parameters.empty()) strings::StrAppend(¶meters, ", ");
|
||||
if (!parameters_with_defaults.empty())
|
||||
strings::StrAppend(¶meters_with_defaults, ", ");
|
||||
|
||||
// Add type annotations to param_and_default
|
||||
strings::StrAppend(¶meters, param_and_default.first.GetRenameTo());
|
||||
strings::StrAppend(¶meters_with_defaults, param_and_default.first.GetRenameTo());
|
||||
if (type_annotations.find(param_and_default.first.GetName()) != type_annotations.end()) {
|
||||
const string param_type = type_annotations[param_and_default.first.GetName()];
|
||||
strings::StrAppend(¶meters, param_and_default.first.GetRenameTo(), ": ", param_type);
|
||||
strings::StrAppend(¶meters_with_defaults,
|
||||
param_and_default.first.GetRenameTo(), ": ",
|
||||
param_type, " = ", param_and_default.second);
|
||||
continue;
|
||||
const string param_type = type_annotations.at(param_and_default.first.GetName());
|
||||
// Append to parameters and parameters_with_defaults because multiple functions
|
||||
// are generated by AddEagerFastPathAndGraphCode() and AddEagerFallbackCode()
|
||||
strings::StrAppend(¶meters, ": ", param_type);
|
||||
strings::StrAppend(¶meters_with_defaults, ":", param_type);
|
||||
}
|
||||
|
||||
strings::StrAppend(¶meters, param_and_default.first.GetRenameTo());
|
||||
strings::StrAppend(¶meters_with_defaults,
|
||||
param_and_default.first.GetRenameTo(), "=",
|
||||
strings::StrAppend(¶meters_with_defaults, "=",
|
||||
param_and_default.second);
|
||||
}
|
||||
|
||||
@ -428,9 +424,9 @@ string GenEagerPythonOp::Code() {
|
||||
return prelude_ + result_;
|
||||
}
|
||||
|
||||
std::unordered_map<string, string> GenEagerPythonOp::GetTypeAnnotationMap() {
|
||||
std::unordered_map<string, string> GenEagerPythonOp::GetTypeAnnotations() {
|
||||
std::unordered_map<string, string> type_annotations;
|
||||
// Mapping attrs to TypeVars
|
||||
// Map attrs to TypeVars
|
||||
for (const auto& attr : op_def_.attr()) {
|
||||
if (attr.type() == "type") {
|
||||
const string type_var_name = "TV_" + op_def_.name() + "_" + attr.name();
|
||||
@ -441,24 +437,26 @@ std::unordered_map<string, string> GenEagerPythonOp::GetTypeAnnotationMap() {
|
||||
}
|
||||
}
|
||||
|
||||
// Mapping input Tensors to their types
|
||||
// Map input Tensors to their types
|
||||
for (const auto& arg : op_def_.input_arg()) {
|
||||
// Do not add type annotations to args that accept a sequence of Tensors
|
||||
if (!arg.number_attr().empty()) continue;
|
||||
// TODO(rahulkamat): Add type annotations to args that accept a sequence of Tensors
|
||||
if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) continue;
|
||||
type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
|
||||
}
|
||||
|
||||
// Mapping output Tensor to its type
|
||||
// TODO(rahulkamat): Add type annotations to handle return types of a sequence of Tensors.
|
||||
// Map output Tensor to its type
|
||||
if (op_def_.output_arg_size() == 1) {
|
||||
const auto& arg = op_def_.output_arg(0);
|
||||
type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
|
||||
if (arg.number_attr().empty() && arg.type_list_attr().empty())
|
||||
type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
|
||||
}
|
||||
|
||||
return type_annotations;
|
||||
}
|
||||
|
||||
// Generate TypeVars using attrs
|
||||
void GenEagerPythonOp::GenerateTypeVars(std::unordered_map<string, string>& type_annotations) {
|
||||
void GenEagerPythonOp::GenerateTypeVars(const std::unordered_map<string, string>& type_annotations) {
|
||||
bool added_typevar = false;
|
||||
for (const auto& attr : op_def_.attr()) {
|
||||
if (attr.type() == "type") {
|
||||
@ -466,12 +464,10 @@ void GenEagerPythonOp::GenerateTypeVars(std::unordered_map<string, string>& type
|
||||
for (int t : attr.allowed_values().list().type()) {
|
||||
DataType dtype = static_cast<DataType>(t);
|
||||
const string py_dtype = python_op_gen_internal::DataTypeToPython(dtype, "_dtypes.");
|
||||
if (dtype_type.find(py_dtype) != dtype_type.end()) {
|
||||
allowed_types.emplace_back(dtype_type[py_dtype]);
|
||||
}
|
||||
allowed_types.emplace_back(dtype_type.at(py_dtype));
|
||||
}
|
||||
|
||||
// If all dtypes are allowed, add them all
|
||||
// When a Tensor does not have any dtypes specified, all dtypes are allowed
|
||||
if (allowed_types.empty()) {
|
||||
for (std::pair<string, string> map_dtype : dtype_type) {
|
||||
allowed_types.emplace_back(map_dtype.second);
|
||||
@ -486,7 +482,7 @@ void GenEagerPythonOp::GenerateTypeVars(std::unordered_map<string, string>& type
|
||||
strings::StrAppend(&typevar_dtypes, *it);
|
||||
}
|
||||
|
||||
const string type_var_name = type_annotations[attr.name()];
|
||||
const string type_var_name = type_annotations.at(attr.name());
|
||||
strings::StrAppend(&result_, type_var_name, " = TypeVar(\"", type_var_name, "\", ", typevar_dtypes,")\n");
|
||||
added_typevar = true;
|
||||
}
|
||||
@ -495,14 +491,15 @@ void GenEagerPythonOp::GenerateTypeVars(std::unordered_map<string, string>& type
|
||||
if (added_typevar) strings::StrAppend(&result_, "\n");
|
||||
}
|
||||
|
||||
// TODO(rahulkamat): Modify AddDefLine() to add return type annotation
|
||||
void GenEagerPythonOp::AddReturnTypeAnnotation(std::unordered_map<string, string>& type_annotations) {
|
||||
void GenEagerPythonOp::AddReturnTypeAnnotation(const std::unordered_map<string, string>& type_annotations) {
|
||||
if (op_def_.output_arg_size() == 1) {
|
||||
const auto& arg = op_def_.output_arg(0);
|
||||
// Add type annotations to param
|
||||
if (type_annotations.find(arg.name()) != type_annotations.end()) {
|
||||
if (arg.number_attr().empty() && arg.type_list_attr().empty()) {
|
||||
const string return_type = type_annotations.at(arg.name());
|
||||
// TODO(rahulkamat): Modify AddDefLine() to add return type annotation to avoid
|
||||
// erasing ":\n" from the end of the def line
|
||||
result_.erase(result_.length() - 2);
|
||||
strings::StrAppend(&result_, " -> ", type_annotations[arg.name()], ":\n");
|
||||
strings::StrAppend(&result_, " -> ", return_type, ":\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -829,8 +826,9 @@ void GenEagerPythonOp::AddEagerFunctionTeardown(
|
||||
|
||||
bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
|
||||
const string& parameters, const std::vector<string>& output_sizes,
|
||||
const string& eager_not_allowed_error, std::unordered_map<string, string>& type_annotations) {
|
||||
if (type_annotate_op_) {
|
||||
const string& eager_not_allowed_error,
|
||||
const std::unordered_map<string, string>& type_annotations) {
|
||||
if (add_type_annotations_) {
|
||||
GenerateTypeVars(type_annotations);
|
||||
}
|
||||
if (api_def_.visibility() == ApiDef::VISIBLE) {
|
||||
@ -839,7 +837,7 @@ bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
|
||||
|
||||
AddExport();
|
||||
AddDefLine(function_name_, parameters);
|
||||
if (type_annotate_op_) {
|
||||
if (add_type_annotations_) {
|
||||
AddReturnTypeAnnotation(type_annotations);
|
||||
}
|
||||
AddDocStringDescription();
|
||||
@ -877,11 +875,11 @@ bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
|
||||
bool GenEagerPythonOp::AddEagerFallbackCode(
|
||||
const string& parameters, const std::vector<string>& output_sizes,
|
||||
const string& num_outputs_expr, const string& eager_not_allowed_error,
|
||||
std::unordered_map<string, string>& type_annotations) {
|
||||
const std::unordered_map<string, string>& type_annotations) {
|
||||
AddDefLine(
|
||||
strings::StrCat(function_name_, kEagerFallbackSuffix),
|
||||
strings::StrCat(parameters, parameters.empty() ? "" : ", ", "ctx"));
|
||||
if (type_annotate_op_) {
|
||||
if (add_type_annotations_) {
|
||||
AddReturnTypeAnnotation(type_annotations);
|
||||
}
|
||||
if (!eager_not_allowed_error.empty()) {
|
||||
@ -1133,7 +1131,7 @@ void GenEagerPythonOp::AddRawOpExport(const string& parameters) {
|
||||
string GetPythonOpsImpl(const OpList& ops, const ApiDefMap& api_defs,
|
||||
const std::vector<string>& hidden_ops,
|
||||
const string& source_file_name = "",
|
||||
std::unordered_set<string> type_annotate_ops = {}) {
|
||||
const std::unordered_set<string> type_annotate_ops = {}) {
|
||||
string result;
|
||||
// Header
|
||||
// TODO(josh11b): Mention the library for which wrappers are being generated.
|
||||
@ -1211,10 +1209,11 @@ from typing import TypeVar
|
||||
continue;
|
||||
}
|
||||
|
||||
const bool type_annotate_op = type_annotate_ops.find(op_def.name()) != type_annotate_ops.end();
|
||||
auto iter = type_annotate_ops.find(op_def.name());
|
||||
bool add_type_annotations = iter != type_annotate_ops.end();
|
||||
|
||||
strings::StrAppend(&result,
|
||||
GetEagerPythonOp(op_def, *api_def, function_name, type_annotate_op));
|
||||
GetEagerPythonOp(op_def, *api_def, function_name, add_type_annotations));
|
||||
}
|
||||
|
||||
return result;
|
||||
@ -1225,14 +1224,14 @@ from typing import TypeVar
|
||||
string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
||||
const std::vector<string>& hidden_ops,
|
||||
const string& source_file_name,
|
||||
std::unordered_set<string> type_annotate_ops) {
|
||||
const std::unordered_set<string> type_annotate_ops) {
|
||||
return GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name, type_annotate_ops);
|
||||
}
|
||||
|
||||
void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
||||
const std::vector<string>& hidden_ops,
|
||||
const string& source_file_name,
|
||||
std::unordered_set<string> type_annotate_ops) {
|
||||
const std::unordered_set<string> type_annotate_ops) {
|
||||
printf("%s",
|
||||
GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name, type_annotate_ops).c_str());
|
||||
}
|
||||
@ -1245,16 +1244,14 @@ string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
|
||||
return GetPythonOpsImpl(ops, api_def_map, {});
|
||||
}
|
||||
|
||||
string GetArgAnnotation(const auto& arg, std::unordered_map<string, string>& type_annotations) {
|
||||
if (type_annotations.find(arg.type_attr()) != type_annotations.end()) {
|
||||
string GetArgAnnotation(const auto& arg, const std::unordered_map<string, string>& type_annotations) {
|
||||
if (!arg.type_attr().empty()) {
|
||||
// Get the correct TypeVar if arg maps to an attr
|
||||
return "_ops.Tensor[" + type_annotations[arg.type_attr()] + "]";
|
||||
return "_ops.Tensor[" + type_annotations.at(arg.type_attr()) + "]";
|
||||
} else {
|
||||
// Get the dtype of the Tensor
|
||||
const string py_dtype = python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
|
||||
if (dtype_type.find(py_dtype) != dtype_type.end()) {
|
||||
return "_ops.Tensor[" + dtype_type[py_dtype] + "]";
|
||||
}
|
||||
return "_ops.Tensor[" + dtype_type.at(py_dtype) + "]";
|
||||
}
|
||||
|
||||
return "Any";
|
||||
|
@ -33,7 +33,7 @@ namespace tensorflow {
|
||||
string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
||||
const std::vector<string>& hidden_ops,
|
||||
const string& source_file_name,
|
||||
std::unordered_set<string> type_annotate_ops);
|
||||
const std::unordered_set<string> type_annotate_ops);
|
||||
|
||||
// Prints the output of GetPrintOps to stdout.
|
||||
// hidden_ops should be a list of Op names that should get a leading _
|
||||
@ -43,7 +43,7 @@ string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
||||
void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
|
||||
const std::vector<string>& hidden_ops,
|
||||
const string& source_file_name,
|
||||
std::unordered_set<string> type_annotate_ops);
|
||||
const std::unordered_set<string> type_annotate_ops);
|
||||
|
||||
// Get the python wrappers for a list of ops in a OpList.
|
||||
// `op_list_buf` should be a pointer to a buffer containing
|
||||
@ -55,7 +55,7 @@ string GetPythonWrappers(const char* op_list_buf, size_t op_list_len);
|
||||
// `arg` should be an input or output of an op
|
||||
// `type_annotations` should contain attr names mapped to TypeVar names
|
||||
string GetArgAnnotation(const auto& arg,
|
||||
std::unordered_map<string, string>& type_annotations);
|
||||
const std::unordered_map<string, string>& type_annotations);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -513,11 +513,11 @@ const ApiDef::Attr* FindAttr(StringPiece name, const ApiDef& api_def) {
|
||||
}
|
||||
|
||||
GenPythonOp::GenPythonOp(const OpDef& op_def, const ApiDef& api_def,
|
||||
const string& function_name, const bool type_annotate_op)
|
||||
const string& function_name, bool add_type_annotations)
|
||||
: op_def_(op_def),
|
||||
api_def_(api_def),
|
||||
function_name_(function_name),
|
||||
type_annotate_op_(type_annotate_op),
|
||||
add_type_annotations_(add_type_annotations),
|
||||
num_outs_(op_def.output_arg_size()) {}
|
||||
|
||||
GenPythonOp::~GenPythonOp() {}
|
||||
|
@ -71,7 +71,7 @@ class ParamNames {
|
||||
class GenPythonOp {
|
||||
public:
|
||||
GenPythonOp(const OpDef& op_def, const ApiDef& api_def,
|
||||
const string& function_name, const bool type_annotate_op_);
|
||||
const string& function_name, bool add_type_annotations_);
|
||||
virtual ~GenPythonOp();
|
||||
|
||||
virtual string Code();
|
||||
@ -98,7 +98,7 @@ class GenPythonOp {
|
||||
const OpDef& op_def_;
|
||||
const ApiDef& api_def_;
|
||||
const string function_name_;
|
||||
const bool type_annotate_op_;
|
||||
bool add_type_annotations_;
|
||||
const int num_outs_;
|
||||
|
||||
// Return value from Code() is prelude_ + result_.
|
||||
|
@ -109,7 +109,7 @@ void PrintAllPythonOps(const std::vector<string>& op_list,
|
||||
const std::vector<string>& api_def_dirs,
|
||||
const string& source_file_name,
|
||||
bool op_list_is_whitelist,
|
||||
std::unordered_set<string> type_annotate_ops) {
|
||||
const std::unordered_set<string> type_annotate_ops) {
|
||||
OpList ops;
|
||||
OpRegistry::Global()->Export(false, &ops);
|
||||
|
||||
@ -159,7 +159,7 @@ int main(int argc, char* argv[]) {
|
||||
argv[1], ",", tensorflow::str_util::SkipEmpty());
|
||||
|
||||
// Add op name to this set to add type annotations
|
||||
std::unordered_set<tensorflow::string> type_annotate_ops {
|
||||
const std::unordered_set<tensorflow::string> type_annotate_ops {
|
||||
};
|
||||
|
||||
if (argc == 2) {
|
||||
|
@ -261,7 +261,7 @@ TEST(PythonOpGen, TypeAnnotateDefaultParams) {
|
||||
|
||||
string code = GetPythonOps(op_defs, api_def_map, {}, "", type_annotate_ops);
|
||||
|
||||
const string params = "def foo_bar(x: _ops.Tensor[_dtypes.Float32], t: TV_FooBar_t, var1: bool = False, var2: int = 0, name=None)";
|
||||
const string params = "def foo_bar(x: _ops.Tensor[_dtypes.Float32], t: TV_FooBar_t, var1:bool=False, var2:int=0, name=None)";
|
||||
const string params_fallback = "def foo_bar_eager_fallback(x: _ops.Tensor[_dtypes.Float32], t: TV_FooBar_t, var1: bool, var2: int, name, ctx)";
|
||||
ExpectHasSubstr(code, params);
|
||||
ExpectHasSubstr(code, params_fallback);
|
||||
|
Loading…
Reference in New Issue
Block a user