Make _USE_C_API = True and_USE_C_SHAPES = False work with import_graph_def.

Without this change, shapes wouldn't be correctly computed for
operations created via import_graph_def.

PiperOrigin-RevId: 189670312
This commit is contained in:
Skye Wanderman-Milne 2018-03-19 17:34:47 -07:00 committed by TensorFlower Gardener
parent b6b4ec642a
commit 2714c07c93
6 changed files with 43 additions and 33 deletions

View File

@ -62,8 +62,7 @@ from tensorflow.python.util import compat
ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape)
# TODO(skyewm): reenable when this works with _USE_C_SHAPES=False
# @test_util.with_c_api
@test_util.with_c_api
class SessionTest(test_util.TensorFlowTestCase):
def setUp(self):

View File

@ -489,23 +489,25 @@ def import_graph_def(graph_def,
# Convert to ValueError for backwards compatibility.
raise ValueError(str(e))
_ProcessNewOps(graph)
# Create _DefinedFunctions for any imported functions.
#
# We do this by creating _DefinedFunctions directly from `graph_def`, and
# adding them to `graph`. Adding an existing function to a TF_Graph is a
# no-op, so this only has the effect of updating the Python state (usually
# _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
#
# TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
# TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
# TODO(b/74620627): move this after _ProcessNewOps outside the lock once
# _USE_C_SHAPES is removed.
if graph_def.library and graph_def.library.function:
# pylint: disable=protected-access
functions = function._from_library(graph_def.library)
for f in functions:
f.add_to_graph(graph)
# pylint: enable=protected-access
# Create _DefinedFunctions for any imported functions.
#
# We do this by creating _DefinedFunctions directly from `graph_def`, and
# adding them to `graph`. Adding an existing function to a TF_Graph is a
# no-op, so this only has the effect of updating the Python state (usually
# _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
#
# TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
# TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
if graph_def.library and graph_def.library.function:
# pylint: disable=protected-access
functions = function._from_library(graph_def.library)
for f in functions:
f.add_to_graph(graph)
# pylint: enable=protected-access
_ProcessNewOps(graph)
# Treat input mappings that don't appear in the graph as an error, because
# they are likely to be due to a typo.

View File

@ -31,6 +31,7 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_ops # pylint: disable=unused-import
from tensorflow.python.framework import test_util
from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@ -43,8 +44,7 @@ import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
# TODO(skyewm): reenable when this works with _USE_C_SHAPES=False
# @test_util.with_c_api
@test_util.with_c_api
class ImportGraphDefTest(test.TestCase):
def _MakeGraphDef(self,

View File

@ -285,8 +285,7 @@ class SimpleMetaGraphTest(test.TestCase):
self.assertIs(global_vars[0], trainable_vars[0])
# TODO(skyewm): reenable when this works with _USE_C_SHAPES=False
# @test_util.with_c_api
@test_util.with_c_api
class ScopedMetaGraphTest(test.TestCase):
def _testScopedExport(self, test_dir, exported_filenames):

View File

@ -3303,6 +3303,20 @@ class Graph(object):
input_types=input_types,
original_op=self._default_original_op,
op_def=op_def)
# TODO(vrv): Instead of eagerly filling in shape property for every op,
# only populate the shape when requested.
#
# TODO(skyewm): unlike in the original Python implementation, the C API
# always computes shape information (even for function calls, which the
# original Python shape inference code doesn't handle). Deprecate the
# compute_shapes argument.
#
# TODO(b/74620627): move this back to _create_op_helper once _USE_C_SHAPES
# is removed
if (ret._c_op and _USE_C_SHAPES) or compute_shapes: # pylint: disable=protected-access
set_shapes_for_outputs(ret)
self._create_op_helper(ret, compute_shapes=compute_shapes,
compute_device=compute_device)
return ret
@ -3336,15 +3350,6 @@ class Graph(object):
def _create_op_helper(self, op, compute_shapes=True, compute_device=True):
"""Common logic for creating an op in this graph."""
# TODO(vrv): Instead of eagerly filling in shape property for every op, only
# populate the shape when requested.
#
# TODO(skyewm): unlike in the original Python implementation, the C API
# always computes shape information (even for function calls, which the
# original Python shape inference code doesn't handle). Deprecate the
# compute_shapes argument.
if (op._c_op and _USE_C_SHAPES) or compute_shapes: # pylint: disable=protected-access
set_shapes_for_outputs(op)
# TODO(b/XXXX): move to Operation.__init__ once _USE_C_API flag is removed.
self._add_op(op)
@ -3449,6 +3454,12 @@ class Graph(object):
]
for op in new_ops:
# The Python shape inference code does not support imported functions. It
# also needs access to op.inputs, which is why we call it here.
# TODO(b/74620627): move this back to _create_op_helper once _USE_C_SHAPES
# is removed.
if not self._is_function(op.type) or _USE_C_SHAPES:
set_shapes_for_outputs(op)
new_control_inputs = self._control_dependencies_for_inputs(op.inputs)
# pylint: disable=protected-access
op._add_control_inputs(new_control_inputs)

View File

@ -1739,8 +1739,7 @@ class CheckpointStateTest(test.TestCase):
os.path.join(save_dir, "./model.ckpt-687529"))
# TODO(skyewm): reenable when this works with _USE_C_SHAPES=False
# @test_util.with_c_api
@test_util.with_c_api
class MetaGraphTest(test.TestCase):
def _get_test_dir(self, dirname):