Copy common_shape
util to Keras.
PiperOrigin-RevId: 340550537 Change-Id: Ic1b23d809102b18e6008c000f95e4e042892c417
This commit is contained in:
parent
317c2b8460
commit
e82b266a54
@ -25,9 +25,9 @@ from google.protobuf import message
|
|||||||
|
|
||||||
from tensorflow.core.framework import versions_pb2
|
from tensorflow.core.framework import versions_pb2
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import function as defun
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.keras import backend
|
from tensorflow.python.keras import backend
|
||||||
from tensorflow.python.keras import regularizers
|
from tensorflow.python.keras import regularizers
|
||||||
@ -1065,6 +1065,31 @@ def recursively_deserialize_keras_object(config, module_objects=None):
|
|||||||
raise ValueError('Unable to decode config: {}'.format(config))
|
raise ValueError('Unable to decode config: {}'.format(config))
|
||||||
|
|
||||||
|
|
||||||
|
def get_common_shape(x, y):
|
||||||
|
"""Find a `TensorShape` that is compatible with both `x` and `y`."""
|
||||||
|
if x is None != y is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
'Cannot find a common shape when LHS shape is None but RHS shape '
|
||||||
|
'is not (or vice versa): %s vs. %s' % (x, y))
|
||||||
|
if x is None:
|
||||||
|
return None # The associated input was not a Tensor, no shape generated.
|
||||||
|
if not isinstance(x, tensor_shape.TensorShape):
|
||||||
|
raise TypeError('Expected x to be a TensorShape but saw %s' % (x,))
|
||||||
|
if not isinstance(y, tensor_shape.TensorShape):
|
||||||
|
raise TypeError('Expected y to be a TensorShape but saw %s' % (y,))
|
||||||
|
if x.rank != y.rank or x.rank is None:
|
||||||
|
return tensor_shape.TensorShape(None)
|
||||||
|
dims = []
|
||||||
|
for dim_x, dim_y in zip(x.dims, y.dims):
|
||||||
|
if (dim_x != dim_y
|
||||||
|
or tensor_shape.dimension_value(dim_x) is None
|
||||||
|
or tensor_shape.dimension_value(dim_y) is None):
|
||||||
|
dims.append(None)
|
||||||
|
else:
|
||||||
|
dims.append(tensor_shape.dimension_value(dim_x))
|
||||||
|
return tensor_shape.TensorShape(dims)
|
||||||
|
|
||||||
|
|
||||||
def infer_inputs_from_restored_call_function(fn):
|
def infer_inputs_from_restored_call_function(fn):
|
||||||
"""Returns TensorSpec of inputs from a restored call function.
|
"""Returns TensorSpec of inputs from a restored call function.
|
||||||
|
|
||||||
@ -1076,7 +1101,7 @@ def infer_inputs_from_restored_call_function(fn):
|
|||||||
TensorSpec of call function inputs.
|
TensorSpec of call function inputs.
|
||||||
"""
|
"""
|
||||||
def common_spec(x, y):
|
def common_spec(x, y):
|
||||||
common_shape = defun.common_shape(x.shape, y.shape)
|
common_shape = get_common_shape(x.shape, y.shape)
|
||||||
if isinstance(x, sparse_tensor.SparseTensorSpec):
|
if isinstance(x, sparse_tensor.SparseTensorSpec):
|
||||||
return sparse_tensor.SparseTensorSpec(common_shape, x.dtype)
|
return sparse_tensor.SparseTensorSpec(common_shape, x.dtype)
|
||||||
return tensor_spec.TensorSpec(common_shape, x.dtype, x.name)
|
return tensor_spec.TensorSpec(common_shape, x.dtype, x.name)
|
||||||
|
Loading…
Reference in New Issue
Block a user