Ignore partitioned variable in TPU computation.
PiperOrigin-RevId: 211833891
This commit is contained in:
parent
58857d06e6
commit
6d893ecfb9
@ -652,13 +652,28 @@ def split_compile_and_replicate(computation,
|
|||||||
# TODO(phawkins): consider removing this code. It will
|
# TODO(phawkins): consider removing this code. It will
|
||||||
# be less confusing to clients if they knowingly choose to use resource
|
# be less confusing to clients if they knowingly choose to use resource
|
||||||
# variables.
|
# variables.
|
||||||
|
# Partitioned variables is not supported (b/112311320).
|
||||||
|
def custom_getter(getter, name, *args, **kwargs):
|
||||||
|
partitioner = kwargs["partitioner"]
|
||||||
|
if partitioner is None:
|
||||||
|
return getter(name, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Partitioned variables are not supported on TPU. Got "
|
||||||
|
"`partitioner` that is {}.".format(partitioner))
|
||||||
|
|
||||||
vscope = variable_scope.get_variable_scope()
|
vscope = variable_scope.get_variable_scope()
|
||||||
|
|
||||||
saved_use_resource = vscope.use_resource
|
saved_use_resource = vscope.use_resource
|
||||||
|
saved_custom_getter = vscope.custom_getter
|
||||||
|
|
||||||
vscope.set_use_resource(True)
|
vscope.set_use_resource(True)
|
||||||
|
vscope.set_custom_getter(custom_getter)
|
||||||
|
|
||||||
outputs = computation(*computation_inputs)
|
outputs = computation(*computation_inputs)
|
||||||
|
|
||||||
vscope.set_use_resource(saved_use_resource)
|
vscope.set_use_resource(saved_use_resource)
|
||||||
|
vscope.set_custom_getter(saved_custom_getter)
|
||||||
|
|
||||||
# If the computation returns `None`, make it an empty tuple.
|
# If the computation returns `None`, make it an empty tuple.
|
||||||
if outputs is None:
|
if outputs is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user