In resource_variable_ops.copy_to_graph_uninitialized, instantiate

`UninitializedVariable` under var.device's scope.

It was discovered that when using tf.saved_model.save with MultiWorkerMirroredStrategy, it is possible that a variable does not have device attribute and tf.distribute code errors out. This fixes it and a test that would have failed is added to cover such use case.

PiperOrigin-RevId: 276383534
Change-Id: I467bdaf5946da7dd1f3722c91a586830566608de
This commit is contained in:
Rick Chao 2019-10-23 17:13:28 -07:00 committed by TensorFlower Gardener
parent 2d3b88b164
commit 27935b664b

View File

@ -259,7 +259,9 @@ class _SaveableView(object):
# created component variables.
new_vars = []
for v in obj.values:
new_variable = resource_variable_ops.copy_to_graph_uninitialized(v)
# Ensure the variables are created with device attribute set.
with ops.device(v.device):
new_variable = resource_variable_ops.copy_to_graph_uninitialized(v)
object_map[v] = new_variable
new_vars.append(new_variable)
resource_map[v.handle] = new_variable.handle