Move the _TFBufferWrapper helper to a common place instead of in the tpu codebase.
PiperOrigin-RevId: 295996954 Change-Id: I2557fddd7cdc858ce661dd5a8c7bcd9996d519c6
This commit is contained in:
parent
e0575253d1
commit
ca0bd89c9a
@ -97,6 +97,16 @@ class ScopedTFFunction(object):
|
|||||||
self.func = None
|
self.func = None
|
||||||
|
|
||||||
|
|
||||||
|
class ScopedTFBuffer(object):
|
||||||
|
"""An internal class to help manage the TF_Buffer lifetime."""
|
||||||
|
|
||||||
|
def __init__(self, buf_string):
|
||||||
|
self.buffer = c_api.TF_NewBufferFromString(compat.as_bytes(buf_string))
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
c_api.TF_DeleteBuffer(self.buffer)
|
||||||
|
|
||||||
|
|
||||||
class ApiDefMap(object):
|
class ApiDefMap(object):
|
||||||
"""Wrapper around Tf_ApiDefMap that handles querying and deletion.
|
"""Wrapper around Tf_ApiDefMap that handles querying and deletion.
|
||||||
|
|
||||||
|
|||||||
@ -27,11 +27,11 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
|||||||
|
|
||||||
from tensorflow.core.framework import attr_value_pb2
|
from tensorflow.core.framework import attr_value_pb2
|
||||||
from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
|
from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
|
||||||
from tensorflow.python.client import pywrap_tf_session
|
|
||||||
from tensorflow.python.compiler.xla import xla
|
from tensorflow.python.compiler.xla import xla
|
||||||
from tensorflow.python.distribute import device_util
|
from tensorflow.python.distribute import device_util
|
||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.framework import auto_control_deps
|
from tensorflow.python.framework import auto_control_deps
|
||||||
|
from tensorflow.python.framework import c_api_util
|
||||||
from tensorflow.python.framework import config
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import device as pydev
|
from tensorflow.python.framework import device as pydev
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -251,16 +251,6 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
|||||||
outside the replicated computation.
|
outside the replicated computation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class _TFBufferWrapper(object):
|
|
||||||
"""An internal class to help manage the TF_Buffer lifetime."""
|
|
||||||
|
|
||||||
def __init__(self, buf_string):
|
|
||||||
self._buffer = pywrap_tf_session.TF_NewBufferFromString(
|
|
||||||
compat.as_bytes(buf_string))
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
pywrap_tf_session.TF_DeleteBuffer(self._buffer)
|
|
||||||
|
|
||||||
def __init__(self, name, num_replicas, pivot):
|
def __init__(self, name, num_replicas, pivot):
|
||||||
"""Builds a new TPUReplicateContext.
|
"""Builds a new TPUReplicateContext.
|
||||||
|
|
||||||
@ -285,7 +275,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
|||||||
self._host_compute_core = []
|
self._host_compute_core = []
|
||||||
self._name = name
|
self._name = name
|
||||||
self._name_as_bytes = compat.as_bytes(name)
|
self._name_as_bytes = compat.as_bytes(name)
|
||||||
self._tpu_relicate_attr_buf = self._TFBufferWrapper(
|
self._tpu_relicate_attr_buf = c_api_util.ScopedTFBuffer(
|
||||||
attr_value_pb2.AttrValue(s=self._name_as_bytes).SerializeToString())
|
attr_value_pb2.AttrValue(s=self._name_as_bytes).SerializeToString())
|
||||||
self._unsupported_ops = []
|
self._unsupported_ops = []
|
||||||
self._pivot = pivot
|
self._pivot = pivot
|
||||||
@ -534,8 +524,8 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
|||||||
"_cloned" not in op.node_def.attr):
|
"_cloned" not in op.node_def.attr):
|
||||||
raise ValueError("TPU computations cannot be nested on op (%s)" %
|
raise ValueError("TPU computations cannot be nested on op (%s)" %
|
||||||
op)
|
op)
|
||||||
op._set_attr_with_buf(
|
op._set_attr_with_buf(_TPU_REPLICATE_ATTR,
|
||||||
_TPU_REPLICATE_ATTR, self._tpu_relicate_attr_buf._buffer)
|
self._tpu_relicate_attr_buf.buffer)
|
||||||
if self._outside_compilation_cluster:
|
if self._outside_compilation_cluster:
|
||||||
op._set_attr(
|
op._set_attr(
|
||||||
_OUTSIDE_COMPILATION_ATTR,
|
_OUTSIDE_COMPILATION_ATTR,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user