Speed up graph tracing in tpu replicate/rewrite
PiperOrigin-RevId: 273606655
This commit is contained in:
parent
ac960b6b5a
commit
0f5189f2bd
@ -2252,12 +2252,16 @@ class Operation(object):
|
|||||||
buf = c_api.TF_NewBufferFromString(
|
buf = c_api.TF_NewBufferFromString(
|
||||||
compat.as_bytes(attr_value.SerializeToString()))
|
compat.as_bytes(attr_value.SerializeToString()))
|
||||||
try:
|
try:
|
||||||
# pylint: disable=protected-access
|
self._set_attr_with_buf(attr_name, buf)
|
||||||
c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf)
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
finally:
|
finally:
|
||||||
c_api.TF_DeleteBuffer(buf)
|
c_api.TF_DeleteBuffer(buf)
|
||||||
|
|
||||||
|
def _set_attr_with_buf(self, attr_name, attr_buf):
|
||||||
|
"""Set an attr in the node_def with a pre-allocated buffer."""
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, attr_buf)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
def _set_func_attr(self, attr_name, func_name):
|
def _set_func_attr(self, attr_name, func_name):
|
||||||
"""Private method used to set a function attribute in the node_def."""
|
"""Private method used to set a function attribute in the node_def."""
|
||||||
func = attr_value_pb2.NameAttrList(name=func_name)
|
func = attr_value_pb2.NameAttrList(name=func_name)
|
||||||
|
@ -25,6 +25,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
|||||||
|
|
||||||
from tensorflow.core.framework import attr_value_pb2
|
from tensorflow.core.framework import attr_value_pb2
|
||||||
from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
|
from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
|
||||||
|
from tensorflow.python import pywrap_tensorflow
|
||||||
from tensorflow.python.compat import compat as api_compat
|
from tensorflow.python.compat import compat as api_compat
|
||||||
from tensorflow.python.compiler.xla import xla
|
from tensorflow.python.compiler.xla import xla
|
||||||
from tensorflow.python.distribute import device_util
|
from tensorflow.python.distribute import device_util
|
||||||
@ -213,6 +214,16 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
|||||||
outside the replicated computation.
|
outside the replicated computation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
class _TFBufferWrapper(object):
|
||||||
|
"""An internal class to help manage the TF_Buffer lifetime."""
|
||||||
|
|
||||||
|
def __init__(self, buf_string):
|
||||||
|
self._buffer = pywrap_tensorflow.TF_NewBufferFromString(
|
||||||
|
compat.as_bytes(buf_string))
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
pywrap_tensorflow.TF_DeleteBuffer(self._buffer)
|
||||||
|
|
||||||
def __init__(self, name, num_replicas, pivot):
|
def __init__(self, name, num_replicas, pivot):
|
||||||
"""Builds a new TPUReplicateContext.
|
"""Builds a new TPUReplicateContext.
|
||||||
|
|
||||||
@ -237,6 +248,8 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
|||||||
self._host_compute_core = []
|
self._host_compute_core = []
|
||||||
self._name = name
|
self._name = name
|
||||||
self._name_as_bytes = compat.as_bytes(name)
|
self._name_as_bytes = compat.as_bytes(name)
|
||||||
|
self._tpu_relicate_attr_buf = self._TFBufferWrapper(
|
||||||
|
attr_value_pb2.AttrValue(s=self._name_as_bytes).SerializeToString())
|
||||||
self._unsupported_ops = []
|
self._unsupported_ops = []
|
||||||
self._pivot = pivot
|
self._pivot = pivot
|
||||||
self._replicated_vars = {}
|
self._replicated_vars = {}
|
||||||
@ -469,8 +482,8 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
|||||||
"(operator name: %s)" % op.name)
|
"(operator name: %s)" % op.name)
|
||||||
if _TPU_REPLICATE_ATTR in op.node_def.attr:
|
if _TPU_REPLICATE_ATTR in op.node_def.attr:
|
||||||
raise ValueError("TPU computations cannot be nested")
|
raise ValueError("TPU computations cannot be nested")
|
||||||
op._set_attr(_TPU_REPLICATE_ATTR,
|
op._set_attr_with_buf(
|
||||||
attr_value_pb2.AttrValue(s=self._name_as_bytes))
|
_TPU_REPLICATE_ATTR, self._tpu_relicate_attr_buf._buffer)
|
||||||
if self._outside_compilation_cluster:
|
if self._outside_compilation_cluster:
|
||||||
op._set_attr(
|
op._set_attr(
|
||||||
_OUTSIDE_COMPILATION_ATTR,
|
_OUTSIDE_COMPILATION_ATTR,
|
||||||
|
Loading…
Reference in New Issue
Block a user