From 831a55584749593400807e0baa7478476b5dbc70 Mon Sep 17 00:00:00 2001 From: Yujing Zhang Date: Tue, 26 May 2020 10:17:29 -0700 Subject: [PATCH] Add an attribute "is_packed" to TPUReplicatedInput op which indicates whether the per-replica inputs are packed into one input. PiperOrigin-RevId: 313216599 Change-Id: I9e9a38ee0fcb64caca9f2d1e2de268c9576ca6c8 --- tensorflow/core/ops/tpu_replication_ops.cc | 2 ++ tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt | 2 +- tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/ops/tpu_replication_ops.cc b/tensorflow/core/ops/tpu_replication_ops.cc index 3bb94044e14..a729d3c3b7b 100644 --- a/tensorflow/core/ops/tpu_replication_ops.cc +++ b/tensorflow/core/ops/tpu_replication_ops.cc @@ -44,6 +44,8 @@ REGISTER_OP("TPUReplicatedInput") .Attr("is_mirrored_variable: bool = false") // Index of the input. If is_mirrored_variable is true, this is ignored. .Attr("index: int = -1") + // All inputs are packed into one input + .Attr("is_packed: bool = false") .SetShapeFn([](InferenceContext* c) { ShapeHandle cur = c->input(c->num_inputs() - 1); for (int i = c->num_inputs() - 2; i >= 0; --i) { diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index a5fe83e713e..37a95cc88d1 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -4606,7 +4606,7 @@ tf_module { } member_method { name: "TPUReplicatedInput" - argspec: "args=[\'inputs\', \'is_mirrored_variable\', \'index\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'-1\', \'None\'], " + argspec: "args=[\'inputs\', \'is_mirrored_variable\', \'index\', \'is_packed\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'-1\', \'False\', \'None\'], " } member_method { name: "TPUReplicatedOutput" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index a5fe83e713e..37a95cc88d1 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -4606,7 +4606,7 @@ tf_module { } member_method { name: "TPUReplicatedInput" - argspec: "args=[\'inputs\', \'is_mirrored_variable\', \'index\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'-1\', \'None\'], " + argspec: "args=[\'inputs\', \'is_mirrored_variable\', \'index\', \'is_packed\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'-1\', \'False\', \'None\'], " } member_method { name: "TPUReplicatedOutput"