[XLA] Fix documentation for the BatchNormGrad and Conditional ops

PiperOrigin-RevId: 314396627
Change-Id: I4b88a4953cf12b14eb0150569de7c373feab434d
This commit is contained in:
David Majnemer 2020-06-02 13:47:47 -07:00 committed by TensorFlower Gardener
parent 4bddb55693
commit 123f576851

View File

@ -182,7 +182,7 @@ respect to `operand`, `offset` and `scale` across all the other dimensions. The
`feature_index` must be a valid index for the feature dimension in `operand`.
The three gradients are defined by the following formulas (assuming a
4-dimensional array as `operand` and with feature dimension index $$l$$, batch
4-dimensional array as `operand` and with feature dimension index `l`, batch
size `m` and spatial sizes `w` and `h`):
\\[ \begin{split} c_l&=
@ -618,36 +618,44 @@ See also
<b> `Conditional(pred, true_operand, true_computation, false_operand,
false_computation)` </b>
<!-- mdformat off(disable mdformat for proper MathJax formatting) -->
Arguments | Type | Semantics
------------------- | ---------------- | --------------------------------------
------------------- | ---------------- | ------------------------------------
`pred` | `XlaOp` | Scalar of type `PRED`
`true_operand` | `XlaOp` | Argument of type $$ T_0 $$
`true_computation` | `XlaComputation` | XlaComputation of type $$ T_0 \to S$$
`false_operand` | `XlaOp` | Argument of type $$ T_1 $$
`false_computation` | `XlaComputation` | XlaComputation of type $$ T_1 \to S $$
`true_operand` | `XlaOp` | Argument of type \\(T_0\\)
`true_computation` | `XlaComputation` | XlaComputation of type \\(T_0 \to S\\)
`false_operand` | `XlaOp` | Argument of type \\(T_1\\)
`false_computation` | `XlaComputation` | XlaComputation of type \\(T_1 \to S\\)
Executes `true_computation` if `pred` is `true`, `false_computation` if `pred`
is `false`, and returns the result.
The `true_computation` must take in a single argument of type $$ T_0 $$ and will
The `true_computation` must take in a single argument of type \\(T_0\\) and will
be invoked with `true_operand` which must be of the same type. The
`false_computation` must take in a single argument of type $$ T_1 $$ and will be
`false_computation` must take in a single argument of type \\(T_1\\) and will be
invoked with `false_operand` which must be of the same type. The type of the
returned value of `true_computation` and `false_computation` must be the same.
<!-- mdformat on -->
Note that only one of `true_computation` and `false_computation` will be
executed depending on the value of `pred`.
<b> `Conditional(branch_index, branch_computations, branch_operands)` </b>
<!-- mdformat off(disable mdformat for proper MathJax formatting) -->
| Arguments | Type | Semantics |
| --------------------- | --------------------- | ---------------------------- |
| `branch_index` | `XlaOp` | Scalar of type `S32` |
| `branch_computations` | sequence of N | XlaComputations of type $$ |
| `branch_computations` | sequence of N | XlaComputations of type \\( |
: : `XlaComputation` : T_0 \to S , T_1 \to S , ..., :
: : : T_{N-1} \to S $$ :
| `branch_operands` | sequence of N `XlaOp` | Arguments of type $$ T_0 , |
: : : T_1 , ..., T_{N-1} $$ :
: : : T_{N-1} \to S \\) :
| `branch_operands` | sequence of N `XlaOp` | Arguments of type \\( T_0 , |
: : : T_1 , ..., T_{N-1} \\) :
<!-- mdformat on -->
Executes `branch_computations[branch_index]`, and returns the result. If
`branch_index` is an `S32` which is < 0 or >= N, then `branch_computations[N-1]`