Allow returning attr types from TPUStrategy.run
PiperOrigin-RevId: 315928063 Change-Id: I60a1596d2f7ea7542d1bb56c34c8ac7766ec1d84
This commit is contained in:
parent
920bf272c8
commit
9033264944
@ -414,10 +414,11 @@ def is_flat(outputs):
|
||||
2) A single object
|
||||
3) A list or tuple of Tensors/Operations
|
||||
|
||||
The only structures that this function understands are sequences and
|
||||
dictionaries. E.g. this means that if outputs contains a single
|
||||
user-defined Object, it is considered to be flat. Errors are raised later on
|
||||
if that Object cannot be converted to a Tensor.
|
||||
The only structures that this function understands are sequences,
|
||||
dictionaries and types defined using the attrs library. E.g. this means
|
||||
that if outputs contains a single user-defined Object, it is considered to
|
||||
be flat. Errors are raised later on if that Object cannot be converted to a
|
||||
Tensor.
|
||||
|
||||
Args:
|
||||
outputs: Output from `computation` inside `xla.compile`.
|
||||
@ -429,14 +430,19 @@ def is_flat(outputs):
|
||||
# there is, then outputs is non-flat.
|
||||
if isinstance(outputs, collections.Sequence):
|
||||
for o in outputs:
|
||||
if isinstance(o, collections.Sequence) or isinstance(
|
||||
o, collections.Mapping):
|
||||
if (isinstance(o, collections.Sequence) or
|
||||
isinstance(o, collections.Mapping) or
|
||||
hasattr(o.__class__, '__attrs_attrs__')):
|
||||
return False
|
||||
|
||||
# If outputs is a dict, it is non-flat.
|
||||
if isinstance(outputs, collections.Mapping):
|
||||
return False
|
||||
|
||||
# If outputs is from the attrs library, it is non-flat.
|
||||
if hasattr(outputs.__class__, '__attrs_attrs__'):
|
||||
return False
|
||||
|
||||
# Getting here means either outputs itself is a single non-structured value
|
||||
# or it is a flat list of single non-structured values.
|
||||
return True
|
||||
|
Loading…
x
Reference in New Issue
Block a user