Move the _TFBufferWrapper helper to a common place instead of in the tpu codebase.
PiperOrigin-RevId: 295996954 Change-Id: I2557fddd7cdc858ce661dd5a8c7bcd9996d519c6
This commit is contained in:
parent
e0575253d1
commit
ca0bd89c9a
tensorflow/python
@ -97,6 +97,16 @@ class ScopedTFFunction(object):
|
||||
self.func = None
|
||||
|
||||
|
||||
class ScopedTFBuffer(object):
|
||||
"""An internal class to help manage the TF_Buffer lifetime."""
|
||||
|
||||
def __init__(self, buf_string):
|
||||
self.buffer = c_api.TF_NewBufferFromString(compat.as_bytes(buf_string))
|
||||
|
||||
def __del__(self):
|
||||
c_api.TF_DeleteBuffer(self.buffer)
|
||||
|
||||
|
||||
class ApiDefMap(object):
|
||||
"""Wrapper around Tf_ApiDefMap that handles querying and deletion.
|
||||
|
||||
|
@ -27,11 +27,11 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
|
||||
from tensorflow.python.client import pywrap_tf_session
|
||||
from tensorflow.python.compiler.xla import xla
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.framework import auto_control_deps
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import device as pydev
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -251,16 +251,6 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
outside the replicated computation.
|
||||
"""
|
||||
|
||||
class _TFBufferWrapper(object):
|
||||
"""An internal class to help manage the TF_Buffer lifetime."""
|
||||
|
||||
def __init__(self, buf_string):
|
||||
self._buffer = pywrap_tf_session.TF_NewBufferFromString(
|
||||
compat.as_bytes(buf_string))
|
||||
|
||||
def __del__(self):
|
||||
pywrap_tf_session.TF_DeleteBuffer(self._buffer)
|
||||
|
||||
def __init__(self, name, num_replicas, pivot):
|
||||
"""Builds a new TPUReplicateContext.
|
||||
|
||||
@ -285,7 +275,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
self._host_compute_core = []
|
||||
self._name = name
|
||||
self._name_as_bytes = compat.as_bytes(name)
|
||||
self._tpu_relicate_attr_buf = self._TFBufferWrapper(
|
||||
self._tpu_relicate_attr_buf = c_api_util.ScopedTFBuffer(
|
||||
attr_value_pb2.AttrValue(s=self._name_as_bytes).SerializeToString())
|
||||
self._unsupported_ops = []
|
||||
self._pivot = pivot
|
||||
@ -534,8 +524,8 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
"_cloned" not in op.node_def.attr):
|
||||
raise ValueError("TPU computations cannot be nested on op (%s)" %
|
||||
op)
|
||||
op._set_attr_with_buf(
|
||||
_TPU_REPLICATE_ATTR, self._tpu_relicate_attr_buf._buffer)
|
||||
op._set_attr_with_buf(_TPU_REPLICATE_ATTR,
|
||||
self._tpu_relicate_attr_buf.buffer)
|
||||
if self._outside_compilation_cluster:
|
||||
op._set_attr(
|
||||
_OUTSIDE_COMPILATION_ATTR,
|
||||
|
Loading…
Reference in New Issue
Block a user