Add an attribute "is_packed" to TPUReplicatedInput op which indicates whether the per-replica inputs are packed into one input.

PiperOrigin-RevId: 313216599
Change-Id: I9e9a38ee0fcb64caca9f2d1e2de268c9576ca6c8
This commit is contained in:
Yujing Zhang 2020-05-26 10:17:29 -07:00 committed by TensorFlower Gardener
parent 0e80859784
commit 831a555847
3 changed files with 4 additions and 2 deletions

View File

@ -44,6 +44,8 @@ REGISTER_OP("TPUReplicatedInput")
.Attr("is_mirrored_variable: bool = false")
// Index of the input. If is_mirrored_variable is true, this is ignored.
.Attr("index: int = -1")
// All inputs are packed into one input
.Attr("is_packed: bool = false")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle cur = c->input(c->num_inputs() - 1);
for (int i = c->num_inputs() - 2; i >= 0; --i) {

View File

@ -4606,7 +4606,7 @@ tf_module {
}
member_method {
name: "TPUReplicatedInput"
argspec: "args=[\'inputs\', \'is_mirrored_variable\', \'index\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'-1\', \'None\'], "
argspec: "args=[\'inputs\', \'is_mirrored_variable\', \'index\', \'is_packed\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'-1\', \'False\', \'None\'], "
}
member_method {
name: "TPUReplicatedOutput"

View File

@ -4606,7 +4606,7 @@ tf_module {
}
member_method {
name: "TPUReplicatedInput"
argspec: "args=[\'inputs\', \'is_mirrored_variable\', \'index\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'-1\', \'None\'], "
argspec: "args=[\'inputs\', \'is_mirrored_variable\', \'index\', \'is_packed\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'-1\', \'False\', \'None\'], "
}
member_method {
name: "TPUReplicatedOutput"