Merge pull request #24078 from trevor-m:tmorris_tftrt_strided_slice_op
PiperOrigin-RevId: 225296218
This commit is contained in:
commit
5ce09c9962
@ -89,51 +89,52 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) {
|
|||||||
// TODO(laigd): move this set to TrtNodeValidator where it should belong.
|
// TODO(laigd): move this set to TrtNodeValidator where it should belong.
|
||||||
// LINT.IfChange
|
// LINT.IfChange
|
||||||
static const std::set<string> candidate_ops = {
|
static const std::set<string> candidate_ops = {
|
||||||
"Identity",
|
"Abs",
|
||||||
"Snapshot",
|
"Add",
|
||||||
|
"AvgPool",
|
||||||
|
"BatchMatMul",
|
||||||
|
"BiasAdd",
|
||||||
|
"ConcatV2",
|
||||||
"Const",
|
"Const",
|
||||||
"Conv2D",
|
"Conv2D",
|
||||||
"MaxPool",
|
|
||||||
"BiasAdd",
|
|
||||||
"Relu",
|
|
||||||
"Sigmoid",
|
|
||||||
"Tanh",
|
|
||||||
"Add",
|
|
||||||
"Mul",
|
|
||||||
"Sub",
|
|
||||||
"Rsqrt",
|
|
||||||
"Pad",
|
|
||||||
"Mean",
|
|
||||||
"AvgPool",
|
|
||||||
"ConcatV2",
|
|
||||||
"DepthwiseConv2dNative",
|
"DepthwiseConv2dNative",
|
||||||
|
"Div",
|
||||||
|
"Exp",
|
||||||
|
"ExpandDims",
|
||||||
"FusedBatchNorm",
|
"FusedBatchNorm",
|
||||||
"FusedBatchNormV2",
|
"FusedBatchNormV2",
|
||||||
"Div",
|
"Identity",
|
||||||
"RealDiv",
|
|
||||||
"Rsqrt",
|
|
||||||
"Reciprocal",
|
|
||||||
"Exp",
|
|
||||||
"Log",
|
"Log",
|
||||||
"Sqrt",
|
|
||||||
"Abs",
|
|
||||||
"Neg",
|
|
||||||
"Transpose",
|
|
||||||
"Reshape",
|
|
||||||
"MatMul",
|
"MatMul",
|
||||||
"BatchMatMul",
|
|
||||||
"Softmax",
|
|
||||||
"Minimum",
|
|
||||||
"Maximum",
|
|
||||||
"TopKV2",
|
|
||||||
"Sum",
|
|
||||||
"Prod",
|
|
||||||
"Max",
|
"Max",
|
||||||
|
"MaxPool",
|
||||||
|
"Maximum",
|
||||||
|
"Mean",
|
||||||
"Min",
|
"Min",
|
||||||
|
"Minimum",
|
||||||
|
"Mul",
|
||||||
|
"Neg",
|
||||||
|
"Pad",
|
||||||
|
"Prod",
|
||||||
|
"RealDiv",
|
||||||
|
"Reciprocal",
|
||||||
|
"Relu",
|
||||||
"Relu6",
|
"Relu6",
|
||||||
|
"Reshape",
|
||||||
|
"Rsqrt",
|
||||||
|
"Rsqrt",
|
||||||
|
"Sigmoid",
|
||||||
|
"Snapshot",
|
||||||
|
"Softmax",
|
||||||
|
"Sqrt",
|
||||||
"Square",
|
"Square",
|
||||||
"ExpandDims",
|
|
||||||
"Squeeze",
|
"Squeeze",
|
||||||
|
"StridedSlice",
|
||||||
|
"Sub",
|
||||||
|
"Sum",
|
||||||
|
"Tanh",
|
||||||
|
"TopKV2",
|
||||||
|
"Transpose",
|
||||||
};
|
};
|
||||||
bool is_supported_op_type =
|
bool is_supported_op_type =
|
||||||
(candidate_ops.count(node->type_string()) ||
|
(candidate_ops.count(node->type_string()) ||
|
||||||
|
@ -632,6 +632,11 @@ bool TFAttrs::get<bool>(const string& key) const {
|
|||||||
return this->at(key)->b();
|
return this->at(key)->b();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
int TFAttrs::get<int>(const string& key) const {
|
||||||
|
return this->at(key)->i();
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(jie): reorder4 & reorder2 should be merged?
|
// TODO(jie): reorder4 & reorder2 should be merged?
|
||||||
// TODO(aaroey): fix the order of parameters.
|
// TODO(aaroey): fix the order of parameters.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -2028,6 +2033,245 @@ tensorflow::Status ConvertSqueeze(OpConverterParams* params) {
|
|||||||
return tensorflow::Status::OK();
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Gets the bounds (start or end) from the weights of a StridedSlice op.
|
||||||
|
tensorflow::Status GetStridedSliceBound(const std::vector<int>& input_dims,
|
||||||
|
const TRT_ShapedWeights& bound_weights,
|
||||||
|
int mask, bool begin, string node_name,
|
||||||
|
std::vector<int>* output_bound) {
|
||||||
|
const string bound_name = (begin) ? "begin" : "end";
|
||||||
|
const int* weights_ptr = static_cast<int*>(bound_weights.GetValues());
|
||||||
|
*output_bound =
|
||||||
|
std::vector<int>(weights_ptr, weights_ptr + bound_weights.count());
|
||||||
|
if (output_bound->size() != input_dims.size()) {
|
||||||
|
return tensorflow::errors::InvalidArgument(
|
||||||
|
"StridedSlice \"", bound_name, "\" specified ",
|
||||||
|
std::to_string(output_bound->size()), " dimensions, but input rank is ",
|
||||||
|
std::to_string(input_dims.size()), ", at ", node_name);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < output_bound->size(); i++) {
|
||||||
|
if ((1 << i) & mask) {
|
||||||
|
// Apply mask.
|
||||||
|
(*output_bound)[i] = (begin) ? 0 : input_dims[i];
|
||||||
|
// Masked bound will always result in a valid, non-negative bound, so we
|
||||||
|
// don't need the following checks. For the common case of using masks on
|
||||||
|
// a undefined batch dim (-1), we specifically don't want to do the
|
||||||
|
// following checks because they will erroneously detect an out of range
|
||||||
|
// bound or try to correct the negative value.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Make sure bound is valid.
|
||||||
|
if (((*output_bound)[i] < -input_dims[i]) ||
|
||||||
|
((*output_bound)[i] > input_dims[i])) {
|
||||||
|
return tensorflow::errors::InvalidArgument(
|
||||||
|
bound_name, " value of ", std::to_string((*output_bound)[i]),
|
||||||
|
" for StridedSlice is invalid, must be in the range "
|
||||||
|
"[-dim_size(i), dim_size(i)], at ",
|
||||||
|
node_name);
|
||||||
|
}
|
||||||
|
// Convert negative values to their positive equivalent.
|
||||||
|
if ((*output_bound)[i] < 0) {
|
||||||
|
(*output_bound)[i] += input_dims[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::Status ConvertStridedSlice(OpConverterParams* params) {
|
||||||
|
const auto& inputs = params->inputs;
|
||||||
|
const auto& node_def = params->node_def;
|
||||||
|
if (inputs.size() != 4) {
|
||||||
|
return tensorflow::errors::InvalidArgument(
|
||||||
|
"StridedSlice expects 4 inputs, at ", node_def.name());
|
||||||
|
}
|
||||||
|
if (!inputs.at(1).is_weights() || !inputs.at(2).is_weights() ||
|
||||||
|
!inputs.at(3).is_weights()) {
|
||||||
|
return tensorflow::errors::InvalidArgument(
|
||||||
|
"StridedSlice expects weights for begin, end, and strides, at ",
|
||||||
|
node_def.name());
|
||||||
|
}
|
||||||
|
if (!inputs.at(0).is_tensor()) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"StridedSlice is only implemented for tensors, at ", node_def.name());
|
||||||
|
}
|
||||||
|
// Get input dims.
|
||||||
|
nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
|
||||||
|
std::vector<int> input_dims(dims.d, dims.d + dims.nbDims);
|
||||||
|
if (inputs.at(0).is_tensor()) {
|
||||||
|
// Temporarily add batch dimension so that indexes line up properly.
|
||||||
|
input_dims.insert(input_dims.begin(), inputs.at(0).batch_size());
|
||||||
|
}
|
||||||
|
if (input_dims.size() > 4) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"StridedSlice is not implemented for tensors with rank > 4, at ",
|
||||||
|
node_def.name());
|
||||||
|
}
|
||||||
|
TFAttrs attrs(node_def);
|
||||||
|
// Get begin and end bounds per axis.
|
||||||
|
std::vector<int> begin, end;
|
||||||
|
TF_RETURN_IF_ERROR(GetStridedSliceBound(input_dims, inputs.at(1).weights(),
|
||||||
|
attrs.get<int>("begin_mask"), true,
|
||||||
|
node_def.name(), &begin));
|
||||||
|
TF_RETURN_IF_ERROR(GetStridedSliceBound(input_dims, inputs.at(2).weights(),
|
||||||
|
attrs.get<int>("end_mask"), false,
|
||||||
|
node_def.name(), &end));
|
||||||
|
// Get strides per axis (must all be 1).
|
||||||
|
TRT_ShapedWeights stride_weights = inputs.at(3).weights();
|
||||||
|
const int* stride_weights_ptr = static_cast<int*>(stride_weights.GetValues());
|
||||||
|
std::vector<int> strides(stride_weights_ptr,
|
||||||
|
stride_weights_ptr + stride_weights.count());
|
||||||
|
for (int x : strides) {
|
||||||
|
if (x != 1) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"StridedSlice is only implemented for stride of 1, at ",
|
||||||
|
node_def.name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Unsupported mask options.
|
||||||
|
for (const string& attr :
|
||||||
|
{"ellipsis_mask", "new_axis_mask", "shrink_axis_mask"}) {
|
||||||
|
int attr_val = attrs.get<int>(attr);
|
||||||
|
if (attr_val != 0) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
attr, " is not supported for StridedSlice, at ", node_def.name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::ITensor* tensor =
|
||||||
|
const_cast<nvinfer1::ITensor*>(inputs.at(0).tensor());
|
||||||
|
// Reshape if necessary to 4-D, since IPaddingLayer requires a 4-D input.
|
||||||
|
const bool need_reshape = (input_dims.size() != 4);
|
||||||
|
int reshape_dims_added = 0;
|
||||||
|
nvinfer1::Dims reshape_dims;
|
||||||
|
if (need_reshape) {
|
||||||
|
// Add new dims after batch dim until tensor is 4D.
|
||||||
|
while (input_dims.size() < 4) {
|
||||||
|
input_dims.insert(input_dims.begin() + 1, 1);
|
||||||
|
begin.insert(begin.begin() + 1, 0);
|
||||||
|
end.insert(end.begin() + 1, 1);
|
||||||
|
reshape_dims_added++;
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &reshape_dims,
|
||||||
|
/*ignore_first_dim=*/true));
|
||||||
|
}
|
||||||
|
// Find dimensions which need to be sliced.
|
||||||
|
std::vector<int> pad_dims;
|
||||||
|
for (int i = 0; i < input_dims.size(); i++) {
|
||||||
|
if ((begin[i] != 0) || (end[i] != input_dims[i])) {
|
||||||
|
if (i == 0) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"StridedSlice can't modify batch dim, at ", node_def.name());
|
||||||
|
} else if ((end[i] - begin[i]) < 0) {
|
||||||
|
return tensorflow::errors::InvalidArgument(
|
||||||
|
"New size of sliced dimension is negative, at ", node_def.name());
|
||||||
|
}
|
||||||
|
pad_dims.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (pad_dims.size() == 0) {
|
||||||
|
// No dimensions are changed. We could create a padding layer anyway with
|
||||||
|
// values of 0.
|
||||||
|
if (params->validation_only) return Status::OK();
|
||||||
|
params->outputs->push_back(inputs.at(0));
|
||||||
|
return tensorflow::Status::OK();
|
||||||
|
} else if (pad_dims.size() == 1) {
|
||||||
|
// Only one dim is modified but we have to have 2, mark a second dim which
|
||||||
|
// will have padding of 0. The dim we add is chosen to avoid an unecessary
|
||||||
|
// transpose.
|
||||||
|
if (pad_dims[0] != 2) {
|
||||||
|
pad_dims.push_back(2);
|
||||||
|
} else {
|
||||||
|
pad_dims.push_back(3);
|
||||||
|
}
|
||||||
|
} else if (pad_dims.size() > 2) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"StridedSlice can only modify 2 dimensions, at ", node_def.name());
|
||||||
|
}
|
||||||
|
std::sort(pad_dims.begin(), pad_dims.end());
|
||||||
|
// Convert to pre/post padding values. Since TRT does not have a StridedSlice
|
||||||
|
// or Slice layer, we instead create an IPaddingLayer with negative padding.
|
||||||
|
nvinfer1::DimsHW pre_padding, post_padding;
|
||||||
|
for (int i = 0; i < pad_dims.size(); i++) {
|
||||||
|
const int axis = pad_dims[i];
|
||||||
|
pre_padding.d[i] = -begin[axis];
|
||||||
|
post_padding.d[i] = end[axis] - input_dims[axis];
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPaddingLayer will always apply the padding to dims 2,3 (input format is
|
||||||
|
// NCHW).
|
||||||
|
const bool need_transpose = !(pad_dims[0] == 2 && pad_dims[1] == 3);
|
||||||
|
std::vector<int> transpose_order(input_dims.size());
|
||||||
|
std::vector<int> inv_transpose_order(input_dims.size());
|
||||||
|
if (need_transpose) {
|
||||||
|
if (pad_dims[0] == 1 && pad_dims[1] == 3) {
|
||||||
|
transpose_order = {0, 2, 1, 3};
|
||||||
|
inv_transpose_order = {0, 2, 1, 3};
|
||||||
|
} else if (pad_dims[0] == 1 && pad_dims[1] == 2) {
|
||||||
|
transpose_order = {0, 3, 1, 2};
|
||||||
|
inv_transpose_order = {0, 2, 3, 1};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (params->validation_only) return Status::OK();
|
||||||
|
|
||||||
|
// Start conversion.
|
||||||
|
if (need_reshape) {
|
||||||
|
const nvinfer1::ITensor* output_tensor = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
|
||||||
|
inputs.at(0), reshape_dims, &output_tensor));
|
||||||
|
tensor = const_cast<nvinfer1::ITensor*>(output_tensor);
|
||||||
|
}
|
||||||
|
if (need_transpose) {
|
||||||
|
const nvinfer1::ITensor* output_tensor = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
|
||||||
|
tensor, transpose_order, &output_tensor));
|
||||||
|
tensor = const_cast<nvinfer1::ITensor*>(output_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add padding layer
|
||||||
|
nvinfer1::IPaddingLayer* layer = params->converter->network()->addPadding(
|
||||||
|
*const_cast<nvinfer1::ITensor*>(tensor), pre_padding, post_padding);
|
||||||
|
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
|
||||||
|
params->converter->MarkQuantizationRangesAsInferrable(tensor,
|
||||||
|
layer->getOutput(0));
|
||||||
|
tensor = layer->getOutput(0);
|
||||||
|
|
||||||
|
// Restore transpose
|
||||||
|
if (need_transpose) {
|
||||||
|
const nvinfer1::ITensor* output_tensor = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
|
||||||
|
tensor, inv_transpose_order, &output_tensor));
|
||||||
|
tensor = const_cast<nvinfer1::ITensor*>(output_tensor);
|
||||||
|
}
|
||||||
|
// Restore reshape
|
||||||
|
if (need_reshape) {
|
||||||
|
// Calculate output dimensions
|
||||||
|
for (int i = 0; i < pad_dims.size(); i++) {
|
||||||
|
const int axis = pad_dims[i];
|
||||||
|
input_dims[axis] = end[axis] - begin[axis];
|
||||||
|
}
|
||||||
|
// Remove added 1 dimensions
|
||||||
|
for (int i = 0; i < reshape_dims_added; i++) {
|
||||||
|
int value = input_dims[1];
|
||||||
|
if (value != 1) {
|
||||||
|
return tensorflow::errors::Internal(
|
||||||
|
"StridedSlice error when reshaping, at ", node_def.name());
|
||||||
|
}
|
||||||
|
input_dims.erase(input_dims.begin() + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::Dims new_dims;
|
||||||
|
TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims,
|
||||||
|
/*ignore_first_dim=*/true));
|
||||||
|
const nvinfer1::ITensor* output_tensor = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
|
||||||
|
TRT_TensorOrWeights(tensor), new_dims, &output_tensor));
|
||||||
|
tensor = const_cast<nvinfer1::ITensor*>(output_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
params->outputs->push_back(
|
||||||
|
TRT_TensorOrWeights(const_cast<nvinfer1::ITensor*>(tensor)));
|
||||||
|
return tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
tensorflow::Status ConvertConv2D(OpConverterParams* params) {
|
tensorflow::Status ConvertConv2D(OpConverterParams* params) {
|
||||||
return ConvertConv2DHelper(params, ConvolutionType::DEFAULT);
|
return ConvertConv2DHelper(params, ConvolutionType::DEFAULT);
|
||||||
}
|
}
|
||||||
@ -3335,14 +3579,15 @@ static void RegisterValidatableOpConverters(
|
|||||||
(*registration)["Const"] = ConvertConst;
|
(*registration)["Const"] = ConvertConst;
|
||||||
(*registration)["Conv2D"] = ConvertConv2D;
|
(*registration)["Conv2D"] = ConvertConv2D;
|
||||||
(*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise;
|
(*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise;
|
||||||
(*registration)["Transpose"] = ConvertTranspose;
|
(*registration)["ExpandDims"] = ConvertExpandDims;
|
||||||
(*registration)["Reshape"] = ConvertReshape;
|
|
||||||
(*registration)["MatMul"] = ConvertMatMul;
|
(*registration)["MatMul"] = ConvertMatMul;
|
||||||
(*registration)["Pad"] = ConvertPad;
|
(*registration)["Pad"] = ConvertPad;
|
||||||
(*registration)["Relu6"] = ConvertRelu6;
|
(*registration)["Relu6"] = ConvertRelu6;
|
||||||
|
(*registration)["Reshape"] = ConvertReshape;
|
||||||
(*registration)["Square"] = ConvertSquare;
|
(*registration)["Square"] = ConvertSquare;
|
||||||
(*registration)["ExpandDims"] = ConvertExpandDims;
|
|
||||||
(*registration)["Squeeze"] = ConvertSqueeze;
|
(*registration)["Squeeze"] = ConvertSqueeze;
|
||||||
|
(*registration)["StridedSlice"] = ConvertStridedSlice;
|
||||||
|
(*registration)["Transpose"] = ConvertTranspose;
|
||||||
|
|
||||||
for (auto quantization_op_type :
|
for (auto quantization_op_type :
|
||||||
{"QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3",
|
{"QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3",
|
||||||
|
@ -2129,7 +2129,6 @@ TEST_F(OpConverterTest, ConvertExpandDims) {
|
|||||||
auto expanddims =
|
auto expanddims =
|
||||||
ops::ExpandDims(s.WithOpName("my_expanddims"), input, weights);
|
ops::ExpandDims(s.WithOpName("my_expanddims"), input, weights);
|
||||||
const NodeDef& node_def = expanddims.operation.node()->def();
|
const NodeDef& node_def = expanddims.operation.node()->def();
|
||||||
|
|
||||||
{
|
{
|
||||||
// Input is weights, should fail.
|
// Input is weights, should fail.
|
||||||
Reset();
|
Reset();
|
||||||
@ -2349,6 +2348,277 @@ TEST_F(OpConverterTest, ConvertSqueeze) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(OpConverterTest, ConvertStridedSlice) {
|
||||||
|
{
|
||||||
|
// Input list is empty, should fail.
|
||||||
|
NodeDef node_def = MakeNodeDef("my_strided_slice", "StridedSlice", {});
|
||||||
|
RunValidationAndConversion(
|
||||||
|
node_def, error::INVALID_ARGUMENT,
|
||||||
|
"StridedSlice expects 4 inputs, at my_strided_slice");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get nodedef for StridedSlice layer.
|
||||||
|
auto get_strided_slice_nodedef =
|
||||||
|
[](int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0,
|
||||||
|
int new_axis_mask = 0, int shrink_axis_mask = 0) -> NodeDef {
|
||||||
|
Scope s = Scope::NewRootScope();
|
||||||
|
auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
|
||||||
|
auto begin = ops::Placeholder(s.WithOpName("begin"), DT_INT32);
|
||||||
|
auto end = ops::Placeholder(s.WithOpName("end"), DT_INT32);
|
||||||
|
auto strides = ops::Placeholder(s.WithOpName("strides"), DT_INT32);
|
||||||
|
ops::StridedSlice::Attrs attrs = ops::StridedSlice::Attrs()
|
||||||
|
.BeginMask(begin_mask)
|
||||||
|
.EndMask(end_mask)
|
||||||
|
.EllipsisMask(ellipsis_mask)
|
||||||
|
.NewAxisMask(new_axis_mask)
|
||||||
|
.ShrinkAxisMask(shrink_axis_mask);
|
||||||
|
auto strided_slice = ops::StridedSlice(s.WithOpName("my_strided_slice"),
|
||||||
|
input, begin, end, strides, attrs);
|
||||||
|
return strided_slice.operation.node()->def();
|
||||||
|
};
|
||||||
|
|
||||||
|
{
|
||||||
|
NodeDef node_def = get_strided_slice_nodedef();
|
||||||
|
AddTestWeights<int32>("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
|
||||||
|
AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
|
||||||
|
AddTestWeights<int32>("end", {4}, {1, 1, 2, 3});
|
||||||
|
AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
|
||||||
|
RunValidationAndConversion(
|
||||||
|
node_def, error::UNIMPLEMENTED,
|
||||||
|
"StridedSlice is only implemented for tensors, at my_strided_slice");
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Begin, end, strides are tensors, should fail.
|
||||||
|
Reset();
|
||||||
|
NodeDef node_def = get_strided_slice_nodedef();
|
||||||
|
AddTestTensor("input", {1, 2, 3});
|
||||||
|
AddTestTensor("begin", {4});
|
||||||
|
AddTestTensor("end", {4});
|
||||||
|
AddTestTensor("strides", {4});
|
||||||
|
RunValidationAndConversion(
|
||||||
|
node_def, error::INVALID_ARGUMENT,
|
||||||
|
"StridedSlice expects weights for begin, end, and strides, at "
|
||||||
|
"my_strided_slice");
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Non-zero ellipsis_mask, should fail.
|
||||||
|
Reset();
|
||||||
|
NodeDef node_def = get_strided_slice_nodedef(
|
||||||
|
/*begin_mask=*/0, /*end_mask=*/0, /*ellipsis_mask=*/2,
|
||||||
|
/*new_axis_mask=*/0, /*shrink_axis_mask=*/0);
|
||||||
|
AddTestTensor("input", {1, 2, 3});
|
||||||
|
AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
|
||||||
|
AddTestWeights<int32>("end", {4}, {1, 1, 2, 3});
|
||||||
|
AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
|
||||||
|
RunValidationAndConversion(
|
||||||
|
node_def, error::UNIMPLEMENTED,
|
||||||
|
"ellipsis_mask is not supported for StridedSlice, at "
|
||||||
|
"my_strided_slice");
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Modify batch dim, should fail.
|
||||||
|
Reset();
|
||||||
|
NodeDef node_def = get_strided_slice_nodedef();
|
||||||
|
AddTestTensor("input", {1, 2, 3});
|
||||||
|
AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
|
||||||
|
AddTestWeights<int32>("end", {4}, {0, 1, 2, 3});
|
||||||
|
AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
|
||||||
|
RunValidationAndConversion(
|
||||||
|
node_def, error::UNIMPLEMENTED,
|
||||||
|
"StridedSlice can't modify batch dim, at my_strided_slice");
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Stride is not 1, should fail.
|
||||||
|
Reset();
|
||||||
|
NodeDef node_def = get_strided_slice_nodedef();
|
||||||
|
AddTestTensor("input", {1, 2, 3});
|
||||||
|
AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
|
||||||
|
AddTestWeights<int32>("end", {4}, {1, 1, 2, 3});
|
||||||
|
AddTestWeights<int32>("strides", {4}, {1, 2, -1, 3});
|
||||||
|
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
|
||||||
|
"StridedSlice is only implemented for stride of "
|
||||||
|
"1, at my_strided_slice");
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Begin out of bounds, should fail.
|
||||||
|
Reset();
|
||||||
|
NodeDef node_def = get_strided_slice_nodedef();
|
||||||
|
AddTestTensor("input", {1, 2, 3});
|
||||||
|
AddTestWeights<int32>("begin", {4}, {1, 2, 3, 4});
|
||||||
|
AddTestWeights<int32>("end", {4}, {0, 1, 2, 3});
|
||||||
|
AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
|
||||||
|
RunValidationAndConversion(
|
||||||
|
node_def, error::INVALID_ARGUMENT,
|
||||||
|
"begin value of 2 for StridedSlice is invalid, must be in the range "
|
||||||
|
"[-dim_size(i), dim_size(i)], at my_strided_slice");
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// End out of bounds, should fail.
|
||||||
|
Reset();
|
||||||
|
NodeDef node_def = get_strided_slice_nodedef();
|
||||||
|
AddTestTensor("input", {1, 2, 3});
|
||||||
|
AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
|
||||||
|
AddTestWeights<int32>("end", {4}, {1, 2, 3, 4});
|
||||||
|
AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
|
||||||
|
RunValidationAndConversion(
|
||||||
|
node_def, error::INVALID_ARGUMENT,
|
||||||
|
"end value of 2 for StridedSlice is invalid, must be in the range "
|
||||||
|
"[-dim_size(i), dim_size(i)], at my_strided_slice");
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Size of sliced dim is negative, should fail.
|
||||||
|
Reset();
|
||||||
|
NodeDef node_def = get_strided_slice_nodedef();
|
||||||
|
AddTestTensor("input", {1, 2, 3});
|
||||||
|
AddTestWeights<int32>("begin", {4}, {0, 0, 2, 0});
|
||||||
|
AddTestWeights<int32>("end", {4}, {1, 1, 0, 3});
|
||||||
|
AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
|
||||||
|
RunValidationAndConversion(
|
||||||
|
node_def, error::INVALID_ARGUMENT,
|
||||||
|
"New size of sliced dimension is negative, at my_strided_slice");
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TestParams {
|
||||||
|
TestParams(const std::vector<int>& input_dims,
|
||||||
|
const std::vector<int>& expected_output_dims,
|
||||||
|
const std::vector<int>& begin, const std::vector<int>& end,
|
||||||
|
const std::vector<int>& begin_mask,
|
||||||
|
const std::vector<int>& end_mask,
|
||||||
|
const std::vector<int>& expected_output)
|
||||||
|
: input_dims(input_dims),
|
||||||
|
expected_output_dims(expected_output_dims),
|
||||||
|
begin(begin),
|
||||||
|
end(end),
|
||||||
|
expected_output(expected_output) {
|
||||||
|
// Masks are provided in terms of vectors for readability. Convert them to
|
||||||
|
// binary here.
|
||||||
|
this->begin_mask = 0;
|
||||||
|
for (int i = 0; i < begin_mask.size(); i++) {
|
||||||
|
if (begin_mask[i]) this->begin_mask |= (1 << i);
|
||||||
|
}
|
||||||
|
this->end_mask = 0;
|
||||||
|
for (int i = 0; i < end_mask.size(); i++) {
|
||||||
|
if (end_mask[i]) this->end_mask |= (1 << i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> input_dims;
|
||||||
|
std::vector<int> expected_output_dims;
|
||||||
|
std::vector<int> begin;
|
||||||
|
std::vector<int> end;
|
||||||
|
int begin_mask;
|
||||||
|
int end_mask;
|
||||||
|
std::vector<int> expected_output;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Ok.
|
||||||
|
const int kStridedSliceOKCases = 18;
|
||||||
|
TestParams ok_params[kStridedSliceOKCases] = {
|
||||||
|
// 2D Crop.
|
||||||
|
TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2},
|
||||||
|
/*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 1, 2},
|
||||||
|
/*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 0, 0},
|
||||||
|
/*expected_output=*/{1, 2}},
|
||||||
|
TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2},
|
||||||
|
/*begin=*/{0, 0, 1, 1}, /*end=*/{0, 0, 0, 0},
|
||||||
|
/*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1},
|
||||||
|
/*expected_output=*/{5, 6}},
|
||||||
|
TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2},
|
||||||
|
/*begin=*/{0, 0, 1, 1}, /*end=*/{0, 1, 2, 3},
|
||||||
|
/*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 0, 0},
|
||||||
|
/*expected_output=*/{5, 6}},
|
||||||
|
// 2D Crop, with transpose.
|
||||||
|
TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 2, 1},
|
||||||
|
/*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 2, 1},
|
||||||
|
/*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0},
|
||||||
|
/*expected_output=*/{1, 2}},
|
||||||
|
TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 2, 1},
|
||||||
|
/*begin=*/{0, 1, 1, 0}, /*end=*/{0, 2, 3, 1},
|
||||||
|
/*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0},
|
||||||
|
/*expected_output=*/{5, 6}},
|
||||||
|
TestParams{/*input_dims=*/{2, 1, 3}, /*expected_output_dims=*/{1, 1, 2},
|
||||||
|
/*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 1, 2},
|
||||||
|
/*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0},
|
||||||
|
/*expected_output=*/{1, 2}},
|
||||||
|
TestParams{/*input_dims=*/{2, 1, 3}, /*expected_output_dims=*/{1, 1, 2},
|
||||||
|
/*begin=*/{0, 1, 0, 1}, /*end=*/{0, 2, 1, 3},
|
||||||
|
/*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0},
|
||||||
|
/*expected_output=*/{5, 6}},
|
||||||
|
// 2D Crop, with reshape.
|
||||||
|
TestParams{/*input_dims=*/{2, 3}, /*expected_output_dims=*/{1, 2},
|
||||||
|
/*begin=*/{0, 0, 0}, /*end=*/{0, 1, 2},
|
||||||
|
/*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 0},
|
||||||
|
/*expected_output=*/{1, 2}},
|
||||||
|
TestParams{/*input_dims=*/{2, 3}, /*expected_output_dims=*/{1, 2},
|
||||||
|
/*begin=*/{0, 1, 1}, /*end=*/{0, 0, 0},
|
||||||
|
/*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 1, 1},
|
||||||
|
/*expected_output=*/{5, 6}},
|
||||||
|
// 1D Crop.
|
||||||
|
TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 2, 2},
|
||||||
|
/*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 0, 2},
|
||||||
|
/*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 0},
|
||||||
|
/*expected_output=*/{1, 2, 4, 5}},
|
||||||
|
TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 3},
|
||||||
|
/*begin=*/{0, 0, 1, 0}, /*end=*/{0, 0, 0, 0},
|
||||||
|
/*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1},
|
||||||
|
/*expected_output=*/{4, 5, 6}},
|
||||||
|
// 1D Crop, with transpose.
|
||||||
|
TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 3, 1},
|
||||||
|
/*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 0, 0},
|
||||||
|
/*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 1, 1},
|
||||||
|
/*expected_output=*/{1, 2, 3}},
|
||||||
|
TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 3, 1},
|
||||||
|
/*begin=*/{0, 1, 0, 0}, /*end=*/{0, 0, 0, 0},
|
||||||
|
/*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1},
|
||||||
|
/*expected_output=*/{4, 5, 6}},
|
||||||
|
// 1D Crop, with reshape.
|
||||||
|
TestParams{/*input_dims=*/{6}, /*expected_output_dims=*/{3},
|
||||||
|
/*begin=*/{0, 0}, /*end=*/{0, 3},
|
||||||
|
/*begin_mask=*/{0, 0}, /*end_mask=*/{1, 0},
|
||||||
|
/*expected_output=*/{1, 2, 3}},
|
||||||
|
TestParams{/*input_dims=*/{1, 6}, /*expected_output_dims=*/{1, 3},
|
||||||
|
/*begin=*/{0, 0, 2}, /*end=*/{0, 0, 5},
|
||||||
|
/*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 1, 0},
|
||||||
|
/*expected_output=*/{3, 4, 5}},
|
||||||
|
TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{3, 1},
|
||||||
|
/*begin=*/{0, 2, 0}, /*end=*/{0, 5, 0},
|
||||||
|
/*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1},
|
||||||
|
/*expected_output=*/{3, 4, 5}},
|
||||||
|
// Negative axis.
|
||||||
|
TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{3, 1},
|
||||||
|
/*begin=*/{0, -6, 0}, /*end=*/{0, -3, 0},
|
||||||
|
/*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1},
|
||||||
|
/*expected_output=*/{1, 2, 3}},
|
||||||
|
TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{5, 1},
|
||||||
|
/*begin=*/{0, 0, 0}, /*end=*/{0, -1, 0},
|
||||||
|
/*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1},
|
||||||
|
/*expected_output=*/{1, 2, 3, 4, 5}},
|
||||||
|
};
|
||||||
|
|
||||||
|
for (int i = 0; i < kStridedSliceOKCases; i++) {
|
||||||
|
Reset();
|
||||||
|
NodeDef node_def = get_strided_slice_nodedef(ok_params[i].begin_mask,
|
||||||
|
ok_params[i].end_mask);
|
||||||
|
AddTestTensor("input", ok_params[i].input_dims);
|
||||||
|
AddTestWeights<int32>("begin",
|
||||||
|
{static_cast<int>(ok_params[i].begin.size())},
|
||||||
|
ok_params[i].begin);
|
||||||
|
AddTestWeights<int32>("end", {static_cast<int>(ok_params[i].end.size())},
|
||||||
|
ok_params[i].end);
|
||||||
|
std::vector<int> strides(ok_params[i].input_dims.size(), 1);
|
||||||
|
AddTestWeights<int32>("strides", {static_cast<int>(strides.size())},
|
||||||
|
strides);
|
||||||
|
RunValidationAndConversion(node_def);
|
||||||
|
|
||||||
|
TRT_TensorOrWeights output;
|
||||||
|
TF_EXPECT_OK(GetTensorOrWeights("my_strided_slice", &output));
|
||||||
|
std::vector<float> output_data(ok_params[i].expected_output.size());
|
||||||
|
BuildAndRun<float>({{"input", {1, 2, 3, 4, 5, 6}}}, "my_strided_slice",
|
||||||
|
&output_data);
|
||||||
|
EXPECT_THAT(output_data, ElementsAreArray(ok_params[i].expected_output));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace convert
|
} // namespace convert
|
||||||
} // namespace tensorrt
|
} // namespace tensorrt
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
Reference in New Issue
Block a user