[XLA] Remove duplicated XlaOp comments for private XlaOp methods.
PiperOrigin-RevId: 222171307
This commit is contained in:
parent
7b327e27d6
commit
3affd7655e
@ -280,31 +280,14 @@ class XlaBuilder {
|
|||||||
// Build helper which takes the id of the root operation..
|
// Build helper which takes the id of the root operation..
|
||||||
StatusOr<XlaComputation> Build(int64 root_id);
|
StatusOr<XlaComputation> Build(int64 root_id);
|
||||||
|
|
||||||
// Enqueues a "retrieve parameter value" instruction for a parameter that was
|
// Description for the methods below can be found in the corresponding public
|
||||||
// passed to the computation.
|
// functions section in this file.
|
||||||
|
|
||||||
XlaOp Parameter(int64 parameter_number, const Shape& shape,
|
XlaOp Parameter(int64 parameter_number, const Shape& shape,
|
||||||
const string& name);
|
const string& name);
|
||||||
|
|
||||||
// Enqueues a constant with the value of the given literal onto the
|
|
||||||
// computation.
|
|
||||||
XlaOp ConstantLiteral(const LiteralSlice& literal);
|
XlaOp ConstantLiteral(const LiteralSlice& literal);
|
||||||
|
|
||||||
// Enqueues a constant onto the computation. Methods are templated on the
|
|
||||||
// native host type (NativeT) which corresponds to a specific XLA
|
|
||||||
// PrimitiveType as given in the following table:
|
|
||||||
//
|
|
||||||
// Native Type PrimitiveType
|
|
||||||
// -----------------------------
|
|
||||||
// bool PRED
|
|
||||||
// int32 S32
|
|
||||||
// int64 S64
|
|
||||||
// uint32 U32
|
|
||||||
// uint64 U64
|
|
||||||
// float F32
|
|
||||||
// double F64
|
|
||||||
//
|
|
||||||
// Note: not all primitive types defined in xla_data.proto have a
|
|
||||||
// corresponding native type yet.
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
XlaOp ConstantR0(NativeT value);
|
XlaOp ConstantR0(NativeT value);
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
@ -334,181 +317,78 @@ class XlaBuilder {
|
|||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
XlaOp ConstantR4FromArray4D(const Array4D<NativeT>& values);
|
XlaOp ConstantR4FromArray4D(const Array4D<NativeT>& values);
|
||||||
|
|
||||||
// Enqueues a rank one constant (vector) onto the computation. The vector has
|
|
||||||
// size 'length' and every element has the value 'value'.
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
XlaOp ConstantR1(int64 length, NativeT value);
|
XlaOp ConstantR1(int64 length, NativeT value);
|
||||||
|
|
||||||
// Adds dimensions to an array by duplicating the data in the array.
|
|
||||||
//
|
|
||||||
// The new dimensions are inserted on the left, i.e. if
|
|
||||||
// broadcast_sizes has values {a0, ..., aN} and the operand shape
|
|
||||||
// has dimensions {b0, ..., bM} then the shape of the output has
|
|
||||||
// dimensions {a0, ..., aN, b0, ..., bM}.
|
|
||||||
//
|
|
||||||
// The new dimensions index into copies of the operand, i.e.
|
|
||||||
//
|
|
||||||
// output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
|
|
||||||
XlaOp Broadcast(const XlaOp& operand,
|
XlaOp Broadcast(const XlaOp& operand,
|
||||||
absl::Span<const int64> broadcast_sizes);
|
absl::Span<const int64> broadcast_sizes);
|
||||||
|
|
||||||
XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape,
|
XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape,
|
||||||
const absl::Span<const int64> broadcast_dimensions);
|
const absl::Span<const int64> broadcast_dimensions);
|
||||||
|
|
||||||
// Enqueues a pad operation onto the computation that pads the given value on
|
|
||||||
// the edges as well as between the elements of the input. padding_config
|
|
||||||
// specifies the padding amount for each dimension.
|
|
||||||
XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
|
XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
|
||||||
const PaddingConfig& padding_config);
|
const PaddingConfig& padding_config);
|
||||||
|
|
||||||
// Enqueues an operation onto the computation that flattens the operand based
|
|
||||||
// on the dimension order (major/slowest-varying to minor/fastest-varying)
|
|
||||||
// given, followed by reshaping it into the shape with the given dimension
|
|
||||||
// sizes (also major to minor). Conceptually, this is a limited form of
|
|
||||||
// "shape casting".
|
|
||||||
XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
|
XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
|
||||||
absl::Span<const int64> new_sizes);
|
absl::Span<const int64> new_sizes);
|
||||||
|
|
||||||
// Enqueues an operation onto the computation that collapses the operand, from
|
|
||||||
// first to last dimension (C order), then reshapes it to the given dimension
|
|
||||||
// sizes. Conceptually, this is a limited form of "shape casting".
|
|
||||||
XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes);
|
XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes);
|
||||||
|
|
||||||
// Wrapper for Reshape.
|
|
||||||
// Enqueues an operation to collapse the provided dimensions; e.g. an
|
|
||||||
// operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
|
|
||||||
// {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
|
|
||||||
// be a consecutive, in-order subsequence of the operand dimensions.
|
|
||||||
//
|
|
||||||
// Note that collapsing a single dimension does nothing:
|
|
||||||
//
|
|
||||||
// {256} collapsing {0} => {256}
|
|
||||||
// {1} collapsing {0} => {1}
|
|
||||||
//
|
|
||||||
// Collapsing multiple dimensions produces a single result dimension:
|
|
||||||
//
|
|
||||||
// {256, 2} collapsing {0,1} => {512}
|
|
||||||
// {256, 2, 3} collapsing {0,1} => {512, 3}
|
|
||||||
//
|
|
||||||
// This could potentially cause data to be moved -- it provides a more
|
|
||||||
// structured form of reshaping than an arbitrary Reshape operation.
|
|
||||||
XlaOp Collapse(const XlaOp& operand, absl::Span<const int64> dimensions);
|
XlaOp Collapse(const XlaOp& operand, absl::Span<const int64> dimensions);
|
||||||
|
|
||||||
// Enqueues a slice operation onto the computation that slices the operand
|
|
||||||
// from the start indices to the limit indices; e.g.
|
|
||||||
//
|
|
||||||
// x
|
|
||||||
// [ 0 1 2 3 ]
|
|
||||||
// y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
|
|
||||||
// [ 8 9 a b ]
|
|
||||||
//
|
|
||||||
// Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
|
|
||||||
// range notation.
|
|
||||||
// The strides parameter determines the stride over the slice
|
|
||||||
XlaOp Slice(const XlaOp& operand, absl::Span<const int64> start_indices,
|
XlaOp Slice(const XlaOp& operand, absl::Span<const int64> start_indices,
|
||||||
absl::Span<const int64> limit_indices,
|
absl::Span<const int64> limit_indices,
|
||||||
absl::Span<const int64> strides);
|
absl::Span<const int64> strides);
|
||||||
|
|
||||||
// Enqueues a slice operation in a given dimension, taking all other
|
|
||||||
// dimensions as they are; e.g. if dimno is 1 from start_index 2 to
|
|
||||||
// limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
|
|
||||||
// for:
|
|
||||||
//
|
|
||||||
// array[:, 2:4:1, :]
|
|
||||||
XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
|
XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
|
||||||
int64 stride, int64 dimno);
|
int64 stride, int64 dimno);
|
||||||
|
|
||||||
// Enqueues a slice operation onto the computation that slices the 'operand'
|
|
||||||
// from dynamic start indices which are passed in 'start_indices'.
|
|
||||||
// The size of the slice in each dimension is passed in 'slice_sizes',
|
|
||||||
// which specify the end point of exclusive slice intervals in each
|
|
||||||
// dimension [start, start + size).
|
|
||||||
// The shape of 'start_indices' must be rank == 1, with dimension size
|
|
||||||
// equal to the rank of the 'operand'.
|
|
||||||
// Slice index calculations are computed modulo input dimension sizes to
|
|
||||||
// prevent dynamic start indices from generating out-of-bound array accesses.
|
|
||||||
XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
|
XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
|
||||||
absl::Span<const int64> slice_sizes);
|
absl::Span<const int64> slice_sizes);
|
||||||
|
|
||||||
// Enqueues a dynamic update slice operation onto the computation, which
|
|
||||||
// updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
|
|
||||||
// The shape of 'update' determines the shape of the slice of 'operand'
|
|
||||||
// which is updated.
|
|
||||||
// The indices specified in 'start_indices' specify the offset of the slice
|
|
||||||
// of 'operand' which is updated.
|
|
||||||
//
|
|
||||||
// update = {10, 11} // calculated at runtime.
|
|
||||||
// [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ]
|
|
||||||
// [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11]
|
|
||||||
// [7 8 9] [7 8 9 ]
|
|
||||||
//
|
|
||||||
// The shape of 'start_indices' must be rank == 1, with dimension size
|
|
||||||
// equal to the rank of the 'operand'.
|
|
||||||
// Slice index calculations are computed modulo update dimension sizes to
|
|
||||||
// prevent dynamic start indices from generating out-of-bound array accesses.
|
|
||||||
XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
|
XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
|
||||||
const XlaOp& start_indices);
|
const XlaOp& start_indices);
|
||||||
|
|
||||||
// Enqueues a concatenate instruction onto the computation. 'operands' must
|
|
||||||
// have >= 1 entry.
|
|
||||||
XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension);
|
XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension);
|
||||||
|
|
||||||
// Enqueue a tracing operation onto the computation; the computation will emit
|
|
||||||
// a logging message with the operand.
|
|
||||||
void Trace(const string& tag, const XlaOp& operand);
|
void Trace(const string& tag, const XlaOp& operand);
|
||||||
|
|
||||||
// Enqueues a conditional-move-like select operation onto the computation;
|
|
||||||
// predicated on pred, selects between on_true and on_false.
|
|
||||||
XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false);
|
XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false);
|
||||||
|
|
||||||
// Enqueues a tuple-creation instruction onto the computation.
|
|
||||||
XlaOp Tuple(absl::Span<const XlaOp> elements);
|
XlaOp Tuple(absl::Span<const XlaOp> elements);
|
||||||
|
|
||||||
// Enqueues a tuple-element-get instruction onto the computation.
|
|
||||||
XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
|
XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
|
||||||
|
|
||||||
// Enqueues an equal-to comparison instruction onto the computation.
|
|
||||||
XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Enqueues a not-equal comparison instruction onto the computation.
|
|
||||||
XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Enqueues a greater-or-equal comparison instruction onto the computation.
|
|
||||||
XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Enqueues a greater-than comparison instruction onto the computation.
|
|
||||||
XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Enqueues a less-than comparison instruction onto the computation.
|
|
||||||
XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Enqueues a less-or-equal comparison instruction onto the computation.
|
|
||||||
XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Enqueues a dot instruction onto the computation.
|
|
||||||
XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
const PrecisionConfig* precision_config = nullptr);
|
const PrecisionConfig* precision_config = nullptr);
|
||||||
|
|
||||||
// Enqueues a general dot instruction onto the computation.
|
|
||||||
XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
const DotDimensionNumbers& dimension_numbers,
|
const DotDimensionNumbers& dimension_numbers,
|
||||||
const PrecisionConfig* precision_config = nullptr);
|
const PrecisionConfig* precision_config = nullptr);
|
||||||
|
|
||||||
// Enqueues a convolution instruction onto the computation, which uses the
|
|
||||||
// default convolution dimension numbers.
|
|
||||||
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> window_strides, Padding padding,
|
absl::Span<const int64> window_strides, Padding padding,
|
||||||
int64 feature_group_count = 1,
|
int64 feature_group_count = 1,
|
||||||
const PrecisionConfig* precision_config = nullptr);
|
const PrecisionConfig* precision_config = nullptr);
|
||||||
|
|
||||||
// Enqueues a convolution instruction onto the computation, with the caller
|
|
||||||
// provided padding configuration in the format returned by MakePadding().
|
|
||||||
XlaOp ConvWithGeneralPadding(
|
XlaOp ConvWithGeneralPadding(
|
||||||
const XlaOp& lhs, const XlaOp& rhs,
|
const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> window_strides,
|
absl::Span<const int64> window_strides,
|
||||||
@ -516,8 +396,6 @@ class XlaBuilder {
|
|||||||
int64 feature_group_count = 1,
|
int64 feature_group_count = 1,
|
||||||
const PrecisionConfig* precision_config = nullptr);
|
const PrecisionConfig* precision_config = nullptr);
|
||||||
|
|
||||||
// Enqueues a convolution instruction onto the computation, with the caller
|
|
||||||
// provided dimension numbers configuration.
|
|
||||||
XlaOp ConvWithGeneralDimensions(
|
XlaOp ConvWithGeneralDimensions(
|
||||||
const XlaOp& lhs, const XlaOp& rhs,
|
const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> window_strides, Padding padding,
|
absl::Span<const int64> window_strides, Padding padding,
|
||||||
@ -525,8 +403,6 @@ class XlaBuilder {
|
|||||||
int64 feature_group_count = 1,
|
int64 feature_group_count = 1,
|
||||||
const PrecisionConfig* precision_config = nullptr);
|
const PrecisionConfig* precision_config = nullptr);
|
||||||
|
|
||||||
// Enqueues a convolution instruction onto the computation, with the caller
|
|
||||||
// provided padding configuration as well as the dimension numbers.
|
|
||||||
XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> window_strides,
|
absl::Span<const int64> window_strides,
|
||||||
absl::Span<const std::pair<int64, int64>> padding,
|
absl::Span<const std::pair<int64, int64>> padding,
|
||||||
@ -534,8 +410,6 @@ class XlaBuilder {
|
|||||||
int64 feature_group_count = 1,
|
int64 feature_group_count = 1,
|
||||||
const PrecisionConfig* precision_config = nullptr);
|
const PrecisionConfig* precision_config = nullptr);
|
||||||
|
|
||||||
// Enqueues a convolution instruction onto the computation, with the caller
|
|
||||||
// provided padding configuration, dilation factors and dimension numbers.
|
|
||||||
XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> window_strides,
|
absl::Span<const int64> window_strides,
|
||||||
absl::Span<const std::pair<int64, int64>> padding,
|
absl::Span<const std::pair<int64, int64>> padding,
|
||||||
@ -545,80 +419,53 @@ class XlaBuilder {
|
|||||||
int64 feature_group_count = 1,
|
int64 feature_group_count = 1,
|
||||||
const PrecisionConfig* precision_config = nullptr);
|
const PrecisionConfig* precision_config = nullptr);
|
||||||
|
|
||||||
// Enqueues an FFT instruction onto the computation, of the given type and
|
|
||||||
// with the given FFT length.
|
|
||||||
XlaOp Fft(const XlaOp& operand, FftType fft_type,
|
XlaOp Fft(const XlaOp& operand, FftType fft_type,
|
||||||
absl::Span<const int64> fft_length);
|
absl::Span<const int64> fft_length);
|
||||||
|
|
||||||
// Enqueues an infeed instruction onto the computation, which writes data of
|
|
||||||
// the given shape to the infeed buffer of the device.
|
|
||||||
XlaOp Infeed(const Shape& shape, const string& config = "");
|
XlaOp Infeed(const Shape& shape, const string& config = "");
|
||||||
XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
|
XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
|
||||||
const string& config = "");
|
const string& config = "");
|
||||||
|
|
||||||
// Enqueues an outfeed instruction onto the computation. This instruction
|
|
||||||
// generates outgoing data transfers for the given data.
|
|
||||||
//
|
|
||||||
// shape_with_layout communicates the laid out shape that we want to outfeed
|
|
||||||
// -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
|
|
||||||
// will occur.
|
|
||||||
void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
|
void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
|
||||||
const string& outfeed_config);
|
const string& outfeed_config);
|
||||||
XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
|
XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
|
||||||
const Shape& shape_with_layout,
|
const Shape& shape_with_layout,
|
||||||
const string& outfeed_config);
|
const string& outfeed_config);
|
||||||
|
|
||||||
// Enqueues a call instruction onto the computation.
|
|
||||||
XlaOp Call(const XlaComputation& computation,
|
XlaOp Call(const XlaComputation& computation,
|
||||||
absl::Span<const XlaOp> operands);
|
absl::Span<const XlaOp> operands);
|
||||||
|
|
||||||
// Enqueues a custom call instruction onto the computation.
|
|
||||||
XlaOp CustomCall(
|
XlaOp CustomCall(
|
||||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||||
const Shape& shape_with_layout, const string& opaque,
|
const Shape& shape_with_layout, const string& opaque,
|
||||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);
|
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);
|
||||||
|
|
||||||
// The following methods enqueue element-wise binary arithmetic operations
|
|
||||||
// onto the computation. The shapes of the operands have to match unless one
|
|
||||||
// of the operands is a scalar, or an explicit broadcast dimension is given
|
|
||||||
// (see g3doc for more details).
|
|
||||||
|
|
||||||
// Enqueues a complex compose instruction onto the computation.
|
|
||||||
XlaOp Complex(const XlaOp& real, const XlaOp& imag,
|
XlaOp Complex(const XlaOp& real, const XlaOp& imag,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Enqueues a complex conjugate instruction onto the computation.
|
|
||||||
XlaOp Conj(const XlaOp& operand);
|
XlaOp Conj(const XlaOp& operand);
|
||||||
|
|
||||||
// Enqueues an add instruction onto the computation.
|
|
||||||
XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Enqueues a subtract instruction onto the computation.
|
|
||||||
XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Enqueues a multiply instruction onto the computation.
|
|
||||||
XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Enqueues a divide instruction onto the computation.
|
|
||||||
XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Enqueues a remainder instruction onto the computation.
|
|
||||||
XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Enqueues a max instruction onto the computation.
|
|
||||||
XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Enqueues a min instruction onto the computation.
|
|
||||||
XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Element-wise logical operators
|
|
||||||
XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
@ -637,32 +484,23 @@ class XlaBuilder {
|
|||||||
XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Reduces an array among the provided dimensions, given "computation" as a
|
|
||||||
// reduction operator.
|
|
||||||
XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
|
XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
|
||||||
const XlaComputation& computation,
|
const XlaComputation& computation,
|
||||||
absl::Span<const int64> dimensions_to_reduce);
|
absl::Span<const int64> dimensions_to_reduce);
|
||||||
|
|
||||||
// Reduces several arrays simultaneously among the provided dimensions, given
|
|
||||||
// "computation" as a reduction operator.
|
|
||||||
XlaOp Reduce(absl::Span<const XlaOp> operands,
|
XlaOp Reduce(absl::Span<const XlaOp> operands,
|
||||||
absl::Span<const XlaOp> init_values,
|
absl::Span<const XlaOp> init_values,
|
||||||
const XlaComputation& computation,
|
const XlaComputation& computation,
|
||||||
absl::Span<const int64> dimensions_to_reduce);
|
absl::Span<const int64> dimensions_to_reduce);
|
||||||
|
|
||||||
// Convenience wrapper around the above that reduces all the dimensions in the
|
|
||||||
// operand shape.
|
|
||||||
XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
|
XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
|
||||||
const XlaComputation& computation);
|
const XlaComputation& computation);
|
||||||
|
|
||||||
// Enqueues a windowed reduce instruction onto the computation.
|
|
||||||
XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
|
XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
|
||||||
const XlaComputation& computation,
|
const XlaComputation& computation,
|
||||||
absl::Span<const int64> window_dimensions,
|
absl::Span<const int64> window_dimensions,
|
||||||
absl::Span<const int64> window_strides, Padding padding);
|
absl::Span<const int64> window_strides, Padding padding);
|
||||||
|
|
||||||
// As ReduceWindow(), but the padding is given in the format
|
|
||||||
// returned by MakePadding().
|
|
||||||
XlaOp ReduceWindowWithGeneralPadding(
|
XlaOp ReduceWindowWithGeneralPadding(
|
||||||
const XlaOp& operand, const XlaOp& init_value,
|
const XlaOp& operand, const XlaOp& init_value,
|
||||||
const XlaComputation& computation,
|
const XlaComputation& computation,
|
||||||
@ -672,48 +510,22 @@ class XlaBuilder {
|
|||||||
absl::Span<const int64> window_dilations,
|
absl::Span<const int64> window_dilations,
|
||||||
absl::Span<const std::pair<int64, int64>> padding);
|
absl::Span<const std::pair<int64, int64>> padding);
|
||||||
|
|
||||||
// Returns the sum of the operand value within each subgroup of replicas. All
|
|
||||||
// replicas supply one input to the sum and all replicas receive the resulting
|
|
||||||
// sum for each subgroup.
|
|
||||||
XlaOp CrossReplicaSum(const XlaOp& operand,
|
XlaOp CrossReplicaSum(const XlaOp& operand,
|
||||||
absl::Span<const ReplicaGroup> replica_groups = {});
|
absl::Span<const ReplicaGroup> replica_groups = {});
|
||||||
|
|
||||||
// Enqueues an operation that do an AllReduce of the operand cross cores. Here
|
|
||||||
// AllReduce means doing a reduction on the input operand cross cores and then
|
|
||||||
// broadcasting the reduction result to those cores. The reduction function is
|
|
||||||
// defined by `computation`, which should be a commutative computation on
|
|
||||||
// scalars, e.g., add, min, or max. The way that AllReduce is applied is
|
|
||||||
// configured by:
|
|
||||||
//
|
|
||||||
// - `replica_groups`: each ReplicaGroup contains a list of replica id. If
|
|
||||||
// empty, all replicas belong to one group. Allreduce will be applied within
|
|
||||||
// subgroups. For example, we have 4 replicas, then
|
|
||||||
// replica_groups={{0,2},{1,3}} means, replica 0 and 2 are in subgroup 0,
|
|
||||||
// replica 1 and 3 are in subgroup 1.
|
|
||||||
//
|
|
||||||
// - `channel_id`: for Allreduce nodes from different modules, if they have
|
|
||||||
// the same channel_id, they will be 'Allreduce'd. If empty, Allreduce will
|
|
||||||
// not be applied cross modules.
|
|
||||||
//
|
|
||||||
// TODO(b/117564385): Rename this to AllReduce when it's ready to use.
|
|
||||||
XlaOp CrossReplicaSum(
|
XlaOp CrossReplicaSum(
|
||||||
const XlaOp& operand, const XlaComputation& computation,
|
const XlaOp& operand, const XlaComputation& computation,
|
||||||
absl::Span<const ReplicaGroup> replica_groups = {},
|
absl::Span<const ReplicaGroup> replica_groups = {},
|
||||||
const absl::optional<ChannelHandle>& channel_id = absl::nullopt);
|
const absl::optional<ChannelHandle>& channel_id = absl::nullopt);
|
||||||
|
|
||||||
// Enqueues an operation that do an Alltoall of the operand cross cores.
|
|
||||||
XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
|
XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
|
||||||
int64 concat_dimension, int64 split_count,
|
int64 concat_dimension, int64 split_count,
|
||||||
const std::vector<ReplicaGroup>& replica_groups);
|
const std::vector<ReplicaGroup>& replica_groups);
|
||||||
|
|
||||||
// Enqueues an operation that do an CollectivePermute of the operand cross
|
|
||||||
// cores.
|
|
||||||
XlaOp CollectivePermute(
|
XlaOp CollectivePermute(
|
||||||
const XlaOp& operand,
|
const XlaOp& operand,
|
||||||
const std::vector<std::pair<int64, int64>>& source_target_pairs);
|
const std::vector<std::pair<int64, int64>>& source_target_pairs);
|
||||||
|
|
||||||
// Enqueues an operation that scatters the `source` array to the selected
|
|
||||||
// indices of each window.
|
|
||||||
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
|
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
|
||||||
absl::Span<const int64> window_dimensions,
|
absl::Span<const int64> window_dimensions,
|
||||||
absl::Span<const int64> window_strides,
|
absl::Span<const int64> window_strides,
|
||||||
@ -721,8 +533,6 @@ class XlaBuilder {
|
|||||||
const XlaOp& init_value,
|
const XlaOp& init_value,
|
||||||
const XlaComputation& scatter);
|
const XlaComputation& scatter);
|
||||||
|
|
||||||
// As SelectAndScatter(), but the padding is given in the format
|
|
||||||
// returned by MakePadding().
|
|
||||||
XlaOp SelectAndScatterWithGeneralPadding(
|
XlaOp SelectAndScatterWithGeneralPadding(
|
||||||
const XlaOp& operand, const XlaComputation& select,
|
const XlaOp& operand, const XlaComputation& select,
|
||||||
absl::Span<const int64> window_dimensions,
|
absl::Span<const int64> window_dimensions,
|
||||||
@ -730,217 +540,119 @@ class XlaBuilder {
|
|||||||
absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
|
absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
|
||||||
const XlaOp& init_value, const XlaComputation& scatter);
|
const XlaOp& init_value, const XlaComputation& scatter);
|
||||||
|
|
||||||
// Enqueues an abs instruction onto the computation.
|
|
||||||
XlaOp Abs(const XlaOp& operand);
|
XlaOp Abs(const XlaOp& operand);
|
||||||
|
|
||||||
// Enqueues a atan2 instruction onto the computation.
|
|
||||||
XlaOp Atan2(const XlaOp& y, const XlaOp& x,
|
XlaOp Atan2(const XlaOp& y, const XlaOp& x,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Enqueues an exp instruction onto the computation.
|
|
||||||
XlaOp Exp(const XlaOp& operand);
|
XlaOp Exp(const XlaOp& operand);
|
||||||
|
|
||||||
// Enqueues an expm1 instruction onto the computation.
|
|
||||||
XlaOp Expm1(const XlaOp& operand);
|
XlaOp Expm1(const XlaOp& operand);
|
||||||
|
|
||||||
// Enqueues a floor instruction onto the computation.
|
|
||||||
XlaOp Floor(const XlaOp& operand);
|
XlaOp Floor(const XlaOp& operand);
|
||||||
|
|
||||||
// Enqueues a ceil instruction onto the computation.
|
|
||||||
XlaOp Ceil(const XlaOp& operand);
|
XlaOp Ceil(const XlaOp& operand);
|
||||||
|
|
||||||
// Enqueues a round instruction onto the computation, rounding to nearest even
|
|
||||||
// with half-way cases rounding away from zero.
|
|
||||||
XlaOp Round(const XlaOp& operand);
|
XlaOp Round(const XlaOp& operand);
|
||||||
|
|
||||||
// Enqueues an log instruction (natural logarithm) onto the computation.
|
|
||||||
XlaOp Log(const XlaOp& operand);
|
XlaOp Log(const XlaOp& operand);
|
||||||
|
|
||||||
// Enqueues an log1p instruction (log(x+1)) onto the computation.
|
|
||||||
XlaOp Log1p(const XlaOp& operand);
|
XlaOp Log1p(const XlaOp& operand);
|
||||||
|
|
||||||
// Enqueues a sign instruction onto the computation.
|
|
||||||
XlaOp Sign(const XlaOp& operand);
|
XlaOp Sign(const XlaOp& operand);
|
||||||
|
|
||||||
// Enqueues a count leading zeros instruction onto the computation.
|
|
||||||
XlaOp Clz(const XlaOp& operand);
|
XlaOp Clz(const XlaOp& operand);
|
||||||
|
|
||||||
// Enqueues a cosine instruction onto the computation.
|
|
||||||
XlaOp Cos(const XlaOp& operand);
|
XlaOp Cos(const XlaOp& operand);
|
||||||
|
|
||||||
// Enqueues a sine instruction onto the computation.
|
|
||||||
XlaOp Sin(const XlaOp& operand);
|
XlaOp Sin(const XlaOp& operand);
|
||||||
|
|
||||||
// Enqueues a tanh instruction onto the computation.
|
|
||||||
XlaOp Tanh(const XlaOp& operand);
|
XlaOp Tanh(const XlaOp& operand);
|
||||||
|
|
||||||
// Enqueues a real-part instruction onto the computation.
|
|
||||||
XlaOp Real(const XlaOp& operand);
|
XlaOp Real(const XlaOp& operand);
|
||||||
|
|
||||||
// Enqueues an imaginary-part instruction onto the computation.
|
|
||||||
XlaOp Imag(const XlaOp& operand);
|
XlaOp Imag(const XlaOp& operand);
|
||||||
|
|
||||||
// Enqueues a lhs^rhs computation onto the computation.
|
|
||||||
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Enqueues an operator that tests if the operand's values are finite, i.e.,
|
|
||||||
// not Inf or NaN. Defined only for floating-point types. Returns an array of
|
|
||||||
// booleans with the same shape where entries are true iff the corresponding
|
|
||||||
// entry was NaN.
|
|
||||||
XlaOp IsFinite(const XlaOp& operand);
|
XlaOp IsFinite(const XlaOp& operand);
|
||||||
|
|
||||||
// Enqueues an iota operation onto the computation.
|
|
||||||
XlaOp Iota(const Shape& shape, int64 iota_dimension);
|
XlaOp Iota(const Shape& shape, int64 iota_dimension);
|
||||||
|
|
||||||
// Enqueues a rank-1 iota operation onto the computation.
|
|
||||||
XlaOp Iota(PrimitiveType type, int64 size);
|
XlaOp Iota(PrimitiveType type, int64 size);
|
||||||
|
|
||||||
// Enqueues a convert instruction onto the computation that changes the
|
|
||||||
// element type of the operand array to primitive_type.
|
|
||||||
XlaOp ConvertElementType(const XlaOp& operand,
|
XlaOp ConvertElementType(const XlaOp& operand,
|
||||||
PrimitiveType new_element_type);
|
PrimitiveType new_element_type);
|
||||||
|
|
||||||
// Enqueues a no-op instruction onto the computation that changes
|
|
||||||
// the element type of the operand array to primitive_type. The
|
|
||||||
// bit-widths of the source and destination element types must be
|
|
||||||
// identical.
|
|
||||||
XlaOp BitcastConvertType(const XlaOp& operand,
|
XlaOp BitcastConvertType(const XlaOp& operand,
|
||||||
PrimitiveType new_element_type);
|
PrimitiveType new_element_type);
|
||||||
|
|
||||||
// Enqueues a negate instruction onto the computation.
|
|
||||||
XlaOp Neg(const XlaOp& operand);
|
XlaOp Neg(const XlaOp& operand);
|
||||||
|
|
||||||
// Enqueues a transpose instruction onto the computation.
|
|
||||||
XlaOp Transpose(const XlaOp& operand, absl::Span<const int64> permutation);
|
XlaOp Transpose(const XlaOp& operand, absl::Span<const int64> permutation);
|
||||||
|
|
||||||
// Enqueues a reverse instruction onto the computation. The order of the
|
|
||||||
// elements in the given dimensions is reversed (i.e., the element at index i
|
|
||||||
// is moved to index dimension_size - 1 - i).
|
|
||||||
XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions);
|
XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions);
|
||||||
|
|
||||||
// Enqueues a sort (as increasing order) instruction onto the computation.
|
|
||||||
// If only keys are provided:
|
|
||||||
// * If the keys are an rank-1 tensor (an array), the result is a sorted array
|
|
||||||
// of keys, in ascending order.
|
|
||||||
// * If the keys have higher rank, the keys are sorted along the provided
|
|
||||||
// dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension
|
|
||||||
// value of 0 will indepenently sort every column, and a dimension value of 1
|
|
||||||
// will independently sort each row. If no dimension number is provided, then
|
|
||||||
// the last dimension is chosen by default.
|
|
||||||
//
|
|
||||||
// If both keys and values are provided:
|
|
||||||
// * The keys and all values must be tensors with the same dimensions. The
|
|
||||||
// element types of the tensors may be different.
|
|
||||||
// * The result is a tuple that consists of a sorted tensor of keys (along the
|
|
||||||
// provided dimension, as above) as the first element, and tensors with their
|
|
||||||
// corresponding values as the other elements.
|
|
||||||
XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values = {},
|
XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values = {},
|
||||||
int64 dimension = -1);
|
int64 dimension = -1);
|
||||||
|
|
||||||
// Enqueues a clamp instruction onto the computation.
|
|
||||||
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
|
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
|
||||||
|
|
||||||
// Enqueues a map instruction onto the computation.
|
|
||||||
XlaOp Map(absl::Span<const XlaOp> operands, const XlaComputation& computation,
|
XlaOp Map(absl::Span<const XlaOp> operands, const XlaComputation& computation,
|
||||||
absl::Span<const int64> dimensions,
|
absl::Span<const int64> dimensions,
|
||||||
absl::Span<const XlaOp> static_operands = {});
|
absl::Span<const XlaOp> static_operands = {});
|
||||||
|
|
||||||
// Enqueues a N(mu, sigma) random number generation instruction onto the
|
|
||||||
// computation.
|
|
||||||
XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape);
|
XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape);
|
||||||
|
|
||||||
// Enqueues a U(a, b) random number generation instruction onto the
|
|
||||||
// computation. Returns values in the semi-open interval [a, b).
|
|
||||||
XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
|
XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
|
||||||
|
|
||||||
// Enqueues a while node onto the computation.
|
|
||||||
XlaOp While(const XlaComputation& condition, const XlaComputation& body,
|
XlaOp While(const XlaComputation& condition, const XlaComputation& body,
|
||||||
const XlaOp& init);
|
const XlaOp& init);
|
||||||
|
|
||||||
// Enqueues a conditional node onto the computation.
|
|
||||||
XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
|
XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
|
||||||
const XlaComputation& true_computation,
|
const XlaComputation& true_computation,
|
||||||
const XlaOp& false_operand,
|
const XlaOp& false_operand,
|
||||||
const XlaComputation& false_computation);
|
const XlaComputation& false_computation);
|
||||||
|
|
||||||
// Enqueues a ReducePrecision node onto the computation.
|
|
||||||
XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
|
XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
|
||||||
const int mantissa_bits);
|
const int mantissa_bits);
|
||||||
|
|
||||||
// Enqueues a Gather node onto the computation.
|
|
||||||
XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
|
XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
|
||||||
const GatherDimensionNumbers& dimension_numbers,
|
const GatherDimensionNumbers& dimension_numbers,
|
||||||
absl::Span<const int64> slice_sizes);
|
absl::Span<const int64> slice_sizes);
|
||||||
|
|
||||||
// Enqueues a Scatter node onto the computation.
|
|
||||||
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
|
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
|
||||||
const XlaOp& updates, const XlaComputation& update_computation,
|
const XlaOp& updates, const XlaComputation& update_computation,
|
||||||
const ScatterDimensionNumbers& dimension_numbers);
|
const ScatterDimensionNumbers& dimension_numbers);
|
||||||
|
|
||||||
// Enqueues a Send node onto the computation for device-to-device
|
|
||||||
// communication, to send the given operand to a Recv instruction that shares
|
|
||||||
// the same channel handle.
|
|
||||||
void Send(const XlaOp& operand, const ChannelHandle& handle);
|
void Send(const XlaOp& operand, const ChannelHandle& handle);
|
||||||
XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
|
XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
|
||||||
const ChannelHandle& handle);
|
const ChannelHandle& handle);
|
||||||
|
|
||||||
// Enqueues a Send node which sends data to the host.
|
|
||||||
XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
|
XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
|
||||||
const Shape& shape_with_layout, const ChannelHandle& handle);
|
const Shape& shape_with_layout, const ChannelHandle& handle);
|
||||||
|
|
||||||
// Enqueues a Recv node which receives data from the host.
|
|
||||||
XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
|
XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
|
||||||
const ChannelHandle& handle);
|
const ChannelHandle& handle);
|
||||||
|
|
||||||
// Enqueues an AfterAll operation with no operands producing a token-shaped
|
|
||||||
// value.
|
|
||||||
XlaOp CreateToken();
|
XlaOp CreateToken();
|
||||||
|
|
||||||
// Enqueues an AfterAll operation with no operands producing a token-shaped
|
|
||||||
// value.
|
|
||||||
XlaOp AfterAll(absl::Span<const XlaOp> tokens);
|
XlaOp AfterAll(absl::Span<const XlaOp> tokens);
|
||||||
|
|
||||||
// Enqueues a Recv node onto the computation. The data comes from a Send
|
|
||||||
// instruction that shares the same channel handle and its shape must
|
|
||||||
// be the same as the given shape.
|
|
||||||
XlaOp Recv(const Shape& shape, const ChannelHandle& handle);
|
XlaOp Recv(const Shape& shape, const ChannelHandle& handle);
|
||||||
XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
|
XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
|
||||||
const ChannelHandle& handle);
|
const ChannelHandle& handle);
|
||||||
|
|
||||||
// Normalizes operand across spatial and batch dimensions for each feature.
|
|
||||||
//
|
|
||||||
// Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
|
|
||||||
// is the normalized result and batch_mean and batch_var are the mean and
|
|
||||||
// variance, respectively, across batch for the operand.
|
|
||||||
XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
|
XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
|
||||||
const XlaOp& offset, float epsilon,
|
const XlaOp& offset, float epsilon,
|
||||||
int64 feature_index);
|
int64 feature_index);
|
||||||
|
|
||||||
// Normalizes operand across spatial and batch dimensions for each feature.
|
|
||||||
//
|
|
||||||
// `BatchNormInference` is equivalent to calling `BatchNormTraining` without
|
|
||||||
// computing `mean` and `variance` for each batch inside the operation. It
|
|
||||||
// uses the input `mean` and `variance` instead as estimated values. The
|
|
||||||
// purpose of this op is to reduce latency in inference, hence the name
|
|
||||||
// `BatchNormInference`.
|
|
||||||
//
|
|
||||||
// The output has the same shape as `operand`, and contains the normalized
|
|
||||||
// values for each batch.
|
|
||||||
XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
|
XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
|
||||||
const XlaOp& offset, const XlaOp& mean,
|
const XlaOp& offset, const XlaOp& mean,
|
||||||
const XlaOp& variance, float epsilon,
|
const XlaOp& variance, float epsilon,
|
||||||
int64 feature_index);
|
int64 feature_index);
|
||||||
|
|
||||||
// Calculates the gradients of a batch norm op.
|
|
||||||
//
|
|
||||||
// The inputs `batch_mean` and `batch_var` represent the mean and variance
|
|
||||||
// across the batch.
|
|
||||||
//
|
|
||||||
// Returns a tuple of three elements:
|
|
||||||
// - grad_operand: Gradient with respect to input `operand`
|
|
||||||
// - grad_offset: Gradient with respect to input `offset`
|
|
||||||
// - grad_scale: Gradient with respect to input `scale`
|
|
||||||
XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
|
XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
|
||||||
const XlaOp& batch_mean, const XlaOp& batch_var,
|
const XlaOp& batch_mean, const XlaOp& batch_var,
|
||||||
const XlaOp& grad_output, float epsilon,
|
const XlaOp& grad_output, float epsilon,
|
||||||
@ -1409,6 +1121,7 @@ class XlaScopedShardingAssignment {
|
|||||||
// Free functions for building XlaOps. The intention is that these will
|
// Free functions for building XlaOps. The intention is that these will
|
||||||
// become the public API for building XlaOps rather than calling methods on
|
// become the public API for building XlaOps rather than calling methods on
|
||||||
// XlaBuilder directly.
|
// XlaBuilder directly.
|
||||||
|
//
|
||||||
|
|
||||||
// Enqueues a "retrieve parameter value" instruction for a parameter that was
|
// Enqueues a "retrieve parameter value" instruction for a parameter that was
|
||||||
// passed to the computation.
|
// passed to the computation.
|
||||||
@ -2154,6 +1867,7 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
|
|||||||
XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension);
|
XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension);
|
||||||
|
|
||||||
// Implementation details below this point.
|
// Implementation details below this point.
|
||||||
|
//
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
XlaOp XlaBuilder::ConstantR0(NativeT value) {
|
XlaOp XlaBuilder::ConstantR0(NativeT value) {
|
||||||
|
Loading…
Reference in New Issue
Block a user