STT-tensorflow/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc
Brian Zhao 556824565d Automated g4 rollback of changelist 304856650.
PiperOrigin-RevId: 305076580
Change-Id: I98886941dbfb25acd99d6ca2601eaee6dc657034
2020-04-06 11:29:58 -07:00

683 lines
30 KiB
C++

/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <numeric>
#include <vector>
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
using xla::XlaOp;
// Calculates the bilinear weight tensor, given basis ratio (px, py) of the
// sampling position:
// W = [(1-px)*(1-py), px*(1-py), (1-px)*py, px*py]
// 'ratio' tensor has dimensions [batch, dim_0, ...dim_n, 2].
//
// The returned tensor has dimensions [batch, dim_0, ... dim_n, 4].
XlaOp BilinearWeights(XlaOpKernelContext* ctx, XlaOp ratio,
const TensorShape warp_shape,
xla::PrimitiveType xla_type) {
auto first_term = xla::ConstantR2<float>(
ctx->builder(), {{1.0, 1.0}, {0.0, 1.0}, {1.0, 0.0}, {0.0, 0.0}});
first_term = xla::ConvertElementType(first_term, xla_type);
auto warp_dims = warp_shape.dim_sizes();
std::vector<int64> broadcast_dims(warp_dims.begin(), warp_dims.end() - 1);
broadcast_dims.push_back(4);
broadcast_dims.push_back(2);
const int64 broadcast_dims_size = broadcast_dims.size();
std::vector<int64> last_two_dims_indices = {(broadcast_dims_size - 2),
(broadcast_dims_size - 1)};
auto broadcast_first_term =
xla::BroadcastInDim(first_term, broadcast_dims, last_two_dims_indices);
// Ratio is of the same dimension as warp, which is [batch, dim_0,... dim_n,
// 2], we broadcast ratio tensor to 'broadcast_dim' by keeping the
// [batch, dim_0,...dim_n] dimensions and the [2] dimension as the last
// dimension.
std::vector<int64> ratio_broadcast_indices(broadcast_dims.size());
std::iota(ratio_broadcast_indices.begin(), ratio_broadcast_indices.end(), 0);
ratio_broadcast_indices.erase(ratio_broadcast_indices.end() - 2);
auto broadcast_ratio =
xla::BroadcastInDim(ratio, broadcast_dims, ratio_broadcast_indices);
auto first_term_subtract_weights = broadcast_first_term - broadcast_ratio;
// Now we have [(1-px, 1-py), (-px, 1-py), (1-px, -py), (px, py)], need to
// flip the signs of the second and the third term.
auto sign_change = xla::ConstantR2<float>(
ctx->builder(), {{1.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}, {1.0, 1.0}});
sign_change = xla::ConvertElementType(sign_change, xla_type);
auto broadcast_sign_change =
xla::BroadcastInDim(sign_change, broadcast_dims, last_two_dims_indices);
auto flipped = first_term_subtract_weights * broadcast_sign_change;
// Build up the final bilinear weight tensor by multiply reduction, which
// gives:
// [(1-px)*(1-py), px*(1-py), (1-px)*py, px*py]
// for each 4 neighboring pixels where px and py are the weight of the target
// pixel we are sampling from.
return xla::Reduce(
flipped, xla::One(ctx->builder(), xla_type),
xla::CreateScalarMultiplyComputation(xla_type, ctx->builder()),
{broadcast_dims_size - 1});
}
// Concatenates the batch indices to the (x, y) coordinate indices.
// This is done by first creating an Iota tensor that represents the current
// batch it is in, then concatenate with the givin (coordinate) indices.
//
// The resulting tensor has dimension (batch, dim_0, ... dim_n, 3) where
// the last dimension of size 3 in turn is [batch_number, x, y].
// The [batch_number, x, y] dimension is needed because the indices
// [x,y] alone cannot allow the xla::Gather operation to gather from the input
// data, which is of dimension [batch, height(y), width(x), channel] with
// 'batch' being the first dimension.
XlaOp ConcatenateIota(xla::XlaBuilder* b, XlaOp indices,
const TensorShape& warp_shape) {
// We need to create an iota tensor with the same batch dimension.
std::vector<int64> dimensions;
for (auto dim : warp_shape) {
dimensions.push_back(dim.size);
}
// Except the last dimension, which is of size 1.
dimensions.back() = 1;
auto batch_indices =
xla::Iota(b, xla::ShapeUtil::MakeShape(xla::S32, dimensions),
/*iota_dimension=*/0);
return xla::ConcatInDim(b, {batch_indices, indices}, dimensions.size() - 1);
}
// Gathers the 2x2 neighbors of the input starting_indices, and return a
// tensor of dimension [batch, dim_0, ... dim_n, 4, data_channels].
// 'gather_indices' is of dimension [batch, dim_0, ..., dim_n, 3] where the last
// dimension of size 3 is (batch_no, x, y).
XlaOp Gather2by2Neighbors(xla::XlaBuilder* b, XlaOp data, XlaOp gather_indices,
int64 data_channels, int warp_dims) {
xla::GatherDimensionNumbers gather_dim_numbers;
const int64 neighbor_data_dimensions = warp_dims + 2;
// Since the Gather output dimensions are [batch, dim_0, ... dim_n, 2, 2,
// data_channels], the offset dimensions for Gather is the last 3 dimensions.
gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 3);
gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 2);
gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 1);
// The last dimension of 'gather_indices' is the starting indices for gather.
gather_dim_numbers.set_index_vector_dim(warp_dims - 1);
gather_dim_numbers.add_collapsed_slice_dims(0);
gather_dim_numbers.add_start_index_map(0);
// Since input is of dimension [batch, height(y), width(x), channel], and warp
// is of dimension [batch, x, y], the ordering of x, y here needs to be
// swapped when gathering.
gather_dim_numbers.add_start_index_map(2);
gather_dim_numbers.add_start_index_map(1);
// Data dimensions are [batch, x, y, channel].
// Output dimensions are [batch, dim_0, ... dim_n, 2, 2, data_channels].
auto neighbors_data = xla::Gather(data, gather_indices, gather_dim_numbers,
/*slice_sizes=*/{1, 2, 2, data_channels});
// Collapse the ...,2,2,... dimensions into ...,4,...
return xla::Collapse(neighbors_data, {warp_dims - 1, warp_dims});
}
// Scatter 'updates' tensor to 'grad_data' based on 'indices'. Returns the
// resulting tensor of dimension: [batch, dim_0, ...dim_n, 2, 2, data_channels].
// This function can also be seen as the inverse of 'Gather2by2Neighbors'.
XlaOp ScatterToGradData(XlaOpKernelContext* ctx, XlaOp grad_data, XlaOp indices,
XlaOp updates, int64 warp_dims,
xla::PrimitiveType xla_type) {
xla::ScatterDimensionNumbers scatter_dim_numbers;
const int64 neighbor_data_dimensions = warp_dims + 2;
// Since the Scatter output dimensions are [batch, dim_0, ... dim_n, 2, 2,
// data_channels], the update window dimensions is the last 3 dimensions.
scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 3);
scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 2);
scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 1);
scatter_dim_numbers.set_index_vector_dim(warp_dims - 1);
scatter_dim_numbers.add_inserted_window_dims(0);
scatter_dim_numbers.add_scatter_dims_to_operand_dims(0);
// Since input is of dimension [batch, height(y), width(x), channel], and warp
// is of dimension [batch, x, y], the ordering of x, y here needs to be
// swapped when scattering.
scatter_dim_numbers.add_scatter_dims_to_operand_dims(2);
scatter_dim_numbers.add_scatter_dims_to_operand_dims(1);
return xla::Scatter(grad_data, indices, updates,
xla::CreateScalarAddComputation(xla_type, ctx->builder()),
scatter_dim_numbers);
}
// Bounds samples to 0 if the warp image indices are out of the (-1, image_size)
// bound.
// The resulting dimension is given by 'result_dims'.
XlaOp BoundSamples(XlaOpKernelContext* ctx, XlaOp warp,
xla::PrimitiveType warp_type, TensorShape warp_shape,
std::vector<int64> result_dims,
std::vector<int64> broadcasted_dims, int64 last_warp_dim,
xla::Shape data_shape, XlaOp sample) {
auto is_gt_minus_one =
xla::Gt(warp,
xla::ConvertElementType(
xla::ConstantR1<float>(ctx->builder(), {-1, -1}), warp_type),
/*broadcast_dimensions=*/{warp_shape.dims() - 1});
auto is_lt_image_size = xla::Lt(
warp,
xla::ConvertElementType(
xla::ConstantR1<float>(
ctx->builder(),
{/*width=*/static_cast<float>(data_shape.dimensions(2)),
/*height=*/static_cast<float>(data_shape.dimensions(1))}),
warp_type),
/*broadcast_dimensions=*/{warp_shape.dims() - 1});
auto is_in_bound_padded_x_y = xla::And(is_gt_minus_one, is_lt_image_size);
// Reduce along last dimension. The resulting dimension is:
// [batch, dim_0, ...dim_n].
auto is_in_bound = xla::Reduce(
is_in_bound_padded_x_y, xla::ConstantR0<bool>(ctx->builder(), true),
xla::CreateScalarAndComputation(xla::PrimitiveType::PRED, ctx->builder()),
{last_warp_dim});
// Broadcast 'is_in_bound' to the same dimension as 'result_dims'.
auto broadcasted_is_in_bound =
xla::BroadcastInDim(is_in_bound, result_dims, broadcasted_dims);
// Set out of bound samples to zero.
auto zeros =
xla::Broadcast(xla::Zero(ctx->builder(), warp_type), result_dims);
return xla::Select(broadcasted_is_in_bound, sample, zeros);
}
// Build computation the backprop into input 'data'.
// Where input:
// grad_output is of dimension [batch, dim_0, ...dim_n, channel]
// ratio is of dimension [batch, dim_0, ...dim_n, 2]
// gather_indices is of dimension [batch, dim_0, ...dim_n, 3]
// data_shape is of dimension [batch, x(width), y(height), channel]
//
// Output:
// scatter-add to each 2x2 grad_data neighbor:
// grad_data[fx, fy, chan] += output_grad * dx * dy
// grad_data[cx, fy, chan] += output_grad * (1 - dx) * dy
// grad_data[fx, cy, chan] += output_grad * dx * (1 - dy)
// grad_data[cx, cy, chan] += output_grad * (1 - dx) * (1 - dy)
// where (dx, dy) is (1 - ratio). If (dx, dy) is out of bound, then the their
// contribution is 0 to 'grad_data'.
XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio,
XlaOp gather_indices, XlaOp warp,
xla::PrimitiveType warp_type, TensorShape warp_shape,
int64 last_warp_dim, int64 data_channels,
xla::Shape data_shape) {
// Weights tensor has dimension [batch, dim_0, ... dim_n, 4].
auto weights = BilinearWeights(ctx, ratio, warp_shape, warp_type);
auto warp_dims = warp_shape.dim_sizes();
std::vector<int64> warp_dims_without_last_dims(warp_dims.begin(),
warp_dims.end() - 1);
std::vector<int64> reshaped_weights_dims = warp_dims_without_last_dims;
// Reshape the last dimension of size 4 to two dimensions [2, 2].
reshaped_weights_dims.push_back(2);
reshaped_weights_dims.push_back(2);
std::vector<int64> reshape_dims(warp_shape.dims());
std::iota(reshape_dims.begin(), reshape_dims.end(), 0);
// The dimension is [batch, dim_0,..., dim_n, 2, 2].
auto reshaped_weights = xla::Reshape(weights, /*dimensions=*/reshape_dims,
/*new_sizes=*/reshaped_weights_dims);
std::vector<int64> weights_with_channels_dims = reshaped_weights_dims;
weights_with_channels_dims.push_back(data_channels);
std::vector<int64> reshaped_weights_indices(reshaped_weights_dims.size());
std::iota(reshaped_weights_indices.begin(), reshaped_weights_indices.end(),
0);
// Set out of bound weights to 0.
// The dimension of the reshaped_weight: [batch, dim_0, ...dim_n, 2, 2].
std::vector<int64> reshaped_result_dims(warp_dims.begin(),
warp_dims.end() - 1);
reshaped_result_dims.push_back(2);
reshaped_result_dims.push_back(2);
std::vector<int64> broadcasted_dims(warp_dims.size() - 1);
std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0);
reshaped_weights = BoundSamples(ctx, warp, warp_type, warp_shape,
reshaped_result_dims, broadcasted_dims,
last_warp_dim, data_shape, reshaped_weights);
// The dimension is [batch, dim_0, ..., dim_n, 2, 2, data_channel].
auto broadcast_reshaped_weights = xla::BroadcastInDim(
reshaped_weights, weights_with_channels_dims, reshaped_weights_indices);
std::vector<int64> grad_output_indices(warp_dims_without_last_dims.size());
std::iota(grad_output_indices.begin(), grad_output_indices.end(), 0);
grad_output_indices.push_back(weights_with_channels_dims.size() - 1);
XlaOp broadcast_grad_output = xla::BroadcastInDim(
grad_output, weights_with_channels_dims, grad_output_indices);
auto grad_output_multiply_weights =
broadcast_grad_output * broadcast_reshaped_weights;
auto grad_data = xla::ConstantLiteral(
ctx->builder(), xla::Literal::CreateFromShape(data_shape));
// Pad grad data then slice it back.
//
// After left and right column 0-padding, the new dimension of padded data
// will be [batch, x+2, y+2, channel].
auto padded_grad_data =
xla::Pad(grad_data, xla::Zero(ctx->builder(), warp_type),
xla::MakeEdgePaddingConfig({{0, 0}, {1, 1}, {1, 1}, {0, 0}}));
auto shifting_value = xla::ConstantR1<int32>(
ctx->builder(), {/*batch=*/0, /*x(width)=*/1, /*y(height)=*/1});
auto shifted_gather_indices =
xla::Add(gather_indices, shifting_value, {last_warp_dim});
auto updated_grad_data = ScatterToGradData(
ctx, padded_grad_data, shifted_gather_indices,
grad_output_multiply_weights, warp_shape.dims(), warp_type);
const int64 batch_size = data_shape.dimensions(0);
const int64 width = data_shape.dimensions(1);
const int64 height = data_shape.dimensions(2);
// Slice out the result accounting for the padding.
return xla::Slice(
updated_grad_data, /*start_indices=*/{0, 1, 1, 0},
/*limit_indices=*/{batch_size, width + 1, height + 1, data_channels},
/*strides=*/{1, 1, 1, 1});
}
// Build computation for the backprop into input 'warp'.
// Where input:
// warp is of dimension [batch, dim_0, ...dim_n, 2]
// grad_output is of dimension [batch, dim_0, ...dim_n, channel]
// ratio is of dimension [batch, dim_0, ...dim_n, 2]
// gather_indices is of dimension [batch, dim_0, ...dim_n, 3] where the last
// dimension of size 3 is for {batch, x(width), y(height)}.
// data is of dimension [batch, x, y, channel]
//
// Output (simplified by ignoring the batch dimensions):
// Since the forward path has:
// output = dot(weights * neighbors)
// The backprop into warp will therefore be:
// grad_warp = output_grad * d_output / d_warp
// = output_grad * (d_weights / d_warp * neighbors + d_neighbors /
// d_warp * weight)
// Where:
// d_weights / d_warp_x = [-(1 - py), (1 - py), -py, py]
// d_weights / d_warp_y = [-(1 - px), -px, (1-px), px]
// and
// d_neighbors / d_warp_x = 0
//
// Therefore:
// grad_warp_x = py * (img_cxcy - img_fxcy) + (1-py) * (img_cxfy-img_fxfy)
// grad_warp_y = px * (img_cxcy - img_cxfy) + (1-px) * (img_fxcy-img_fxfy)
//
// where (px, py) is warp, (fx, fy) is the top left corner and (cx, cy) is the
// bottom right corner in a 2x2 neighborhood.
XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio,
XlaOp gather_indices, XlaOp data,
TensorShape warp_shape, int64 data_channels,
xla::PrimitiveType data_type, xla::Shape data_shape) {
auto warp_dims = warp_shape.dim_sizes();
std::vector<int64> warp_dims_without_last_dims(warp_dims.begin(),
warp_dims.end() - 1);
// With dimension [batch, dim_0, ...dim_n, 4]
std::vector<int64> neighbor_broadcast_dims = warp_dims_without_last_dims;
neighbor_broadcast_dims.push_back(4);
// With dimension [batch, dim_0, ...dim_n, 4]
auto neighbor_broadcast_shape =
xla::ShapeUtil::MakeShape(data_type, neighbor_broadcast_dims);
const int64 last_warp_dim = warp_shape.dims() - 1;
// Pad data with 0, before gathering such that 0 will be returned for samples
// in the range of (-1, 0) or (image_dimension-1, image_dimension).
// After left and right column 0-padding, the new dimension of padded data
// will be [batch, x+2, y+2, channel].
auto padded_data =
xla::Pad(data, xla::Zero(ctx->builder(), data_type),
xla::MakeEdgePaddingConfig({{0, 0}, {1, 1}, {1, 1}, {0, 0}}));
auto shifting_value = xla::ConstantR1<int32>(
ctx->builder(), {/*batch=*/0, /*x(width)=*/1, /*y(height)=*/1});
auto shifted_gather_indices =
xla::Add(gather_indices, shifting_value, {last_warp_dim});
// The dimension is [batch, dim_0, ... dim_n, 4, data_channels]
auto neighbors_data =
Gather2by2Neighbors(ctx->builder(), padded_data, shifted_gather_indices,
data_channels, warp_shape.dims());
// Since we will be creating the dot product of:
// lhs: [batch, dim_0, ...dim_n, 4]
// and
// rhs: [batch, dim_0, ...dim_n, 4, data_channels]
// we choose the last dimension of lhs and the second last dimension of rhs,
// with size 4, as the contracting dimension.
xla::DotDimensionNumbers dot_dims;
for (int i = 0; i < warp_shape.dims() - 1; ++i) {
dot_dims.add_lhs_batch_dimensions(i);
dot_dims.add_rhs_batch_dimensions(i);
}
dot_dims.add_lhs_contracting_dimensions(warp_shape.dims() - 1);
dot_dims.add_rhs_contracting_dimensions(warp_shape.dims() - 1);
// img_cxcy - img_fxcy
auto bottom_right_minus_bottom_left = xla::DotGeneral(
xla::BroadcastInDim(
xla::ConvertElementType(
xla::ConstantR1<float>(ctx->builder(), {0, 0, -1, 1}), data_type),
neighbor_broadcast_dims, {last_warp_dim}),
neighbors_data, dot_dims, /*precision_config=*/nullptr);
// img_cxfy - img_fxfy
auto top_right_minus_top_left = xla::DotGeneral(
xla::BroadcastInDim(
xla::ConvertElementType(
xla::ConstantR1<float>(ctx->builder(), {-1, 1, 0, 0}), data_type),
neighbor_broadcast_dims, {last_warp_dim}),
neighbors_data, dot_dims, /*precision_config=*/nullptr);
// img_cxcy - img_cxfy
auto bottom_right_minus_top_right = xla::DotGeneral(
xla::BroadcastInDim(
xla::ConvertElementType(
xla::ConstantR1<float>(ctx->builder(), {0, -1, 0, 1}), data_type),
neighbor_broadcast_dims, {last_warp_dim}),
neighbors_data, dot_dims, /*precision_config=*/nullptr);
// img_fxcy - img_fxfy
auto bottom_left_minus_top_left = xla::DotGeneral(
xla::BroadcastInDim(
xla::ConvertElementType(
xla::ConstantR1<float>(ctx->builder(), {-1, 0, 1, 0}), data_type),
neighbor_broadcast_dims, {last_warp_dim}),
neighbors_data, dot_dims, /*precision_config=*/nullptr);
// Slice out x and y.
auto weight_x = xla::SliceInDim(ratio, /*start_index=*/0, /*limit_index=*/1,
/*stride=*/1, /*dimno=*/last_warp_dim);
auto weight_y = xla::SliceInDim(ratio, /*start_index=*/1, /*limit_index=*/2,
/*stride=*/1, /*dimno=*/last_warp_dim);
// Build 1 - y and 1 - x.
auto one_minus_y = xla::One(ctx->builder(), data_type) - weight_y;
auto one_minus_x = xla::One(ctx->builder(), data_type) - weight_x;
auto x_before_reduce =
grad_output * weight_y * bottom_right_minus_bottom_left +
one_minus_y * top_right_minus_top_left;
std::vector<int64> reshaped_sizes = warp_dims_without_last_dims;
reshaped_sizes.push_back(1);
std::vector<int64> reshaped_dims(warp_dims_without_last_dims.size());
std::iota(reshaped_dims.begin(), reshaped_dims.end(), 0);
// Reduce-add along the channel dimension.
auto x_result =
xla::Reduce(x_before_reduce, xla::Zero(ctx->builder(), data_type),
xla::CreateScalarAddComputation(data_type, ctx->builder()),
{last_warp_dim});
// Reshape before concatenating with y values.
XlaOp reshaped_x = xla::Reshape(x_result, reshaped_dims, reshaped_sizes);
auto y_before_reduce = grad_output * weight_x * bottom_right_minus_top_right +
one_minus_x * bottom_left_minus_top_left;
// Reduce-add along the channel dimension.
auto y_result =
xla::Reduce(y_before_reduce, xla::Zero(ctx->builder(), data_type),
xla::CreateScalarAddComputation(data_type, ctx->builder()),
{last_warp_dim});
XlaOp reshaped_y = xla::Reshape(y_result, reshaped_dims, reshaped_sizes);
return xla::ConcatInDim(ctx->builder(), {reshaped_x, reshaped_y},
last_warp_dim);
}
class ResamplerOp : public XlaOpKernel {
public:
explicit ResamplerOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
TensorShape data_shape = ctx->InputShape("data");
OP_REQUIRES(ctx, data_shape.dims() == 4,
errors::InvalidArgument("data must be 4-dimensional",
data_shape.DebugString()));
const int64 data_channels = data_shape.dim_size(3);
xla::PrimitiveType data_type = ctx->input_xla_type(0);
TensorShape warp_shape = ctx->InputShape("warp");
OP_REQUIRES(ctx, warp_shape.dims() >= 2,
errors::InvalidArgument("warp must be at least 2-dimensional",
warp_shape.DebugString()));
for (int size : warp_shape.dim_sizes()) {
OP_REQUIRES(ctx, size > 0,
errors::InvalidArgument("warp sizes must be positive, got [",
size, "]"));
}
const int64 last_warp_dim = warp_shape.dims() - 1;
// Last dimension of warp shape must be of size 2.
OP_REQUIRES(ctx, warp_shape.dim_size(last_warp_dim) == 2,
errors::InvalidArgument(
"the last dimension of warp must be exactly size 2."));
xla::PrimitiveType warp_type = ctx->input_xla_type(1);
XlaOp data = ctx->Input("data");
XlaOp warp = ctx->Input("warp");
// Find the coordinates of the top left corner for the 2x2 region to be
// sampled from. The dimensions are [batch, dim_0, ... dim_n, 2] where the
// last dimension of size 2 in turn is [x, y].
XlaOp top_left = xla::ConvertElementType(warp, xla::S32);
auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape);
// The dimension is [batch, dim_0, ... dim_n, 4, data_channels]
auto neighbors_data = Gather2by2Neighbors(
ctx->builder(), data, gather_indices, data_channels, warp_shape.dims());
// Dimensions are [batch, dim_0, ... dim_n, 2].
XlaOp ratio = warp - xla::ConvertElementType(top_left, data_type);
// Obtain the bilinear blending weights, the dimension is [batch, dim_0,
// ...dim_n, 4].
auto weights = BilinearWeights(ctx, ratio, warp_shape, data_type);
// Since we will be creating the dot product of:
// lhs: [batch, dim_0, ...dim_n, 4]
// and
// rhs: [batch, dim_0, ...dim_n, 4, data_channels]
// we choose the last dimension of lhs and the second last dimension of rhs,
// with size 4, as the contracting dimension.
xla::DotDimensionNumbers dot_dims;
for (int i = 0; i < warp_shape.dims() - 1; ++i) {
dot_dims.add_lhs_batch_dimensions(i);
dot_dims.add_rhs_batch_dimensions(i);
}
dot_dims.add_lhs_contracting_dimensions(warp_shape.dims() - 1);
dot_dims.add_rhs_contracting_dimensions(warp_shape.dims() - 1);
// The dimension is [batch, dim_0, ...dim_n, data_channels].
auto blended_pixels = xla::DotGeneral(weights, neighbors_data, dot_dims,
/*precision_config=*/nullptr);
// Handle out of boundary cases by constructing a predicate mask array based
// on the in-bound condition, and output 0 for the blended pixel value if
// out-bound. The dimension is the same as top_left: [batch, dim_0,
// ...dim_n, 2] where the last dimension of size 2 is the [x, y] coordinate.
auto is_ge_zero = xla::Ge(warp, xla::ZerosLike(warp));
auto is_lt_image_size = xla::Lt(
warp,
xla::ConvertElementType(
xla::ConstantR1<float>(
ctx->builder(),
{/*width=*/static_cast<float>(data_shape.dim_size(2) - 1),
/*height=*/static_cast<float>(data_shape.dim_size(1) - 1)}),
warp_type),
/*broadcast_dimensions=*/{warp_shape.dims() - 1});
auto is_in_bound_x_y = xla::And(is_ge_zero, is_lt_image_size);
// Reduce along last dimension. The resulting dimension is:
// [batch, dim_0, ...dim_n].
auto is_in_bound = xla::Reduce(
is_in_bound_x_y, xla::ConstantR0<bool>(ctx->builder(), true),
xla::CreateScalarAndComputation(xla::PrimitiveType::PRED,
ctx->builder()),
{last_warp_dim});
// Broadcast 'is_in_bound' to the same dimension as 'blended_pixels', which
// is the dimension of the result:
// [batch, dim_0, ...dim_n, data_channels].
auto warp_dims = warp_shape.dim_sizes();
std::vector<int64> result_dims(warp_dims.begin(), warp_dims.end() - 1);
result_dims.push_back(data_channels);
std::vector<int64> broadcasted_dims(warp_dims.size() - 1);
std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0);
auto broadcasted_is_in_bound =
xla::BroadcastInDim(is_in_bound, result_dims, broadcasted_dims);
// Set out of bound samples to zero.
auto zeros =
xla::Broadcast(xla::Zero(ctx->builder(), data_type), result_dims);
auto result = xla::Select(broadcasted_is_in_bound, blended_pixels, zeros);
ctx->SetOutput(0, result);
}
};
REGISTER_XLA_OP(Name("Resampler"), ResamplerOp);
class ResamplerGradOp : public XlaOpKernel {
public:
explicit ResamplerGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
DataType output_dtype;
OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype));
}
// TODO(b/112295522): note that sampling from image boundary is not currently
// being handled properly.
void Compile(XlaOpKernelContext* ctx) override {
TensorShape data_shape_tf = ctx->InputShape("data");
OP_REQUIRES(ctx, data_shape_tf.dims() == 4,
errors::InvalidArgument("data must be 4-dimensional",
data_shape_tf.DebugString()));
const int64 data_channels = data_shape_tf.dim_size(3);
xla::PrimitiveType data_type = ctx->input_xla_type(0);
TensorShape warp_shape = ctx->InputShape("warp");
OP_REQUIRES(ctx, warp_shape.dims() >= 2,
errors::InvalidArgument("warp must be at least 2-dimensional",
warp_shape.DebugString()));
for (int size : warp_shape.dim_sizes()) {
OP_REQUIRES(ctx, size > 0,
errors::InvalidArgument("warp sizes must be positive, got [",
size, "]"));
}
// Last dimension of warp shape must be of size 2.
const int64 last_warp_dim = warp_shape.dims() - 1;
OP_REQUIRES(ctx, warp_shape.dim_size(last_warp_dim) == 2,
errors::InvalidArgument(
"the last dimension of warp must be exactly size 2."));
xla::PrimitiveType warp_type = ctx->input_xla_type(1);
TensorShape output_grad_shape = ctx->InputShape("grad_output");
OP_REQUIRES(
ctx, output_grad_shape.dims() >= 2,
errors::InvalidArgument("output_grad must be at least 2-dimensional",
output_grad_shape.DebugString()));
// Dimensions are [batch, x, y, channel].
XlaOp data = ctx->Input("data");
xla::Shape data_shape = TensorShapeToXLAShape(data_type, data_shape_tf);
// Dimensions are [batch, dim_0, ...dim_n, 2].
XlaOp warp = ctx->Input("warp");
// Dimensions are [batch, dim_0, ...dim_n, channel].
XlaOp grad_output = ctx->Input("grad_output");
// Find the top left corner coordinate for the region to be sampled from.
// The dimensions are [batch, dim_0, ... dim_n, 2] where the last dimension
// of size 2 in turn is [x, y].
XlaOp top_left = xla::ConvertElementType(xla::Floor(warp), xla::S32);
// Dimensions are [batch, dim_0, ... dim_n, 2].
XlaOp ratio = warp - xla::ConvertElementType(top_left, warp_type);
// Indices for gathering neighboring pixels.
auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape);
auto grad_data = CalculateGradData(
ctx, grad_output, ratio, gather_indices, warp, warp_type, warp_shape,
last_warp_dim, data_channels, data_shape);
auto grad_warp =
CalculateGradWarp(ctx, grad_output, ratio, gather_indices, data,
warp_shape, data_channels, data_type, data_shape);
auto warp_dims = warp_shape.dim_sizes();
std::vector<int64> result_dims(warp_dims.begin(), warp_dims.end() - 1);
result_dims.push_back(2);
std::vector<int64> broadcasted_dims(warp_dims.size() - 1);
std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0);
auto grad_warp_bounded =
BoundSamples(ctx, warp, warp_type, warp_shape, result_dims,
broadcasted_dims, last_warp_dim, data_shape, grad_warp);
ctx->SetOutput(0, grad_data);
ctx->SetOutput(1, grad_warp_bounded);
}
};
REGISTER_XLA_OP(Name("ResamplerGrad"), ResamplerGradOp);
} // namespace
} // namespace tensorflow