Workaround for PartitionedCall trying and failing to run on TPUs when saving

Just omits the function decorator for now. This is pretty terrible and we should fix it, but it will need some work on the TPU side.

Spoke to iga@. Apparently the CPU annotations don't work because the function captures a resource which is on the TPU (and so the eager placer puts the call op on the TPU). One option is to then XLA-compile the function, although that fails right now because we're trying to save strings and XLA doesn't have a kernel for that.

I should also follow up with TPU+checkpointing integration tests.

PiperOrigin-RevId: 226390521
This commit is contained in:
Allen Lavoie 2018-12-20 14:05:51 -08:00 committed by TensorFlower Gardener
parent e9821ef1df
commit 2f4d4da52f

View File

@ -55,22 +55,24 @@ class Saver(object):
filename_tensor = array_ops.placeholder(
shape=[], dtype=dtypes.string, name="saver_filename")
# TODO(allenl): Add save and restore function names to the proto directly.
save_tensor = self.save(filename_tensor)
restore_op = self.restore(filename_tensor).op
signature = (tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),)
# Autograph is off because of reference cycles which must be collected when
# a function is created and destroyed (as in tf.saved_model.save). It's also
# not necessary, so having it off may be slightly faster.
#
# TODO(b/121302372): We should be able to decorate save() and restore()
# unconditionally.
save_tensor = def_function.function(
self.save, input_signature=signature, autograph=False)(filename_tensor)
restore_op = def_function.function(
self.restore, input_signature=signature, autograph=False)(
filename_tensor).op
return saver_pb2.SaverDef(
filename_tensor_name=filename_tensor.name,
save_tensor_name=save_tensor.name,
restore_op_name=restore_op.name,
version=saver_pb2.SaverDef.V2)
@def_function.function(
input_signature=(tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),),
# Autograph is off because of reference cycles which must be collected
# when a function is created and destroyed (as in
# tf.saved_model.save). It's also not necessary, so having it off may be
# slightly faster.
autograph=False,
)
def save(self, file_prefix):
"""Save the saveable objects to a checkpoint with `file_prefix`.
@ -89,13 +91,11 @@ class Saver(object):
tensor_names.append(spec.name)
tensors.append(spec.tensor)
tensor_slices.append(spec.slice_spec)
io_ops.save_v2(file_prefix, tensor_names, tensor_slices, tensors)
return file_prefix
with ops.device("cpu:0"):
with ops.control_dependencies([io_ops.save_v2(
file_prefix, tensor_names, tensor_slices, tensors)]):
return array_ops.identity(file_prefix)
@def_function.function(
input_signature=(tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),),
autograph=False,
)
def restore(self, file_prefix):
"""Restore the saveable objects from a checkpoint with `file_prefix`.
@ -121,8 +121,11 @@ class Saver(object):
file_prefix, tensor_names, tensor_slices, tensor_dtypes)
structured_restored_tensors = nest.pack_sequence_as(
tensor_structure, restored_tensors)
restore_ops = []
for saveable, restored_tensors in zip(self._saveable_objects,
structured_restored_tensors):
saveable.restore(restored_tensors,
restored_shapes=None)
return file_prefix
restore_ops.append(saveable.restore(restored_tensors,
restored_shapes=None))
with ops.device("cpu:0"):
with ops.control_dependencies(restore_ops):
return array_ops.identity(file_prefix)