From 828f43c473a8b561fdd75f805ea6ba245a43de86 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Thu, 14 Jan 2021 15:13:11 -0800 Subject: [PATCH] Copy squeeze_batch_dims (private TF API) to keras. PiperOrigin-RevId: 351886087 Change-Id: Id1540eb2c4338e3442207cd1cdca02c83d0c7227 --- .../python/keras/layers/convolutional.py | 2 +- tensorflow/python/keras/utils/conv_utils.py | 50 +++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py index 731b51e2862..db71d98cbca 100644 --- a/tensorflow/python/keras/layers/convolutional.py +++ b/tensorflow/python/keras/layers/convolutional.py @@ -260,7 +260,7 @@ class Conv(Layer): def _apply_fn(o): return nn.bias_add(o, self.bias, data_format=self._tf_data_format) - outputs = nn_ops.squeeze_batch_dims( + outputs = conv_utils.squeeze_batch_dims( outputs, _apply_fn, inner_rank=self.rank + 1) else: outputs = nn.bias_add( diff --git a/tensorflow/python/keras/utils/conv_utils.py b/tensorflow/python/keras/utils/conv_utils.py index 1d328c4422f..8bf7ea94930 100644 --- a/tensorflow/python/keras/utils/conv_utils.py +++ b/tensorflow/python/keras/utils/conv_utils.py @@ -22,7 +22,10 @@ import itertools import numpy as np from six.moves import range # pylint: disable=redefined-builtin +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import backend +from tensorflow.python.ops import array_ops def convert_data_format(data_format, ndim): @@ -467,3 +470,50 @@ def conv_output_shape(input_shape, kernel_shape, strides, padding): output_shape = tuple( [0 if input_shape[d] == 0 else output_shape[d] for d in dims]) return output_shape + + +def squeeze_batch_dims(inp, op, inner_rank): + """Returns `unsqueeze_batch(op(squeeze_batch(inp)))`. + + Where `squeeze_batch` reshapes `inp` to shape + `[prod(inp.shape[:-inner_rank])] + inp.shape[-inner_rank:]` + and `unsqueeze_batch` does the reverse reshape but on the output. + + Args: + inp: A tensor with dims `batch_shape + inner_shape` where `inner_shape` + is length `inner_rank`. + op: A callable that takes a single input tensor and returns a single. + output tensor. + inner_rank: A python integer. + + Returns: + `unsqueeze_batch_op(squeeze_batch(inp))`. + """ + with ops.name_scope_v2('squeeze_batch_dims'): + shape = inp.shape + + inner_shape = shape[-inner_rank:] + if not inner_shape.is_fully_defined(): + inner_shape = array_ops.shape(inp)[-inner_rank:] + + batch_shape = shape[:-inner_rank] + if not batch_shape.is_fully_defined(): + batch_shape = array_ops.shape(inp)[:-inner_rank] + + if isinstance(inner_shape, tensor_shape.TensorShape): + inp_reshaped = array_ops.reshape(inp, [-1] + inner_shape.as_list()) + else: + inp_reshaped = array_ops.reshape( + inp, array_ops.concat(([-1], inner_shape), axis=-1)) + + out_reshaped = op(inp_reshaped) + + out_inner_shape = out_reshaped.shape[-inner_rank:] + if not out_inner_shape.is_fully_defined(): + out_inner_shape = array_ops.shape(out_reshaped)[-inner_rank:] + + out = array_ops.reshape( + out_reshaped, array_ops.concat((batch_shape, out_inner_shape), axis=-1)) + + out.set_shape(inp.shape[:-inner_rank] + out.shape[-inner_rank:]) + return out