Include defaults in the generated raw op signature
This change also ensures that a raw op has a __doc__. PiperOrigin-RevId: 238544250
This commit is contained in:
parent
fcf272ac48
commit
29b136397b
@ -144,7 +144,7 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
|
||||
const string& num_outputs_expr);
|
||||
void AddDispatch(const string& prefix);
|
||||
|
||||
void AddRawOpExport();
|
||||
void AddRawOpExport(const string& parameters);
|
||||
|
||||
void AddAttrForArg(const string& attr, int arg_index) {
|
||||
gtl::InsertIfNotPresent(&inferred_attrs_, attr,
|
||||
@ -300,6 +300,7 @@ string GenEagerPythonOp::Code() {
|
||||
attrs_.push_back(p.first.GetName());
|
||||
}
|
||||
|
||||
// TODO(slebedev): call AvoidPythonReserved on each param?
|
||||
param_names_.reserve(params_no_default_.size() + params_with_default_.size());
|
||||
param_names_.insert(param_names_.begin(), params_no_default_.begin(),
|
||||
params_no_default_.end());
|
||||
@ -317,8 +318,7 @@ string GenEagerPythonOp::Code() {
|
||||
strings::StrAppend(¶meters, param_and_default.first.GetRenameTo(), "=",
|
||||
param_and_default.second);
|
||||
}
|
||||
if (!parameters.empty()) strings::StrAppend(¶meters, ", ");
|
||||
strings::StrAppend(¶meters, "name=None");
|
||||
strings::StrAppend(¶meters, parameters.empty() ? "" : ", ", "name=None");
|
||||
|
||||
// Add attr_expressions_ for attrs that are params.
|
||||
for (int i = 0; i < attrs_.size(); ++i) {
|
||||
@ -639,6 +639,7 @@ bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
|
||||
if (api_def_.visibility() == ApiDef::VISIBLE) {
|
||||
strings::StrAppend(&result_, "@_dispatch.add_dispatch_list\n");
|
||||
}
|
||||
|
||||
AddExport();
|
||||
AddDefLine(function_name_, parameters);
|
||||
AddDocStringDescription();
|
||||
@ -671,7 +672,7 @@ bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
|
||||
AddEagerFunctionTeardown(" ", output_sizes,
|
||||
true /* execute_record_gradient */);
|
||||
|
||||
AddRawOpExport();
|
||||
AddRawOpExport(parameters);
|
||||
strings::StrAppend(&result_, "\n\n");
|
||||
return true;
|
||||
}
|
||||
@ -679,8 +680,9 @@ bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
|
||||
bool GenEagerPythonOp::AddEagerFallbackCode(
|
||||
const string& parameters, const std::vector<string>& output_sizes,
|
||||
const string& num_outputs_expr, const string& eager_not_allowed_error) {
|
||||
AddDefLine(strings::StrCat(function_name_, kEagerFallbackSuffix),
|
||||
strings::StrCat(parameters, ", ctx=None"));
|
||||
AddDefLine(
|
||||
strings::StrCat(function_name_, kEagerFallbackSuffix),
|
||||
strings::StrCat(parameters, parameters.empty() ? "" : ", ", "ctx=None"));
|
||||
|
||||
if (!eager_not_allowed_error.empty()) {
|
||||
strings::StrAppend(&result_, " ", eager_not_allowed_error);
|
||||
@ -925,64 +927,31 @@ void GenEagerPythonOp::AddDispatch(const string& prefix) {
|
||||
strings::StrAppend(&result_, prefix, " raise\n");
|
||||
}
|
||||
|
||||
void GenEagerPythonOp::AddRawOpExport() {
|
||||
// Create function for python op.
|
||||
string raw_parameters;
|
||||
string function_call_parameters;
|
||||
string inputs;
|
||||
string attrs;
|
||||
|
||||
std::map<string, string> renames;
|
||||
|
||||
void GenEagerPythonOp::AddRawOpExport(const string& parameters) {
|
||||
string arguments;
|
||||
for (const auto& param_names : param_names_) {
|
||||
renames.insert({param_names.GetName(), param_names.GetRenameTo()});
|
||||
}
|
||||
|
||||
for (const auto& input_arg : op_def_.input_arg()) {
|
||||
const string input_arg_name =
|
||||
python_op_gen_internal::AvoidPythonReserved(input_arg.name());
|
||||
if (!raw_parameters.empty()) strings::StrAppend(&raw_parameters, ", ");
|
||||
strings::StrAppend(&raw_parameters, input_arg_name);
|
||||
|
||||
if (!inputs.empty()) strings::StrAppend(&inputs, ", ");
|
||||
strings::StrAppend(&inputs, input_arg_name);
|
||||
|
||||
if (!function_call_parameters.empty()) {
|
||||
strings::StrAppend(&function_call_parameters, ", ");
|
||||
}
|
||||
strings::StrAppend(&function_call_parameters, renames[input_arg.name()],
|
||||
"=", input_arg_name);
|
||||
}
|
||||
for (const auto& attr : op_def_.attr()) {
|
||||
if (inferred_attrs_.find(attr.name()) != inferred_attrs_.end()) continue;
|
||||
|
||||
const string attr_name =
|
||||
python_op_gen_internal::AvoidPythonReserved(attr.name());
|
||||
|
||||
if (!raw_parameters.empty()) strings::StrAppend(&raw_parameters, ", ");
|
||||
strings::StrAppend(&raw_parameters, attr_name);
|
||||
|
||||
if (!attrs.empty()) strings::StrAppend(&attrs, ", ");
|
||||
strings::StrAppend(&attrs, "\"", attr_name, "\", ", attr_name);
|
||||
|
||||
if (!function_call_parameters.empty()) {
|
||||
strings::StrAppend(&function_call_parameters, ", ");
|
||||
}
|
||||
strings::StrAppend(&function_call_parameters, renames[attr.name()], "=",
|
||||
attr_name);
|
||||
const string renamed = param_names.GetRenameTo();
|
||||
strings::StrAppend(&arguments, arguments.empty() ? "" : ", ", renamed, "=",
|
||||
renamed);
|
||||
}
|
||||
strings::StrAppend(&arguments, arguments.empty() ? "" : ", ", "name=name");
|
||||
|
||||
const string raw_function_name =
|
||||
python_op_gen_internal::AvoidPythonReserved(op_def_.name());
|
||||
|
||||
strings::StrAppend(&result_,
|
||||
"@_doc_controls.do_not_generate_docs\n@_kwarg_only\ndef ",
|
||||
raw_function_name, "(", raw_parameters, "):\n");
|
||||
strings::StrAppend(&result_, "def ", raw_function_name, "(", parameters,
|
||||
"):\n");
|
||||
strings::StrAppend(&result_, " return ", function_name_, "(", arguments,
|
||||
")\n");
|
||||
|
||||
// Function body.
|
||||
strings::StrAppend(&result_, " return ", function_name_, "(",
|
||||
function_call_parameters, ")\n");
|
||||
// Copy the __doc__ from the original op and apply the decorators.
|
||||
strings::StrAppend(&result_, raw_function_name, ".__doc__", " = ",
|
||||
function_name_, ".__doc__\n");
|
||||
strings::StrAppend(&result_, raw_function_name, " = ",
|
||||
"_doc_controls.do_not_generate_docs(_kwarg_only(",
|
||||
raw_function_name, "))\n");
|
||||
|
||||
// Export.
|
||||
strings::StrAppend(&result_, "tf_export(\"raw_ops.", raw_function_name,
|
||||
"\")(", raw_function_name, ")\n");
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -37,10 +38,27 @@ class RawOpsTest(test.TestCase):
|
||||
gen_math_ops.Add(1., 1.)
|
||||
|
||||
def testRequiresKwargs_providesSuggestion(self):
|
||||
msg = "possible keys: \\['x', 'y'\\]"
|
||||
msg = "possible keys: \\['x', 'y', 'name'\\]"
|
||||
with self.assertRaisesRegexp(TypeError, msg):
|
||||
gen_math_ops.Add(1., y=2.)
|
||||
|
||||
def testName(self):
|
||||
x = constant_op.constant(1)
|
||||
op = gen_math_ops.Add(x=x, y=x, name="double")
|
||||
if not context.executing_eagerly():
|
||||
# `Tensor.name` is not available in eager.
|
||||
self.assertEqual(op.name, "double:0")
|
||||
|
||||
def testDoc(self):
|
||||
self.assertEqual(gen_math_ops.add.__doc__, gen_math_ops.Add.__doc__)
|
||||
|
||||
def testDefaults(self):
|
||||
x = constant_op.constant([[True]])
|
||||
self.assertAllClose(
|
||||
gen_math_ops.Any(input=x, axis=0),
|
||||
gen_math_ops.Any(input=x, axis=0, keep_dims=False))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user