711 lines
25 KiB
Protocol Buffer
711 lines
25 KiB
Protocol Buffer
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
syntax = "proto3";
|
|
|
|
package xla;
|
|
|
|
option cc_enable_arenas = true;
|
|
|
|
// Primitive types are the individual values that can be held in rectangular
|
|
// multidimensional arrays. A description of the rectangular multidimensional
|
|
// array dimensions / primitive type is given by Shape, below.
|
|
//
|
|
// LINT.IfChange
|
|
enum PrimitiveType {
|
|
// Invalid primitive type to serve as default.
|
|
PRIMITIVE_TYPE_INVALID = 0;
|
|
|
|
// Predicates are two-state booleans.
|
|
PRED = 1;
|
|
|
|
// Signed integral values of fixed width.
|
|
S8 = 2;
|
|
S16 = 3;
|
|
S32 = 4;
|
|
S64 = 5;
|
|
|
|
// Unsigned integral values of fixed width.
|
|
U8 = 6;
|
|
U16 = 7;
|
|
U32 = 8;
|
|
U64 = 9;
|
|
|
|
// Floating-point values of fixed width.
|
|
//
|
|
// Note: if f16s are not natively supported on the device, they will be
|
|
// converted to f16 from f32 at arbirary points in the computation.
|
|
F16 = 10;
|
|
F32 = 11;
|
|
|
|
// Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit
|
|
// floating-point format, but uses 1 bit for the sign, 8 bits for the exponent
|
|
// and 7 bits for the mantissa.
|
|
BF16 = 16;
|
|
|
|
F64 = 12;
|
|
|
|
// Complex values of fixed width.
|
|
C64 = 15; // Paired F32 (real, imag), as in std::complex<float>.
|
|
C128 = 18; // Paired F64 (real, imag), as in std::complex<double>.
|
|
|
|
// A tuple is a polymorphic sequence; e.g. a shape that holds different
|
|
// sub-shapes. They are used for things like returning multiple values from a
|
|
// computation; e.g. a computation that returns weights and biases may have a
|
|
// signature that results in a tuple like (f32[784x2000], f32[2000])
|
|
//
|
|
// If a shape proto has the tuple element type, it may not have any entries
|
|
// in the dimensions field.
|
|
TUPLE = 13;
|
|
|
|
// An opaque type used for passing context-specific data to a custom
|
|
// operation. Shapes of this primitive type will have empty dimensions and
|
|
// tuple_shapes fields.
|
|
//
|
|
// (OPAQUE would be a better name for this identifier, but that conflicts with
|
|
// a macro defined in windows.h.)
|
|
OPAQUE_TYPE = 14;
|
|
|
|
// A token type threaded between side-effecting operations. Shapes of this
|
|
// primitive type will have empty dimensions and tuple_shapes fields.
|
|
TOKEN = 17;
|
|
|
|
// Next = 19
|
|
}
|
|
// LINT.ThenChange(
|
|
// https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc
|
|
// )
|
|
|
|
// Describes the padding configuration for Pad operation. The padding amount on
|
|
// both edges as well as between the elements are specified for each dimension.
|
|
message PaddingConfig {
|
|
// Describes the padding configuration for a dimension.
|
|
message PaddingConfigDimension {
|
|
// Padding amount on the low-end (next to the index 0). May be negative.
|
|
int64 edge_padding_low = 1;
|
|
|
|
// Padding amount on the high-end (next to the highest index). May be
|
|
// negative.
|
|
int64 edge_padding_high = 2;
|
|
|
|
// Padding amount between the elements. May not be negative.
|
|
int64 interior_padding = 3;
|
|
}
|
|
|
|
// The padding configuration for all dimensions.
|
|
repeated PaddingConfigDimension dimensions = 1;
|
|
}
|
|
|
|
// A format specifies the method used by a layout to store an array in memory.
|
|
enum Format {
|
|
// TODO(b/120869032): Rename this to FORMAT_NONE or something else which
|
|
// better corresponds to its meaning.
|
|
INVALID_FORMAT = 0;
|
|
// The default layout, with exactly one storage location per element.
|
|
DENSE = 1;
|
|
reserved 2;
|
|
reserved "SPARSE";
|
|
}
|
|
|
|
// Describes a tile used in tiling-based layout. Refer to
|
|
// g3doc/third_party/tensorflow/compiler/xla/g3doc/tiled_layout.md for
|
|
// details about tiling-based layout.
|
|
message TileProto {
|
|
// Number of elements in each dimension of the tile. It's ordered from the
|
|
// most major dimension of the tile to the most minor dimension of the tile.
|
|
// The dimensions correspond to a suffix of the dimensions of the shape being
|
|
// tiled.
|
|
repeated int64 dimensions = 1;
|
|
}
|
|
|
|
// A layout describes how the array is placed in (1D) memory space. This
|
|
// includes the minor-to-major ordering of dimensions within a shape.
|
|
//
|
|
// Clients must specify the layouts of input Literals to the
|
|
// computation. Layouts specified in interior operations which take Shapes (for
|
|
// example, Convert) are ignored.
|
|
//
|
|
// See the XLA documentation for more information on shapes and layouts.
|
|
//
|
|
// LINT.IfChange
|
|
message LayoutProto {
|
|
// The method used to store the data in memory. The format determines which of
|
|
// the other fields are used by the layout.
|
|
Format format = 4;
|
|
|
|
// Sequence of dimension numbers, from minor (fastest varying index) to major
|
|
// (slowest varying index). This field is required.
|
|
repeated int64 minor_to_major = 1;
|
|
|
|
reserved 2;
|
|
reserved "padded_dimensions";
|
|
|
|
reserved 3;
|
|
reserved "padding_value";
|
|
|
|
reserved 5;
|
|
reserved "max_sparse_elements";
|
|
|
|
// A sequence of tiles, starting from the tile that's applied first to the
|
|
// Shape.
|
|
//
|
|
// TODO(b/119839262): implement tiling in each backend or add Unimplemented
|
|
// error.
|
|
repeated TileProto tiles = 6;
|
|
|
|
// Bit size of each element. If the size is bigger than what the element
|
|
// type requires, the value is stored in the least significant
|
|
// bits and the additional most significant bits are filled with 0's.
|
|
//
|
|
// TODO(b/119839262): implement in each backend or add Unimplemented error.
|
|
int64 element_size_in_bits = 7;
|
|
|
|
// Memory space where this array resides. The integer field is interpreted in
|
|
// a backend-specific manner.
|
|
int64 memory_space = 8;
|
|
|
|
// Important: if any field is added, be sure to modify ShapeUtil::Equal() and
|
|
// LayoutUtil::Hash appropriately to account for the new field.
|
|
}
|
|
// LINT.ThenChange( \
|
|
// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, \
|
|
// https://www.tensorflow.org/code/tensorflow/compiler/xla/layout_util.cc)
|
|
|
|
// A shape describes the number of dimensions in the array, the size of each
|
|
// dimension, and the primitive component type.
|
|
//
|
|
// Tuples are a special case in that they have rank zero and have tuple_shapes
|
|
// defined.
|
|
//
|
|
// See the XLA documentation for more information on shapes and layouts.
|
|
//
|
|
// LINT.IfChange
|
|
message ShapeProto {
|
|
reserved 1;
|
|
reserved "rank";
|
|
|
|
// The element type for this shape.
|
|
PrimitiveType element_type = 2;
|
|
|
|
// The size (number of elements) for each dimension, or an upper bound on the
|
|
// size if the dimension is dynamic. In XLA, dimensions are numbered from 0
|
|
// to N-1 for an N-dimensional array. The first element of 'dimensions' is the
|
|
// size of dimension 0, the second element is the size of dimension 1, and so
|
|
// forth. Empty list indicates a scalar.
|
|
//
|
|
// If the respective element in 'is_dimension_dynamic' is true then the value
|
|
// in this field represents an upper bound on the size of the dimension.
|
|
repeated int64 dimensions = 3;
|
|
|
|
// For tuples only, the shapes of constituent shapes in the tuple sequence.
|
|
repeated ShapeProto tuple_shapes = 4;
|
|
|
|
// The layout used to back this shape.
|
|
LayoutProto layout = 5;
|
|
|
|
// For arrays, this indicates whether or not each dimension is
|
|
// dynamically-sized. The number of elements in this repeated field should be
|
|
// zero (indicating that no dimensions are dynamic) or equal to the number of
|
|
// elements in the 'dimensions' field.
|
|
repeated bool is_dynamic_dimension = 6;
|
|
|
|
// Important: if any field is added, be sure to modify ShapeUtil::Equal(),
|
|
// ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for
|
|
// the new field.
|
|
}
|
|
// LINT.ThenChange( \
|
|
// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc)
|
|
|
|
// Shape of the parameters and output of a computation (like a traditional
|
|
// function signature).
|
|
message ProgramShapeProto {
|
|
repeated ShapeProto parameters = 1;
|
|
ShapeProto result = 2;
|
|
repeated string parameter_names = 3;
|
|
}
|
|
|
|
// Statistics of a computation.
|
|
message ComputationStats {
|
|
// The number of floating point operations in the computation.
|
|
double flop_count = 1;
|
|
|
|
// The number of transcendental operations (e.g., exp) in the computation.
|
|
double transcendental_count = 2;
|
|
}
|
|
|
|
// The type optimization profiles in use.
|
|
enum ProfileType {
|
|
INVALID = 0;
|
|
WINDOW = 1;
|
|
FLAG = 2;
|
|
}
|
|
|
|
// Symbolization metadata for HLO Instructions.
|
|
//
|
|
// This metadata is used for debugging XLA code generation, as well as
|
|
// performance profiling of XLA-generated executables.
|
|
message OpMetadata {
|
|
// The framework op name that generated this XLA op.
|
|
//
|
|
// Frameworks that build on top of XLA should mirror the names of their ops
|
|
// back to users by specifying the op_type. In this way, even if the
|
|
// framework's "ops" are implemented as multiple XLA HLO Ops, they can be
|
|
// grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as
|
|
// multiple ops, then each op should have the op_type be "SoftMax".)
|
|
string op_type = 1;
|
|
// The user-specified name of the op.
|
|
//
|
|
// This name is often unique within a computation. Note: some frameworks
|
|
// add auto-generated names if the user does not provide one.
|
|
string op_name = 2;
|
|
// Indicate a file and line that this op is associated to in a user's program.
|
|
//
|
|
// e.g. it could be the file and line of user code that generated the op.
|
|
string source_file = 3;
|
|
int32 source_line = 4;
|
|
|
|
repeated ProfileType profile_type = 5;
|
|
}
|
|
|
|
// Profile data from the execution of a computation.
|
|
message ExecutionProfile {
|
|
// Whether the executable was read from the compilation cache.
|
|
bool compilation_cache_hit = 1;
|
|
|
|
// The time in milliseconds spent to compile the computation. This only set if
|
|
// the executable was not read from the compilation cache
|
|
// (compilation_cache_hit == false).
|
|
int64 compile_time_ms = 2;
|
|
|
|
// The number of cycles spent for the computation. This does not include the
|
|
// time taken for the data transfers between the host and the device. This is
|
|
// a target-dependent field and only used for debugging purposes.
|
|
int64 compute_cycle_count = 3;
|
|
|
|
// The time in nanoseconds spent for the computation, without data transfer.
|
|
int64 compute_time_ns = 4;
|
|
|
|
// The time in nanoseconds spent for the entire computation, including the
|
|
// result data transfer time. Current implementation does not spend any cycles
|
|
// for the input data transfer since the memory is initialized with the proper
|
|
// values before the execution.
|
|
int64 compute_and_transfer_time_ns = 5;
|
|
|
|
// The size of the binary code in the executable.
|
|
int64 executable_size_in_bytes = 6;
|
|
|
|
// Whether this profile was drawn from a cache of profiles instead of from
|
|
// execution on the hardware.
|
|
bool profile_cache_hit = 7;
|
|
}
|
|
|
|
// Handle given to a user that represents an execution that the user launched
|
|
// asynchronously on the device.
|
|
message ExecutionHandle {
|
|
int64 handle = 1;
|
|
}
|
|
|
|
// Handle given to a user that represents a globally accessible allocation.
|
|
// Contrast this against a ComputationDataHandle, which is not globally
|
|
// accessible, since it only exists within a specific computation.
|
|
message GlobalDataHandle {
|
|
int64 handle = 1;
|
|
}
|
|
|
|
// Handle given to a user that represents a replicated virtual device. Each
|
|
// replicated device represents N physical devices for execution where N is the
|
|
// number of replicas.
|
|
message DeviceHandle {
|
|
int64 handle = 1;
|
|
|
|
// The number of model-parallel virtual devices that communicate via XLA
|
|
// Send/Recv instructions.
|
|
int64 device_count = 2;
|
|
}
|
|
|
|
// Handle given to a user to represent a channel between two computations
|
|
// via a Send and Recv instruction pair. Channels are unbuffered, so Send
|
|
// Send instructions will be blocked until the data is transferred.
|
|
message ChannelHandle {
|
|
int64 handle = 1;
|
|
enum ChannelType {
|
|
// Invalid primitive type to serve as default.
|
|
CHANNEL_TYPE_INVALID = 0;
|
|
|
|
// A channel for sending data between devices.
|
|
DEVICE_TO_DEVICE = 1;
|
|
|
|
// A channel for sending data from the device to the host. Can only be used
|
|
// with a Send operation.
|
|
DEVICE_TO_HOST = 2;
|
|
|
|
// A channel for sending data from the host to the device. Can only be used
|
|
// with a Recv operation.
|
|
HOST_TO_DEVICE = 3;
|
|
}
|
|
ChannelType type = 2;
|
|
}
|
|
|
|
// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which
|
|
// represents the device ids assigned to a set of replicated computations.
|
|
// See xla::DeviceAssignment class comment for more details.
|
|
message DeviceAssignmentProto {
|
|
int32 replica_count = 1;
|
|
int32 computation_count = 2;
|
|
|
|
// Each logical computation runs on replica_count physical devices.
|
|
// ComputationDevice represents the device ids assinged to the replicas.
|
|
message ComputationDevice {
|
|
repeated int32 replica_device_ids = 1;
|
|
}
|
|
repeated ComputationDevice computation_devices = 3;
|
|
}
|
|
|
|
// Literals are used when the server and client need to exchange materialized
|
|
// data / results. Literals are also used to describe constants used in
|
|
// computations.
|
|
//
|
|
// Transfers to/from the client are encoded in literal form, and the structure
|
|
// of the repeated fields is implied by the shape.
|
|
message LiteralProto {
|
|
ShapeProto shape = 1;
|
|
repeated bool preds = 2;
|
|
bytes s8s = 15;
|
|
bytes u8s = 3;
|
|
repeated int32 s32s = 4;
|
|
repeated int64 s64s = 5;
|
|
repeated uint32 u32s = 6;
|
|
repeated uint64 u64s = 7;
|
|
repeated float f32s = 8;
|
|
repeated double f64s = 9;
|
|
repeated float c64s = 12; // Stored as interleaved real, imag floats.
|
|
repeated double c128s = 18; // Stored as interleaved real, imag doubles.
|
|
repeated LiteralProto tuple_literals = 10;
|
|
// The F16s, BF16s, U16s and S16s are encoded in little endian byte order
|
|
bytes f16s = 11;
|
|
bytes bf16s = 13;
|
|
bytes u16s = 16;
|
|
bytes s16s = 17;
|
|
repeated int64 sparse_indices = 14;
|
|
// Next = 19
|
|
}
|
|
|
|
message WindowDimension {
|
|
// The size of the window in this dimension. For a rectangle, this would be
|
|
// the width or height.
|
|
int64 size = 1;
|
|
|
|
// The stride at which the window moves across the base area in this
|
|
// dimension. In other words, this is the spacing between different
|
|
// positions of the window in this dimension.
|
|
int64 stride = 2;
|
|
|
|
// If positive, means the amount of padding to add to the base area at the low
|
|
// end of this dimension; if negative, its negative means the number of
|
|
// elements removed from the low end of this dimension. For example, in the
|
|
// horizontal dimension of a rectangle, this would be the number of padding
|
|
// values to pad on the left, given that indices increase when going right.
|
|
// The actual padding value depends upon the context. Convolution pads with
|
|
// zeros. ReduceWindow and SelectAndScatter pads with the reduce function's
|
|
// init value.
|
|
int64 padding_low = 3;
|
|
|
|
// As padding_low, but on the high end of this dimension. For example, in the
|
|
// horizontal dimension of a rectangle, this would be the number of values to
|
|
// pad on the right, given that indices increase when going right.
|
|
int64 padding_high = 4;
|
|
|
|
// Dilation factor of the sliding window in this dimension. A dilation factor
|
|
// of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are
|
|
// implicitly placed between each kernel element. This value may not be less
|
|
// than 1. See documentation for convolution.
|
|
int64 window_dilation = 5;
|
|
|
|
// Dilation factor of the base area in this dimension. A dilation factor of 1
|
|
// means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly
|
|
// placed between each base area element. This value may not be less than 1.
|
|
// See documentation for convolution.
|
|
int64 base_dilation = 6;
|
|
|
|
// Window reversal means that this dimension was logically reversed before the
|
|
// operation.
|
|
bool window_reversal = 7;
|
|
}
|
|
|
|
// Describes the windowing in an operation such as convolution.
|
|
//
|
|
// The window is moved across a base area and for each position of the
|
|
// window a computation is performed. The field below describes the
|
|
// window and the movement of the window across a base area.
|
|
message Window {
|
|
repeated WindowDimension dimensions = 1;
|
|
}
|
|
|
|
// Describes the dimension numbers for a gather operation.
|
|
//
|
|
// See https://www.tensorflow.org/performance/xla/operation_semantics#gather for
|
|
// more details.
|
|
message GatherDimensionNumbers {
|
|
// "Window indices" is a term for a set of indices that index into the
|
|
// interior of a dynamic-slice from the input tensor, the starting indices for
|
|
// which were computed from output_gather_dims (see the operation semantic for
|
|
// how this is defined) and the start_indices tensor.
|
|
//
|
|
// The window indices for a specific output index Out is computed as:
|
|
//
|
|
// i = 0
|
|
// for (k : [0, input_tensor_shape.rank))
|
|
// window_indices[k] =
|
|
// if k in collapsed_slice_dims
|
|
// then 0
|
|
// else Out[offset_dims[i++]]
|
|
repeated int64 offset_dims = 1;
|
|
repeated int64 collapsed_slice_dims = 2;
|
|
|
|
// This is interpreted as a map from i to start_index_map[i]. It
|
|
// transforms the gather index looked up from the start_indices tensor into
|
|
// the starting index in the input space.
|
|
repeated int64 start_index_map = 3;
|
|
|
|
// The dimension in the start_indices input that contains the starting
|
|
// indices.
|
|
int64 index_vector_dim = 4;
|
|
}
|
|
|
|
// Describes the dimension numbers for a scatter operation.
|
|
//
|
|
// All the fields are similar to the corresponding fields in
|
|
// GatherDimensionNumbers. Differences are noted below.
|
|
message ScatterDimensionNumbers {
|
|
// The set of dimensions in the updates shape that are window dimensions.
|
|
repeated int64 update_window_dims = 1;
|
|
// The set of window dimensions that must be inserted into the updates shape.
|
|
repeated int64 inserted_window_dims = 2;
|
|
|
|
repeated int64 scatter_dims_to_operand_dims = 3;
|
|
int64 index_vector_dim = 4;
|
|
}
|
|
|
|
message ConvolutionDimensionNumbers {
|
|
// The number of the dimension that represents batch in the input.
|
|
int64 input_batch_dimension = 7;
|
|
|
|
// The number of the dimension that represents features in the input.
|
|
int64 input_feature_dimension = 8;
|
|
|
|
// The dimension numbers for the spatial dimensions that the window
|
|
// moves through in the input.
|
|
repeated int64 input_spatial_dimensions = 11;
|
|
|
|
// The number of the dimension that represents input features in the
|
|
// convolutional kernel (rhs).
|
|
int64 kernel_input_feature_dimension = 3;
|
|
|
|
// The number of the dimension that represents output features in
|
|
// the convolutional kernel (rhs).
|
|
int64 kernel_output_feature_dimension = 4;
|
|
|
|
// The dimension numbers for the spatial dimensions that the window
|
|
// moves through in the kernel (rhs). window.strides(0) is the
|
|
// stride in the kernel_spatial_dimensions(0) dimension.
|
|
repeated int64 kernel_spatial_dimensions = 6;
|
|
|
|
// The number of the dimension that represents batch in the output.
|
|
int64 output_batch_dimension = 9;
|
|
|
|
// The number of the dimension that represents features in the output.
|
|
int64 output_feature_dimension = 10;
|
|
|
|
// The dimension numbers for the spatial dimensions that the window
|
|
// moves through in the output.
|
|
repeated int64 output_spatial_dimensions = 12;
|
|
|
|
// Next = 13
|
|
}
|
|
|
|
enum FftType {
|
|
FFT = 0; // Forward FFT; complex in, complex out.
|
|
IFFT = 1; // Inverse FFT; complex in, complex out.
|
|
RFFT = 2; // Forward real FFT; real in, fft_length / 2 + 1 complex out
|
|
IRFFT = 3; // Inverse real FFT; fft_length / 2 + 1 complex in,
|
|
// fft_length real out
|
|
}
|
|
|
|
message DotDimensionNumbers {
|
|
// The dimension numbers that represent the 'lhs' contracting dimensions.
|
|
repeated int64 lhs_contracting_dimensions = 1;
|
|
// The dimension numbers that represent the 'rhs' contracting dimensions.
|
|
repeated int64 rhs_contracting_dimensions = 2;
|
|
// The dimension numbers that represent the 'lhs' batch dimensions.
|
|
repeated int64 lhs_batch_dimensions = 3;
|
|
// The dimension numbers that represent the 'rhs' batch dimensions.
|
|
repeated int64 rhs_batch_dimensions = 4;
|
|
}
|
|
|
|
enum RandomDistribution {
|
|
RNG_INVALID = 0;
|
|
|
|
// Creates a uniform-distribution-generated random number on the semi-open
|
|
// interval [parameter[0], parameter[1]).
|
|
RNG_UNIFORM = 1;
|
|
|
|
// Creates a normal-distribution-generated random number with mean
|
|
// parameter[0] and standard deviation parameter[1].
|
|
RNG_NORMAL = 2;
|
|
|
|
// Next: 4
|
|
}
|
|
|
|
enum RandomAlgorithm {
|
|
RNG_DEFAULT = 0; // Backend dependent default algorithm.
|
|
RNG_THREE_FRY = 1;
|
|
RNG_PHILOX = 2;
|
|
// Next: 2
|
|
}
|
|
|
|
message TriangularSolveOptions {
|
|
// If true, solves ax = b. If false, solves xa = b.
|
|
bool left_side = 1;
|
|
|
|
// If true, 'a' is lower triangular. If false, 'a' is upper triangular.
|
|
bool lower = 2;
|
|
|
|
// If true, the diagonal elements of 'a' are assumed to be 1 and not accessed.
|
|
bool unit_diagonal = 3;
|
|
|
|
// Should we transpose or use the adjoint of 'a'?
|
|
enum Transpose {
|
|
TRANSPOSE_INVALID = 0;
|
|
NO_TRANSPOSE = 1; // Don't transpose 'a'.
|
|
TRANSPOSE = 2; // Transpose 'a'.
|
|
ADJOINT = 3; // Complex conjugate and transpose 'a'.
|
|
}
|
|
Transpose transpose_a = 4;
|
|
}
|
|
|
|
message CholeskyOptions {
|
|
// If true, uses the lower triangle of `a`. If false, uses the upper triangle
|
|
// of `a`.
|
|
bool lower = 1;
|
|
}
|
|
|
|
// Generic map of attributes used to pass hints / configuration options from
|
|
// the Python frontend to the XLA backend.
|
|
message FrontendAttributes {
|
|
map<string, string> map = 1;
|
|
}
|
|
|
|
message OpSharding {
|
|
enum Type {
|
|
// This sharding is replicated across all devices (implies maximal,
|
|
// all other fields are unused).
|
|
REPLICATED = 0;
|
|
// This sharding is maximal - one device runs the entire operation.
|
|
MAXIMAL = 1;
|
|
// This sharding is a tuple - only the tuple_shardings field is valid.
|
|
TUPLE = 2;
|
|
// None of the above; tile_shape and tile_assignment are both used.
|
|
OTHER = 3;
|
|
}
|
|
Type type = 1;
|
|
// The shape of the sharded tile.
|
|
ShapeProto tile_shape = 2;
|
|
// The shape of the tile assignment tensor - this must be the same rank as
|
|
// tile_shape and the product of its dimensions must equal
|
|
// tile_assignment_devices.size().
|
|
repeated int64 tile_assignment_dimensions = 3;
|
|
// Flattened list of device IDs. The order of flattening is the same as used
|
|
// by IndexUtil::MultiToLinearIndex(tile_assignment_shape).
|
|
repeated int64 tile_assignment_devices = 4;
|
|
// If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape,
|
|
// in pre-order. The tuple shape could be nested; here we store just a
|
|
// flattened list of all leaves in the tuple shape. Note that the tuple shape
|
|
// is not stored here; shardings do not store the shapes to which they are
|
|
// applied, this is inferred from the instruction this sharding gets attached
|
|
// to.
|
|
repeated OpSharding tuple_shardings = 5;
|
|
|
|
// Only used for OTHER type. If true, data is sharded according to other
|
|
// dimensions of tile_assignment(), but replicated across devices along the
|
|
// last dimension. (Experimental)
|
|
bool replicate_on_last_tile_dim = 6;
|
|
}
|
|
|
|
// Describes the replica groups in a cross replica op (e.g., all-reduce and
|
|
// all-to-all).
|
|
message ReplicaGroup {
|
|
// The ids of the replicas that belongs to the same group. The ordering of the
|
|
// ids matters in some ops (e.g., all-to-all).
|
|
repeated int64 replica_ids = 1;
|
|
}
|
|
|
|
// Describes the source target pair in the collective permute op.
|
|
message SourceTarget {
|
|
int64 source = 1;
|
|
int64 target = 2;
|
|
}
|
|
|
|
// Used to indicate the precision configuration. It has backend specific
|
|
// meaning.
|
|
message PrecisionConfig {
|
|
enum Precision {
|
|
DEFAULT = 0;
|
|
HIGH = 1;
|
|
HIGHEST = 2;
|
|
|
|
// Next: 3
|
|
}
|
|
repeated Precision operand_precision = 1;
|
|
|
|
// Next: 2
|
|
}
|
|
|
|
// Describes whether all data-parallelism replicas will receive the same
|
|
// parameter data at each buffer.
|
|
message ParameterReplication {
|
|
// A list of boolean values for the flattened leaf buffers. Each value
|
|
// indicates whether the corresponding leaf buffer is replicated.
|
|
//
|
|
// If this field is empty, it means no buffer is replicated. Otherwise, the
|
|
// number of elements in this field must match the number of leaf buffers in
|
|
// the HLO instruction's shape.
|
|
repeated bool replicated_at_leaf_buffers = 1;
|
|
}
|
|
|
|
// A backend-config for kWhile loops that stores the loop's trip count, if it is
|
|
// known.
|
|
//
|
|
// This is useful for backends that can implement a `for i in 0..N` loop more
|
|
// efficiently than a `while` loop. For example, on GPUs, we can implement a
|
|
// `for i in 0..N` loop by enqueueing the kernels for the loop body N times,
|
|
// whereas implementing a `while` loop requires a host-device sync on each
|
|
// iteration.
|
|
message WhileLoopBackendConfig {
|
|
message KnownTripCount {
|
|
int64 n = 1;
|
|
}
|
|
// This indirection lets us distinguish between known-trip-count == 0 and
|
|
// unknown-trip-count.
|
|
KnownTripCount known_trip_count = 1;
|
|
}
|
|
|
|
// Specifies a pair of output/operand buffers for kCustomCall that alias each
|
|
// other.
|
|
message CustomCallOutputOperandAliasing {
|
|
repeated int64 output_shape_index = 1;
|
|
int64 operand_index = 2;
|
|
repeated int64 operand_shape_index = 3;
|
|
}
|