Expose batch_group_count in xla_client python wrapper.

PiperOrigin-RevId: 252001924
This commit is contained in:
A. Unique TensorFlower 2019-06-07 00:27:15 -07:00 committed by TensorFlower Gardener
parent 89b1e717c0
commit 1ad04efcee

View File

@ -1269,7 +1269,8 @@ class ComputationBuilder(object):
dimension_numbers = GetDotDimensionsFromLists(dimension_numbers) dimension_numbers = GetDotDimensionsFromLists(dimension_numbers)
return ops.DotGeneral(lhs, rhs, dimension_numbers) return ops.DotGeneral(lhs, rhs, dimension_numbers)
def Conv(self, lhs, rhs, window_strides, padding, feature_group_count=1): def Conv(self, lhs, rhs, window_strides, padding,
feature_group_count=1, batch_group_count=1):
"""Enqueues a Conv operation onto the computation. """Enqueues a Conv operation onto the computation.
Args: Args:
@ -1278,6 +1279,7 @@ class ComputationBuilder(object):
window_strides: length-N array-like of integer kernel strides. window_strides: length-N array-like of integer kernel strides.
padding: PaddingType representing either 'SAME' or 'VALID' padding. padding: PaddingType representing either 'SAME' or 'VALID' padding.
feature_group_count: number of feature groups for grouped convolution. feature_group_count: number of feature groups for grouped convolution.
batch_group_count: number of batch groups for grouped convolution.
Returns: a XlaOp representing the Conv operation. Returns: a XlaOp representing the Conv operation.
""" """
pads = _convert_padding_type_to_pad_values( pads = _convert_padding_type_to_pad_values(
@ -1290,7 +1292,8 @@ class ComputationBuilder(object):
window_strides, window_strides,
pads, [], [], pads, [], [],
dimension_numbers=None, dimension_numbers=None,
feature_group_count=feature_group_count) feature_group_count=feature_group_count,
batch_group_count=batch_group_count)
def ConvWithGeneralPadding(self, def ConvWithGeneralPadding(self,
lhs, lhs,
@ -1299,7 +1302,8 @@ class ComputationBuilder(object):
padding, padding,
lhs_dilation, lhs_dilation,
rhs_dilation, rhs_dilation,
feature_group_count=1): feature_group_count=1,
batch_group_count=1):
"""Enqueues a ConvWithGeneralPadding operation onto the computation. """Enqueues a ConvWithGeneralPadding operation onto the computation.
Args: Args:
@ -1310,6 +1314,7 @@ class ComputationBuilder(object):
lhs_dilation: length-N array-like of dilation factors. lhs_dilation: length-N array-like of dilation factors.
rhs_dilation: length-N array-like of dilation factors. rhs_dilation: length-N array-like of dilation factors.
feature_group_count: number of feature groups for grouped convolution. feature_group_count: number of feature groups for grouped convolution.
batch_group_count: number of batch groups for grouped convolution.
Returns: Returns:
A ComputationdataHandle representing the added ConvWithGeneralPadding op. A ComputationdataHandle representing the added ConvWithGeneralPadding op.
@ -1322,7 +1327,8 @@ class ComputationBuilder(object):
list(lhs_dilation), list(lhs_dilation),
list(rhs_dilation), list(rhs_dilation),
dimension_numbers=None, dimension_numbers=None,
feature_group_count=feature_group_count) feature_group_count=feature_group_count,
batch_group_count=batch_group_count)
def _GetConvDimensionNumbers(self, num_spatial_dims): def _GetConvDimensionNumbers(self, num_spatial_dims):
"""Create ConvolutionDimensionNumbers proto for convolutions.""" """Create ConvolutionDimensionNumbers proto for convolutions."""
@ -1347,7 +1353,8 @@ class ComputationBuilder(object):
lhs_dilation, lhs_dilation,
rhs_dilation, rhs_dilation,
dimension_numbers=None, dimension_numbers=None,
feature_group_count=1): feature_group_count=1,
batch_group_count=1):
"""Enqueues a ConvGeneralDilated operation onto the computation. """Enqueues a ConvGeneralDilated operation onto the computation.
Args: Args:
@ -1377,6 +1384,7 @@ class ComputationBuilder(object):
default, use the same dimension numbering as Conv and default, use the same dimension numbering as Conv and
ConvWithGeneralPadding. ConvWithGeneralPadding.
feature_group_count: number of feature groups for grouped convolution. feature_group_count: number of feature groups for grouped convolution.
batch_group_count: number of batch groups for grouped convolution.
Returns: a XlaOp representing the ConvGenralDilated operation. Returns: a XlaOp representing the ConvGenralDilated operation.
""" """
if dimension_numbers is None: if dimension_numbers is None:
@ -1402,7 +1410,7 @@ class ComputationBuilder(object):
key=lambda i: rhs_spec.index(out_spec[i]))) key=lambda i: rhs_spec.index(out_spec[i])))
return ops.ConvGeneralDilated(lhs, rhs, window_strides, padding, return ops.ConvGeneralDilated(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers, lhs_dilation, rhs_dilation, dimension_numbers,
feature_group_count) feature_group_count, batch_group_count)
def Sort(self, operand, dimension=-1): def Sort(self, operand, dimension=-1):
"""Enqueues a sort operation onto the computation.""" """Enqueues a sort operation onto the computation."""