[XLA] Add free functions in the xla:: namespace that mirror XlaBuilder methods. These free functions can be used unqualified if desired (even outside xla:: because of argument-dependent lookup), removing the need to explicitly pass or call via an XlaBuilder for most methods.

The main exceptions to this rule are source nodes (e.g., Const, Parameter, Recv), or nodes with a possibly-zero arity (e.g., ConcatInDim, Tuple).

No tests are added for the new methods; the intention is to switch all existing users of the builder methods to the free functions, at which point they will have good test coverage.

PiperOrigin-RevId: 202167553
This commit is contained in:
Peter Hawkins 2018-06-26 11:54:20 -07:00 committed by TensorFlower Gardener
parent bfda539bef
commit 623513f265
2 changed files with 1243 additions and 24 deletions

View File

@ -2122,4 +2122,500 @@ StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
return &instructions_[op.handle()];
}
// Enqueues a "retrieve parameter value" instruction for a parameter that was
// passed to the computation.
XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
const string& name) {
return builder->Parameter(parameter_number, shape, name);
}
// Enqueues a constant with the value of the given literal onto the
// computation.
XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) {
return builder->ConstantLiteral(literal);
}
XlaOp Broadcast(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
return operand.builder()->Broadcast(operand, broadcast_sizes);
}
XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
const PaddingConfig& padding_config) {
return operand.builder()->Pad(operand, padding_value, padding_config);
}
XlaOp Reshape(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<int64> new_sizes) {
return operand.builder()->Reshape(operand, dimensions, new_sizes);
}
XlaOp Reshape(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> new_sizes) {
return operand.builder()->Reshape(operand, new_sizes);
}
XlaOp Collapse(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> dimensions) {
return operand.builder()->Collapse(operand, dimensions);
}
XlaOp Slice(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices,
tensorflow::gtl::ArraySlice<int64> strides) {
return operand.builder()->Slice(operand, start_indices, limit_indices,
strides);
}
XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
int64 stride, int64 dimno) {
return operand.builder()->SliceInDim(operand, start_index, limit_index,
stride, dimno);
}
XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
tensorflow::gtl::ArraySlice<int64> slice_sizes) {
return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes);
}
XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
const XlaOp& start_indices) {
return operand.builder()->DynamicUpdateSlice(operand, update, start_indices);
}
XlaOp ConcatInDim(XlaBuilder* builder,
tensorflow::gtl::ArraySlice<XlaOp> operands,
int64 dimension) {
return builder->ConcatInDim(operands, dimension);
}
void Trace(const string& tag, const XlaOp& operand) {
return operand.builder()->Trace(tag, operand);
}
XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false) {
return pred.builder()->Select(pred, on_true, on_false);
}
XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> elements) {
return builder->Tuple(elements);
}
XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index) {
return tuple_data.builder()->GetTupleElement(tuple_data, index);
}
XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return lhs.builder()->Eq(lhs, rhs, broadcast_dimensions);
}
XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return lhs.builder()->Ne(lhs, rhs, broadcast_dimensions);
}
XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return lhs.builder()->Ge(lhs, rhs, broadcast_dimensions);
}
XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return lhs.builder()->Gt(lhs, rhs, broadcast_dimensions);
}
XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return lhs.builder()->Lt(lhs, rhs, broadcast_dimensions);
}
XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return lhs.builder()->Le(lhs, rhs, broadcast_dimensions);
}
XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs) {
return lhs.builder()->Dot(lhs, rhs);
}
XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
const DotDimensionNumbers& dimension_numbers) {
return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers);
}
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
return lhs.builder()->Conv(lhs, rhs, window_strides, padding);
}
XlaOp ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides,
padding);
}
XlaOp ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
const ConvolutionDimensionNumbers& dimension_numbers) {
return lhs.builder()->ConvWithGeneralDimensions(lhs, rhs, window_strides,
padding, dimension_numbers);
}
XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers) {
return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding,
dimension_numbers);
}
XlaOp ConvGeneralDilated(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers) {
return lhs.builder()->ConvGeneralDilated(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation,
dimension_numbers);
}
XlaOp Fft(const XlaOp& operand, FftType fft_type,
tensorflow::gtl::ArraySlice<int64> fft_length) {
return operand.builder()->Fft(operand, fft_type, fft_length);
}
XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config) {
return builder->Infeed(shape, config);
}
void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
const string& outfeed_config) {
return operand.builder()->Outfeed(operand, shape_with_layout, outfeed_config);
}
XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
tensorflow::gtl::ArraySlice<XlaOp> operands) {
return builder->Call(computation, operands);
}
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
tensorflow::gtl::ArraySlice<XlaOp> operands,
const Shape& shape) {
return builder->CustomCall(call_target_name, operands, shape);
}
XlaOp HostCompute(XlaBuilder* builder,
tensorflow::gtl::ArraySlice<XlaOp> operands,
const string& channel_name, int64 cost_estimate_ns,
const Shape& shape) {
return builder->HostCompute(operands, channel_name, cost_estimate_ns, shape);
}
XlaOp Complex(const XlaOp& real, const XlaOp& imag,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return real.builder()->Add(real, imag, broadcast_dimensions);
}
XlaOp Conj(const XlaOp& operand) { return operand.builder()->Conj(operand); }
XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return lhs.builder()->Add(lhs, rhs, broadcast_dimensions);
}
XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return lhs.builder()->Sub(lhs, rhs, broadcast_dimensions);
}
XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return lhs.builder()->Mul(lhs, rhs, broadcast_dimensions);
}
XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return lhs.builder()->Div(lhs, rhs, broadcast_dimensions);
}
XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return lhs.builder()->Rem(lhs, rhs, broadcast_dimensions);
}
XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return lhs.builder()->Max(lhs, rhs, broadcast_dimensions);
}
XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return lhs.builder()->Min(lhs, rhs, broadcast_dimensions);
}
XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return lhs.builder()->And(lhs, rhs, broadcast_dimensions);
}
XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return lhs.builder()->Or(lhs, rhs, broadcast_dimensions);
}
XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return lhs.builder()->Xor(lhs, rhs, broadcast_dimensions);
}
XlaOp Not(const XlaOp& operand) { return operand.builder()->Not(operand); }
XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return lhs.builder()->ShiftLeft(lhs, rhs, broadcast_dimensions);
}
XlaOp ShiftRightArithmetic(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return lhs.builder()->ShiftRightArithmetic(lhs, rhs, broadcast_dimensions);
}
XlaOp ShiftRightLogical(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return lhs.builder()->ShiftRightLogical(lhs, rhs, broadcast_dimensions);
}
XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
return operand.builder()->Reduce(operand, init_value, computation,
dimensions_to_reduce);
}
XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation) {
return operand.builder()->ReduceAll(operand, init_value, computation);
}
XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,
Padding padding) {
return operand.builder()->ReduceWindow(operand, init_value, computation,
window_dimensions, window_strides,
padding);
}
XlaOp ReduceWindowWithGeneralPadding(
const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
return operand.builder()->ReduceWindowWithGeneralPadding(
operand, init_value, computation, window_dimensions, window_strides,
padding);
}
XlaOp CrossReplicaSum(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> replica_group_ids) {
return operand.builder()->CrossReplicaSum(operand, replica_group_ids);
}
XlaOp CrossReplicaSum(
const XlaOp& operand, const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> replica_group_ids,
const tensorflow::gtl::optional<ChannelHandle>& channel_id) {
return operand.builder()->CrossReplicaSum(operand, computation,
replica_group_ids, channel_id);
}
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,
Padding padding, const XlaOp& source,
const XlaOp& init_value, const XlaComputation& scatter) {
return operand.builder()->SelectAndScatter(operand, select, window_dimensions,
window_strides, padding, source,
init_value, scatter);
}
XlaOp SelectAndScatterWithGeneralPadding(
const XlaOp& operand, const XlaComputation& select,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
const XlaOp& source, const XlaOp& init_value,
const XlaComputation& scatter) {
return operand.builder()->SelectAndScatterWithGeneralPadding(
operand, select, window_dimensions, window_strides, padding, source,
init_value, scatter);
}
XlaOp Abs(const XlaOp& operand) { return operand.builder()->Abs(operand); }
XlaOp Atan2(const XlaOp& y, const XlaOp& x,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return y.builder()->Atan2(y, x, broadcast_dimensions);
}
XlaOp Exp(const XlaOp& operand) { return operand.builder()->Exp(operand); }
XlaOp Expm1(const XlaOp& operand) { return operand.builder()->Expm1(operand); }
XlaOp Floor(const XlaOp& operand) { return operand.builder()->Floor(operand); }
XlaOp Ceil(const XlaOp& operand) { return operand.builder()->Ceil(operand); }
XlaOp Round(const XlaOp& operand) { return operand.builder()->Round(operand); }
XlaOp Log(const XlaOp& operand) { return operand.builder()->Log(operand); }
XlaOp Log1p(const XlaOp& operand) { return operand.builder()->Log1p(operand); }
XlaOp Sign(const XlaOp& operand) { return operand.builder()->Sign(operand); }
XlaOp Clz(const XlaOp& operand) { return operand.builder()->Clz(operand); }
XlaOp Cos(const XlaOp& operand) { return operand.builder()->Cos(operand); }
XlaOp Sin(const XlaOp& operand) { return operand.builder()->Sin(operand); }
XlaOp Tanh(const XlaOp& operand) { return operand.builder()->Tanh(operand); }
XlaOp Real(const XlaOp& operand) { return operand.builder()->Real(operand); }
XlaOp Imag(const XlaOp& operand) { return operand.builder()->Imag(operand); }
XlaOp SqrtF32(const XlaOp& operand) {
return operand.builder()->SqrtF32(operand);
}
XlaOp SquareF32(const XlaOp& operand) {
return operand.builder()->SquareF32(operand);
}
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return lhs.builder()->Pow(lhs, rhs, broadcast_dimensions);
}
XlaOp IsFinite(const XlaOp& operand) {
return operand.builder()->IsFinite(operand);
}
XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type) {
return operand.builder()->ConvertElementType(operand, new_element_type);
}
XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) {
return operand.builder()->BitcastConvertType(operand, new_element_type);
}
XlaOp ReciprocalF32(const XlaOp& operand) {
return operand.builder()->ReciprocalF32(operand);
}
XlaOp Neg(const XlaOp& operand) { return operand.builder()->Neg(operand); }
XlaOp Transpose(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> permutation) {
return operand.builder()->Transpose(operand, permutation);
}
XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
return operand.builder()->Rev(operand, dimensions);
}
XlaOp Sort(const XlaOp& operand) { return operand.builder()->Sort(operand); }
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) {
return min.builder()->Clamp(min, operand, max);
}
XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> operands,
const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<XlaOp> static_operands) {
return builder->Map(operands, computation, dimensions, static_operands);
}
XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape) {
return mu.builder()->RngNormal(mu, sigma, shape);
}
XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape) {
return a.builder()->RngUniform(a, b, shape);
}
XlaOp While(const XlaComputation& condition, const XlaComputation& body,
const XlaOp& init) {
return init.builder()->While(condition, body, init);
}
XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
const XlaComputation& true_computation,
const XlaOp& false_operand,
const XlaComputation& false_computation) {
return predicate.builder()->Conditional(predicate, true_operand,
true_computation, false_operand,
false_computation);
}
XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
const int mantissa_bits) {
return operand.builder()->ReducePrecision(operand, exponent_bits,
mantissa_bits);
}
XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
const GatherDimensionNumbers& dimension_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds) {
return input.builder()->Gather(input, gather_indices, dimension_numbers,
window_bounds);
}
void Send(const XlaOp& operand, const ChannelHandle& handle) {
return operand.builder()->Send(operand, handle);
}
XlaOp Recv(XlaBuilder* builder, const Shape& shape,
const ChannelHandle& handle) {
return builder->Recv(shape, handle);
}
XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
const XlaOp& offset, float epsilon,
int64 feature_index) {
return operand.builder()->BatchNormTraining(operand, scale, offset, epsilon,
feature_index);
}
XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
const XlaOp& offset, const XlaOp& mean,
const XlaOp& variance, float epsilon,
int64 feature_index) {
return operand.builder()->BatchNormInference(
operand, scale, offset, mean, variance, epsilon, feature_index);
}
XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
const XlaOp& batch_mean, const XlaOp& batch_var,
const XlaOp& grad_output, float epsilon,
int64 feature_index) {
return operand.builder()->BatchNormGrad(operand, scale, batch_mean, batch_var,
grad_output, epsilon, feature_index);
}
} // namespace xla

View File

@ -973,6 +973,670 @@ class XlaBuilder {
XlaBuilder* parent_builder_{nullptr};
};
// RAII-style object: sets the current sharding assignment in builder on
// construction, and sets back to the previous assignment on destruction.
class XlaScopedShardingAssignment {
public:
XlaScopedShardingAssignment(xla::XlaBuilder* builder,
tensorflow::gtl::optional<OpSharding> sharding)
: builder_(builder), prev_sharding_(builder->sharding()) {
SetSharding(sharding);
}
XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete;
XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) =
delete;
~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); }
private:
void SetSharding(const tensorflow::gtl::optional<OpSharding>& sharding) {
if (sharding.has_value()) {
builder_->SetSharding(sharding.value());
} else {
builder_->ClearSharding();
}
}
xla::XlaBuilder* const builder_;
tensorflow::gtl::optional<OpSharding> prev_sharding_;
};
// Free functions for building XlaOps. The intention is that these will
// become the public API for building XlaOps rather than calling methods on
// XlaBuilder directly.
// Enqueues a "retrieve parameter value" instruction for a parameter that was
// passed to the computation.
XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
const string& name);
// Enqueues a constant with the value of the given literal onto the
// computation.
XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal);
// Enqueues a constant onto the computation. Methods are templated on the
// native host type (NativeT) which corresponds to a specific XLA
// PrimitiveType as given in the following table:
//
// Native Type PrimitiveType
// -----------------------------
// bool PRED
// int32 S32
// int64 S64
// uint32 U32
// uint64 U64
// float F32
// double F64
//
// Note: not all primitive types defined in xla_data.proto have a
// corresponding native type yet.
template <typename NativeT>
XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
template <typename NativeT>
XlaOp ConstantR1(XlaBuilder* builder,
tensorflow::gtl::ArraySlice<NativeT> values);
XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values);
template <typename NativeT>
XlaOp ConstantR2(XlaBuilder* builder,
std::initializer_list<std::initializer_list<NativeT>> values);
template <typename NativeT>
XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
const Array<NativeT>& values,
const Layout& layout);
template <typename NativeT>
XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values);
template <typename NativeT>
XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
const Array2D<NativeT>& values,
const Layout& layout);
template <typename NativeT>
XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
const Array2D<NativeT>& values);
template <typename NativeT>
XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
const Array3D<NativeT>& values,
const Layout& layout);
template <typename NativeT>
XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
const Array3D<NativeT>& values);
template <typename NativeT>
XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
const Array4D<NativeT>& values,
const Layout& layout);
template <typename NativeT>
XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
const Array4D<NativeT>& values);
// Enqueues a rank one constant (XlaBuilder* builder, vector) onto the
// computation. The vector has size 'length' and every element has the value
// 'value'.
template <typename NativeT>
XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);
// Adds dimensions to an array by duplicating the data in the array.
//
// The new dimensions are inserted on the left, i.e. if
// broadcast_sizes has values {a0, ..., aN} and the operand shape
// has dimensions {b0, ..., bM} then the shape of the output has
// dimensions {a0, ..., aN, b0, ..., bM}.
//
// The new dimensions index into copies of the operand, i.e.
//
// output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
XlaOp Broadcast(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
// Enqueues a pad operation onto the computation that pads the given value on
// the edges as well as between the elements of the input. padding_config
// specifies the padding amount for each dimension.
XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
const PaddingConfig& padding_config);
// Enqueues an operation onto the computation that flattens the operand based
// on the dimension order (major/slowest-varying to minor/fastest-varying)
// given, followed by reshaping it into the shape with the given dimension
// sizes (also major to minor). Conceptually, this is a limited form of
// "shape casting".
XlaOp Reshape(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<int64> new_sizes);
// Enqueues an operation onto the computation that collapses the operand, from
// first to last dimension (C order), then reshapes it to the given dimension
// sizes. Conceptually, this is a limited form of "shape casting".
XlaOp Reshape(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> new_sizes);
// Wrapper for Reshape.
// Enqueues an operation to collapse the provided dimensions; e.g. an
// operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
// {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
// be a consecutive, in-order subsequence of the operand dimensions.
//
// Note that collapsing a single dimension does nothing:
//
// {256} collapsing {0} => {256}
// {1} collapsing {0} => {1}
//
// Collapsing multiple dimensions produces a single result dimension:
//
// {256, 2} collapsing {0,1} => {512}
// {256, 2, 3} collapsing {0,1} => {512, 3}
//
// This could potentially cause data to be moved -- it provides a more
// structured form of reshaping than an arbitrary Reshape operation.
XlaOp Collapse(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> dimensions);
// Enqueues a slice operation onto the computation that slices the operand
// from the start indices to the limit indices; e.g.
//
// x
// [ 0 1 2 3 ]
// y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
// [ 8 9 a b ]
//
// Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
// range notation.
// The strides parameter determines the stride over the slice
XlaOp Slice(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices,
tensorflow::gtl::ArraySlice<int64> strides);
// Enqueues a slice operation in a given dimension, taking all other
// dimensions as they are; e.g. if dimno is 1 from start_index 2 to
// limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
// for:
//
// array[:, 2:4:1, :]
XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
int64 stride, int64 dimno);
// Enqueues a slice operation onto the computation that slices the 'operand'
// from dynamic start indices which are passed in 'start_indices'.
// The size of the slice in each dimension is passed in 'slice_sizes',
// which specify the end point of exclusive slice intervals in each
// dimension [start, start + size).
// The shape of 'start_indices' must be rank == 1, with dimension size
// equal to the rank of the 'operand'.
// Slice index calculations are computed modulo input dimension sizes to
// prevent dynamic start indices from generating out-of-bound array accesses.
XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
tensorflow::gtl::ArraySlice<int64> slice_sizes);
// Enqueues a dynamic update slice operation onto the computation, which
// updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
// The shape of 'update' determines the shape of the slice of 'operand'
// which is updated.
// The indices specified in 'start_indices' specify the offset of the slice
// of 'operand' which is updated.
//
// update = {10, 11} // calculated at runtime.
// [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ]
// [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11]
// [7 8 9] [7 8 9 ]
//
// The shape of 'start_indices' must be rank == 1, with dimension size
// equal to the rank of the 'operand'.
// Slice index calculations are computed modulo update dimension sizes to
// prevent dynamic start indices from generating out-of-bound array accesses.
XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
const XlaOp& start_indices);
// Enqueues a concatenate instruction onto the computation. 'operands' must
// have >= 1 entry.
XlaOp ConcatInDim(XlaBuilder* builder,
tensorflow::gtl::ArraySlice<XlaOp> operands, int64 dimension);
// Enqueue a tracing operation onto the computation; the computation will emit
// a logging message with the operand.
void Trace(const string& tag, const XlaOp& operand);
// Enqueues a conditional-move-like select operation onto the computation;
// predicated on pred, selects between on_true and on_false.
XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false);
// Enqueues a tuple-creation instruction onto the computation.
XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> elements);
// Enqueues a tuple-element-get instruction onto the computation.
XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
// Enqueues an equal-to comparison instruction onto the computation.
XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
// Enqueues a not-equal comparison instruction onto the computation.
XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
// Enqueues a greater-or-equal comparison instruction onto the computation.
XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
// Enqueues a greater-than comparison instruction onto the computation.
XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
// Enqueues a less-than comparison instruction onto the computation.
XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
// Enqueues a less-or-equal comparison instruction onto the computation.
XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
// Enqueues a dot instruction onto the computation.
XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs);
// Enqueues a general dot instruction onto the computation.
XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
const DotDimensionNumbers& dimension_numbers);
// Enqueues a convolution instruction onto the computation, which uses the
// default convolution dimension numbers.
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration in the format returned by MakePadding().
XlaOp ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
// Enqueues a convolution instruction onto the computation, with the caller
// provided dimension numbers configuration.
XlaOp ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
const ConvolutionDimensionNumbers& dimension_numbers);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration as well as the dimension numbers.
XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration, dilation factors and dimension numbers.
XlaOp ConvGeneralDilated(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers);
// Enqueues an FFT instruction onto the computation, of the given type and
// with the given FFT length.
XlaOp Fft(const XlaOp& operand, FftType fft_type,
tensorflow::gtl::ArraySlice<int64> fft_length);
// Enqueues an infeed instruction onto the computation, which writes data of
// the given shape to the infeed buffer of the device.
XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
const string& config = "");
// Enqueues an outfeed instruction onto the computation. This instruction
// generates outgoing data transfers for the given data.
//
// shape_with_layout communicates the laid out shape that we want to outfeed
// -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
// will occur.
void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
const string& outfeed_config);
// Enqueues a call instruction onto the computation.
XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
tensorflow::gtl::ArraySlice<XlaOp> operands);
// Enqueues a custom call instruction onto the computation.
// During code generation, a call instruction is emitted which targets a
// symbol with the name |call_target_name|. The |operands| are passed to the
// call instruction. |shape| is the resultant shape.
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
tensorflow::gtl::ArraySlice<XlaOp> operands,
const Shape& shape);
// Enqueues a pseudo-op to represent host-side computation data-dependencies.
// During code generation, host send and receive operations will be generated
// to transfer |operands| to the host and a single result of |shape| back to
// the device. Host send/recv operations are emitted using |channel_name|.
// Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO
// instruction scheduling.
XlaOp HostCompute(XlaBuilder* builder,
tensorflow::gtl::ArraySlice<XlaOp> operands,
const string& channel_name, int64 cost_estimate_ns,
const Shape& shape);
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
// of the operands is a scalar, or an explicit broadcast dimension is given
// (see g3doc for more details).
// Enqueues a complex compose instruction onto the computation.
XlaOp Complex(const XlaOp& real, const XlaOp& imag,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
// Enqueues a complex conjugate instruction onto the computation.
XlaOp Conj(const XlaOp& operand);
// Enqueues an add instruction onto the computation.
XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
// Enqueues a subtract instruction onto the computation.
XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
// Enqueues a multiply instruction onto the computation.
XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
// Enqueues a divide instruction onto the computation.
XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
// Enqueues a remainder instruction onto the computation.
XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
// Enqueues a max instruction onto the computation.
XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
// Enqueues a min instruction onto the computation.
XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
// Element-wise logical operators
XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
XlaOp Not(const XlaOp& operand);
XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
XlaOp ShiftRightArithmetic(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
XlaOp ShiftRightLogical(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
// Reduces an array among the provided dimensions, given "computation" as a
// reduction operator.
XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
// Convenience wrapper around the above that reduces all the dimensions in the
// operand shape.
XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation);
// Enqueues a windowed reduce instruction onto the computation.
XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,
Padding padding);
// As ReduceWindow(), but the padding is given in the format
// returned by MakePadding().
XlaOp ReduceWindowWithGeneralPadding(
const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
// Returns the sum of the operand value within each subgroup of replicas. All
// replicas supply one input to the sum and all replicas receive the resulting
// sum for each subgroup.
XlaOp CrossReplicaSum(
const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> replica_group_ids = {});
// Enqueues an operation that do an AllReduce of the operand cross cores. Here
// AllReduce means doing a reduction on the input operand cross cores and then
// broadcasting the reduction result to those cores. The reduction function is
// defined by `computation`, which should be a commutative computation on
// scalars, e.g., add, min, or max. The way that AllReduce is applied is
// configured by:
//
// - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all
// replicas belong to one group. Allreduce will be applied within subgroups.
// For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
//
// - `channel_id`: for Allreduce nodes from different models, if they have the
// same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be
// applied cross models.
//
// TODO(b/79737069): Rename this to AllReduce when it's ready to use.
XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> replica_group_ids = {},
const tensorflow::gtl::optional<ChannelHandle>&
channel_id = tensorflow::gtl::nullopt);
// Enqueues an operation that scatters the `source` array to the selected
// indices of each window.
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,
Padding padding, const XlaOp& source,
const XlaOp& init_value, const XlaComputation& scatter);
// As SelectAndScatter(), but the padding is given in the format
// returned by MakePadding().
XlaOp SelectAndScatterWithGeneralPadding(
const XlaOp& operand, const XlaComputation& select,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
const XlaOp& source, const XlaOp& init_value,
const XlaComputation& scatter);
// Enqueues an abs instruction onto the computation.
XlaOp Abs(const XlaOp& operand);
// Enqueues a atan2 instruction onto the computation.
XlaOp Atan2(const XlaOp& y, const XlaOp& x,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
// Enqueues an exp instruction onto the computation.
XlaOp Exp(const XlaOp& operand);
// Enqueues an expm1 instruction onto the computation.
XlaOp Expm1(const XlaOp& operand);
// Enqueues a floor instruction onto the computation.
XlaOp Floor(const XlaOp& operand);
// Enqueues a ceil instruction onto the computation.
XlaOp Ceil(const XlaOp& operand);
// Enqueues a round instruction onto the computation, rounding to nearest even
// with half-way cases rounding away from zero.
XlaOp Round(const XlaOp& operand);
// Enqueues an log instruction (natural logarithm) onto the computation.
XlaOp Log(const XlaOp& operand);
// Enqueues an log1p instruction (log(x+1)) onto the computation.
XlaOp Log1p(const XlaOp& operand);
// Enqueues a sign instruction onto the computation.
XlaOp Sign(const XlaOp& operand);
// Enqueues a count leading zeros instruction onto the computation.
XlaOp Clz(const XlaOp& operand);
// Enqueues a cosine instruction onto the computation.
XlaOp Cos(const XlaOp& operand);
// Enqueues a sine instruction onto the computation.
XlaOp Sin(const XlaOp& operand);
// Enqueues a tanh instruction onto the computation.
XlaOp Tanh(const XlaOp& operand);
// Enqueues a real-part instruction onto the computation.
XlaOp Real(const XlaOp& operand);
// Enqueues an imaginary-part instruction onto the computation.
XlaOp Imag(const XlaOp& operand);
// Enqueues a float32 sqrt instruction onto the computation.
// (float32 is specified as there is an implicit float32 0.5f constant
// exponent).
XlaOp SqrtF32(const XlaOp& operand);
// Enqueues a float32 square instruction onto the computation.
// (float32 is specified as there is an implicit float32 2.0f constant
// exponent).
XlaOp SquareF32(const XlaOp& operand);
// Enqueues a lhs^rhs computation onto the computation.
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
// Enqueues an operator that tests if the operand's values are finite, i.e.,
// not Inf or NaN. Defined only for floating-point types. Returns an array of
// booleans with the same shape where entries are true iff the corresponding
// entry was NaN.
XlaOp IsFinite(const XlaOp& operand);
// Enqueues a convert instruction onto the computation that changes the
// element type of the operand array to primitive_type.
XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type);
// Enqueues a no-op instruction onto the computation that changes
// the element type of the operand array to primitive_type. The
// bit-widths of the source and destination element types must be
// identical.
XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type);
// Enqueues a float32 reciprocal instruction onto the computation.
// (float32 is specified as there is an implicit float32 -1.0f constant
// exponent).
//
// TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the
// shape of the operand.
XlaOp ReciprocalF32(const XlaOp& operand);
// Enqueues a negate instruction onto the computation.
XlaOp Neg(const XlaOp& operand);
// Enqueues a transpose instruction onto the computation.
XlaOp Transpose(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> permutation);
// Enqueues a reverse instruction onto the computation. The order of the
// elements in the given dimensions is reversed (i.e., the element at index i
// is moved to index dimension_size - 1 - i).
XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions);
// Enqueues a sort (as increasing order) instruction onto the computation.
XlaOp Sort(const XlaOp& operand);
// Enqueues a clamp instruction onto the computation.
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
// Enqueues a map instruction onto the computation.
XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> operands,
const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<XlaOp> static_operands = {});
// Enqueues a N(mu, sigma) random number generation instruction onto the
// computation.
XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape);
// Enqueues a U(a, b) random number generation instruction onto the
// computation. Returns values in the semi-open interval [a, b).
XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
// Enqueues a while node onto the computation.
XlaOp While(const XlaComputation& condition, const XlaComputation& body,
const XlaOp& init);
// Enqueues a conditional node onto the computation.
XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
const XlaComputation& true_computation,
const XlaOp& false_operand,
const XlaComputation& false_computation);
// Enqueues a ReducePrecision node onto the computation.
XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
const int mantissa_bits);
// Enqueues a Gather node onto the computation.
XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
const GatherDimensionNumbers& dimension_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds);
// Enqueues a Send node onto the computation, to send the given operand to
// a Recv instruction that shares the same channel handle.
void Send(const XlaOp& operand, const ChannelHandle& handle);
// Enqueues a Recv node onto the computation. The data comes from a Send
// instruction that shares the same channel handle and its shape must
// be the same as the given shape.
XlaOp Recv(XlaBuilder* builder, const Shape& shape,
const ChannelHandle& handle);
// Normalizes operand across spatial and batch dimensions for each feature.
//
// Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
// is the normalized result and batch_mean and batch_var are the mean and
// variance, respectively, across batch for the operand.
XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
const XlaOp& offset, float epsilon,
int64 feature_index);
// Normalizes operand across spatial and batch dimensions for each feature.
//
// `BatchNormInference` is equivalent to calling `BatchNormTraining` without
// computing `mean` and `variance` for each batch inside the operation. It
// uses the input `mean` and `variance` instead as estimated values. The
// purpose of this op is to reduce latency in inference, hence the name
// `BatchNormInference`.
//
// The output has the same shape as `operand`, and contains the normalized
// values for each batch.
XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
const XlaOp& offset, const XlaOp& mean,
const XlaOp& variance, float epsilon,
int64 feature_index);
// Calculates the gradients of a batch norm op.
//
// The inputs `batch_mean` and `batch_var` represent the mean and variance
// across the batch.
//
// Returns a tuple of three elements:
// - grad_operand: Gradient with respect to input `operand`
// - grad_offset: Gradient with respect to input `offset`
// - grad_scale: Gradient with respect to input `scale`
XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
const XlaOp& batch_mean, const XlaOp& batch_var,
const XlaOp& grad_output, float epsilon,
int64 feature_index);
// Implementation details below this point.
template <typename NativeT>
XlaOp XlaBuilder::ConstantR0(NativeT value) {
return ConstantLiteral(*Literal::CreateR0<NativeT>(value));
@ -1048,34 +1712,93 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D<NativeT>& values) {
return ConstantFromArray(values);
}
// RAII-style object: sets the current sharding assignment in builder on
// construction, and sets back to the previous assignment on destruction.
class XlaScopedShardingAssignment {
public:
XlaScopedShardingAssignment(xla::XlaBuilder* builder,
tensorflow::gtl::optional<OpSharding> sharding)
: builder_(builder), prev_sharding_(builder->sharding()) {
SetSharding(sharding);
}
// Free function template implementations.
XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete;
XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) =
delete;
template <typename NativeT>
XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
return ConstantLiteral(builder, *Literal::CreateR0<NativeT>(value));
}
~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); }
template <typename NativeT>
XlaOp ConstantR1(XlaBuilder* builder,
tensorflow::gtl::ArraySlice<NativeT> values) {
return ConstantLiteral(builder, *Literal::CreateR1<NativeT>(values));
}
private:
void SetSharding(const tensorflow::gtl::optional<OpSharding>& sharding) {
if (sharding.has_value()) {
builder_->SetSharding(sharding.value());
} else {
builder_->ClearSharding();
}
}
template <typename NativeT>
XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) {
Literal literal(ShapeUtil::MakeShape(
primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
literal.PopulateWithValue(value);
return ConstantLiteral(builder, literal);
}
xla::XlaBuilder* const builder_;
tensorflow::gtl::optional<OpSharding> prev_sharding_;
};
inline XlaOp ConstantR1(XlaBuilder* builder,
const tensorflow::core::Bitmap& values) {
return ConstantLiteral(builder, *Literal::CreateR1(values));
}
template <typename NativeT>
XlaOp ConstantR2(XlaBuilder* builder,
std::initializer_list<std::initializer_list<NativeT>> values) {
return ConstantLiteral(builder, *Literal::CreateR2<NativeT>(values));
}
template <typename NativeT>
XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
const Array<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
builder, *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
return ConstantLiteral(builder, *Literal::CreateFromArray<NativeT>(values));
}
template <typename NativeT>
XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
const Array2D<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
builder, *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
const Array2D<NativeT>& values) {
return ConstantLiteral(builder,
*Literal::CreateR2FromArray2D<NativeT>(values));
}
template <typename NativeT>
XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
const Array3D<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
builder,
*Literal::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
const Array3D<NativeT>& values) {
return ConstantFromArray(builder, values);
}
template <typename NativeT>
XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
const Array4D<NativeT>& values,
const Layout& layout) {
return ConstantFromArrayWithLayout(builder, values, layout);
}
template <typename NativeT>
XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
const Array4D<NativeT>& values) {
return ConstantFromArray(builder, values);
}
} // namespace xla