Allow taking gradient of v2 control flow with external colocations by disabling colocation checks.

Enables LSTMTest with v2 control flow which was blocked on this.

PiperOrigin-RevId: 260776280
This commit is contained in:
Saurabh Saxena 2019-07-30 12:55:55 -07:00 committed by TensorFlower Gardener
parent 33440f37fe
commit 769900b011
6 changed files with 100 additions and 4 deletions

View File

@ -1139,3 +1139,8 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
// TODO(bgogul): Set output_resource_shapes_and_types.
}
void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
TF_ImportGraphDefOptions* opts, unsigned char enable) {
opts->opts.validate_colocation_constraints = enable;
}

View File

@ -391,6 +391,11 @@ TF_CAPI_EXPORT extern void TFE_InferShapes(
TF_ShapeAndTypeList** input_resource_shapes_and_types,
TF_ShapeAndTypeList** output_shapes,
TF_ShapeAndTypeList*** output_resource_shapes_and_types, TF_Status* status);
TF_CAPI_EXPORT extern void
TF_ImportGraphDefOptionsSetValidateColocationConstraints(
TF_ImportGraphDefOptions* opts, unsigned char enable);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -62,7 +62,7 @@ def function_def_to_graph(fdef, input_shapes=None, copy_functions=True):
with func_graph.as_default():
# Add all function nodes to the graph.
importer.import_graph_def(graph_def, name="")
importer.import_graph_def_for_function(graph_def, name="")
# Initialize fields specific to FuncGraph.

View File

@ -202,7 +202,8 @@ def _ConvertInputMapValues(name, input_map):
def _PopulateTFImportGraphDefOptions(options, prefix, input_map,
return_elements):
return_elements,
validate_colocation_constraints):
"""Populates the TF_ImportGraphDefOptions `options`."""
c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix)
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, True)
@ -229,6 +230,9 @@ def _PopulateTFImportGraphDefOptions(options, prefix, input_map,
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options,
compat.as_str(name))
c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(
options, validate_colocation_constraints)
def _ProcessNewOps(graph):
"""Processes the newly-added TF_Operations in `graph`."""
@ -384,6 +388,73 @@ def import_graph_def(graph_def,
corresponding to the names in `return_elements`,
and None if `returns_elements` is None.
Raises:
TypeError: If `graph_def` is not a `GraphDef` proto,
`input_map` is not a dictionary mapping strings to `Tensor` objects,
or `return_elements` is not a list of strings.
ValueError: If `input_map`, or `return_elements` contains names that
do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
it refers to an unknown tensor).
"""
return _import_graph_def_internal(
graph_def,
input_map=input_map,
return_elements=return_elements,
name=name,
op_dict=op_dict,
producer_op_list=producer_op_list)
def import_graph_def_for_function( # pylint: disable=invalid-name
graph_def, name=None):
"""Like import_graph_def but does not validate colocation constraints."""
return _import_graph_def_internal(
graph_def, validate_colocation_constraints=False, name=name)
def _import_graph_def_internal( # pylint: disable=invalid-name
graph_def,
input_map=None,
return_elements=None,
validate_colocation_constraints=True,
name=None,
op_dict=None,
producer_op_list=None):
"""Imports the graph from `graph_def` into the current default `Graph`.
This function provides a way to import a serialized TensorFlow
[`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
protocol buffer, and extract individual objects in the `GraphDef` as
`tf.Tensor` and `tf.Operation` objects. Once extracted,
these objects are placed into the current default `Graph`. See
`tf.Graph.as_graph_def` for a way to create a `GraphDef`
proto.
Args:
graph_def: A `GraphDef` proto containing operations to be imported into the
default graph.
input_map: A dictionary mapping input names (as strings) in `graph_def` to
`Tensor` objects. The values of the named input tensors in the imported
graph will be re-mapped to the respective `Tensor` values.
return_elements: A list of strings containing operation names in `graph_def`
that will be returned as `Operation` objects; and/or tensor names in
`graph_def` that will be returned as `Tensor` objects.
validate_colocation_constraints: Whether to validate colocation constraints.
name: (Optional.) A prefix that will be prepended to the names in
`graph_def`. Note that this does not apply to imported function names.
Defaults to `"import"`.
op_dict: (Optional.) Deprecated, do not use.
producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
list of `OpDef`s used by the producer of the graph. If provided,
unrecognized attrs for ops in `graph_def` that have their default value
according to `producer_op_list` will be removed. This will allow some more
`GraphDef`s produced by later binaries to be accepted by earlier binaries.
Returns:
A list of `Operation` and/or `Tensor` objects from the imported graph,
corresponding to the names in `return_elements`,
and None if `returns_elements` is None.
Raises:
TypeError: If `graph_def` is not a `GraphDef` proto,
`input_map` is not a dictionary mapping strings to `Tensor` objects,
@ -416,8 +487,8 @@ def import_graph_def(graph_def,
scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
options = scoped_options.options
_PopulateTFImportGraphDefOptions(options, prefix, input_map,
return_elements)
_PopulateTFImportGraphDefOptions(options, prefix, input_map, return_elements,
validate_colocation_constraints)
# _ProcessNewOps mutates the new operations. _mutation_lock ensures a
# Session.run call cannot occur between creating the TF_Operations in the

View File

@ -579,6 +579,20 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
array_ops.zeros([5, 3, 4], dtype=dtypes.float32),
])
@test_util.run_deprecated_v1
def testExternalColocationGrad(self):
external_t = constant_op.constant(2.)
v0 = constant_op.constant(2.)
def Body(v):
with ops.colocate_with(external_t):
return v * v
ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0]
grad = gradients_impl.gradients(ret, [v0])[0]
self.assertAllEqual(ret, 16.)
self.assertAllEqual(grad, 32.)
def ScalarShape():
return ops.convert_to_tensor([], dtype=dtypes.int32)

View File

@ -166,6 +166,7 @@ limitations under the License.
%rename("%s") TFE_CancellationManagerIsCancelled;
%rename("%s") TFE_CancellationManagerStartCancel;
%rename("%s") TFE_DeleteCancellationManager;
%rename("%s") TF_ImportGraphDefOptionsSetValidateColocationConstraints;
%{
#include "tensorflow/python/eager/pywrap_tfe.h"