Speed up graph tracing in tpu replicate/rewrite

PiperOrigin-RevId: 273606655
This commit is contained in:
Akshay Modi 2019-10-08 14:33:05 -07:00 committed by TensorFlower Gardener
parent ac960b6b5a
commit 0f5189f2bd
2 changed files with 22 additions and 5 deletions

View File

@ -2252,12 +2252,16 @@ class Operation(object):
buf = c_api.TF_NewBufferFromString(
compat.as_bytes(attr_value.SerializeToString()))
try:
# pylint: disable=protected-access
c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf)
# pylint: enable=protected-access
self._set_attr_with_buf(attr_name, buf)
finally:
c_api.TF_DeleteBuffer(buf)
def _set_attr_with_buf(self, attr_name, attr_buf):
"""Set an attr in the node_def with a pre-allocated buffer."""
# pylint: disable=protected-access
c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, attr_buf)
# pylint: enable=protected-access
def _set_func_attr(self, attr_name, func_name):
"""Private method used to set a function attribute in the node_def."""
func = attr_value_pb2.NameAttrList(name=func_name)

View File

@ -25,6 +25,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.compat import compat as api_compat
from tensorflow.python.compiler.xla import xla
from tensorflow.python.distribute import device_util
@ -213,6 +214,16 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
outside the replicated computation.
"""
class _TFBufferWrapper(object):
"""An internal class to help manage the TF_Buffer lifetime."""
def __init__(self, buf_string):
self._buffer = pywrap_tensorflow.TF_NewBufferFromString(
compat.as_bytes(buf_string))
def __del__(self):
pywrap_tensorflow.TF_DeleteBuffer(self._buffer)
def __init__(self, name, num_replicas, pivot):
"""Builds a new TPUReplicateContext.
@ -237,6 +248,8 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
self._host_compute_core = []
self._name = name
self._name_as_bytes = compat.as_bytes(name)
self._tpu_relicate_attr_buf = self._TFBufferWrapper(
attr_value_pb2.AttrValue(s=self._name_as_bytes).SerializeToString())
self._unsupported_ops = []
self._pivot = pivot
self._replicated_vars = {}
@ -469,8 +482,8 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
"(operator name: %s)" % op.name)
if _TPU_REPLICATE_ATTR in op.node_def.attr:
raise ValueError("TPU computations cannot be nested")
op._set_attr(_TPU_REPLICATE_ATTR,
attr_value_pb2.AttrValue(s=self._name_as_bytes))
op._set_attr_with_buf(
_TPU_REPLICATE_ATTR, self._tpu_relicate_attr_buf._buffer)
if self._outside_compilation_cluster:
op._set_attr(
_OUTSIDE_COMPILATION_ATTR,