diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py index 4180c0b7e1d..d75b6a125bf 100644 --- a/tensorflow/python/keras/engine/training_utils.py +++ b/tensorflow/python/keras/engine/training_utils.py @@ -214,7 +214,7 @@ def get_static_batch_size(layer): """ batch_input_shape, _ = get_input_shape_and_dtype(layer) if batch_input_shape is not None: - return tensor_shape.as_dimension(batch_input_shape[0]).value + return tensor_shape.Dimension(batch_input_shape[0]).value return None diff --git a/tensorflow/python/keras/engine/training_v1.py b/tensorflow/python/keras/engine/training_v1.py index dbf1703136b..6617c2dae09 100644 --- a/tensorflow/python/keras/engine/training_v1.py +++ b/tensorflow/python/keras/engine/training_v1.py @@ -1750,7 +1750,7 @@ class Model(training_lib.Model): # Check Dataset/Iterator batch size is consistent with InputLayer. if isinstance(x, (dataset_ops.DatasetV2, iterator_ops.Iterator, iterator_ops.IteratorBase)): - ds_batch_size = tensor_shape.as_dimension( + ds_batch_size = tensor_shape.Dimension( nest.flatten(dataset_ops.get_legacy_output_shapes(x))[0][0]).value if ds_batch_size is not None: if ds_batch_size % num_splits_for_ds != 0: