Make _USE_C_API = True and_USE_C_SHAPES = False work with import_graph_def.
Without this change, shapes wouldn't be correctly computed for operations created via import_graph_def. PiperOrigin-RevId: 189670312
This commit is contained in:
parent
b6b4ec642a
commit
2714c07c93
@ -62,8 +62,7 @@ from tensorflow.python.util import compat
|
||||
ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape)
|
||||
|
||||
|
||||
# TODO(skyewm): reenable when this works with _USE_C_SHAPES=False
|
||||
# @test_util.with_c_api
|
||||
@test_util.with_c_api
|
||||
class SessionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -489,23 +489,25 @@ def import_graph_def(graph_def,
|
||||
# Convert to ValueError for backwards compatibility.
|
||||
raise ValueError(str(e))
|
||||
|
||||
_ProcessNewOps(graph)
|
||||
# Create _DefinedFunctions for any imported functions.
|
||||
#
|
||||
# We do this by creating _DefinedFunctions directly from `graph_def`, and
|
||||
# adding them to `graph`. Adding an existing function to a TF_Graph is a
|
||||
# no-op, so this only has the effect of updating the Python state (usually
|
||||
# _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
|
||||
#
|
||||
# TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
|
||||
# TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
|
||||
# TODO(b/74620627): move this after _ProcessNewOps outside the lock once
|
||||
# _USE_C_SHAPES is removed.
|
||||
if graph_def.library and graph_def.library.function:
|
||||
# pylint: disable=protected-access
|
||||
functions = function._from_library(graph_def.library)
|
||||
for f in functions:
|
||||
f.add_to_graph(graph)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# Create _DefinedFunctions for any imported functions.
|
||||
#
|
||||
# We do this by creating _DefinedFunctions directly from `graph_def`, and
|
||||
# adding them to `graph`. Adding an existing function to a TF_Graph is a
|
||||
# no-op, so this only has the effect of updating the Python state (usually
|
||||
# _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
|
||||
#
|
||||
# TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
|
||||
# TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
|
||||
if graph_def.library and graph_def.library.function:
|
||||
# pylint: disable=protected-access
|
||||
functions = function._from_library(graph_def.library)
|
||||
for f in functions:
|
||||
f.add_to_graph(graph)
|
||||
# pylint: enable=protected-access
|
||||
_ProcessNewOps(graph)
|
||||
|
||||
# Treat input mappings that don't appear in the graph as an error, because
|
||||
# they are likely to be due to a typo.
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import importer
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_ops # pylint: disable=unused-import
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import versions
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -43,8 +44,7 @@ import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
# TODO(skyewm): reenable when this works with _USE_C_SHAPES=False
|
||||
# @test_util.with_c_api
|
||||
@test_util.with_c_api
|
||||
class ImportGraphDefTest(test.TestCase):
|
||||
|
||||
def _MakeGraphDef(self,
|
||||
|
@ -285,8 +285,7 @@ class SimpleMetaGraphTest(test.TestCase):
|
||||
self.assertIs(global_vars[0], trainable_vars[0])
|
||||
|
||||
|
||||
# TODO(skyewm): reenable when this works with _USE_C_SHAPES=False
|
||||
# @test_util.with_c_api
|
||||
@test_util.with_c_api
|
||||
class ScopedMetaGraphTest(test.TestCase):
|
||||
|
||||
def _testScopedExport(self, test_dir, exported_filenames):
|
||||
|
@ -3303,6 +3303,20 @@ class Graph(object):
|
||||
input_types=input_types,
|
||||
original_op=self._default_original_op,
|
||||
op_def=op_def)
|
||||
|
||||
# TODO(vrv): Instead of eagerly filling in shape property for every op,
|
||||
# only populate the shape when requested.
|
||||
#
|
||||
# TODO(skyewm): unlike in the original Python implementation, the C API
|
||||
# always computes shape information (even for function calls, which the
|
||||
# original Python shape inference code doesn't handle). Deprecate the
|
||||
# compute_shapes argument.
|
||||
#
|
||||
# TODO(b/74620627): move this back to _create_op_helper once _USE_C_SHAPES
|
||||
# is removed
|
||||
if (ret._c_op and _USE_C_SHAPES) or compute_shapes: # pylint: disable=protected-access
|
||||
set_shapes_for_outputs(ret)
|
||||
|
||||
self._create_op_helper(ret, compute_shapes=compute_shapes,
|
||||
compute_device=compute_device)
|
||||
return ret
|
||||
@ -3336,15 +3350,6 @@ class Graph(object):
|
||||
|
||||
def _create_op_helper(self, op, compute_shapes=True, compute_device=True):
|
||||
"""Common logic for creating an op in this graph."""
|
||||
# TODO(vrv): Instead of eagerly filling in shape property for every op, only
|
||||
# populate the shape when requested.
|
||||
#
|
||||
# TODO(skyewm): unlike in the original Python implementation, the C API
|
||||
# always computes shape information (even for function calls, which the
|
||||
# original Python shape inference code doesn't handle). Deprecate the
|
||||
# compute_shapes argument.
|
||||
if (op._c_op and _USE_C_SHAPES) or compute_shapes: # pylint: disable=protected-access
|
||||
set_shapes_for_outputs(op)
|
||||
# TODO(b/XXXX): move to Operation.__init__ once _USE_C_API flag is removed.
|
||||
self._add_op(op)
|
||||
|
||||
@ -3449,6 +3454,12 @@ class Graph(object):
|
||||
]
|
||||
|
||||
for op in new_ops:
|
||||
# The Python shape inference code does not support imported functions. It
|
||||
# also needs access to op.inputs, which is why we call it here.
|
||||
# TODO(b/74620627): move this back to _create_op_helper once _USE_C_SHAPES
|
||||
# is removed.
|
||||
if not self._is_function(op.type) or _USE_C_SHAPES:
|
||||
set_shapes_for_outputs(op)
|
||||
new_control_inputs = self._control_dependencies_for_inputs(op.inputs)
|
||||
# pylint: disable=protected-access
|
||||
op._add_control_inputs(new_control_inputs)
|
||||
|
@ -1739,8 +1739,7 @@ class CheckpointStateTest(test.TestCase):
|
||||
os.path.join(save_dir, "./model.ckpt-687529"))
|
||||
|
||||
|
||||
# TODO(skyewm): reenable when this works with _USE_C_SHAPES=False
|
||||
# @test_util.with_c_api
|
||||
@test_util.with_c_api
|
||||
class MetaGraphTest(test.TestCase):
|
||||
|
||||
def _get_test_dir(self, dirname):
|
||||
|
Loading…
Reference in New Issue
Block a user