Speed up graph tracing in tpu replicate/rewrite
PiperOrigin-RevId: 273606655
This commit is contained in:
parent
ac960b6b5a
commit
0f5189f2bd
@ -2252,12 +2252,16 @@ class Operation(object):
|
||||
buf = c_api.TF_NewBufferFromString(
|
||||
compat.as_bytes(attr_value.SerializeToString()))
|
||||
try:
|
||||
# pylint: disable=protected-access
|
||||
c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf)
|
||||
# pylint: enable=protected-access
|
||||
self._set_attr_with_buf(attr_name, buf)
|
||||
finally:
|
||||
c_api.TF_DeleteBuffer(buf)
|
||||
|
||||
def _set_attr_with_buf(self, attr_name, attr_buf):
|
||||
"""Set an attr in the node_def with a pre-allocated buffer."""
|
||||
# pylint: disable=protected-access
|
||||
c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, attr_buf)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def _set_func_attr(self, attr_name, func_name):
|
||||
"""Private method used to set a function attribute in the node_def."""
|
||||
func = attr_value_pb2.NameAttrList(name=func_name)
|
||||
|
@ -25,6 +25,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.compat import compat as api_compat
|
||||
from tensorflow.python.compiler.xla import xla
|
||||
from tensorflow.python.distribute import device_util
|
||||
@ -213,6 +214,16 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
outside the replicated computation.
|
||||
"""
|
||||
|
||||
class _TFBufferWrapper(object):
|
||||
"""An internal class to help manage the TF_Buffer lifetime."""
|
||||
|
||||
def __init__(self, buf_string):
|
||||
self._buffer = pywrap_tensorflow.TF_NewBufferFromString(
|
||||
compat.as_bytes(buf_string))
|
||||
|
||||
def __del__(self):
|
||||
pywrap_tensorflow.TF_DeleteBuffer(self._buffer)
|
||||
|
||||
def __init__(self, name, num_replicas, pivot):
|
||||
"""Builds a new TPUReplicateContext.
|
||||
|
||||
@ -237,6 +248,8 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
self._host_compute_core = []
|
||||
self._name = name
|
||||
self._name_as_bytes = compat.as_bytes(name)
|
||||
self._tpu_relicate_attr_buf = self._TFBufferWrapper(
|
||||
attr_value_pb2.AttrValue(s=self._name_as_bytes).SerializeToString())
|
||||
self._unsupported_ops = []
|
||||
self._pivot = pivot
|
||||
self._replicated_vars = {}
|
||||
@ -469,8 +482,8 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
"(operator name: %s)" % op.name)
|
||||
if _TPU_REPLICATE_ATTR in op.node_def.attr:
|
||||
raise ValueError("TPU computations cannot be nested")
|
||||
op._set_attr(_TPU_REPLICATE_ATTR,
|
||||
attr_value_pb2.AttrValue(s=self._name_as_bytes))
|
||||
op._set_attr_with_buf(
|
||||
_TPU_REPLICATE_ATTR, self._tpu_relicate_attr_buf._buffer)
|
||||
if self._outside_compilation_cluster:
|
||||
op._set_attr(
|
||||
_OUTSIDE_COMPILATION_ATTR,
|
||||
|
Loading…
Reference in New Issue
Block a user