Copy common_shape util to Keras.

PiperOrigin-RevId: 340550537
Change-Id: Ic1b23d809102b18e6008c000f95e4e042892c417
This commit is contained in:
Yanhui Liang 2020-11-03 16:43:39 -08:00 committed by TensorFlower Gardener
parent 317c2b8460
commit e82b266a54

View File

@ -25,9 +25,9 @@ from google.protobuf import message
from tensorflow.core.framework import versions_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import function as defun
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend
from tensorflow.python.keras import regularizers
@ -1065,6 +1065,31 @@ def recursively_deserialize_keras_object(config, module_objects=None):
raise ValueError('Unable to decode config: {}'.format(config))
def get_common_shape(x, y):
"""Find a `TensorShape` that is compatible with both `x` and `y`."""
if x is None != y is None:
raise RuntimeError(
'Cannot find a common shape when LHS shape is None but RHS shape '
'is not (or vice versa): %s vs. %s' % (x, y))
if x is None:
return None # The associated input was not a Tensor, no shape generated.
if not isinstance(x, tensor_shape.TensorShape):
raise TypeError('Expected x to be a TensorShape but saw %s' % (x,))
if not isinstance(y, tensor_shape.TensorShape):
raise TypeError('Expected y to be a TensorShape but saw %s' % (y,))
if x.rank != y.rank or x.rank is None:
return tensor_shape.TensorShape(None)
dims = []
for dim_x, dim_y in zip(x.dims, y.dims):
if (dim_x != dim_y
or tensor_shape.dimension_value(dim_x) is None
or tensor_shape.dimension_value(dim_y) is None):
dims.append(None)
else:
dims.append(tensor_shape.dimension_value(dim_x))
return tensor_shape.TensorShape(dims)
def infer_inputs_from_restored_call_function(fn):
"""Returns TensorSpec of inputs from a restored call function.
@ -1076,7 +1101,7 @@ def infer_inputs_from_restored_call_function(fn):
TensorSpec of call function inputs.
"""
def common_spec(x, y):
common_shape = defun.common_shape(x.shape, y.shape)
common_shape = get_common_shape(x.shape, y.shape)
if isinstance(x, sparse_tensor.SparseTensorSpec):
return sparse_tensor.SparseTensorSpec(common_shape, x.dtype)
return tensor_spec.TensorSpec(common_shape, x.dtype, x.name)