Copy ben's changes
This commit is contained in:
parent
7d57d32e44
commit
1051c37705
@ -62,14 +62,14 @@ limitations under the License.
|
|||||||
|
|
||||||
#define TFTRT_RETURN_ERROR_IF_FALSE(status, node) \
|
#define TFTRT_RETURN_ERROR_IF_FALSE(status, node) \
|
||||||
do { \
|
do { \
|
||||||
if (status == false) { \
|
if ((status) == false) { \
|
||||||
TFTRT_INTERNAL_ERROR_AT_NODE(node); \
|
TFTRT_INTERNAL_ERROR_AT_NODE(node); \
|
||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
#define TFTRT_RETURN_ERROR_IF_NULLPTR(ptr, node) \
|
#define TFTRT_RETURN_ERROR_IF_NULLPTR(ptr, node) \
|
||||||
do { \
|
do { \
|
||||||
if (ptr == nullptr) { \
|
if ((ptr) == nullptr) { \
|
||||||
TFTRT_INTERNAL_ERROR_AT_NODE(node); \
|
TFTRT_INTERNAL_ERROR_AT_NODE(node); \
|
||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
@ -1577,12 +1577,14 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) {
|
|||||||
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
|
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
|
||||||
TFAttrs attrs(node_def);
|
TFAttrs attrs(node_def);
|
||||||
|
|
||||||
|
int c_index = 1;
|
||||||
int h_index = 2;
|
int h_index = 2;
|
||||||
int w_index = 3;
|
int w_index = 3;
|
||||||
auto data_format = attrs.get<string>("data_format");
|
auto data_format = attrs.get<string>("data_format");
|
||||||
if (data_format == "NHWC") {
|
if (data_format == "NHWC") {
|
||||||
TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
|
TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
|
||||||
const_cast<nvinfer1::ITensor*>(tensor), {0, 3, 1, 2}, &tensor));
|
const_cast<nvinfer1::ITensor*>(tensor), {0, 3, 1, 2}, &tensor));
|
||||||
|
c_index = 3;
|
||||||
h_index = 1;
|
h_index = 1;
|
||||||
w_index = 2;
|
w_index = 2;
|
||||||
// TODO(jie): transpose it
|
// TODO(jie): transpose it
|
||||||
@ -1618,14 +1620,30 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) {
|
|||||||
<< tf_stride[3];
|
<< tf_stride[3];
|
||||||
const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
|
const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
|
||||||
|
|
||||||
|
auto tf_dilations = attrs.get<std::vector<int>>("dilations");
|
||||||
|
if ((int)tf_dilations.size() != 4) {
|
||||||
|
return tensorflow::errors::InvalidArgument(
|
||||||
|
"Convolution dilations field must specify 4 dimensions " +
|
||||||
|
node_def.name());
|
||||||
|
}
|
||||||
|
if (tf_dilations[0] != 1 || tf_dilations[c_index] != 1) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"Dilation rate must be 1 for batch and channel dimensions, at ",
|
||||||
|
node_def.name());
|
||||||
|
}
|
||||||
|
nvinfer1::DimsHW dilation(tf_dilations[h_index], tf_dilations[w_index]);
|
||||||
|
|
||||||
std::vector<std::pair<int, int>> padding;
|
std::vector<std::pair<int, int>> padding;
|
||||||
// TODO(jie): padding.
|
// TODO(jie): padding.
|
||||||
if (attrs.get<string>("padding") == "SAME") {
|
if (attrs.get<string>("padding") == "SAME") {
|
||||||
// This is NCHW tensor with no batch dimension.
|
// This is NCHW tensor with no batch dimension.
|
||||||
// 1 -> h
|
// 1 -> h
|
||||||
// 2 -> w
|
// 2 -> w
|
||||||
|
nvinfer1::DimsHW effective_kernel_size = kernel_size;
|
||||||
|
effective_kernel_size.h() += (kernel_size.h() - 1) * (dilation.h() - 1);
|
||||||
|
effective_kernel_size.w() += (kernel_size.w() - 1) * (dilation.w() - 1);
|
||||||
padding = CreateSamePadding(
|
padding = CreateSamePadding(
|
||||||
stride, kernel_size,
|
stride, effective_kernel_size,
|
||||||
{static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
|
{static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
|
||||||
} else {
|
} else {
|
||||||
padding = {{0, 0}, {0, 0}};
|
padding = {{0, 0}, {0, 0}};
|
||||||
@ -1659,6 +1677,7 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) {
|
|||||||
layer->setPadding({padding[0].first, padding[1].first});
|
layer->setPadding({padding[0].first, padding[1].first});
|
||||||
layer->setName(node_def.name().c_str());
|
layer->setName(node_def.name().c_str());
|
||||||
layer->setNbGroups(num_groups);
|
layer->setNbGroups(num_groups);
|
||||||
|
layer->setDilation(dilation);
|
||||||
const nvinfer1::ITensor* output_tensor = layer->getOutput(0);
|
const nvinfer1::ITensor* output_tensor = layer->getOutput(0);
|
||||||
VLOG(2) << "TENSOR out: " << DebugString(output_tensor->getDimensions());
|
VLOG(2) << "TENSOR out: " << DebugString(output_tensor->getDimensions());
|
||||||
VLOG(2) << "data_format: " << data_format;
|
VLOG(2) << "data_format: " << data_format;
|
||||||
|
172
tensorflow/contrib/tensorrt/test/conv2d_test.py
Normal file
172
tensorflow/contrib/tensorrt/test/conv2d_test.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Model script to test TF-TensorRT integration."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import gen_array_ops
|
||||||
|
from tensorflow.python.ops import gen_nn_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import nn_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
def conv2d_layer(inputs, filters, kernel_size, strides=(1, 1), padding='valid',
|
||||||
|
data_format='channels_last', dilation_rate=(1, 1), name=None):
|
||||||
|
dtype = inputs.dtype
|
||||||
|
c_axis = -1 if data_format == 'channels_last' else 1
|
||||||
|
nchan = inputs.shape[c_axis]
|
||||||
|
weights_shape = (kernel_size[0], kernel_size[1], nchan, filters)
|
||||||
|
weights = constant_op.constant(np.random.randn(*weights_shape), dtype=dtype)
|
||||||
|
padding = padding.upper()
|
||||||
|
if data_format == 'channels_last':
|
||||||
|
strides = [1] + list(strides) + [1]
|
||||||
|
dilations = [1] + list(dilation_rate) + [1]
|
||||||
|
data_format = 'NHWC'
|
||||||
|
else:
|
||||||
|
strides = [1, 1] + list(strides)
|
||||||
|
dilations = [1, 1] + list(dilation_rate)
|
||||||
|
data_format = 'NCHW'
|
||||||
|
return gen_nn_ops.conv2d(inputs, weights, strides=strides, padding=padding,
|
||||||
|
dilations=dilations, data_format=data_format)
|
||||||
|
|
||||||
|
def div_round_up(n, d):
|
||||||
|
return (n - 1) // d + 1
|
||||||
|
|
||||||
|
class Conv2DNCHWTest(trt_test.TfTrtIntegrationTestBase):
|
||||||
|
|
||||||
|
def GetParams(self):
|
||||||
|
"""Testing conversion of Conv2D (data_format=NCHW) in TF-TRT conversion."""
|
||||||
|
np.random.seed(1234)
|
||||||
|
dtype = dtypes.float32
|
||||||
|
input_name = "input"
|
||||||
|
n, c, h, w = 13, 3, 7, 11
|
||||||
|
num_filters = 5
|
||||||
|
input_dims = [n, c, h, w]
|
||||||
|
output_name = "output"
|
||||||
|
g = ops.Graph()
|
||||||
|
with g.as_default():
|
||||||
|
inp = array_ops.placeholder(
|
||||||
|
dtype=dtype, shape=[None] + input_dims[1:], name=input_name)
|
||||||
|
with g.device("/GPU:0"):
|
||||||
|
results = []
|
||||||
|
for kernel_size in [(3, 3), (3, 2)]:
|
||||||
|
for dilation_rate in [(1, 1), (2, 3)]:
|
||||||
|
result = conv2d_layer(inp, num_filters, kernel_size,
|
||||||
|
dilation_rate=dilation_rate, padding='same',
|
||||||
|
data_format='channels_first')
|
||||||
|
results.append(result)
|
||||||
|
output = sum(results)
|
||||||
|
output = array_ops.identity(output, name=output_name)
|
||||||
|
return trt_test.TfTrtIntegrationTestParams(
|
||||||
|
gdef=g.as_graph_def(),
|
||||||
|
input_names=[input_name],
|
||||||
|
input_dims=[input_dims],
|
||||||
|
output_names=[output_name],
|
||||||
|
expected_output_dims=[(n, num_filters, h, w)])
|
||||||
|
|
||||||
|
def ExpectedEnginesToBuild(self, run_params):
|
||||||
|
"""Return the expected engines to build."""
|
||||||
|
return ["my_trt_op_0"]
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2DStridedNCHWTest(trt_test.TfTrtIntegrationTestBase):
|
||||||
|
|
||||||
|
def GetParams(self):
|
||||||
|
"""Testing conversion of strided Conv2D (data_format=NCHW) in TF-TRT
|
||||||
|
conversion."""
|
||||||
|
np.random.seed(1234)
|
||||||
|
dtype = dtypes.float32
|
||||||
|
input_name = "input"
|
||||||
|
n, c, h, w = 13, 3, 7, 11
|
||||||
|
num_filters = 5
|
||||||
|
input_dims = [n, c, h, w]
|
||||||
|
output_name = "output"
|
||||||
|
g = ops.Graph()
|
||||||
|
with g.as_default():
|
||||||
|
inp = array_ops.placeholder(
|
||||||
|
dtype=dtype, shape=[None] + input_dims[1:], name=input_name)
|
||||||
|
with g.device("/GPU:0"):
|
||||||
|
output = inp
|
||||||
|
output = conv2d_layer(output, num_filters, (3, 2), strides=(2, 2),
|
||||||
|
padding='same', data_format='channels_first')
|
||||||
|
h = div_round_up(h, 2)
|
||||||
|
w = div_round_up(w, 2)
|
||||||
|
output = conv2d_layer(output, num_filters, (3, 3), strides=(2, 2),
|
||||||
|
dilation_rate=(2, 3), padding='same',
|
||||||
|
data_format='channels_first')
|
||||||
|
h = div_round_up(h, 2)
|
||||||
|
w = div_round_up(w, 2)
|
||||||
|
output = array_ops.identity(output, name=output_name)
|
||||||
|
return trt_test.TfTrtIntegrationTestParams(
|
||||||
|
gdef=g.as_graph_def(),
|
||||||
|
input_names=[input_name],
|
||||||
|
input_dims=[input_dims],
|
||||||
|
output_names=[output_name],
|
||||||
|
expected_output_dims=[(n, num_filters, h, w)])
|
||||||
|
|
||||||
|
def ExpectedEnginesToBuild(self, run_params):
|
||||||
|
"""Return the expected engines to build."""
|
||||||
|
return ["my_trt_op_0"]
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2DNHWCTest(trt_test.TfTrtIntegrationTestBase):
|
||||||
|
|
||||||
|
def GetParams(self):
|
||||||
|
"""Testing conversion of Conv2D (data_format=NHWC) in TF-TRT conversion."""
|
||||||
|
np.random.seed(1234)
|
||||||
|
dtype = dtypes.float32
|
||||||
|
input_name = "input"
|
||||||
|
n, h, w, c = 13, 7, 11, 3
|
||||||
|
num_filters = 5
|
||||||
|
input_dims = [n, h, w, c]
|
||||||
|
output_name = "output"
|
||||||
|
g = ops.Graph()
|
||||||
|
with g.as_default():
|
||||||
|
inp = array_ops.placeholder(
|
||||||
|
dtype=dtype, shape=[None] + input_dims[1:], name=input_name)
|
||||||
|
with g.device("/GPU:0"):
|
||||||
|
results = []
|
||||||
|
for kernel_size in [(3, 3), (3, 2)]:
|
||||||
|
for dilation_rate in [(1, 1), (2, 3)]:
|
||||||
|
result = conv2d_layer(inp, num_filters, kernel_size,
|
||||||
|
dilation_rate=dilation_rate, padding='same',
|
||||||
|
data_format='channels_last')
|
||||||
|
results.append(result)
|
||||||
|
output = sum(results)
|
||||||
|
output = array_ops.identity(output, name=output_name)
|
||||||
|
return trt_test.TfTrtIntegrationTestParams(
|
||||||
|
gdef=g.as_graph_def(),
|
||||||
|
input_names=[input_name],
|
||||||
|
input_dims=[input_dims],
|
||||||
|
output_names=[output_name],
|
||||||
|
expected_output_dims=[(n, h, w, num_filters)])
|
||||||
|
|
||||||
|
def ExpectedEnginesToBuild(self, run_params):
|
||||||
|
"""Return the expected engines to build."""
|
||||||
|
return ["my_trt_op_0"]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
Loading…
x
Reference in New Issue
Block a user