(Hopefully) remaining work to get custom ops working w/ namespaces. Also makes some example tests about adding ops create & run namespaced ops.
PiperOrigin-RevId: 268944869
This commit is contained in:
parent
4112ac2be0
commit
b4dcff4480
@ -292,7 +292,9 @@ string ToCamelCase(const string& str) {
|
||||
bool cap = true;
|
||||
while (i < str.size()) {
|
||||
const char c = str[i++];
|
||||
if (c == joiner) {
|
||||
if (c == '>') {
|
||||
cap = true;
|
||||
} else if (c == joiner) {
|
||||
cap = true;
|
||||
} else if (cap) {
|
||||
result += toupper(c);
|
||||
@ -304,6 +306,21 @@ string ToCamelCase(const string& str) {
|
||||
return result;
|
||||
}
|
||||
|
||||
string SeparateNamespaces(const string& str) {
|
||||
string result;
|
||||
const char joiner = '_';
|
||||
size_t i = 0;
|
||||
while (i < str.size()) {
|
||||
const char c = str[i++];
|
||||
if (c == '>') {
|
||||
result += joiner;
|
||||
} else {
|
||||
result += c;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns a <string, bool> pair. The string is the C++ type name to be used for
|
||||
// attr_type when defining an object of that type. The bool is a flag to
|
||||
// indicate whether to treat the type as const when accepting the C++ type as an
|
||||
@ -549,7 +566,7 @@ struct OpInfo {
|
||||
OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def,
|
||||
const std::vector<string>& aliases)
|
||||
: graph_op_def(graph_op_def), api_def(api_def), aliases(aliases) {
|
||||
op_name = api_def.endpoint(0).name();
|
||||
op_name = SeparateNamespaces(api_def.endpoint(0).name());
|
||||
InferOpAttributes(graph_op_def, &inferred_input_attrs);
|
||||
has_optional_attrs = HasOptionalAttrs(api_def, inferred_input_attrs);
|
||||
arg_types.push_back("const ::tensorflow::Scope&");
|
||||
|
@ -66,12 +66,23 @@ inline bool IsNextIteration(const NodeDef& node_def) {
|
||||
|
||||
bool IsValidNodeName(StringPiece s, bool allow_internal_ops) {
|
||||
using ::tensorflow::strings::Scanner;
|
||||
return Scanner(s)
|
||||
Scanner scanner(s);
|
||||
scanner
|
||||
.One(allow_internal_ops ? Scanner::LETTER_DIGIT_DOT_UNDERSCORE
|
||||
: Scanner::LETTER_DIGIT_DOT)
|
||||
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
|
||||
.Eos()
|
||||
.GetResult();
|
||||
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
|
||||
|
||||
while (true) {
|
||||
if (!scanner.GetResult()) // Some error in previous iteration.
|
||||
return false;
|
||||
if (scanner.empty()) // No error, but nothing left, good.
|
||||
return true;
|
||||
|
||||
// Absorb another piece, starting with a '>'
|
||||
scanner.One(Scanner::RANGLE)
|
||||
.One(Scanner::LETTER_DIGIT_DOT)
|
||||
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
|
||||
}
|
||||
}
|
||||
|
||||
class GraphConstructor {
|
||||
|
@ -34,6 +34,18 @@ class ZeroOut1Test(tf.test.TestCase):
|
||||
result = zero_out_op_1.zero_out([5, 4, 3, 2, 1])
|
||||
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_namespace(self):
|
||||
with self.cached_session():
|
||||
result = zero_out_op_1.namespace_zero_out([5, 4, 3, 2, 1])
|
||||
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_namespace_nested(self):
|
||||
with self.cached_session():
|
||||
result = zero_out_op_1.namespace_nested_zero_out([5, 4, 3, 2, 1])
|
||||
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
|
||||
|
||||
def testLoadTwice(self):
|
||||
zero_out_loaded_again = tf.load_op_library(os.path.join(
|
||||
tf.resource_loader.get_data_files_path(), 'zero_out_op_kernel_1.so'))
|
||||
|
@ -25,3 +25,5 @@ _zero_out_module = tf.load_op_library(
|
||||
os.path.join(tf.resource_loader.get_data_files_path(),
|
||||
'zero_out_op_kernel_1.so'))
|
||||
zero_out = _zero_out_module.zero_out
|
||||
namespace_zero_out = _zero_out_module.namespace_zero_out
|
||||
namespace_nested_zero_out = _zero_out_module.namespace_nested_zero_out
|
||||
|
@ -60,3 +60,37 @@ class ZeroOutOp : public OpKernel {
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
|
||||
|
||||
REGISTER_OP("Namespace>ZeroOut")
|
||||
.Input("to_zero: int32")
|
||||
.Output("zeroed: int32")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, c->input(0));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Zeros out all but the first value of a Tensor.
|
||||
|
||||
zeroed: A Tensor whose first value is identical to `to_zero`, and 0
|
||||
otherwise.
|
||||
)doc");
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Namespace>ZeroOut").Device(DEVICE_CPU),
|
||||
ZeroOutOp);
|
||||
|
||||
REGISTER_OP("Namespace>Nested>ZeroOut")
|
||||
.Input("to_zero: int32")
|
||||
.Output("zeroed: int32")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, c->input(0));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Zeros out all but the first value of a Tensor.
|
||||
|
||||
zeroed: A Tensor whose first value is identical to `to_zero`, and 0
|
||||
otherwise.
|
||||
)doc");
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Namespace>Nested>ZeroOut").Device(DEVICE_CPU),
|
||||
ZeroOutOp);
|
||||
|
@ -1578,8 +1578,8 @@ def _NodeDef(op_type, name, attrs=None):
|
||||
|
||||
# Copied from core/framework/node_def_util.cc
|
||||
# TODO(mrry,josh11b): Consolidate this validation in C++ code.
|
||||
_VALID_OP_NAME_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\-/]*$")
|
||||
_VALID_SCOPE_NAME_REGEX = re.compile("^[A-Za-z0-9_.\\-/]*$")
|
||||
_VALID_OP_NAME_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\-/>]*$")
|
||||
_VALID_SCOPE_NAME_REGEX = re.compile("^[A-Za-z0-9_.\\-/>]*$")
|
||||
|
||||
|
||||
def _create_c_op(graph, node_def, inputs, control_inputs):
|
||||
|
@ -108,8 +108,22 @@ bool IsOpWithUnderscorePrefix(const string& s) {
|
||||
}
|
||||
|
||||
string AvoidPythonReserved(const string& s) {
|
||||
if (IsPythonReserved(s)) return strings::StrCat(s, "_");
|
||||
return s;
|
||||
const char namespace_separator = '>';
|
||||
const char joiner = '_';
|
||||
const int last_index = s.size();
|
||||
string result;
|
||||
for (int i = 0; i < last_index; ++i) {
|
||||
const char c = s[i];
|
||||
// Convert namespace separators ('>' characters) to joiners
|
||||
if (c == namespace_separator) {
|
||||
result.push_back(joiner);
|
||||
} else {
|
||||
result.push_back(c);
|
||||
}
|
||||
}
|
||||
|
||||
if (IsPythonReserved(result)) return strings::StrCat(result, "_");
|
||||
return result;
|
||||
}
|
||||
|
||||
// Indent the first line by "initial" spaces and all following lines
|
||||
@ -467,20 +481,24 @@ string AttrValueToPython(const string& type, const AttrValue& value,
|
||||
|
||||
void GenerateLowerCaseOpName(const string& str, string* result) {
|
||||
const char joiner = '_';
|
||||
const char namespace_separator = '>';
|
||||
const int last_index = str.size() - 1;
|
||||
for (int i = 0; i <= last_index; ++i) {
|
||||
const char c = str[i];
|
||||
// Convert namespace separators ('>' characters) to joiners
|
||||
if (c == '>') {
|
||||
if (c == namespace_separator) {
|
||||
result->push_back(joiner);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Emit a joiner only if a previous-lower-to-now-upper or a
|
||||
// now-upper-to-next-lower transition happens.
|
||||
// (But don't emit an extra joiner if we just saw a namespace separator
|
||||
if (isupper(c) && (i > 0)) {
|
||||
if (islower(str[i - 1]) || ((i < last_index) && islower(str[i + 1]))) {
|
||||
result->push_back(joiner);
|
||||
if (!(str[i - 1] == namespace_separator)) {
|
||||
result->push_back(joiner);
|
||||
}
|
||||
}
|
||||
}
|
||||
result->push_back(tolower(c));
|
||||
|
Loading…
Reference in New Issue
Block a user