(Hopefully) remaining work to get custom ops working w/ namespaces. Also makes some example tests about adding ops create & run namespaced ops.

PiperOrigin-RevId: 268944869
This commit is contained in:
A. Unique TensorFlower 2019-09-13 11:25:02 -07:00 committed by TensorFlower Gardener
parent 4112ac2be0
commit b4dcff4480
7 changed files with 106 additions and 12 deletions

View File

@ -292,7 +292,9 @@ string ToCamelCase(const string& str) {
bool cap = true;
while (i < str.size()) {
const char c = str[i++];
if (c == joiner) {
if (c == '>') {
cap = true;
} else if (c == joiner) {
cap = true;
} else if (cap) {
result += toupper(c);
@ -304,6 +306,21 @@ string ToCamelCase(const string& str) {
return result;
}
string SeparateNamespaces(const string& str) {
string result;
const char joiner = '_';
size_t i = 0;
while (i < str.size()) {
const char c = str[i++];
if (c == '>') {
result += joiner;
} else {
result += c;
}
}
return result;
}
// Returns a <string, bool> pair. The string is the C++ type name to be used for
// attr_type when defining an object of that type. The bool is a flag to
// indicate whether to treat the type as const when accepting the C++ type as an
@ -549,7 +566,7 @@ struct OpInfo {
OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def,
const std::vector<string>& aliases)
: graph_op_def(graph_op_def), api_def(api_def), aliases(aliases) {
op_name = api_def.endpoint(0).name();
op_name = SeparateNamespaces(api_def.endpoint(0).name());
InferOpAttributes(graph_op_def, &inferred_input_attrs);
has_optional_attrs = HasOptionalAttrs(api_def, inferred_input_attrs);
arg_types.push_back("const ::tensorflow::Scope&");

View File

@ -66,12 +66,23 @@ inline bool IsNextIteration(const NodeDef& node_def) {
bool IsValidNodeName(StringPiece s, bool allow_internal_ops) {
using ::tensorflow::strings::Scanner;
return Scanner(s)
Scanner scanner(s);
scanner
.One(allow_internal_ops ? Scanner::LETTER_DIGIT_DOT_UNDERSCORE
: Scanner::LETTER_DIGIT_DOT)
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
.Eos()
.GetResult();
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
while (true) {
if (!scanner.GetResult()) // Some error in previous iteration.
return false;
if (scanner.empty()) // No error, but nothing left, good.
return true;
// Absorb another piece, starting with a '>'
scanner.One(Scanner::RANGLE)
.One(Scanner::LETTER_DIGIT_DOT)
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
}
}
class GraphConstructor {

View File

@ -34,6 +34,18 @@ class ZeroOut1Test(tf.test.TestCase):
result = zero_out_op_1.zero_out([5, 4, 3, 2, 1])
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
@test_util.run_deprecated_v1
def test_namespace(self):
with self.cached_session():
result = zero_out_op_1.namespace_zero_out([5, 4, 3, 2, 1])
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
@test_util.run_deprecated_v1
def test_namespace_nested(self):
with self.cached_session():
result = zero_out_op_1.namespace_nested_zero_out([5, 4, 3, 2, 1])
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
def testLoadTwice(self):
zero_out_loaded_again = tf.load_op_library(os.path.join(
tf.resource_loader.get_data_files_path(), 'zero_out_op_kernel_1.so'))

View File

@ -25,3 +25,5 @@ _zero_out_module = tf.load_op_library(
os.path.join(tf.resource_loader.get_data_files_path(),
'zero_out_op_kernel_1.so'))
zero_out = _zero_out_module.zero_out
namespace_zero_out = _zero_out_module.namespace_zero_out
namespace_nested_zero_out = _zero_out_module.namespace_nested_zero_out

View File

@ -60,3 +60,37 @@ class ZeroOutOp : public OpKernel {
};
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
REGISTER_OP("Namespace>ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
})
.Doc(R"doc(
Zeros out all but the first value of a Tensor.
zeroed: A Tensor whose first value is identical to `to_zero`, and 0
otherwise.
)doc");
REGISTER_KERNEL_BUILDER(Name("Namespace>ZeroOut").Device(DEVICE_CPU),
ZeroOutOp);
REGISTER_OP("Namespace>Nested>ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
})
.Doc(R"doc(
Zeros out all but the first value of a Tensor.
zeroed: A Tensor whose first value is identical to `to_zero`, and 0
otherwise.
)doc");
REGISTER_KERNEL_BUILDER(Name("Namespace>Nested>ZeroOut").Device(DEVICE_CPU),
ZeroOutOp);

View File

@ -1578,8 +1578,8 @@ def _NodeDef(op_type, name, attrs=None):
# Copied from core/framework/node_def_util.cc
# TODO(mrry,josh11b): Consolidate this validation in C++ code.
_VALID_OP_NAME_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\-/]*$")
_VALID_SCOPE_NAME_REGEX = re.compile("^[A-Za-z0-9_.\\-/]*$")
_VALID_OP_NAME_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\-/>]*$")
_VALID_SCOPE_NAME_REGEX = re.compile("^[A-Za-z0-9_.\\-/>]*$")
def _create_c_op(graph, node_def, inputs, control_inputs):

View File

@ -108,8 +108,22 @@ bool IsOpWithUnderscorePrefix(const string& s) {
}
string AvoidPythonReserved(const string& s) {
if (IsPythonReserved(s)) return strings::StrCat(s, "_");
return s;
const char namespace_separator = '>';
const char joiner = '_';
const int last_index = s.size();
string result;
for (int i = 0; i < last_index; ++i) {
const char c = s[i];
// Convert namespace separators ('>' characters) to joiners
if (c == namespace_separator) {
result.push_back(joiner);
} else {
result.push_back(c);
}
}
if (IsPythonReserved(result)) return strings::StrCat(result, "_");
return result;
}
// Indent the first line by "initial" spaces and all following lines
@ -467,20 +481,24 @@ string AttrValueToPython(const string& type, const AttrValue& value,
void GenerateLowerCaseOpName(const string& str, string* result) {
const char joiner = '_';
const char namespace_separator = '>';
const int last_index = str.size() - 1;
for (int i = 0; i <= last_index; ++i) {
const char c = str[i];
// Convert namespace separators ('>' characters) to joiners
if (c == '>') {
if (c == namespace_separator) {
result->push_back(joiner);
continue;
}
// Emit a joiner only if a previous-lower-to-now-upper or a
// now-upper-to-next-lower transition happens.
// (But don't emit an extra joiner if we just saw a namespace separator
if (isupper(c) && (i > 0)) {
if (islower(str[i - 1]) || ((i < last_index) && islower(str[i + 1]))) {
result->push_back(joiner);
if (!(str[i - 1] == namespace_separator)) {
result->push_back(joiner);
}
}
}
result->push_back(tolower(c));