Copy squeeze_batch_dims (private TF API) to keras.

PiperOrigin-RevId: 351886087
Change-Id: Id1540eb2c4338e3442207cd1cdca02c83d0c7227
This commit is contained in:
Scott Zhu 2021-01-14 15:13:11 -08:00 committed by TensorFlower Gardener
parent 74c2870b60
commit 828f43c473
2 changed files with 51 additions and 1 deletions

View File

@ -260,7 +260,7 @@ class Conv(Layer):
def _apply_fn(o):
return nn.bias_add(o, self.bias, data_format=self._tf_data_format)
outputs = nn_ops.squeeze_batch_dims(
outputs = conv_utils.squeeze_batch_dims(
outputs, _apply_fn, inner_rank=self.rank + 1)
else:
outputs = nn.bias_add(

View File

@ -22,7 +22,10 @@ import itertools
import numpy as np
from six.moves import range # pylint: disable=redefined-builtin
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend
from tensorflow.python.ops import array_ops
def convert_data_format(data_format, ndim):
@ -467,3 +470,50 @@ def conv_output_shape(input_shape, kernel_shape, strides, padding):
output_shape = tuple(
[0 if input_shape[d] == 0 else output_shape[d] for d in dims])
return output_shape
def squeeze_batch_dims(inp, op, inner_rank):
"""Returns `unsqueeze_batch(op(squeeze_batch(inp)))`.
Where `squeeze_batch` reshapes `inp` to shape
`[prod(inp.shape[:-inner_rank])] + inp.shape[-inner_rank:]`
and `unsqueeze_batch` does the reverse reshape but on the output.
Args:
inp: A tensor with dims `batch_shape + inner_shape` where `inner_shape`
is length `inner_rank`.
op: A callable that takes a single input tensor and returns a single.
output tensor.
inner_rank: A python integer.
Returns:
`unsqueeze_batch_op(squeeze_batch(inp))`.
"""
with ops.name_scope_v2('squeeze_batch_dims'):
shape = inp.shape
inner_shape = shape[-inner_rank:]
if not inner_shape.is_fully_defined():
inner_shape = array_ops.shape(inp)[-inner_rank:]
batch_shape = shape[:-inner_rank]
if not batch_shape.is_fully_defined():
batch_shape = array_ops.shape(inp)[:-inner_rank]
if isinstance(inner_shape, tensor_shape.TensorShape):
inp_reshaped = array_ops.reshape(inp, [-1] + inner_shape.as_list())
else:
inp_reshaped = array_ops.reshape(
inp, array_ops.concat(([-1], inner_shape), axis=-1))
out_reshaped = op(inp_reshaped)
out_inner_shape = out_reshaped.shape[-inner_rank:]
if not out_inner_shape.is_fully_defined():
out_inner_shape = array_ops.shape(out_reshaped)[-inner_rank:]
out = array_ops.reshape(
out_reshaped, array_ops.concat((batch_shape, out_inner_shape), axis=-1))
out.set_shape(inp.shape[:-inner_rank] + out.shape[-inner_rank:])
return out