Allow taking gradient of v2 control flow with external colocations by disabling colocation checks.
Enables LSTMTest with v2 control flow which was blocked on this. PiperOrigin-RevId: 260776280
This commit is contained in:
parent
33440f37fe
commit
769900b011
@ -1139,3 +1139,8 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
||||
|
||||
// TODO(bgogul): Set output_resource_shapes_and_types.
|
||||
}
|
||||
|
||||
void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
||||
TF_ImportGraphDefOptions* opts, unsigned char enable) {
|
||||
opts->opts.validate_colocation_constraints = enable;
|
||||
}
|
||||
|
@ -391,6 +391,11 @@ TF_CAPI_EXPORT extern void TFE_InferShapes(
|
||||
TF_ShapeAndTypeList** input_resource_shapes_and_types,
|
||||
TF_ShapeAndTypeList** output_shapes,
|
||||
TF_ShapeAndTypeList*** output_resource_shapes_and_types, TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void
|
||||
TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
||||
TF_ImportGraphDefOptions* opts, unsigned char enable);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -62,7 +62,7 @@ def function_def_to_graph(fdef, input_shapes=None, copy_functions=True):
|
||||
|
||||
with func_graph.as_default():
|
||||
# Add all function nodes to the graph.
|
||||
importer.import_graph_def(graph_def, name="")
|
||||
importer.import_graph_def_for_function(graph_def, name="")
|
||||
|
||||
# Initialize fields specific to FuncGraph.
|
||||
|
||||
|
@ -202,7 +202,8 @@ def _ConvertInputMapValues(name, input_map):
|
||||
|
||||
|
||||
def _PopulateTFImportGraphDefOptions(options, prefix, input_map,
|
||||
return_elements):
|
||||
return_elements,
|
||||
validate_colocation_constraints):
|
||||
"""Populates the TF_ImportGraphDefOptions `options`."""
|
||||
c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix)
|
||||
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, True)
|
||||
@ -229,6 +230,9 @@ def _PopulateTFImportGraphDefOptions(options, prefix, input_map,
|
||||
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options,
|
||||
compat.as_str(name))
|
||||
|
||||
c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
||||
options, validate_colocation_constraints)
|
||||
|
||||
|
||||
def _ProcessNewOps(graph):
|
||||
"""Processes the newly-added TF_Operations in `graph`."""
|
||||
@ -384,6 +388,73 @@ def import_graph_def(graph_def,
|
||||
corresponding to the names in `return_elements`,
|
||||
and None if `returns_elements` is None.
|
||||
|
||||
Raises:
|
||||
TypeError: If `graph_def` is not a `GraphDef` proto,
|
||||
`input_map` is not a dictionary mapping strings to `Tensor` objects,
|
||||
or `return_elements` is not a list of strings.
|
||||
ValueError: If `input_map`, or `return_elements` contains names that
|
||||
do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
|
||||
it refers to an unknown tensor).
|
||||
"""
|
||||
return _import_graph_def_internal(
|
||||
graph_def,
|
||||
input_map=input_map,
|
||||
return_elements=return_elements,
|
||||
name=name,
|
||||
op_dict=op_dict,
|
||||
producer_op_list=producer_op_list)
|
||||
|
||||
|
||||
def import_graph_def_for_function( # pylint: disable=invalid-name
|
||||
graph_def, name=None):
|
||||
"""Like import_graph_def but does not validate colocation constraints."""
|
||||
return _import_graph_def_internal(
|
||||
graph_def, validate_colocation_constraints=False, name=name)
|
||||
|
||||
|
||||
def _import_graph_def_internal( # pylint: disable=invalid-name
|
||||
graph_def,
|
||||
input_map=None,
|
||||
return_elements=None,
|
||||
validate_colocation_constraints=True,
|
||||
name=None,
|
||||
op_dict=None,
|
||||
producer_op_list=None):
|
||||
"""Imports the graph from `graph_def` into the current default `Graph`.
|
||||
|
||||
This function provides a way to import a serialized TensorFlow
|
||||
[`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
|
||||
protocol buffer, and extract individual objects in the `GraphDef` as
|
||||
`tf.Tensor` and `tf.Operation` objects. Once extracted,
|
||||
these objects are placed into the current default `Graph`. See
|
||||
`tf.Graph.as_graph_def` for a way to create a `GraphDef`
|
||||
proto.
|
||||
|
||||
Args:
|
||||
graph_def: A `GraphDef` proto containing operations to be imported into the
|
||||
default graph.
|
||||
input_map: A dictionary mapping input names (as strings) in `graph_def` to
|
||||
`Tensor` objects. The values of the named input tensors in the imported
|
||||
graph will be re-mapped to the respective `Tensor` values.
|
||||
return_elements: A list of strings containing operation names in `graph_def`
|
||||
that will be returned as `Operation` objects; and/or tensor names in
|
||||
`graph_def` that will be returned as `Tensor` objects.
|
||||
validate_colocation_constraints: Whether to validate colocation constraints.
|
||||
name: (Optional.) A prefix that will be prepended to the names in
|
||||
`graph_def`. Note that this does not apply to imported function names.
|
||||
Defaults to `"import"`.
|
||||
op_dict: (Optional.) Deprecated, do not use.
|
||||
producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
|
||||
list of `OpDef`s used by the producer of the graph. If provided,
|
||||
unrecognized attrs for ops in `graph_def` that have their default value
|
||||
according to `producer_op_list` will be removed. This will allow some more
|
||||
`GraphDef`s produced by later binaries to be accepted by earlier binaries.
|
||||
|
||||
Returns:
|
||||
A list of `Operation` and/or `Tensor` objects from the imported graph,
|
||||
corresponding to the names in `return_elements`,
|
||||
and None if `returns_elements` is None.
|
||||
|
||||
Raises:
|
||||
TypeError: If `graph_def` is not a `GraphDef` proto,
|
||||
`input_map` is not a dictionary mapping strings to `Tensor` objects,
|
||||
@ -416,8 +487,8 @@ def import_graph_def(graph_def,
|
||||
|
||||
scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
|
||||
options = scoped_options.options
|
||||
_PopulateTFImportGraphDefOptions(options, prefix, input_map,
|
||||
return_elements)
|
||||
_PopulateTFImportGraphDefOptions(options, prefix, input_map, return_elements,
|
||||
validate_colocation_constraints)
|
||||
|
||||
# _ProcessNewOps mutates the new operations. _mutation_lock ensures a
|
||||
# Session.run call cannot occur between creating the TF_Operations in the
|
||||
|
@ -579,6 +579,20 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
array_ops.zeros([5, 3, 4], dtype=dtypes.float32),
|
||||
])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testExternalColocationGrad(self):
|
||||
external_t = constant_op.constant(2.)
|
||||
v0 = constant_op.constant(2.)
|
||||
|
||||
def Body(v):
|
||||
with ops.colocate_with(external_t):
|
||||
return v * v
|
||||
|
||||
ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0]
|
||||
grad = gradients_impl.gradients(ret, [v0])[0]
|
||||
self.assertAllEqual(ret, 16.)
|
||||
self.assertAllEqual(grad, 32.)
|
||||
|
||||
|
||||
def ScalarShape():
|
||||
return ops.convert_to_tensor([], dtype=dtypes.int32)
|
||||
|
@ -166,6 +166,7 @@ limitations under the License.
|
||||
%rename("%s") TFE_CancellationManagerIsCancelled;
|
||||
%rename("%s") TFE_CancellationManagerStartCancel;
|
||||
%rename("%s") TFE_DeleteCancellationManager;
|
||||
%rename("%s") TF_ImportGraphDefOptionsSetValidateColocationConstraints;
|
||||
|
||||
%{
|
||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||
|
Loading…
Reference in New Issue
Block a user