Include defaults in the generated raw op signature

This change also ensures that a raw op has a __doc__.

PiperOrigin-RevId: 238544250
This commit is contained in:
Sergei Lebedev 2019-03-14 16:33:08 -07:00 committed by TensorFlower Gardener
parent fcf272ac48
commit 29b136397b
4 changed files with 2218 additions and 2231 deletions

View File

@ -144,7 +144,7 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
const string& num_outputs_expr);
void AddDispatch(const string& prefix);
void AddRawOpExport();
void AddRawOpExport(const string& parameters);
void AddAttrForArg(const string& attr, int arg_index) {
gtl::InsertIfNotPresent(&inferred_attrs_, attr,
@ -300,6 +300,7 @@ string GenEagerPythonOp::Code() {
attrs_.push_back(p.first.GetName());
}
// TODO(slebedev): call AvoidPythonReserved on each param?
param_names_.reserve(params_no_default_.size() + params_with_default_.size());
param_names_.insert(param_names_.begin(), params_no_default_.begin(),
params_no_default_.end());
@ -317,8 +318,7 @@ string GenEagerPythonOp::Code() {
strings::StrAppend(&parameters, param_and_default.first.GetRenameTo(), "=",
param_and_default.second);
}
if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
strings::StrAppend(&parameters, "name=None");
strings::StrAppend(&parameters, parameters.empty() ? "" : ", ", "name=None");
// Add attr_expressions_ for attrs that are params.
for (int i = 0; i < attrs_.size(); ++i) {
@ -639,6 +639,7 @@ bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
if (api_def_.visibility() == ApiDef::VISIBLE) {
strings::StrAppend(&result_, "@_dispatch.add_dispatch_list\n");
}
AddExport();
AddDefLine(function_name_, parameters);
AddDocStringDescription();
@ -671,7 +672,7 @@ bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
AddEagerFunctionTeardown(" ", output_sizes,
true /* execute_record_gradient */);
AddRawOpExport();
AddRawOpExport(parameters);
strings::StrAppend(&result_, "\n\n");
return true;
}
@ -679,8 +680,9 @@ bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
bool GenEagerPythonOp::AddEagerFallbackCode(
const string& parameters, const std::vector<string>& output_sizes,
const string& num_outputs_expr, const string& eager_not_allowed_error) {
AddDefLine(strings::StrCat(function_name_, kEagerFallbackSuffix),
strings::StrCat(parameters, ", ctx=None"));
AddDefLine(
strings::StrCat(function_name_, kEagerFallbackSuffix),
strings::StrCat(parameters, parameters.empty() ? "" : ", ", "ctx=None"));
if (!eager_not_allowed_error.empty()) {
strings::StrAppend(&result_, " ", eager_not_allowed_error);
@ -925,64 +927,31 @@ void GenEagerPythonOp::AddDispatch(const string& prefix) {
strings::StrAppend(&result_, prefix, " raise\n");
}
void GenEagerPythonOp::AddRawOpExport() {
// Create function for python op.
string raw_parameters;
string function_call_parameters;
string inputs;
string attrs;
std::map<string, string> renames;
void GenEagerPythonOp::AddRawOpExport(const string& parameters) {
string arguments;
for (const auto& param_names : param_names_) {
renames.insert({param_names.GetName(), param_names.GetRenameTo()});
}
for (const auto& input_arg : op_def_.input_arg()) {
const string input_arg_name =
python_op_gen_internal::AvoidPythonReserved(input_arg.name());
if (!raw_parameters.empty()) strings::StrAppend(&raw_parameters, ", ");
strings::StrAppend(&raw_parameters, input_arg_name);
if (!inputs.empty()) strings::StrAppend(&inputs, ", ");
strings::StrAppend(&inputs, input_arg_name);
if (!function_call_parameters.empty()) {
strings::StrAppend(&function_call_parameters, ", ");
}
strings::StrAppend(&function_call_parameters, renames[input_arg.name()],
"=", input_arg_name);
}
for (const auto& attr : op_def_.attr()) {
if (inferred_attrs_.find(attr.name()) != inferred_attrs_.end()) continue;
const string attr_name =
python_op_gen_internal::AvoidPythonReserved(attr.name());
if (!raw_parameters.empty()) strings::StrAppend(&raw_parameters, ", ");
strings::StrAppend(&raw_parameters, attr_name);
if (!attrs.empty()) strings::StrAppend(&attrs, ", ");
strings::StrAppend(&attrs, "\"", attr_name, "\", ", attr_name);
if (!function_call_parameters.empty()) {
strings::StrAppend(&function_call_parameters, ", ");
}
strings::StrAppend(&function_call_parameters, renames[attr.name()], "=",
attr_name);
const string renamed = param_names.GetRenameTo();
strings::StrAppend(&arguments, arguments.empty() ? "" : ", ", renamed, "=",
renamed);
}
strings::StrAppend(&arguments, arguments.empty() ? "" : ", ", "name=name");
const string raw_function_name =
python_op_gen_internal::AvoidPythonReserved(op_def_.name());
strings::StrAppend(&result_,
"@_doc_controls.do_not_generate_docs\n@_kwarg_only\ndef ",
raw_function_name, "(", raw_parameters, "):\n");
strings::StrAppend(&result_, "def ", raw_function_name, "(", parameters,
"):\n");
strings::StrAppend(&result_, " return ", function_name_, "(", arguments,
")\n");
// Function body.
strings::StrAppend(&result_, " return ", function_name_, "(",
function_call_parameters, ")\n");
// Copy the __doc__ from the original op and apply the decorators.
strings::StrAppend(&result_, raw_function_name, ".__doc__", " = ",
function_name_, ".__doc__\n");
strings::StrAppend(&result_, raw_function_name, " = ",
"_doc_controls.do_not_generate_docs(_kwarg_only(",
raw_function_name, "))\n");
// Export.
strings::StrAppend(&result_, "tf_export(\"raw_ops.", raw_function_name,
"\")(", raw_function_name, ")\n");
}

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
@ -37,10 +38,27 @@ class RawOpsTest(test.TestCase):
gen_math_ops.Add(1., 1.)
def testRequiresKwargs_providesSuggestion(self):
msg = "possible keys: \\['x', 'y'\\]"
msg = "possible keys: \\['x', 'y', 'name'\\]"
with self.assertRaisesRegexp(TypeError, msg):
gen_math_ops.Add(1., y=2.)
def testName(self):
x = constant_op.constant(1)
op = gen_math_ops.Add(x=x, y=x, name="double")
if not context.executing_eagerly():
# `Tensor.name` is not available in eager.
self.assertEqual(op.name, "double:0")
def testDoc(self):
self.assertEqual(gen_math_ops.add.__doc__, gen_math_ops.Add.__doc__)
def testDefaults(self):
x = constant_op.constant([[True]])
self.assertAllClose(
gen_math_ops.Any(input=x, axis=0),
gen_math_ops.Any(input=x, axis=0, keep_dims=False))
if __name__ == "__main__":
ops.enable_eager_execution()
test.main()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff