Mark bfloat16 as trainable.

PiperOrigin-RevId: 303816514
Change-Id: Iebaed635ea7adaf8ded604aba4d81907bce86395
This commit is contained in:
Guangda Lai 2020-03-30 13:58:20 -07:00 committed by TensorFlower Gardener
parent f9a22e49b4
commit d83a344d76

View File

@ -30,4 +30,4 @@ def IsTrainable(tensor_or_dtype):
dtype = dtypes.as_dtype(dtype)
return dtype.base_dtype in (dtypes.float16, dtypes.float32, dtypes.float64,
dtypes.complex64, dtypes.complex128,
dtypes.resource, dtypes.variant)
dtypes.resource, dtypes.variant, dtypes.bfloat16)