Ignore partitioned variable in TPU computation.
PiperOrigin-RevId: 211833891
This commit is contained in:
parent
58857d06e6
commit
6d893ecfb9
@ -652,13 +652,28 @@ def split_compile_and_replicate(computation,
|
||||
# TODO(phawkins): consider removing this code. It will
|
||||
# be less confusing to clients if they knowingly choose to use resource
|
||||
# variables.
|
||||
# Partitioned variables is not supported (b/112311320).
|
||||
def custom_getter(getter, name, *args, **kwargs):
|
||||
partitioner = kwargs["partitioner"]
|
||||
if partitioner is None:
|
||||
return getter(name, *args, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Partitioned variables are not supported on TPU. Got "
|
||||
"`partitioner` that is {}.".format(partitioner))
|
||||
|
||||
vscope = variable_scope.get_variable_scope()
|
||||
|
||||
saved_use_resource = vscope.use_resource
|
||||
saved_custom_getter = vscope.custom_getter
|
||||
|
||||
vscope.set_use_resource(True)
|
||||
vscope.set_custom_getter(custom_getter)
|
||||
|
||||
outputs = computation(*computation_inputs)
|
||||
|
||||
vscope.set_use_resource(saved_use_resource)
|
||||
vscope.set_custom_getter(saved_custom_getter)
|
||||
|
||||
# If the computation returns `None`, make it an empty tuple.
|
||||
if outputs is None:
|
||||
|
Loading…
Reference in New Issue
Block a user