Only apply check for non-tensor case
PiperOrigin-RevId: 261408908
This commit is contained in:
parent
68ea31c83c
commit
9c14f6ba30
@ -4031,17 +4031,19 @@ def in_train_phase(x, alt, training=None):
|
||||
if training is None:
|
||||
training = learning_phase()
|
||||
|
||||
if training == 1 or training is True:
|
||||
if callable(x):
|
||||
return x()
|
||||
else:
|
||||
return x
|
||||
# TODO(b/138862903): Handle the case when training is tensor.
|
||||
if not tensor_util.is_tensor(training):
|
||||
if training == 1 or training is True:
|
||||
if callable(x):
|
||||
return x()
|
||||
else:
|
||||
return x
|
||||
|
||||
elif training == 0 or training is False:
|
||||
if callable(alt):
|
||||
return alt()
|
||||
else:
|
||||
return alt
|
||||
elif training == 0 or training is False:
|
||||
if callable(alt):
|
||||
return alt()
|
||||
else:
|
||||
return alt
|
||||
|
||||
# else: assume learning phase is a placeholder tensor.
|
||||
x = switch(training, x, alt)
|
||||
|
Loading…
x
Reference in New Issue
Block a user