Expose batch_group_count in xla_client python wrapper.
PiperOrigin-RevId: 252001924
This commit is contained in:
parent
89b1e717c0
commit
1ad04efcee
@ -1269,7 +1269,8 @@ class ComputationBuilder(object):
|
||||
dimension_numbers = GetDotDimensionsFromLists(dimension_numbers)
|
||||
return ops.DotGeneral(lhs, rhs, dimension_numbers)
|
||||
|
||||
def Conv(self, lhs, rhs, window_strides, padding, feature_group_count=1):
|
||||
def Conv(self, lhs, rhs, window_strides, padding,
|
||||
feature_group_count=1, batch_group_count=1):
|
||||
"""Enqueues a Conv operation onto the computation.
|
||||
|
||||
Args:
|
||||
@ -1278,6 +1279,7 @@ class ComputationBuilder(object):
|
||||
window_strides: length-N array-like of integer kernel strides.
|
||||
padding: PaddingType representing either 'SAME' or 'VALID' padding.
|
||||
feature_group_count: number of feature groups for grouped convolution.
|
||||
batch_group_count: number of batch groups for grouped convolution.
|
||||
Returns: a XlaOp representing the Conv operation.
|
||||
"""
|
||||
pads = _convert_padding_type_to_pad_values(
|
||||
@ -1290,7 +1292,8 @@ class ComputationBuilder(object):
|
||||
window_strides,
|
||||
pads, [], [],
|
||||
dimension_numbers=None,
|
||||
feature_group_count=feature_group_count)
|
||||
feature_group_count=feature_group_count,
|
||||
batch_group_count=batch_group_count)
|
||||
|
||||
def ConvWithGeneralPadding(self,
|
||||
lhs,
|
||||
@ -1299,7 +1302,8 @@ class ComputationBuilder(object):
|
||||
padding,
|
||||
lhs_dilation,
|
||||
rhs_dilation,
|
||||
feature_group_count=1):
|
||||
feature_group_count=1,
|
||||
batch_group_count=1):
|
||||
"""Enqueues a ConvWithGeneralPadding operation onto the computation.
|
||||
|
||||
Args:
|
||||
@ -1310,6 +1314,7 @@ class ComputationBuilder(object):
|
||||
lhs_dilation: length-N array-like of dilation factors.
|
||||
rhs_dilation: length-N array-like of dilation factors.
|
||||
feature_group_count: number of feature groups for grouped convolution.
|
||||
batch_group_count: number of batch groups for grouped convolution.
|
||||
|
||||
Returns:
|
||||
A ComputationdataHandle representing the added ConvWithGeneralPadding op.
|
||||
@ -1322,7 +1327,8 @@ class ComputationBuilder(object):
|
||||
list(lhs_dilation),
|
||||
list(rhs_dilation),
|
||||
dimension_numbers=None,
|
||||
feature_group_count=feature_group_count)
|
||||
feature_group_count=feature_group_count,
|
||||
batch_group_count=batch_group_count)
|
||||
|
||||
def _GetConvDimensionNumbers(self, num_spatial_dims):
|
||||
"""Create ConvolutionDimensionNumbers proto for convolutions."""
|
||||
@ -1347,7 +1353,8 @@ class ComputationBuilder(object):
|
||||
lhs_dilation,
|
||||
rhs_dilation,
|
||||
dimension_numbers=None,
|
||||
feature_group_count=1):
|
||||
feature_group_count=1,
|
||||
batch_group_count=1):
|
||||
"""Enqueues a ConvGeneralDilated operation onto the computation.
|
||||
|
||||
Args:
|
||||
@ -1377,6 +1384,7 @@ class ComputationBuilder(object):
|
||||
default, use the same dimension numbering as Conv and
|
||||
ConvWithGeneralPadding.
|
||||
feature_group_count: number of feature groups for grouped convolution.
|
||||
batch_group_count: number of batch groups for grouped convolution.
|
||||
Returns: a XlaOp representing the ConvGenralDilated operation.
|
||||
"""
|
||||
if dimension_numbers is None:
|
||||
@ -1402,7 +1410,7 @@ class ComputationBuilder(object):
|
||||
key=lambda i: rhs_spec.index(out_spec[i])))
|
||||
return ops.ConvGeneralDilated(lhs, rhs, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dimension_numbers,
|
||||
feature_group_count)
|
||||
feature_group_count, batch_group_count)
|
||||
|
||||
def Sort(self, operand, dimension=-1):
|
||||
"""Enqueues a sort operation onto the computation."""
|
||||
|
Loading…
Reference in New Issue
Block a user