[XLA] Fix documentation for the BatchNormGrad and Conditional ops
PiperOrigin-RevId: 314396627 Change-Id: I4b88a4953cf12b14eb0150569de7c373feab434d
This commit is contained in:
parent
4bddb55693
commit
123f576851
@ -182,7 +182,7 @@ respect to `operand`, `offset` and `scale` across all the other dimensions. The
|
||||
`feature_index` must be a valid index for the feature dimension in `operand`.
|
||||
|
||||
The three gradients are defined by the following formulas (assuming a
|
||||
4-dimensional array as `operand` and with feature dimension index $$l$$, batch
|
||||
4-dimensional array as `operand` and with feature dimension index `l`, batch
|
||||
size `m` and spatial sizes `w` and `h`):
|
||||
|
||||
\\[ \begin{split} c_l&=
|
||||
@ -618,36 +618,44 @@ See also
|
||||
<b> `Conditional(pred, true_operand, true_computation, false_operand,
|
||||
false_computation)` </b>
|
||||
|
||||
<!-- mdformat off(disable mdformat for proper MathJax formatting) -->
|
||||
|
||||
Arguments | Type | Semantics
|
||||
------------------- | ---------------- | --------------------------------------
|
||||
------------------- | ---------------- | ------------------------------------
|
||||
`pred` | `XlaOp` | Scalar of type `PRED`
|
||||
`true_operand` | `XlaOp` | Argument of type $$ T_0 $$
|
||||
`true_computation` | `XlaComputation` | XlaComputation of type $$ T_0 \to S$$
|
||||
`false_operand` | `XlaOp` | Argument of type $$ T_1 $$
|
||||
`false_computation` | `XlaComputation` | XlaComputation of type $$ T_1 \to S $$
|
||||
`true_operand` | `XlaOp` | Argument of type \\(T_0\\)
|
||||
`true_computation` | `XlaComputation` | XlaComputation of type \\(T_0 \to S\\)
|
||||
`false_operand` | `XlaOp` | Argument of type \\(T_1\\)
|
||||
`false_computation` | `XlaComputation` | XlaComputation of type \\(T_1 \to S\\)
|
||||
|
||||
Executes `true_computation` if `pred` is `true`, `false_computation` if `pred`
|
||||
is `false`, and returns the result.
|
||||
|
||||
The `true_computation` must take in a single argument of type $$ T_0 $$ and will
|
||||
The `true_computation` must take in a single argument of type \\(T_0\\) and will
|
||||
be invoked with `true_operand` which must be of the same type. The
|
||||
`false_computation` must take in a single argument of type $$ T_1 $$ and will be
|
||||
`false_computation` must take in a single argument of type \\(T_1\\) and will be
|
||||
invoked with `false_operand` which must be of the same type. The type of the
|
||||
returned value of `true_computation` and `false_computation` must be the same.
|
||||
|
||||
<!-- mdformat on -->
|
||||
|
||||
Note that only one of `true_computation` and `false_computation` will be
|
||||
executed depending on the value of `pred`.
|
||||
|
||||
<b> `Conditional(branch_index, branch_computations, branch_operands)` </b>
|
||||
|
||||
<!-- mdformat off(disable mdformat for proper MathJax formatting) -->
|
||||
|
||||
| Arguments | Type | Semantics |
|
||||
| --------------------- | --------------------- | ---------------------------- |
|
||||
| `branch_index` | `XlaOp` | Scalar of type `S32` |
|
||||
| `branch_computations` | sequence of N | XlaComputations of type $$ |
|
||||
| `branch_computations` | sequence of N | XlaComputations of type \\( |
|
||||
: : `XlaComputation` : T_0 \to S , T_1 \to S , ..., :
|
||||
: : : T_{N-1} \to S $$ :
|
||||
| `branch_operands` | sequence of N `XlaOp` | Arguments of type $$ T_0 , |
|
||||
: : : T_1 , ..., T_{N-1} $$ :
|
||||
: : : T_{N-1} \to S \\) :
|
||||
| `branch_operands` | sequence of N `XlaOp` | Arguments of type \\( T_0 , |
|
||||
: : : T_1 , ..., T_{N-1} \\) :
|
||||
|
||||
<!-- mdformat on -->
|
||||
|
||||
Executes `branch_computations[branch_index]`, and returns the result. If
|
||||
`branch_index` is an `S32` which is < 0 or >= N, then `branch_computations[N-1]`
|
||||
|
Loading…
Reference in New Issue
Block a user