Ignore partitioned variable in TPU computation.

PiperOrigin-RevId: 211833891
This commit is contained in:
A. Unique TensorFlower 2018-09-06 10:55:42 -07:00 committed by TensorFlower Gardener
parent 58857d06e6
commit 6d893ecfb9

View File

@ -652,13 +652,28 @@ def split_compile_and_replicate(computation,
# TODO(phawkins): consider removing this code. It will
# be less confusing to clients if they knowingly choose to use resource
# variables.
# Partitioned variables is not supported (b/112311320).
def custom_getter(getter, name, *args, **kwargs):
partitioner = kwargs["partitioner"]
if partitioner is None:
return getter(name, *args, **kwargs)
else:
raise ValueError(
"Partitioned variables are not supported on TPU. Got "
"`partitioner` that is {}.".format(partitioner))
vscope = variable_scope.get_variable_scope()
saved_use_resource = vscope.use_resource
saved_custom_getter = vscope.custom_getter
vscope.set_use_resource(True)
vscope.set_custom_getter(custom_getter)
outputs = computation(*computation_inputs)
vscope.set_use_resource(saved_use_resource)
vscope.set_custom_getter(saved_custom_getter)
# If the computation returns `None`, make it an empty tuple.
if outputs is None: