Verify the node before putting it into a TRT subgraph, so if the TRT converter
doesn't support it, it can be excluded earlier. This prevents a single incompatible node from causing the conversion of a large subgraph to fail. Currently this is only added for Transpose, Reshape and Const. Implementation details: - A new class TrtCandidateSelector is added for the validation. Previously it used IsTensorRTCandidate() to pick TRT candidate nodes which only checks the op type; the new TrtCandidateSelector::IsTensorRTCandidate() checks the op type as well as input properties of the node, e.g. const or not, the values of the const, etc. - TrtCandidateSelector uses TrtNodeValidator added earlier to validate the nodes. A new method TrtNodeValidator::GetTensorOrWeights() is added to convert an input of a node to a TRT_TensorOrWeights, and after gathering all input TRT_TensorOrWeights, TrtNodeValidator::ValidateNode() is called to determine whether the node is supported by TRT. - The original InputEdgeValidator class is now removed. It only validates the input nodes of the subgraph, and now its logic is merged into ValidateTensorProperties which will be used by TrtNodeValidator to validate all nodes before putting them in subgraph. PiperOrigin-RevId: 219662352
This commit is contained in:
parent
e204232f19
commit
486dca315c
@ -312,15 +312,20 @@ tf_cuda_cc_test(
|
||||
],
|
||||
deps = [
|
||||
":trt_conversion",
|
||||
"@com_google_googletest//:gtest",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler/clusters:cluster",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"//tensorflow/core:direct_session",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:nv_infer",
|
||||
]),
|
||||
@ -340,6 +345,10 @@ tf_cuda_cc_test(
|
||||
":trt_conversion",
|
||||
":trt_plugins",
|
||||
"@com_google_googletest//:gtest",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -81,12 +81,13 @@ std::vector<int> GetLoadedTensorRTVersion() {
|
||||
return {ver_major, ver_minor, ver_patch};
|
||||
}
|
||||
|
||||
namespace {
|
||||
TrtCandidateSelector::TrtCandidateSelector(
|
||||
const grappler::GraphProperties& graph_properties)
|
||||
: graph_properties_(graph_properties) {}
|
||||
|
||||
bool IsTensorRTCandidate(const tensorflow::Node* node) {
|
||||
Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) {
|
||||
// TODO(laigd): move this set to TrtNodeValidator where it should belong.
|
||||
// LINT.IfChange
|
||||
// TODO(jie): Segmentation shouldn't associated with op name.
|
||||
// Split it into a registration for each kernel.
|
||||
static const std::set<string> candidate_ops = {
|
||||
"Identity",
|
||||
"Snapshot",
|
||||
@ -127,13 +128,29 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) {
|
||||
"Prod",
|
||||
"Max",
|
||||
"Min",
|
||||
// TODO(ben,jie): ...
|
||||
};
|
||||
// LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.cc)
|
||||
return (candidate_ops.count(node->type_string()) ||
|
||||
PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string()));
|
||||
const bool is_supported_op_type =
|
||||
(candidate_ops.count(node->type_string()) ||
|
||||
PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string()));
|
||||
if (!is_supported_op_type) {
|
||||
return errors::Unimplemented("Op type ", node->type_string(),
|
||||
" is not supported.");
|
||||
}
|
||||
|
||||
std::vector<const Edge*> input_edges;
|
||||
TF_RETURN_IF_ERROR(node->input_edges(&input_edges));
|
||||
std::vector<std::pair<const NodeDef*, int>> input_node_and_ports;
|
||||
for (const Edge* input_edge : input_edges) {
|
||||
input_node_and_ports.emplace_back(&input_edge->src()->def(),
|
||||
input_edge->src_output());
|
||||
}
|
||||
return validator_.ValidateNode(node->def(), input_node_and_ports,
|
||||
graph_properties_);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
tensorflow::Status BuildNodeMap(
|
||||
const tensorflow::Graph& graph,
|
||||
std::unordered_map<string, tensorflow::Node*>* node_map) {
|
||||
@ -846,9 +863,15 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) {
|
||||
}
|
||||
segment_options.minimum_segment_size = params.minimum_segment_size;
|
||||
tensorflow::tensorrt::segment::SegmentNodesVector initial_segments;
|
||||
TrtCandidateSelector candidate_selector(*params.graph_properties);
|
||||
TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph(
|
||||
&graph, IsTensorRTCandidate, InputEdgeValidator(*params.graph_properties),
|
||||
OutputEdgeValidator(), segment_options, &initial_segments));
|
||||
&graph,
|
||||
std::bind(&TrtCandidateSelector::IsTensorRTCandidate, &candidate_selector,
|
||||
std::placeholders::_1),
|
||||
// Input validation is already done by TrtCandidateSelector, so we don't
|
||||
// need to check the input edges.
|
||||
[](const Edge* edge) { return true; }, OutputEdgeValidator(),
|
||||
segment_options, &initial_segments));
|
||||
if (initial_segments.size() > 1) {
|
||||
VLOG(0) << "MULTIPLE tensorrt candidate conversion: "
|
||||
<< initial_segments.size();
|
||||
|
@ -31,6 +31,26 @@ namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
namespace convert {
|
||||
|
||||
// Helper class for the segmenter to determine whether given TF node is
|
||||
// supported by TRT.
|
||||
class TrtCandidateSelector {
|
||||
public:
|
||||
TrtCandidateSelector(const grappler::GraphProperties& graph_properties);
|
||||
|
||||
// Returns OK iff 'node' is a TF-TRT conversion candidate, which will be added
|
||||
// to TRT subgraph and later converted into TRT engine.
|
||||
Status IsTensorRTCandidate(const tensorflow::Node* node);
|
||||
|
||||
private:
|
||||
// The TF-TRT node converter used to verify whether individual node is
|
||||
// supported. It will operate in validation-only mode.
|
||||
TrtNodeValidator validator_;
|
||||
|
||||
// GraphProperties of the graph whose nodes are to be validated by
|
||||
// IsTensorRTCandidate().
|
||||
const grappler::GraphProperties& graph_properties_;
|
||||
};
|
||||
|
||||
struct ConversionParams {
|
||||
ConversionParams()
|
||||
: input_graph_def(nullptr),
|
||||
|
@ -15,9 +15,14 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -33,6 +38,76 @@ namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
namespace convert {
|
||||
|
||||
// TODO(laigd): put this into some test utils file.
|
||||
void ExpectStatus(Status status, error::Code code = error::OK,
|
||||
const char* substr = nullptr) {
|
||||
EXPECT_EQ(code, status.code())
|
||||
<< status << " vs expected error code \"" << error::Code_Name(code)
|
||||
<< "\" and message \"" << substr << "\"";
|
||||
if (substr) {
|
||||
EXPECT_THAT(status.error_message(), ::testing::HasSubstr(substr)) << status;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TrtCandidateSelector, Basics) {
|
||||
// Create a graph containing both TRT-compatible and TRT-incompatible nodes
|
||||
// and use it to test TrtCandidateSelector::IsTensorRTCandidate().
|
||||
const std::vector<int32> input_shape_array{2, 2};
|
||||
TensorShape input_shape;
|
||||
TF_EXPECT_OK(TensorShapeUtils::MakeShape(input_shape_array, &input_shape));
|
||||
|
||||
Scope s = Scope::NewRootScope();
|
||||
ops::Placeholder::Attrs feed_attrs;
|
||||
TF_EXPECT_OK(
|
||||
TensorShapeUtils::MakeShape(input_shape_array, &feed_attrs.shape_));
|
||||
|
||||
// Compatible input.
|
||||
auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT, feed_attrs);
|
||||
auto const_1 = ops::Const(s.WithOpName("const_1"), 1.0f, input_shape);
|
||||
|
||||
// Compatible MatMul.
|
||||
auto matmul = ops::MatMul(s.WithOpName("matmul"), feed, const_1);
|
||||
|
||||
// Incompatible MatMul.
|
||||
ops::MatMul::Attrs matmul_attrs;
|
||||
matmul_attrs.transpose_a_ = true;
|
||||
auto incompatible_matmul = ops::MatMul(s.WithOpName("incompatible_matmul"),
|
||||
feed, const_1, matmul_attrs);
|
||||
|
||||
// Unsupported op.
|
||||
auto unsupported_op = ops::Sin(s.WithOpName("sin"), feed);
|
||||
|
||||
// Incompatible input.
|
||||
auto incompatible_feed = ops::Placeholder(s.WithOpName("feed"), DT_DOUBLE);
|
||||
auto const_2 = ops::Const(s.WithOpName("const_2"), 1.0, input_shape);
|
||||
// Compatible op with incompatible input.
|
||||
auto matmul_with_incompatible_input =
|
||||
ops::MatMul(s.WithOpName("matmul_with_incompatible_input"),
|
||||
incompatible_feed, const_2);
|
||||
|
||||
grappler::GrapplerItem item;
|
||||
TF_EXPECT_OK(s.ToGraphDef(&item.graph));
|
||||
Tensor feed_tensor(DT_FLOAT, input_shape);
|
||||
item.feed.push_back(std::make_pair("feed", feed_tensor));
|
||||
|
||||
grappler::GraphProperties graph_properties(item);
|
||||
TF_EXPECT_OK(graph_properties.InferStatically(true));
|
||||
|
||||
TrtCandidateSelector selector(graph_properties);
|
||||
TF_EXPECT_OK(selector.IsTensorRTCandidate(matmul.operation.node()));
|
||||
ExpectStatus(
|
||||
selector.IsTensorRTCandidate(incompatible_matmul.operation.node()),
|
||||
error::INVALID_ARGUMENT,
|
||||
"transpose_a is not supported for TensorRT FullyConnected "
|
||||
"(op: MatMul), at: incompatible_matmul");
|
||||
ExpectStatus(selector.IsTensorRTCandidate(unsupported_op.operation.node()),
|
||||
error::UNIMPLEMENTED, "Op type Sin is not supported");
|
||||
ExpectStatus(selector.IsTensorRTCandidate(
|
||||
matmul_with_incompatible_input.operation.node()),
|
||||
error::INTERNAL,
|
||||
"Failed to convert input with index 0 to a TRT_TensorOrWeights");
|
||||
}
|
||||
|
||||
class FakeCluster : public grappler::Cluster {
|
||||
public:
|
||||
FakeCluster() : Cluster(0) {}
|
||||
@ -48,8 +123,7 @@ class FakeCluster : public grappler::Cluster {
|
||||
}
|
||||
Status Run(const GraphDef& graph_def,
|
||||
const std::vector<std::pair<string, Tensor>>& feed,
|
||||
const std::vector<string>& fetch,
|
||||
RunMetadata* metadata) override {
|
||||
const std::vector<string>& fetch, RunMetadata* metadata) override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -108,6 +108,18 @@ inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype,
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
template <typename TensorShapeType>
|
||||
inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape,
|
||||
bool ignore_first_dim) {
|
||||
nvinfer1::Dims trt_dims;
|
||||
const int offset = (ignore_first_dim ? 1 : 0);
|
||||
for (int i = offset; i < shape.dims(); i++) {
|
||||
trt_dims.d[i - offset] = shape.dim_size(i);
|
||||
}
|
||||
trt_dims.nbDims = shape.dims() - offset;
|
||||
return trt_dims;
|
||||
}
|
||||
|
||||
void GetOutputProperties(const grappler::GraphProperties& graph_properties,
|
||||
const Node* node, const int out_port,
|
||||
PartialTensorShape* shape,
|
||||
@ -137,22 +149,37 @@ void GetInputProperties(const grappler::GraphProperties& graph_properties,
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::Status ValidateInputProperties(const PartialTensorShape& shape,
|
||||
const tensorflow::DataType dtype,
|
||||
nvinfer1::DataType* trt_dtype) {
|
||||
// TODO(aaroey): some of these checks also apply to IsTensorRTCandidate(), so
|
||||
// put them there instead.
|
||||
Status ValidateTensorProperties(const string& producer_node_type,
|
||||
const tensorflow::DataType dtype,
|
||||
const PartialTensorShape& shape,
|
||||
bool validation_only,
|
||||
nvinfer1::DataType* trt_dtype,
|
||||
nvinfer1::Dims* trt_dims, int* batch_size) {
|
||||
// Convert data type.
|
||||
TF_RETURN_IF_ERROR(ConvertDType(dtype, trt_dtype));
|
||||
|
||||
// Convert shape.
|
||||
if (shape.dims() < 0) {
|
||||
return tensorflow::errors::InvalidArgument("Input tensor rank is unknown.");
|
||||
return errors::InvalidArgument("Input tensor rank is unknown.");
|
||||
}
|
||||
if (shape.dims() > 9) {
|
||||
return tensorflow::errors::OutOfRange(
|
||||
"Input tensor rank is greater than 8.");
|
||||
if (shape.dims() > nvinfer1::Dims::MAX_DIMS + 1) { // +1 for batch dim
|
||||
return errors::OutOfRange("Input tensor rank is greater than ",
|
||||
nvinfer1::Dims::MAX_DIMS + 1);
|
||||
}
|
||||
if (producer_node_type != "Const" && shape.dims() < 2) {
|
||||
return errors::InvalidArgument(
|
||||
"Input tensor with rank<2 is not supported since the first dimension "
|
||||
"is treated as batch dimension by TRT");
|
||||
}
|
||||
*trt_dims = TensorShapeToTrtDims(shape, /*ignore_first_dim=*/true);
|
||||
*batch_size = shape.dim_size(0);
|
||||
|
||||
if (validation_only) return Status::OK();
|
||||
// Following are validations at runtime.
|
||||
|
||||
for (int d = 1; d < shape.dims(); ++d) {
|
||||
if (shape.dim_size(d) < 0) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
return errors::InvalidArgument(
|
||||
"Input tensor with shape ", shape.DebugString(),
|
||||
" has an unknown non-batch dimemension at dim ", d);
|
||||
}
|
||||
@ -358,22 +385,75 @@ string TRT_ShapedWeights::DebugString() const {
|
||||
", values=", reinterpret_cast<uintptr_t>(GetValues()), ")");
|
||||
}
|
||||
|
||||
// A fake ITensor implementation used to check whether the TF-TRT converter can
|
||||
// handle specific node. We only need shape and type information, and the
|
||||
// converter won't (and shouldn't) use this to build the TRT network.
|
||||
class TRT_TensorOrWeights::SimpleITensor : public nvinfer1::ITensor {
|
||||
public:
|
||||
SimpleITensor(nvinfer1::DataType trt_dtype, const nvinfer1::Dims& trt_dims)
|
||||
: trt_dtype_(trt_dtype), trt_dims_(trt_dims) {}
|
||||
|
||||
void setName(const char* name) override {}
|
||||
|
||||
const char* getName() const override { return ""; }
|
||||
|
||||
void setDimensions(nvinfer1::Dims dimensions) override {
|
||||
trt_dims_ = dimensions;
|
||||
}
|
||||
|
||||
nvinfer1::Dims getDimensions() const override { return trt_dims_; }
|
||||
|
||||
void setType(nvinfer1::DataType trt_dtype) override {
|
||||
trt_dtype_ = trt_dtype;
|
||||
}
|
||||
|
||||
nvinfer1::DataType getType() const override { return trt_dtype_; }
|
||||
|
||||
bool isNetworkInput() const override { return false; }
|
||||
|
||||
bool isNetworkOutput() const override { return false; }
|
||||
|
||||
void setBroadcastAcrossBatch(bool broadcastAcrossBatch) override {}
|
||||
|
||||
bool getBroadcastAcrossBatch() const override { return false; }
|
||||
|
||||
nvinfer1::TensorLocation getLocation() const override {
|
||||
// This is arbitrary, since we don't use it.
|
||||
return nvinfer1::TensorLocation::kDEVICE;
|
||||
}
|
||||
|
||||
void setLocation(nvinfer1::TensorLocation location) override {}
|
||||
|
||||
#if NV_TENSORRT_MAJOR >= 5
|
||||
bool setDynamicRange(float min, float max) override {}
|
||||
#endif
|
||||
|
||||
private:
|
||||
nvinfer1::DataType trt_dtype_;
|
||||
nvinfer1::Dims trt_dims_;
|
||||
};
|
||||
|
||||
TRT_TensorOrWeights::TRT_TensorOrWeights(nvinfer1::ITensor* tensor,
|
||||
int batch_size)
|
||||
: tensor_(tensor),
|
||||
batch_size_(batch_size),
|
||||
weights_(DT_FLOAT),
|
||||
initialized_(true),
|
||||
is_tensor_(true) {}
|
||||
|
||||
TRT_TensorOrWeights::TRT_TensorOrWeights(nvinfer1::DataType trt_dtype,
|
||||
const nvinfer1::Dims& trt_dims,
|
||||
int batch_size)
|
||||
: simple_itensor_(new SimpleITensor(trt_dtype, trt_dims)),
|
||||
batch_size_(batch_size),
|
||||
initialized_(true),
|
||||
is_tensor_(true) {}
|
||||
|
||||
TRT_TensorOrWeights::TRT_TensorOrWeights(const TRT_ShapedWeights& weights)
|
||||
: tensor_(nullptr),
|
||||
weights_(weights),
|
||||
initialized_(true),
|
||||
is_tensor_(false) {}
|
||||
: weights_(weights), initialized_(true), is_tensor_(false) {}
|
||||
|
||||
TRT_TensorOrWeights::TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs)
|
||||
: tensor_(rhs.tensor_),
|
||||
simple_itensor_(rhs.simple_itensor_),
|
||||
batch_size_(rhs.batch_size_),
|
||||
weights_(rhs.weights_),
|
||||
initialized_(rhs.initialized_),
|
||||
@ -381,12 +461,23 @@ TRT_TensorOrWeights::TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs)
|
||||
|
||||
void TRT_TensorOrWeights::operator=(const TRT_TensorOrWeights& rhs) {
|
||||
tensor_ = rhs.tensor_;
|
||||
simple_itensor_ = rhs.simple_itensor_;
|
||||
batch_size_ = rhs.batch_size_;
|
||||
weights_ = rhs.weights_;
|
||||
initialized_ = rhs.initialized_;
|
||||
is_tensor_ = rhs.is_tensor_;
|
||||
}
|
||||
|
||||
nvinfer1::ITensor* TRT_TensorOrWeights::tensor() {
|
||||
CHECK(is_tensor());
|
||||
return tensor_ == nullptr ? simple_itensor_.get() : tensor_;
|
||||
}
|
||||
|
||||
const nvinfer1::ITensor* TRT_TensorOrWeights::tensor() const {
|
||||
CHECK(is_tensor());
|
||||
return tensor_ == nullptr ? simple_itensor_.get() : tensor_;
|
||||
}
|
||||
|
||||
nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const {
|
||||
if (is_tensor()) {
|
||||
return tensor()->getDimensions();
|
||||
@ -398,8 +489,8 @@ nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const {
|
||||
string TRT_TensorOrWeights::DebugString() const {
|
||||
string output = "TRT_TensorOrWeights(type=";
|
||||
if (is_tensor()) {
|
||||
StrAppend(&output, "tensor @", reinterpret_cast<uintptr_t>(tensor_),
|
||||
", shape=", convert::DebugString(tensor_->getDimensions()),
|
||||
StrAppend(&output, "tensor @", reinterpret_cast<uintptr_t>(tensor()),
|
||||
", shape=", convert::DebugString(tensor()->getDimensions()),
|
||||
", batch_size=", batch_size_);
|
||||
} else {
|
||||
StrAppend(&output, "weights=", weights_.DebugString());
|
||||
@ -560,11 +651,10 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
|
||||
// TRT requires GKcRS, while TF depthwise has RSCK where c=1, C=G
|
||||
const int c = iweights.shape_.d[2] / num_groups;
|
||||
const int k = iweights.shape_.d[3] * num_groups;
|
||||
VLOG(2) << "num_groups: " << num_groups
|
||||
<< "c" << iweights.shape_.d[2] << " then " << c
|
||||
<< "k" << iweights.shape_.d[3] << " then " << k
|
||||
<< "r" << iweights.shape_.d[0] << " then " << r
|
||||
<< "s" << iweights.shape_.d[1] << " then " << s;
|
||||
VLOG(2) << "num_groups: " << num_groups << "c" << iweights.shape_.d[2]
|
||||
<< " then " << c << "k" << iweights.shape_.d[3] << " then " << k
|
||||
<< "r" << iweights.shape_.d[0] << " then " << r << "s"
|
||||
<< iweights.shape_.d[1] << " then " << s;
|
||||
oweights->shape_.d[0] = k / num_groups;
|
||||
oweights->shape_.d[1] = c * num_groups;
|
||||
oweights->shape_.d[2] = r;
|
||||
@ -608,9 +698,68 @@ TRT_ShapedWeights TrtWeightStore::GetTempWeights(tensorflow::DataType type,
|
||||
|
||||
TrtNodeValidator::TrtNodeValidator() { RegisterOpValidators(); }
|
||||
|
||||
Status TrtNodeValidator::ConvertToTensorOrWeights(
|
||||
const NodeDef& node_def, int output_port,
|
||||
const grappler::GraphProperties& graph_properties,
|
||||
TRT_TensorOrWeights* tensor_or_weights) {
|
||||
if (node_def.op() == "Const") {
|
||||
if (output_port != 0) {
|
||||
return errors::InvalidArgument("Const node should only have one output.");
|
||||
}
|
||||
// The output of the conversion will be used as input to other nodes to
|
||||
// determine whether TRT supports those nodes. If it cannot convert the
|
||||
// Const, it's very likely we cannot treat it as a tensor and make it an
|
||||
// input to the TRT network, since TRT removes the first dimension and
|
||||
// treats it as batch size. Also, it's not likely that the converter can
|
||||
// support the op, and performance may suffer even if it can, so we just
|
||||
// simply return error if the conversion fails.
|
||||
std::vector<TRT_TensorOrWeights> inputs;
|
||||
return ConvertConstToWeights(node_def, inputs, tensor_or_weights);
|
||||
}
|
||||
if (!graph_properties.HasOutputProperties(node_def.name())) {
|
||||
return errors::InvalidArgument("Shape and data type are unknown");
|
||||
}
|
||||
|
||||
// Validate and convert shape and dtype.
|
||||
const auto& output_params =
|
||||
graph_properties.GetOutputProperties(node_def.name());
|
||||
const auto& tensor_properties = output_params.at(output_port);
|
||||
const DataType dtype = tensor_properties.dtype();
|
||||
const PartialTensorShape shape = tensor_properties.shape();
|
||||
nvinfer1::DataType trt_dtype;
|
||||
nvinfer1::Dims trt_dims;
|
||||
int batch_size = -1;
|
||||
TF_RETURN_IF_ERROR(ValidateTensorProperties(
|
||||
node_def.op(), dtype, shape, /*validation_only_=*/true, &trt_dtype,
|
||||
&trt_dims, &batch_size));
|
||||
|
||||
// Adds a fake ITensor. This is fine since op converter operates in
|
||||
// validation-only mode and it won't (and shouldn't) use the tensor to do
|
||||
// any TRT network operations.
|
||||
*tensor_or_weights = TRT_TensorOrWeights(trt_dtype, trt_dims, batch_size);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TrtNodeValidator::ValidateNode(
|
||||
const tensorflow::NodeDef& node_def,
|
||||
const std::vector<TRT_TensorOrWeights>& inputs) {
|
||||
const std::vector<std::pair<const NodeDef*, int>>& input_node_and_ports,
|
||||
const grappler::GraphProperties& graph_properties) {
|
||||
// Convert input NodeDef and corresponding output ports to
|
||||
// TRT_TensorOrWeights.
|
||||
std::vector<TRT_TensorOrWeights> inputs;
|
||||
for (int i = 0; i < input_node_and_ports.size(); ++i) {
|
||||
const auto& pair = input_node_and_ports[i];
|
||||
TRT_TensorOrWeights tensor_or_weights;
|
||||
Status status = ConvertToTensorOrWeights(
|
||||
*pair.first, pair.second, graph_properties, &tensor_or_weights);
|
||||
if (!status.ok()) {
|
||||
return errors::Internal("Failed to convert input with index ", i,
|
||||
" to a TRT_TensorOrWeights");
|
||||
}
|
||||
inputs.push_back(tensor_or_weights);
|
||||
}
|
||||
|
||||
// Validate the node.
|
||||
const auto iter = op_validators_.find(node_def.op());
|
||||
if (iter == op_validators_.end()) {
|
||||
// If validator is not registered, it means no validation is needed.
|
||||
@ -621,7 +770,19 @@ Status TrtNodeValidator::ValidateNode(
|
||||
OpConverterParams params(
|
||||
/*arg_converter=*/nullptr, node_def, inputs, /*arg_outputs=*/nullptr,
|
||||
/*arg_validation_only=*/true, &weight_store_);
|
||||
Status status = validator(¶ms);
|
||||
return validator(¶ms);
|
||||
}
|
||||
|
||||
Status TrtNodeValidator::ConvertConstToWeights(
|
||||
const NodeDef& const_node_def,
|
||||
const std::vector<TRT_TensorOrWeights>& inputs,
|
||||
TRT_TensorOrWeights* output) {
|
||||
std::vector<TRT_TensorOrWeights> outputs;
|
||||
OpConverterParams params(
|
||||
/*arg_converter=*/nullptr, const_node_def, inputs, &outputs,
|
||||
/*arg_validation_only=*/true, &weight_store_);
|
||||
Status status = op_validators_["Const"](¶ms);
|
||||
if (status.ok() && output) *output = outputs[0];
|
||||
return status;
|
||||
}
|
||||
|
||||
@ -1663,7 +1824,7 @@ tensorflow::Status ConvertActivation(OpConverterParams* params) {
|
||||
}
|
||||
|
||||
tensorflow::Status ConvertScale(OpConverterParams* params) {
|
||||
const auto inputs = params->inputs;
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
|
||||
!inputs.at(1).is_weights()) {
|
||||
@ -1798,8 +1959,13 @@ Status TfTensorToTrtWeights(const DataType dtype, const Tensor& tensor,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Convert a Const NodeDef to TRT_ShapedWeights. This is a special converter, it
|
||||
// always ignores the params->validation_only parameter but adds the converted
|
||||
// weights to params->outputs. We did this since TrtNodeValidator needs the
|
||||
// weights as input to other nodes, and use it to determine whether those nodes
|
||||
// are supported by TRT.
|
||||
tensorflow::Status ConvertConst(OpConverterParams* params) {
|
||||
const auto inputs = params->inputs;
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
if (!inputs.empty()) {
|
||||
return errors::InvalidArgument(
|
||||
@ -1896,11 +2062,10 @@ tensorflow::Status ConvertConst(OpConverterParams* params) {
|
||||
return errors::Unimplemented("Not supported constant type, at ",
|
||||
node_def.name());
|
||||
}
|
||||
// Pass the output.
|
||||
if (!params->validation_only) {
|
||||
if (params->outputs != nullptr) {
|
||||
params->outputs->push_back(TRT_TensorOrWeights(weights));
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status ConvertIdentity(OpConverterParams* params) {
|
||||
@ -1909,7 +2074,7 @@ tensorflow::Status ConvertIdentity(OpConverterParams* params) {
|
||||
}
|
||||
|
||||
tensorflow::Status ConvertBinary(OpConverterParams* params) {
|
||||
const auto inputs = params->inputs;
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
if (inputs.size() != 2) {
|
||||
return tensorflow::errors::FailedPrecondition(
|
||||
@ -2418,19 +2583,20 @@ tensorflow::Status ConvertMatMul(OpConverterParams* params) {
|
||||
// TODO(jie): INT32 should be converted?
|
||||
tensorflow::DataType tf_dtype = attrs.get<tensorflow::DataType>("T");
|
||||
if (tf_dtype != DataType::DT_FLOAT && tf_dtype != DataType::DT_HALF) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"data type is not supported, for node " + node_def.name() + " got " +
|
||||
tensorflow::DataTypeString(tf_dtype));
|
||||
return errors::Unimplemented("Data type is not supported, for node ",
|
||||
node_def.name(), " got ",
|
||||
DataTypeString(tf_dtype));
|
||||
}
|
||||
bool transpose_a = attrs.get<bool>("transpose_a");
|
||||
bool transpose_b = attrs.get<bool>("transpose_b");
|
||||
|
||||
// FullyConnected:
|
||||
if (transpose_a) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Transpose_a is not supported for TensorRT FullyConnected (op: " +
|
||||
node_def.op() + "), at: " + node_def.name());
|
||||
return errors::InvalidArgument(
|
||||
"transpose_a is not supported for TensorRT FullyConnected (op: ",
|
||||
node_def.op(), "), at: ", node_def.name());
|
||||
}
|
||||
if (params->validation_only) return Status::OK();
|
||||
return ConvertMatMulHelper(params, inputs.at(0), inputs.at(1).weights(),
|
||||
transpose_b, node_def.name());
|
||||
}
|
||||
@ -2673,10 +2839,13 @@ tensorflow::Status ConvertGraphDefToEngine(
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Failed to parse slot number from ", node_name);
|
||||
}
|
||||
nvinfer1::DataType dtype;
|
||||
nvinfer1::DataType trt_dtype;
|
||||
nvinfer1::Dims trt_dims;
|
||||
int batch_size = -1;
|
||||
auto shape = input_shapes.at(slot_number);
|
||||
auto status = ValidateInputProperties(
|
||||
shape, node_def.attr().at("dtype").type(), &dtype);
|
||||
auto status = ValidateTensorProperties(
|
||||
node_def.op(), node_def.attr().at("dtype").type(), shape,
|
||||
/*validation_only=*/false, &trt_dtype, &trt_dims, &batch_size);
|
||||
if (!status.ok()) {
|
||||
const string error_message =
|
||||
StrCat("Validation failed for ", node_name, " and input slot ",
|
||||
@ -2684,19 +2853,13 @@ tensorflow::Status ConvertGraphDefToEngine(
|
||||
LOG(WARNING) << error_message;
|
||||
return Status(status.code(), error_message);
|
||||
}
|
||||
|
||||
nvinfer1::Dims input_dim;
|
||||
for (int i = 1; i < shape.dims(); i++) {
|
||||
input_dim.d[i - 1] = shape.dim_size(i);
|
||||
}
|
||||
input_dim.nbDims = shape.dims() - 1;
|
||||
VLOG(2) << "Adding engine input tensor " << node_name << " with shape "
|
||||
<< DebugString(input_dim);
|
||||
<< DebugString(trt_dims);
|
||||
// TODO(laigd): the conversion should always happen at runtime where all
|
||||
// the shapes are known, and we can provide a mode to generate the
|
||||
// engines offline, by calling sess.run() and cache/serialize the engines.
|
||||
TF_RETURN_IF_ERROR(converter.AddInputTensor(node_name, dtype, input_dim,
|
||||
shape.dim_size(0)));
|
||||
TF_RETURN_IF_ERROR(
|
||||
converter.AddInputTensor(node_name, trt_dtype, trt_dims, batch_size));
|
||||
} else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) &&
|
||||
(node_def.op() == "Identity")) {
|
||||
int32 slot_number = -1;
|
||||
@ -2866,34 +3029,6 @@ tensorflow::Status ConvertSegmentToGraphDef(
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const {
|
||||
if (in_edge->IsControlEdge()) return true;
|
||||
PartialTensorShape shape;
|
||||
tensorflow::DataType dtype;
|
||||
GetOutputProperties(graph_properties_, in_edge->src(), in_edge->src_output(),
|
||||
&shape, &dtype);
|
||||
nvinfer1::DataType trt_dtype;
|
||||
Status status = ValidateInputProperties(shape, dtype, &trt_dtype);
|
||||
if (!status.ok()) {
|
||||
VLOG(1) << "--> Need to remove input node " << in_edge->dst()->name()
|
||||
<< ": " << status;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
if (in_edge->src()->type_string() != "Const" &&
|
||||
// Single dimensional input tensor is not supported since the first
|
||||
// dimension is treated as batch dimension.
|
||||
shape.dims() < 2) {
|
||||
VLOG(1) << "--> Need to remove input node " << in_edge->dst()->name()
|
||||
<< " which has an input at port " << in_edge->dst_input() << " with"
|
||||
<< " #dim<2"
|
||||
<< " and is not a const: " << shape;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool OutputEdgeValidator::operator()(const tensorflow::Edge* out_edge) const {
|
||||
if (out_edge->IsControlEdge()) return true;
|
||||
if (out_edge->src()->type_string() == "Const") {
|
||||
|
@ -148,21 +148,6 @@ tensorflow::Status ConvertGraphDefToEngine(
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine>* engine,
|
||||
bool* convert_successfully);
|
||||
|
||||
// Helper class for the segmenter to determine whether an input edge to the TRT
|
||||
// segment is valid.
|
||||
class InputEdgeValidator {
|
||||
public:
|
||||
InputEdgeValidator(const grappler::GraphProperties& graph_properties)
|
||||
: graph_properties_(graph_properties) {}
|
||||
|
||||
// Return true if the specified edge is eligible to be an input edge of the
|
||||
// TRT segment.
|
||||
bool operator()(const tensorflow::Edge* in_edge) const;
|
||||
|
||||
private:
|
||||
const grappler::GraphProperties& graph_properties_;
|
||||
};
|
||||
|
||||
// Helper class for the segmenter to determine whether an output edge from the
|
||||
// TRT segment is valid.
|
||||
class OutputEdgeValidator {
|
||||
@ -245,8 +230,21 @@ class TRT_TensorOrWeights {
|
||||
public:
|
||||
TRT_TensorOrWeights() {}
|
||||
|
||||
// Constructor that makes it an ITensor, doesn't take ownership of 'tensor'.
|
||||
// This is used by Converter when building the TRT network, where the ITensor
|
||||
// is owned by the TRT network being built. See comment for 'tensor_' below.
|
||||
explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor, int batch_size = -1);
|
||||
|
||||
// Constructor that makes it an ITensor by creating one using provided data
|
||||
// type and shape, and takes ownership of the created ITensor. This is used by
|
||||
// TrtNodeValidator to encapsulate the type and shape information for
|
||||
// validation of graph nodes, and the created ITensor is fake and temporary,
|
||||
// and should not be used to build any TRT network. See comment for
|
||||
// 'simple_itensor_' below.
|
||||
explicit TRT_TensorOrWeights(nvinfer1::DataType trt_dtype,
|
||||
const nvinfer1::Dims& trt_dims, int batch_size);
|
||||
|
||||
// Constructor that makes it a TRT_TensorOrWeights.
|
||||
explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights);
|
||||
|
||||
TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs);
|
||||
@ -256,15 +254,9 @@ class TRT_TensorOrWeights {
|
||||
bool is_tensor() const { return initialized_ && is_tensor_; }
|
||||
bool is_weights() const { return initialized_ && !is_tensor_; }
|
||||
|
||||
nvinfer1::ITensor* tensor() {
|
||||
CHECK(is_tensor());
|
||||
return tensor_;
|
||||
}
|
||||
nvinfer1::ITensor* tensor();
|
||||
|
||||
const nvinfer1::ITensor* tensor() const {
|
||||
CHECK(is_tensor());
|
||||
return tensor_;
|
||||
}
|
||||
const nvinfer1::ITensor* tensor() const;
|
||||
|
||||
TRT_ShapedWeights& weights() {
|
||||
CHECK(is_weights());
|
||||
@ -283,9 +275,25 @@ class TRT_TensorOrWeights {
|
||||
string DebugString() const;
|
||||
|
||||
private:
|
||||
class SimpleITensor;
|
||||
|
||||
void set_batch_size(int batch_size) { batch_size_ = batch_size; }
|
||||
|
||||
// When it represents an ITensor, the ITensor can be either passed by the
|
||||
// caller via the constructor that takes an ITensor* as parameter, or be
|
||||
// created as a SimpleITensor.
|
||||
//
|
||||
// In the first case, the ITensor pointer is stored in 'tensor_' below, and
|
||||
// the ITensor itself is not owned by this class. This method is used by
|
||||
// Converter (e.g. AddInputTensor) and op converters during TRT network
|
||||
// construction, where the TRT network owns the ITensor.
|
||||
//
|
||||
// In the second case, the created SimpleITensor is stored in
|
||||
// 'simple_itensor_' below and is owned by this class. SimpleITensor is a fake
|
||||
// implementation of ITensor and is used only by TrtNodeValidator to validate
|
||||
// the graph nodes.
|
||||
nvinfer1::ITensor* tensor_ = nullptr; // Not owned.
|
||||
std::shared_ptr<SimpleITensor> simple_itensor_ = nullptr;
|
||||
|
||||
// First dimension of the TF tensor (NOT tensor_) that is represented by
|
||||
// tensor_ is treated as the "batch dimension" by TRT, and tensor_'s
|
||||
@ -339,13 +347,35 @@ class TrtNodeValidator {
|
||||
public:
|
||||
TrtNodeValidator();
|
||||
|
||||
// Validate the node, and return ok if it's supported by the converter.
|
||||
Status ValidateNode(const NodeDef& node_def,
|
||||
const std::vector<TRT_TensorOrWeights>& inputs);
|
||||
// Validate the node, and return ok if it's supported by TRT.
|
||||
//
|
||||
// - 'node_def' is the node to validate.
|
||||
// - 'input_node_and_ports' are the input NodeDefs and their output ports that
|
||||
// are connected to 'node_def' in the TF graph.
|
||||
// - 'graph_properties' is the GraphProperties of the graph where 'node_def'
|
||||
// belongs. It is used to get the shape and data type information of a
|
||||
// tensor for validation purpose.
|
||||
Status ValidateNode(
|
||||
const NodeDef& node_def,
|
||||
const std::vector<std::pair<const NodeDef*, int>>& input_node_and_ports,
|
||||
const grappler::GraphProperties& graph_properties);
|
||||
|
||||
private:
|
||||
void RegisterOpValidators();
|
||||
|
||||
// Convert a Const node to a TRT_TensorOrWeights.
|
||||
Status ConvertConstToWeights(const NodeDef& const_node_def,
|
||||
const std::vector<TRT_TensorOrWeights>& inputs,
|
||||
TRT_TensorOrWeights* output);
|
||||
|
||||
// Convert the output tensor at 'output_port' of 'node_def' to a
|
||||
// TRT_TensorOrWeights which will be later used as an input to other nodes and
|
||||
// passed to ValidateNode() below.
|
||||
Status ConvertToTensorOrWeights(
|
||||
const NodeDef& node_def, int output_port,
|
||||
const grappler::GraphProperties& graph_properties,
|
||||
TRT_TensorOrWeights* tensor_or_weights);
|
||||
|
||||
// Stores all the validators by op type. If no validator is registered for
|
||||
// specific op, it means no validation is needed and ValidateNode() will
|
||||
// return OK.
|
||||
|
@ -21,6 +21,9 @@ limitations under the License.
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
|
||||
#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h" // NOLINT
|
||||
@ -29,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
@ -45,6 +49,7 @@ namespace convert {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
|
||||
// TODO(laigd): put this into some test utils file.
|
||||
void ExpectStatus(Status status, error::Code code = error::OK,
|
||||
const char* substr = nullptr) {
|
||||
EXPECT_EQ(code, status.code())
|
||||
@ -75,6 +80,23 @@ NodeDef MakeNodeDef(const string& name, const string& op,
|
||||
return node_def;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
NodeDef MakeConstNodeDef(const string& name, const std::vector<T>& vals,
|
||||
const TensorShape& shape) {
|
||||
Scope s = Scope::NewRootScope();
|
||||
Tensor t = ::tensorflow::test::AsTensor<T>(vals, shape);
|
||||
auto const_op = ops::Const(s.WithOpName(name), t);
|
||||
return const_op.node()->def();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
NodeDef MakeConstNodeDef(const string& name, const std::vector<T>& vals) {
|
||||
TensorShape shape;
|
||||
const std::vector<int32> shape_dims = {static_cast<int32>(vals.size())};
|
||||
TF_EXPECT_OK(TensorShapeUtils::MakeShape(shape_dims, &shape));
|
||||
return MakeConstNodeDef(name, vals, shape);
|
||||
}
|
||||
|
||||
bool TrtDimsEquals(const nvinfer1::Dims& lhs, const nvinfer1::Dims& rhs) {
|
||||
if (lhs.nbDims != rhs.nbDims) return false;
|
||||
for (int i = 0; i < lhs.nbDims; ++i) {
|
||||
@ -95,6 +117,19 @@ bool TrtShapedWeightsEquals(const TRT_ShapedWeights& lhs,
|
||||
lhs.GetValues() == rhs.GetValues();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ValidateWeights(const TRT_ShapedWeights& weights,
|
||||
const std::vector<int>& expected_dims,
|
||||
const std::vector<T>& expected_value) {
|
||||
EXPECT_TRUE(TrtDimsEqualsArray(expected_dims, weights.shape_))
|
||||
<< weights.DebugString();
|
||||
ASSERT_EQ(expected_value.size(), weights.count()) << weights.DebugString();
|
||||
const T* actual_values = static_cast<const T*>(weights.GetValues());
|
||||
for (int i = 0; i < expected_value.size(); ++i) {
|
||||
EXPECT_EQ(expected_value[i], actual_values[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Fake ITensor implementation for testing purposes.
|
||||
class FakeITensor : public nvinfer1::ITensor {
|
||||
public:
|
||||
@ -194,32 +229,86 @@ TEST(TRT_ShapedWeights_Test, Basic) {
|
||||
}
|
||||
|
||||
TEST(TRT_TensorOrWeights_Test, Basic) {
|
||||
// Test constructor with no arguments.
|
||||
{
|
||||
TRT_TensorOrWeights tw;
|
||||
TRT_TensorOrWeights copy(tw);
|
||||
TRT_TensorOrWeights assigned;
|
||||
assigned = tw;
|
||||
for (auto ptr : {&tw, ©, &assigned}) {
|
||||
EXPECT_EQ(false, ptr->is_tensor());
|
||||
EXPECT_EQ(false, ptr->is_weights());
|
||||
EXPECT_EQ(-1, ptr->batch_size());
|
||||
}
|
||||
}
|
||||
|
||||
// Test constructor with ITensor and batch size argument.
|
||||
{
|
||||
nvinfer1::Dims dims;
|
||||
dims.nbDims = 1;
|
||||
dims.d[0] = 1;
|
||||
FakeITensor itensor(dims);
|
||||
|
||||
TRT_TensorOrWeights tw(&itensor);
|
||||
EXPECT_EQ(true, tw.is_tensor());
|
||||
EXPECT_EQ(false, tw.is_weights());
|
||||
EXPECT_EQ(&itensor, tw.tensor());
|
||||
EXPECT_TRUE(TrtDimsEqualsArray({1}, tw.GetTrtDims()))
|
||||
<< "- expected: " << DebugString(dims)
|
||||
<< "\n vs\n- actual: " << DebugString(tw.GetTrtDims());
|
||||
}
|
||||
{
|
||||
TRT_ShapedWeights weights(DT_FLOAT);
|
||||
TRT_TensorOrWeights tw(weights);
|
||||
EXPECT_EQ(false, tw.is_tensor());
|
||||
EXPECT_EQ(true, tw.is_weights());
|
||||
EXPECT_TRUE(TrtShapedWeightsEquals(weights, tw.weights()));
|
||||
TRT_TensorOrWeights tw1(&itensor, /*batch_size=*/1);
|
||||
|
||||
for (auto original_ptr : {&tw, &tw1}) {
|
||||
TRT_TensorOrWeights copy(*original_ptr);
|
||||
TRT_TensorOrWeights assigned;
|
||||
assigned = *original_ptr;
|
||||
|
||||
for (auto ptr : {original_ptr, ©, &assigned}) {
|
||||
EXPECT_EQ(true, ptr->is_tensor());
|
||||
EXPECT_EQ(false, ptr->is_weights());
|
||||
if (original_ptr == &tw) {
|
||||
EXPECT_EQ(-1, ptr->batch_size());
|
||||
} else {
|
||||
EXPECT_EQ(1, ptr->batch_size());
|
||||
}
|
||||
EXPECT_EQ(&itensor, ptr->tensor());
|
||||
EXPECT_TRUE(TrtDimsEqualsArray({1}, ptr->GetTrtDims()))
|
||||
<< "- expected: " << DebugString(dims)
|
||||
<< "\n vs\n- actual: " << DebugString(ptr->GetTrtDims());
|
||||
}
|
||||
}
|
||||
}
|
||||
// Test constructor which creates and owns an ITensor.
|
||||
{
|
||||
nvinfer1::Dims dims;
|
||||
dims.nbDims = 0;
|
||||
EXPECT_TRUE(TrtDimsEqualsArray({}, tw.GetTrtDims()))
|
||||
<< "- expected: " << DebugString(dims)
|
||||
<< "\n vs\n- actual: " << DebugString(tw.GetTrtDims());
|
||||
dims.nbDims = 1;
|
||||
dims.d[0] = 1;
|
||||
TRT_TensorOrWeights tw(nvinfer1::DataType::kFLOAT, dims, /*batch_size=*/1);
|
||||
TRT_TensorOrWeights copy(tw);
|
||||
TRT_TensorOrWeights assigned;
|
||||
assigned = tw;
|
||||
|
||||
for (auto ptr : {&tw, ©, &assigned}) {
|
||||
EXPECT_EQ(true, ptr->is_tensor());
|
||||
EXPECT_EQ(false, ptr->is_weights());
|
||||
EXPECT_EQ(1, ptr->batch_size());
|
||||
EXPECT_NE(nullptr, ptr->tensor());
|
||||
EXPECT_TRUE(TrtDimsEqualsArray({1}, ptr->GetTrtDims()))
|
||||
<< "- expected: " << DebugString(dims)
|
||||
<< "\n vs\n- actual: " << DebugString(ptr->GetTrtDims());
|
||||
}
|
||||
}
|
||||
// Test constructor with TRT_ShapedWeights argument.
|
||||
{
|
||||
TRT_ShapedWeights weights;
|
||||
TRT_TensorOrWeights tw(weights);
|
||||
TRT_TensorOrWeights copy(tw);
|
||||
TRT_TensorOrWeights assigned;
|
||||
assigned = tw;
|
||||
for (auto ptr : {&tw, ©, &assigned}) {
|
||||
EXPECT_EQ(false, ptr->is_tensor());
|
||||
EXPECT_EQ(true, ptr->is_weights());
|
||||
EXPECT_TRUE(TrtShapedWeightsEquals(weights, ptr->weights()));
|
||||
|
||||
nvinfer1::Dims dims;
|
||||
dims.nbDims = 0;
|
||||
EXPECT_TRUE(TrtDimsEqualsArray({}, ptr->GetTrtDims()))
|
||||
<< "- expected: " << DebugString(dims)
|
||||
<< "\n vs\n- actual: " << DebugString(ptr->GetTrtDims());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -229,11 +318,64 @@ class ValidatorTest : public ::testing::Test {
|
||||
validator_.op_validators_[op_name] = op_validator;
|
||||
}
|
||||
|
||||
Status ConvertToTensorOrWeights(
|
||||
const NodeDef& node_def, int output_port,
|
||||
const grappler::GraphProperties& graph_properties,
|
||||
TRT_TensorOrWeights* tensor_or_weights) {
|
||||
return validator_.ConvertToTensorOrWeights(
|
||||
node_def, output_port, graph_properties, tensor_or_weights);
|
||||
}
|
||||
|
||||
protected:
|
||||
TrtNodeValidator validator_;
|
||||
};
|
||||
|
||||
TEST_F(ValidatorTest, ConvertToTensorOrWeights) {
|
||||
// Convert Const.
|
||||
{
|
||||
NodeDef node_def = MakeConstNodeDef<float>("my_const", {1.0f, 2.0f});
|
||||
TRT_TensorOrWeights output;
|
||||
grappler::GrapplerItem item;
|
||||
grappler::GraphProperties graph_properties(item);
|
||||
ExpectStatus(ConvertToTensorOrWeights(node_def, /*output_port=*/0,
|
||||
graph_properties, &output));
|
||||
ValidateWeights<float>(output.weights(), {2}, {1.0, 2.0});
|
||||
}
|
||||
// Convert non-Const. We test the case where the non-batch dimemsion is
|
||||
// unknown as well, to make sure the validator allows that.
|
||||
for (const int32 non_batch_dim : {-1, 2}) {
|
||||
const int32 batch_size = 12;
|
||||
|
||||
Scope s = Scope::NewRootScope();
|
||||
ops::Placeholder::Attrs attrs;
|
||||
TF_EXPECT_OK(TensorShapeUtils::MakeShape(
|
||||
std::vector<int32>{batch_size, non_batch_dim}, &attrs.shape_));
|
||||
auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT, attrs);
|
||||
auto add = ops::Add(s.WithOpName("add"), feed, feed);
|
||||
|
||||
grappler::GrapplerItem item;
|
||||
TF_EXPECT_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
grappler::GraphProperties graph_properties(item);
|
||||
TF_EXPECT_OK(graph_properties.InferStatically(true));
|
||||
|
||||
auto& node_def = add.operation.node()->def();
|
||||
TRT_TensorOrWeights output;
|
||||
ExpectStatus(ConvertToTensorOrWeights(node_def, /*output_port=*/0,
|
||||
graph_properties, &output));
|
||||
EXPECT_EQ(true, output.is_tensor());
|
||||
EXPECT_EQ(batch_size, output.batch_size());
|
||||
EXPECT_NE(nullptr, output.tensor());
|
||||
EXPECT_TRUE(TrtDimsEqualsArray({non_batch_dim}, output.GetTrtDims()))
|
||||
<< "- expected: {" << non_batch_dim << "} \n vs\n"
|
||||
<< "- actual: " << DebugString(output.GetTrtDims());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ValidatorTest, ValidateNode) {
|
||||
grappler::GrapplerItem item;
|
||||
grappler::GraphProperties graph_properties(item);
|
||||
|
||||
bool start_conversion = false;
|
||||
bool should_fail = false;
|
||||
auto op_converter = [&start_conversion,
|
||||
@ -245,16 +387,17 @@ TEST_F(ValidatorTest, ValidateNode) {
|
||||
NodeDef node_def = MakeNodeDef("my_op", "MyOp", {});
|
||||
|
||||
// Validator not registered, validation should pass.
|
||||
TF_EXPECT_OK(validator_.ValidateNode(node_def, {}));
|
||||
TF_EXPECT_OK(validator_.ValidateNode(node_def, {}, graph_properties));
|
||||
|
||||
// Register validator.
|
||||
AddOpValidator("MyOp", op_converter);
|
||||
TF_EXPECT_OK(validator_.ValidateNode(node_def, {}));
|
||||
TF_EXPECT_OK(validator_.ValidateNode(node_def, {}, graph_properties));
|
||||
EXPECT_EQ(false, start_conversion);
|
||||
|
||||
// Let the converter return error.
|
||||
should_fail = true;
|
||||
ExpectStatus(validator_.ValidateNode(node_def, {}), error::INVALID_ARGUMENT);
|
||||
ExpectStatus(validator_.ValidateNode(node_def, {}, graph_properties),
|
||||
error::INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
class ConverterTest : public ::testing::Test {
|
||||
@ -289,6 +432,8 @@ class ConverterTest : public ::testing::Test {
|
||||
return converter_->GetInputs(node_def, inputs);
|
||||
}
|
||||
|
||||
int batch_size() const { return converter_->batch_size_; }
|
||||
|
||||
private:
|
||||
Logger logger_;
|
||||
// These members are ordered in a way such that the destruction order is:
|
||||
@ -474,11 +619,48 @@ TEST_F(ConverterTest, PrepareTensorForShape_Weights) {
|
||||
<< DebugString(*output_tensor);
|
||||
}
|
||||
|
||||
TEST_F(ConverterTest, MaybeUpdateBatchSize) {
|
||||
EXPECT_EQ(-1, batch_size());
|
||||
|
||||
TF_EXPECT_OK(MaybeUpdateBatchSize(-1));
|
||||
EXPECT_EQ(-1, batch_size());
|
||||
|
||||
TF_EXPECT_OK(MaybeUpdateBatchSize(123));
|
||||
EXPECT_EQ(123, batch_size());
|
||||
|
||||
TF_EXPECT_OK(MaybeUpdateBatchSize(123));
|
||||
EXPECT_EQ(123, batch_size());
|
||||
|
||||
TF_EXPECT_OK(MaybeUpdateBatchSize(-1));
|
||||
EXPECT_EQ(123, batch_size());
|
||||
|
||||
ExpectStatus(MaybeUpdateBatchSize(124), error::INVALID_ARGUMENT,
|
||||
"Provided batch size does not match converter batch size");
|
||||
}
|
||||
|
||||
TEST_F(ConverterTest, AddAndGetTensorOrWeights) {
|
||||
// Add a tensor.
|
||||
FakeITensor fake_tensor;
|
||||
TRT_TensorOrWeights tensor(&fake_tensor);
|
||||
EXPECT_EQ(-1, tensor.batch_size());
|
||||
TF_EXPECT_OK(MaybeUpdateBatchSize(123));
|
||||
TF_EXPECT_OK(AddTensorOrWeights("my_tensor", tensor));
|
||||
|
||||
// Get the added tensor.
|
||||
TRT_TensorOrWeights added_tensor;
|
||||
TF_EXPECT_OK(GetTensorOrWeights("my_tensor", &added_tensor));
|
||||
EXPECT_EQ(123, added_tensor.batch_size());
|
||||
|
||||
// Add the same tensor again.
|
||||
ExpectStatus(AddTensorOrWeights("my_tensor", tensor), error::ALREADY_EXISTS,
|
||||
"tensor/weights my_tensor already exist");
|
||||
}
|
||||
|
||||
// Class to test various op converters, using both a TrtNodeValidator and
|
||||
// Converter.
|
||||
class OpConverterTest : public ::testing::Test {
|
||||
public:
|
||||
OpConverterTest() {
|
||||
OpConverterTest() : scope_(Scope::NewRootScope()) {
|
||||
QCHECK_EQ(0, cudaStreamCreate(&stream_));
|
||||
Reset();
|
||||
}
|
||||
@ -505,8 +687,8 @@ class OpConverterTest : public ::testing::Test {
|
||||
converter_.reset(new Converter(network_.get(), /*fp16=*/false));
|
||||
|
||||
// Reset other related artifacts.
|
||||
fake_itensors_.clear();
|
||||
fake_tensor_or_weights_.clear();
|
||||
scope_ = Scope::NewRootScope();
|
||||
validator_inputs_.clear();
|
||||
}
|
||||
|
||||
void BuildAndRun(const char* input_name, const std::vector<float>& input_data,
|
||||
@ -551,33 +733,41 @@ class OpConverterTest : public ::testing::Test {
|
||||
}
|
||||
|
||||
// Add ITensor for both validation and conversion.
|
||||
void AddTestTensor(const char* name, const std::vector<int>& dims,
|
||||
int batch_size = 1) {
|
||||
nvinfer1::Dims trt_dims = GetTestDims(dims);
|
||||
// Add FakeITensor for validation.
|
||||
//
|
||||
// TRT cannot add a tensor that has undetermined dims, so we manage the
|
||||
// tensor using a vector. These tensors are used to test validation-only
|
||||
// mode and thus should not be used to build the engine.
|
||||
FakeITensor* fake_itensor = new FakeITensor(trt_dims);
|
||||
fake_itensors_.emplace_back(fake_itensor);
|
||||
fake_tensor_or_weights_[string(name)] =
|
||||
TRT_TensorOrWeights{fake_itensor, batch_size};
|
||||
void AddTestTensor(
|
||||
const char* name, const std::vector<int32>& dims, int batch_size = 1,
|
||||
nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT) {
|
||||
DataType tf_dtype = DT_FLOAT;
|
||||
switch (trt_dtype) {
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
tf_dtype = DT_FLOAT;
|
||||
break;
|
||||
case nvinfer1::DataType::kINT32:
|
||||
tf_dtype = DT_INT32;
|
||||
break;
|
||||
default:
|
||||
ASSERT_TRUE(false) << "Unexpected data type "
|
||||
<< static_cast<int>(trt_dtype);
|
||||
}
|
||||
ops::Placeholder::Attrs attrs;
|
||||
TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &attrs.shape_));
|
||||
attrs.shape_.InsertDim(0, batch_size);
|
||||
auto input = ops::Placeholder(scope_.WithOpName(name), tf_dtype, attrs);
|
||||
validator_inputs_[name] = input.operation.node()->def();
|
||||
|
||||
// Add a real ITensor for conversion conditionally.
|
||||
const nvinfer1::Dims trt_dims = GetTestDims(dims);
|
||||
if (HasStaticShape(trt_dims)) {
|
||||
TF_EXPECT_OK(converter_->AddInputTensor(name, nvinfer1::DataType::kFLOAT,
|
||||
trt_dims, batch_size));
|
||||
TF_EXPECT_OK(
|
||||
converter_->AddInputTensor(name, trt_dtype, trt_dims, batch_size));
|
||||
ASSERT_EQ(batch_size, converter_->batch_size_);
|
||||
}
|
||||
}
|
||||
|
||||
// Add weights for both validation and conversion.
|
||||
template <typename CType>
|
||||
void AddTestWeights(const char* name, const DataType dtype,
|
||||
const std::vector<int>& dims,
|
||||
const std::vector<CType>& values) {
|
||||
QCHECK_EQ(DataTypeToEnum<CType>::v(), dtype);
|
||||
template <typename T>
|
||||
void AddTestWeights(const char* name, const std::vector<int>& dims,
|
||||
const std::vector<T>& values) {
|
||||
const DataType dtype = DataTypeToEnum<T>::v();
|
||||
const nvinfer1::Dims trt_dims = GetTestDims(dims);
|
||||
const int64_t num_elements = TrtDimsNumElements(trt_dims);
|
||||
QCHECK_EQ(num_elements, values.size())
|
||||
@ -585,13 +775,15 @@ class OpConverterTest : public ::testing::Test {
|
||||
TRT_ShapedWeights weights(dtype);
|
||||
if (num_elements) {
|
||||
weights = converter_->weight_store_.GetTempWeights(dtype, trt_dims);
|
||||
QCHECK_EQ(weights.size_bytes(), sizeof(CType) * values.size())
|
||||
<< weights.size_bytes() << " vs " << sizeof(CType) * values.size();
|
||||
QCHECK_EQ(weights.size_bytes(), sizeof(T) * values.size())
|
||||
<< weights.size_bytes() << " vs " << sizeof(T) * values.size();
|
||||
memcpy(const_cast<void*>(weights.GetValues()), values.data(),
|
||||
weights.size_bytes());
|
||||
}
|
||||
// Add weights for validation.
|
||||
fake_tensor_or_weights_[string(name)] = TRT_TensorOrWeights{weights};
|
||||
TensorShape shape;
|
||||
TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &shape));
|
||||
validator_inputs_[name] = MakeConstNodeDef<T>(name, values, shape);
|
||||
// Add weights for conversion.
|
||||
TF_EXPECT_OK(
|
||||
converter_->AddTensorOrWeights(name, TRT_TensorOrWeights{weights}));
|
||||
@ -601,12 +793,18 @@ class OpConverterTest : public ::testing::Test {
|
||||
void RunValidation(const NodeDef& node_def,
|
||||
error::Code expected_code = error::OK,
|
||||
const char* expected_msg_substr = nullptr) {
|
||||
std::vector<TRT_TensorOrWeights> inputs;
|
||||
std::vector<std::pair<const NodeDef*, int>> input_node_and_ports;
|
||||
for (const string& input : node_def.input()) {
|
||||
inputs.emplace_back(fake_tensor_or_weights_[input]);
|
||||
input_node_and_ports.emplace_back(&validator_inputs_[input], 0);
|
||||
}
|
||||
ExpectStatus(validator_->ValidateNode(node_def, inputs), expected_code,
|
||||
expected_msg_substr);
|
||||
grappler::GrapplerItem item;
|
||||
TF_EXPECT_OK(scope_.ToGraphDef(&item.graph));
|
||||
grappler::GraphProperties graph_properties(item);
|
||||
TF_EXPECT_OK(graph_properties.InferStatically(true));
|
||||
|
||||
ExpectStatus(validator_->ValidateNode(node_def, input_node_and_ports,
|
||||
graph_properties),
|
||||
expected_code, expected_msg_substr);
|
||||
}
|
||||
|
||||
void RunConversion(const NodeDef& node_def,
|
||||
@ -637,8 +835,8 @@ class OpConverterTest : public ::testing::Test {
|
||||
TrtUniquePtrType<nvinfer1::INetworkDefinition> network_;
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
|
||||
cudaStream_t stream_;
|
||||
std::vector<std::unique_ptr<FakeITensor>> fake_itensors_;
|
||||
std::unordered_map<string, TRT_TensorOrWeights> fake_tensor_or_weights_;
|
||||
Scope scope_;
|
||||
std::unordered_map<string, NodeDef> validator_inputs_;
|
||||
};
|
||||
|
||||
template <DataType dtype, typename InputCType, typename OutputCType>
|
||||
@ -662,15 +860,7 @@ void TestConvertConst(OpConverterTest* test) {
|
||||
test->RunValidationAndConversion(node_def);
|
||||
TRT_TensorOrWeights output;
|
||||
TF_EXPECT_OK(test->GetTensorOrWeights("my_const", &output));
|
||||
EXPECT_TRUE(TrtDimsEqualsArray(expected_dims, output.weights().shape_))
|
||||
<< output.DebugString();
|
||||
ASSERT_EQ(expected_value.size(), output.weights().count())
|
||||
<< output.DebugString();
|
||||
const OutputCType* actual_values =
|
||||
static_cast<const OutputCType*>(output.weights().GetValues());
|
||||
for (int i = 0; i < expected_value.size(); ++i) {
|
||||
EXPECT_EQ(expected_value[i], actual_values[i]);
|
||||
}
|
||||
ValidateWeights(output.weights(), expected_dims, expected_value);
|
||||
};
|
||||
|
||||
auto& attr = *node_def.mutable_attr();
|
||||
@ -700,8 +890,6 @@ void TestConvertConst(OpConverterTest* test) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(laigd): we should use c++ API to create the nodedef, so any change in
|
||||
// the API will be captured.
|
||||
TEST_F(OpConverterTest, ConvertConst) {
|
||||
{
|
||||
Reset();
|
||||
@ -713,10 +901,9 @@ TEST_F(OpConverterTest, ConvertConst) {
|
||||
}
|
||||
{
|
||||
Reset();
|
||||
NodeDef node_def = MakeNodeDef("my_const", "Const", {});
|
||||
(*node_def.mutable_attr())["dtype"].set_type(DT_DOUBLE);
|
||||
NodeDef node_def = MakeConstNodeDef<double>("my_const", {});
|
||||
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
|
||||
"Unsupported data type");
|
||||
"Unsupported data type double");
|
||||
}
|
||||
|
||||
TestConvertConst<DT_FLOAT, float, float>(this);
|
||||
@ -732,8 +919,14 @@ TEST_F(OpConverterTest, ConvertTranspose) {
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"Input expects tensor and weights, at my_transpose");
|
||||
}
|
||||
NodeDef node_def =
|
||||
MakeNodeDef("my_transpose", "Transpose", {"input", "weights"});
|
||||
|
||||
// Get the NodeDef for Transpose.
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
|
||||
auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32);
|
||||
auto transpose = ops::Transpose(s.WithOpName("my_transpose"), input, weights);
|
||||
const NodeDef& node_def = transpose.operation.node()->def();
|
||||
|
||||
{
|
||||
// Permutation is a tensor, should fail.
|
||||
Reset();
|
||||
@ -747,7 +940,7 @@ TEST_F(OpConverterTest, ConvertTranspose) {
|
||||
// Transpose at batch dimension, should fail.
|
||||
Reset();
|
||||
AddTestTensor("input", {1, 2, 3});
|
||||
AddTestWeights<int32>("weights", DT_INT32, {4}, {1, 0, 2, 3});
|
||||
AddTestWeights<int32>("weights", {4}, {1, 0, 2, 3});
|
||||
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
|
||||
"Transpose at batch dimension is not supported");
|
||||
}
|
||||
@ -755,7 +948,7 @@ TEST_F(OpConverterTest, ConvertTranspose) {
|
||||
// Permutation rank doesn't match, should fail.
|
||||
Reset();
|
||||
AddTestTensor("input", {1, 2, 3});
|
||||
AddTestWeights<int32>("weights", DT_INT32, {3}, {0, 1, 2});
|
||||
AddTestWeights<int32>("weights", {3}, {0, 1, 2});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"Rank of perm for transpose does not match with that of the input.");
|
||||
@ -764,7 +957,7 @@ TEST_F(OpConverterTest, ConvertTranspose) {
|
||||
// Ok.
|
||||
Reset();
|
||||
AddTestTensor("input", {1, 2, 3});
|
||||
AddTestWeights<int32>("weights", DT_INT32, {4}, {0, 3, 1, 2});
|
||||
AddTestWeights<int32>("weights", {4}, {0, 3, 1, 2});
|
||||
RunConversion(node_def);
|
||||
TRT_TensorOrWeights output;
|
||||
TF_EXPECT_OK(GetTensorOrWeights("my_transpose", &output));
|
||||
@ -786,7 +979,14 @@ TEST_F(OpConverterTest, ConvertReshape) {
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"Input expects weights for shape, at my_reshape");
|
||||
}
|
||||
NodeDef node_def = MakeNodeDef("my_reshape", "Reshape", {"input", "weights"});
|
||||
|
||||
// Get the NodeDef for Reshape.
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
|
||||
auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32);
|
||||
auto reshape = ops::Reshape(s.WithOpName("my_reshape"), input, weights);
|
||||
const NodeDef& node_def = reshape.operation.node()->def();
|
||||
|
||||
{
|
||||
// Shape is a tensor, should fail.
|
||||
Reset();
|
||||
@ -800,7 +1000,7 @@ TEST_F(OpConverterTest, ConvertReshape) {
|
||||
// Reshape to scalar, should fail.
|
||||
Reset();
|
||||
AddTestTensor("input", {1, 2, 3});
|
||||
AddTestWeights<int32>("weights", DT_INT32, {}, {});
|
||||
AddTestWeights<int32>("weights", {0}, {});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::UNIMPLEMENTED,
|
||||
"Reshape to shape=[] is not supported, at my_reshape");
|
||||
@ -830,7 +1030,7 @@ TEST_F(OpConverterTest, ConvertReshape) {
|
||||
Reset();
|
||||
const std::vector<int>& dims = params[i].tensor_dims;
|
||||
AddTestTensor("input", dims, params[i].batch_size);
|
||||
AddTestWeights<int32>("weights", DT_INT32, {4}, params[i].shape);
|
||||
AddTestWeights<int32>("weights", {4}, params[i].shape);
|
||||
RunValidationAndConversion(
|
||||
node_def, error::UNIMPLEMENTED,
|
||||
"Reshape on batch dimension is not supported, at my_reshape",
|
||||
@ -847,7 +1047,7 @@ TEST_F(OpConverterTest, ConvertReshape) {
|
||||
for (int i = 0; i < kReshapeOKCases; ++i) {
|
||||
Reset();
|
||||
AddTestTensor("input", ok_params[i].tensor_dims, ok_params[i].batch_size);
|
||||
AddTestWeights<int32>("weights", DT_INT32, {4}, ok_params[i].shape);
|
||||
AddTestWeights<int32>("weights", {4}, ok_params[i].shape);
|
||||
RunConversion(node_def);
|
||||
TRT_TensorOrWeights output;
|
||||
TF_EXPECT_OK(GetTensorOrWeights("my_reshape", &output));
|
||||
@ -869,24 +1069,67 @@ TEST_F(OpConverterTest, ConvertMatMul) {
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"Input expects tensor and weights, at my_matmul");
|
||||
}
|
||||
NodeDef node_def = MakeNodeDef("my_matmul", "MatMul", {"input", "weights"});
|
||||
auto& attr = *node_def.mutable_attr();
|
||||
attr["T"].set_type(DT_FLOAT);
|
||||
attr["transpose_a"].set_b(false);
|
||||
attr["transpose_b"].set_b(false);
|
||||
{
|
||||
AddTestTensor("input", {2}, 1);
|
||||
AddTestWeights<float>("weights", DT_FLOAT, {2, 1}, {3, 5});
|
||||
RunConversion(node_def);
|
||||
TRT_TensorOrWeights output;
|
||||
TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output));
|
||||
EXPECT_TRUE(output.is_tensor());
|
||||
EXPECT_TRUE(TrtDimsEqualsArray({1}, output.tensor()->getDimensions()))
|
||||
<< output.DebugString();
|
||||
|
||||
std::vector<float> output_data(1);
|
||||
BuildAndRun("input", {2, 7}, "my_matmul", &output_data);
|
||||
EXPECT_THAT(output_data, ElementsAre(41));
|
||||
// Get the NodeDef for Reshape.
|
||||
auto get_matmul_nodedef = [](DataType dtype, bool transpose_a,
|
||||
bool transpose_b) -> NodeDef {
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto input = ops::Placeholder(s.WithOpName("input"), dtype);
|
||||
auto weights = ops::Placeholder(s.WithOpName("weights"), dtype);
|
||||
ops::MatMul::Attrs matmul_attrs;
|
||||
matmul_attrs.transpose_a_ = transpose_a;
|
||||
matmul_attrs.transpose_b_ = transpose_b;
|
||||
auto matmul =
|
||||
ops::MatMul(s.WithOpName("my_matmul"), input, weights, matmul_attrs);
|
||||
return matmul.operation.node()->def();
|
||||
};
|
||||
|
||||
{
|
||||
// Unsupported data type.
|
||||
Reset();
|
||||
NodeDef node_def = get_matmul_nodedef(DT_INT32, false, false);
|
||||
AddTestTensor("input", {2}, /*batch_size=*/1, nvinfer1::DataType::kINT32);
|
||||
AddTestWeights<int32>("weights", {2, 1}, {3, 5});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::UNIMPLEMENTED,
|
||||
"Data type is not supported, for node my_matmul got int32");
|
||||
}
|
||||
{
|
||||
// transpose_a is set.
|
||||
for (bool transpose_b : {false, true}) {
|
||||
Reset();
|
||||
NodeDef node_def =
|
||||
get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/true, transpose_b);
|
||||
AddTestTensor("input", {2}, /*batch_size=*/1);
|
||||
AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"transpose_a is not supported for TensorRT FullyConnected");
|
||||
}
|
||||
}
|
||||
{
|
||||
// OK.
|
||||
for (bool transpose_b : {false, true}) {
|
||||
Reset();
|
||||
NodeDef node_def =
|
||||
get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/false, transpose_b);
|
||||
AddTestTensor("input", {2}, /*batch_size=*/1);
|
||||
AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
|
||||
RunConversion(node_def);
|
||||
TRT_TensorOrWeights output;
|
||||
TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output));
|
||||
EXPECT_TRUE(output.is_tensor());
|
||||
EXPECT_TRUE(TrtDimsEqualsArray({2}, output.tensor()->getDimensions()))
|
||||
<< output.DebugString();
|
||||
|
||||
std::vector<float> output_data(2);
|
||||
BuildAndRun("input", {0, 1}, "my_matmul", &output_data);
|
||||
if (transpose_b) {
|
||||
EXPECT_THAT(output_data, ElementsAre(1, 3));
|
||||
} else {
|
||||
EXPECT_THAT(output_data, ElementsAre(2, 3));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -389,7 +389,7 @@ void ContractEdge(SimpleEdge* edge, SimpleGraph* graph,
|
||||
|
||||
tensorflow::Status SegmentGraph(
|
||||
const tensorflow::Graph* tf_graph,
|
||||
const std::function<bool(const tensorflow::Node*)>& candidate_fn,
|
||||
const std::function<Status(const tensorflow::Node*)>& candidate_fn,
|
||||
const std::function<bool(const tensorflow::Edge*)>& input_candidate_fn,
|
||||
const std::function<bool(const tensorflow::Edge*)>& output_candidate_fn,
|
||||
const SegmentOptions& options, SegmentNodesVector* segments) {
|
||||
@ -409,9 +409,16 @@ tensorflow::Status SegmentGraph(
|
||||
std::vector<UnionFind<SimpleNode*>> node_segments;
|
||||
for (int i = 0; i < graph->num_node_ids(); ++i) {
|
||||
SimpleNode* node = graph->FindNodeId(i);
|
||||
if (options.exclude_node_list.count(node->name()) != 0 ||
|
||||
!candidate_fn(node->tf_node())) {
|
||||
if (options.exclude_node_list.count(node->name()) != 0) {
|
||||
VLOG(1) << "Not a TF-TRT candidate: " << node->name()
|
||||
<< " (excluded by segmenter option).";
|
||||
node = nullptr;
|
||||
} else {
|
||||
const Status status = candidate_fn(node->tf_node());
|
||||
if (!status.ok()) {
|
||||
VLOG(1) << "Not a TF-TRT candidate: " << node->name() << ": " << status;
|
||||
node = nullptr;
|
||||
}
|
||||
}
|
||||
node_segments.emplace_back(node);
|
||||
}
|
||||
|
@ -43,7 +43,7 @@ struct SegmentOptions {
|
||||
// Get the subgraphs of a graph that can be handled by TensorRT.
|
||||
//
|
||||
// @param graph tensorflow::Graph of the network
|
||||
// @param candidate_fn A function that returns true for a Node* if
|
||||
// @param candidate_fn A function that returns OK for a Node* if
|
||||
// that node can be handled by TensorRT.
|
||||
// @param segments Returns the TensorRT segments/subgraphs. Each entry
|
||||
// in the vector describes a subgraph by giving a set of the names of
|
||||
@ -51,7 +51,7 @@ struct SegmentOptions {
|
||||
// @return the status.
|
||||
tensorflow::Status SegmentGraph(
|
||||
const tensorflow::Graph* tf_graph,
|
||||
const std::function<bool(const tensorflow::Node*)>& candidate_fn,
|
||||
const std::function<Status(const tensorflow::Node*)>& candidate_fn,
|
||||
const std::function<bool(const tensorflow::Edge*)>& input_candidate_fn,
|
||||
const std::function<bool(const tensorflow::Edge*)>& output_candidate_fn,
|
||||
const SegmentOptions& options, SegmentNodesVector* segments);
|
||||
|
@ -34,10 +34,13 @@ namespace ops = ::tensorflow::ops;
|
||||
|
||||
class SegmentTest : public ::testing::Test {
|
||||
protected:
|
||||
std::function<bool(const tensorflow::Node*)> MakeCandidateFn(
|
||||
std::function<Status(const tensorflow::Node*)> MakeCandidateFn(
|
||||
const std::set<string>& node_names) {
|
||||
return [node_names](const tensorflow::Node* node) -> bool {
|
||||
return node_names.find(node->name()) != node_names.end();
|
||||
return [node_names](const tensorflow::Node* node) -> Status {
|
||||
if (node_names.find(node->name()) != node_names.end()) {
|
||||
return Status::OK();
|
||||
}
|
||||
return errors::NotFound("");
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -118,30 +118,11 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase):
|
||||
"""Return a ConversionParams for test."""
|
||||
return super(BiasaddMatMulTest,
|
||||
self).GetConversionParams(run_params)._replace(
|
||||
max_batch_size=4, maximum_cached_engines=2)
|
||||
|
||||
def _ValidEngines(self):
|
||||
"""Engines expected to build and run."""
|
||||
return ["my_trt_op_0"]
|
||||
|
||||
def _InvalidEngines(self):
|
||||
"""Engines that will cause conversion error at building time."""
|
||||
return ["my_trt_op_1", "my_trt_op_2"]
|
||||
max_batch_size=4, maximum_cached_engines=1)
|
||||
|
||||
def ExpectedEnginesToBuild(self, run_params):
|
||||
"""Return the expected engines to build."""
|
||||
# In dynamic engine mode the engines are built in execution time, not in
|
||||
# conversion time, so build errors occurs later. Here three of the engines
|
||||
# will be failed to built but the corresponding engine op are still created.
|
||||
# TODO(aaroey, jjsjann123): fix this.
|
||||
if (run_params.dynamic_engine and
|
||||
not trt_test.IsQuantizationMode(run_params.precision_mode)):
|
||||
return self._ValidEngines() + self._InvalidEngines()
|
||||
return self._ValidEngines()
|
||||
|
||||
def ExpectedEnginesToRun(self, run_params):
|
||||
"""Return the expected engines to run."""
|
||||
return self._ValidEngines()
|
||||
return ["my_trt_op_0"]
|
||||
|
||||
def ShouldRunTest(self, run_params):
|
||||
"""Whether to run the test."""
|
||||
|
@ -79,7 +79,7 @@ class ReshapeTest(trt_test.TfTrtIntegrationTestBase):
|
||||
def ExpectedEnginesToBuild(self, run_params):
|
||||
"""Return the expected engines to build."""
|
||||
return {
|
||||
"my_trt_op_3": ["reshape-%d" % i for i in range(7)] +
|
||||
"my_trt_op_0": ["reshape-%d" % i for i in range(7)] +
|
||||
["reshape-%d/shape" % i for i in range(7)]
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user