Expose batch_group_count in xla_client python wrapper.
PiperOrigin-RevId: 252001924
This commit is contained in:
parent
89b1e717c0
commit
1ad04efcee
@ -1269,7 +1269,8 @@ class ComputationBuilder(object):
|
|||||||
dimension_numbers = GetDotDimensionsFromLists(dimension_numbers)
|
dimension_numbers = GetDotDimensionsFromLists(dimension_numbers)
|
||||||
return ops.DotGeneral(lhs, rhs, dimension_numbers)
|
return ops.DotGeneral(lhs, rhs, dimension_numbers)
|
||||||
|
|
||||||
def Conv(self, lhs, rhs, window_strides, padding, feature_group_count=1):
|
def Conv(self, lhs, rhs, window_strides, padding,
|
||||||
|
feature_group_count=1, batch_group_count=1):
|
||||||
"""Enqueues a Conv operation onto the computation.
|
"""Enqueues a Conv operation onto the computation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1278,6 +1279,7 @@ class ComputationBuilder(object):
|
|||||||
window_strides: length-N array-like of integer kernel strides.
|
window_strides: length-N array-like of integer kernel strides.
|
||||||
padding: PaddingType representing either 'SAME' or 'VALID' padding.
|
padding: PaddingType representing either 'SAME' or 'VALID' padding.
|
||||||
feature_group_count: number of feature groups for grouped convolution.
|
feature_group_count: number of feature groups for grouped convolution.
|
||||||
|
batch_group_count: number of batch groups for grouped convolution.
|
||||||
Returns: a XlaOp representing the Conv operation.
|
Returns: a XlaOp representing the Conv operation.
|
||||||
"""
|
"""
|
||||||
pads = _convert_padding_type_to_pad_values(
|
pads = _convert_padding_type_to_pad_values(
|
||||||
@ -1290,7 +1292,8 @@ class ComputationBuilder(object):
|
|||||||
window_strides,
|
window_strides,
|
||||||
pads, [], [],
|
pads, [], [],
|
||||||
dimension_numbers=None,
|
dimension_numbers=None,
|
||||||
feature_group_count=feature_group_count)
|
feature_group_count=feature_group_count,
|
||||||
|
batch_group_count=batch_group_count)
|
||||||
|
|
||||||
def ConvWithGeneralPadding(self,
|
def ConvWithGeneralPadding(self,
|
||||||
lhs,
|
lhs,
|
||||||
@ -1299,7 +1302,8 @@ class ComputationBuilder(object):
|
|||||||
padding,
|
padding,
|
||||||
lhs_dilation,
|
lhs_dilation,
|
||||||
rhs_dilation,
|
rhs_dilation,
|
||||||
feature_group_count=1):
|
feature_group_count=1,
|
||||||
|
batch_group_count=1):
|
||||||
"""Enqueues a ConvWithGeneralPadding operation onto the computation.
|
"""Enqueues a ConvWithGeneralPadding operation onto the computation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1310,6 +1314,7 @@ class ComputationBuilder(object):
|
|||||||
lhs_dilation: length-N array-like of dilation factors.
|
lhs_dilation: length-N array-like of dilation factors.
|
||||||
rhs_dilation: length-N array-like of dilation factors.
|
rhs_dilation: length-N array-like of dilation factors.
|
||||||
feature_group_count: number of feature groups for grouped convolution.
|
feature_group_count: number of feature groups for grouped convolution.
|
||||||
|
batch_group_count: number of batch groups for grouped convolution.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A ComputationdataHandle representing the added ConvWithGeneralPadding op.
|
A ComputationdataHandle representing the added ConvWithGeneralPadding op.
|
||||||
@ -1322,7 +1327,8 @@ class ComputationBuilder(object):
|
|||||||
list(lhs_dilation),
|
list(lhs_dilation),
|
||||||
list(rhs_dilation),
|
list(rhs_dilation),
|
||||||
dimension_numbers=None,
|
dimension_numbers=None,
|
||||||
feature_group_count=feature_group_count)
|
feature_group_count=feature_group_count,
|
||||||
|
batch_group_count=batch_group_count)
|
||||||
|
|
||||||
def _GetConvDimensionNumbers(self, num_spatial_dims):
|
def _GetConvDimensionNumbers(self, num_spatial_dims):
|
||||||
"""Create ConvolutionDimensionNumbers proto for convolutions."""
|
"""Create ConvolutionDimensionNumbers proto for convolutions."""
|
||||||
@ -1347,7 +1353,8 @@ class ComputationBuilder(object):
|
|||||||
lhs_dilation,
|
lhs_dilation,
|
||||||
rhs_dilation,
|
rhs_dilation,
|
||||||
dimension_numbers=None,
|
dimension_numbers=None,
|
||||||
feature_group_count=1):
|
feature_group_count=1,
|
||||||
|
batch_group_count=1):
|
||||||
"""Enqueues a ConvGeneralDilated operation onto the computation.
|
"""Enqueues a ConvGeneralDilated operation onto the computation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1377,6 +1384,7 @@ class ComputationBuilder(object):
|
|||||||
default, use the same dimension numbering as Conv and
|
default, use the same dimension numbering as Conv and
|
||||||
ConvWithGeneralPadding.
|
ConvWithGeneralPadding.
|
||||||
feature_group_count: number of feature groups for grouped convolution.
|
feature_group_count: number of feature groups for grouped convolution.
|
||||||
|
batch_group_count: number of batch groups for grouped convolution.
|
||||||
Returns: a XlaOp representing the ConvGenralDilated operation.
|
Returns: a XlaOp representing the ConvGenralDilated operation.
|
||||||
"""
|
"""
|
||||||
if dimension_numbers is None:
|
if dimension_numbers is None:
|
||||||
@ -1402,7 +1410,7 @@ class ComputationBuilder(object):
|
|||||||
key=lambda i: rhs_spec.index(out_spec[i])))
|
key=lambda i: rhs_spec.index(out_spec[i])))
|
||||||
return ops.ConvGeneralDilated(lhs, rhs, window_strides, padding,
|
return ops.ConvGeneralDilated(lhs, rhs, window_strides, padding,
|
||||||
lhs_dilation, rhs_dilation, dimension_numbers,
|
lhs_dilation, rhs_dilation, dimension_numbers,
|
||||||
feature_group_count)
|
feature_group_count, batch_group_count)
|
||||||
|
|
||||||
def Sort(self, operand, dimension=-1):
|
def Sort(self, operand, dimension=-1):
|
||||||
"""Enqueues a sort operation onto the computation."""
|
"""Enqueues a sort operation onto the computation."""
|
||||||
|
Loading…
Reference in New Issue
Block a user