From dad0d6a8562401028ab0dd42e145fe8afebe6d7d Mon Sep 17 00:00:00 2001 From: Pavithra Vijay Date: Thu, 1 Aug 2019 17:58:14 -0700 Subject: [PATCH] Remove Sequence condition check that is around input casting. We do not know why this was added and the casting is required for other dtypes such as dict as well. PiperOrigin-RevId: 261235920 --- .../keras/engine/training_arrays_test.py | 35 ++++++++++++++++++- .../python/keras/engine/training_eager.py | 15 ++++---- 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/keras/engine/training_arrays_test.py b/tensorflow/python/keras/engine/training_arrays_test.py index a4647ad34da..097d5eef36d 100644 --- a/tensorflow/python/keras/engine/training_arrays_test.py +++ b/tensorflow/python/keras/engine/training_arrays_test.py @@ -63,7 +63,8 @@ class ValidationDatasetNoLimitTest(keras_parameterized.TestCase): evaluation[-1], places=5) -class PrintTrainingInfoTest(parameterized.TestCase): +class PrintTrainingInfoTest(keras_parameterized.TestCase, + parameterized.TestCase): @test_util.run_v1_only("Only relevant in graph mode.") def test_print_info_with_datasets(self): @@ -110,6 +111,38 @@ class PrintTrainingInfoTest(parameterized.TestCase): if do_validation: self.assertIn(", validate on 50 samples", mock_stdout.getvalue()) + @keras_parameterized.run_all_keras_modes + def test_dict_float64_input(self): + + class MyModel(keras.Model): + + def __init__(self): + super(MyModel, self).__init__(self) + self.dense1 = keras.layers.Dense(10, activation="relu") + self.dense2 = keras.layers.Dense(10, activation="relu") + self.concat = keras.layers.Concatenate() + self.dense3 = keras.layers.Dense(1, activation="sigmoid") + + def call(self, inputs): + d1 = self.dense1(inputs["one"]) + d2 = self.dense2(inputs["two"]) + concat = self.concat([d1, d2]) + return self.dense3(concat) + + model = MyModel() + model.compile( + loss="mae", + optimizer="adam", + run_eagerly=testing_utils.should_run_eagerly(), + experimental_run_tf_function=testing_utils.should_run_tf_function()) + + model.fit( + x={ + "one": np.random.rand(100, 10, 1), + "two": np.random.rand(100, 10, 1) + }, + y=np.random.rand(100, 10, 1)) + def test_dict_validation_input(self): """Test case for GitHub issue 30122.""" train_input_0 = np.random.rand(1000, 1) diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py index 3e1a03076e8..cd1fd8c6b2d 100644 --- a/tensorflow/python/keras/engine/training_eager.py +++ b/tensorflow/python/keras/engine/training_eager.py @@ -31,7 +31,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.losses import util as tf_losses_utils from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest -from tensorflow.python.util.compat import collections_abc def _eager_loss_fn(outputs, targets, loss_fn, output_name): @@ -287,10 +286,9 @@ def train_on_batch(model, Returns: total loss and the loss associated with each output. """ - if isinstance(inputs, collections_abc.Sequence): - inputs = training_utils.cast_to_model_input_dtypes(inputs, model) - if targets: - targets = training_utils.cast_if_floating_dtype(targets) + inputs = training_utils.cast_to_model_input_dtypes(inputs, model) + if targets: + targets = training_utils.cast_if_floating_dtype(targets) if sample_weights: sample_weights = [ training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val)) @@ -333,10 +331,9 @@ def test_on_batch(model, Returns: total loss, loss and metrics associated with each output. """ - if isinstance(inputs, collections_abc.Sequence): - inputs = training_utils.cast_to_model_input_dtypes(inputs, model) - if targets: - targets = training_utils.cast_if_floating_dtype(targets) + inputs = training_utils.cast_to_model_input_dtypes(inputs, model) + if targets: + targets = training_utils.cast_if_floating_dtype(targets) if sample_weights: sample_weights = [ training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val))