Internal change: Remove tensor_shape.as_dimension from Keras.
To facilitate splitting Keras without relying on private apis. PiperOrigin-RevId: 341892302 Change-Id: I0713e2ea8409263e0e0e68fe7cfaf7c5074c7871
This commit is contained in:
parent
2889cec62c
commit
c11d0dea22
@ -214,7 +214,7 @@ def get_static_batch_size(layer):
|
|||||||
"""
|
"""
|
||||||
batch_input_shape, _ = get_input_shape_and_dtype(layer)
|
batch_input_shape, _ = get_input_shape_and_dtype(layer)
|
||||||
if batch_input_shape is not None:
|
if batch_input_shape is not None:
|
||||||
return tensor_shape.as_dimension(batch_input_shape[0]).value
|
return tensor_shape.Dimension(batch_input_shape[0]).value
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@ -1750,7 +1750,7 @@ class Model(training_lib.Model):
|
|||||||
# Check Dataset/Iterator batch size is consistent with InputLayer.
|
# Check Dataset/Iterator batch size is consistent with InputLayer.
|
||||||
if isinstance(x, (dataset_ops.DatasetV2, iterator_ops.Iterator,
|
if isinstance(x, (dataset_ops.DatasetV2, iterator_ops.Iterator,
|
||||||
iterator_ops.IteratorBase)):
|
iterator_ops.IteratorBase)):
|
||||||
ds_batch_size = tensor_shape.as_dimension(
|
ds_batch_size = tensor_shape.Dimension(
|
||||||
nest.flatten(dataset_ops.get_legacy_output_shapes(x))[0][0]).value
|
nest.flatten(dataset_ops.get_legacy_output_shapes(x))[0][0]).value
|
||||||
if ds_batch_size is not None:
|
if ds_batch_size is not None:
|
||||||
if ds_batch_size % num_splits_for_ds != 0:
|
if ds_batch_size % num_splits_for_ds != 0:
|
||||||
|
Loading…
Reference in New Issue
Block a user