Allow op namespacing to work in data inputs and control inputs
PiperOrigin-RevId: 288534578 Change-Id: I7bb5571ebec102a8b3145ec8bb018e9de1b39ab5
This commit is contained in:
parent
e37290d8df
commit
7fda1add7c
@ -716,7 +716,7 @@ bool IsValidNodeName(StringPiece sp) {
|
||||
if (scanner.empty()) // No error, but nothing left, good.
|
||||
return true;
|
||||
|
||||
// Absorb another piece, starting with a '>'
|
||||
// Absorb another name/namespace, starting with a '>'
|
||||
scanner.One(Scanner::RANGLE)
|
||||
.One(Scanner::LETTER_DIGIT_DOT)
|
||||
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
|
||||
@ -728,26 +728,46 @@ bool IsValidDataInputName(StringPiece sp) {
|
||||
Scanner scan(sp);
|
||||
scan.One(Scanner::LETTER_DIGIT_DOT)
|
||||
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
|
||||
if (scan.Peek() == ':') {
|
||||
scan.OneLiteral(":");
|
||||
if (scan.Peek() == '0') {
|
||||
scan.OneLiteral("0"); // :0
|
||||
|
||||
while (true) {
|
||||
if (!scan.GetResult()) // Some error in previous iteration.
|
||||
return false;
|
||||
if (scan.empty()) // No error, but nothing left, good.
|
||||
return true;
|
||||
|
||||
if (scan.Peek() == ':') { // Absorb identifier after the colon
|
||||
scan.OneLiteral(":");
|
||||
if (scan.Peek() == '0') {
|
||||
scan.OneLiteral("0"); // :0
|
||||
} else {
|
||||
scan.Many(Scanner::DIGIT); // :[1-9][0-9]*
|
||||
}
|
||||
} else {
|
||||
scan.Many(Scanner::DIGIT); // :[1-9][0-9]*
|
||||
// Absorb another name/namespace, starting with a '>'
|
||||
scan.One(Scanner::RANGLE)
|
||||
.One(Scanner::LETTER_DIGIT_DOT)
|
||||
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
|
||||
}
|
||||
}
|
||||
scan.Eos();
|
||||
|
||||
return scan.GetResult();
|
||||
}
|
||||
|
||||
bool IsValidControlInputName(StringPiece sp) {
|
||||
return Scanner(sp)
|
||||
.OneLiteral("^")
|
||||
Scanner scan(sp);
|
||||
scan.OneLiteral("^")
|
||||
.One(Scanner::LETTER_DIGIT_DOT)
|
||||
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
|
||||
.Eos()
|
||||
.GetResult();
|
||||
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
|
||||
|
||||
while (true) {
|
||||
if (!scan.GetResult()) // Some error in previous iteration.
|
||||
return false;
|
||||
if (scan.empty()) // No error, but nothing left, good.
|
||||
return true;
|
||||
|
||||
// Absorb another name/namespace, starting with a '>'
|
||||
scan.One(Scanner::RANGLE)
|
||||
.One(Scanner::LETTER_DIGIT_DOT)
|
||||
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -309,6 +309,24 @@ TEST(NodeDefUtilTest, ValidSyntax) {
|
||||
EXPECT_EQ("{{node n}} = AnyIn[T=[DT_INT32, DT_STRING]](a:0, b:123)",
|
||||
SummarizeNodeDef(node_def_explicit_inputs));
|
||||
|
||||
const NodeDef node_def_explicit_inputs_namespace = ToNodeDef(R"proto(
|
||||
name: 'Project>n'
|
||||
op: 'Project>AnyIn'
|
||||
input: 'Project>a:0'
|
||||
input: 'Project>b:123'
|
||||
input: '^Project>c'
|
||||
attr {
|
||||
key: 'T'
|
||||
value { list { type: [ DT_INT32, DT_STRING ] } }
|
||||
}
|
||||
)proto");
|
||||
ExpectValidSyntax(node_def_explicit_inputs_namespace);
|
||||
|
||||
EXPECT_EQ(
|
||||
"{{node Project>n}} = Project>AnyIn[T=[DT_INT32, DT_STRING]]"
|
||||
"(Project>a:0, Project>b:123, ^Project>c)",
|
||||
SummarizeNodeDef(node_def_explicit_inputs_namespace));
|
||||
|
||||
const NodeDef node_def_partial_shape = ToNodeDef(R"proto(
|
||||
name:'n' op:'AnyIn'
|
||||
attr { key:'shp' value { shape { dim { size: -1 } dim { size: 0 } } } }
|
||||
|
@ -40,6 +40,13 @@ class ZeroOut1Test(tf.test.TestCase):
|
||||
result = zero_out_op_1.namespace_zero_out([5, 4, 3, 2, 1])
|
||||
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_namespace_call_op_on_op(self):
|
||||
with self.cached_session():
|
||||
x = zero_out_op_1.namespace_zero_out([5, 4, 3, 2, 1])
|
||||
result = zero_out_op_1.namespace_zero_out(x)
|
||||
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_namespace_nested(self):
|
||||
with self.cached_session():
|
||||
|
Loading…
Reference in New Issue
Block a user