Internal change: Remove tensor_shape.as_dimension from Keras.

To facilitate splitting Keras without relying on private apis.

PiperOrigin-RevId: 341892302
Change-Id: I0713e2ea8409263e0e0e68fe7cfaf7c5074c7871
This commit is contained in:
Tomer Kaftan 2020-11-11 12:56:28 -08:00 committed by TensorFlower Gardener
parent 2889cec62c
commit c11d0dea22
2 changed files with 2 additions and 2 deletions

View File

@ -214,7 +214,7 @@ def get_static_batch_size(layer):
""" """
batch_input_shape, _ = get_input_shape_and_dtype(layer) batch_input_shape, _ = get_input_shape_and_dtype(layer)
if batch_input_shape is not None: if batch_input_shape is not None:
return tensor_shape.as_dimension(batch_input_shape[0]).value return tensor_shape.Dimension(batch_input_shape[0]).value
return None return None

View File

@ -1750,7 +1750,7 @@ class Model(training_lib.Model):
# Check Dataset/Iterator batch size is consistent with InputLayer. # Check Dataset/Iterator batch size is consistent with InputLayer.
if isinstance(x, (dataset_ops.DatasetV2, iterator_ops.Iterator, if isinstance(x, (dataset_ops.DatasetV2, iterator_ops.Iterator,
iterator_ops.IteratorBase)): iterator_ops.IteratorBase)):
ds_batch_size = tensor_shape.as_dimension( ds_batch_size = tensor_shape.Dimension(
nest.flatten(dataset_ops.get_legacy_output_shapes(x))[0][0]).value nest.flatten(dataset_ops.get_legacy_output_shapes(x))[0][0]).value
if ds_batch_size is not None: if ds_batch_size is not None:
if ds_batch_size % num_splits_for_ds != 0: if ds_batch_size % num_splits_for_ds != 0: