Only apply check for non-tensor case
PiperOrigin-RevId: 261408908
This commit is contained in:
parent
68ea31c83c
commit
9c14f6ba30
@ -4031,6 +4031,8 @@ def in_train_phase(x, alt, training=None):
|
||||
if training is None:
|
||||
training = learning_phase()
|
||||
|
||||
# TODO(b/138862903): Handle the case when training is tensor.
|
||||
if not tensor_util.is_tensor(training):
|
||||
if training == 1 or training is True:
|
||||
if callable(x):
|
||||
return x()
|
||||
|
Loading…
x
Reference in New Issue
Block a user