Allow returning attr types from TPUStrategy.run

PiperOrigin-RevId: 315928063
Change-Id: I60a1596d2f7ea7542d1bb56c34c8ac7766ec1d84
This commit is contained in:
A. Unique TensorFlower 2020-06-11 10:33:09 -07:00 committed by TensorFlower Gardener
parent 920bf272c8
commit 9033264944

View File

@ -414,10 +414,11 @@ def is_flat(outputs):
2) A single object
3) A list or tuple of Tensors/Operations
The only structures that this function understands are sequences and
dictionaries. E.g. this means that if outputs contains a single
user-defined Object, it is considered to be flat. Errors are raised later on
if that Object cannot be converted to a Tensor.
The only structures that this function understands are sequences,
dictionaries and types defined using the attrs library. E.g. this means
that if outputs contains a single user-defined Object, it is considered to
be flat. Errors are raised later on if that Object cannot be converted to a
Tensor.
Args:
outputs: Output from `computation` inside `xla.compile`.
@ -429,14 +430,19 @@ def is_flat(outputs):
# there is, then outputs is non-flat.
if isinstance(outputs, collections.Sequence):
for o in outputs:
if isinstance(o, collections.Sequence) or isinstance(
o, collections.Mapping):
if (isinstance(o, collections.Sequence) or
isinstance(o, collections.Mapping) or
hasattr(o.__class__, '__attrs_attrs__')):
return False
# If outputs is a dict, it is non-flat.
if isinstance(outputs, collections.Mapping):
return False
# If outputs is from the attrs library, it is non-flat.
if hasattr(outputs.__class__, '__attrs_attrs__'):
return False
# Getting here means either outputs itself is a single non-structured value
# or it is a flat list of single non-structured values.
return True