Mark bfloat16 as trainable.
PiperOrigin-RevId: 303816514 Change-Id: Iebaed635ea7adaf8ded604aba4d81907bce86395
This commit is contained in:
parent
f9a22e49b4
commit
d83a344d76
@ -30,4 +30,4 @@ def IsTrainable(tensor_or_dtype):
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
return dtype.base_dtype in (dtypes.float16, dtypes.float32, dtypes.float64,
|
||||
dtypes.complex64, dtypes.complex128,
|
||||
dtypes.resource, dtypes.variant)
|
||||
dtypes.resource, dtypes.variant, dtypes.bfloat16)
|
||||
|
Loading…
Reference in New Issue
Block a user