Make Operation.node_def fetch from C API when enabled

Also fixes previously-introduced memory management bugs in graph_def_versions and op_def.

PiperOrigin-RevId: 175314829
This commit is contained in:
Skye Wanderman-Milne 2017-11-10 11:28:04 -08:00 committed by Andrew Selle
parent 550f7220aa
commit 80078ae4a9

View File

@ -1639,7 +1639,7 @@ class Operation(object):
def colocation_groups(self):
"""Returns the list of colocation groups of the op."""
default_colocation_group = [
compat.as_bytes("loc:@%s" % self._node_def.name)
compat.as_bytes("loc:@%s" % self.name)
]
try:
class_attr = self.get_attr("_class")
@ -1894,7 +1894,7 @@ class Operation(object):
["^%s" % op.name for op in self._control_inputs])
def __str__(self):
return str(self._node_def)
return str(self.node_def)
def __repr__(self):
return "<tf.Operation '%s' type=%s>" % (self.name, self.type)
@ -2011,7 +2011,7 @@ class Operation(object):
@property
def node_def(self):
# pylint: disable=line-too-long
"""Returns a serialized `NodeDef` representation of this operation.
"""Returns the `NodeDef` representation of this operation.
Returns:
A
@ -2019,7 +2019,16 @@ class Operation(object):
protocol buffer.
"""
# pylint: enable=line-too-long
return self._node_def
if self._c_op:
with c_api_util.tf_buffer() as buf:
with errors.raise_exception_on_not_ok_status() as status:
c_api.TF_OperationToNodeDef(self._c_op, buf, status)
data = c_api.TF_GetBuffer(buf)
node_def = node_def_pb2.NodeDef()
node_def.ParseFromString(compat.as_bytes(data))
return node_def
else:
return self._node_def
@property
def op_def(self):
@ -2033,13 +2042,13 @@ class Operation(object):
"""
# pylint: enable=line-too-long
if self._c_op:
with errors.raise_exception_on_not_ok_status() as status:
with c_api_util.tf_buffer() as buf:
with c_api_util.tf_buffer() as buf:
with errors.raise_exception_on_not_ok_status() as status:
# pylint: disable=protected-access
c_api.TF_GraphGetOpDef(self._graph._c_graph,
compat.as_bytes(self.type), buf, status)
# pylint: enable=protected-access
data = c_api.TF_GetBuffer(buf)
data = c_api.TF_GetBuffer(buf)
op_def = op_def_pb2.OpDef()
op_def.ParseFromString(compat.as_bytes(data))
return op_def
@ -2750,10 +2759,10 @@ class Graph(object):
"""
# pylint: enable=line-too-long
if self._c_graph:
with errors.raise_exception_on_not_ok_status() as status:
with c_api_util.tf_buffer() as buf:
with c_api_util.tf_buffer() as buf:
with errors.raise_exception_on_not_ok_status() as status:
c_api.TF_GraphVersions(self._c_graph, buf, status)
data = c_api.TF_GetBuffer(buf)
data = c_api.TF_GetBuffer(buf)
version_def = versions_pb2.VersionDef()
version_def.ParseFromString(compat.as_bytes(data))
return version_def