Make Operation.node_def fetch from C API when enabled
Also fixes previously-introduced memory management bugs in graph_def_versions and op_def. PiperOrigin-RevId: 175314829
This commit is contained in:
parent
550f7220aa
commit
80078ae4a9
@ -1639,7 +1639,7 @@ class Operation(object):
|
||||
def colocation_groups(self):
|
||||
"""Returns the list of colocation groups of the op."""
|
||||
default_colocation_group = [
|
||||
compat.as_bytes("loc:@%s" % self._node_def.name)
|
||||
compat.as_bytes("loc:@%s" % self.name)
|
||||
]
|
||||
try:
|
||||
class_attr = self.get_attr("_class")
|
||||
@ -1894,7 +1894,7 @@ class Operation(object):
|
||||
["^%s" % op.name for op in self._control_inputs])
|
||||
|
||||
def __str__(self):
|
||||
return str(self._node_def)
|
||||
return str(self.node_def)
|
||||
|
||||
def __repr__(self):
|
||||
return "<tf.Operation '%s' type=%s>" % (self.name, self.type)
|
||||
@ -2011,7 +2011,7 @@ class Operation(object):
|
||||
@property
|
||||
def node_def(self):
|
||||
# pylint: disable=line-too-long
|
||||
"""Returns a serialized `NodeDef` representation of this operation.
|
||||
"""Returns the `NodeDef` representation of this operation.
|
||||
|
||||
Returns:
|
||||
A
|
||||
@ -2019,7 +2019,16 @@ class Operation(object):
|
||||
protocol buffer.
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
return self._node_def
|
||||
if self._c_op:
|
||||
with c_api_util.tf_buffer() as buf:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
c_api.TF_OperationToNodeDef(self._c_op, buf, status)
|
||||
data = c_api.TF_GetBuffer(buf)
|
||||
node_def = node_def_pb2.NodeDef()
|
||||
node_def.ParseFromString(compat.as_bytes(data))
|
||||
return node_def
|
||||
else:
|
||||
return self._node_def
|
||||
|
||||
@property
|
||||
def op_def(self):
|
||||
@ -2033,13 +2042,13 @@ class Operation(object):
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
if self._c_op:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
with c_api_util.tf_buffer() as buf:
|
||||
with c_api_util.tf_buffer() as buf:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
# pylint: disable=protected-access
|
||||
c_api.TF_GraphGetOpDef(self._graph._c_graph,
|
||||
compat.as_bytes(self.type), buf, status)
|
||||
# pylint: enable=protected-access
|
||||
data = c_api.TF_GetBuffer(buf)
|
||||
data = c_api.TF_GetBuffer(buf)
|
||||
op_def = op_def_pb2.OpDef()
|
||||
op_def.ParseFromString(compat.as_bytes(data))
|
||||
return op_def
|
||||
@ -2750,10 +2759,10 @@ class Graph(object):
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
if self._c_graph:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
with c_api_util.tf_buffer() as buf:
|
||||
with c_api_util.tf_buffer() as buf:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
c_api.TF_GraphVersions(self._c_graph, buf, status)
|
||||
data = c_api.TF_GetBuffer(buf)
|
||||
data = c_api.TF_GetBuffer(buf)
|
||||
version_def = versions_pb2.VersionDef()
|
||||
version_def.ParseFromString(compat.as_bytes(data))
|
||||
return version_def
|
||||
|
Loading…
Reference in New Issue
Block a user