Only apply check for non-tensor case
PiperOrigin-RevId: 261408908
This commit is contained in:
		
							parent
							
								
									68ea31c83c
								
							
						
					
					
						commit
						9c14f6ba30
					
				| @ -4031,6 +4031,8 @@ def in_train_phase(x, alt, training=None): | |||||||
|   if training is None: |   if training is None: | ||||||
|     training = learning_phase() |     training = learning_phase() | ||||||
| 
 | 
 | ||||||
|  |   # TODO(b/138862903): Handle the case when training is tensor. | ||||||
|  |   if not tensor_util.is_tensor(training): | ||||||
|     if training == 1 or training is True: |     if training == 1 or training is True: | ||||||
|       if callable(x): |       if callable(x): | ||||||
|         return x() |         return x() | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user