Remove Sequence condition check that is around input casting.
We do not know why this was added and the casting is required for other dtypes such as dict as well. PiperOrigin-RevId: 261235920
This commit is contained in:
parent
27cc0c9a65
commit
dad0d6a856
@ -63,7 +63,8 @@ class ValidationDatasetNoLimitTest(keras_parameterized.TestCase):
|
|||||||
evaluation[-1], places=5)
|
evaluation[-1], places=5)
|
||||||
|
|
||||||
|
|
||||||
class PrintTrainingInfoTest(parameterized.TestCase):
|
class PrintTrainingInfoTest(keras_parameterized.TestCase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
|
||||||
@test_util.run_v1_only("Only relevant in graph mode.")
|
@test_util.run_v1_only("Only relevant in graph mode.")
|
||||||
def test_print_info_with_datasets(self):
|
def test_print_info_with_datasets(self):
|
||||||
@ -110,6 +111,38 @@ class PrintTrainingInfoTest(parameterized.TestCase):
|
|||||||
if do_validation:
|
if do_validation:
|
||||||
self.assertIn(", validate on 50 samples", mock_stdout.getvalue())
|
self.assertIn(", validate on 50 samples", mock_stdout.getvalue())
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes
|
||||||
|
def test_dict_float64_input(self):
|
||||||
|
|
||||||
|
class MyModel(keras.Model):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(MyModel, self).__init__(self)
|
||||||
|
self.dense1 = keras.layers.Dense(10, activation="relu")
|
||||||
|
self.dense2 = keras.layers.Dense(10, activation="relu")
|
||||||
|
self.concat = keras.layers.Concatenate()
|
||||||
|
self.dense3 = keras.layers.Dense(1, activation="sigmoid")
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
d1 = self.dense1(inputs["one"])
|
||||||
|
d2 = self.dense2(inputs["two"])
|
||||||
|
concat = self.concat([d1, d2])
|
||||||
|
return self.dense3(concat)
|
||||||
|
|
||||||
|
model = MyModel()
|
||||||
|
model.compile(
|
||||||
|
loss="mae",
|
||||||
|
optimizer="adam",
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly(),
|
||||||
|
experimental_run_tf_function=testing_utils.should_run_tf_function())
|
||||||
|
|
||||||
|
model.fit(
|
||||||
|
x={
|
||||||
|
"one": np.random.rand(100, 10, 1),
|
||||||
|
"two": np.random.rand(100, 10, 1)
|
||||||
|
},
|
||||||
|
y=np.random.rand(100, 10, 1))
|
||||||
|
|
||||||
def test_dict_validation_input(self):
|
def test_dict_validation_input(self):
|
||||||
"""Test case for GitHub issue 30122."""
|
"""Test case for GitHub issue 30122."""
|
||||||
train_input_0 = np.random.rand(1000, 1)
|
train_input_0 = np.random.rand(1000, 1)
|
||||||
|
@ -31,7 +31,6 @@ from tensorflow.python.ops import math_ops
|
|||||||
from tensorflow.python.ops.losses import util as tf_losses_utils
|
from tensorflow.python.ops.losses import util as tf_losses_utils
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util.compat import collections_abc
|
|
||||||
|
|
||||||
|
|
||||||
def _eager_loss_fn(outputs, targets, loss_fn, output_name):
|
def _eager_loss_fn(outputs, targets, loss_fn, output_name):
|
||||||
@ -287,10 +286,9 @@ def train_on_batch(model,
|
|||||||
Returns:
|
Returns:
|
||||||
total loss and the loss associated with each output.
|
total loss and the loss associated with each output.
|
||||||
"""
|
"""
|
||||||
if isinstance(inputs, collections_abc.Sequence):
|
inputs = training_utils.cast_to_model_input_dtypes(inputs, model)
|
||||||
inputs = training_utils.cast_to_model_input_dtypes(inputs, model)
|
if targets:
|
||||||
if targets:
|
targets = training_utils.cast_if_floating_dtype(targets)
|
||||||
targets = training_utils.cast_if_floating_dtype(targets)
|
|
||||||
if sample_weights:
|
if sample_weights:
|
||||||
sample_weights = [
|
sample_weights = [
|
||||||
training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val))
|
training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val))
|
||||||
@ -333,10 +331,9 @@ def test_on_batch(model,
|
|||||||
Returns:
|
Returns:
|
||||||
total loss, loss and metrics associated with each output.
|
total loss, loss and metrics associated with each output.
|
||||||
"""
|
"""
|
||||||
if isinstance(inputs, collections_abc.Sequence):
|
inputs = training_utils.cast_to_model_input_dtypes(inputs, model)
|
||||||
inputs = training_utils.cast_to_model_input_dtypes(inputs, model)
|
if targets:
|
||||||
if targets:
|
targets = training_utils.cast_if_floating_dtype(targets)
|
||||||
targets = training_utils.cast_if_floating_dtype(targets)
|
|
||||||
if sample_weights:
|
if sample_weights:
|
||||||
sample_weights = [
|
sample_weights = [
|
||||||
training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val))
|
training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val))
|
||||||
|
Loading…
Reference in New Issue
Block a user