diff --git a/tensorflow/python/keras/engine/training_arrays_test.py b/tensorflow/python/keras/engine/training_arrays_test.py index a4647ad34da..097d5eef36d 100644 --- a/tensorflow/python/keras/engine/training_arrays_test.py +++ b/tensorflow/python/keras/engine/training_arrays_test.py @@ -63,7 +63,8 @@ class ValidationDatasetNoLimitTest(keras_parameterized.TestCase): evaluation[-1], places=5) -class PrintTrainingInfoTest(parameterized.TestCase): +class PrintTrainingInfoTest(keras_parameterized.TestCase, + parameterized.TestCase): @test_util.run_v1_only("Only relevant in graph mode.") def test_print_info_with_datasets(self): @@ -110,6 +111,38 @@ class PrintTrainingInfoTest(parameterized.TestCase): if do_validation: self.assertIn(", validate on 50 samples", mock_stdout.getvalue()) + @keras_parameterized.run_all_keras_modes + def test_dict_float64_input(self): + + class MyModel(keras.Model): + + def __init__(self): + super(MyModel, self).__init__(self) + self.dense1 = keras.layers.Dense(10, activation="relu") + self.dense2 = keras.layers.Dense(10, activation="relu") + self.concat = keras.layers.Concatenate() + self.dense3 = keras.layers.Dense(1, activation="sigmoid") + + def call(self, inputs): + d1 = self.dense1(inputs["one"]) + d2 = self.dense2(inputs["two"]) + concat = self.concat([d1, d2]) + return self.dense3(concat) + + model = MyModel() + model.compile( + loss="mae", + optimizer="adam", + run_eagerly=testing_utils.should_run_eagerly(), + experimental_run_tf_function=testing_utils.should_run_tf_function()) + + model.fit( + x={ + "one": np.random.rand(100, 10, 1), + "two": np.random.rand(100, 10, 1) + }, + y=np.random.rand(100, 10, 1)) + def test_dict_validation_input(self): """Test case for GitHub issue 30122.""" train_input_0 = np.random.rand(1000, 1) diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py index 3e1a03076e8..cd1fd8c6b2d 100644 --- a/tensorflow/python/keras/engine/training_eager.py +++ b/tensorflow/python/keras/engine/training_eager.py @@ -31,7 +31,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.losses import util as tf_losses_utils from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest -from tensorflow.python.util.compat import collections_abc def _eager_loss_fn(outputs, targets, loss_fn, output_name): @@ -287,10 +286,9 @@ def train_on_batch(model, Returns: total loss and the loss associated with each output. """ - if isinstance(inputs, collections_abc.Sequence): - inputs = training_utils.cast_to_model_input_dtypes(inputs, model) - if targets: - targets = training_utils.cast_if_floating_dtype(targets) + inputs = training_utils.cast_to_model_input_dtypes(inputs, model) + if targets: + targets = training_utils.cast_if_floating_dtype(targets) if sample_weights: sample_weights = [ training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val)) @@ -333,10 +331,9 @@ def test_on_batch(model, Returns: total loss, loss and metrics associated with each output. """ - if isinstance(inputs, collections_abc.Sequence): - inputs = training_utils.cast_to_model_input_dtypes(inputs, model) - if targets: - targets = training_utils.cast_if_floating_dtype(targets) + inputs = training_utils.cast_to_model_input_dtypes(inputs, model) + if targets: + targets = training_utils.cast_if_floating_dtype(targets) if sample_weights: sample_weights = [ training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val))