From 486dca315c481d95715b4e6ea8b5e20c0d66ae30 Mon Sep 17 00:00:00 2001 From: Guangda Lai Date: Thu, 1 Nov 2018 10:58:22 -0700 Subject: [PATCH] Verify the node before putting it into a TRT subgraph, so if the TRT converter doesn't support it, it can be excluded earlier. This prevents a single incompatible node from causing the conversion of a large subgraph to fail. Currently this is only added for Transpose, Reshape and Const. Implementation details: - A new class TrtCandidateSelector is added for the validation. Previously it used IsTensorRTCandidate() to pick TRT candidate nodes which only checks the op type; the new TrtCandidateSelector::IsTensorRTCandidate() checks the op type as well as input properties of the node, e.g. const or not, the values of the const, etc. - TrtCandidateSelector uses TrtNodeValidator added earlier to validate the nodes. A new method TrtNodeValidator::GetTensorOrWeights() is added to convert an input of a node to a TRT_TensorOrWeights, and after gathering all input TRT_TensorOrWeights, TrtNodeValidator::ValidateNode() is called to determine whether the node is supported by TRT. - The original InputEdgeValidator class is now removed. It only validates the input nodes of the subgraph, and now its logic is merged into ValidateTensorProperties which will be used by TrtNodeValidator to validate all nodes before putting them in subgraph. PiperOrigin-RevId: 219662352 --- tensorflow/contrib/tensorrt/BUILD | 9 + .../contrib/tensorrt/convert/convert_graph.cc | 41 +- .../contrib/tensorrt/convert/convert_graph.h | 20 + .../tensorrt/convert/convert_graph_test.cc | 78 +++- .../contrib/tensorrt/convert/convert_nodes.cc | 287 ++++++++---- .../contrib/tensorrt/convert/convert_nodes.h | 82 ++-- .../tensorrt/convert/convert_nodes_test.cc | 427 ++++++++++++++---- .../contrib/tensorrt/segment/segment.cc | 13 +- tensorflow/contrib/tensorrt/segment/segment.h | 4 +- .../contrib/tensorrt/segment/segment_test.cc | 9 +- .../tensorrt/test/biasadd_matmul_test.py | 23 +- .../tensorrt/test/reshape_transpose_test.py | 2 +- 12 files changed, 760 insertions(+), 235 deletions(-) diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index c96ca302d9e..20bcd2447e6 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -312,15 +312,20 @@ tf_cuda_cc_test( ], deps = [ ":trt_conversion", + "@com_google_googletest//:gtest", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/clusters:cluster", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_base", "//tensorflow/core:direct_session", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]), @@ -340,6 +345,10 @@ tf_cuda_cc_test( ":trt_conversion", ":trt_plugins", "@com_google_googletest//:gtest", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 26f13b02a89..1f5591fe2a6 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -81,12 +81,13 @@ std::vector GetLoadedTensorRTVersion() { return {ver_major, ver_minor, ver_patch}; } -namespace { +TrtCandidateSelector::TrtCandidateSelector( + const grappler::GraphProperties& graph_properties) + : graph_properties_(graph_properties) {} -bool IsTensorRTCandidate(const tensorflow::Node* node) { +Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { + // TODO(laigd): move this set to TrtNodeValidator where it should belong. // LINT.IfChange - // TODO(jie): Segmentation shouldn't associated with op name. - // Split it into a registration for each kernel. static const std::set candidate_ops = { "Identity", "Snapshot", @@ -127,13 +128,29 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) { "Prod", "Max", "Min", - // TODO(ben,jie): ... }; // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.cc) - return (candidate_ops.count(node->type_string()) || - PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); + const bool is_supported_op_type = + (candidate_ops.count(node->type_string()) || + PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); + if (!is_supported_op_type) { + return errors::Unimplemented("Op type ", node->type_string(), + " is not supported."); + } + + std::vector input_edges; + TF_RETURN_IF_ERROR(node->input_edges(&input_edges)); + std::vector> input_node_and_ports; + for (const Edge* input_edge : input_edges) { + input_node_and_ports.emplace_back(&input_edge->src()->def(), + input_edge->src_output()); + } + return validator_.ValidateNode(node->def(), input_node_and_ports, + graph_properties_); } +namespace { + tensorflow::Status BuildNodeMap( const tensorflow::Graph& graph, std::unordered_map* node_map) { @@ -846,9 +863,15 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { } segment_options.minimum_segment_size = params.minimum_segment_size; tensorflow::tensorrt::segment::SegmentNodesVector initial_segments; + TrtCandidateSelector candidate_selector(*params.graph_properties); TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph( - &graph, IsTensorRTCandidate, InputEdgeValidator(*params.graph_properties), - OutputEdgeValidator(), segment_options, &initial_segments)); + &graph, + std::bind(&TrtCandidateSelector::IsTensorRTCandidate, &candidate_selector, + std::placeholders::_1), + // Input validation is already done by TrtCandidateSelector, so we don't + // need to check the input edges. + [](const Edge* edge) { return true; }, OutputEdgeValidator(), + segment_options, &initial_segments)); if (initial_segments.size() > 1) { VLOG(0) << "MULTIPLE tensorrt candidate conversion: " << initial_segments.size(); diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h index 35252023698..1c9d82105a7 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.h +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h @@ -31,6 +31,26 @@ namespace tensorflow { namespace tensorrt { namespace convert { +// Helper class for the segmenter to determine whether given TF node is +// supported by TRT. +class TrtCandidateSelector { + public: + TrtCandidateSelector(const grappler::GraphProperties& graph_properties); + + // Returns OK iff 'node' is a TF-TRT conversion candidate, which will be added + // to TRT subgraph and later converted into TRT engine. + Status IsTensorRTCandidate(const tensorflow::Node* node); + + private: + // The TF-TRT node converter used to verify whether individual node is + // supported. It will operate in validation-only mode. + TrtNodeValidator validator_; + + // GraphProperties of the graph whose nodes are to be validated by + // IsTensorRTCandidate(). + const grappler::GraphProperties& graph_properties_; +}; + struct ConversionParams { ConversionParams() : input_graph_def(nullptr), diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc b/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc index 8146bed4b05..f10729987fd 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc @@ -15,9 +15,14 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/convert/convert_graph.h" +#include +#include +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/lib/core/status.h" @@ -33,6 +38,76 @@ namespace tensorflow { namespace tensorrt { namespace convert { +// TODO(laigd): put this into some test utils file. +void ExpectStatus(Status status, error::Code code = error::OK, + const char* substr = nullptr) { + EXPECT_EQ(code, status.code()) + << status << " vs expected error code \"" << error::Code_Name(code) + << "\" and message \"" << substr << "\""; + if (substr) { + EXPECT_THAT(status.error_message(), ::testing::HasSubstr(substr)) << status; + } +} + +TEST(TrtCandidateSelector, Basics) { + // Create a graph containing both TRT-compatible and TRT-incompatible nodes + // and use it to test TrtCandidateSelector::IsTensorRTCandidate(). + const std::vector input_shape_array{2, 2}; + TensorShape input_shape; + TF_EXPECT_OK(TensorShapeUtils::MakeShape(input_shape_array, &input_shape)); + + Scope s = Scope::NewRootScope(); + ops::Placeholder::Attrs feed_attrs; + TF_EXPECT_OK( + TensorShapeUtils::MakeShape(input_shape_array, &feed_attrs.shape_)); + + // Compatible input. + auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT, feed_attrs); + auto const_1 = ops::Const(s.WithOpName("const_1"), 1.0f, input_shape); + + // Compatible MatMul. + auto matmul = ops::MatMul(s.WithOpName("matmul"), feed, const_1); + + // Incompatible MatMul. + ops::MatMul::Attrs matmul_attrs; + matmul_attrs.transpose_a_ = true; + auto incompatible_matmul = ops::MatMul(s.WithOpName("incompatible_matmul"), + feed, const_1, matmul_attrs); + + // Unsupported op. + auto unsupported_op = ops::Sin(s.WithOpName("sin"), feed); + + // Incompatible input. + auto incompatible_feed = ops::Placeholder(s.WithOpName("feed"), DT_DOUBLE); + auto const_2 = ops::Const(s.WithOpName("const_2"), 1.0, input_shape); + // Compatible op with incompatible input. + auto matmul_with_incompatible_input = + ops::MatMul(s.WithOpName("matmul_with_incompatible_input"), + incompatible_feed, const_2); + + grappler::GrapplerItem item; + TF_EXPECT_OK(s.ToGraphDef(&item.graph)); + Tensor feed_tensor(DT_FLOAT, input_shape); + item.feed.push_back(std::make_pair("feed", feed_tensor)); + + grappler::GraphProperties graph_properties(item); + TF_EXPECT_OK(graph_properties.InferStatically(true)); + + TrtCandidateSelector selector(graph_properties); + TF_EXPECT_OK(selector.IsTensorRTCandidate(matmul.operation.node())); + ExpectStatus( + selector.IsTensorRTCandidate(incompatible_matmul.operation.node()), + error::INVALID_ARGUMENT, + "transpose_a is not supported for TensorRT FullyConnected " + "(op: MatMul), at: incompatible_matmul"); + ExpectStatus(selector.IsTensorRTCandidate(unsupported_op.operation.node()), + error::UNIMPLEMENTED, "Op type Sin is not supported"); + ExpectStatus(selector.IsTensorRTCandidate( + matmul_with_incompatible_input.operation.node()), + error::INTERNAL, + "Failed to convert input with index 0 to a TRT_TensorOrWeights"); +} + class FakeCluster : public grappler::Cluster { public: FakeCluster() : Cluster(0) {} @@ -48,8 +123,7 @@ class FakeCluster : public grappler::Cluster { } Status Run(const GraphDef& graph_def, const std::vector>& feed, - const std::vector& fetch, - RunMetadata* metadata) override { + const std::vector& fetch, RunMetadata* metadata) override { return Status::OK(); } diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index dcbc75aebf4..a6f954391d3 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -108,6 +108,18 @@ inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype, return tensorflow::Status::OK(); } +template +inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape, + bool ignore_first_dim) { + nvinfer1::Dims trt_dims; + const int offset = (ignore_first_dim ? 1 : 0); + for (int i = offset; i < shape.dims(); i++) { + trt_dims.d[i - offset] = shape.dim_size(i); + } + trt_dims.nbDims = shape.dims() - offset; + return trt_dims; +} + void GetOutputProperties(const grappler::GraphProperties& graph_properties, const Node* node, const int out_port, PartialTensorShape* shape, @@ -137,22 +149,37 @@ void GetInputProperties(const grappler::GraphProperties& graph_properties, } } -tensorflow::Status ValidateInputProperties(const PartialTensorShape& shape, - const tensorflow::DataType dtype, - nvinfer1::DataType* trt_dtype) { - // TODO(aaroey): some of these checks also apply to IsTensorRTCandidate(), so - // put them there instead. +Status ValidateTensorProperties(const string& producer_node_type, + const tensorflow::DataType dtype, + const PartialTensorShape& shape, + bool validation_only, + nvinfer1::DataType* trt_dtype, + nvinfer1::Dims* trt_dims, int* batch_size) { + // Convert data type. TF_RETURN_IF_ERROR(ConvertDType(dtype, trt_dtype)); + + // Convert shape. if (shape.dims() < 0) { - return tensorflow::errors::InvalidArgument("Input tensor rank is unknown."); + return errors::InvalidArgument("Input tensor rank is unknown."); } - if (shape.dims() > 9) { - return tensorflow::errors::OutOfRange( - "Input tensor rank is greater than 8."); + if (shape.dims() > nvinfer1::Dims::MAX_DIMS + 1) { // +1 for batch dim + return errors::OutOfRange("Input tensor rank is greater than ", + nvinfer1::Dims::MAX_DIMS + 1); } + if (producer_node_type != "Const" && shape.dims() < 2) { + return errors::InvalidArgument( + "Input tensor with rank<2 is not supported since the first dimension " + "is treated as batch dimension by TRT"); + } + *trt_dims = TensorShapeToTrtDims(shape, /*ignore_first_dim=*/true); + *batch_size = shape.dim_size(0); + + if (validation_only) return Status::OK(); + // Following are validations at runtime. + for (int d = 1; d < shape.dims(); ++d) { if (shape.dim_size(d) < 0) { - return tensorflow::errors::InvalidArgument( + return errors::InvalidArgument( "Input tensor with shape ", shape.DebugString(), " has an unknown non-batch dimemension at dim ", d); } @@ -358,22 +385,75 @@ string TRT_ShapedWeights::DebugString() const { ", values=", reinterpret_cast(GetValues()), ")"); } +// A fake ITensor implementation used to check whether the TF-TRT converter can +// handle specific node. We only need shape and type information, and the +// converter won't (and shouldn't) use this to build the TRT network. +class TRT_TensorOrWeights::SimpleITensor : public nvinfer1::ITensor { + public: + SimpleITensor(nvinfer1::DataType trt_dtype, const nvinfer1::Dims& trt_dims) + : trt_dtype_(trt_dtype), trt_dims_(trt_dims) {} + + void setName(const char* name) override {} + + const char* getName() const override { return ""; } + + void setDimensions(nvinfer1::Dims dimensions) override { + trt_dims_ = dimensions; + } + + nvinfer1::Dims getDimensions() const override { return trt_dims_; } + + void setType(nvinfer1::DataType trt_dtype) override { + trt_dtype_ = trt_dtype; + } + + nvinfer1::DataType getType() const override { return trt_dtype_; } + + bool isNetworkInput() const override { return false; } + + bool isNetworkOutput() const override { return false; } + + void setBroadcastAcrossBatch(bool broadcastAcrossBatch) override {} + + bool getBroadcastAcrossBatch() const override { return false; } + + nvinfer1::TensorLocation getLocation() const override { + // This is arbitrary, since we don't use it. + return nvinfer1::TensorLocation::kDEVICE; + } + + void setLocation(nvinfer1::TensorLocation location) override {} + +#if NV_TENSORRT_MAJOR >= 5 + bool setDynamicRange(float min, float max) override {} +#endif + + private: + nvinfer1::DataType trt_dtype_; + nvinfer1::Dims trt_dims_; +}; + TRT_TensorOrWeights::TRT_TensorOrWeights(nvinfer1::ITensor* tensor, int batch_size) : tensor_(tensor), batch_size_(batch_size), - weights_(DT_FLOAT), + initialized_(true), + is_tensor_(true) {} + +TRT_TensorOrWeights::TRT_TensorOrWeights(nvinfer1::DataType trt_dtype, + const nvinfer1::Dims& trt_dims, + int batch_size) + : simple_itensor_(new SimpleITensor(trt_dtype, trt_dims)), + batch_size_(batch_size), initialized_(true), is_tensor_(true) {} TRT_TensorOrWeights::TRT_TensorOrWeights(const TRT_ShapedWeights& weights) - : tensor_(nullptr), - weights_(weights), - initialized_(true), - is_tensor_(false) {} + : weights_(weights), initialized_(true), is_tensor_(false) {} TRT_TensorOrWeights::TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs) : tensor_(rhs.tensor_), + simple_itensor_(rhs.simple_itensor_), batch_size_(rhs.batch_size_), weights_(rhs.weights_), initialized_(rhs.initialized_), @@ -381,12 +461,23 @@ TRT_TensorOrWeights::TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs) void TRT_TensorOrWeights::operator=(const TRT_TensorOrWeights& rhs) { tensor_ = rhs.tensor_; + simple_itensor_ = rhs.simple_itensor_; batch_size_ = rhs.batch_size_; weights_ = rhs.weights_; initialized_ = rhs.initialized_; is_tensor_ = rhs.is_tensor_; } +nvinfer1::ITensor* TRT_TensorOrWeights::tensor() { + CHECK(is_tensor()); + return tensor_ == nullptr ? simple_itensor_.get() : tensor_; +} + +const nvinfer1::ITensor* TRT_TensorOrWeights::tensor() const { + CHECK(is_tensor()); + return tensor_ == nullptr ? simple_itensor_.get() : tensor_; +} + nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const { if (is_tensor()) { return tensor()->getDimensions(); @@ -398,8 +489,8 @@ nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const { string TRT_TensorOrWeights::DebugString() const { string output = "TRT_TensorOrWeights(type="; if (is_tensor()) { - StrAppend(&output, "tensor @", reinterpret_cast(tensor_), - ", shape=", convert::DebugString(tensor_->getDimensions()), + StrAppend(&output, "tensor @", reinterpret_cast(tensor()), + ", shape=", convert::DebugString(tensor()->getDimensions()), ", batch_size=", batch_size_); } else { StrAppend(&output, "weights=", weights_.DebugString()); @@ -560,11 +651,10 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights, // TRT requires GKcRS, while TF depthwise has RSCK where c=1, C=G const int c = iweights.shape_.d[2] / num_groups; const int k = iweights.shape_.d[3] * num_groups; - VLOG(2) << "num_groups: " << num_groups - << "c" << iweights.shape_.d[2] << " then " << c - << "k" << iweights.shape_.d[3] << " then " << k - << "r" << iweights.shape_.d[0] << " then " << r - << "s" << iweights.shape_.d[1] << " then " << s; + VLOG(2) << "num_groups: " << num_groups << "c" << iweights.shape_.d[2] + << " then " << c << "k" << iweights.shape_.d[3] << " then " << k + << "r" << iweights.shape_.d[0] << " then " << r << "s" + << iweights.shape_.d[1] << " then " << s; oweights->shape_.d[0] = k / num_groups; oweights->shape_.d[1] = c * num_groups; oweights->shape_.d[2] = r; @@ -608,9 +698,68 @@ TRT_ShapedWeights TrtWeightStore::GetTempWeights(tensorflow::DataType type, TrtNodeValidator::TrtNodeValidator() { RegisterOpValidators(); } +Status TrtNodeValidator::ConvertToTensorOrWeights( + const NodeDef& node_def, int output_port, + const grappler::GraphProperties& graph_properties, + TRT_TensorOrWeights* tensor_or_weights) { + if (node_def.op() == "Const") { + if (output_port != 0) { + return errors::InvalidArgument("Const node should only have one output."); + } + // The output of the conversion will be used as input to other nodes to + // determine whether TRT supports those nodes. If it cannot convert the + // Const, it's very likely we cannot treat it as a tensor and make it an + // input to the TRT network, since TRT removes the first dimension and + // treats it as batch size. Also, it's not likely that the converter can + // support the op, and performance may suffer even if it can, so we just + // simply return error if the conversion fails. + std::vector inputs; + return ConvertConstToWeights(node_def, inputs, tensor_or_weights); + } + if (!graph_properties.HasOutputProperties(node_def.name())) { + return errors::InvalidArgument("Shape and data type are unknown"); + } + + // Validate and convert shape and dtype. + const auto& output_params = + graph_properties.GetOutputProperties(node_def.name()); + const auto& tensor_properties = output_params.at(output_port); + const DataType dtype = tensor_properties.dtype(); + const PartialTensorShape shape = tensor_properties.shape(); + nvinfer1::DataType trt_dtype; + nvinfer1::Dims trt_dims; + int batch_size = -1; + TF_RETURN_IF_ERROR(ValidateTensorProperties( + node_def.op(), dtype, shape, /*validation_only_=*/true, &trt_dtype, + &trt_dims, &batch_size)); + + // Adds a fake ITensor. This is fine since op converter operates in + // validation-only mode and it won't (and shouldn't) use the tensor to do + // any TRT network operations. + *tensor_or_weights = TRT_TensorOrWeights(trt_dtype, trt_dims, batch_size); + return Status::OK(); +} + Status TrtNodeValidator::ValidateNode( const tensorflow::NodeDef& node_def, - const std::vector& inputs) { + const std::vector>& input_node_and_ports, + const grappler::GraphProperties& graph_properties) { + // Convert input NodeDef and corresponding output ports to + // TRT_TensorOrWeights. + std::vector inputs; + for (int i = 0; i < input_node_and_ports.size(); ++i) { + const auto& pair = input_node_and_ports[i]; + TRT_TensorOrWeights tensor_or_weights; + Status status = ConvertToTensorOrWeights( + *pair.first, pair.second, graph_properties, &tensor_or_weights); + if (!status.ok()) { + return errors::Internal("Failed to convert input with index ", i, + " to a TRT_TensorOrWeights"); + } + inputs.push_back(tensor_or_weights); + } + + // Validate the node. const auto iter = op_validators_.find(node_def.op()); if (iter == op_validators_.end()) { // If validator is not registered, it means no validation is needed. @@ -621,7 +770,19 @@ Status TrtNodeValidator::ValidateNode( OpConverterParams params( /*arg_converter=*/nullptr, node_def, inputs, /*arg_outputs=*/nullptr, /*arg_validation_only=*/true, &weight_store_); - Status status = validator(¶ms); + return validator(¶ms); +} + +Status TrtNodeValidator::ConvertConstToWeights( + const NodeDef& const_node_def, + const std::vector& inputs, + TRT_TensorOrWeights* output) { + std::vector outputs; + OpConverterParams params( + /*arg_converter=*/nullptr, const_node_def, inputs, &outputs, + /*arg_validation_only=*/true, &weight_store_); + Status status = op_validators_["Const"](¶ms); + if (status.ok() && output) *output = outputs[0]; return status; } @@ -1663,7 +1824,7 @@ tensorflow::Status ConvertActivation(OpConverterParams* params) { } tensorflow::Status ConvertScale(OpConverterParams* params) { - const auto inputs = params->inputs; + const auto& inputs = params->inputs; const auto& node_def = params->node_def; if (inputs.size() != 2 || !inputs.at(0).is_tensor() || !inputs.at(1).is_weights()) { @@ -1798,8 +1959,13 @@ Status TfTensorToTrtWeights(const DataType dtype, const Tensor& tensor, return Status::OK(); } +// Convert a Const NodeDef to TRT_ShapedWeights. This is a special converter, it +// always ignores the params->validation_only parameter but adds the converted +// weights to params->outputs. We did this since TrtNodeValidator needs the +// weights as input to other nodes, and use it to determine whether those nodes +// are supported by TRT. tensorflow::Status ConvertConst(OpConverterParams* params) { - const auto inputs = params->inputs; + const auto& inputs = params->inputs; const auto& node_def = params->node_def; if (!inputs.empty()) { return errors::InvalidArgument( @@ -1896,11 +2062,10 @@ tensorflow::Status ConvertConst(OpConverterParams* params) { return errors::Unimplemented("Not supported constant type, at ", node_def.name()); } - // Pass the output. - if (!params->validation_only) { + if (params->outputs != nullptr) { params->outputs->push_back(TRT_TensorOrWeights(weights)); } - return tensorflow::Status::OK(); + return Status::OK(); } tensorflow::Status ConvertIdentity(OpConverterParams* params) { @@ -1909,7 +2074,7 @@ tensorflow::Status ConvertIdentity(OpConverterParams* params) { } tensorflow::Status ConvertBinary(OpConverterParams* params) { - const auto inputs = params->inputs; + const auto& inputs = params->inputs; const auto& node_def = params->node_def; if (inputs.size() != 2) { return tensorflow::errors::FailedPrecondition( @@ -2418,19 +2583,20 @@ tensorflow::Status ConvertMatMul(OpConverterParams* params) { // TODO(jie): INT32 should be converted? tensorflow::DataType tf_dtype = attrs.get("T"); if (tf_dtype != DataType::DT_FLOAT && tf_dtype != DataType::DT_HALF) { - return tensorflow::errors::Unimplemented( - "data type is not supported, for node " + node_def.name() + " got " + - tensorflow::DataTypeString(tf_dtype)); + return errors::Unimplemented("Data type is not supported, for node ", + node_def.name(), " got ", + DataTypeString(tf_dtype)); } bool transpose_a = attrs.get("transpose_a"); bool transpose_b = attrs.get("transpose_b"); // FullyConnected: if (transpose_a) { - return tensorflow::errors::Internal( - "Transpose_a is not supported for TensorRT FullyConnected (op: " + - node_def.op() + "), at: " + node_def.name()); + return errors::InvalidArgument( + "transpose_a is not supported for TensorRT FullyConnected (op: ", + node_def.op(), "), at: ", node_def.name()); } + if (params->validation_only) return Status::OK(); return ConvertMatMulHelper(params, inputs.at(0), inputs.at(1).weights(), transpose_b, node_def.name()); } @@ -2673,10 +2839,13 @@ tensorflow::Status ConvertGraphDefToEngine( return tensorflow::errors::InvalidArgument( "Failed to parse slot number from ", node_name); } - nvinfer1::DataType dtype; + nvinfer1::DataType trt_dtype; + nvinfer1::Dims trt_dims; + int batch_size = -1; auto shape = input_shapes.at(slot_number); - auto status = ValidateInputProperties( - shape, node_def.attr().at("dtype").type(), &dtype); + auto status = ValidateTensorProperties( + node_def.op(), node_def.attr().at("dtype").type(), shape, + /*validation_only=*/false, &trt_dtype, &trt_dims, &batch_size); if (!status.ok()) { const string error_message = StrCat("Validation failed for ", node_name, " and input slot ", @@ -2684,19 +2853,13 @@ tensorflow::Status ConvertGraphDefToEngine( LOG(WARNING) << error_message; return Status(status.code(), error_message); } - - nvinfer1::Dims input_dim; - for (int i = 1; i < shape.dims(); i++) { - input_dim.d[i - 1] = shape.dim_size(i); - } - input_dim.nbDims = shape.dims() - 1; VLOG(2) << "Adding engine input tensor " << node_name << " with shape " - << DebugString(input_dim); + << DebugString(trt_dims); // TODO(laigd): the conversion should always happen at runtime where all // the shapes are known, and we can provide a mode to generate the // engines offline, by calling sess.run() and cache/serialize the engines. - TF_RETURN_IF_ERROR(converter.AddInputTensor(node_name, dtype, input_dim, - shape.dim_size(0))); + TF_RETURN_IF_ERROR( + converter.AddInputTensor(node_name, trt_dtype, trt_dims, batch_size)); } else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) && (node_def.op() == "Identity")) { int32 slot_number = -1; @@ -2866,34 +3029,6 @@ tensorflow::Status ConvertSegmentToGraphDef( return tensorflow::Status::OK(); } -bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const { - if (in_edge->IsControlEdge()) return true; - PartialTensorShape shape; - tensorflow::DataType dtype; - GetOutputProperties(graph_properties_, in_edge->src(), in_edge->src_output(), - &shape, &dtype); - nvinfer1::DataType trt_dtype; - Status status = ValidateInputProperties(shape, dtype, &trt_dtype); - if (!status.ok()) { - VLOG(1) << "--> Need to remove input node " << in_edge->dst()->name() - << ": " << status; - return false; - } - - - if (in_edge->src()->type_string() != "Const" && - // Single dimensional input tensor is not supported since the first - // dimension is treated as batch dimension. - shape.dims() < 2) { - VLOG(1) << "--> Need to remove input node " << in_edge->dst()->name() - << " which has an input at port " << in_edge->dst_input() << " with" - << " #dim<2" - << " and is not a const: " << shape; - return false; - } - return true; -} - bool OutputEdgeValidator::operator()(const tensorflow::Edge* out_edge) const { if (out_edge->IsControlEdge()) return true; if (out_edge->src()->type_string() == "Const") { diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h index 699b50b37e3..5cc28b33e7f 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -148,21 +148,6 @@ tensorflow::Status ConvertGraphDefToEngine( TrtUniquePtrType* engine, bool* convert_successfully); -// Helper class for the segmenter to determine whether an input edge to the TRT -// segment is valid. -class InputEdgeValidator { - public: - InputEdgeValidator(const grappler::GraphProperties& graph_properties) - : graph_properties_(graph_properties) {} - - // Return true if the specified edge is eligible to be an input edge of the - // TRT segment. - bool operator()(const tensorflow::Edge* in_edge) const; - - private: - const grappler::GraphProperties& graph_properties_; -}; - // Helper class for the segmenter to determine whether an output edge from the // TRT segment is valid. class OutputEdgeValidator { @@ -245,8 +230,21 @@ class TRT_TensorOrWeights { public: TRT_TensorOrWeights() {} + // Constructor that makes it an ITensor, doesn't take ownership of 'tensor'. + // This is used by Converter when building the TRT network, where the ITensor + // is owned by the TRT network being built. See comment for 'tensor_' below. explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor, int batch_size = -1); + // Constructor that makes it an ITensor by creating one using provided data + // type and shape, and takes ownership of the created ITensor. This is used by + // TrtNodeValidator to encapsulate the type and shape information for + // validation of graph nodes, and the created ITensor is fake and temporary, + // and should not be used to build any TRT network. See comment for + // 'simple_itensor_' below. + explicit TRT_TensorOrWeights(nvinfer1::DataType trt_dtype, + const nvinfer1::Dims& trt_dims, int batch_size); + + // Constructor that makes it a TRT_TensorOrWeights. explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights); TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs); @@ -256,15 +254,9 @@ class TRT_TensorOrWeights { bool is_tensor() const { return initialized_ && is_tensor_; } bool is_weights() const { return initialized_ && !is_tensor_; } - nvinfer1::ITensor* tensor() { - CHECK(is_tensor()); - return tensor_; - } + nvinfer1::ITensor* tensor(); - const nvinfer1::ITensor* tensor() const { - CHECK(is_tensor()); - return tensor_; - } + const nvinfer1::ITensor* tensor() const; TRT_ShapedWeights& weights() { CHECK(is_weights()); @@ -283,9 +275,25 @@ class TRT_TensorOrWeights { string DebugString() const; private: + class SimpleITensor; + void set_batch_size(int batch_size) { batch_size_ = batch_size; } + // When it represents an ITensor, the ITensor can be either passed by the + // caller via the constructor that takes an ITensor* as parameter, or be + // created as a SimpleITensor. + // + // In the first case, the ITensor pointer is stored in 'tensor_' below, and + // the ITensor itself is not owned by this class. This method is used by + // Converter (e.g. AddInputTensor) and op converters during TRT network + // construction, where the TRT network owns the ITensor. + // + // In the second case, the created SimpleITensor is stored in + // 'simple_itensor_' below and is owned by this class. SimpleITensor is a fake + // implementation of ITensor and is used only by TrtNodeValidator to validate + // the graph nodes. nvinfer1::ITensor* tensor_ = nullptr; // Not owned. + std::shared_ptr simple_itensor_ = nullptr; // First dimension of the TF tensor (NOT tensor_) that is represented by // tensor_ is treated as the "batch dimension" by TRT, and tensor_'s @@ -339,13 +347,35 @@ class TrtNodeValidator { public: TrtNodeValidator(); - // Validate the node, and return ok if it's supported by the converter. - Status ValidateNode(const NodeDef& node_def, - const std::vector& inputs); + // Validate the node, and return ok if it's supported by TRT. + // + // - 'node_def' is the node to validate. + // - 'input_node_and_ports' are the input NodeDefs and their output ports that + // are connected to 'node_def' in the TF graph. + // - 'graph_properties' is the GraphProperties of the graph where 'node_def' + // belongs. It is used to get the shape and data type information of a + // tensor for validation purpose. + Status ValidateNode( + const NodeDef& node_def, + const std::vector>& input_node_and_ports, + const grappler::GraphProperties& graph_properties); private: void RegisterOpValidators(); + // Convert a Const node to a TRT_TensorOrWeights. + Status ConvertConstToWeights(const NodeDef& const_node_def, + const std::vector& inputs, + TRT_TensorOrWeights* output); + + // Convert the output tensor at 'output_port' of 'node_def' to a + // TRT_TensorOrWeights which will be later used as an input to other nodes and + // passed to ValidateNode() below. + Status ConvertToTensorOrWeights( + const NodeDef& node_def, int output_port, + const grappler::GraphProperties& graph_properties, + TRT_TensorOrWeights* tensor_or_weights); + // Stores all the validators by op type. If no validator is registered for // specific op, it means no validation is needed and ValidateNode() will // return OK. diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc index bc390743335..c3a39395f3a 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc @@ -21,6 +21,9 @@ limitations under the License. #include #include +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT @@ -29,6 +32,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -45,6 +49,7 @@ namespace convert { using ::testing::ElementsAre; +// TODO(laigd): put this into some test utils file. void ExpectStatus(Status status, error::Code code = error::OK, const char* substr = nullptr) { EXPECT_EQ(code, status.code()) @@ -75,6 +80,23 @@ NodeDef MakeNodeDef(const string& name, const string& op, return node_def; } +template +NodeDef MakeConstNodeDef(const string& name, const std::vector& vals, + const TensorShape& shape) { + Scope s = Scope::NewRootScope(); + Tensor t = ::tensorflow::test::AsTensor(vals, shape); + auto const_op = ops::Const(s.WithOpName(name), t); + return const_op.node()->def(); +} + +template +NodeDef MakeConstNodeDef(const string& name, const std::vector& vals) { + TensorShape shape; + const std::vector shape_dims = {static_cast(vals.size())}; + TF_EXPECT_OK(TensorShapeUtils::MakeShape(shape_dims, &shape)); + return MakeConstNodeDef(name, vals, shape); +} + bool TrtDimsEquals(const nvinfer1::Dims& lhs, const nvinfer1::Dims& rhs) { if (lhs.nbDims != rhs.nbDims) return false; for (int i = 0; i < lhs.nbDims; ++i) { @@ -95,6 +117,19 @@ bool TrtShapedWeightsEquals(const TRT_ShapedWeights& lhs, lhs.GetValues() == rhs.GetValues(); } +template +void ValidateWeights(const TRT_ShapedWeights& weights, + const std::vector& expected_dims, + const std::vector& expected_value) { + EXPECT_TRUE(TrtDimsEqualsArray(expected_dims, weights.shape_)) + << weights.DebugString(); + ASSERT_EQ(expected_value.size(), weights.count()) << weights.DebugString(); + const T* actual_values = static_cast(weights.GetValues()); + for (int i = 0; i < expected_value.size(); ++i) { + EXPECT_EQ(expected_value[i], actual_values[i]); + } +} + // Fake ITensor implementation for testing purposes. class FakeITensor : public nvinfer1::ITensor { public: @@ -194,32 +229,86 @@ TEST(TRT_ShapedWeights_Test, Basic) { } TEST(TRT_TensorOrWeights_Test, Basic) { + // Test constructor with no arguments. + { + TRT_TensorOrWeights tw; + TRT_TensorOrWeights copy(tw); + TRT_TensorOrWeights assigned; + assigned = tw; + for (auto ptr : {&tw, ©, &assigned}) { + EXPECT_EQ(false, ptr->is_tensor()); + EXPECT_EQ(false, ptr->is_weights()); + EXPECT_EQ(-1, ptr->batch_size()); + } + } + + // Test constructor with ITensor and batch size argument. { nvinfer1::Dims dims; dims.nbDims = 1; dims.d[0] = 1; FakeITensor itensor(dims); - TRT_TensorOrWeights tw(&itensor); - EXPECT_EQ(true, tw.is_tensor()); - EXPECT_EQ(false, tw.is_weights()); - EXPECT_EQ(&itensor, tw.tensor()); - EXPECT_TRUE(TrtDimsEqualsArray({1}, tw.GetTrtDims())) - << "- expected: " << DebugString(dims) - << "\n vs\n- actual: " << DebugString(tw.GetTrtDims()); - } - { - TRT_ShapedWeights weights(DT_FLOAT); - TRT_TensorOrWeights tw(weights); - EXPECT_EQ(false, tw.is_tensor()); - EXPECT_EQ(true, tw.is_weights()); - EXPECT_TRUE(TrtShapedWeightsEquals(weights, tw.weights())); + TRT_TensorOrWeights tw1(&itensor, /*batch_size=*/1); + for (auto original_ptr : {&tw, &tw1}) { + TRT_TensorOrWeights copy(*original_ptr); + TRT_TensorOrWeights assigned; + assigned = *original_ptr; + + for (auto ptr : {original_ptr, ©, &assigned}) { + EXPECT_EQ(true, ptr->is_tensor()); + EXPECT_EQ(false, ptr->is_weights()); + if (original_ptr == &tw) { + EXPECT_EQ(-1, ptr->batch_size()); + } else { + EXPECT_EQ(1, ptr->batch_size()); + } + EXPECT_EQ(&itensor, ptr->tensor()); + EXPECT_TRUE(TrtDimsEqualsArray({1}, ptr->GetTrtDims())) + << "- expected: " << DebugString(dims) + << "\n vs\n- actual: " << DebugString(ptr->GetTrtDims()); + } + } + } + // Test constructor which creates and owns an ITensor. + { nvinfer1::Dims dims; - dims.nbDims = 0; - EXPECT_TRUE(TrtDimsEqualsArray({}, tw.GetTrtDims())) - << "- expected: " << DebugString(dims) - << "\n vs\n- actual: " << DebugString(tw.GetTrtDims()); + dims.nbDims = 1; + dims.d[0] = 1; + TRT_TensorOrWeights tw(nvinfer1::DataType::kFLOAT, dims, /*batch_size=*/1); + TRT_TensorOrWeights copy(tw); + TRT_TensorOrWeights assigned; + assigned = tw; + + for (auto ptr : {&tw, ©, &assigned}) { + EXPECT_EQ(true, ptr->is_tensor()); + EXPECT_EQ(false, ptr->is_weights()); + EXPECT_EQ(1, ptr->batch_size()); + EXPECT_NE(nullptr, ptr->tensor()); + EXPECT_TRUE(TrtDimsEqualsArray({1}, ptr->GetTrtDims())) + << "- expected: " << DebugString(dims) + << "\n vs\n- actual: " << DebugString(ptr->GetTrtDims()); + } + } + // Test constructor with TRT_ShapedWeights argument. + { + TRT_ShapedWeights weights; + TRT_TensorOrWeights tw(weights); + TRT_TensorOrWeights copy(tw); + TRT_TensorOrWeights assigned; + assigned = tw; + for (auto ptr : {&tw, ©, &assigned}) { + EXPECT_EQ(false, ptr->is_tensor()); + EXPECT_EQ(true, ptr->is_weights()); + EXPECT_TRUE(TrtShapedWeightsEquals(weights, ptr->weights())); + + nvinfer1::Dims dims; + dims.nbDims = 0; + EXPECT_TRUE(TrtDimsEqualsArray({}, ptr->GetTrtDims())) + << "- expected: " << DebugString(dims) + << "\n vs\n- actual: " << DebugString(ptr->GetTrtDims()); + } } } @@ -229,11 +318,64 @@ class ValidatorTest : public ::testing::Test { validator_.op_validators_[op_name] = op_validator; } + Status ConvertToTensorOrWeights( + const NodeDef& node_def, int output_port, + const grappler::GraphProperties& graph_properties, + TRT_TensorOrWeights* tensor_or_weights) { + return validator_.ConvertToTensorOrWeights( + node_def, output_port, graph_properties, tensor_or_weights); + } + protected: TrtNodeValidator validator_; }; +TEST_F(ValidatorTest, ConvertToTensorOrWeights) { + // Convert Const. + { + NodeDef node_def = MakeConstNodeDef("my_const", {1.0f, 2.0f}); + TRT_TensorOrWeights output; + grappler::GrapplerItem item; + grappler::GraphProperties graph_properties(item); + ExpectStatus(ConvertToTensorOrWeights(node_def, /*output_port=*/0, + graph_properties, &output)); + ValidateWeights(output.weights(), {2}, {1.0, 2.0}); + } + // Convert non-Const. We test the case where the non-batch dimemsion is + // unknown as well, to make sure the validator allows that. + for (const int32 non_batch_dim : {-1, 2}) { + const int32 batch_size = 12; + + Scope s = Scope::NewRootScope(); + ops::Placeholder::Attrs attrs; + TF_EXPECT_OK(TensorShapeUtils::MakeShape( + std::vector{batch_size, non_batch_dim}, &attrs.shape_)); + auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT, attrs); + auto add = ops::Add(s.WithOpName("add"), feed, feed); + + grappler::GrapplerItem item; + TF_EXPECT_OK(s.ToGraphDef(&item.graph)); + + grappler::GraphProperties graph_properties(item); + TF_EXPECT_OK(graph_properties.InferStatically(true)); + + auto& node_def = add.operation.node()->def(); + TRT_TensorOrWeights output; + ExpectStatus(ConvertToTensorOrWeights(node_def, /*output_port=*/0, + graph_properties, &output)); + EXPECT_EQ(true, output.is_tensor()); + EXPECT_EQ(batch_size, output.batch_size()); + EXPECT_NE(nullptr, output.tensor()); + EXPECT_TRUE(TrtDimsEqualsArray({non_batch_dim}, output.GetTrtDims())) + << "- expected: {" << non_batch_dim << "} \n vs\n" + << "- actual: " << DebugString(output.GetTrtDims()); + } +} + TEST_F(ValidatorTest, ValidateNode) { + grappler::GrapplerItem item; + grappler::GraphProperties graph_properties(item); + bool start_conversion = false; bool should_fail = false; auto op_converter = [&start_conversion, @@ -245,16 +387,17 @@ TEST_F(ValidatorTest, ValidateNode) { NodeDef node_def = MakeNodeDef("my_op", "MyOp", {}); // Validator not registered, validation should pass. - TF_EXPECT_OK(validator_.ValidateNode(node_def, {})); + TF_EXPECT_OK(validator_.ValidateNode(node_def, {}, graph_properties)); // Register validator. AddOpValidator("MyOp", op_converter); - TF_EXPECT_OK(validator_.ValidateNode(node_def, {})); + TF_EXPECT_OK(validator_.ValidateNode(node_def, {}, graph_properties)); EXPECT_EQ(false, start_conversion); // Let the converter return error. should_fail = true; - ExpectStatus(validator_.ValidateNode(node_def, {}), error::INVALID_ARGUMENT); + ExpectStatus(validator_.ValidateNode(node_def, {}, graph_properties), + error::INVALID_ARGUMENT); } class ConverterTest : public ::testing::Test { @@ -289,6 +432,8 @@ class ConverterTest : public ::testing::Test { return converter_->GetInputs(node_def, inputs); } + int batch_size() const { return converter_->batch_size_; } + private: Logger logger_; // These members are ordered in a way such that the destruction order is: @@ -474,11 +619,48 @@ TEST_F(ConverterTest, PrepareTensorForShape_Weights) { << DebugString(*output_tensor); } +TEST_F(ConverterTest, MaybeUpdateBatchSize) { + EXPECT_EQ(-1, batch_size()); + + TF_EXPECT_OK(MaybeUpdateBatchSize(-1)); + EXPECT_EQ(-1, batch_size()); + + TF_EXPECT_OK(MaybeUpdateBatchSize(123)); + EXPECT_EQ(123, batch_size()); + + TF_EXPECT_OK(MaybeUpdateBatchSize(123)); + EXPECT_EQ(123, batch_size()); + + TF_EXPECT_OK(MaybeUpdateBatchSize(-1)); + EXPECT_EQ(123, batch_size()); + + ExpectStatus(MaybeUpdateBatchSize(124), error::INVALID_ARGUMENT, + "Provided batch size does not match converter batch size"); +} + +TEST_F(ConverterTest, AddAndGetTensorOrWeights) { + // Add a tensor. + FakeITensor fake_tensor; + TRT_TensorOrWeights tensor(&fake_tensor); + EXPECT_EQ(-1, tensor.batch_size()); + TF_EXPECT_OK(MaybeUpdateBatchSize(123)); + TF_EXPECT_OK(AddTensorOrWeights("my_tensor", tensor)); + + // Get the added tensor. + TRT_TensorOrWeights added_tensor; + TF_EXPECT_OK(GetTensorOrWeights("my_tensor", &added_tensor)); + EXPECT_EQ(123, added_tensor.batch_size()); + + // Add the same tensor again. + ExpectStatus(AddTensorOrWeights("my_tensor", tensor), error::ALREADY_EXISTS, + "tensor/weights my_tensor already exist"); +} + // Class to test various op converters, using both a TrtNodeValidator and // Converter. class OpConverterTest : public ::testing::Test { public: - OpConverterTest() { + OpConverterTest() : scope_(Scope::NewRootScope()) { QCHECK_EQ(0, cudaStreamCreate(&stream_)); Reset(); } @@ -505,8 +687,8 @@ class OpConverterTest : public ::testing::Test { converter_.reset(new Converter(network_.get(), /*fp16=*/false)); // Reset other related artifacts. - fake_itensors_.clear(); - fake_tensor_or_weights_.clear(); + scope_ = Scope::NewRootScope(); + validator_inputs_.clear(); } void BuildAndRun(const char* input_name, const std::vector& input_data, @@ -551,33 +733,41 @@ class OpConverterTest : public ::testing::Test { } // Add ITensor for both validation and conversion. - void AddTestTensor(const char* name, const std::vector& dims, - int batch_size = 1) { - nvinfer1::Dims trt_dims = GetTestDims(dims); - // Add FakeITensor for validation. - // - // TRT cannot add a tensor that has undetermined dims, so we manage the - // tensor using a vector. These tensors are used to test validation-only - // mode and thus should not be used to build the engine. - FakeITensor* fake_itensor = new FakeITensor(trt_dims); - fake_itensors_.emplace_back(fake_itensor); - fake_tensor_or_weights_[string(name)] = - TRT_TensorOrWeights{fake_itensor, batch_size}; + void AddTestTensor( + const char* name, const std::vector& dims, int batch_size = 1, + nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT) { + DataType tf_dtype = DT_FLOAT; + switch (trt_dtype) { + case nvinfer1::DataType::kFLOAT: + tf_dtype = DT_FLOAT; + break; + case nvinfer1::DataType::kINT32: + tf_dtype = DT_INT32; + break; + default: + ASSERT_TRUE(false) << "Unexpected data type " + << static_cast(trt_dtype); + } + ops::Placeholder::Attrs attrs; + TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &attrs.shape_)); + attrs.shape_.InsertDim(0, batch_size); + auto input = ops::Placeholder(scope_.WithOpName(name), tf_dtype, attrs); + validator_inputs_[name] = input.operation.node()->def(); // Add a real ITensor for conversion conditionally. + const nvinfer1::Dims trt_dims = GetTestDims(dims); if (HasStaticShape(trt_dims)) { - TF_EXPECT_OK(converter_->AddInputTensor(name, nvinfer1::DataType::kFLOAT, - trt_dims, batch_size)); + TF_EXPECT_OK( + converter_->AddInputTensor(name, trt_dtype, trt_dims, batch_size)); ASSERT_EQ(batch_size, converter_->batch_size_); } } // Add weights for both validation and conversion. - template - void AddTestWeights(const char* name, const DataType dtype, - const std::vector& dims, - const std::vector& values) { - QCHECK_EQ(DataTypeToEnum::v(), dtype); + template + void AddTestWeights(const char* name, const std::vector& dims, + const std::vector& values) { + const DataType dtype = DataTypeToEnum::v(); const nvinfer1::Dims trt_dims = GetTestDims(dims); const int64_t num_elements = TrtDimsNumElements(trt_dims); QCHECK_EQ(num_elements, values.size()) @@ -585,13 +775,15 @@ class OpConverterTest : public ::testing::Test { TRT_ShapedWeights weights(dtype); if (num_elements) { weights = converter_->weight_store_.GetTempWeights(dtype, trt_dims); - QCHECK_EQ(weights.size_bytes(), sizeof(CType) * values.size()) - << weights.size_bytes() << " vs " << sizeof(CType) * values.size(); + QCHECK_EQ(weights.size_bytes(), sizeof(T) * values.size()) + << weights.size_bytes() << " vs " << sizeof(T) * values.size(); memcpy(const_cast(weights.GetValues()), values.data(), weights.size_bytes()); } // Add weights for validation. - fake_tensor_or_weights_[string(name)] = TRT_TensorOrWeights{weights}; + TensorShape shape; + TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &shape)); + validator_inputs_[name] = MakeConstNodeDef(name, values, shape); // Add weights for conversion. TF_EXPECT_OK( converter_->AddTensorOrWeights(name, TRT_TensorOrWeights{weights})); @@ -601,12 +793,18 @@ class OpConverterTest : public ::testing::Test { void RunValidation(const NodeDef& node_def, error::Code expected_code = error::OK, const char* expected_msg_substr = nullptr) { - std::vector inputs; + std::vector> input_node_and_ports; for (const string& input : node_def.input()) { - inputs.emplace_back(fake_tensor_or_weights_[input]); + input_node_and_ports.emplace_back(&validator_inputs_[input], 0); } - ExpectStatus(validator_->ValidateNode(node_def, inputs), expected_code, - expected_msg_substr); + grappler::GrapplerItem item; + TF_EXPECT_OK(scope_.ToGraphDef(&item.graph)); + grappler::GraphProperties graph_properties(item); + TF_EXPECT_OK(graph_properties.InferStatically(true)); + + ExpectStatus(validator_->ValidateNode(node_def, input_node_and_ports, + graph_properties), + expected_code, expected_msg_substr); } void RunConversion(const NodeDef& node_def, @@ -637,8 +835,8 @@ class OpConverterTest : public ::testing::Test { TrtUniquePtrType network_; TrtUniquePtrType engine_; cudaStream_t stream_; - std::vector> fake_itensors_; - std::unordered_map fake_tensor_or_weights_; + Scope scope_; + std::unordered_map validator_inputs_; }; template @@ -662,15 +860,7 @@ void TestConvertConst(OpConverterTest* test) { test->RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(test->GetTensorOrWeights("my_const", &output)); - EXPECT_TRUE(TrtDimsEqualsArray(expected_dims, output.weights().shape_)) - << output.DebugString(); - ASSERT_EQ(expected_value.size(), output.weights().count()) - << output.DebugString(); - const OutputCType* actual_values = - static_cast(output.weights().GetValues()); - for (int i = 0; i < expected_value.size(); ++i) { - EXPECT_EQ(expected_value[i], actual_values[i]); - } + ValidateWeights(output.weights(), expected_dims, expected_value); }; auto& attr = *node_def.mutable_attr(); @@ -700,8 +890,6 @@ void TestConvertConst(OpConverterTest* test) { } } -// TODO(laigd): we should use c++ API to create the nodedef, so any change in -// the API will be captured. TEST_F(OpConverterTest, ConvertConst) { { Reset(); @@ -713,10 +901,9 @@ TEST_F(OpConverterTest, ConvertConst) { } { Reset(); - NodeDef node_def = MakeNodeDef("my_const", "Const", {}); - (*node_def.mutable_attr())["dtype"].set_type(DT_DOUBLE); + NodeDef node_def = MakeConstNodeDef("my_const", {}); RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, - "Unsupported data type"); + "Unsupported data type double"); } TestConvertConst(this); @@ -732,8 +919,14 @@ TEST_F(OpConverterTest, ConvertTranspose) { node_def, error::INVALID_ARGUMENT, "Input expects tensor and weights, at my_transpose"); } - NodeDef node_def = - MakeNodeDef("my_transpose", "Transpose", {"input", "weights"}); + + // Get the NodeDef for Transpose. + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32); + auto transpose = ops::Transpose(s.WithOpName("my_transpose"), input, weights); + const NodeDef& node_def = transpose.operation.node()->def(); + { // Permutation is a tensor, should fail. Reset(); @@ -747,7 +940,7 @@ TEST_F(OpConverterTest, ConvertTranspose) { // Transpose at batch dimension, should fail. Reset(); AddTestTensor("input", {1, 2, 3}); - AddTestWeights("weights", DT_INT32, {4}, {1, 0, 2, 3}); + AddTestWeights("weights", {4}, {1, 0, 2, 3}); RunValidationAndConversion(node_def, error::UNIMPLEMENTED, "Transpose at batch dimension is not supported"); } @@ -755,7 +948,7 @@ TEST_F(OpConverterTest, ConvertTranspose) { // Permutation rank doesn't match, should fail. Reset(); AddTestTensor("input", {1, 2, 3}); - AddTestWeights("weights", DT_INT32, {3}, {0, 1, 2}); + AddTestWeights("weights", {3}, {0, 1, 2}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, "Rank of perm for transpose does not match with that of the input."); @@ -764,7 +957,7 @@ TEST_F(OpConverterTest, ConvertTranspose) { // Ok. Reset(); AddTestTensor("input", {1, 2, 3}); - AddTestWeights("weights", DT_INT32, {4}, {0, 3, 1, 2}); + AddTestWeights("weights", {4}, {0, 3, 1, 2}); RunConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_transpose", &output)); @@ -786,7 +979,14 @@ TEST_F(OpConverterTest, ConvertReshape) { node_def, error::INVALID_ARGUMENT, "Input expects weights for shape, at my_reshape"); } - NodeDef node_def = MakeNodeDef("my_reshape", "Reshape", {"input", "weights"}); + + // Get the NodeDef for Reshape. + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32); + auto reshape = ops::Reshape(s.WithOpName("my_reshape"), input, weights); + const NodeDef& node_def = reshape.operation.node()->def(); + { // Shape is a tensor, should fail. Reset(); @@ -800,7 +1000,7 @@ TEST_F(OpConverterTest, ConvertReshape) { // Reshape to scalar, should fail. Reset(); AddTestTensor("input", {1, 2, 3}); - AddTestWeights("weights", DT_INT32, {}, {}); + AddTestWeights("weights", {0}, {}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, "Reshape to shape=[] is not supported, at my_reshape"); @@ -830,7 +1030,7 @@ TEST_F(OpConverterTest, ConvertReshape) { Reset(); const std::vector& dims = params[i].tensor_dims; AddTestTensor("input", dims, params[i].batch_size); - AddTestWeights("weights", DT_INT32, {4}, params[i].shape); + AddTestWeights("weights", {4}, params[i].shape); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, "Reshape on batch dimension is not supported, at my_reshape", @@ -847,7 +1047,7 @@ TEST_F(OpConverterTest, ConvertReshape) { for (int i = 0; i < kReshapeOKCases; ++i) { Reset(); AddTestTensor("input", ok_params[i].tensor_dims, ok_params[i].batch_size); - AddTestWeights("weights", DT_INT32, {4}, ok_params[i].shape); + AddTestWeights("weights", {4}, ok_params[i].shape); RunConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_reshape", &output)); @@ -869,24 +1069,67 @@ TEST_F(OpConverterTest, ConvertMatMul) { node_def, error::INVALID_ARGUMENT, "Input expects tensor and weights, at my_matmul"); } - NodeDef node_def = MakeNodeDef("my_matmul", "MatMul", {"input", "weights"}); - auto& attr = *node_def.mutable_attr(); - attr["T"].set_type(DT_FLOAT); - attr["transpose_a"].set_b(false); - attr["transpose_b"].set_b(false); - { - AddTestTensor("input", {2}, 1); - AddTestWeights("weights", DT_FLOAT, {2, 1}, {3, 5}); - RunConversion(node_def); - TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output)); - EXPECT_TRUE(output.is_tensor()); - EXPECT_TRUE(TrtDimsEqualsArray({1}, output.tensor()->getDimensions())) - << output.DebugString(); - std::vector output_data(1); - BuildAndRun("input", {2, 7}, "my_matmul", &output_data); - EXPECT_THAT(output_data, ElementsAre(41)); + // Get the NodeDef for Reshape. + auto get_matmul_nodedef = [](DataType dtype, bool transpose_a, + bool transpose_b) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), dtype); + auto weights = ops::Placeholder(s.WithOpName("weights"), dtype); + ops::MatMul::Attrs matmul_attrs; + matmul_attrs.transpose_a_ = transpose_a; + matmul_attrs.transpose_b_ = transpose_b; + auto matmul = + ops::MatMul(s.WithOpName("my_matmul"), input, weights, matmul_attrs); + return matmul.operation.node()->def(); + }; + + { + // Unsupported data type. + Reset(); + NodeDef node_def = get_matmul_nodedef(DT_INT32, false, false); + AddTestTensor("input", {2}, /*batch_size=*/1, nvinfer1::DataType::kINT32); + AddTestWeights("weights", {2, 1}, {3, 5}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Data type is not supported, for node my_matmul got int32"); + } + { + // transpose_a is set. + for (bool transpose_b : {false, true}) { + Reset(); + NodeDef node_def = + get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/true, transpose_b); + AddTestTensor("input", {2}, /*batch_size=*/1); + AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "transpose_a is not supported for TensorRT FullyConnected"); + } + } + { + // OK. + for (bool transpose_b : {false, true}) { + Reset(); + NodeDef node_def = + get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/false, transpose_b); + AddTestTensor("input", {2}, /*batch_size=*/1); + AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); + RunConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output)); + EXPECT_TRUE(output.is_tensor()); + EXPECT_TRUE(TrtDimsEqualsArray({2}, output.tensor()->getDimensions())) + << output.DebugString(); + + std::vector output_data(2); + BuildAndRun("input", {0, 1}, "my_matmul", &output_data); + if (transpose_b) { + EXPECT_THAT(output_data, ElementsAre(1, 3)); + } else { + EXPECT_THAT(output_data, ElementsAre(2, 3)); + } + } } } diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc index c82d4a01839..4f64b7a9522 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.cc +++ b/tensorflow/contrib/tensorrt/segment/segment.cc @@ -389,7 +389,7 @@ void ContractEdge(SimpleEdge* edge, SimpleGraph* graph, tensorflow::Status SegmentGraph( const tensorflow::Graph* tf_graph, - const std::function& candidate_fn, + const std::function& candidate_fn, const std::function& input_candidate_fn, const std::function& output_candidate_fn, const SegmentOptions& options, SegmentNodesVector* segments) { @@ -409,9 +409,16 @@ tensorflow::Status SegmentGraph( std::vector> node_segments; for (int i = 0; i < graph->num_node_ids(); ++i) { SimpleNode* node = graph->FindNodeId(i); - if (options.exclude_node_list.count(node->name()) != 0 || - !candidate_fn(node->tf_node())) { + if (options.exclude_node_list.count(node->name()) != 0) { + VLOG(1) << "Not a TF-TRT candidate: " << node->name() + << " (excluded by segmenter option)."; node = nullptr; + } else { + const Status status = candidate_fn(node->tf_node()); + if (!status.ok()) { + VLOG(1) << "Not a TF-TRT candidate: " << node->name() << ": " << status; + node = nullptr; + } } node_segments.emplace_back(node); } diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h index 8c44eb782aa..b9693aad1b7 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.h +++ b/tensorflow/contrib/tensorrt/segment/segment.h @@ -43,7 +43,7 @@ struct SegmentOptions { // Get the subgraphs of a graph that can be handled by TensorRT. // // @param graph tensorflow::Graph of the network -// @param candidate_fn A function that returns true for a Node* if +// @param candidate_fn A function that returns OK for a Node* if // that node can be handled by TensorRT. // @param segments Returns the TensorRT segments/subgraphs. Each entry // in the vector describes a subgraph by giving a set of the names of @@ -51,7 +51,7 @@ struct SegmentOptions { // @return the status. tensorflow::Status SegmentGraph( const tensorflow::Graph* tf_graph, - const std::function& candidate_fn, + const std::function& candidate_fn, const std::function& input_candidate_fn, const std::function& output_candidate_fn, const SegmentOptions& options, SegmentNodesVector* segments); diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc index 5937fa8259a..4805ef9c61a 100644 --- a/tensorflow/contrib/tensorrt/segment/segment_test.cc +++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc @@ -34,10 +34,13 @@ namespace ops = ::tensorflow::ops; class SegmentTest : public ::testing::Test { protected: - std::function MakeCandidateFn( + std::function MakeCandidateFn( const std::set& node_names) { - return [node_names](const tensorflow::Node* node) -> bool { - return node_names.find(node->name()) != node_names.end(); + return [node_names](const tensorflow::Node* node) -> Status { + if (node_names.find(node->name()) != node_names.end()) { + return Status::OK(); + } + return errors::NotFound(""); }; } diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py index 7d006b73d53..7545bb9df20 100644 --- a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py +++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py @@ -118,30 +118,11 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase): """Return a ConversionParams for test.""" return super(BiasaddMatMulTest, self).GetConversionParams(run_params)._replace( - max_batch_size=4, maximum_cached_engines=2) - - def _ValidEngines(self): - """Engines expected to build and run.""" - return ["my_trt_op_0"] - - def _InvalidEngines(self): - """Engines that will cause conversion error at building time.""" - return ["my_trt_op_1", "my_trt_op_2"] + max_batch_size=4, maximum_cached_engines=1) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - # In dynamic engine mode the engines are built in execution time, not in - # conversion time, so build errors occurs later. Here three of the engines - # will be failed to built but the corresponding engine op are still created. - # TODO(aaroey, jjsjann123): fix this. - if (run_params.dynamic_engine and - not trt_test.IsQuantizationMode(run_params.precision_mode)): - return self._ValidEngines() + self._InvalidEngines() - return self._ValidEngines() - - def ExpectedEnginesToRun(self, run_params): - """Return the expected engines to run.""" - return self._ValidEngines() + return ["my_trt_op_0"] def ShouldRunTest(self, run_params): """Whether to run the test.""" diff --git a/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py b/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py index 3cf7dadb1f4..bbc724ab18e 100644 --- a/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py +++ b/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py @@ -79,7 +79,7 @@ class ReshapeTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" return { - "my_trt_op_3": ["reshape-%d" % i for i in range(7)] + + "my_trt_op_0": ["reshape-%d" % i for i in range(7)] + ["reshape-%d/shape" % i for i in range(7)] }