diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index c96ca302d9e..20bcd2447e6 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -312,15 +312,20 @@ tf_cuda_cc_test( ], deps = [ ":trt_conversion", + "@com_google_googletest//:gtest", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/clusters:cluster", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_base", "//tensorflow/core:direct_session", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]), @@ -340,6 +345,10 @@ tf_cuda_cc_test( ":trt_conversion", ":trt_plugins", "@com_google_googletest//:gtest", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 26f13b02a89..1f5591fe2a6 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -81,12 +81,13 @@ std::vector GetLoadedTensorRTVersion() { return {ver_major, ver_minor, ver_patch}; } -namespace { +TrtCandidateSelector::TrtCandidateSelector( + const grappler::GraphProperties& graph_properties) + : graph_properties_(graph_properties) {} -bool IsTensorRTCandidate(const tensorflow::Node* node) { +Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { + // TODO(laigd): move this set to TrtNodeValidator where it should belong. // LINT.IfChange - // TODO(jie): Segmentation shouldn't associated with op name. - // Split it into a registration for each kernel. static const std::set candidate_ops = { "Identity", "Snapshot", @@ -127,13 +128,29 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) { "Prod", "Max", "Min", - // TODO(ben,jie): ... }; // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.cc) - return (candidate_ops.count(node->type_string()) || - PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); + const bool is_supported_op_type = + (candidate_ops.count(node->type_string()) || + PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); + if (!is_supported_op_type) { + return errors::Unimplemented("Op type ", node->type_string(), + " is not supported."); + } + + std::vector input_edges; + TF_RETURN_IF_ERROR(node->input_edges(&input_edges)); + std::vector> input_node_and_ports; + for (const Edge* input_edge : input_edges) { + input_node_and_ports.emplace_back(&input_edge->src()->def(), + input_edge->src_output()); + } + return validator_.ValidateNode(node->def(), input_node_and_ports, + graph_properties_); } +namespace { + tensorflow::Status BuildNodeMap( const tensorflow::Graph& graph, std::unordered_map* node_map) { @@ -846,9 +863,15 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { } segment_options.minimum_segment_size = params.minimum_segment_size; tensorflow::tensorrt::segment::SegmentNodesVector initial_segments; + TrtCandidateSelector candidate_selector(*params.graph_properties); TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph( - &graph, IsTensorRTCandidate, InputEdgeValidator(*params.graph_properties), - OutputEdgeValidator(), segment_options, &initial_segments)); + &graph, + std::bind(&TrtCandidateSelector::IsTensorRTCandidate, &candidate_selector, + std::placeholders::_1), + // Input validation is already done by TrtCandidateSelector, so we don't + // need to check the input edges. + [](const Edge* edge) { return true; }, OutputEdgeValidator(), + segment_options, &initial_segments)); if (initial_segments.size() > 1) { VLOG(0) << "MULTIPLE tensorrt candidate conversion: " << initial_segments.size(); diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h index 35252023698..1c9d82105a7 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.h +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h @@ -31,6 +31,26 @@ namespace tensorflow { namespace tensorrt { namespace convert { +// Helper class for the segmenter to determine whether given TF node is +// supported by TRT. +class TrtCandidateSelector { + public: + TrtCandidateSelector(const grappler::GraphProperties& graph_properties); + + // Returns OK iff 'node' is a TF-TRT conversion candidate, which will be added + // to TRT subgraph and later converted into TRT engine. + Status IsTensorRTCandidate(const tensorflow::Node* node); + + private: + // The TF-TRT node converter used to verify whether individual node is + // supported. It will operate in validation-only mode. + TrtNodeValidator validator_; + + // GraphProperties of the graph whose nodes are to be validated by + // IsTensorRTCandidate(). + const grappler::GraphProperties& graph_properties_; +}; + struct ConversionParams { ConversionParams() : input_graph_def(nullptr), diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc b/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc index 8146bed4b05..f10729987fd 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc @@ -15,9 +15,14 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/convert/convert_graph.h" +#include +#include +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/lib/core/status.h" @@ -33,6 +38,76 @@ namespace tensorflow { namespace tensorrt { namespace convert { +// TODO(laigd): put this into some test utils file. +void ExpectStatus(Status status, error::Code code = error::OK, + const char* substr = nullptr) { + EXPECT_EQ(code, status.code()) + << status << " vs expected error code \"" << error::Code_Name(code) + << "\" and message \"" << substr << "\""; + if (substr) { + EXPECT_THAT(status.error_message(), ::testing::HasSubstr(substr)) << status; + } +} + +TEST(TrtCandidateSelector, Basics) { + // Create a graph containing both TRT-compatible and TRT-incompatible nodes + // and use it to test TrtCandidateSelector::IsTensorRTCandidate(). + const std::vector input_shape_array{2, 2}; + TensorShape input_shape; + TF_EXPECT_OK(TensorShapeUtils::MakeShape(input_shape_array, &input_shape)); + + Scope s = Scope::NewRootScope(); + ops::Placeholder::Attrs feed_attrs; + TF_EXPECT_OK( + TensorShapeUtils::MakeShape(input_shape_array, &feed_attrs.shape_)); + + // Compatible input. + auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT, feed_attrs); + auto const_1 = ops::Const(s.WithOpName("const_1"), 1.0f, input_shape); + + // Compatible MatMul. + auto matmul = ops::MatMul(s.WithOpName("matmul"), feed, const_1); + + // Incompatible MatMul. + ops::MatMul::Attrs matmul_attrs; + matmul_attrs.transpose_a_ = true; + auto incompatible_matmul = ops::MatMul(s.WithOpName("incompatible_matmul"), + feed, const_1, matmul_attrs); + + // Unsupported op. + auto unsupported_op = ops::Sin(s.WithOpName("sin"), feed); + + // Incompatible input. + auto incompatible_feed = ops::Placeholder(s.WithOpName("feed"), DT_DOUBLE); + auto const_2 = ops::Const(s.WithOpName("const_2"), 1.0, input_shape); + // Compatible op with incompatible input. + auto matmul_with_incompatible_input = + ops::MatMul(s.WithOpName("matmul_with_incompatible_input"), + incompatible_feed, const_2); + + grappler::GrapplerItem item; + TF_EXPECT_OK(s.ToGraphDef(&item.graph)); + Tensor feed_tensor(DT_FLOAT, input_shape); + item.feed.push_back(std::make_pair("feed", feed_tensor)); + + grappler::GraphProperties graph_properties(item); + TF_EXPECT_OK(graph_properties.InferStatically(true)); + + TrtCandidateSelector selector(graph_properties); + TF_EXPECT_OK(selector.IsTensorRTCandidate(matmul.operation.node())); + ExpectStatus( + selector.IsTensorRTCandidate(incompatible_matmul.operation.node()), + error::INVALID_ARGUMENT, + "transpose_a is not supported for TensorRT FullyConnected " + "(op: MatMul), at: incompatible_matmul"); + ExpectStatus(selector.IsTensorRTCandidate(unsupported_op.operation.node()), + error::UNIMPLEMENTED, "Op type Sin is not supported"); + ExpectStatus(selector.IsTensorRTCandidate( + matmul_with_incompatible_input.operation.node()), + error::INTERNAL, + "Failed to convert input with index 0 to a TRT_TensorOrWeights"); +} + class FakeCluster : public grappler::Cluster { public: FakeCluster() : Cluster(0) {} @@ -48,8 +123,7 @@ class FakeCluster : public grappler::Cluster { } Status Run(const GraphDef& graph_def, const std::vector>& feed, - const std::vector& fetch, - RunMetadata* metadata) override { + const std::vector& fetch, RunMetadata* metadata) override { return Status::OK(); } diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index dcbc75aebf4..a6f954391d3 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -108,6 +108,18 @@ inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype, return tensorflow::Status::OK(); } +template +inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape, + bool ignore_first_dim) { + nvinfer1::Dims trt_dims; + const int offset = (ignore_first_dim ? 1 : 0); + for (int i = offset; i < shape.dims(); i++) { + trt_dims.d[i - offset] = shape.dim_size(i); + } + trt_dims.nbDims = shape.dims() - offset; + return trt_dims; +} + void GetOutputProperties(const grappler::GraphProperties& graph_properties, const Node* node, const int out_port, PartialTensorShape* shape, @@ -137,22 +149,37 @@ void GetInputProperties(const grappler::GraphProperties& graph_properties, } } -tensorflow::Status ValidateInputProperties(const PartialTensorShape& shape, - const tensorflow::DataType dtype, - nvinfer1::DataType* trt_dtype) { - // TODO(aaroey): some of these checks also apply to IsTensorRTCandidate(), so - // put them there instead. +Status ValidateTensorProperties(const string& producer_node_type, + const tensorflow::DataType dtype, + const PartialTensorShape& shape, + bool validation_only, + nvinfer1::DataType* trt_dtype, + nvinfer1::Dims* trt_dims, int* batch_size) { + // Convert data type. TF_RETURN_IF_ERROR(ConvertDType(dtype, trt_dtype)); + + // Convert shape. if (shape.dims() < 0) { - return tensorflow::errors::InvalidArgument("Input tensor rank is unknown."); + return errors::InvalidArgument("Input tensor rank is unknown."); } - if (shape.dims() > 9) { - return tensorflow::errors::OutOfRange( - "Input tensor rank is greater than 8."); + if (shape.dims() > nvinfer1::Dims::MAX_DIMS + 1) { // +1 for batch dim + return errors::OutOfRange("Input tensor rank is greater than ", + nvinfer1::Dims::MAX_DIMS + 1); } + if (producer_node_type != "Const" && shape.dims() < 2) { + return errors::InvalidArgument( + "Input tensor with rank<2 is not supported since the first dimension " + "is treated as batch dimension by TRT"); + } + *trt_dims = TensorShapeToTrtDims(shape, /*ignore_first_dim=*/true); + *batch_size = shape.dim_size(0); + + if (validation_only) return Status::OK(); + // Following are validations at runtime. + for (int d = 1; d < shape.dims(); ++d) { if (shape.dim_size(d) < 0) { - return tensorflow::errors::InvalidArgument( + return errors::InvalidArgument( "Input tensor with shape ", shape.DebugString(), " has an unknown non-batch dimemension at dim ", d); } @@ -358,22 +385,75 @@ string TRT_ShapedWeights::DebugString() const { ", values=", reinterpret_cast(GetValues()), ")"); } +// A fake ITensor implementation used to check whether the TF-TRT converter can +// handle specific node. We only need shape and type information, and the +// converter won't (and shouldn't) use this to build the TRT network. +class TRT_TensorOrWeights::SimpleITensor : public nvinfer1::ITensor { + public: + SimpleITensor(nvinfer1::DataType trt_dtype, const nvinfer1::Dims& trt_dims) + : trt_dtype_(trt_dtype), trt_dims_(trt_dims) {} + + void setName(const char* name) override {} + + const char* getName() const override { return ""; } + + void setDimensions(nvinfer1::Dims dimensions) override { + trt_dims_ = dimensions; + } + + nvinfer1::Dims getDimensions() const override { return trt_dims_; } + + void setType(nvinfer1::DataType trt_dtype) override { + trt_dtype_ = trt_dtype; + } + + nvinfer1::DataType getType() const override { return trt_dtype_; } + + bool isNetworkInput() const override { return false; } + + bool isNetworkOutput() const override { return false; } + + void setBroadcastAcrossBatch(bool broadcastAcrossBatch) override {} + + bool getBroadcastAcrossBatch() const override { return false; } + + nvinfer1::TensorLocation getLocation() const override { + // This is arbitrary, since we don't use it. + return nvinfer1::TensorLocation::kDEVICE; + } + + void setLocation(nvinfer1::TensorLocation location) override {} + +#if NV_TENSORRT_MAJOR >= 5 + bool setDynamicRange(float min, float max) override {} +#endif + + private: + nvinfer1::DataType trt_dtype_; + nvinfer1::Dims trt_dims_; +}; + TRT_TensorOrWeights::TRT_TensorOrWeights(nvinfer1::ITensor* tensor, int batch_size) : tensor_(tensor), batch_size_(batch_size), - weights_(DT_FLOAT), + initialized_(true), + is_tensor_(true) {} + +TRT_TensorOrWeights::TRT_TensorOrWeights(nvinfer1::DataType trt_dtype, + const nvinfer1::Dims& trt_dims, + int batch_size) + : simple_itensor_(new SimpleITensor(trt_dtype, trt_dims)), + batch_size_(batch_size), initialized_(true), is_tensor_(true) {} TRT_TensorOrWeights::TRT_TensorOrWeights(const TRT_ShapedWeights& weights) - : tensor_(nullptr), - weights_(weights), - initialized_(true), - is_tensor_(false) {} + : weights_(weights), initialized_(true), is_tensor_(false) {} TRT_TensorOrWeights::TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs) : tensor_(rhs.tensor_), + simple_itensor_(rhs.simple_itensor_), batch_size_(rhs.batch_size_), weights_(rhs.weights_), initialized_(rhs.initialized_), @@ -381,12 +461,23 @@ TRT_TensorOrWeights::TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs) void TRT_TensorOrWeights::operator=(const TRT_TensorOrWeights& rhs) { tensor_ = rhs.tensor_; + simple_itensor_ = rhs.simple_itensor_; batch_size_ = rhs.batch_size_; weights_ = rhs.weights_; initialized_ = rhs.initialized_; is_tensor_ = rhs.is_tensor_; } +nvinfer1::ITensor* TRT_TensorOrWeights::tensor() { + CHECK(is_tensor()); + return tensor_ == nullptr ? simple_itensor_.get() : tensor_; +} + +const nvinfer1::ITensor* TRT_TensorOrWeights::tensor() const { + CHECK(is_tensor()); + return tensor_ == nullptr ? simple_itensor_.get() : tensor_; +} + nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const { if (is_tensor()) { return tensor()->getDimensions(); @@ -398,8 +489,8 @@ nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const { string TRT_TensorOrWeights::DebugString() const { string output = "TRT_TensorOrWeights(type="; if (is_tensor()) { - StrAppend(&output, "tensor @", reinterpret_cast(tensor_), - ", shape=", convert::DebugString(tensor_->getDimensions()), + StrAppend(&output, "tensor @", reinterpret_cast(tensor()), + ", shape=", convert::DebugString(tensor()->getDimensions()), ", batch_size=", batch_size_); } else { StrAppend(&output, "weights=", weights_.DebugString()); @@ -560,11 +651,10 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights, // TRT requires GKcRS, while TF depthwise has RSCK where c=1, C=G const int c = iweights.shape_.d[2] / num_groups; const int k = iweights.shape_.d[3] * num_groups; - VLOG(2) << "num_groups: " << num_groups - << "c" << iweights.shape_.d[2] << " then " << c - << "k" << iweights.shape_.d[3] << " then " << k - << "r" << iweights.shape_.d[0] << " then " << r - << "s" << iweights.shape_.d[1] << " then " << s; + VLOG(2) << "num_groups: " << num_groups << "c" << iweights.shape_.d[2] + << " then " << c << "k" << iweights.shape_.d[3] << " then " << k + << "r" << iweights.shape_.d[0] << " then " << r << "s" + << iweights.shape_.d[1] << " then " << s; oweights->shape_.d[0] = k / num_groups; oweights->shape_.d[1] = c * num_groups; oweights->shape_.d[2] = r; @@ -608,9 +698,68 @@ TRT_ShapedWeights TrtWeightStore::GetTempWeights(tensorflow::DataType type, TrtNodeValidator::TrtNodeValidator() { RegisterOpValidators(); } +Status TrtNodeValidator::ConvertToTensorOrWeights( + const NodeDef& node_def, int output_port, + const grappler::GraphProperties& graph_properties, + TRT_TensorOrWeights* tensor_or_weights) { + if (node_def.op() == "Const") { + if (output_port != 0) { + return errors::InvalidArgument("Const node should only have one output."); + } + // The output of the conversion will be used as input to other nodes to + // determine whether TRT supports those nodes. If it cannot convert the + // Const, it's very likely we cannot treat it as a tensor and make it an + // input to the TRT network, since TRT removes the first dimension and + // treats it as batch size. Also, it's not likely that the converter can + // support the op, and performance may suffer even if it can, so we just + // simply return error if the conversion fails. + std::vector inputs; + return ConvertConstToWeights(node_def, inputs, tensor_or_weights); + } + if (!graph_properties.HasOutputProperties(node_def.name())) { + return errors::InvalidArgument("Shape and data type are unknown"); + } + + // Validate and convert shape and dtype. + const auto& output_params = + graph_properties.GetOutputProperties(node_def.name()); + const auto& tensor_properties = output_params.at(output_port); + const DataType dtype = tensor_properties.dtype(); + const PartialTensorShape shape = tensor_properties.shape(); + nvinfer1::DataType trt_dtype; + nvinfer1::Dims trt_dims; + int batch_size = -1; + TF_RETURN_IF_ERROR(ValidateTensorProperties( + node_def.op(), dtype, shape, /*validation_only_=*/true, &trt_dtype, + &trt_dims, &batch_size)); + + // Adds a fake ITensor. This is fine since op converter operates in + // validation-only mode and it won't (and shouldn't) use the tensor to do + // any TRT network operations. + *tensor_or_weights = TRT_TensorOrWeights(trt_dtype, trt_dims, batch_size); + return Status::OK(); +} + Status TrtNodeValidator::ValidateNode( const tensorflow::NodeDef& node_def, - const std::vector& inputs) { + const std::vector>& input_node_and_ports, + const grappler::GraphProperties& graph_properties) { + // Convert input NodeDef and corresponding output ports to + // TRT_TensorOrWeights. + std::vector inputs; + for (int i = 0; i < input_node_and_ports.size(); ++i) { + const auto& pair = input_node_and_ports[i]; + TRT_TensorOrWeights tensor_or_weights; + Status status = ConvertToTensorOrWeights( + *pair.first, pair.second, graph_properties, &tensor_or_weights); + if (!status.ok()) { + return errors::Internal("Failed to convert input with index ", i, + " to a TRT_TensorOrWeights"); + } + inputs.push_back(tensor_or_weights); + } + + // Validate the node. const auto iter = op_validators_.find(node_def.op()); if (iter == op_validators_.end()) { // If validator is not registered, it means no validation is needed. @@ -621,7 +770,19 @@ Status TrtNodeValidator::ValidateNode( OpConverterParams params( /*arg_converter=*/nullptr, node_def, inputs, /*arg_outputs=*/nullptr, /*arg_validation_only=*/true, &weight_store_); - Status status = validator(¶ms); + return validator(¶ms); +} + +Status TrtNodeValidator::ConvertConstToWeights( + const NodeDef& const_node_def, + const std::vector& inputs, + TRT_TensorOrWeights* output) { + std::vector outputs; + OpConverterParams params( + /*arg_converter=*/nullptr, const_node_def, inputs, &outputs, + /*arg_validation_only=*/true, &weight_store_); + Status status = op_validators_["Const"](¶ms); + if (status.ok() && output) *output = outputs[0]; return status; } @@ -1663,7 +1824,7 @@ tensorflow::Status ConvertActivation(OpConverterParams* params) { } tensorflow::Status ConvertScale(OpConverterParams* params) { - const auto inputs = params->inputs; + const auto& inputs = params->inputs; const auto& node_def = params->node_def; if (inputs.size() != 2 || !inputs.at(0).is_tensor() || !inputs.at(1).is_weights()) { @@ -1798,8 +1959,13 @@ Status TfTensorToTrtWeights(const DataType dtype, const Tensor& tensor, return Status::OK(); } +// Convert a Const NodeDef to TRT_ShapedWeights. This is a special converter, it +// always ignores the params->validation_only parameter but adds the converted +// weights to params->outputs. We did this since TrtNodeValidator needs the +// weights as input to other nodes, and use it to determine whether those nodes +// are supported by TRT. tensorflow::Status ConvertConst(OpConverterParams* params) { - const auto inputs = params->inputs; + const auto& inputs = params->inputs; const auto& node_def = params->node_def; if (!inputs.empty()) { return errors::InvalidArgument( @@ -1896,11 +2062,10 @@ tensorflow::Status ConvertConst(OpConverterParams* params) { return errors::Unimplemented("Not supported constant type, at ", node_def.name()); } - // Pass the output. - if (!params->validation_only) { + if (params->outputs != nullptr) { params->outputs->push_back(TRT_TensorOrWeights(weights)); } - return tensorflow::Status::OK(); + return Status::OK(); } tensorflow::Status ConvertIdentity(OpConverterParams* params) { @@ -1909,7 +2074,7 @@ tensorflow::Status ConvertIdentity(OpConverterParams* params) { } tensorflow::Status ConvertBinary(OpConverterParams* params) { - const auto inputs = params->inputs; + const auto& inputs = params->inputs; const auto& node_def = params->node_def; if (inputs.size() != 2) { return tensorflow::errors::FailedPrecondition( @@ -2418,19 +2583,20 @@ tensorflow::Status ConvertMatMul(OpConverterParams* params) { // TODO(jie): INT32 should be converted? tensorflow::DataType tf_dtype = attrs.get("T"); if (tf_dtype != DataType::DT_FLOAT && tf_dtype != DataType::DT_HALF) { - return tensorflow::errors::Unimplemented( - "data type is not supported, for node " + node_def.name() + " got " + - tensorflow::DataTypeString(tf_dtype)); + return errors::Unimplemented("Data type is not supported, for node ", + node_def.name(), " got ", + DataTypeString(tf_dtype)); } bool transpose_a = attrs.get("transpose_a"); bool transpose_b = attrs.get("transpose_b"); // FullyConnected: if (transpose_a) { - return tensorflow::errors::Internal( - "Transpose_a is not supported for TensorRT FullyConnected (op: " + - node_def.op() + "), at: " + node_def.name()); + return errors::InvalidArgument( + "transpose_a is not supported for TensorRT FullyConnected (op: ", + node_def.op(), "), at: ", node_def.name()); } + if (params->validation_only) return Status::OK(); return ConvertMatMulHelper(params, inputs.at(0), inputs.at(1).weights(), transpose_b, node_def.name()); } @@ -2673,10 +2839,13 @@ tensorflow::Status ConvertGraphDefToEngine( return tensorflow::errors::InvalidArgument( "Failed to parse slot number from ", node_name); } - nvinfer1::DataType dtype; + nvinfer1::DataType trt_dtype; + nvinfer1::Dims trt_dims; + int batch_size = -1; auto shape = input_shapes.at(slot_number); - auto status = ValidateInputProperties( - shape, node_def.attr().at("dtype").type(), &dtype); + auto status = ValidateTensorProperties( + node_def.op(), node_def.attr().at("dtype").type(), shape, + /*validation_only=*/false, &trt_dtype, &trt_dims, &batch_size); if (!status.ok()) { const string error_message = StrCat("Validation failed for ", node_name, " and input slot ", @@ -2684,19 +2853,13 @@ tensorflow::Status ConvertGraphDefToEngine( LOG(WARNING) << error_message; return Status(status.code(), error_message); } - - nvinfer1::Dims input_dim; - for (int i = 1; i < shape.dims(); i++) { - input_dim.d[i - 1] = shape.dim_size(i); - } - input_dim.nbDims = shape.dims() - 1; VLOG(2) << "Adding engine input tensor " << node_name << " with shape " - << DebugString(input_dim); + << DebugString(trt_dims); // TODO(laigd): the conversion should always happen at runtime where all // the shapes are known, and we can provide a mode to generate the // engines offline, by calling sess.run() and cache/serialize the engines. - TF_RETURN_IF_ERROR(converter.AddInputTensor(node_name, dtype, input_dim, - shape.dim_size(0))); + TF_RETURN_IF_ERROR( + converter.AddInputTensor(node_name, trt_dtype, trt_dims, batch_size)); } else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) && (node_def.op() == "Identity")) { int32 slot_number = -1; @@ -2866,34 +3029,6 @@ tensorflow::Status ConvertSegmentToGraphDef( return tensorflow::Status::OK(); } -bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const { - if (in_edge->IsControlEdge()) return true; - PartialTensorShape shape; - tensorflow::DataType dtype; - GetOutputProperties(graph_properties_, in_edge->src(), in_edge->src_output(), - &shape, &dtype); - nvinfer1::DataType trt_dtype; - Status status = ValidateInputProperties(shape, dtype, &trt_dtype); - if (!status.ok()) { - VLOG(1) << "--> Need to remove input node " << in_edge->dst()->name() - << ": " << status; - return false; - } - - - if (in_edge->src()->type_string() != "Const" && - // Single dimensional input tensor is not supported since the first - // dimension is treated as batch dimension. - shape.dims() < 2) { - VLOG(1) << "--> Need to remove input node " << in_edge->dst()->name() - << " which has an input at port " << in_edge->dst_input() << " with" - << " #dim<2" - << " and is not a const: " << shape; - return false; - } - return true; -} - bool OutputEdgeValidator::operator()(const tensorflow::Edge* out_edge) const { if (out_edge->IsControlEdge()) return true; if (out_edge->src()->type_string() == "Const") { diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h index 699b50b37e3..5cc28b33e7f 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -148,21 +148,6 @@ tensorflow::Status ConvertGraphDefToEngine( TrtUniquePtrType* engine, bool* convert_successfully); -// Helper class for the segmenter to determine whether an input edge to the TRT -// segment is valid. -class InputEdgeValidator { - public: - InputEdgeValidator(const grappler::GraphProperties& graph_properties) - : graph_properties_(graph_properties) {} - - // Return true if the specified edge is eligible to be an input edge of the - // TRT segment. - bool operator()(const tensorflow::Edge* in_edge) const; - - private: - const grappler::GraphProperties& graph_properties_; -}; - // Helper class for the segmenter to determine whether an output edge from the // TRT segment is valid. class OutputEdgeValidator { @@ -245,8 +230,21 @@ class TRT_TensorOrWeights { public: TRT_TensorOrWeights() {} + // Constructor that makes it an ITensor, doesn't take ownership of 'tensor'. + // This is used by Converter when building the TRT network, where the ITensor + // is owned by the TRT network being built. See comment for 'tensor_' below. explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor, int batch_size = -1); + // Constructor that makes it an ITensor by creating one using provided data + // type and shape, and takes ownership of the created ITensor. This is used by + // TrtNodeValidator to encapsulate the type and shape information for + // validation of graph nodes, and the created ITensor is fake and temporary, + // and should not be used to build any TRT network. See comment for + // 'simple_itensor_' below. + explicit TRT_TensorOrWeights(nvinfer1::DataType trt_dtype, + const nvinfer1::Dims& trt_dims, int batch_size); + + // Constructor that makes it a TRT_TensorOrWeights. explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights); TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs); @@ -256,15 +254,9 @@ class TRT_TensorOrWeights { bool is_tensor() const { return initialized_ && is_tensor_; } bool is_weights() const { return initialized_ && !is_tensor_; } - nvinfer1::ITensor* tensor() { - CHECK(is_tensor()); - return tensor_; - } + nvinfer1::ITensor* tensor(); - const nvinfer1::ITensor* tensor() const { - CHECK(is_tensor()); - return tensor_; - } + const nvinfer1::ITensor* tensor() const; TRT_ShapedWeights& weights() { CHECK(is_weights()); @@ -283,9 +275,25 @@ class TRT_TensorOrWeights { string DebugString() const; private: + class SimpleITensor; + void set_batch_size(int batch_size) { batch_size_ = batch_size; } + // When it represents an ITensor, the ITensor can be either passed by the + // caller via the constructor that takes an ITensor* as parameter, or be + // created as a SimpleITensor. + // + // In the first case, the ITensor pointer is stored in 'tensor_' below, and + // the ITensor itself is not owned by this class. This method is used by + // Converter (e.g. AddInputTensor) and op converters during TRT network + // construction, where the TRT network owns the ITensor. + // + // In the second case, the created SimpleITensor is stored in + // 'simple_itensor_' below and is owned by this class. SimpleITensor is a fake + // implementation of ITensor and is used only by TrtNodeValidator to validate + // the graph nodes. nvinfer1::ITensor* tensor_ = nullptr; // Not owned. + std::shared_ptr simple_itensor_ = nullptr; // First dimension of the TF tensor (NOT tensor_) that is represented by // tensor_ is treated as the "batch dimension" by TRT, and tensor_'s @@ -339,13 +347,35 @@ class TrtNodeValidator { public: TrtNodeValidator(); - // Validate the node, and return ok if it's supported by the converter. - Status ValidateNode(const NodeDef& node_def, - const std::vector& inputs); + // Validate the node, and return ok if it's supported by TRT. + // + // - 'node_def' is the node to validate. + // - 'input_node_and_ports' are the input NodeDefs and their output ports that + // are connected to 'node_def' in the TF graph. + // - 'graph_properties' is the GraphProperties of the graph where 'node_def' + // belongs. It is used to get the shape and data type information of a + // tensor for validation purpose. + Status ValidateNode( + const NodeDef& node_def, + const std::vector>& input_node_and_ports, + const grappler::GraphProperties& graph_properties); private: void RegisterOpValidators(); + // Convert a Const node to a TRT_TensorOrWeights. + Status ConvertConstToWeights(const NodeDef& const_node_def, + const std::vector& inputs, + TRT_TensorOrWeights* output); + + // Convert the output tensor at 'output_port' of 'node_def' to a + // TRT_TensorOrWeights which will be later used as an input to other nodes and + // passed to ValidateNode() below. + Status ConvertToTensorOrWeights( + const NodeDef& node_def, int output_port, + const grappler::GraphProperties& graph_properties, + TRT_TensorOrWeights* tensor_or_weights); + // Stores all the validators by op type. If no validator is registered for // specific op, it means no validation is needed and ValidateNode() will // return OK. diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc index bc390743335..c3a39395f3a 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc @@ -21,6 +21,9 @@ limitations under the License. #include #include +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT @@ -29,6 +32,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -45,6 +49,7 @@ namespace convert { using ::testing::ElementsAre; +// TODO(laigd): put this into some test utils file. void ExpectStatus(Status status, error::Code code = error::OK, const char* substr = nullptr) { EXPECT_EQ(code, status.code()) @@ -75,6 +80,23 @@ NodeDef MakeNodeDef(const string& name, const string& op, return node_def; } +template +NodeDef MakeConstNodeDef(const string& name, const std::vector& vals, + const TensorShape& shape) { + Scope s = Scope::NewRootScope(); + Tensor t = ::tensorflow::test::AsTensor(vals, shape); + auto const_op = ops::Const(s.WithOpName(name), t); + return const_op.node()->def(); +} + +template +NodeDef MakeConstNodeDef(const string& name, const std::vector& vals) { + TensorShape shape; + const std::vector shape_dims = {static_cast(vals.size())}; + TF_EXPECT_OK(TensorShapeUtils::MakeShape(shape_dims, &shape)); + return MakeConstNodeDef(name, vals, shape); +} + bool TrtDimsEquals(const nvinfer1::Dims& lhs, const nvinfer1::Dims& rhs) { if (lhs.nbDims != rhs.nbDims) return false; for (int i = 0; i < lhs.nbDims; ++i) { @@ -95,6 +117,19 @@ bool TrtShapedWeightsEquals(const TRT_ShapedWeights& lhs, lhs.GetValues() == rhs.GetValues(); } +template +void ValidateWeights(const TRT_ShapedWeights& weights, + const std::vector& expected_dims, + const std::vector& expected_value) { + EXPECT_TRUE(TrtDimsEqualsArray(expected_dims, weights.shape_)) + << weights.DebugString(); + ASSERT_EQ(expected_value.size(), weights.count()) << weights.DebugString(); + const T* actual_values = static_cast(weights.GetValues()); + for (int i = 0; i < expected_value.size(); ++i) { + EXPECT_EQ(expected_value[i], actual_values[i]); + } +} + // Fake ITensor implementation for testing purposes. class FakeITensor : public nvinfer1::ITensor { public: @@ -194,32 +229,86 @@ TEST(TRT_ShapedWeights_Test, Basic) { } TEST(TRT_TensorOrWeights_Test, Basic) { + // Test constructor with no arguments. + { + TRT_TensorOrWeights tw; + TRT_TensorOrWeights copy(tw); + TRT_TensorOrWeights assigned; + assigned = tw; + for (auto ptr : {&tw, ©, &assigned}) { + EXPECT_EQ(false, ptr->is_tensor()); + EXPECT_EQ(false, ptr->is_weights()); + EXPECT_EQ(-1, ptr->batch_size()); + } + } + + // Test constructor with ITensor and batch size argument. { nvinfer1::Dims dims; dims.nbDims = 1; dims.d[0] = 1; FakeITensor itensor(dims); - TRT_TensorOrWeights tw(&itensor); - EXPECT_EQ(true, tw.is_tensor()); - EXPECT_EQ(false, tw.is_weights()); - EXPECT_EQ(&itensor, tw.tensor()); - EXPECT_TRUE(TrtDimsEqualsArray({1}, tw.GetTrtDims())) - << "- expected: " << DebugString(dims) - << "\n vs\n- actual: " << DebugString(tw.GetTrtDims()); - } - { - TRT_ShapedWeights weights(DT_FLOAT); - TRT_TensorOrWeights tw(weights); - EXPECT_EQ(false, tw.is_tensor()); - EXPECT_EQ(true, tw.is_weights()); - EXPECT_TRUE(TrtShapedWeightsEquals(weights, tw.weights())); + TRT_TensorOrWeights tw1(&itensor, /*batch_size=*/1); + for (auto original_ptr : {&tw, &tw1}) { + TRT_TensorOrWeights copy(*original_ptr); + TRT_TensorOrWeights assigned; + assigned = *original_ptr; + + for (auto ptr : {original_ptr, ©, &assigned}) { + EXPECT_EQ(true, ptr->is_tensor()); + EXPECT_EQ(false, ptr->is_weights()); + if (original_ptr == &tw) { + EXPECT_EQ(-1, ptr->batch_size()); + } else { + EXPECT_EQ(1, ptr->batch_size()); + } + EXPECT_EQ(&itensor, ptr->tensor()); + EXPECT_TRUE(TrtDimsEqualsArray({1}, ptr->GetTrtDims())) + << "- expected: " << DebugString(dims) + << "\n vs\n- actual: " << DebugString(ptr->GetTrtDims()); + } + } + } + // Test constructor which creates and owns an ITensor. + { nvinfer1::Dims dims; - dims.nbDims = 0; - EXPECT_TRUE(TrtDimsEqualsArray({}, tw.GetTrtDims())) - << "- expected: " << DebugString(dims) - << "\n vs\n- actual: " << DebugString(tw.GetTrtDims()); + dims.nbDims = 1; + dims.d[0] = 1; + TRT_TensorOrWeights tw(nvinfer1::DataType::kFLOAT, dims, /*batch_size=*/1); + TRT_TensorOrWeights copy(tw); + TRT_TensorOrWeights assigned; + assigned = tw; + + for (auto ptr : {&tw, ©, &assigned}) { + EXPECT_EQ(true, ptr->is_tensor()); + EXPECT_EQ(false, ptr->is_weights()); + EXPECT_EQ(1, ptr->batch_size()); + EXPECT_NE(nullptr, ptr->tensor()); + EXPECT_TRUE(TrtDimsEqualsArray({1}, ptr->GetTrtDims())) + << "- expected: " << DebugString(dims) + << "\n vs\n- actual: " << DebugString(ptr->GetTrtDims()); + } + } + // Test constructor with TRT_ShapedWeights argument. + { + TRT_ShapedWeights weights; + TRT_TensorOrWeights tw(weights); + TRT_TensorOrWeights copy(tw); + TRT_TensorOrWeights assigned; + assigned = tw; + for (auto ptr : {&tw, ©, &assigned}) { + EXPECT_EQ(false, ptr->is_tensor()); + EXPECT_EQ(true, ptr->is_weights()); + EXPECT_TRUE(TrtShapedWeightsEquals(weights, ptr->weights())); + + nvinfer1::Dims dims; + dims.nbDims = 0; + EXPECT_TRUE(TrtDimsEqualsArray({}, ptr->GetTrtDims())) + << "- expected: " << DebugString(dims) + << "\n vs\n- actual: " << DebugString(ptr->GetTrtDims()); + } } } @@ -229,11 +318,64 @@ class ValidatorTest : public ::testing::Test { validator_.op_validators_[op_name] = op_validator; } + Status ConvertToTensorOrWeights( + const NodeDef& node_def, int output_port, + const grappler::GraphProperties& graph_properties, + TRT_TensorOrWeights* tensor_or_weights) { + return validator_.ConvertToTensorOrWeights( + node_def, output_port, graph_properties, tensor_or_weights); + } + protected: TrtNodeValidator validator_; }; +TEST_F(ValidatorTest, ConvertToTensorOrWeights) { + // Convert Const. + { + NodeDef node_def = MakeConstNodeDef("my_const", {1.0f, 2.0f}); + TRT_TensorOrWeights output; + grappler::GrapplerItem item; + grappler::GraphProperties graph_properties(item); + ExpectStatus(ConvertToTensorOrWeights(node_def, /*output_port=*/0, + graph_properties, &output)); + ValidateWeights(output.weights(), {2}, {1.0, 2.0}); + } + // Convert non-Const. We test the case where the non-batch dimemsion is + // unknown as well, to make sure the validator allows that. + for (const int32 non_batch_dim : {-1, 2}) { + const int32 batch_size = 12; + + Scope s = Scope::NewRootScope(); + ops::Placeholder::Attrs attrs; + TF_EXPECT_OK(TensorShapeUtils::MakeShape( + std::vector{batch_size, non_batch_dim}, &attrs.shape_)); + auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT, attrs); + auto add = ops::Add(s.WithOpName("add"), feed, feed); + + grappler::GrapplerItem item; + TF_EXPECT_OK(s.ToGraphDef(&item.graph)); + + grappler::GraphProperties graph_properties(item); + TF_EXPECT_OK(graph_properties.InferStatically(true)); + + auto& node_def = add.operation.node()->def(); + TRT_TensorOrWeights output; + ExpectStatus(ConvertToTensorOrWeights(node_def, /*output_port=*/0, + graph_properties, &output)); + EXPECT_EQ(true, output.is_tensor()); + EXPECT_EQ(batch_size, output.batch_size()); + EXPECT_NE(nullptr, output.tensor()); + EXPECT_TRUE(TrtDimsEqualsArray({non_batch_dim}, output.GetTrtDims())) + << "- expected: {" << non_batch_dim << "} \n vs\n" + << "- actual: " << DebugString(output.GetTrtDims()); + } +} + TEST_F(ValidatorTest, ValidateNode) { + grappler::GrapplerItem item; + grappler::GraphProperties graph_properties(item); + bool start_conversion = false; bool should_fail = false; auto op_converter = [&start_conversion, @@ -245,16 +387,17 @@ TEST_F(ValidatorTest, ValidateNode) { NodeDef node_def = MakeNodeDef("my_op", "MyOp", {}); // Validator not registered, validation should pass. - TF_EXPECT_OK(validator_.ValidateNode(node_def, {})); + TF_EXPECT_OK(validator_.ValidateNode(node_def, {}, graph_properties)); // Register validator. AddOpValidator("MyOp", op_converter); - TF_EXPECT_OK(validator_.ValidateNode(node_def, {})); + TF_EXPECT_OK(validator_.ValidateNode(node_def, {}, graph_properties)); EXPECT_EQ(false, start_conversion); // Let the converter return error. should_fail = true; - ExpectStatus(validator_.ValidateNode(node_def, {}), error::INVALID_ARGUMENT); + ExpectStatus(validator_.ValidateNode(node_def, {}, graph_properties), + error::INVALID_ARGUMENT); } class ConverterTest : public ::testing::Test { @@ -289,6 +432,8 @@ class ConverterTest : public ::testing::Test { return converter_->GetInputs(node_def, inputs); } + int batch_size() const { return converter_->batch_size_; } + private: Logger logger_; // These members are ordered in a way such that the destruction order is: @@ -474,11 +619,48 @@ TEST_F(ConverterTest, PrepareTensorForShape_Weights) { << DebugString(*output_tensor); } +TEST_F(ConverterTest, MaybeUpdateBatchSize) { + EXPECT_EQ(-1, batch_size()); + + TF_EXPECT_OK(MaybeUpdateBatchSize(-1)); + EXPECT_EQ(-1, batch_size()); + + TF_EXPECT_OK(MaybeUpdateBatchSize(123)); + EXPECT_EQ(123, batch_size()); + + TF_EXPECT_OK(MaybeUpdateBatchSize(123)); + EXPECT_EQ(123, batch_size()); + + TF_EXPECT_OK(MaybeUpdateBatchSize(-1)); + EXPECT_EQ(123, batch_size()); + + ExpectStatus(MaybeUpdateBatchSize(124), error::INVALID_ARGUMENT, + "Provided batch size does not match converter batch size"); +} + +TEST_F(ConverterTest, AddAndGetTensorOrWeights) { + // Add a tensor. + FakeITensor fake_tensor; + TRT_TensorOrWeights tensor(&fake_tensor); + EXPECT_EQ(-1, tensor.batch_size()); + TF_EXPECT_OK(MaybeUpdateBatchSize(123)); + TF_EXPECT_OK(AddTensorOrWeights("my_tensor", tensor)); + + // Get the added tensor. + TRT_TensorOrWeights added_tensor; + TF_EXPECT_OK(GetTensorOrWeights("my_tensor", &added_tensor)); + EXPECT_EQ(123, added_tensor.batch_size()); + + // Add the same tensor again. + ExpectStatus(AddTensorOrWeights("my_tensor", tensor), error::ALREADY_EXISTS, + "tensor/weights my_tensor already exist"); +} + // Class to test various op converters, using both a TrtNodeValidator and // Converter. class OpConverterTest : public ::testing::Test { public: - OpConverterTest() { + OpConverterTest() : scope_(Scope::NewRootScope()) { QCHECK_EQ(0, cudaStreamCreate(&stream_)); Reset(); } @@ -505,8 +687,8 @@ class OpConverterTest : public ::testing::Test { converter_.reset(new Converter(network_.get(), /*fp16=*/false)); // Reset other related artifacts. - fake_itensors_.clear(); - fake_tensor_or_weights_.clear(); + scope_ = Scope::NewRootScope(); + validator_inputs_.clear(); } void BuildAndRun(const char* input_name, const std::vector& input_data, @@ -551,33 +733,41 @@ class OpConverterTest : public ::testing::Test { } // Add ITensor for both validation and conversion. - void AddTestTensor(const char* name, const std::vector& dims, - int batch_size = 1) { - nvinfer1::Dims trt_dims = GetTestDims(dims); - // Add FakeITensor for validation. - // - // TRT cannot add a tensor that has undetermined dims, so we manage the - // tensor using a vector. These tensors are used to test validation-only - // mode and thus should not be used to build the engine. - FakeITensor* fake_itensor = new FakeITensor(trt_dims); - fake_itensors_.emplace_back(fake_itensor); - fake_tensor_or_weights_[string(name)] = - TRT_TensorOrWeights{fake_itensor, batch_size}; + void AddTestTensor( + const char* name, const std::vector& dims, int batch_size = 1, + nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT) { + DataType tf_dtype = DT_FLOAT; + switch (trt_dtype) { + case nvinfer1::DataType::kFLOAT: + tf_dtype = DT_FLOAT; + break; + case nvinfer1::DataType::kINT32: + tf_dtype = DT_INT32; + break; + default: + ASSERT_TRUE(false) << "Unexpected data type " + << static_cast(trt_dtype); + } + ops::Placeholder::Attrs attrs; + TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &attrs.shape_)); + attrs.shape_.InsertDim(0, batch_size); + auto input = ops::Placeholder(scope_.WithOpName(name), tf_dtype, attrs); + validator_inputs_[name] = input.operation.node()->def(); // Add a real ITensor for conversion conditionally. + const nvinfer1::Dims trt_dims = GetTestDims(dims); if (HasStaticShape(trt_dims)) { - TF_EXPECT_OK(converter_->AddInputTensor(name, nvinfer1::DataType::kFLOAT, - trt_dims, batch_size)); + TF_EXPECT_OK( + converter_->AddInputTensor(name, trt_dtype, trt_dims, batch_size)); ASSERT_EQ(batch_size, converter_->batch_size_); } } // Add weights for both validation and conversion. - template - void AddTestWeights(const char* name, const DataType dtype, - const std::vector& dims, - const std::vector& values) { - QCHECK_EQ(DataTypeToEnum::v(), dtype); + template + void AddTestWeights(const char* name, const std::vector& dims, + const std::vector& values) { + const DataType dtype = DataTypeToEnum::v(); const nvinfer1::Dims trt_dims = GetTestDims(dims); const int64_t num_elements = TrtDimsNumElements(trt_dims); QCHECK_EQ(num_elements, values.size()) @@ -585,13 +775,15 @@ class OpConverterTest : public ::testing::Test { TRT_ShapedWeights weights(dtype); if (num_elements) { weights = converter_->weight_store_.GetTempWeights(dtype, trt_dims); - QCHECK_EQ(weights.size_bytes(), sizeof(CType) * values.size()) - << weights.size_bytes() << " vs " << sizeof(CType) * values.size(); + QCHECK_EQ(weights.size_bytes(), sizeof(T) * values.size()) + << weights.size_bytes() << " vs " << sizeof(T) * values.size(); memcpy(const_cast(weights.GetValues()), values.data(), weights.size_bytes()); } // Add weights for validation. - fake_tensor_or_weights_[string(name)] = TRT_TensorOrWeights{weights}; + TensorShape shape; + TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &shape)); + validator_inputs_[name] = MakeConstNodeDef(name, values, shape); // Add weights for conversion. TF_EXPECT_OK( converter_->AddTensorOrWeights(name, TRT_TensorOrWeights{weights})); @@ -601,12 +793,18 @@ class OpConverterTest : public ::testing::Test { void RunValidation(const NodeDef& node_def, error::Code expected_code = error::OK, const char* expected_msg_substr = nullptr) { - std::vector inputs; + std::vector> input_node_and_ports; for (const string& input : node_def.input()) { - inputs.emplace_back(fake_tensor_or_weights_[input]); + input_node_and_ports.emplace_back(&validator_inputs_[input], 0); } - ExpectStatus(validator_->ValidateNode(node_def, inputs), expected_code, - expected_msg_substr); + grappler::GrapplerItem item; + TF_EXPECT_OK(scope_.ToGraphDef(&item.graph)); + grappler::GraphProperties graph_properties(item); + TF_EXPECT_OK(graph_properties.InferStatically(true)); + + ExpectStatus(validator_->ValidateNode(node_def, input_node_and_ports, + graph_properties), + expected_code, expected_msg_substr); } void RunConversion(const NodeDef& node_def, @@ -637,8 +835,8 @@ class OpConverterTest : public ::testing::Test { TrtUniquePtrType network_; TrtUniquePtrType engine_; cudaStream_t stream_; - std::vector> fake_itensors_; - std::unordered_map fake_tensor_or_weights_; + Scope scope_; + std::unordered_map validator_inputs_; }; template @@ -662,15 +860,7 @@ void TestConvertConst(OpConverterTest* test) { test->RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(test->GetTensorOrWeights("my_const", &output)); - EXPECT_TRUE(TrtDimsEqualsArray(expected_dims, output.weights().shape_)) - << output.DebugString(); - ASSERT_EQ(expected_value.size(), output.weights().count()) - << output.DebugString(); - const OutputCType* actual_values = - static_cast(output.weights().GetValues()); - for (int i = 0; i < expected_value.size(); ++i) { - EXPECT_EQ(expected_value[i], actual_values[i]); - } + ValidateWeights(output.weights(), expected_dims, expected_value); }; auto& attr = *node_def.mutable_attr(); @@ -700,8 +890,6 @@ void TestConvertConst(OpConverterTest* test) { } } -// TODO(laigd): we should use c++ API to create the nodedef, so any change in -// the API will be captured. TEST_F(OpConverterTest, ConvertConst) { { Reset(); @@ -713,10 +901,9 @@ TEST_F(OpConverterTest, ConvertConst) { } { Reset(); - NodeDef node_def = MakeNodeDef("my_const", "Const", {}); - (*node_def.mutable_attr())["dtype"].set_type(DT_DOUBLE); + NodeDef node_def = MakeConstNodeDef("my_const", {}); RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, - "Unsupported data type"); + "Unsupported data type double"); } TestConvertConst(this); @@ -732,8 +919,14 @@ TEST_F(OpConverterTest, ConvertTranspose) { node_def, error::INVALID_ARGUMENT, "Input expects tensor and weights, at my_transpose"); } - NodeDef node_def = - MakeNodeDef("my_transpose", "Transpose", {"input", "weights"}); + + // Get the NodeDef for Transpose. + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32); + auto transpose = ops::Transpose(s.WithOpName("my_transpose"), input, weights); + const NodeDef& node_def = transpose.operation.node()->def(); + { // Permutation is a tensor, should fail. Reset(); @@ -747,7 +940,7 @@ TEST_F(OpConverterTest, ConvertTranspose) { // Transpose at batch dimension, should fail. Reset(); AddTestTensor("input", {1, 2, 3}); - AddTestWeights("weights", DT_INT32, {4}, {1, 0, 2, 3}); + AddTestWeights("weights", {4}, {1, 0, 2, 3}); RunValidationAndConversion(node_def, error::UNIMPLEMENTED, "Transpose at batch dimension is not supported"); } @@ -755,7 +948,7 @@ TEST_F(OpConverterTest, ConvertTranspose) { // Permutation rank doesn't match, should fail. Reset(); AddTestTensor("input", {1, 2, 3}); - AddTestWeights("weights", DT_INT32, {3}, {0, 1, 2}); + AddTestWeights("weights", {3}, {0, 1, 2}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, "Rank of perm for transpose does not match with that of the input."); @@ -764,7 +957,7 @@ TEST_F(OpConverterTest, ConvertTranspose) { // Ok. Reset(); AddTestTensor("input", {1, 2, 3}); - AddTestWeights("weights", DT_INT32, {4}, {0, 3, 1, 2}); + AddTestWeights("weights", {4}, {0, 3, 1, 2}); RunConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_transpose", &output)); @@ -786,7 +979,14 @@ TEST_F(OpConverterTest, ConvertReshape) { node_def, error::INVALID_ARGUMENT, "Input expects weights for shape, at my_reshape"); } - NodeDef node_def = MakeNodeDef("my_reshape", "Reshape", {"input", "weights"}); + + // Get the NodeDef for Reshape. + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32); + auto reshape = ops::Reshape(s.WithOpName("my_reshape"), input, weights); + const NodeDef& node_def = reshape.operation.node()->def(); + { // Shape is a tensor, should fail. Reset(); @@ -800,7 +1000,7 @@ TEST_F(OpConverterTest, ConvertReshape) { // Reshape to scalar, should fail. Reset(); AddTestTensor("input", {1, 2, 3}); - AddTestWeights("weights", DT_INT32, {}, {}); + AddTestWeights("weights", {0}, {}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, "Reshape to shape=[] is not supported, at my_reshape"); @@ -830,7 +1030,7 @@ TEST_F(OpConverterTest, ConvertReshape) { Reset(); const std::vector& dims = params[i].tensor_dims; AddTestTensor("input", dims, params[i].batch_size); - AddTestWeights("weights", DT_INT32, {4}, params[i].shape); + AddTestWeights("weights", {4}, params[i].shape); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, "Reshape on batch dimension is not supported, at my_reshape", @@ -847,7 +1047,7 @@ TEST_F(OpConverterTest, ConvertReshape) { for (int i = 0; i < kReshapeOKCases; ++i) { Reset(); AddTestTensor("input", ok_params[i].tensor_dims, ok_params[i].batch_size); - AddTestWeights("weights", DT_INT32, {4}, ok_params[i].shape); + AddTestWeights("weights", {4}, ok_params[i].shape); RunConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_reshape", &output)); @@ -869,24 +1069,67 @@ TEST_F(OpConverterTest, ConvertMatMul) { node_def, error::INVALID_ARGUMENT, "Input expects tensor and weights, at my_matmul"); } - NodeDef node_def = MakeNodeDef("my_matmul", "MatMul", {"input", "weights"}); - auto& attr = *node_def.mutable_attr(); - attr["T"].set_type(DT_FLOAT); - attr["transpose_a"].set_b(false); - attr["transpose_b"].set_b(false); - { - AddTestTensor("input", {2}, 1); - AddTestWeights("weights", DT_FLOAT, {2, 1}, {3, 5}); - RunConversion(node_def); - TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output)); - EXPECT_TRUE(output.is_tensor()); - EXPECT_TRUE(TrtDimsEqualsArray({1}, output.tensor()->getDimensions())) - << output.DebugString(); - std::vector output_data(1); - BuildAndRun("input", {2, 7}, "my_matmul", &output_data); - EXPECT_THAT(output_data, ElementsAre(41)); + // Get the NodeDef for Reshape. + auto get_matmul_nodedef = [](DataType dtype, bool transpose_a, + bool transpose_b) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), dtype); + auto weights = ops::Placeholder(s.WithOpName("weights"), dtype); + ops::MatMul::Attrs matmul_attrs; + matmul_attrs.transpose_a_ = transpose_a; + matmul_attrs.transpose_b_ = transpose_b; + auto matmul = + ops::MatMul(s.WithOpName("my_matmul"), input, weights, matmul_attrs); + return matmul.operation.node()->def(); + }; + + { + // Unsupported data type. + Reset(); + NodeDef node_def = get_matmul_nodedef(DT_INT32, false, false); + AddTestTensor("input", {2}, /*batch_size=*/1, nvinfer1::DataType::kINT32); + AddTestWeights("weights", {2, 1}, {3, 5}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Data type is not supported, for node my_matmul got int32"); + } + { + // transpose_a is set. + for (bool transpose_b : {false, true}) { + Reset(); + NodeDef node_def = + get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/true, transpose_b); + AddTestTensor("input", {2}, /*batch_size=*/1); + AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "transpose_a is not supported for TensorRT FullyConnected"); + } + } + { + // OK. + for (bool transpose_b : {false, true}) { + Reset(); + NodeDef node_def = + get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/false, transpose_b); + AddTestTensor("input", {2}, /*batch_size=*/1); + AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); + RunConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output)); + EXPECT_TRUE(output.is_tensor()); + EXPECT_TRUE(TrtDimsEqualsArray({2}, output.tensor()->getDimensions())) + << output.DebugString(); + + std::vector output_data(2); + BuildAndRun("input", {0, 1}, "my_matmul", &output_data); + if (transpose_b) { + EXPECT_THAT(output_data, ElementsAre(1, 3)); + } else { + EXPECT_THAT(output_data, ElementsAre(2, 3)); + } + } } } diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc index c82d4a01839..4f64b7a9522 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.cc +++ b/tensorflow/contrib/tensorrt/segment/segment.cc @@ -389,7 +389,7 @@ void ContractEdge(SimpleEdge* edge, SimpleGraph* graph, tensorflow::Status SegmentGraph( const tensorflow::Graph* tf_graph, - const std::function& candidate_fn, + const std::function& candidate_fn, const std::function& input_candidate_fn, const std::function& output_candidate_fn, const SegmentOptions& options, SegmentNodesVector* segments) { @@ -409,9 +409,16 @@ tensorflow::Status SegmentGraph( std::vector> node_segments; for (int i = 0; i < graph->num_node_ids(); ++i) { SimpleNode* node = graph->FindNodeId(i); - if (options.exclude_node_list.count(node->name()) != 0 || - !candidate_fn(node->tf_node())) { + if (options.exclude_node_list.count(node->name()) != 0) { + VLOG(1) << "Not a TF-TRT candidate: " << node->name() + << " (excluded by segmenter option)."; node = nullptr; + } else { + const Status status = candidate_fn(node->tf_node()); + if (!status.ok()) { + VLOG(1) << "Not a TF-TRT candidate: " << node->name() << ": " << status; + node = nullptr; + } } node_segments.emplace_back(node); } diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h index 8c44eb782aa..b9693aad1b7 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.h +++ b/tensorflow/contrib/tensorrt/segment/segment.h @@ -43,7 +43,7 @@ struct SegmentOptions { // Get the subgraphs of a graph that can be handled by TensorRT. // // @param graph tensorflow::Graph of the network -// @param candidate_fn A function that returns true for a Node* if +// @param candidate_fn A function that returns OK for a Node* if // that node can be handled by TensorRT. // @param segments Returns the TensorRT segments/subgraphs. Each entry // in the vector describes a subgraph by giving a set of the names of @@ -51,7 +51,7 @@ struct SegmentOptions { // @return the status. tensorflow::Status SegmentGraph( const tensorflow::Graph* tf_graph, - const std::function& candidate_fn, + const std::function& candidate_fn, const std::function& input_candidate_fn, const std::function& output_candidate_fn, const SegmentOptions& options, SegmentNodesVector* segments); diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc index 5937fa8259a..4805ef9c61a 100644 --- a/tensorflow/contrib/tensorrt/segment/segment_test.cc +++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc @@ -34,10 +34,13 @@ namespace ops = ::tensorflow::ops; class SegmentTest : public ::testing::Test { protected: - std::function MakeCandidateFn( + std::function MakeCandidateFn( const std::set& node_names) { - return [node_names](const tensorflow::Node* node) -> bool { - return node_names.find(node->name()) != node_names.end(); + return [node_names](const tensorflow::Node* node) -> Status { + if (node_names.find(node->name()) != node_names.end()) { + return Status::OK(); + } + return errors::NotFound(""); }; } diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py index 7d006b73d53..7545bb9df20 100644 --- a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py +++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py @@ -118,30 +118,11 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase): """Return a ConversionParams for test.""" return super(BiasaddMatMulTest, self).GetConversionParams(run_params)._replace( - max_batch_size=4, maximum_cached_engines=2) - - def _ValidEngines(self): - """Engines expected to build and run.""" - return ["my_trt_op_0"] - - def _InvalidEngines(self): - """Engines that will cause conversion error at building time.""" - return ["my_trt_op_1", "my_trt_op_2"] + max_batch_size=4, maximum_cached_engines=1) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - # In dynamic engine mode the engines are built in execution time, not in - # conversion time, so build errors occurs later. Here three of the engines - # will be failed to built but the corresponding engine op are still created. - # TODO(aaroey, jjsjann123): fix this. - if (run_params.dynamic_engine and - not trt_test.IsQuantizationMode(run_params.precision_mode)): - return self._ValidEngines() + self._InvalidEngines() - return self._ValidEngines() - - def ExpectedEnginesToRun(self, run_params): - """Return the expected engines to run.""" - return self._ValidEngines() + return ["my_trt_op_0"] def ShouldRunTest(self, run_params): """Whether to run the test.""" diff --git a/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py b/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py index 3cf7dadb1f4..bbc724ab18e 100644 --- a/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py +++ b/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py @@ -79,7 +79,7 @@ class ReshapeTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" return { - "my_trt_op_3": ["reshape-%d" % i for i in range(7)] + + "my_trt_op_0": ["reshape-%d" % i for i in range(7)] + ["reshape-%d/shape" % i for i in range(7)] }