Only apply check for non-tensor case

PiperOrigin-RevId: 261408908
This commit is contained in:
Yanhua Sun 2019-08-02 16:00:01 -07:00 committed by TensorFlower Gardener
parent 68ea31c83c
commit 9c14f6ba30

View File

@ -4031,17 +4031,19 @@ def in_train_phase(x, alt, training=None):
if training is None:
training = learning_phase()
if training == 1 or training is True:
if callable(x):
return x()
else:
return x
# TODO(b/138862903): Handle the case when training is tensor.
if not tensor_util.is_tensor(training):
if training == 1 or training is True:
if callable(x):
return x()
else:
return x
elif training == 0 or training is False:
if callable(alt):
return alt()
else:
return alt
elif training == 0 or training is False:
if callable(alt):
return alt()
else:
return alt
# else: assume learning phase is a placeholder tensor.
x = switch(training, x, alt)