Allow op namespacing to work in data inputs and control inputs
PiperOrigin-RevId: 288534578 Change-Id: I7bb5571ebec102a8b3145ec8bb018e9de1b39ab5
This commit is contained in:
parent
e37290d8df
commit
7fda1add7c
@ -716,7 +716,7 @@ bool IsValidNodeName(StringPiece sp) {
|
|||||||
if (scanner.empty()) // No error, but nothing left, good.
|
if (scanner.empty()) // No error, but nothing left, good.
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
// Absorb another piece, starting with a '>'
|
// Absorb another name/namespace, starting with a '>'
|
||||||
scanner.One(Scanner::RANGLE)
|
scanner.One(Scanner::RANGLE)
|
||||||
.One(Scanner::LETTER_DIGIT_DOT)
|
.One(Scanner::LETTER_DIGIT_DOT)
|
||||||
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
|
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
|
||||||
@ -728,26 +728,46 @@ bool IsValidDataInputName(StringPiece sp) {
|
|||||||
Scanner scan(sp);
|
Scanner scan(sp);
|
||||||
scan.One(Scanner::LETTER_DIGIT_DOT)
|
scan.One(Scanner::LETTER_DIGIT_DOT)
|
||||||
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
|
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
|
||||||
if (scan.Peek() == ':') {
|
|
||||||
scan.OneLiteral(":");
|
while (true) {
|
||||||
if (scan.Peek() == '0') {
|
if (!scan.GetResult()) // Some error in previous iteration.
|
||||||
scan.OneLiteral("0"); // :0
|
return false;
|
||||||
|
if (scan.empty()) // No error, but nothing left, good.
|
||||||
|
return true;
|
||||||
|
|
||||||
|
if (scan.Peek() == ':') { // Absorb identifier after the colon
|
||||||
|
scan.OneLiteral(":");
|
||||||
|
if (scan.Peek() == '0') {
|
||||||
|
scan.OneLiteral("0"); // :0
|
||||||
|
} else {
|
||||||
|
scan.Many(Scanner::DIGIT); // :[1-9][0-9]*
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
scan.Many(Scanner::DIGIT); // :[1-9][0-9]*
|
// Absorb another name/namespace, starting with a '>'
|
||||||
|
scan.One(Scanner::RANGLE)
|
||||||
|
.One(Scanner::LETTER_DIGIT_DOT)
|
||||||
|
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
scan.Eos();
|
|
||||||
|
|
||||||
return scan.GetResult();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsValidControlInputName(StringPiece sp) {
|
bool IsValidControlInputName(StringPiece sp) {
|
||||||
return Scanner(sp)
|
Scanner scan(sp);
|
||||||
.OneLiteral("^")
|
scan.OneLiteral("^")
|
||||||
.One(Scanner::LETTER_DIGIT_DOT)
|
.One(Scanner::LETTER_DIGIT_DOT)
|
||||||
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
|
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
|
||||||
.Eos()
|
|
||||||
.GetResult();
|
while (true) {
|
||||||
|
if (!scan.GetResult()) // Some error in previous iteration.
|
||||||
|
return false;
|
||||||
|
if (scan.empty()) // No error, but nothing left, good.
|
||||||
|
return true;
|
||||||
|
|
||||||
|
// Absorb another name/namespace, starting with a '>'
|
||||||
|
scan.One(Scanner::RANGLE)
|
||||||
|
.One(Scanner::LETTER_DIGIT_DOT)
|
||||||
|
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -309,6 +309,24 @@ TEST(NodeDefUtilTest, ValidSyntax) {
|
|||||||
EXPECT_EQ("{{node n}} = AnyIn[T=[DT_INT32, DT_STRING]](a:0, b:123)",
|
EXPECT_EQ("{{node n}} = AnyIn[T=[DT_INT32, DT_STRING]](a:0, b:123)",
|
||||||
SummarizeNodeDef(node_def_explicit_inputs));
|
SummarizeNodeDef(node_def_explicit_inputs));
|
||||||
|
|
||||||
|
const NodeDef node_def_explicit_inputs_namespace = ToNodeDef(R"proto(
|
||||||
|
name: 'Project>n'
|
||||||
|
op: 'Project>AnyIn'
|
||||||
|
input: 'Project>a:0'
|
||||||
|
input: 'Project>b:123'
|
||||||
|
input: '^Project>c'
|
||||||
|
attr {
|
||||||
|
key: 'T'
|
||||||
|
value { list { type: [ DT_INT32, DT_STRING ] } }
|
||||||
|
}
|
||||||
|
)proto");
|
||||||
|
ExpectValidSyntax(node_def_explicit_inputs_namespace);
|
||||||
|
|
||||||
|
EXPECT_EQ(
|
||||||
|
"{{node Project>n}} = Project>AnyIn[T=[DT_INT32, DT_STRING]]"
|
||||||
|
"(Project>a:0, Project>b:123, ^Project>c)",
|
||||||
|
SummarizeNodeDef(node_def_explicit_inputs_namespace));
|
||||||
|
|
||||||
const NodeDef node_def_partial_shape = ToNodeDef(R"proto(
|
const NodeDef node_def_partial_shape = ToNodeDef(R"proto(
|
||||||
name:'n' op:'AnyIn'
|
name:'n' op:'AnyIn'
|
||||||
attr { key:'shp' value { shape { dim { size: -1 } dim { size: 0 } } } }
|
attr { key:'shp' value { shape { dim { size: -1 } dim { size: 0 } } } }
|
||||||
|
@ -40,6 +40,13 @@ class ZeroOut1Test(tf.test.TestCase):
|
|||||||
result = zero_out_op_1.namespace_zero_out([5, 4, 3, 2, 1])
|
result = zero_out_op_1.namespace_zero_out([5, 4, 3, 2, 1])
|
||||||
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
|
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def test_namespace_call_op_on_op(self):
|
||||||
|
with self.cached_session():
|
||||||
|
x = zero_out_op_1.namespace_zero_out([5, 4, 3, 2, 1])
|
||||||
|
result = zero_out_op_1.namespace_zero_out(x)
|
||||||
|
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def test_namespace_nested(self):
|
def test_namespace_nested(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
|
Loading…
Reference in New Issue
Block a user