Fix bug where TensorFlowOpLayer would autocast inputs.
PiperOrigin-RevId: 261395761
This commit is contained in:
parent
6fba1efce3
commit
686c123392
@ -2488,8 +2488,11 @@ class TensorFlowOpLayer(Layer):
|
||||
constants=None,
|
||||
trainable=True,
|
||||
dtype=None):
|
||||
# Pass autocast=False, as if inputs are cast, input types might not match
|
||||
# Operation type.
|
||||
super(TensorFlowOpLayer, self).__init__(
|
||||
name=_TF_OP_LAYER_NAME_PREFIX + name, trainable=trainable, dtype=dtype)
|
||||
name=_TF_OP_LAYER_NAME_PREFIX + name, trainable=trainable, dtype=dtype,
|
||||
autocast=False)
|
||||
if not isinstance(node_def, bytes):
|
||||
node_def = node_def.encode('utf-8')
|
||||
self.node_def = node_def_pb2.NodeDef.FromString(node_def)
|
||||
|
@ -126,6 +126,15 @@ def _reuse_op():
|
||||
return keras.Model(inputs, outputs)
|
||||
|
||||
|
||||
def _float64_op():
|
||||
inputs = keras.Input(shape=(10,))
|
||||
x = keras.layers.Dense(10, dtype='float64')(inputs)
|
||||
x = gen_nn_ops.relu(x)
|
||||
assert x.dtype == 'float64', 'x has dtype: %s' % x.dtype
|
||||
outputs = keras.layers.Dense(10)(x)
|
||||
return keras.Model(inputs, outputs)
|
||||
|
||||
|
||||
class LayerWithLayer(keras.layers.Layer):
|
||||
|
||||
def build(self, input_shape):
|
||||
@ -179,6 +188,7 @@ class AutoLambdaTest(keras_parameterized.TestCase):
|
||||
('op_with_tensor_list', _op_with_tensor_list),
|
||||
('add_n', _add_n),
|
||||
('_reuse_op', _reuse_op),
|
||||
('_float64_op', _float64_op),
|
||||
('_inner_layer', _inner_layer),
|
||||
('_reuse_ancillary_layer', _reuse_ancillary_layer),
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user