diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py index e0e85295fec..fe270af3d63 100644 --- a/tensorflow/compiler/tests/pooling_ops_test.py +++ b/tensorflow/compiler/tests/pooling_ops_test.py @@ -292,8 +292,15 @@ class PoolGradTest(XLATestCase): CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0" - def _VerifyOneTest(self, pool_func, pool_grad_func, input_sizes, ksize, - strides, padding, data_format): + def _VerifyOneTest(self, + pool_func, + pool_grad_func, + input_sizes, + ksize, + strides, + padding, + data_format, + pool_grad_grad_func=None): """Verifies the output values of the pooling gradient function. Args: @@ -304,9 +311,19 @@ class PoolGradTest(XLATestCase): strides: The stride dimensions padding: Padding type. data_format: The data format we use to run the pooling operation. + pool_grad_grad_func: Second-order gradient function, if available. """ total_size = np.prod(input_sizes) - x = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_sizes) + # TODO(b/73062247): MaxPoolGradGrad can confuse gradients when x is equally + # maximal at 16 bits. Switch to np.random.randn when resolved. + x = np.arange(1, total_size + 1, dtype=np.float32) + x *= (np.random.randint(2, size=total_size) * 2 - 1) # Flip signs randomly + # Verify some specifically interesting values... + x[np.random.choice(total_size)] = np.inf + x[np.random.choice(total_size)] = -np.inf + # TODO(b/74222344): Fix nan handling for max pool grad. + # x[np.random.choice(total_size)] = np.nan + x = x.reshape(input_sizes) with self.test_session() as sess: # Use the forward pool function to compute some corresponding outputs # (needed for the CPU device, and we need the shape in both cases). @@ -323,6 +340,8 @@ class PoolGradTest(XLATestCase): output_gradient_vals = np.arange( 1, output_vals.size + 1, dtype=np.float32) output_gradient_vals = output_gradient_vals.reshape(output_vals.shape) + output_grad_grad_vals = np.arange(1, x.size + 1, dtype=np.float32) + output_grad_grad_vals = output_grad_grad_vals.reshape(x.shape) # Use the Tensorflow CPU pooling gradient to compute the expected input # gradients. @@ -342,18 +361,36 @@ class PoolGradTest(XLATestCase): {inputs: x, output_gradients: output_gradient_vals}) + output_grad_gradients = array_ops.placeholder( + dtypes.float32, shape=expected_input_gradient_vals.shape) + if pool_grad_grad_func is not None: + expected_grad_gradients = pool_grad_grad_func( + inputs, + outputs, + output_grad_gradients, + ksize=ksize, + strides=strides, + padding=padding, + data_format="NHWC") + expected_grad_gradients_vals = sess.run(expected_grad_gradients, { + inputs: x, + output_grad_gradients: output_grad_grad_vals + }) + # Run the gradient op on the XLA device with self.test_scope(): outputs = array_ops.placeholder(dtypes.float32, shape=output_vals.shape) xla_inputs = inputs xla_outputs = outputs xla_output_gradients = output_gradients + xla_output_grad_gradients = output_grad_gradients xla_ksize = ksize xla_strides = strides if data_format == "NCHW": xla_inputs = NHWCToNCHW(inputs) xla_outputs = NHWCToNCHW(outputs) xla_output_gradients = NHWCToNCHW(output_gradients) + xla_output_grad_gradients = NHWCToNCHW(output_grad_gradients) xla_ksize = NHWCToNCHW(ksize) xla_strides = NHWCToNCHW(strides) actual_input_gradients = pool_grad_func( @@ -366,22 +403,54 @@ class PoolGradTest(XLATestCase): data_format=data_format) if data_format == "NCHW": actual_input_gradients = NCHWToNHWC(actual_input_gradients) - actual = sess.run(actual_input_gradients, { + if pool_grad_grad_func is not None: + actual_grad_gradients = pool_grad_grad_func( + xla_inputs, + xla_outputs, + xla_output_grad_gradients, + ksize=xla_ksize, + strides=xla_strides, + padding=padding, + data_format=data_format) + if data_format == "NCHW": + actual_grad_gradients = NCHWToNHWC(actual_grad_gradients) + actual_input_gradients_vals = sess.run(actual_input_gradients, { inputs: x, outputs: output_vals, output_gradients: output_gradient_vals }) - # Compare the Tensorflow and XLA results. self.assertAllClose( - expected_input_gradient_vals.flatten(), - actual.flatten(), + expected_input_gradient_vals, + actual_input_gradients_vals, rtol=1e-4, atol=1e-6) - self.assertShapeEqual(actual, inputs) + self.assertShapeEqual(actual_input_gradients_vals, inputs) - def _VerifyValues(self, pool_func, pool_grad_func, input_sizes, ksize, - strides, padding): + if pool_grad_grad_func is not None: + actual_grad_gradients_vals = sess.run( + actual_grad_gradients, { + inputs: x, + outputs: output_vals, + output_grad_gradients: output_grad_grad_vals + }) + + # Compare the Tensorflow and XLA results. + self.assertAllClose( + expected_grad_gradients_vals, + actual_grad_gradients_vals, + rtol=1e-4, + atol=1e-6) + self.assertShapeEqual(actual_grad_gradients_vals, outputs) + + def _VerifyValues(self, + pool_func, + pool_grad_func, + input_sizes, + ksize, + strides, + padding, + pool_grad_grad_func=None): """Verifies the output values of the pooling function. Args: @@ -391,12 +460,20 @@ class PoolGradTest(XLATestCase): ksize: The kernel size dimensions strides: The stride dimensions padding: Padding type. + pool_grad_grad_func: Second-order gradient function, if available. """ for data_format in GetTestConfigs(): - self._VerifyOneTest(pool_func, pool_grad_func, input_sizes, ksize, - strides, padding, data_format) + self._VerifyOneTest( + pool_func, + pool_grad_func, + input_sizes, + ksize, + strides, + padding, + data_format, + pool_grad_grad_func=pool_grad_grad_func) - def _TestPooling(self, forward_op, backward_op): + def _TestPooling(self, forward_op, backward_op, pool_grad_grad_func=None): # VALID padding self._VerifyValues( forward_op, @@ -404,7 +481,8 @@ class PoolGradTest(XLATestCase): input_sizes=[1, 3, 3, 3], ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], - padding="VALID") + padding="VALID", + pool_grad_grad_func=pool_grad_grad_func) # SAME padding self._VerifyValues( @@ -413,7 +491,8 @@ class PoolGradTest(XLATestCase): input_sizes=[1, 2, 3, 3], ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=pool_grad_grad_func) # SAME padding, non square window self._VerifyValues( @@ -422,7 +501,8 @@ class PoolGradTest(XLATestCase): input_sizes=[1, 2, 2, 1], ksize=[1, 1, 2, 1], strides=[1, 1, 1, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=pool_grad_grad_func) # VALID padding, uneven stride self._VerifyValues( @@ -431,14 +511,16 @@ class PoolGradTest(XLATestCase): input_sizes=[1, 4, 4, 1], ksize=[1, 2, 2, 1], strides=[1, 1, 2, 1], - padding="VALID") + padding="VALID", + pool_grad_grad_func=pool_grad_grad_func) self._VerifyValues( forward_op, backward_op, input_sizes=[1, 4, 4, 1], ksize=[1, 2, 2, 1], strides=[1, 2, 1, 1], - padding="VALID") + padding="VALID", + pool_grad_grad_func=pool_grad_grad_func) # SAME padding, size 4 input self._VerifyValues( @@ -447,7 +529,8 @@ class PoolGradTest(XLATestCase): input_sizes=[1, 4, 4, 4], ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=pool_grad_grad_func) # SAME padding, size 8 input self._VerifyValues( @@ -456,10 +539,14 @@ class PoolGradTest(XLATestCase): input_sizes=[1, 8, 8, 8], ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=pool_grad_grad_func) def testMaxPool(self): - self._TestPooling(nn_ops.max_pool, gen_nn_ops.max_pool_grad) + self._TestPooling( + nn_ops.max_pool, + gen_nn_ops.max_pool_grad, + pool_grad_grad_func=gen_nn_ops.max_pool_grad_grad) def testAvgPool(self): # Wrapper around AvgPoolGrad that ignores extra arguments needed by diff --git a/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md index 91351421bca..20179b67991 100644 --- a/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md +++ b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md @@ -3,6 +3,7 @@ Operator | Type Constraint ------------------------------------- | --------------- `Abs` | `T={double,float,int32,int64}` +`Acos` | `T={complex64,double,float,int32,int64}` `Acosh` | `T={complex64,double,float}` `Add` | `T={complex64,double,float,int32,int64}` `AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}` @@ -15,10 +16,12 @@ Operator | Type Constraint `ApproximateEqual` | `T={complex64,double,float,int32,int64,uint32,uint64}` `ArgMax` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={float}` `ArgMin` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Asin` | `T={complex64,double,float,int32,int64}` `Asinh` | `T={complex64,double,float}` `AssignAddVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` `AssignSubVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` `AssignVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Atan` | `T={complex64,double,float,int32,int64}` `Atan2` | `T={double,float}` `Atanh` | `T={complex64,double,float}` `AvgPool` | `T={double,float}` @@ -75,6 +78,10 @@ Operator | Type Constraint `FFT` | `FFT2D` | `FFT3D` | +`FakeQuantWithMinMaxArgs` | +`FakeQuantWithMinMaxArgsGradient` | +`FakeQuantWithMinMaxVars` | +`FakeQuantWithMinMaxVarsGradient` | `Fill` | `index_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Floor` | `T={double,float}` `FloorDiv` | `T={complex64,double,float,int32,int64}` @@ -84,6 +91,7 @@ Operator | Type Constraint `FusedBatchNormGradV2` | `U={float}`
`T={float}` `FusedBatchNormV2` | `U={float}`
`T={float}` `Gather` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`GatherNd` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` `GatherV2` | `Taxis={int32,int64}`
`Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` `Greater` | `T={double,float,int32,int64,uint32,uint64}` `GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}` @@ -117,14 +125,18 @@ Operator | Type Constraint `LogicalNot` | `LogicalOr` | `MatMul` | `T={complex64,double,float}` +`MatrixBandPart` | `Tindex={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`MatrixSetDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixTriangularSolve` | `T={complex64,double,float}` `Max` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` `MaxPool` | `T={double,float,int32,int64}` `MaxPool3D` | `T={float}` `MaxPool3DGrad` | `TInput={float}`
`T={float}` `MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}` +`MaxPoolGradGrad` | `T={float}` +`MaxPoolGradGradV2` | `T={float}` `MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}` `MaxPoolV2` | `T={double,float,int32,int64}` `Maximum` | `T={double,float,int32,int64}` @@ -186,6 +198,7 @@ Operator | Type Constraint `Round` | `T={complex64,double,float,int32,int64}` `Rsqrt` | `T={complex64,double,float}` `RsqrtGrad` | `T={complex64,double,float}` +`ScatterNd` | `Tindices={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Select` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Selu` | `T={double,float}` `SeluGrad` | `T={double,float}` @@ -198,6 +211,7 @@ Operator | Type Constraint `Sinh` | `T={complex64,double,float}` `Size` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Slice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Snapshot` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Softmax` | `T={double,float}` `SoftmaxCrossEntropyWithLogits` | `T={double,float}` `Softplus` | `T={double,float,int32,int64,uint32,uint64}` diff --git a/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md index b9bdb829d77..55f0538dba7 100644 --- a/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md +++ b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md @@ -3,6 +3,7 @@ Operator | Type Constraint ------------------------------------- | --------------- `Abs` | `T={double,float,int32,int64}` +`Acos` | `T={complex64,double,float,int32,int64}` `Acosh` | `T={complex64,double,float}` `Add` | `T={complex64,double,float,int32,int64}` `AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}` @@ -15,10 +16,12 @@ Operator | Type Constraint `ApproximateEqual` | `T={complex64,double,float,int32,int64,uint32,uint64}` `ArgMax` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` `ArgMin` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Asin` | `T={complex64,double,float,int32,int64}` `Asinh` | `T={complex64,double,float}` `AssignAddVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` `AssignSubVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` `AssignVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Atan` | `T={complex64,double,float,int32,int64}` `Atan2` | `T={double,float}` `Atanh` | `T={complex64,double,float}` `AvgPool` | `T={double,float}` @@ -75,6 +78,10 @@ Operator | Type Constraint `FFT` | `FFT2D` | `FFT3D` | +`FakeQuantWithMinMaxArgs` | +`FakeQuantWithMinMaxArgsGradient` | +`FakeQuantWithMinMaxVars` | +`FakeQuantWithMinMaxVarsGradient` | `Fill` | `index_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Floor` | `T={double,float}` `FloorDiv` | `T={complex64,double,float,int32,int64}` @@ -84,6 +91,7 @@ Operator | Type Constraint `FusedBatchNormGradV2` | `U={float}`
`T={float}` `FusedBatchNormV2` | `U={float}`
`T={float}` `Gather` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`GatherNd` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` `GatherV2` | `Taxis={int32,int64}`
`Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` `Greater` | `T={double,float,int32,int64,uint32,uint64}` `GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}` @@ -117,14 +125,18 @@ Operator | Type Constraint `LogicalNot` | `LogicalOr` | `MatMul` | `T={complex64,double,float}` +`MatrixBandPart` | `Tindex={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`MatrixSetDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixTriangularSolve` | `T={complex64,double,float}` `Max` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` `MaxPool` | `T={double,float,int32,int64}` `MaxPool3D` | `T={float}` `MaxPool3DGrad` | `TInput={float}`
`T={float}` `MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}` +`MaxPoolGradGrad` | `T={float}` +`MaxPoolGradGradV2` | `T={float}` `MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}` `MaxPoolV2` | `T={double,float,int32,int64}` `Maximum` | `T={double,float,int32,int64}` @@ -183,6 +195,7 @@ Operator | Type Constraint `Round` | `T={complex64,double,float,int32,int64}` `Rsqrt` | `T={complex64,double,float}` `RsqrtGrad` | `T={complex64,double,float}` +`ScatterNd` | `Tindices={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Select` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Selu` | `T={double,float}` `SeluGrad` | `T={double,float}` @@ -195,6 +208,7 @@ Operator | Type Constraint `Sinh` | `T={complex64,double,float}` `Size` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Slice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Snapshot` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Softmax` | `T={double,float}` `SoftmaxCrossEntropyWithLogits` | `T={double,float}` `Softplus` | `T={double,float,int32,int64,uint32,uint64}` diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index d4fb5dd4e06..086a9491aa9 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -525,5 +525,172 @@ class AvgPool3DGradOp : public AvgPoolGradOp { REGISTER_XLA_OP(Name("AvgPool3DGrad").CompileTimeConstInput("orig_input_shape"), AvgPool3DGradOp); +class MaxPoolGradGradOp : public XlaOpKernel { + public: + MaxPoolGradGradOp(OpKernelConstruction* ctx, int num_spatial_dims) + : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { + if (ctx->num_inputs() == 3) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_)); + } + OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); + } + + int num_dims() const { return num_spatial_dims_ + 2; } + + void Compile(XlaOpKernelContext* ctx) override { + if (ctx->num_inputs() != 3) { + OP_REQUIRES( + ctx, ctx->num_inputs() == 5, + errors::InvalidArgument("Must supply ksize and stride arguments.")); + const TensorShape ksize_shape = ctx->InputShape(3); + // Validate input sizes. + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape), + errors::InvalidArgument("ksize must be a vector, not shape ", + ksize_shape.DebugString())); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_)); + + const TensorShape stride_shape = ctx->InputShape(4); + // Validate input sizes. + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape), + errors::InvalidArgument("stride must be a vector, not shape ", + stride_shape.DebugString())); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_)); + } + + OP_REQUIRES(ctx, ksize_.size() == num_dims(), + errors::InvalidArgument("Sliding window ksize field must " + "specify ", + num_dims(), " dimensions")); + OP_REQUIRES(ctx, stride_.size() == num_dims(), + errors::InvalidArgument("Sliding window strides field must " + "specify ", + num_dims(), " dimensions")); + + const TensorShape tensor_in_shape = ctx->InputShape(0); + const TensorShape tensor_out_shape = ctx->InputShape(1); + const TensorShape out_backprop_shape = ctx->InputShape(2); + + // For maxpooling, tensor_in should have num_dims() dimensions. + OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(), + errors::InvalidArgument("tensor_in must be ", num_dims(), + "-dimensional")); + OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(), + errors::InvalidArgument("tensor_out must be ", num_dims(), + "-dimensional")); + // For maxpooling, out_backprop should have num_dims() dimensions. + OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(), + errors::InvalidArgument("out_backprop must be ", num_dims(), + "-dimensional")); + + // What we want to compute: + // Given y = MaxPool(x), and xs_grad = MaxPoolGrad(x, y, ys_grad) + // MaxPoolGradGrad computes {ys_grad}_grad given x, y, and {xs_grad}_grad. + // + // In the regular TF op, this amounts to selecting for each window the + // incoming backprop value from xs_grad_grad that corresponds to the maximal + // value in the corresponding window of x. + // + // TODO(b/73062247): What we really want is a ReduceWindow with different + // arrays for index selection vs return value selection--a select-to-gather. + // + // Here, we implement a bitwise hack: we use the hi 16 bits of input for + // separate max pooling alongside each of the hi and lo 16 bits of + // out_backprop packed into 16 lo bits, which we then glue back together at + // the end to get a full 32 bits of gradient. + // + // This could select the wrong backprop value for two x values that are + // equally maximal up to the first 16 bits, in which case we are taking the + // latter. + // + // Note that in principle we could use 32 separate maxpools to recover each + // of 32 bits of the gradient while preserving 31 bits of input for the max + // pooling criteria; here, we just truncate to the first 16 bits of input. + + auto input = ctx->Input(0); + auto out_backprop = ctx->Input(2); + + auto b = ctx->builder(); + + auto sixteen = b->ConstantR0(16); + // in (f32) -> round to bf16 -> f32 for correct bitwidth -> 16-high-bit u32 + auto in_hi = b->BitcastConvertType( + b->ConvertElementType(b->ConvertElementType(input, xla::BF16), + xla::F32), + xla::U32); + auto bp_int = b->BitcastConvertType(out_backprop, xla::U32); + auto bp_hi = b->ShiftRightLogical(bp_int, sixteen); + auto bp_lo = b->ShiftRightLogical(b->ShiftLeft(bp_int, sixteen), sixteen); + auto in_hi_bp_hi = b->Add(in_hi, bp_hi); // Want an unsigned add. + auto in_hi_bp_lo = b->Add(in_hi, bp_lo); // Want an unsigned add. + + auto init_value = XlaHelpers::MinValue(b, DT_FLOAT); + // We will reduce by taking the maximal value up to 16 bits (ignoring the lo + // 16 bits of packed-in hi/lo backprop value). + auto rb = b->CreateSubBuilder("GreaterOrEqOf_ByFirst16Bits"); + { + // F32 parameters to satisfy lowering type restriction for reduce opcode. + const xla::Shape scalar = xla::ShapeUtil::MakeShape(xla::F32, {}); + auto lhs = rb->Parameter(0, scalar, "lhs"); + auto rhs = rb->Parameter(1, scalar, "rhs"); + auto sixteen = rb->ConstantR0(16); + auto lhs_criteria = rb->ShiftLeft( + rb->ShiftRightLogical(rb->BitcastConvertType(lhs, xla::S32), sixteen), + sixteen); + auto rhs_criteria = rb->ShiftLeft( + rb->ShiftRightLogical(rb->BitcastConvertType(rhs, xla::S32), sixteen), + sixteen); + // Must use a F32 comparison, because S32 would not work for negatives. + rb->Select(rb->Ge(rb->BitcastConvertType(lhs_criteria, xla::F32), + rb->BitcastConvertType(rhs_criteria, xla::F32)), + lhs, rhs); + } + auto reduce = rb->BuildAndNoteError(); + xla::Padding xla_padding = + (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; + auto pooled_hi = + b->ReduceWindow(b->BitcastConvertType(in_hi_bp_hi, xla::F32), + init_value, reduce, ksize_, stride_, xla_padding); + auto pooled_lo = + b->ReduceWindow(b->BitcastConvertType(in_hi_bp_lo, xla::F32), + init_value, reduce, ksize_, stride_, xla_padding); + auto grads_hi = + b->ShiftLeft(b->BitcastConvertType(pooled_hi, xla::U32), sixteen); + auto grads_lo = b->ShiftRightLogical( + b->ShiftLeft(b->BitcastConvertType(pooled_lo, xla::U32), sixteen), + sixteen); + auto grads = b->Add(grads_hi, grads_lo); // Want an unsigned add. + + xla::PrimitiveType element_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type)); + ctx->SetOutput(0, b->BitcastConvertType(grads, element_type)); + } + + protected: + const int num_spatial_dims_; + std::vector ksize_; + std::vector stride_; + Padding padding_; + TensorFormat data_format_ = FORMAT_NHWC; +}; + +class MaxPool2DGradGradOp : public MaxPoolGradGradOp { + public: + explicit MaxPool2DGradGradOp(OpKernelConstruction* ctx) + : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/2) { + string data_format; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); + OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + } +}; +REGISTER_XLA_OP(Name("MaxPoolGradGrad").TypeConstraint("T", DT_FLOAT), + MaxPool2DGradGradOp); +REGISTER_XLA_OP(Name("MaxPoolGradGradV2") + .TypeConstraint("T", DT_FLOAT) + .CompileTimeConstInput("ksize") + .CompileTimeConstInput("strides"), + MaxPool2DGradGradOp); + } // anonymous namespace } // namespace tensorflow