Remove Sequence condition check that is around input casting.

We do not know why this was added and the casting is required for other dtypes such as dict as well.

PiperOrigin-RevId: 261235920
This commit is contained in:
Pavithra Vijay 2019-08-01 17:58:14 -07:00 committed by TensorFlower Gardener
parent 27cc0c9a65
commit dad0d6a856
2 changed files with 40 additions and 10 deletions

View File

@ -63,7 +63,8 @@ class ValidationDatasetNoLimitTest(keras_parameterized.TestCase):
evaluation[-1], places=5)
class PrintTrainingInfoTest(parameterized.TestCase):
class PrintTrainingInfoTest(keras_parameterized.TestCase,
parameterized.TestCase):
@test_util.run_v1_only("Only relevant in graph mode.")
def test_print_info_with_datasets(self):
@ -110,6 +111,38 @@ class PrintTrainingInfoTest(parameterized.TestCase):
if do_validation:
self.assertIn(", validate on 50 samples", mock_stdout.getvalue())
@keras_parameterized.run_all_keras_modes
def test_dict_float64_input(self):
class MyModel(keras.Model):
def __init__(self):
super(MyModel, self).__init__(self)
self.dense1 = keras.layers.Dense(10, activation="relu")
self.dense2 = keras.layers.Dense(10, activation="relu")
self.concat = keras.layers.Concatenate()
self.dense3 = keras.layers.Dense(1, activation="sigmoid")
def call(self, inputs):
d1 = self.dense1(inputs["one"])
d2 = self.dense2(inputs["two"])
concat = self.concat([d1, d2])
return self.dense3(concat)
model = MyModel()
model.compile(
loss="mae",
optimizer="adam",
run_eagerly=testing_utils.should_run_eagerly(),
experimental_run_tf_function=testing_utils.should_run_tf_function())
model.fit(
x={
"one": np.random.rand(100, 10, 1),
"two": np.random.rand(100, 10, 1)
},
y=np.random.rand(100, 10, 1))
def test_dict_validation_input(self):
"""Test case for GitHub issue 30122."""
train_input_0 = np.random.rand(1000, 1)

View File

@ -31,7 +31,6 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops.losses import util as tf_losses_utils
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util.compat import collections_abc
def _eager_loss_fn(outputs, targets, loss_fn, output_name):
@ -287,7 +286,6 @@ def train_on_batch(model,
Returns:
total loss and the loss associated with each output.
"""
if isinstance(inputs, collections_abc.Sequence):
inputs = training_utils.cast_to_model_input_dtypes(inputs, model)
if targets:
targets = training_utils.cast_if_floating_dtype(targets)
@ -333,7 +331,6 @@ def test_on_batch(model,
Returns:
total loss, loss and metrics associated with each output.
"""
if isinstance(inputs, collections_abc.Sequence):
inputs = training_utils.cast_to_model_input_dtypes(inputs, model)
if targets:
targets = training_utils.cast_if_floating_dtype(targets)