Allows calling keras layers in eager mode.
PiperOrigin-RevId: 173129805
This commit is contained in:
parent
4ec6f2b07c
commit
3ed049b673
@ -26,6 +26,7 @@ import os
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from six.moves import zip # pylint: disable=redefined-builtin
|
from six.moves import zip # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.keras._impl.keras import backend as K
|
from tensorflow.python.keras._impl.keras import backend as K
|
||||||
from tensorflow.python.keras._impl.keras.utils import conv_utils
|
from tensorflow.python.keras._impl.keras.utils import conv_utils
|
||||||
@ -250,6 +251,8 @@ class Layer(tf_base_layers.Layer):
|
|||||||
"""
|
"""
|
||||||
# Actually call the layer (optionally building it).
|
# Actually call the layer (optionally building it).
|
||||||
output = super(Layer, self).__call__(inputs, **kwargs)
|
output = super(Layer, self).__call__(inputs, **kwargs)
|
||||||
|
if context.in_eager_mode():
|
||||||
|
return output
|
||||||
|
|
||||||
# Update learning phase info.
|
# Update learning phase info.
|
||||||
output_tensors = _to_list(output)
|
output_tensors = _to_list(output)
|
||||||
|
@ -20,8 +20,11 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.keras._impl import keras
|
from tensorflow.python.keras._impl import keras
|
||||||
from tensorflow.python.keras._impl.keras import testing_utils
|
from tensorflow.python.keras._impl.keras import testing_utils
|
||||||
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -198,6 +201,12 @@ class CoreLayersTest(test.TestCase):
|
|||||||
self.assertEqual(layer.kernel.constraint, k_constraint)
|
self.assertEqual(layer.kernel.constraint, k_constraint)
|
||||||
self.assertEqual(layer.bias.constraint, b_constraint)
|
self.assertEqual(layer.bias.constraint, b_constraint)
|
||||||
|
|
||||||
|
def test_eager_dense(self):
|
||||||
|
with context.eager_mode():
|
||||||
|
l = keras.layers.Dense(units=3,
|
||||||
|
kernel_initializer=init_ops.zeros_initializer())
|
||||||
|
self.assertAllEqual(l(constant_op.constant([[1.0]])), [[0., 0., 0.]])
|
||||||
|
|
||||||
def test_activity_regularization(self):
|
def test_activity_regularization(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
layer = keras.layers.ActivityRegularization(l1=0.1)
|
layer = keras.layers.ActivityRegularization(l1=0.1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user