Internal change: Remove tensor_shape.as_dimension from Keras.
To facilitate splitting Keras without relying on private apis. PiperOrigin-RevId: 341892302 Change-Id: I0713e2ea8409263e0e0e68fe7cfaf7c5074c7871
This commit is contained in:
parent
2889cec62c
commit
c11d0dea22
@ -214,7 +214,7 @@ def get_static_batch_size(layer):
|
||||
"""
|
||||
batch_input_shape, _ = get_input_shape_and_dtype(layer)
|
||||
if batch_input_shape is not None:
|
||||
return tensor_shape.as_dimension(batch_input_shape[0]).value
|
||||
return tensor_shape.Dimension(batch_input_shape[0]).value
|
||||
return None
|
||||
|
||||
|
||||
|
@ -1750,7 +1750,7 @@ class Model(training_lib.Model):
|
||||
# Check Dataset/Iterator batch size is consistent with InputLayer.
|
||||
if isinstance(x, (dataset_ops.DatasetV2, iterator_ops.Iterator,
|
||||
iterator_ops.IteratorBase)):
|
||||
ds_batch_size = tensor_shape.as_dimension(
|
||||
ds_batch_size = tensor_shape.Dimension(
|
||||
nest.flatten(dataset_ops.get_legacy_output_shapes(x))[0][0]).value
|
||||
if ds_batch_size is not None:
|
||||
if ds_batch_size % num_splits_for_ds != 0:
|
||||
|
Loading…
Reference in New Issue
Block a user