Move the _TFBufferWrapper helper to a common place instead of in the tpu codebase.

PiperOrigin-RevId: 295996954
Change-Id: I2557fddd7cdc858ce661dd5a8c7bcd9996d519c6
This commit is contained in:
Akshay Modi 2020-02-19 10:15:08 -08:00 committed by TensorFlower Gardener
parent e0575253d1
commit ca0bd89c9a
2 changed files with 14 additions and 14 deletions
tensorflow/python

View File

@ -97,6 +97,16 @@ class ScopedTFFunction(object):
self.func = None
class ScopedTFBuffer(object):
"""An internal class to help manage the TF_Buffer lifetime."""
def __init__(self, buf_string):
self.buffer = c_api.TF_NewBufferFromString(compat.as_bytes(buf_string))
def __del__(self):
c_api.TF_DeleteBuffer(self.buffer)
class ApiDefMap(object):
"""Wrapper around Tf_ApiDefMap that handles querying and deletion.

View File

@ -27,11 +27,11 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.compiler.xla import xla
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import auto_control_deps
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import config
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
@ -251,16 +251,6 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
outside the replicated computation.
"""
class _TFBufferWrapper(object):
"""An internal class to help manage the TF_Buffer lifetime."""
def __init__(self, buf_string):
self._buffer = pywrap_tf_session.TF_NewBufferFromString(
compat.as_bytes(buf_string))
def __del__(self):
pywrap_tf_session.TF_DeleteBuffer(self._buffer)
def __init__(self, name, num_replicas, pivot):
"""Builds a new TPUReplicateContext.
@ -285,7 +275,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
self._host_compute_core = []
self._name = name
self._name_as_bytes = compat.as_bytes(name)
self._tpu_relicate_attr_buf = self._TFBufferWrapper(
self._tpu_relicate_attr_buf = c_api_util.ScopedTFBuffer(
attr_value_pb2.AttrValue(s=self._name_as_bytes).SerializeToString())
self._unsupported_ops = []
self._pivot = pivot
@ -534,8 +524,8 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
"_cloned" not in op.node_def.attr):
raise ValueError("TPU computations cannot be nested on op (%s)" %
op)
op._set_attr_with_buf(
_TPU_REPLICATE_ATTR, self._tpu_relicate_attr_buf._buffer)
op._set_attr_with_buf(_TPU_REPLICATE_ATTR,
self._tpu_relicate_attr_buf.buffer)
if self._outside_compilation_cluster:
op._set_attr(
_OUTSIDE_COMPILATION_ATTR,