diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index a735333153d..3aada44fb68 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -1269,7 +1269,8 @@ class ComputationBuilder(object): dimension_numbers = GetDotDimensionsFromLists(dimension_numbers) return ops.DotGeneral(lhs, rhs, dimension_numbers) - def Conv(self, lhs, rhs, window_strides, padding, feature_group_count=1): + def Conv(self, lhs, rhs, window_strides, padding, + feature_group_count=1, batch_group_count=1): """Enqueues a Conv operation onto the computation. Args: @@ -1278,6 +1279,7 @@ class ComputationBuilder(object): window_strides: length-N array-like of integer kernel strides. padding: PaddingType representing either 'SAME' or 'VALID' padding. feature_group_count: number of feature groups for grouped convolution. + batch_group_count: number of batch groups for grouped convolution. Returns: a XlaOp representing the Conv operation. """ pads = _convert_padding_type_to_pad_values( @@ -1290,7 +1292,8 @@ class ComputationBuilder(object): window_strides, pads, [], [], dimension_numbers=None, - feature_group_count=feature_group_count) + feature_group_count=feature_group_count, + batch_group_count=batch_group_count) def ConvWithGeneralPadding(self, lhs, @@ -1299,7 +1302,8 @@ class ComputationBuilder(object): padding, lhs_dilation, rhs_dilation, - feature_group_count=1): + feature_group_count=1, + batch_group_count=1): """Enqueues a ConvWithGeneralPadding operation onto the computation. Args: @@ -1310,6 +1314,7 @@ class ComputationBuilder(object): lhs_dilation: length-N array-like of dilation factors. rhs_dilation: length-N array-like of dilation factors. feature_group_count: number of feature groups for grouped convolution. + batch_group_count: number of batch groups for grouped convolution. Returns: A ComputationdataHandle representing the added ConvWithGeneralPadding op. @@ -1322,7 +1327,8 @@ class ComputationBuilder(object): list(lhs_dilation), list(rhs_dilation), dimension_numbers=None, - feature_group_count=feature_group_count) + feature_group_count=feature_group_count, + batch_group_count=batch_group_count) def _GetConvDimensionNumbers(self, num_spatial_dims): """Create ConvolutionDimensionNumbers proto for convolutions.""" @@ -1347,7 +1353,8 @@ class ComputationBuilder(object): lhs_dilation, rhs_dilation, dimension_numbers=None, - feature_group_count=1): + feature_group_count=1, + batch_group_count=1): """Enqueues a ConvGeneralDilated operation onto the computation. Args: @@ -1377,6 +1384,7 @@ class ComputationBuilder(object): default, use the same dimension numbering as Conv and ConvWithGeneralPadding. feature_group_count: number of feature groups for grouped convolution. + batch_group_count: number of batch groups for grouped convolution. Returns: a XlaOp representing the ConvGenralDilated operation. """ if dimension_numbers is None: @@ -1402,7 +1410,7 @@ class ComputationBuilder(object): key=lambda i: rhs_spec.index(out_spec[i]))) return ops.ConvGeneralDilated(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, - feature_group_count) + feature_group_count, batch_group_count) def Sort(self, operand, dimension=-1): """Enqueues a sort operation onto the computation."""