Allow op namespacing to work in data inputs and control inputs

PiperOrigin-RevId: 288534578
Change-Id: I7bb5571ebec102a8b3145ec8bb018e9de1b39ab5
This commit is contained in:
A. Unique TensorFlower 2020-01-07 11:16:42 -08:00 committed by TensorFlower Gardener
parent e37290d8df
commit 7fda1add7c
3 changed files with 59 additions and 14 deletions

View File

@ -716,7 +716,7 @@ bool IsValidNodeName(StringPiece sp) {
if (scanner.empty()) // No error, but nothing left, good.
return true;
// Absorb another piece, starting with a '>'
// Absorb another name/namespace, starting with a '>'
scanner.One(Scanner::RANGLE)
.One(Scanner::LETTER_DIGIT_DOT)
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
@ -728,26 +728,46 @@ bool IsValidDataInputName(StringPiece sp) {
Scanner scan(sp);
scan.One(Scanner::LETTER_DIGIT_DOT)
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
if (scan.Peek() == ':') {
scan.OneLiteral(":");
if (scan.Peek() == '0') {
scan.OneLiteral("0"); // :0
while (true) {
if (!scan.GetResult()) // Some error in previous iteration.
return false;
if (scan.empty()) // No error, but nothing left, good.
return true;
if (scan.Peek() == ':') { // Absorb identifier after the colon
scan.OneLiteral(":");
if (scan.Peek() == '0') {
scan.OneLiteral("0"); // :0
} else {
scan.Many(Scanner::DIGIT); // :[1-9][0-9]*
}
} else {
scan.Many(Scanner::DIGIT); // :[1-9][0-9]*
// Absorb another name/namespace, starting with a '>'
scan.One(Scanner::RANGLE)
.One(Scanner::LETTER_DIGIT_DOT)
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
}
}
scan.Eos();
return scan.GetResult();
}
bool IsValidControlInputName(StringPiece sp) {
return Scanner(sp)
.OneLiteral("^")
Scanner scan(sp);
scan.OneLiteral("^")
.One(Scanner::LETTER_DIGIT_DOT)
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
.Eos()
.GetResult();
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
while (true) {
if (!scan.GetResult()) // Some error in previous iteration.
return false;
if (scan.empty()) // No error, but nothing left, good.
return true;
// Absorb another name/namespace, starting with a '>'
scan.One(Scanner::RANGLE)
.One(Scanner::LETTER_DIGIT_DOT)
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
}
}
} // namespace

View File

@ -309,6 +309,24 @@ TEST(NodeDefUtilTest, ValidSyntax) {
EXPECT_EQ("{{node n}} = AnyIn[T=[DT_INT32, DT_STRING]](a:0, b:123)",
SummarizeNodeDef(node_def_explicit_inputs));
const NodeDef node_def_explicit_inputs_namespace = ToNodeDef(R"proto(
name: 'Project>n'
op: 'Project>AnyIn'
input: 'Project>a:0'
input: 'Project>b:123'
input: '^Project>c'
attr {
key: 'T'
value { list { type: [ DT_INT32, DT_STRING ] } }
}
)proto");
ExpectValidSyntax(node_def_explicit_inputs_namespace);
EXPECT_EQ(
"{{node Project>n}} = Project>AnyIn[T=[DT_INT32, DT_STRING]]"
"(Project>a:0, Project>b:123, ^Project>c)",
SummarizeNodeDef(node_def_explicit_inputs_namespace));
const NodeDef node_def_partial_shape = ToNodeDef(R"proto(
name:'n' op:'AnyIn'
attr { key:'shp' value { shape { dim { size: -1 } dim { size: 0 } } } }

View File

@ -40,6 +40,13 @@ class ZeroOut1Test(tf.test.TestCase):
result = zero_out_op_1.namespace_zero_out([5, 4, 3, 2, 1])
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
@test_util.run_deprecated_v1
def test_namespace_call_op_on_op(self):
with self.cached_session():
x = zero_out_op_1.namespace_zero_out([5, 4, 3, 2, 1])
result = zero_out_op_1.namespace_zero_out(x)
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
@test_util.run_deprecated_v1
def test_namespace_nested(self):
with self.cached_session():