Update contraction mapper benchmark

PiperOrigin-RevId: 219883684
This commit is contained in:
Eugene Zhulenev 2018-11-02 16:48:49 -07:00 committed by TensorFlower Gardener
parent bd08356cba
commit 2c923299cc

View File

@ -1392,14 +1392,17 @@ static void PackRhsHelper(int iters,
static const int packet_size = Eigen::internal::packet_traits<float>::size;
// Reshape dimensions.
using NewDimension = Eigen::array<Eigen::Index, 2>;
using NewDimension = Eigen::DSizes<Index, 2>;
// Contraction dimensions.
using nocontract_t = Eigen::array<Eigen::Index, 1>;
using contract_t = Eigen::array<Eigen::Index, 1>;
// Input to the TensorImagePatchOp.
using ArgType = Tensor<float, 4>;
// Input to the TensorImagePatchOp. It is the tensorflow TTypes<float>::Tensor
// with ColMajor layout, instead of RowMajor. But that doesn't make any
// difference, because TensorContraction swaps LHS with RHS for row major
// inputs, and contraction mapper always works with column major data.
using ArgType = TensorMap<Tensor<float, 4>, Eigen::Aligned>;
using Evaluator = TensorEvaluator<
const TensorReshapingOp<
@ -1454,9 +1457,11 @@ static void PackRhsHelper(int iters,
inputs.emplace_back(input_dims);
inputs[i].setRandom();
ArgType tensor_map(inputs[i].data(), input_dims);
// 1. Extract image patches from input tensor. All strides are `1`.
const auto image_patch_op = TensorImagePatchOp<Dynamic, Dynamic, ArgType>(
inputs[i], //
tensor_map, //
filter_rows, filter_cols, //
/*row_strides=*/1, /*col_strides=*/1, //
/*in_row_strides=*/1, /*in_col_strides=*/1, //
@ -1464,10 +1469,11 @@ static void PackRhsHelper(int iters,
Eigen::PADDING_SAME, /*padding_value=*/0.0);
// 2. Reshape extracted patches into "virtual" 2d tensor.
NewDimension reshape_dims = {
input_depth * filter_rows * filter_cols, // patch size
// PADDING_SAME: output {rows, cols} == input {rows, cols}
input_rows * input_cols * input_batches}; // num_patches
// NOTE: for PADDING_SAME output {rows, cols} == input {rows, cols}.
NewDimension reshape_dims;
reshape_dims[0] = input_depth * filter_rows * filter_cols; // patch size
reshape_dims[1] = input_rows * input_cols * input_batches; // num_patches
const auto reshape_op =
TensorReshapingOp<NewDimension, decltype(image_patch_op)>(
image_patch_op, reshape_dims);