diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 7a8c865ca64..ad954c1f972 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -2488,8 +2488,11 @@ class TensorFlowOpLayer(Layer): constants=None, trainable=True, dtype=None): + # Pass autocast=False, as if inputs are cast, input types might not match + # Operation type. super(TensorFlowOpLayer, self).__init__( - name=_TF_OP_LAYER_NAME_PREFIX + name, trainable=trainable, dtype=dtype) + name=_TF_OP_LAYER_NAME_PREFIX + name, trainable=trainable, dtype=dtype, + autocast=False) if not isinstance(node_def, bytes): node_def = node_def.encode('utf-8') self.node_def = node_def_pb2.NodeDef.FromString(node_def) diff --git a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py index 54455cad73a..a853ce5eed0 100644 --- a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py +++ b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py @@ -126,6 +126,15 @@ def _reuse_op(): return keras.Model(inputs, outputs) +def _float64_op(): + inputs = keras.Input(shape=(10,)) + x = keras.layers.Dense(10, dtype='float64')(inputs) + x = gen_nn_ops.relu(x) + assert x.dtype == 'float64', 'x has dtype: %s' % x.dtype + outputs = keras.layers.Dense(10)(x) + return keras.Model(inputs, outputs) + + class LayerWithLayer(keras.layers.Layer): def build(self, input_shape): @@ -179,6 +188,7 @@ class AutoLambdaTest(keras_parameterized.TestCase): ('op_with_tensor_list', _op_with_tensor_list), ('add_n', _add_n), ('_reuse_op', _reuse_op), + ('_float64_op', _float64_op), ('_inner_layer', _inner_layer), ('_reuse_ancillary_layer', _reuse_ancillary_layer), )