Do not have Sequential models cast inputs
PiperOrigin-RevId: 276584832 Change-Id: I90c5b0d97ef31393de82a38b6340b3b3dc489d60
This commit is contained in:
parent
00ea430324
commit
b74e8fb4a5
@ -100,7 +100,7 @@ class Sequential(training.Model):
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def __init__(self, layers=None, name=None):
|
||||
super(Sequential, self).__init__(name=name)
|
||||
super(Sequential, self).__init__(name=name, autocast=False)
|
||||
self.supports_masking = True
|
||||
self._build_input_shape = None
|
||||
self._compute_output_and_mask_jointly = True
|
||||
|
@ -449,6 +449,21 @@ class TestSequential(keras_parameterized.TestCase):
|
||||
model.pop()
|
||||
self.assertEqual(model._layers[-1], layer)
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_sequential_does_not_autocast(self):
|
||||
|
||||
class AssertFloat64InputLayer(keras.layers.Layer):
|
||||
|
||||
def __init__(self):
|
||||
super(AssertFloat64InputLayer, self).__init__(autocast=False)
|
||||
|
||||
def call(self, inputs):
|
||||
assert inputs.dtype == 'float64', 'inputs are %s' % inputs.dtype
|
||||
return array_ops.identity(inputs)
|
||||
|
||||
model = keras.Sequential([AssertFloat64InputLayer(), keras.layers.Dense(4)])
|
||||
model(np.random.random((4, 4)))
|
||||
|
||||
|
||||
class TestSequentialEagerIntegration(keras_parameterized.TestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user