diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py
index e0e85295fec..fe270af3d63 100644
--- a/tensorflow/compiler/tests/pooling_ops_test.py
+++ b/tensorflow/compiler/tests/pooling_ops_test.py
@@ -292,8 +292,15 @@ class PoolGradTest(XLATestCase):
CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
- def _VerifyOneTest(self, pool_func, pool_grad_func, input_sizes, ksize,
- strides, padding, data_format):
+ def _VerifyOneTest(self,
+ pool_func,
+ pool_grad_func,
+ input_sizes,
+ ksize,
+ strides,
+ padding,
+ data_format,
+ pool_grad_grad_func=None):
"""Verifies the output values of the pooling gradient function.
Args:
@@ -304,9 +311,19 @@ class PoolGradTest(XLATestCase):
strides: The stride dimensions
padding: Padding type.
data_format: The data format we use to run the pooling operation.
+ pool_grad_grad_func: Second-order gradient function, if available.
"""
total_size = np.prod(input_sizes)
- x = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_sizes)
+ # TODO(b/73062247): MaxPoolGradGrad can confuse gradients when x is equally
+ # maximal at 16 bits. Switch to np.random.randn when resolved.
+ x = np.arange(1, total_size + 1, dtype=np.float32)
+ x *= (np.random.randint(2, size=total_size) * 2 - 1) # Flip signs randomly
+ # Verify some specifically interesting values...
+ x[np.random.choice(total_size)] = np.inf
+ x[np.random.choice(total_size)] = -np.inf
+ # TODO(b/74222344): Fix nan handling for max pool grad.
+ # x[np.random.choice(total_size)] = np.nan
+ x = x.reshape(input_sizes)
with self.test_session() as sess:
# Use the forward pool function to compute some corresponding outputs
# (needed for the CPU device, and we need the shape in both cases).
@@ -323,6 +340,8 @@ class PoolGradTest(XLATestCase):
output_gradient_vals = np.arange(
1, output_vals.size + 1, dtype=np.float32)
output_gradient_vals = output_gradient_vals.reshape(output_vals.shape)
+ output_grad_grad_vals = np.arange(1, x.size + 1, dtype=np.float32)
+ output_grad_grad_vals = output_grad_grad_vals.reshape(x.shape)
# Use the Tensorflow CPU pooling gradient to compute the expected input
# gradients.
@@ -342,18 +361,36 @@ class PoolGradTest(XLATestCase):
{inputs: x,
output_gradients: output_gradient_vals})
+ output_grad_gradients = array_ops.placeholder(
+ dtypes.float32, shape=expected_input_gradient_vals.shape)
+ if pool_grad_grad_func is not None:
+ expected_grad_gradients = pool_grad_grad_func(
+ inputs,
+ outputs,
+ output_grad_gradients,
+ ksize=ksize,
+ strides=strides,
+ padding=padding,
+ data_format="NHWC")
+ expected_grad_gradients_vals = sess.run(expected_grad_gradients, {
+ inputs: x,
+ output_grad_gradients: output_grad_grad_vals
+ })
+
# Run the gradient op on the XLA device
with self.test_scope():
outputs = array_ops.placeholder(dtypes.float32, shape=output_vals.shape)
xla_inputs = inputs
xla_outputs = outputs
xla_output_gradients = output_gradients
+ xla_output_grad_gradients = output_grad_gradients
xla_ksize = ksize
xla_strides = strides
if data_format == "NCHW":
xla_inputs = NHWCToNCHW(inputs)
xla_outputs = NHWCToNCHW(outputs)
xla_output_gradients = NHWCToNCHW(output_gradients)
+ xla_output_grad_gradients = NHWCToNCHW(output_grad_gradients)
xla_ksize = NHWCToNCHW(ksize)
xla_strides = NHWCToNCHW(strides)
actual_input_gradients = pool_grad_func(
@@ -366,22 +403,54 @@ class PoolGradTest(XLATestCase):
data_format=data_format)
if data_format == "NCHW":
actual_input_gradients = NCHWToNHWC(actual_input_gradients)
- actual = sess.run(actual_input_gradients, {
+ if pool_grad_grad_func is not None:
+ actual_grad_gradients = pool_grad_grad_func(
+ xla_inputs,
+ xla_outputs,
+ xla_output_grad_gradients,
+ ksize=xla_ksize,
+ strides=xla_strides,
+ padding=padding,
+ data_format=data_format)
+ if data_format == "NCHW":
+ actual_grad_gradients = NCHWToNHWC(actual_grad_gradients)
+ actual_input_gradients_vals = sess.run(actual_input_gradients, {
inputs: x,
outputs: output_vals,
output_gradients: output_gradient_vals
})
-
# Compare the Tensorflow and XLA results.
self.assertAllClose(
- expected_input_gradient_vals.flatten(),
- actual.flatten(),
+ expected_input_gradient_vals,
+ actual_input_gradients_vals,
rtol=1e-4,
atol=1e-6)
- self.assertShapeEqual(actual, inputs)
+ self.assertShapeEqual(actual_input_gradients_vals, inputs)
- def _VerifyValues(self, pool_func, pool_grad_func, input_sizes, ksize,
- strides, padding):
+ if pool_grad_grad_func is not None:
+ actual_grad_gradients_vals = sess.run(
+ actual_grad_gradients, {
+ inputs: x,
+ outputs: output_vals,
+ output_grad_gradients: output_grad_grad_vals
+ })
+
+ # Compare the Tensorflow and XLA results.
+ self.assertAllClose(
+ expected_grad_gradients_vals,
+ actual_grad_gradients_vals,
+ rtol=1e-4,
+ atol=1e-6)
+ self.assertShapeEqual(actual_grad_gradients_vals, outputs)
+
+ def _VerifyValues(self,
+ pool_func,
+ pool_grad_func,
+ input_sizes,
+ ksize,
+ strides,
+ padding,
+ pool_grad_grad_func=None):
"""Verifies the output values of the pooling function.
Args:
@@ -391,12 +460,20 @@ class PoolGradTest(XLATestCase):
ksize: The kernel size dimensions
strides: The stride dimensions
padding: Padding type.
+ pool_grad_grad_func: Second-order gradient function, if available.
"""
for data_format in GetTestConfigs():
- self._VerifyOneTest(pool_func, pool_grad_func, input_sizes, ksize,
- strides, padding, data_format)
+ self._VerifyOneTest(
+ pool_func,
+ pool_grad_func,
+ input_sizes,
+ ksize,
+ strides,
+ padding,
+ data_format,
+ pool_grad_grad_func=pool_grad_grad_func)
- def _TestPooling(self, forward_op, backward_op):
+ def _TestPooling(self, forward_op, backward_op, pool_grad_grad_func=None):
# VALID padding
self._VerifyValues(
forward_op,
@@ -404,7 +481,8 @@ class PoolGradTest(XLATestCase):
input_sizes=[1, 3, 3, 3],
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
- padding="VALID")
+ padding="VALID",
+ pool_grad_grad_func=pool_grad_grad_func)
# SAME padding
self._VerifyValues(
@@ -413,7 +491,8 @@ class PoolGradTest(XLATestCase):
input_sizes=[1, 2, 3, 3],
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
- padding="SAME")
+ padding="SAME",
+ pool_grad_grad_func=pool_grad_grad_func)
# SAME padding, non square window
self._VerifyValues(
@@ -422,7 +501,8 @@ class PoolGradTest(XLATestCase):
input_sizes=[1, 2, 2, 1],
ksize=[1, 1, 2, 1],
strides=[1, 1, 1, 1],
- padding="SAME")
+ padding="SAME",
+ pool_grad_grad_func=pool_grad_grad_func)
# VALID padding, uneven stride
self._VerifyValues(
@@ -431,14 +511,16 @@ class PoolGradTest(XLATestCase):
input_sizes=[1, 4, 4, 1],
ksize=[1, 2, 2, 1],
strides=[1, 1, 2, 1],
- padding="VALID")
+ padding="VALID",
+ pool_grad_grad_func=pool_grad_grad_func)
self._VerifyValues(
forward_op,
backward_op,
input_sizes=[1, 4, 4, 1],
ksize=[1, 2, 2, 1],
strides=[1, 2, 1, 1],
- padding="VALID")
+ padding="VALID",
+ pool_grad_grad_func=pool_grad_grad_func)
# SAME padding, size 4 input
self._VerifyValues(
@@ -447,7 +529,8 @@ class PoolGradTest(XLATestCase):
input_sizes=[1, 4, 4, 4],
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
- padding="SAME")
+ padding="SAME",
+ pool_grad_grad_func=pool_grad_grad_func)
# SAME padding, size 8 input
self._VerifyValues(
@@ -456,10 +539,14 @@ class PoolGradTest(XLATestCase):
input_sizes=[1, 8, 8, 8],
ksize=[1, 3, 3, 1],
strides=[1, 2, 2, 1],
- padding="SAME")
+ padding="SAME",
+ pool_grad_grad_func=pool_grad_grad_func)
def testMaxPool(self):
- self._TestPooling(nn_ops.max_pool, gen_nn_ops.max_pool_grad)
+ self._TestPooling(
+ nn_ops.max_pool,
+ gen_nn_ops.max_pool_grad,
+ pool_grad_grad_func=gen_nn_ops.max_pool_grad_grad)
def testAvgPool(self):
# Wrapper around AvgPoolGrad that ignores extra arguments needed by
diff --git a/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md
index 91351421bca..20179b67991 100644
--- a/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md
+++ b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md
@@ -3,6 +3,7 @@
Operator | Type Constraint
------------------------------------- | ---------------
`Abs` | `T={double,float,int32,int64}`
+`Acos` | `T={complex64,double,float,int32,int64}`
`Acosh` | `T={complex64,double,float}`
`Add` | `T={complex64,double,float,int32,int64}`
`AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}`
@@ -15,10 +16,12 @@ Operator | Type Constraint
`ApproximateEqual` | `T={complex64,double,float,int32,int64,uint32,uint64}`
`ArgMax` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={float}`
`ArgMin` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}`
+`Asin` | `T={complex64,double,float,int32,int64}`
`Asinh` | `T={complex64,double,float}`
`AssignAddVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}`
`AssignSubVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}`
`AssignVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
+`Atan` | `T={complex64,double,float,int32,int64}`
`Atan2` | `T={double,float}`
`Atanh` | `T={complex64,double,float}`
`AvgPool` | `T={double,float}`
@@ -75,6 +78,10 @@ Operator | Type Constraint
`FFT` |
`FFT2D` |
`FFT3D` |
+`FakeQuantWithMinMaxArgs` |
+`FakeQuantWithMinMaxArgsGradient` |
+`FakeQuantWithMinMaxVars` |
+`FakeQuantWithMinMaxVarsGradient` |
`Fill` | `index_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Floor` | `T={double,float}`
`FloorDiv` | `T={complex64,double,float,int32,int64}`
@@ -84,6 +91,7 @@ Operator | Type Constraint
`FusedBatchNormGradV2` | `U={float}`
`T={float}`
`FusedBatchNormV2` | `U={float}`
`T={float}`
`Gather` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}`
+`GatherNd` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}`
`GatherV2` | `Taxis={int32,int64}`
`Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Greater` | `T={double,float,int32,int64,uint32,uint64}`
`GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}`
@@ -117,14 +125,18 @@ Operator | Type Constraint
`LogicalNot` |
`LogicalOr` |
`MatMul` | `T={complex64,double,float}`
+`MatrixBandPart` | `Tindex={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
+`MatrixSetDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`MatrixTriangularSolve` | `T={complex64,double,float}`
`Max` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}`
`MaxPool` | `T={double,float,int32,int64}`
`MaxPool3D` | `T={float}`
`MaxPool3DGrad` | `TInput={float}`
`T={float}`
`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}`
+`MaxPoolGradGrad` | `T={float}`
+`MaxPoolGradGradV2` | `T={float}`
`MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}`
`MaxPoolV2` | `T={double,float,int32,int64}`
`Maximum` | `T={double,float,int32,int64}`
@@ -186,6 +198,7 @@ Operator | Type Constraint
`Round` | `T={complex64,double,float,int32,int64}`
`Rsqrt` | `T={complex64,double,float}`
`RsqrtGrad` | `T={complex64,double,float}`
+`ScatterNd` | `Tindices={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Select` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Selu` | `T={double,float}`
`SeluGrad` | `T={double,float}`
@@ -198,6 +211,7 @@ Operator | Type Constraint
`Sinh` | `T={complex64,double,float}`
`Size` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Slice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
+`Snapshot` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Softmax` | `T={double,float}`
`SoftmaxCrossEntropyWithLogits` | `T={double,float}`
`Softplus` | `T={double,float,int32,int64,uint32,uint64}`
diff --git a/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md
index b9bdb829d77..55f0538dba7 100644
--- a/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md
+++ b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md
@@ -3,6 +3,7 @@
Operator | Type Constraint
------------------------------------- | ---------------
`Abs` | `T={double,float,int32,int64}`
+`Acos` | `T={complex64,double,float,int32,int64}`
`Acosh` | `T={complex64,double,float}`
`Add` | `T={complex64,double,float,int32,int64}`
`AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}`
@@ -15,10 +16,12 @@ Operator | Type Constraint
`ApproximateEqual` | `T={complex64,double,float,int32,int64,uint32,uint64}`
`ArgMax` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}`
`ArgMin` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}`
+`Asin` | `T={complex64,double,float,int32,int64}`
`Asinh` | `T={complex64,double,float}`
`AssignAddVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}`
`AssignSubVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}`
`AssignVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
+`Atan` | `T={complex64,double,float,int32,int64}`
`Atan2` | `T={double,float}`
`Atanh` | `T={complex64,double,float}`
`AvgPool` | `T={double,float}`
@@ -75,6 +78,10 @@ Operator | Type Constraint
`FFT` |
`FFT2D` |
`FFT3D` |
+`FakeQuantWithMinMaxArgs` |
+`FakeQuantWithMinMaxArgsGradient` |
+`FakeQuantWithMinMaxVars` |
+`FakeQuantWithMinMaxVarsGradient` |
`Fill` | `index_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Floor` | `T={double,float}`
`FloorDiv` | `T={complex64,double,float,int32,int64}`
@@ -84,6 +91,7 @@ Operator | Type Constraint
`FusedBatchNormGradV2` | `U={float}`
`T={float}`
`FusedBatchNormV2` | `U={float}`
`T={float}`
`Gather` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}`
+`GatherNd` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}`
`GatherV2` | `Taxis={int32,int64}`
`Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Greater` | `T={double,float,int32,int64,uint32,uint64}`
`GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}`
@@ -117,14 +125,18 @@ Operator | Type Constraint
`LogicalNot` |
`LogicalOr` |
`MatMul` | `T={complex64,double,float}`
+`MatrixBandPart` | `Tindex={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
+`MatrixSetDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`MatrixTriangularSolve` | `T={complex64,double,float}`
`Max` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}`
`MaxPool` | `T={double,float,int32,int64}`
`MaxPool3D` | `T={float}`
`MaxPool3DGrad` | `TInput={float}`
`T={float}`
`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}`
+`MaxPoolGradGrad` | `T={float}`
+`MaxPoolGradGradV2` | `T={float}`
`MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}`
`MaxPoolV2` | `T={double,float,int32,int64}`
`Maximum` | `T={double,float,int32,int64}`
@@ -183,6 +195,7 @@ Operator | Type Constraint
`Round` | `T={complex64,double,float,int32,int64}`
`Rsqrt` | `T={complex64,double,float}`
`RsqrtGrad` | `T={complex64,double,float}`
+`ScatterNd` | `Tindices={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Select` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Selu` | `T={double,float}`
`SeluGrad` | `T={double,float}`
@@ -195,6 +208,7 @@ Operator | Type Constraint
`Sinh` | `T={complex64,double,float}`
`Size` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Slice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
+`Snapshot` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Softmax` | `T={double,float}`
`SoftmaxCrossEntropyWithLogits` | `T={double,float}`
`Softplus` | `T={double,float,int32,int64,uint32,uint64}`
diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
index d4fb5dd4e06..086a9491aa9 100644
--- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
@@ -525,5 +525,172 @@ class AvgPool3DGradOp : public AvgPoolGradOp {
REGISTER_XLA_OP(Name("AvgPool3DGrad").CompileTimeConstInput("orig_input_shape"),
AvgPool3DGradOp);
+class MaxPoolGradGradOp : public XlaOpKernel {
+ public:
+ MaxPoolGradGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
+ : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
+ if (ctx->num_inputs() == 3) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
+ }
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
+ }
+
+ int num_dims() const { return num_spatial_dims_ + 2; }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ if (ctx->num_inputs() != 3) {
+ OP_REQUIRES(
+ ctx, ctx->num_inputs() == 5,
+ errors::InvalidArgument("Must supply ksize and stride arguments."));
+ const TensorShape ksize_shape = ctx->InputShape(3);
+ // Validate input sizes.
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
+ errors::InvalidArgument("ksize must be a vector, not shape ",
+ ksize_shape.DebugString()));
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_));
+
+ const TensorShape stride_shape = ctx->InputShape(4);
+ // Validate input sizes.
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
+ errors::InvalidArgument("stride must be a vector, not shape ",
+ stride_shape.DebugString()));
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_));
+ }
+
+ OP_REQUIRES(ctx, ksize_.size() == num_dims(),
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify ",
+ num_dims(), " dimensions"));
+ OP_REQUIRES(ctx, stride_.size() == num_dims(),
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify ",
+ num_dims(), " dimensions"));
+
+ const TensorShape tensor_in_shape = ctx->InputShape(0);
+ const TensorShape tensor_out_shape = ctx->InputShape(1);
+ const TensorShape out_backprop_shape = ctx->InputShape(2);
+
+ // For maxpooling, tensor_in should have num_dims() dimensions.
+ OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(),
+ errors::InvalidArgument("tensor_in must be ", num_dims(),
+ "-dimensional"));
+ OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(),
+ errors::InvalidArgument("tensor_out must be ", num_dims(),
+ "-dimensional"));
+ // For maxpooling, out_backprop should have num_dims() dimensions.
+ OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
+ errors::InvalidArgument("out_backprop must be ", num_dims(),
+ "-dimensional"));
+
+ // What we want to compute:
+ // Given y = MaxPool(x), and xs_grad = MaxPoolGrad(x, y, ys_grad)
+ // MaxPoolGradGrad computes {ys_grad}_grad given x, y, and {xs_grad}_grad.
+ //
+ // In the regular TF op, this amounts to selecting for each window the
+ // incoming backprop value from xs_grad_grad that corresponds to the maximal
+ // value in the corresponding window of x.
+ //
+ // TODO(b/73062247): What we really want is a ReduceWindow with different
+ // arrays for index selection vs return value selection--a select-to-gather.
+ //
+ // Here, we implement a bitwise hack: we use the hi 16 bits of input for
+ // separate max pooling alongside each of the hi and lo 16 bits of
+ // out_backprop packed into 16 lo bits, which we then glue back together at
+ // the end to get a full 32 bits of gradient.
+ //
+ // This could select the wrong backprop value for two x values that are
+ // equally maximal up to the first 16 bits, in which case we are taking the
+ // latter.
+ //
+ // Note that in principle we could use 32 separate maxpools to recover each
+ // of 32 bits of the gradient while preserving 31 bits of input for the max
+ // pooling criteria; here, we just truncate to the first 16 bits of input.
+
+ auto input = ctx->Input(0);
+ auto out_backprop = ctx->Input(2);
+
+ auto b = ctx->builder();
+
+ auto sixteen = b->ConstantR0(16);
+ // in (f32) -> round to bf16 -> f32 for correct bitwidth -> 16-high-bit u32
+ auto in_hi = b->BitcastConvertType(
+ b->ConvertElementType(b->ConvertElementType(input, xla::BF16),
+ xla::F32),
+ xla::U32);
+ auto bp_int = b->BitcastConvertType(out_backprop, xla::U32);
+ auto bp_hi = b->ShiftRightLogical(bp_int, sixteen);
+ auto bp_lo = b->ShiftRightLogical(b->ShiftLeft(bp_int, sixteen), sixteen);
+ auto in_hi_bp_hi = b->Add(in_hi, bp_hi); // Want an unsigned add.
+ auto in_hi_bp_lo = b->Add(in_hi, bp_lo); // Want an unsigned add.
+
+ auto init_value = XlaHelpers::MinValue(b, DT_FLOAT);
+ // We will reduce by taking the maximal value up to 16 bits (ignoring the lo
+ // 16 bits of packed-in hi/lo backprop value).
+ auto rb = b->CreateSubBuilder("GreaterOrEqOf_ByFirst16Bits");
+ {
+ // F32 parameters to satisfy lowering type restriction for reduce opcode.
+ const xla::Shape scalar = xla::ShapeUtil::MakeShape(xla::F32, {});
+ auto lhs = rb->Parameter(0, scalar, "lhs");
+ auto rhs = rb->Parameter(1, scalar, "rhs");
+ auto sixteen = rb->ConstantR0(16);
+ auto lhs_criteria = rb->ShiftLeft(
+ rb->ShiftRightLogical(rb->BitcastConvertType(lhs, xla::S32), sixteen),
+ sixteen);
+ auto rhs_criteria = rb->ShiftLeft(
+ rb->ShiftRightLogical(rb->BitcastConvertType(rhs, xla::S32), sixteen),
+ sixteen);
+ // Must use a F32 comparison, because S32 would not work for negatives.
+ rb->Select(rb->Ge(rb->BitcastConvertType(lhs_criteria, xla::F32),
+ rb->BitcastConvertType(rhs_criteria, xla::F32)),
+ lhs, rhs);
+ }
+ auto reduce = rb->BuildAndNoteError();
+ xla::Padding xla_padding =
+ (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
+ auto pooled_hi =
+ b->ReduceWindow(b->BitcastConvertType(in_hi_bp_hi, xla::F32),
+ init_value, reduce, ksize_, stride_, xla_padding);
+ auto pooled_lo =
+ b->ReduceWindow(b->BitcastConvertType(in_hi_bp_lo, xla::F32),
+ init_value, reduce, ksize_, stride_, xla_padding);
+ auto grads_hi =
+ b->ShiftLeft(b->BitcastConvertType(pooled_hi, xla::U32), sixteen);
+ auto grads_lo = b->ShiftRightLogical(
+ b->ShiftLeft(b->BitcastConvertType(pooled_lo, xla::U32), sixteen),
+ sixteen);
+ auto grads = b->Add(grads_hi, grads_lo); // Want an unsigned add.
+
+ xla::PrimitiveType element_type;
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
+ ctx->SetOutput(0, b->BitcastConvertType(grads, element_type));
+ }
+
+ protected:
+ const int num_spatial_dims_;
+ std::vector ksize_;
+ std::vector stride_;
+ Padding padding_;
+ TensorFormat data_format_ = FORMAT_NHWC;
+};
+
+class MaxPool2DGradGradOp : public MaxPoolGradGradOp {
+ public:
+ explicit MaxPool2DGradGradOp(OpKernelConstruction* ctx)
+ : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/2) {
+ string data_format;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
+ OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ }
+};
+REGISTER_XLA_OP(Name("MaxPoolGradGrad").TypeConstraint("T", DT_FLOAT),
+ MaxPool2DGradGradOp);
+REGISTER_XLA_OP(Name("MaxPoolGradGradV2")
+ .TypeConstraint("T", DT_FLOAT)
+ .CompileTimeConstInput("ksize")
+ .CompileTimeConstInput("strides"),
+ MaxPool2DGradGradOp);
+
} // anonymous namespace
} // namespace tensorflow