[XLA] Remove duplicated XlaOp comments for private XlaOp methods.

PiperOrigin-RevId: 222171307
This commit is contained in:
Kay Zhu 2018-11-19 18:02:45 -08:00 committed by TensorFlower Gardener
parent 7b327e27d6
commit 3affd7655e

View File

@ -280,31 +280,14 @@ class XlaBuilder {
// Build helper which takes the id of the root operation..
StatusOr<XlaComputation> Build(int64 root_id);
// Enqueues a "retrieve parameter value" instruction for a parameter that was
// passed to the computation.
// Description for the methods below can be found in the corresponding public
// functions section in this file.
XlaOp Parameter(int64 parameter_number, const Shape& shape,
const string& name);
// Enqueues a constant with the value of the given literal onto the
// computation.
XlaOp ConstantLiteral(const LiteralSlice& literal);
// Enqueues a constant onto the computation. Methods are templated on the
// native host type (NativeT) which corresponds to a specific XLA
// PrimitiveType as given in the following table:
//
// Native Type PrimitiveType
// -----------------------------
// bool PRED
// int32 S32
// int64 S64
// uint32 U32
// uint64 U64
// float F32
// double F64
//
// Note: not all primitive types defined in xla_data.proto have a
// corresponding native type yet.
template <typename NativeT>
XlaOp ConstantR0(NativeT value);
template <typename NativeT>
@ -334,181 +317,78 @@ class XlaBuilder {
template <typename NativeT>
XlaOp ConstantR4FromArray4D(const Array4D<NativeT>& values);
// Enqueues a rank one constant (vector) onto the computation. The vector has
// size 'length' and every element has the value 'value'.
template <typename NativeT>
XlaOp ConstantR1(int64 length, NativeT value);
// Adds dimensions to an array by duplicating the data in the array.
//
// The new dimensions are inserted on the left, i.e. if
// broadcast_sizes has values {a0, ..., aN} and the operand shape
// has dimensions {b0, ..., bM} then the shape of the output has
// dimensions {a0, ..., aN, b0, ..., bM}.
//
// The new dimensions index into copies of the operand, i.e.
//
// output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
XlaOp Broadcast(const XlaOp& operand,
absl::Span<const int64> broadcast_sizes);
XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape,
const absl::Span<const int64> broadcast_dimensions);
// Enqueues a pad operation onto the computation that pads the given value on
// the edges as well as between the elements of the input. padding_config
// specifies the padding amount for each dimension.
XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
const PaddingConfig& padding_config);
// Enqueues an operation onto the computation that flattens the operand based
// on the dimension order (major/slowest-varying to minor/fastest-varying)
// given, followed by reshaping it into the shape with the given dimension
// sizes (also major to minor). Conceptually, this is a limited form of
// "shape casting".
XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
absl::Span<const int64> new_sizes);
// Enqueues an operation onto the computation that collapses the operand, from
// first to last dimension (C order), then reshapes it to the given dimension
// sizes. Conceptually, this is a limited form of "shape casting".
XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes);
// Wrapper for Reshape.
// Enqueues an operation to collapse the provided dimensions; e.g. an
// operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
// {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
// be a consecutive, in-order subsequence of the operand dimensions.
//
// Note that collapsing a single dimension does nothing:
//
// {256} collapsing {0} => {256}
// {1} collapsing {0} => {1}
//
// Collapsing multiple dimensions produces a single result dimension:
//
// {256, 2} collapsing {0,1} => {512}
// {256, 2, 3} collapsing {0,1} => {512, 3}
//
// This could potentially cause data to be moved -- it provides a more
// structured form of reshaping than an arbitrary Reshape operation.
XlaOp Collapse(const XlaOp& operand, absl::Span<const int64> dimensions);
// Enqueues a slice operation onto the computation that slices the operand
// from the start indices to the limit indices; e.g.
//
// x
// [ 0 1 2 3 ]
// y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
// [ 8 9 a b ]
//
// Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
// range notation.
// The strides parameter determines the stride over the slice
XlaOp Slice(const XlaOp& operand, absl::Span<const int64> start_indices,
absl::Span<const int64> limit_indices,
absl::Span<const int64> strides);
// Enqueues a slice operation in a given dimension, taking all other
// dimensions as they are; e.g. if dimno is 1 from start_index 2 to
// limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
// for:
//
// array[:, 2:4:1, :]
XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
int64 stride, int64 dimno);
// Enqueues a slice operation onto the computation that slices the 'operand'
// from dynamic start indices which are passed in 'start_indices'.
// The size of the slice in each dimension is passed in 'slice_sizes',
// which specify the end point of exclusive slice intervals in each
// dimension [start, start + size).
// The shape of 'start_indices' must be rank == 1, with dimension size
// equal to the rank of the 'operand'.
// Slice index calculations are computed modulo input dimension sizes to
// prevent dynamic start indices from generating out-of-bound array accesses.
XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
absl::Span<const int64> slice_sizes);
// Enqueues a dynamic update slice operation onto the computation, which
// updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
// The shape of 'update' determines the shape of the slice of 'operand'
// which is updated.
// The indices specified in 'start_indices' specify the offset of the slice
// of 'operand' which is updated.
//
// update = {10, 11} // calculated at runtime.
// [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ]
// [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11]
// [7 8 9] [7 8 9 ]
//
// The shape of 'start_indices' must be rank == 1, with dimension size
// equal to the rank of the 'operand'.
// Slice index calculations are computed modulo update dimension sizes to
// prevent dynamic start indices from generating out-of-bound array accesses.
XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
const XlaOp& start_indices);
// Enqueues a concatenate instruction onto the computation. 'operands' must
// have >= 1 entry.
XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension);
// Enqueue a tracing operation onto the computation; the computation will emit
// a logging message with the operand.
void Trace(const string& tag, const XlaOp& operand);
// Enqueues a conditional-move-like select operation onto the computation;
// predicated on pred, selects between on_true and on_false.
XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false);
// Enqueues a tuple-creation instruction onto the computation.
XlaOp Tuple(absl::Span<const XlaOp> elements);
// Enqueues a tuple-element-get instruction onto the computation.
XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
// Enqueues an equal-to comparison instruction onto the computation.
XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a not-equal comparison instruction onto the computation.
XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a greater-or-equal comparison instruction onto the computation.
XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a greater-than comparison instruction onto the computation.
XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a less-than comparison instruction onto the computation.
XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a less-or-equal comparison instruction onto the computation.
XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a dot instruction onto the computation.
XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
const PrecisionConfig* precision_config = nullptr);
// Enqueues a general dot instruction onto the computation.
XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, which uses the
// default convolution dimension numbers.
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> window_strides, Padding padding,
int64 feature_group_count = 1,
const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration in the format returned by MakePadding().
XlaOp ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> window_strides,
@ -516,8 +396,6 @@ class XlaBuilder {
int64 feature_group_count = 1,
const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided dimension numbers configuration.
XlaOp ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> window_strides, Padding padding,
@ -525,8 +403,6 @@ class XlaBuilder {
int64 feature_group_count = 1,
const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration as well as the dimension numbers.
XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
@ -534,8 +410,6 @@ class XlaBuilder {
int64 feature_group_count = 1,
const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration, dilation factors and dimension numbers.
XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
@ -545,80 +419,53 @@ class XlaBuilder {
int64 feature_group_count = 1,
const PrecisionConfig* precision_config = nullptr);
// Enqueues an FFT instruction onto the computation, of the given type and
// with the given FFT length.
XlaOp Fft(const XlaOp& operand, FftType fft_type,
absl::Span<const int64> fft_length);
// Enqueues an infeed instruction onto the computation, which writes data of
// the given shape to the infeed buffer of the device.
XlaOp Infeed(const Shape& shape, const string& config = "");
XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
const string& config = "");
// Enqueues an outfeed instruction onto the computation. This instruction
// generates outgoing data transfers for the given data.
//
// shape_with_layout communicates the laid out shape that we want to outfeed
// -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
// will occur.
void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
const string& outfeed_config);
XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
const Shape& shape_with_layout,
const string& outfeed_config);
// Enqueues a call instruction onto the computation.
XlaOp Call(const XlaComputation& computation,
absl::Span<const XlaOp> operands);
// Enqueues a custom call instruction onto the computation.
XlaOp CustomCall(
const string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape_with_layout, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
// of the operands is a scalar, or an explicit broadcast dimension is given
// (see g3doc for more details).
// Enqueues a complex compose instruction onto the computation.
XlaOp Complex(const XlaOp& real, const XlaOp& imag,
absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a complex conjugate instruction onto the computation.
XlaOp Conj(const XlaOp& operand);
// Enqueues an add instruction onto the computation.
XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a subtract instruction onto the computation.
XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a multiply instruction onto the computation.
XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a divide instruction onto the computation.
XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a remainder instruction onto the computation.
XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a max instruction onto the computation.
XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a min instruction onto the computation.
XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions = {});
// Element-wise logical operators
XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions = {});
@ -637,32 +484,23 @@ class XlaBuilder {
XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions = {});
// Reduces an array among the provided dimensions, given "computation" as a
// reduction operator.
XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
absl::Span<const int64> dimensions_to_reduce);
// Reduces several arrays simultaneously among the provided dimensions, given
// "computation" as a reduction operator.
XlaOp Reduce(absl::Span<const XlaOp> operands,
absl::Span<const XlaOp> init_values,
const XlaComputation& computation,
absl::Span<const int64> dimensions_to_reduce);
// Convenience wrapper around the above that reduces all the dimensions in the
// operand shape.
XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation);
// Enqueues a windowed reduce instruction onto the computation.
XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides, Padding padding);
// As ReduceWindow(), but the padding is given in the format
// returned by MakePadding().
XlaOp ReduceWindowWithGeneralPadding(
const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
@ -672,48 +510,22 @@ class XlaBuilder {
absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding);
// Returns the sum of the operand value within each subgroup of replicas. All
// replicas supply one input to the sum and all replicas receive the resulting
// sum for each subgroup.
XlaOp CrossReplicaSum(const XlaOp& operand,
absl::Span<const ReplicaGroup> replica_groups = {});
// Enqueues an operation that do an AllReduce of the operand cross cores. Here
// AllReduce means doing a reduction on the input operand cross cores and then
// broadcasting the reduction result to those cores. The reduction function is
// defined by `computation`, which should be a commutative computation on
// scalars, e.g., add, min, or max. The way that AllReduce is applied is
// configured by:
//
// - `replica_groups`: each ReplicaGroup contains a list of replica id. If
// empty, all replicas belong to one group. Allreduce will be applied within
// subgroups. For example, we have 4 replicas, then
// replica_groups={{0,2},{1,3}} means, replica 0 and 2 are in subgroup 0,
// replica 1 and 3 are in subgroup 1.
//
// - `channel_id`: for Allreduce nodes from different modules, if they have
// the same channel_id, they will be 'Allreduce'd. If empty, Allreduce will
// not be applied cross modules.
//
// TODO(b/117564385): Rename this to AllReduce when it's ready to use.
XlaOp CrossReplicaSum(
const XlaOp& operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups = {},
const absl::optional<ChannelHandle>& channel_id = absl::nullopt);
// Enqueues an operation that do an Alltoall of the operand cross cores.
XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups);
// Enqueues an operation that do an CollectivePermute of the operand cross
// cores.
XlaOp CollectivePermute(
const XlaOp& operand,
const std::vector<std::pair<int64, int64>>& source_target_pairs);
// Enqueues an operation that scatters the `source` array to the selected
// indices of each window.
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
@ -721,8 +533,6 @@ class XlaBuilder {
const XlaOp& init_value,
const XlaComputation& scatter);
// As SelectAndScatter(), but the padding is given in the format
// returned by MakePadding().
XlaOp SelectAndScatterWithGeneralPadding(
const XlaOp& operand, const XlaComputation& select,
absl::Span<const int64> window_dimensions,
@ -730,217 +540,119 @@ class XlaBuilder {
absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
const XlaOp& init_value, const XlaComputation& scatter);
// Enqueues an abs instruction onto the computation.
XlaOp Abs(const XlaOp& operand);
// Enqueues a atan2 instruction onto the computation.
XlaOp Atan2(const XlaOp& y, const XlaOp& x,
absl::Span<const int64> broadcast_dimensions = {});
// Enqueues an exp instruction onto the computation.
XlaOp Exp(const XlaOp& operand);
// Enqueues an expm1 instruction onto the computation.
XlaOp Expm1(const XlaOp& operand);
// Enqueues a floor instruction onto the computation.
XlaOp Floor(const XlaOp& operand);
// Enqueues a ceil instruction onto the computation.
XlaOp Ceil(const XlaOp& operand);
// Enqueues a round instruction onto the computation, rounding to nearest even
// with half-way cases rounding away from zero.
XlaOp Round(const XlaOp& operand);
// Enqueues an log instruction (natural logarithm) onto the computation.
XlaOp Log(const XlaOp& operand);
// Enqueues an log1p instruction (log(x+1)) onto the computation.
XlaOp Log1p(const XlaOp& operand);
// Enqueues a sign instruction onto the computation.
XlaOp Sign(const XlaOp& operand);
// Enqueues a count leading zeros instruction onto the computation.
XlaOp Clz(const XlaOp& operand);
// Enqueues a cosine instruction onto the computation.
XlaOp Cos(const XlaOp& operand);
// Enqueues a sine instruction onto the computation.
XlaOp Sin(const XlaOp& operand);
// Enqueues a tanh instruction onto the computation.
XlaOp Tanh(const XlaOp& operand);
// Enqueues a real-part instruction onto the computation.
XlaOp Real(const XlaOp& operand);
// Enqueues an imaginary-part instruction onto the computation.
XlaOp Imag(const XlaOp& operand);
// Enqueues a lhs^rhs computation onto the computation.
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions = {});
// Enqueues an operator that tests if the operand's values are finite, i.e.,
// not Inf or NaN. Defined only for floating-point types. Returns an array of
// booleans with the same shape where entries are true iff the corresponding
// entry was NaN.
XlaOp IsFinite(const XlaOp& operand);
// Enqueues an iota operation onto the computation.
XlaOp Iota(const Shape& shape, int64 iota_dimension);
// Enqueues a rank-1 iota operation onto the computation.
XlaOp Iota(PrimitiveType type, int64 size);
// Enqueues a convert instruction onto the computation that changes the
// element type of the operand array to primitive_type.
XlaOp ConvertElementType(const XlaOp& operand,
PrimitiveType new_element_type);
// Enqueues a no-op instruction onto the computation that changes
// the element type of the operand array to primitive_type. The
// bit-widths of the source and destination element types must be
// identical.
XlaOp BitcastConvertType(const XlaOp& operand,
PrimitiveType new_element_type);
// Enqueues a negate instruction onto the computation.
XlaOp Neg(const XlaOp& operand);
// Enqueues a transpose instruction onto the computation.
XlaOp Transpose(const XlaOp& operand, absl::Span<const int64> permutation);
// Enqueues a reverse instruction onto the computation. The order of the
// elements in the given dimensions is reversed (i.e., the element at index i
// is moved to index dimension_size - 1 - i).
XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions);
// Enqueues a sort (as increasing order) instruction onto the computation.
// If only keys are provided:
// * If the keys are an rank-1 tensor (an array), the result is a sorted array
// of keys, in ascending order.
// * If the keys have higher rank, the keys are sorted along the provided
// dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension
// value of 0 will indepenently sort every column, and a dimension value of 1
// will independently sort each row. If no dimension number is provided, then
// the last dimension is chosen by default.
//
// If both keys and values are provided:
// * The keys and all values must be tensors with the same dimensions. The
// element types of the tensors may be different.
// * The result is a tuple that consists of a sorted tensor of keys (along the
// provided dimension, as above) as the first element, and tensors with their
// corresponding values as the other elements.
XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values = {},
int64 dimension = -1);
// Enqueues a clamp instruction onto the computation.
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
// Enqueues a map instruction onto the computation.
XlaOp Map(absl::Span<const XlaOp> operands, const XlaComputation& computation,
absl::Span<const int64> dimensions,
absl::Span<const XlaOp> static_operands = {});
// Enqueues a N(mu, sigma) random number generation instruction onto the
// computation.
XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape);
// Enqueues a U(a, b) random number generation instruction onto the
// computation. Returns values in the semi-open interval [a, b).
XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
// Enqueues a while node onto the computation.
XlaOp While(const XlaComputation& condition, const XlaComputation& body,
const XlaOp& init);
// Enqueues a conditional node onto the computation.
XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
const XlaComputation& true_computation,
const XlaOp& false_operand,
const XlaComputation& false_computation);
// Enqueues a ReducePrecision node onto the computation.
XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
const int mantissa_bits);
// Enqueues a Gather node onto the computation.
XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
absl::Span<const int64> slice_sizes);
// Enqueues a Scatter node onto the computation.
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
const XlaOp& updates, const XlaComputation& update_computation,
const ScatterDimensionNumbers& dimension_numbers);
// Enqueues a Send node onto the computation for device-to-device
// communication, to send the given operand to a Recv instruction that shares
// the same channel handle.
void Send(const XlaOp& operand, const ChannelHandle& handle);
XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
const ChannelHandle& handle);
// Enqueues a Send node which sends data to the host.
XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
const Shape& shape_with_layout, const ChannelHandle& handle);
// Enqueues a Recv node which receives data from the host.
XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
const ChannelHandle& handle);
// Enqueues an AfterAll operation with no operands producing a token-shaped
// value.
XlaOp CreateToken();
// Enqueues an AfterAll operation with no operands producing a token-shaped
// value.
XlaOp AfterAll(absl::Span<const XlaOp> tokens);
// Enqueues a Recv node onto the computation. The data comes from a Send
// instruction that shares the same channel handle and its shape must
// be the same as the given shape.
XlaOp Recv(const Shape& shape, const ChannelHandle& handle);
XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
const ChannelHandle& handle);
// Normalizes operand across spatial and batch dimensions for each feature.
//
// Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
// is the normalized result and batch_mean and batch_var are the mean and
// variance, respectively, across batch for the operand.
XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
const XlaOp& offset, float epsilon,
int64 feature_index);
// Normalizes operand across spatial and batch dimensions for each feature.
//
// `BatchNormInference` is equivalent to calling `BatchNormTraining` without
// computing `mean` and `variance` for each batch inside the operation. It
// uses the input `mean` and `variance` instead as estimated values. The
// purpose of this op is to reduce latency in inference, hence the name
// `BatchNormInference`.
//
// The output has the same shape as `operand`, and contains the normalized
// values for each batch.
XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
const XlaOp& offset, const XlaOp& mean,
const XlaOp& variance, float epsilon,
int64 feature_index);
// Calculates the gradients of a batch norm op.
//
// The inputs `batch_mean` and `batch_var` represent the mean and variance
// across the batch.
//
// Returns a tuple of three elements:
// - grad_operand: Gradient with respect to input `operand`
// - grad_offset: Gradient with respect to input `offset`
// - grad_scale: Gradient with respect to input `scale`
XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
const XlaOp& batch_mean, const XlaOp& batch_var,
const XlaOp& grad_output, float epsilon,
@ -1409,6 +1121,7 @@ class XlaScopedShardingAssignment {
// Free functions for building XlaOps. The intention is that these will
// become the public API for building XlaOps rather than calling methods on
// XlaBuilder directly.
//
// Enqueues a "retrieve parameter value" instruction for a parameter that was
// passed to the computation.
@ -2154,6 +1867,7 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension);
// Implementation details below this point.
//
template <typename NativeT>
XlaOp XlaBuilder::ConstantR0(NativeT value) {