[XLA] Use HloEvaluator for convolution in reference_util.
Also Speed up HloEvaluator's HandleConvolution in non-opt build, by moving calls to HloInstruction::shape() out of the inner loop. PiperOrigin-RevId: 163416183
This commit is contained in:
parent
569a00e681
commit
253bcbb71b
tensorflow/compiler/xla
@ -563,6 +563,9 @@ cc_library(
|
||||
":xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:computation_builder",
|
||||
"//tensorflow/compiler/xla/client:padding",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_evaluator",
|
||||
"//tensorflow/compiler/xla/service:shape_inference",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
|
@ -470,10 +470,11 @@ class Literal {
|
||||
|
||||
// Populates literal values by calling the generator function for every cell
|
||||
// in this literal object.
|
||||
template <typename NativeT>
|
||||
Status Populate(
|
||||
const std::function<NativeT(tensorflow::gtl::ArraySlice<int64> indexes)>&
|
||||
generator);
|
||||
//
|
||||
// generator must be a callable of the type
|
||||
// NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible.
|
||||
template <typename NativeT, typename FnType>
|
||||
Status Populate(const FnType& generator);
|
||||
|
||||
// Creates a Literal of the given dimensions with all elements set to the
|
||||
// given value.
|
||||
@ -1107,12 +1108,10 @@ void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
|
||||
PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4());
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
Status Literal::Populate(
|
||||
const std::function<NativeT(tensorflow::gtl::ArraySlice<int64> indexes)>&
|
||||
generator) {
|
||||
template <typename NativeT, typename FnType>
|
||||
Status Literal::Populate(const FnType& generator) {
|
||||
const Shape& this_shape = shape();
|
||||
int64 rank = ShapeUtil::Rank(this_shape);
|
||||
const int64 rank = ShapeUtil::Rank(this_shape);
|
||||
TF_RET_CHECK(this_shape.element_type() ==
|
||||
primitive_util::NativeToPrimitiveType<NativeT>());
|
||||
tensorflow::gtl::MutableArraySlice<NativeT> data =
|
||||
@ -1125,7 +1124,7 @@ Status Literal::Populate(
|
||||
ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
|
||||
|
||||
auto init_function = [&](const std::vector<int64>& indexes) {
|
||||
int64 index = LinearIndex(indexes);
|
||||
const int64 index = LinearIndex(indexes);
|
||||
std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
|
||||
for (int64 i = 0; i < minor_dimension_size; ++i) {
|
||||
minor_scan_indexes[stride_config.minor_dimension] = i;
|
||||
|
@ -20,6 +20,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||
#include "tensorflow/compiler/xla/window_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/math/math_util.h"
|
||||
@ -446,179 +449,85 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
|
||||
std::pair<int64, int64> kernel_stride, Padding padding,
|
||||
std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation,
|
||||
ConvolutionDimensionNumbers dnums) {
|
||||
std::array<int64, 4> lhs_dimensions{{lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()}};
|
||||
std::array<int64, 4> rhs_dimensions{{rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}};
|
||||
HloComputation::Builder b("ConvArray4DGeneralDimensionDilated");
|
||||
auto lhs_literal = Literal::CreateR4FromArray4D<float>(lhs);
|
||||
auto rhs_literal = Literal::CreateR4FromArray4D<float>(rhs);
|
||||
|
||||
const int64 ksy = kernel_stride.first;
|
||||
const int64 ksx = kernel_stride.second;
|
||||
const int64 dy = lhs_dilation.first;
|
||||
const int64 dx = lhs_dilation.second;
|
||||
const int64 dky = rhs_dilation.first;
|
||||
const int64 dkx = rhs_dilation.second;
|
||||
CHECK_GE(dky, 1);
|
||||
CHECK_GE(dkx, 1);
|
||||
CHECK_GE(dy, 1);
|
||||
CHECK_GE(dx, 1);
|
||||
|
||||
// Get all dimension sizes in lhs and rhs based on the given convolution
|
||||
// dimension configuration.
|
||||
const int64 ix = window_util::DilatedBound(
|
||||
lhs_dimensions[dnums.spatial_dimensions(1)], dx);
|
||||
const int64 iy = window_util::DilatedBound(
|
||||
lhs_dimensions[dnums.spatial_dimensions(0)], dy);
|
||||
const int64 iz = lhs_dimensions[dnums.feature_dimension()];
|
||||
const int64 samples = lhs_dimensions[dnums.batch_dimension()];
|
||||
const int64 kx = window_util::DilatedBound(
|
||||
rhs_dimensions[dnums.kernel_spatial_dimensions(1)], dkx);
|
||||
const int64 ky = window_util::DilatedBound(
|
||||
rhs_dimensions[dnums.kernel_spatial_dimensions(0)], dky);
|
||||
const int64 oz = rhs_dimensions[dnums.kernel_output_feature_dimension()];
|
||||
{
|
||||
const int64 kiz = rhs_dimensions[dnums.kernel_input_feature_dimension()];
|
||||
CHECK_EQ(kiz, iz);
|
||||
std::array<int64, 2> ordered_kernel_strides;
|
||||
std::array<int64, 2> ordered_input_dimensions;
|
||||
std::array<int64, 2> ordered_kernel_dimensions;
|
||||
if (dnums.kernel_spatial_dimensions(0) > dnums.kernel_spatial_dimensions(1)) {
|
||||
ordered_kernel_strides[0] = kernel_stride.second;
|
||||
ordered_kernel_strides[1] = kernel_stride.first;
|
||||
} else {
|
||||
ordered_kernel_strides[0] = kernel_stride.first;
|
||||
ordered_kernel_strides[1] = kernel_stride.second;
|
||||
}
|
||||
|
||||
if (padding == Padding::kSame) {
|
||||
// We reject same padding with kernel striding, since it's somewhat
|
||||
// nonsensical. We can always follow up to implement this with the desired
|
||||
// semantics if anybody actually uses it.
|
||||
CHECK_EQ(1, ksy);
|
||||
CHECK_EQ(1, ksx);
|
||||
}
|
||||
ordered_input_dimensions[0] =
|
||||
lhs_literal->shape().dimensions(dnums.spatial_dimensions(0));
|
||||
ordered_input_dimensions[1] =
|
||||
lhs_literal->shape().dimensions(dnums.spatial_dimensions(1));
|
||||
ordered_kernel_dimensions[0] =
|
||||
rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0));
|
||||
ordered_kernel_dimensions[1] =
|
||||
rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1));
|
||||
|
||||
const int64 ox =
|
||||
padding == Padding::kSame ? ix : window_util::StridedBound(ix, kx, ksx);
|
||||
const int64 oy =
|
||||
padding == Padding::kSame ? iy : window_util::StridedBound(iy, ky, ksy);
|
||||
const int64 istartx =
|
||||
padding == Padding::kValid ? 0 : kx % 2 == 0 ? -(kx / 2 - 1) : -kx / 2;
|
||||
const int64 istarty =
|
||||
padding == Padding::kValid ? 0 : ky % 2 == 0 ? -(ky / 2 - 1) : -ky / 2;
|
||||
// Create the output result array and reset the values to 0.
|
||||
std::array<int64, 4> result_dimensions;
|
||||
result_dimensions[dnums.batch_dimension()] = samples;
|
||||
result_dimensions[dnums.feature_dimension()] = oz;
|
||||
result_dimensions[dnums.spatial_dimensions(0)] = oy;
|
||||
result_dimensions[dnums.spatial_dimensions(1)] = ox;
|
||||
std::vector<std::pair<int64, int64>> paddings =
|
||||
MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
|
||||
ordered_kernel_strides, padding);
|
||||
CHECK_EQ(paddings.size(), 2);
|
||||
|
||||
Window window;
|
||||
|
||||
WindowDimension dim;
|
||||
dim.set_size(
|
||||
rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)));
|
||||
dim.set_stride(kernel_stride.first);
|
||||
dim.set_padding_low(paddings[0].first);
|
||||
dim.set_padding_high(paddings[0].second);
|
||||
dim.set_window_dilation(rhs_dilation.first);
|
||||
dim.set_base_dilation(lhs_dilation.first);
|
||||
*window.add_dimensions() = dim;
|
||||
|
||||
WindowDimension dim2;
|
||||
dim2.set_size(
|
||||
rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)));
|
||||
dim2.set_stride(kernel_stride.second);
|
||||
dim2.set_padding_low(paddings[1].first);
|
||||
dim2.set_padding_high(paddings[1].second);
|
||||
dim2.set_window_dilation(rhs_dilation.second);
|
||||
dim2.set_base_dilation(lhs_dilation.second);
|
||||
*window.add_dimensions() = dim2;
|
||||
|
||||
const Shape& shape =
|
||||
ShapeInference::InferConvolveShape(lhs_literal->shape(),
|
||||
rhs_literal->shape(), window, dnums)
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
HloInstruction* lhs_instruction =
|
||||
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
|
||||
HloInstruction* rhs_instruction =
|
||||
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
|
||||
|
||||
b.AddInstruction(HloInstruction::CreateConvolve(
|
||||
shape, lhs_instruction, rhs_instruction, window, dnums));
|
||||
|
||||
HloEvaluator evaluator;
|
||||
std::unique_ptr<Literal> result_literal =
|
||||
evaluator.Evaluate(*b.Build(), {}).ConsumeValueOrDie();
|
||||
|
||||
CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4);
|
||||
auto result =
|
||||
MakeUnique<Array4D<float>>(result_dimensions[0], result_dimensions[1],
|
||||
result_dimensions[2], result_dimensions[3]);
|
||||
result->Fill(0.0);
|
||||
MakeUnique<Array4D<float>>(result_literal->shape().dimensions(0),
|
||||
result_literal->shape().dimensions(1),
|
||||
result_literal->shape().dimensions(2),
|
||||
result_literal->shape().dimensions(3));
|
||||
|
||||
const auto is_int32 = [](int64 x) {
|
||||
return x >= std::numeric_limits<int32>::min() &&
|
||||
x <= std::numeric_limits<int32>::max();
|
||||
};
|
||||
result->Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) {
|
||||
*value = result_literal->Get<float>(indices);
|
||||
});
|
||||
|
||||
// 64-bit idiv/mod are much more expensive x86-64 than 32-bit idiv/imod (at
|
||||
// least on x86-64), so we avoid them where possible.
|
||||
const auto fast_idiv64 = [&](int64 a, int64 b) {
|
||||
if (is_int32(a) && is_int32(b)) {
|
||||
return static_cast<int64>(static_cast<int32>(a) / static_cast<int32>(b));
|
||||
}
|
||||
return a / b;
|
||||
};
|
||||
const auto fast_imod64 = [&](int64 a, int64 b) {
|
||||
if (is_int32(a) && is_int32(b)) {
|
||||
return static_cast<int64>(static_cast<int32>(a) % static_cast<int32>(b));
|
||||
}
|
||||
return a % b;
|
||||
};
|
||||
|
||||
// Lambda to access the lhs operand at the given 4D index.
|
||||
const auto lhs_element = [&](int64 batch, int64 feature, int64 height,
|
||||
int64 width) {
|
||||
if (fast_imod64(height, dy) != 0 || fast_imod64(width, dx) != 0) {
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
std::array<int64, 4> index;
|
||||
index[dnums.batch_dimension()] = batch;
|
||||
index[dnums.feature_dimension()] = feature;
|
||||
index[dnums.spatial_dimensions(0)] = fast_idiv64(height, dy);
|
||||
index[dnums.spatial_dimensions(1)] = fast_idiv64(width, dx);
|
||||
return lhs(index[0], index[1], index[2], index[3]);
|
||||
};
|
||||
|
||||
// Lambda to access the rhs operand at the given 4D index. height_over_dky
|
||||
// should be equal to height / dky, and width_over_dkx should be equal to
|
||||
// width / dkx. (This is an optimization to avoid doing divisions.)
|
||||
const auto rhs_element =
|
||||
[&](int64 kernel_output_feature, int64 kernel_input_feature, int64 height,
|
||||
int64 width, int64 height_over_dky, int64 width_over_dkx) {
|
||||
DCHECK_EQ(height % dky, 0);
|
||||
DCHECK_EQ(width % dkx, 0);
|
||||
DCHECK_EQ(height / dky, height_over_dky);
|
||||
DCHECK_EQ(width / dkx, width_over_dkx);
|
||||
|
||||
std::array<int64, 4> index;
|
||||
index[dnums.kernel_output_feature_dimension()] = kernel_output_feature;
|
||||
index[dnums.kernel_input_feature_dimension()] = kernel_input_feature;
|
||||
index[dnums.kernel_spatial_dimensions(0)] = height_over_dky;
|
||||
index[dnums.kernel_spatial_dimensions(1)] = width_over_dkx;
|
||||
return rhs(index[0], index[1], index[2], index[3]);
|
||||
};
|
||||
|
||||
// Lambda to access the result data at the given 4D index.
|
||||
const auto result_element = [&](int64 batch, int64 kernel_output_feature,
|
||||
int64 height, int64 width) -> float& {
|
||||
std::array<int64, 4> index;
|
||||
index[dnums.batch_dimension()] = batch;
|
||||
index[dnums.feature_dimension()] = kernel_output_feature;
|
||||
index[dnums.spatial_dimensions(0)] = height;
|
||||
index[dnums.spatial_dimensions(1)] = width;
|
||||
return (*result)(index[0], index[1], index[2], index[3]);
|
||||
};
|
||||
|
||||
for (int64 oyi = 0; oyi < oy; ++oyi) {
|
||||
for (int64 oxi = 0; oxi < ox; ++oxi) {
|
||||
for (int64 sample = 0; sample < samples; ++sample) {
|
||||
for (int64 izi = 0; izi < iz; ++izi) {
|
||||
for (int64 ozi = 0; ozi < oz; ++ozi) {
|
||||
for (int64 kyi = 0, kyi_over_dky = 0; kyi < ky;
|
||||
kyi += dky, kyi_over_dky++) {
|
||||
for (int64 kxi = 0, kxi_over_dkx = 0; kxi < kx;
|
||||
kxi += dkx, kxi_over_dkx++) {
|
||||
int64 iyi = istarty + ksy * oyi + kyi;
|
||||
int64 ixi = istartx + ksx * oxi + kxi;
|
||||
float input = (iyi >= iy || ixi >= ix || iyi < 0 || ixi < 0)
|
||||
? 0.0
|
||||
: lhs_element(sample, izi, iyi, ixi);
|
||||
float gain =
|
||||
rhs_element(ozi, izi, kyi, kxi, kyi_over_dky, kxi_over_dkx);
|
||||
float addend = input * gain;
|
||||
result_element(sample, ozi, oyi, oxi) += addend;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (samples == 0 || kx == 0 || ky == 0 || ox == 0 || oy == 0 || oz == 0 ||
|
||||
iz == 0) {
|
||||
LOG(INFO) << "Output will be trivially empty because one of these "
|
||||
"dimensions is 0: samples: "
|
||||
<< samples << " kx: " << kx << " ky: " << ky << " ox: " << ox
|
||||
<< " oy: " << oy << " oz: " << oz << " iz: " << iz;
|
||||
return result;
|
||||
}
|
||||
bool trivial = true;
|
||||
auto check_trivial = [&trivial](tensorflow::gtl::ArraySlice<int64> indices,
|
||||
float value) {
|
||||
if (value != 0.0) {
|
||||
trivial = false;
|
||||
}
|
||||
};
|
||||
lhs.Each(check_trivial);
|
||||
if (trivial) {
|
||||
LOG(FATAL) << "LHS is all 0.0.";
|
||||
}
|
||||
trivial = true;
|
||||
rhs.Each(check_trivial);
|
||||
if (trivial) {
|
||||
LOG(FATAL) << "RHS is all 0.0.";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -404,12 +404,16 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
Status HandleConvolution(HloInstruction* conv, HloInstruction* lhs,
|
||||
HloInstruction* rhs, const Window& window) override {
|
||||
CHECK(ShapeUtil::IsArray(lhs->shape()));
|
||||
CHECK(ShapeUtil::IsArray(rhs->shape()));
|
||||
CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape()));
|
||||
CHECK(ShapeUtil::SameElementType(lhs->shape(), conv->shape()));
|
||||
TF_CHECK_OK(ShapeUtil::ValidateShape(lhs->shape()));
|
||||
TF_CHECK_OK(ShapeUtil::ValidateShape(rhs->shape()));
|
||||
const Shape& result_shape = conv->shape();
|
||||
const Shape& lhs_shape = lhs->shape();
|
||||
const Shape& rhs_shape = rhs->shape();
|
||||
|
||||
TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape));
|
||||
TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape));
|
||||
CHECK(ShapeUtil::IsArray(lhs_shape));
|
||||
CHECK(ShapeUtil::IsArray(rhs_shape));
|
||||
CHECK(ShapeUtil::SameElementType(lhs_shape, rhs_shape));
|
||||
CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape));
|
||||
|
||||
const auto& dnums = conv->convolution_dimension_numbers();
|
||||
const int64 num_spatial_dims = dnums.spatial_dimensions_size();
|
||||
@ -417,23 +421,23 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
CHECK_GE(num_spatial_dims, 1);
|
||||
CHECK_EQ(window.dimensions_size(), num_spatial_dims);
|
||||
|
||||
CHECK_EQ(num_spatial_dims + 2, ShapeUtil::Rank(lhs->shape()));
|
||||
CHECK_EQ(num_spatial_dims + 2, ShapeUtil::Rank(rhs->shape()));
|
||||
const auto lhs_rank = ShapeUtil::Rank(lhs_shape);
|
||||
const auto rhs_rank = ShapeUtil::Rank(rhs_shape);
|
||||
|
||||
CHECK_EQ(num_spatial_dims + 2, lhs_rank);
|
||||
CHECK_EQ(num_spatial_dims + 2, rhs_rank);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
lhs->shape(), rhs->shape(), window, dnums));
|
||||
CHECK(ShapeUtil::Compatible(conv->shape(), inferred_return_shape))
|
||||
<< "return shape set to: " << ShapeUtil::HumanString(conv->shape())
|
||||
ShapeInference::InferConvolveShape(lhs_shape, rhs_shape,
|
||||
window, dnums));
|
||||
CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
|
||||
<< "return shape set to: " << ShapeUtil::HumanString(result_shape)
|
||||
<< " but is inferred to be: "
|
||||
<< ShapeUtil::HumanString(inferred_return_shape);
|
||||
|
||||
const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
|
||||
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
|
||||
|
||||
const auto lhs_rank = ShapeUtil::Rank(lhs->shape());
|
||||
const auto rhs_rank = ShapeUtil::Rank(rhs->shape());
|
||||
|
||||
// Dimension number applicable for both input (lhs), and output.
|
||||
const int64 batch_dim = dnums.batch_dimension();
|
||||
const int64 z_dim = dnums.feature_dimension();
|
||||
@ -441,78 +445,78 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension();
|
||||
const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension();
|
||||
|
||||
const int64 z_size = ShapeUtil::GetDimension(lhs->shape(), z_dim);
|
||||
const int64 z_size = ShapeUtil::GetDimension(lhs_shape, z_dim);
|
||||
|
||||
std::vector<int64> window_dimension_sizes;
|
||||
for (auto i : dnums.kernel_spatial_dimensions()) {
|
||||
window_dimension_sizes.push_back(
|
||||
ShapeUtil::GetDimension(rhs->shape(), i));
|
||||
window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i));
|
||||
}
|
||||
|
||||
const Shape& window_shape = ShapeUtil::MakeShape(
|
||||
rhs->shape().element_type(), window_dimension_sizes);
|
||||
const Shape& window_shape =
|
||||
ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes);
|
||||
|
||||
auto result = Literal::CreateFromShape(conv->shape());
|
||||
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> out_index) {
|
||||
ReturnT result_val = static_cast<ReturnT>(0);
|
||||
DimensionVector lhs_index(lhs_rank);
|
||||
DimensionVector rhs_index(rhs_rank);
|
||||
DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size());
|
||||
|
||||
std::vector<int64> lhs_index(lhs_rank, 0);
|
||||
std::vector<int64> rhs_index(rhs_rank, 0);
|
||||
auto func = [&](tensorflow::gtl::ArraySlice<int64> out_index) {
|
||||
ReturnT result_val = static_cast<ReturnT>(0);
|
||||
|
||||
lhs_index[batch_dim] = out_index[batch_dim];
|
||||
rhs_index[kernel_output_z_dim] = out_index[z_dim];
|
||||
std::fill(lhs_index.begin(), lhs_index.end(), 0);
|
||||
std::fill(rhs_index.begin(), rhs_index.end(), 0);
|
||||
std::fill(rhs_spatial_index.begin(), rhs_spatial_index.end(), 0);
|
||||
|
||||
std::vector<int64> rhs_spatial_index(
|
||||
dnums.kernel_spatial_dimensions_size(), 0);
|
||||
lhs_index[batch_dim] = out_index[batch_dim];
|
||||
rhs_index[kernel_output_z_dim] = out_index[z_dim];
|
||||
|
||||
// Convolve input feature with kernel.
|
||||
do {
|
||||
for (int64 iz = 0; iz < z_size; ++iz) {
|
||||
lhs_index[z_dim] = iz;
|
||||
rhs_index[kernel_input_z_dim] = iz;
|
||||
// Convolve input feature with kernel.
|
||||
do {
|
||||
for (int64 iz = 0; iz < z_size; ++iz) {
|
||||
lhs_index[z_dim] = iz;
|
||||
rhs_index[kernel_input_z_dim] = iz;
|
||||
|
||||
// Find corresponding spatial dimension index for input (lhs).
|
||||
for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) {
|
||||
// Spatial dimension number for input (lhs) and output.
|
||||
const int64 spatial_dim = dnums.spatial_dimensions(ki);
|
||||
// Find corresponding spatial dimension index for input (lhs).
|
||||
for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) {
|
||||
// Spatial dimension number for input (lhs) and output.
|
||||
const int64 spatial_dim = dnums.spatial_dimensions(ki);
|
||||
|
||||
// Calculate lhs (input) index without taking base dilation into
|
||||
// account.
|
||||
const int64 undilated_index =
|
||||
out_index[spatial_dim] * window.dimensions(ki).stride() -
|
||||
window.dimensions(ki).padding_low() +
|
||||
rhs_spatial_index[ki] *
|
||||
window.dimensions(ki).window_dilation();
|
||||
// Skip if the lhs (input) index is to be dilated.
|
||||
if (undilated_index % window.dimensions(ki).base_dilation() !=
|
||||
0) {
|
||||
goto cnt;
|
||||
}
|
||||
|
||||
// Calculate the actual lhs (input) index after dilation.
|
||||
lhs_index[spatial_dim] =
|
||||
undilated_index / window.dimensions(ki).base_dilation();
|
||||
|
||||
// Skip if input index is not in bound.
|
||||
if (!(lhs_index[spatial_dim] >= 0 &&
|
||||
lhs_index[spatial_dim] <
|
||||
lhs->shape().dimensions(spatial_dim))) {
|
||||
goto cnt;
|
||||
}
|
||||
|
||||
rhs_index[dnums.kernel_spatial_dimensions(ki)] =
|
||||
rhs_spatial_index[ki];
|
||||
}
|
||||
|
||||
result_val += lhs_literal.Get<ReturnT>(lhs_index) *
|
||||
rhs_literal.Get<ReturnT>(rhs_index);
|
||||
// Calculate lhs (input) index without taking base dilation into
|
||||
// account.
|
||||
const auto& window_dim = window.dimensions(ki);
|
||||
const int64 undilated_index =
|
||||
out_index[spatial_dim] * window_dim.stride() -
|
||||
window_dim.padding_low() +
|
||||
rhs_spatial_index[ki] * window_dim.window_dilation();
|
||||
// Skip if the lhs (input) index is to be dilated.
|
||||
if (undilated_index % window_dim.base_dilation() != 0) {
|
||||
goto cnt;
|
||||
}
|
||||
cnt:;
|
||||
} while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index));
|
||||
|
||||
return result_val;
|
||||
}));
|
||||
// Calculate the actual lhs (input) index after dilation.
|
||||
lhs_index[spatial_dim] =
|
||||
undilated_index / window_dim.base_dilation();
|
||||
|
||||
// Skip if input index is not in bound.
|
||||
if (!(lhs_index[spatial_dim] >= 0 &&
|
||||
lhs_index[spatial_dim] < lhs_shape.dimensions(spatial_dim))) {
|
||||
goto cnt;
|
||||
}
|
||||
|
||||
rhs_index[dnums.kernel_spatial_dimensions(ki)] =
|
||||
rhs_spatial_index[ki];
|
||||
}
|
||||
|
||||
result_val += lhs_literal.Get<ReturnT>(lhs_index) *
|
||||
rhs_literal.Get<ReturnT>(rhs_index);
|
||||
}
|
||||
cnt:;
|
||||
} while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index));
|
||||
|
||||
return result_val;
|
||||
};
|
||||
|
||||
auto result = Literal::CreateFromShape(result_shape);
|
||||
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
|
||||
|
||||
parent_->evaluated_[conv] = std::move(result);
|
||||
return Status::OK();
|
||||
|
@ -1218,34 +1218,4 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
|
||||
return shape;
|
||||
}
|
||||
|
||||
/* static */ void ShapeUtil::ForEachIndex(
|
||||
const Shape& shape, tensorflow::gtl::ArraySlice<int64> base,
|
||||
tensorflow::gtl::ArraySlice<int64> count,
|
||||
tensorflow::gtl::ArraySlice<int64> incr,
|
||||
const IndexVisitorFunction& visitor_function) {
|
||||
if (ShapeUtil::HasZeroElements(shape)) {
|
||||
return;
|
||||
}
|
||||
DCHECK_EQ(Rank(shape), base.size());
|
||||
DCHECK_EQ(incr.size(), base.size());
|
||||
DCHECK_EQ(count.size(), base.size());
|
||||
const Layout& layout = shape.layout();
|
||||
int64 rank = layout.minor_to_major_size();
|
||||
// Allows handling R0 arrays, such that the visitor function will be called
|
||||
// once with the proper empty indexes.
|
||||
int64 n = -1;
|
||||
std::vector<int64> indexes(base.begin(), base.end());
|
||||
while (n < rank && visitor_function(indexes)) {
|
||||
// Increments dimensions in minor to major order.
|
||||
for (n = 0; n < rank; ++n) {
|
||||
int64 dim = layout.minor_to_major(n);
|
||||
indexes[dim] += incr[dim];
|
||||
if (indexes[dim] < base[dim] + count[dim]) {
|
||||
break;
|
||||
}
|
||||
indexes[dim] = base[dim];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -421,12 +421,39 @@ class ShapeUtil {
|
||||
// current index.
|
||||
// The visitor_function visitor function should return true if it wants to
|
||||
// continue, or false otherwise.
|
||||
using IndexVisitorFunction = std::function<bool(const std::vector<int64>&)>;
|
||||
//
|
||||
// visitor_function must be a callable of type bool(const std::vector<int64>&)
|
||||
// or compatible.
|
||||
template <typename FnType>
|
||||
static void ForEachIndex(const Shape& shape,
|
||||
tensorflow::gtl::ArraySlice<int64> base,
|
||||
tensorflow::gtl::ArraySlice<int64> count,
|
||||
tensorflow::gtl::ArraySlice<int64> incr,
|
||||
const IndexVisitorFunction& visitor_function);
|
||||
const FnType& visitor_function) {
|
||||
if (ShapeUtil::HasZeroElements(shape)) {
|
||||
return;
|
||||
}
|
||||
CHECK_EQ(Rank(shape), base.size());
|
||||
CHECK_EQ(incr.size(), base.size());
|
||||
CHECK_EQ(count.size(), base.size());
|
||||
const Layout& layout = shape.layout();
|
||||
const int64 rank = layout.minor_to_major_size();
|
||||
// Allows handling R0 arrays, such that the visitor function will be called
|
||||
// once with the proper empty indexes.
|
||||
int64 n = -1;
|
||||
std::vector<int64> indexes(base.begin(), base.end());
|
||||
while (n < rank && visitor_function(indexes)) {
|
||||
// Increments dimensions in minor to major order.
|
||||
for (n = 0; n < rank; ++n) {
|
||||
int64 dim = layout.minor_to_major(n);
|
||||
indexes[dim] += incr[dim];
|
||||
if (indexes[dim] < base[dim] + count[dim]) {
|
||||
break;
|
||||
}
|
||||
indexes[dim] = base[dim];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Validates all of the non-layout properties of the shape -- this is a helper
|
||||
|
Loading…
Reference in New Issue
Block a user