From e353b0e2ccd53b76e787e6c1d13b96b46f4cf5a9 Mon Sep 17 00:00:00 2001 From: Pranav Marathe Date: Mon, 8 Apr 2019 16:30:56 -0700 Subject: [PATCH 01/15] Changes broadcast logic, and binary converter to use ElementWise instead of Scale Removes debug prints Changes broadcast logic, and binary converter to use ElementWise instead of Scale Removes debug prints Better logging for broadcasting Adds back checks for non-trivial batch dimension when broadcasting weights Changes broadcast logic, and binary converter to use ElementWise instead of Scale Removes debug prints Changes broadcast logic, and binary converter to use ElementWise instead of Scale Removes debug prints Adds back checks for non-trivial batch dimension when broadcasting weights Changes broadcast logic, and binary converter to use ElementWise instead of Scale Removes debug prints Changes broadcast logic, and binary converter to use ElementWise instead of Scale Removes debug prints Adds back checks for non-trivial batch dimension when broadcasting weights Changes broadcast logic, and binary converter to use ElementWise instead of Scale Removes debug prints Changes broadcast logic, and binary converter to use ElementWise instead of Scale Removes debug prints Better logging for broadcasting Adds back checks for non-trivial batch dimension when broadcasting weights Adds new unit tests for ConvertBinary Minor refactor Formats changes using clang-format --style=google Adds additional test cases for ConvertBinary --- .../tf2tensorrt/convert/convert_nodes.cc | 475 +++++------------- .../tf2tensorrt/convert/convert_nodes.h | 9 + .../tf2tensorrt/convert/convert_nodes_test.cc | 378 ++++---------- 3 files changed, 248 insertions(+), 614 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 306341301b9..739bf942c68 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -389,80 +389,83 @@ Status Converter::GetTrtBroadcastShape( const TRT_TensorOrWeights& operand_l, const TRT_TensorOrWeights& operand_r, nvinfer1::Dims* operand_l_new_dims, nvinfer1::Dims* operand_r_new_dims) const { - // *************************************************************************** // TensorRT Elementwise op supports broadcast but requires both tensor to be - // of Identical rank - // - // We consider case of: - // 1. operand_l to be a Tensor & operand_r to be a Const; - // 2. operand_l to be a Tensor & operand_r to be a Tensor; - // note: const op const (constant folding) should fallback to TensorFlow - // - // broadcast scheme: - // T: 1 3 5 (tensor would not have batch dimension) - // W: 1 1 3 1 (weight would have all explicit dimensions) - // i. fill in explicit dimensions - // -> T: -1 1 3 5 (we put a -1 for batch dimension) - // -> W: 1 1 3 1 - // ii. compare broadcast feasibility - // - // We cannot support the following since TensorRT does not allow manipulation - // on batch dimension, we cannot generate output with proper shape - // T: 3 5 1 - // W: 1 1 1 1 3 5 1 - // -> T: 1 1 1 -1 3 5 1 - // -> W: 1 1 1 1 3 5 1 - // *************************************************************************** - if (!operand_l.is_tensor() && !operand_r.is_tensor()) { - return errors::InvalidArgument( - "Broadcasting requires at least one of the operands be tensors"); + // of Identical rank. + // This function broadcasts the lower rank dimension across the higher rank + // one. + (*operand_l_new_dims) = operand_l.GetTrtDims(); + (*operand_r_new_dims) = operand_r.GetTrtDims(); + + // clang-format off + // Weights may include a batch dimension, so we need to remove it. + // We determine if that is the case by checking if the rank of the weights is + // larger than the rank of the tensor. Needed for cases such as: + // t: [1, 1] w/ implicit batch size of 1 + // w: [1, 1, 1] + // where the output in TRT is expected to be 2D, not 3D. + // clang-format on + if (operand_l.is_weights() && + operand_l_new_dims->nbDims > operand_r_new_dims->nbDims) { + if (operand_l_new_dims->d[0] != -1 && operand_l_new_dims->d[0] != 1) { + return errors::InvalidArgument( + "Cannot broadcast weights with non-trivial batch dimension"); + } + TF_RETURN_IF_ERROR(RemoveBatchDimension(operand_l_new_dims)); } - const int max_nb_dims = nvinfer1::Dims::MAX_DIMS + 1; - auto compute_output_dims = [](const TRT_TensorOrWeights& input, - int broadcast_num_dims, int* output_dims_array, - nvinfer1::Dims* output_dims) { - const nvinfer1::Dims input_dims = input.GetTrtDims(); - std::fill(output_dims_array, output_dims_array + max_nb_dims, 1); - std::copy(input_dims.d, input_dims.d + input_dims.nbDims, - output_dims_array + broadcast_num_dims - input_dims.nbDims); - if (input.is_tensor()) { - const int true_input_dims = input_dims.nbDims + 1; - if (true_input_dims < broadcast_num_dims) { - return errors::InvalidArgument( - "Broadcasting beyond batch dimension is not supported ", - "(tensor #dims ", true_input_dims, " vs broadcast #dims ", - broadcast_num_dims, ")"); - } - // Set the batch dimension to -1, since batch size is not supposed to - // be broadcasted. - output_dims_array[0] = -1; + if (operand_r.is_weights() && + operand_r_new_dims->nbDims > operand_l_new_dims->nbDims) { + if (operand_r_new_dims->d[0] != -1 && operand_r_new_dims->d[0] != 1) { + return errors::InvalidArgument( + "Cannot broadcast weights with non-trivial batch dimension"); } - // Copy to output dimensions (stripping the batch dimension). - output_dims->nbDims = broadcast_num_dims - 1; - std::copy(output_dims_array + 1, output_dims_array + broadcast_num_dims, - output_dims->d); + TF_RETURN_IF_ERROR(RemoveBatchDimension(operand_r_new_dims)); + } + + // If the rank of the tensors is already the same, we can't do anything + // further. + if (operand_l_new_dims->nbDims == operand_r_new_dims->nbDims) { + VLOG(2) << "Broadcasted operands to [L] " + << DebugString(*operand_l_new_dims) << " and [R] " + << DebugString(*operand_r_new_dims); return Status::OK(); + } + + const nvinfer1::Dims* higher_rank = + (operand_l_new_dims->nbDims > operand_r_new_dims->nbDims) + ? operand_l_new_dims + : operand_r_new_dims; + nvinfer1::Dims* lower_rank = + (operand_l_new_dims->nbDims <= operand_r_new_dims->nbDims) + ? operand_l_new_dims + : operand_r_new_dims; + + // Broadcasts low_rank over high_rank in-place by inserting ones at the front + // of low_rank so the ranks match. + constexpr auto broadcast_dims = [](const nvinfer1::Dims& high_rank, + const nvinfer1::Dims& low_rank) { + nvinfer1::Dims ret{high_rank.nbDims}; + std::fill(ret.d, ret.d + ret.nbDims, 1); + int num_leading_ones = high_rank.nbDims - low_rank.nbDims; + std::copy(low_rank.d, low_rank.d + low_rank.nbDims, + ret.d + num_leading_ones); + return ret; }; - // Compute the output dimensions. - const int broadcast_num_dims = - std::max(operand_l.GetTrtDims().nbDims + (operand_l.is_tensor() ? 1 : 0), - operand_r.GetTrtDims().nbDims + (operand_r.is_tensor() ? 1 : 0)); - int output_l[max_nb_dims], output_r[max_nb_dims]; - TF_RETURN_IF_ERROR(compute_output_dims(operand_l, broadcast_num_dims, - output_l, operand_l_new_dims)); - TF_RETURN_IF_ERROR(compute_output_dims(operand_r, broadcast_num_dims, - output_r, operand_r_new_dims)); + (*lower_rank) = broadcast_dims(*higher_rank, *lower_rank); + VLOG(2) << "Broadcasted operands to [L] " << DebugString(*operand_l_new_dims) + << " and [R] " << DebugString(*operand_r_new_dims); // Compare broadcast feasibility - for (int i = 0; i < broadcast_num_dims; ++i) { - if ((output_l[i] != output_r[i]) && (output_l[i] != 1) && - (output_r[i] != 1)) { + for (int i = 0; i < operand_r_new_dims->nbDims; ++i) { + if ((operand_l_new_dims->d[i] != operand_r_new_dims->d[i]) && + (operand_l_new_dims->d[i] != 1) && (operand_r_new_dims->d[i] != 1)) { return errors::InvalidArgument( - "Infeasible broadcast scheme (", "batch_dim: ", output_l[0], ", ", - DebugString(*operand_l_new_dims), " vs ", "batch_dim: ", output_r[0], - ", ", DebugString(*operand_r_new_dims), ")"); + "Infeasible broadcast scheme (", + "batch_dim: ", operand_l_new_dims->d[0], ", ", + DebugString(*operand_l_new_dims), " vs ", + "batch_dim: ", operand_r_new_dims->d[0], ", ", + DebugString(*operand_r_new_dims), ")"); } } return Status::OK(); @@ -1676,190 +1679,6 @@ Status UnaryCompute(const TRT_ShapedWeights& iweights, return Status::OK(); } -// If swapped_inputs is false, 'tensor' is the left operand and 'weights' is the -// right operand. If swapped_inputs is true, those two are swapped. -// -// TODO(jie): broadcast is needed yet not implemented. -// Only implemented channel wise for the time being. -Status BinaryTensorOpWeight(OpConverterParams* params, - nvinfer1::ITensor* tensor, - TRT_ShapedWeights weights, bool swapped_inputs) { - static const std::unordered_set supported_ops = {"Sub", "Add", "Mul", - "Div", "RealDiv"}; - const auto& node_def = params->node_def; - if (!supported_ops.count(node_def.op())) { - return errors::Unimplemented(node_def.op(), " is not supported, at ", - node_def.name()); - } - - // Check scale mode. - auto dims_w = weights.shape_; - const auto dims_t = tensor->getDimensions(); - - // TODO(jie): addScale checks for input tensor dimension - if (dims_t.nbDims != 3) { - return errors::InvalidArgument("addScale requires tensor with rank 3, at ", - node_def.name()); - } - - // Default to element-wise - auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; - - // TODO(jie): maybe use a permutation instead to support more cases; - bool need_to_permute = false; - - if (weights.count() == 1) { - scale_mode = nvinfer1::ScaleMode::kUNIFORM; - } else { - VLOG(2) << "weights dims: " << DebugString(dims_w) - << "; tensor dims: " << DebugString(dims_t); - // Make sure no broadcasting on batch dimension. - if (dims_w.nbDims == dims_t.nbDims + 1) { - if (dims_w.d[0] == 1) { - for (int i = 1; i < dims_w.nbDims; i++) { - dims_w.d[i - 1] = dims_w.d[i]; - } - dims_w.nbDims--; - } else { - return errors::InvalidArgument("Binary op cannot operate on batch, at ", - node_def.name()); - } - } - - if (dims_w.nbDims == dims_t.nbDims && dims_w.d[0] == dims_t.d[0]) { - scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; - // Default is element-wise - for (int i = 1; i < dims_w.nbDims; i++) { - if (dims_w.d[i] != dims_t.d[i]) { - // If dimension does not match, switch back to per-channel - scale_mode = nvinfer1::ScaleMode::kCHANNEL; - break; - } - } - // If the mode is per-channel, since channel dimension is assumed to be - // the third to last dimension, we need to make sure all other dimensions - // have size 1. - if (scale_mode == nvinfer1::ScaleMode::kCHANNEL) { - for (int i = 1; i < dims_w.nbDims; i++) { - if (dims_w.d[i] != 1) - return errors::InvalidArgument( - "Weight dims not compatible for channel-wise broadcast at ", - node_def.name()); - } - } - } else if (dims_w.nbDims == 1 && - dims_w.d[0] == dims_t.d[dims_t.nbDims - 1]) { - // Channel wise and broadcast required. We compare the last dimension of - // the tensor shape because of tensorflow default broadcasting rules. - need_to_permute = true; - scale_mode = nvinfer1::ScaleMode::kCHANNEL; - } else { - return errors::InvalidArgument("Weight dims not compatible at ", - node_def.name()); - } - } - // TODO(laigd): we should add validation_only support in TransposeTensor() and - // PrepareTensorForShape(). - if (params->validation_only) return Status::OK(); - - // Transpose last dimension. - std::vector permutation(dims_t.nbDims + 1); - if (need_to_permute) { - // We swap the last dimension into channel for trt, because of tensorflow - // default broadcasting rules. - for (int i = 0; i < static_cast(permutation.size()); i++) { - permutation[i] = i; - } - permutation[1] = dims_t.nbDims; - permutation[dims_t.nbDims] = 1; - TF_RETURN_IF_ERROR( - params->converter->TransposeTensor(tensor, permutation, &tensor)); - } - - // Prepare weights - TRT_ShapedWeights shift_weights(weights.TrtDType()); - TRT_ShapedWeights scale_weights(weights.TrtDType()); - TRT_ShapedWeights power_weights(weights.TrtDType()); - - if (node_def.op() == "Sub") { - if (swapped_inputs) { - shift_weights = weights; - nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary( - *tensor, nvinfer1::UnaryOperation::kNEG); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - // Since quantization ranges are symmetric, the same range as the input - // will work for the negation of the input. - params->converter->MarkQuantizationRangesAsInferrable( - tensor, layer->getOutput(0)); - tensor = layer->getOutput(0); - } else { - TRT_ShapedWeights neg_weights = - params->weight_store->GetTempWeights(weights); - LambdaFactory unary_op; - unary_op.op = LambdaFactory::OP_CATEGORY::NEG; - TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op)); - shift_weights = neg_weights; - } - } else if (node_def.op() == "Div" || node_def.op() == "RealDiv") { - if (swapped_inputs) { - // We need to infer the quantization range for this intermediate tensor. - // - // x -> [Recip] -> 1/x -> [Scale] -> s/x - // ^ - // need range for this - // - // We have the quantization scales for x and s/x - can we divide the scale - // for s/x by s? Only if it is a scalar. - // - // Because of this issue, fall back to BinaryTensorOpTensor if we are - // doing INT8 with no calibration. There is most likely no performance - // penalty by falling back here. - if (params->converter->precision_mode() == TrtPrecisionMode::INT8 && - !params->converter->use_calibration()) { - return errors::Unimplemented( - "Intermediate quantization range cannot be determined without" - " calibration. Falling back to BinaryTensorOpTensor for ", - node_def.op(), ", at ", node_def.name()); - } - scale_weights = weights; - nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary( - *tensor, nvinfer1::UnaryOperation::kRECIP); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - tensor = layer->getOutput(0); - } else { - TRT_ShapedWeights recip_weights = - params->weight_store->GetTempWeights(weights); - LambdaFactory unary_op; - unary_op.op = LambdaFactory::OP_CATEGORY::RECIP; - TF_RETURN_IF_ERROR(UnaryCompute(weights, &recip_weights, unary_op)); - scale_weights = recip_weights; - } - } else if (node_def.op() == "Mul") { - scale_weights = weights; - } else if (node_def.op() == "Add") { - shift_weights = weights; - } else { - // This should not happen. - return errors::Unimplemented("Binary op not supported at ", node_def.op()); - } - - nvinfer1::IScaleLayer* layer = params->converter->network()->addScale( - *tensor, scale_mode, shift_weights.GetTrtWeights(), - scale_weights.GetTrtWeights(), power_weights.GetTrtWeights()); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - - nvinfer1::ITensor* output_tensor = layer->getOutput(0); - // Transpose back dimension - if (need_to_permute) { - TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - output_tensor, permutation, &output_tensor)); - } - - // Pass the output - params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); - return Status::OK(); -} - Status ConvertConv2DHelper(OpConverterParams* params, int group, bool is_conv2d_backprop_input) { const auto& inputs = params->inputs; @@ -2060,74 +1879,6 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, return Status::OK(); } -Status BinaryTensorOpTensor(OpConverterParams* params, - const TRT_TensorOrWeights& operand_l, - const TRT_TensorOrWeights& operand_r) { - const auto& node_def = params->node_def; - static const std::unordered_map ops{ - {"Add", nvinfer1::ElementWiseOperation::kSUM}, - {"Mul", nvinfer1::ElementWiseOperation::kPROD}, - {"Sub", nvinfer1::ElementWiseOperation::kSUB}, - {"Div", nvinfer1::ElementWiseOperation::kDIV}, - {"RealDiv", nvinfer1::ElementWiseOperation::kDIV}, - {"Minimum", nvinfer1::ElementWiseOperation::kMIN}, - {"Maximum", nvinfer1::ElementWiseOperation::kMAX}, - {"Pow", nvinfer1::ElementWiseOperation::kPOW}, - }; - auto op_pair = ops.find(node_def.op()); - if (op_pair == ops.end()) { - return errors::Unimplemented("Binary op ", node_def.op(), - " not supported at: ", node_def.name()); - } - - nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r; - Status status = params->converter->GetTrtBroadcastShape( - operand_l, operand_r, &broadcasted_dims_l, &broadcasted_dims_r); - if (!status.ok()) { - return errors::InvalidArgument( - "Unsupported binary op broadcast scheme for op ", node_def.name(), ": ", - status.error_message()); - } - TFAttrs attrs(node_def); - nvinfer1::DataType dtype = attrs.get("T"); - if (dtype == nvinfer1::DataType::kINT32) { - return errors::Unimplemented("Binary op ", node_def.op(), - " does not support INT32, at ", - node_def.name()); - } - if (params->validation_only) return Status::OK(); - - nvinfer1::ITensor* tensor_l = nullptr; - nvinfer1::ITensor* tensor_r = nullptr; - status = params->converter->PrepareTensorForShape( - operand_l, broadcasted_dims_l, /*validation_only=*/false, &tensor_l); - if (status.ok()) { - status = params->converter->PrepareTensorForShape( - operand_r, broadcasted_dims_r, /*validation_only=*/false, &tensor_r); - } - if (!status.ok()) { - return errors::Internal("Failed to convert binary op ", node_def.name(), - ": ", status.error_message()); - } - - // Check type consistency. - TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype) - << DebugString(tensor_l->getType()) << " vs " << DebugString(dtype); - TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype) - << DebugString(tensor_r->getType()) << " vs " << DebugString(dtype); - - // Add ElementWise layer. - nvinfer1::IElementWiseLayer* layer = - params->converter->network()->addElementWise(*tensor_l, *tensor_r, - op_pair->second); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - nvinfer1::ITensor* output_tensor = layer->getOutput(0); - - // Pass the output - params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); - return Status::OK(); -} - Status ConvertPlugin(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; @@ -3349,16 +3100,10 @@ Status ConvertIdentity(OpConverterParams* params) { Status ConvertBinary(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - // TODO(tmorris): Enable once false is updated to mean either tensor or weight - // TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y", - // false}})); if (inputs.size() != 2) { - return errors::InvalidArgument(node_def.op(), " got ", inputs.size(), - " inputs but expected 2, at ", + return errors::InvalidArgument("Binary ops require two inputs, at ", node_def.name()); } - TF_RETURN_IF_ERROR( - AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); // Constant folding should have been done by TensorFlow if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) { @@ -3368,32 +3113,68 @@ Status ConvertBinary(OpConverterParams* params) { node_def.name()); } - // TODO(tmorris): TRT plans to deprecate IScaleLayer and will replace it with - // IElementwiseLayer. At that point, we can remove BinaryTensorOpWeight. For - // now, the performance will be slightly better with IScaleLayer because it - // can be fused in more situations. However, most of the benefits of - // IScaleLayer are when the layer performs both a shift and a scale, which we - // don't do except for convolutions. - // - // Try to convert into Scale layer first (for better performance). - // Since scale layer supports restricted broadcast policy and op types, we - // allow failure and try to handle it through Elementwise op - // (BinaryTensorOpTensor). - Status status = Status::OK(); - if (inputs.at(0).is_tensor() && inputs.at(1).is_weights()) { - status = BinaryTensorOpWeight(params, inputs.at(0).tensor(), - inputs.at(1).weights(), false); - } else if (inputs.at(0).is_weights() && inputs.at(1).is_tensor()) { - status = BinaryTensorOpWeight(params, inputs.at(1).tensor(), - inputs.at(0).weights(), true); + const TRT_TensorOrWeights& operand_l = inputs.at(0); + const TRT_TensorOrWeights& operand_r = inputs.at(1); + + static const std::unordered_map ops{ + {"Add", nvinfer1::ElementWiseOperation::kSUM}, + {"Mul", nvinfer1::ElementWiseOperation::kPROD}, + {"Sub", nvinfer1::ElementWiseOperation::kSUB}, + {"Div", nvinfer1::ElementWiseOperation::kDIV}, + {"RealDiv", nvinfer1::ElementWiseOperation::kDIV}, + {"Minimum", nvinfer1::ElementWiseOperation::kMIN}, + {"Maximum", nvinfer1::ElementWiseOperation::kMAX}, + {"Pow", nvinfer1::ElementWiseOperation::kPOW}, + }; + auto op_pair = ops.find(node_def.op()); + if (op_pair == ops.end()) { + return errors::Unimplemented("Binary op ", node_def.op(), + " not supported at: ", node_def.name()); } - // If both input are tensors, or one of them is weights but the conversion - // above failed, try the conversion using BinaryTensorOpTensor. - if ((inputs.at(0).is_tensor() && inputs.at(1).is_tensor()) || !status.ok()) { - if (!status.ok()) VLOG(2) << status; - status = BinaryTensorOpTensor(params, inputs.at(0), inputs.at(1)); + + nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r; + Status status = params->converter->GetTrtBroadcastShape( + operand_l, operand_r, &broadcasted_dims_l, &broadcasted_dims_r); + if (!status.ok()) { + return errors::InvalidArgument( + "Unsupported binary op broadcast scheme for op ", node_def.name(), ": ", + status.error_message()); } - return status; + if (params->validation_only) return Status::OK(); + + nvinfer1::ITensor* tensor_l = nullptr; + nvinfer1::ITensor* tensor_r = nullptr; + // This will also convert constants to tensors, and set quantization ranges. + status = params->converter->PrepareTensorForShape( + operand_l, broadcasted_dims_l, params->validation_only, &tensor_l); + if (status.ok()) { + status = params->converter->PrepareTensorForShape( + operand_r, broadcasted_dims_r, params->validation_only, &tensor_r); + } + if (!status.ok()) { + return errors::Internal("Failed to convert binary op ", node_def.name(), + ": ", status.error_message()); + } + + // Check type consistency. + TFAttrs attrs(node_def); + nvinfer1::DataType dtype = attrs.get("T"); + TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype) + << DebugString(tensor_l->getType()) << " vs " << DebugString(dtype); + TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype) + << DebugString(tensor_r->getType()) << " vs " << DebugString(dtype); + + // Add ElementWise layer. + nvinfer1::IElementWiseLayer* layer = + params->converter->network()->addElementWise( + *const_cast(tensor_l), + *const_cast(tensor_r), op_pair->second); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + + // Pass the output + params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); } Status ConvertRsqrt(OpConverterParams* params) { diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index 763b28b7402..d1958456291 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -51,6 +51,15 @@ namespace convert { (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \ NV_TENSORRT_PATCH == patch && NV_TENSORRT_BUILD >= build)) +inline std::ostream& operator<<(std::ostream& o, const nvinfer1::Dims& dims) +{ + o << "["; + for (int i = 0; i < dims.nbDims; i++) + o << (i ? "," : "") << dims.d[i]; + o << "]"; + return o; +} + struct EngineConnection { // Constructs a non-control edge. EngineConnection(const string& outside, int out_id, int out_port, diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 211a2ee5369..2963f25492c 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -1008,9 +1008,7 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) { }; // Both inputs are weights. - symmetric_test( - {1}, {1}, kIsNotTensor, kIsNotTensor, {}, {}, error::INVALID_ARGUMENT, - "Broadcasting requires at least one of the operands be tensors"); + symmetric_test({1}, {1}, kIsNotTensor, kIsNotTensor, {1}, {1}); // One tensor and one weights. symmetric_test({1, 1, 1}, {2}, kIsTensor, kIsNotTensor, {1, 1, 1}, {1, 1, 2}); @@ -1018,6 +1016,7 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) { symmetric_test({1, 3, 2}, {1}, kIsTensor, kIsNotTensor, {1, 3, 2}, {1, 1, 1}); symmetric_test({1, 1, 1}, {2, 3}, kIsTensor, kIsNotTensor, {1, 1, 1}, {1, 2, 3}); + symmetric_test({1, 1, 1}, {2, 3}, kIsTensor, kIsTensor, {1, 1, 1}, {1, 2, 3}); symmetric_test({1, 1, 1}, {2, 3, 4}, kIsTensor, kIsNotTensor, {1, 1, 1}, {2, 3, 4}); symmetric_test({1, 1, 1}, {1, 2, 3, 4}, kIsTensor, kIsNotTensor, {1, 1, 1}, @@ -1025,26 +1024,21 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) { symmetric_test({1, 3, 4}, {1, 2, 1, 4}, kIsTensor, kIsNotTensor, {1, 3, 4}, {2, 1, 4}); symmetric_test({1, 1, 1}, {2, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {}, - error::INVALID_ARGUMENT, "Infeasible broadcast scheme"); - symmetric_test({1, 1, 1}, {2, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {}, - error::INVALID_ARGUMENT, "Infeasible broadcast scheme", - /*operand_1_batch_size=*/2); - symmetric_test({1, 1, 1}, {1, 1, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {}, error::INVALID_ARGUMENT, - "Broadcasting beyond batch dimension is not supported " - "(tensor #dims 4 vs broadcast #dims 5)"); + "Cannot broadcast weights with non-trivial batch dimension"); + symmetric_test({1, 1, 1}, {2, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {}, + error::INVALID_ARGUMENT, + "Cannot broadcast weights with non-trivial batch dimension", + /*operand_1_batch_size=*/2); + symmetric_test({1, 1, 1}, {1, 1, 1, 1, 1}, kIsTensor, kIsNotTensor, + {1, 1, 1, 1}, {1, 1, 1, 1}); // Both inputs are tensors. - symmetric_test({1, 1, 1}, {1, 1}, kIsTensor, kIsTensor, {}, {}, - error::INVALID_ARGUMENT, - "Broadcasting beyond batch dimension is not supported " - "(tensor #dims 3 vs broadcast #dims 4)"); + symmetric_test({1, 1, 1}, {1, 1}, kIsTensor, kIsTensor, {1, 1, 1}, {1, 1, 1}); symmetric_test({1, 3, 4}, {2, 1, 4}, kIsTensor, kIsTensor, {1, 3, 4}, {2, 1, 4}); - symmetric_test({1, 1, 1}, {1, 1, 1, 1}, kIsTensor, kIsTensor, {}, {}, - error::INVALID_ARGUMENT, - "Broadcasting beyond batch dimension is not supported " - "(tensor #dims 4 vs broadcast #dims 5)"); + symmetric_test({1, 1, 1}, {1, 1, 1, 1}, kIsTensor, kIsTensor, {1, 1, 1, 1}, + {1, 1, 1, 1}); } TEST_F(ConverterTest, CreateConstantLayer) { @@ -2006,222 +2000,16 @@ void CheckAddedLayers(OpConverterTest* test, bool expect_scale_layer) { EXPECT_NE(expect_scale_layer, element_wise_layer_found); } -template -void TestBinaryTensorOpWeightNoBroadcast(OpConverterTest* test) { - typedef typename EnumToDataType::Type CType; - for (auto swap_inputs : {false, true}) { - test->Reset(); - NodeDef node_def; - if (swap_inputs) { - node_def = GetBinaryOpNodeDef("weights", "input", dtype); - } else { - node_def = GetBinaryOpNodeDef("input", "weights", dtype); - } - - const std::vector operand1{CType(3), CType(7.5)}; - const std::vector operand2{CType(2), CType(3)}; - - // It requires the dims to be at least of rank 3 to apply an IScaleLayer. - test->AddTestTensor("input", /*dims=*/{1, 1, 2}, /*batch_size=*/1, - TfDataTypeToTrt(dtype)); - test->AddTestWeights("weights", /*dims=*/{1, 1, 2}, - /*values=*/swap_inputs ? operand1 : operand2); - test->RunValidationAndConversion(node_def); - - // Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor. - CheckAddedLayers(test, /*expect_scale_layer=*/true); - - // Check the dims of the output ITensor. - TRT_TensorOrWeights output; - TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); - ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray({1, 1, 2}, output.tensor()->getDimensions()); - - const DataVec input_data{ - {"input", test::AsTensor(swap_inputs ? operand2 : operand1)}}; - DataVec output_data{{"my_binary", ConstructTensor(2)}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); - if (node_def.op() == "Add") { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(5), CType(10.5))); - } else if (node_def.op() == "Sub") { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(1), CType(4.5))); - } else if (node_def.op() == "Mul") { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(6), CType(22.5))); - } else if (node_def.op() == "Div") { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(1.5), CType(2.5))); - } else if (node_def.op() == "RealDiv") { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(1.5), CType(2.5))); - } else { - ASSERT_TRUE(false); - } - } -} - template -void TestBinaryTensorOpWeightWithChannelWiseBroadcast(OpConverterTest* test) { - typedef typename EnumToDataType::Type CType; - const NodeDef node_def = - GetBinaryOpNodeDef("input", "weights", dtype); - const std::vector input{CType(1), CType(2), CType(3), CType(4)}; - const std::vector weights{CType(10), CType(20)}; - // There are two types of valid dim pairs which requires channel-wise - // broadcasting: - // - input dims (X Y Z) vs weights dims (X 1 1) - // - input dims (X Y Z) vs weights dims (Z) - // Here X=Z=2 and Y=1. - for (auto weights_dims : std::vector>{{2, 1, 1}, {2}}) { - test->Reset(); - test->AddTestTensor("input", /*dims=*/{2, 1, 2}, /*batch_size=*/1, - TfDataTypeToTrt(dtype)); - test->AddTestWeights("weights", weights_dims, weights); - test->RunValidationAndConversion(node_def); - - // Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor. - CheckAddedLayers(test, /*expect_scale_layer=*/true); - - // Check the dims of the output ITensor. - TRT_TensorOrWeights output; - TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); - ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions()); - - const DataVec input_data{{"input", test::AsTensor(input)}}; - DataVec output_data{{"my_binary", ConstructTensor(4)}}; - test->BuildAndRun(input_data, &output_data); - if (weights_dims.size() == 1) { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(11), CType(22), CType(13), CType(24))); - } else { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(11), CType(12), CType(23), CType(24))); - } - } -} - -template -void TestBinaryTensorOpWeightWithUniformlyBroadcast(OpConverterTest* test) { - typedef typename EnumToDataType::Type CType; - const NodeDef node_def = - GetBinaryOpNodeDef("input", "weights", dtype); - const std::vector input{CType(1), CType(2), CType(3), CType(4)}; - const std::vector weights{CType(10)}; - test->Reset(); - test->AddTestTensor("input", /*dims=*/{2, 1, 2}, /*batch_size=*/1, - TfDataTypeToTrt(dtype)); - test->AddTestWeights("weights", {1, 1, 1, 1}, weights); - test->RunValidationAndConversion(node_def); - - // Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor. - CheckAddedLayers(test, /*expect_scale_layer=*/true); - - // Check the dims of the output ITensor. - TRT_TensorOrWeights output; - TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); - ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions()); - - const DataVec input_data{{"input", test::AsTensor(input)}}; - DataVec output_data{{"my_binary", ConstructTensor(4)}}; - test->BuildAndRun(input_data, &output_data); - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(11), CType(12), CType(13), CType(14))); -} - -template -void TestBinaryTensorOpWeightFallback(OpConverterTest* test, - const std::vector& input_dims, - const std::vector& weights_dims, - error::Code code = error::OK, - const char* error_msg_substr = nullptr, - const int input_batch_size = 1) { - const DataType dtype = DT_FLOAT; - typedef typename EnumToDataType::Type CType; - const size_t num_inputs = TrtTensorDimsNumElements(GetTestDims(input_dims)); - const size_t num_weights = - TrtWeightDimsNumElements(GetTestDims(weights_dims)); - - test->Reset(); - const NodeDef node_def = - GetBinaryOpNodeDef("input", "weights", dtype); - test->AddTestTensor("input", /*dims=*/input_dims, input_batch_size, - TfDataTypeToTrt(dtype)); - test->AddTestWeights( - "weights", /*dims=*/weights_dims, - /*values=*/std::vector(num_weights, CType(1))); - test->RunValidationAndConversion(node_def, code, error_msg_substr); - if (code != error::OK) return; - - // Make sure it does use BinaryTensorOpTensor, not BinaryTensorOpWeight. - CheckAddedLayers(test, /*expect_scale_layer=*/false); - - TRT_TensorOrWeights output; - TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); - ASSERT_TRUE(output.is_tensor()); - - // Check the dims of the output ITensor. - std::vector expected_output_dims = input_dims; - for (int i = expected_output_dims.size() - 1, j = weights_dims.size() - 1; - i >= 0 && j >= 0; --i, --j) { - if (expected_output_dims[i] == 1) { - expected_output_dims[i] = weights_dims[j]; - } - } - ExpectTrtDimsEqualsArray(expected_output_dims, - output.tensor()->getDimensions()); - - // Check the result of running the engine. - const int expected_num_outputs = - TrtTensorDimsNumElements(GetTestDims(expected_output_dims)); - const DataVec input_data{ - {"input", ConstructTensor(num_inputs, CType(2))}}; - DataVec output_data{ - {"my_binary", ConstructTensor(expected_num_outputs)}}; - test->BuildAndRun(input_data, &output_data); - if (node_def.op() == "Add") { - EXPECT_THAT( - GetSpanForData(output_data[0]), - ElementsAreArray(std::vector(expected_num_outputs, CType(3)))); - } else if (node_def.op() == "Minimum") { - EXPECT_THAT( - GetSpanForData(output_data[0]), - ElementsAreArray(std::vector(expected_num_outputs, CType(1)))); - } else { - ASSERT_TRUE(false); - } -} - -template -void TestBinaryTensorOpTensor(OpConverterTest* test) { - typedef typename EnumToDataType::Type CType; - test->Reset(); - const NodeDef node_def = - GetBinaryOpNodeDef("input1", "input2", dtype); - test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/1, - TfDataTypeToTrt(dtype)); - test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/1, - TfDataTypeToTrt(dtype)); - test->RunValidationAndConversion(node_def); - - // Make sure it does use BinaryTensorOpTensor, not BinaryTensorOpWeight. - CheckAddedLayers(test, /*expect_scale_layer=*/false); +void checkBinaryResults(OpConverterTest* test, const NodeDef& node_def, + const DataVec& input_data, DataVec& output_data) { + using CType = typename EnumToDataType::Type; // Check output dims. TRT_TensorOrWeights output; TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions()); - - const DataVec input_data{ - {"input1", test::AsTensor({CType(3), CType(6)})}, - {"input2", test::AsTensor({CType(2), CType(3)})}}; - DataVec output_data{{"my_binary", ConstructTensor(4)}}; // After broadcasting first input becomes {3, 6, 3, 6} and second input // becomes {2, 3, 2, 3}. test->BuildAndRun( @@ -2257,6 +2045,61 @@ void TestBinaryTensorOpTensor(OpConverterTest* test) { } } +template +void TestBinaryTensorOpTensor(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + test->Reset(); + const NodeDef node_def = + GetBinaryOpNodeDef("input1", "input2", dtype); + test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); + test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); + test->RunValidationAndConversion(node_def); + + const DataVec input_data{ + {"input1", test::AsTensor({CType(3), CType(6)})}, + {"input2", test::AsTensor({CType(2), CType(3)})}}; + DataVec output_data{{"my_binary", ConstructTensor(4)}}; + checkBinaryResults(test, node_def, input_data, output_data); +} + +template +void TestBinaryTensorOpWeight(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + test->Reset(); + const NodeDef node_def = + GetBinaryOpNodeDef("input1", "input2", dtype); + test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); + test->AddTestWeights("input2", /*dims=*/{2, 1}, + /*values=*/std::vector{CType(2), CType(3)}); + test->RunValidationAndConversion(node_def); + + const DataVec input_data{ + {"input1", test::AsTensor({CType(3), CType(6)})}}; + DataVec output_data{{"my_binary", ConstructTensor(4)}}; + checkBinaryResults(test, node_def, input_data, output_data); +} + +template +void TestBinaryWeightOpTensor(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + test->Reset(); + const NodeDef node_def = + GetBinaryOpNodeDef("input1", "input2", dtype); + test->AddTestWeights("input1", /*dims=*/{1, 2}, + /*values=*/std::vector{CType(3), CType(6)}); + test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); + test->RunValidationAndConversion(node_def); + + const DataVec input_data{ + {"input2", test::AsTensor({CType(2), CType(3)})}}; + DataVec output_data{{"my_binary", ConstructTensor(4)}}; + checkBinaryResults(test, node_def, input_data, output_data); +} + TEST_F(OpConverterTest, ConvertBinary) { AttrValue dtype; dtype.set_type(DT_FLOAT); @@ -2266,10 +2109,9 @@ TEST_F(OpConverterTest, ConvertBinary) { NodeDef node_def = MakeNodeDef("my_add", "Add", {num_inputs, "input"}, {{"T", dtype}}); AddTestTensor("input", {1}, /*batch_size=*/1, nvinfer1::DataType::kFLOAT); - RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, - StrCat("Add got ", std::to_string(num_inputs), - " inputs but expected 2, at my_add") - .c_str()); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + StrCat("Binary ops require two inputs, at my_add").c_str()); } { // Both inputs are weights. @@ -2283,42 +2125,7 @@ TEST_F(OpConverterTest, ConvertBinary) { "Constant folding is falled back to TensorFlow, binary op received " "both input as constant at: my_add"); } - - // Test BinaryTensorOpWeight() without broadcasting. - TestBinaryTensorOpWeightNoBroadcast(this); - TestBinaryTensorOpWeightNoBroadcast(this); - TestBinaryTensorOpWeightNoBroadcast(this); - TestBinaryTensorOpWeightNoBroadcast(this); - TestBinaryTensorOpWeightNoBroadcast(this); - - TestBinaryTensorOpWeightNoBroadcast(this); - TestBinaryTensorOpWeightNoBroadcast(this); - TestBinaryTensorOpWeightNoBroadcast(this); - TestBinaryTensorOpWeightNoBroadcast(this); - TestBinaryTensorOpWeightNoBroadcast(this); - - // Test BinaryTensorOpWeight() with channel-wise broadcasting. - TestBinaryTensorOpWeightWithChannelWiseBroadcast(this); - - // Test BinaryTensorOpWeight() with uniformly broadcasting. - TestBinaryTensorOpWeightWithUniformlyBroadcast(this); - - // Test BinaryTensorOpWeight() falling back to BinaryTensorOpTensor(). - // Unsupported op. - TestBinaryTensorOpWeightFallback(this, {1, 1, 1}, {1}); - // Rank of input tensor dimension <3. - TestBinaryTensorOpWeightFallback(this, {1, 1}, {1}); - // Broadcast on batch dimension, should fail. - TestBinaryTensorOpWeightFallback( - this, {1, 1, 1}, {2, 1, 1, 1}, error::INVALID_ARGUMENT, - "Unsupported binary op broadcast scheme for op my_binary", - /*input_batch_size=*/2); - // Incompatible dims with per-channel mode. - TestBinaryTensorOpWeightFallback(this, {1, 1, 1}, {1, 2, 1}); - // Incompatible dims. - TestBinaryTensorOpWeightFallback(this, {1, 2, 1}, {2}); - - // Test BinaryTensorOpTensor() with broadcasting. + // FP32 tests TestBinaryTensorOpTensor(this); TestBinaryTensorOpTensor(this); TestBinaryTensorOpTensor(this); @@ -2327,7 +2134,26 @@ TEST_F(OpConverterTest, ConvertBinary) { TestBinaryTensorOpTensor(this); TestBinaryTensorOpTensor(this); TestBinaryTensorOpTensor(this); + // Test with operand R = Weights + TestBinaryTensorOpWeight(this); + TestBinaryTensorOpWeight(this); + TestBinaryTensorOpWeight(this); + TestBinaryTensorOpWeight(this); + TestBinaryTensorOpWeight(this); + TestBinaryTensorOpWeight(this); + TestBinaryTensorOpWeight(this); + TestBinaryTensorOpWeight(this); + // Test with operand L = Weights + TestBinaryWeightOpTensor(this); + TestBinaryWeightOpTensor(this); + TestBinaryWeightOpTensor(this); + TestBinaryWeightOpTensor(this); + TestBinaryWeightOpTensor(this); + TestBinaryWeightOpTensor(this); + TestBinaryWeightOpTensor(this); + TestBinaryWeightOpTensor(this); + // FP16 tests TestBinaryTensorOpTensor(this); TestBinaryTensorOpTensor(this); TestBinaryTensorOpTensor(this); @@ -2336,6 +2162,24 @@ TEST_F(OpConverterTest, ConvertBinary) { TestBinaryTensorOpTensor(this); TestBinaryTensorOpTensor(this); TestBinaryTensorOpTensor(this); + // Test with operand R = Weights + TestBinaryTensorOpWeight(this); + TestBinaryTensorOpWeight(this); + TestBinaryTensorOpWeight(this); + TestBinaryTensorOpWeight(this); + TestBinaryTensorOpWeight(this); + TestBinaryTensorOpWeight(this); + TestBinaryTensorOpWeight(this); + TestBinaryTensorOpWeight(this); + // Test with operand L = Weights + TestBinaryWeightOpTensor(this); + TestBinaryWeightOpTensor(this); + TestBinaryWeightOpTensor(this); + TestBinaryWeightOpTensor(this); + TestBinaryWeightOpTensor(this); + TestBinaryWeightOpTensor(this); + TestBinaryWeightOpTensor(this); + TestBinaryWeightOpTensor(this); } TEST_F(OpConverterTest, ConvertQuantize) { From 5a3a5084712bbed016bfc5f0ff42c709260ad616 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Fri, 17 May 2019 09:44:32 -0700 Subject: [PATCH 02/15] Bring in merged MM changes --- tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index d1958456291..763b28b7402 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -51,15 +51,6 @@ namespace convert { (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \ NV_TENSORRT_PATCH == patch && NV_TENSORRT_BUILD >= build)) -inline std::ostream& operator<<(std::ostream& o, const nvinfer1::Dims& dims) -{ - o << "["; - for (int i = 0; i < dims.nbDims; i++) - o << (i ? "," : "") << dims.d[i]; - o << "]"; - return o; -} - struct EngineConnection { // Constructs a non-control edge. EngineConnection(const string& outside, int out_id, int out_port, From 35b41992fa27845dd061257976b64c2d5f57530c Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Fri, 17 May 2019 10:23:46 -0700 Subject: [PATCH 03/15] Make GetTrtBroadcastShape a standalone function. Small improvements to ConvertBinary --- .../tf2tensorrt/convert/convert_nodes.cc | 49 +++++++------------ .../tf2tensorrt/convert/convert_nodes.h | 14 +++--- .../tf2tensorrt/convert/convert_nodes_test.cc | 6 +-- 3 files changed, 27 insertions(+), 42 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 739bf942c68..2d70dde8f68 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -385,10 +385,10 @@ string DebugString(const nvinfer1::ITensor& tensor) { ", dims=", DebugString(tensor.getDimensions()), ")"); } -Status Converter::GetTrtBroadcastShape( +Status GetTrtBroadcastShape( const TRT_TensorOrWeights& operand_l, const TRT_TensorOrWeights& operand_r, nvinfer1::Dims* operand_l_new_dims, - nvinfer1::Dims* operand_r_new_dims) const { + nvinfer1::Dims* operand_r_new_dims) { // TensorRT Elementwise op supports broadcast but requires both tensor to be // of Identical rank. // This function broadcasts the lower rank dimension across the higher rank @@ -396,17 +396,15 @@ Status Converter::GetTrtBroadcastShape( (*operand_l_new_dims) = operand_l.GetTrtDims(); (*operand_r_new_dims) = operand_r.GetTrtDims(); - // clang-format off // Weights may include a batch dimension, so we need to remove it. // We determine if that is the case by checking if the rank of the weights is // larger than the rank of the tensor. Needed for cases such as: // t: [1, 1] w/ implicit batch size of 1 // w: [1, 1, 1] // where the output in TRT is expected to be 2D, not 3D. - // clang-format on if (operand_l.is_weights() && operand_l_new_dims->nbDims > operand_r_new_dims->nbDims) { - if (operand_l_new_dims->d[0] != -1 && operand_l_new_dims->d[0] != 1) { + if (operand_l_new_dims->d[0] != 1) { return errors::InvalidArgument( "Cannot broadcast weights with non-trivial batch dimension"); } @@ -415,7 +413,7 @@ Status Converter::GetTrtBroadcastShape( if (operand_r.is_weights() && operand_r_new_dims->nbDims > operand_l_new_dims->nbDims) { - if (operand_r_new_dims->d[0] != -1 && operand_r_new_dims->d[0] != 1) { + if (operand_r_new_dims->d[0] != 1) { return errors::InvalidArgument( "Cannot broadcast weights with non-trivial batch dimension"); } @@ -3112,7 +3110,8 @@ Status ConvertBinary(OpConverterParams* params) { "both input as constant at: ", node_def.name()); } - + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); const TRT_TensorOrWeights& operand_l = inputs.at(0); const TRT_TensorOrWeights& operand_r = inputs.at(1); @@ -3133,30 +3132,20 @@ Status ConvertBinary(OpConverterParams* params) { } nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r; - Status status = params->converter->GetTrtBroadcastShape( - operand_l, operand_r, &broadcasted_dims_l, &broadcasted_dims_r); - if (!status.ok()) { - return errors::InvalidArgument( - "Unsupported binary op broadcast scheme for op ", node_def.name(), ": ", - status.error_message()); - } - if (params->validation_only) return Status::OK(); + TF_RETURN_IF_ERROR(GetTrtBroadcastShape( + operand_l, operand_r, &broadcasted_dims_l, &broadcasted_dims_r)); nvinfer1::ITensor* tensor_l = nullptr; nvinfer1::ITensor* tensor_r = nullptr; // This will also convert constants to tensors, and set quantization ranges. - status = params->converter->PrepareTensorForShape( - operand_l, broadcasted_dims_l, params->validation_only, &tensor_l); - if (status.ok()) { - status = params->converter->PrepareTensorForShape( - operand_r, broadcasted_dims_r, params->validation_only, &tensor_r); - } - if (!status.ok()) { - return errors::Internal("Failed to convert binary op ", node_def.name(), - ": ", status.error_message()); - } + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + operand_l, broadcasted_dims_l, params->validation_only, &tensor_l)); + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + operand_r, broadcasted_dims_r, params->validation_only, &tensor_r)); + if (params->validation_only) return Status::OK(); // Check type consistency. + // TODO(tmorris): Check if this is still necessary. TFAttrs attrs(node_def); nvinfer1::DataType dtype = attrs.get("T"); TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype) @@ -3167,14 +3156,12 @@ Status ConvertBinary(OpConverterParams* params) { // Add ElementWise layer. nvinfer1::IElementWiseLayer* layer = params->converter->network()->addElementWise( - *const_cast(tensor_l), - *const_cast(tensor_r), op_pair->second); + *tensor_l, *tensor_r, op_pair->second); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - nvinfer1::ITensor* output_tensor = layer->getOutput(0); // Pass the output - params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); - return tensorflow::Status::OK(); + params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); + return Status::OK(); } Status ConvertRsqrt(OpConverterParams* params) { @@ -4327,7 +4314,7 @@ Status ConvertSquaredDifference(OpConverterParams* params) { const auto& node_def = params->node_def; // Broadcast inputs. nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r; - TF_RETURN_IF_ERROR(params->converter->GetTrtBroadcastShape( + TF_RETURN_IF_ERROR(GetTrtBroadcastShape( inputs.at(0), inputs.at(1), &broadcasted_dims_l, &broadcasted_dims_r)); nvinfer1::ITensor* tensor_l = nullptr; nvinfer1::ITensor* tensor_r = nullptr; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index 763b28b7402..d0f6d5ef1d1 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -512,13 +512,6 @@ class Converter { const bool validation_only, nvinfer1::ITensor** tensor); - // Return OK if the broadcast scheme is supported and compute the shapes after - // broadcasting. - Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l, - const TRT_TensorOrWeights& operand_r, - nvinfer1::Dims* operand_l_new_dims, - nvinfer1::Dims* operand_r_new_dims) const; - // Creates an IConstantLayer using 'weights' whose dimensions are specified by // 'dims', and returns the output ITensor. nvinfer1::ITensor* CreateConstantLayer(const TRT_ShapedWeights& weights, @@ -592,6 +585,13 @@ class Converter { friend class OpConverterTest; }; +// Return OK if the broadcast scheme is supported and compute the shapes after +// broadcasting. +Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l, + const TRT_TensorOrWeights& operand_r, + nvinfer1::Dims* operand_l_new_dims, + nvinfer1::Dims* operand_r_new_dims); + // Map of all supported UnaryOperations const std::unordered_map* UnaryOperationMap(); // Map of all supported ActivationTypes diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 2963f25492c..cfbe203ec18 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -988,8 +988,7 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) { operand_2_shape, operand_2_is_tensor, operand_2_batch_size); // operand_1 broadcast operand_2 - ExpectStatus( - this->converter_->GetTrtBroadcastShape( + ExpectStatus(GetTrtBroadcastShape( operand_1, operand_2, &operand_1_new_dims, &operand_2_new_dims), expected_code, expected_error_msg_substr); if (expected_code == error::OK) { @@ -997,8 +996,7 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) { ExpectTrtDimsEqualsArray(expected_operand_2_shape, operand_2_new_dims); } // operand_2 broadcast operand_1 - ExpectStatus( - this->converter_->GetTrtBroadcastShape( + ExpectStatus(GetTrtBroadcastShape( operand_2, operand_1, &operand_2_new_dims, &operand_1_new_dims), expected_code, expected_error_msg_substr); if (expected_code == error::OK) { From 1d7af1619fb030948c2357ab0c2915f0baffe196 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Fri, 17 May 2019 10:47:20 -0700 Subject: [PATCH 04/15] Rename CheckBinaryResults --- .../tf2tensorrt/convert/convert_nodes_test.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index cfbe203ec18..e58f966e4bb 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -1005,7 +1005,7 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) { } }; - // Both inputs are weights. + // Both inputs are weights. This should be handled by constfold grappler. symmetric_test({1}, {1}, kIsNotTensor, kIsNotTensor, {1}, {1}); // One tensor and one weights. @@ -1999,7 +1999,7 @@ void CheckAddedLayers(OpConverterTest* test, bool expect_scale_layer) { } template -void checkBinaryResults(OpConverterTest* test, const NodeDef& node_def, +void CheckBinaryResults(OpConverterTest* test, const NodeDef& node_def, const DataVec& input_data, DataVec& output_data) { using CType = typename EnumToDataType::Type; @@ -2059,7 +2059,7 @@ void TestBinaryTensorOpTensor(OpConverterTest* test) { {"input1", test::AsTensor({CType(3), CType(6)})}, {"input2", test::AsTensor({CType(2), CType(3)})}}; DataVec output_data{{"my_binary", ConstructTensor(4)}}; - checkBinaryResults(test, node_def, input_data, output_data); + CheckBinaryResults(test, node_def, input_data, output_data); } template @@ -2077,7 +2077,7 @@ void TestBinaryTensorOpWeight(OpConverterTest* test) { const DataVec input_data{ {"input1", test::AsTensor({CType(3), CType(6)})}}; DataVec output_data{{"my_binary", ConstructTensor(4)}}; - checkBinaryResults(test, node_def, input_data, output_data); + CheckBinaryResults(test, node_def, input_data, output_data); } template @@ -2095,7 +2095,7 @@ void TestBinaryWeightOpTensor(OpConverterTest* test) { const DataVec input_data{ {"input2", test::AsTensor({CType(2), CType(3)})}}; DataVec output_data{{"my_binary", ConstructTensor(4)}}; - checkBinaryResults(test, node_def, input_data, output_data); + CheckBinaryResults(test, node_def, input_data, output_data); } TEST_F(OpConverterTest, ConvertBinary) { @@ -2152,6 +2152,7 @@ TEST_F(OpConverterTest, ConvertBinary) { TestBinaryWeightOpTensor(this); // FP16 tests + // TODO(tmorris): Use templates to avoid duplication. TestBinaryTensorOpTensor(this); TestBinaryTensorOpTensor(this); TestBinaryTensorOpTensor(this); From dffc568b12b3e2cc2cb57eaf2354c9f3e385a40c Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Fri, 17 May 2019 11:37:26 -0700 Subject: [PATCH 05/15] Remove unnecessary type consistency check --- tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 2d70dde8f68..78a1bd61a8d 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -3144,15 +3144,6 @@ Status ConvertBinary(OpConverterParams* params) { operand_r, broadcasted_dims_r, params->validation_only, &tensor_r)); if (params->validation_only) return Status::OK(); - // Check type consistency. - // TODO(tmorris): Check if this is still necessary. - TFAttrs attrs(node_def); - nvinfer1::DataType dtype = attrs.get("T"); - TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype) - << DebugString(tensor_l->getType()) << " vs " << DebugString(dtype); - TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype) - << DebugString(tensor_r->getType()) << " vs " << DebugString(dtype); - // Add ElementWise layer. nvinfer1::IElementWiseLayer* layer = params->converter->network()->addElementWise( From af0a4a724bd4774f59c52c57dc5682ec79a9021a Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Fri, 17 May 2019 13:01:53 -0700 Subject: [PATCH 06/15] Fix format --- .../tf2tensorrt/convert/convert_nodes.cc | 76 +++++++++---------- .../tf2tensorrt/convert/convert_nodes_test.cc | 12 +-- 2 files changed, 43 insertions(+), 45 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 78a1bd61a8d..582eb4e256d 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -385,10 +385,10 @@ string DebugString(const nvinfer1::ITensor& tensor) { ", dims=", DebugString(tensor.getDimensions()), ")"); } -Status GetTrtBroadcastShape( - const TRT_TensorOrWeights& operand_l, const TRT_TensorOrWeights& operand_r, - nvinfer1::Dims* operand_l_new_dims, - nvinfer1::Dims* operand_r_new_dims) { +Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l, + const TRT_TensorOrWeights& operand_r, + nvinfer1::Dims* operand_l_new_dims, + nvinfer1::Dims* operand_r_new_dims) { // TensorRT Elementwise op supports broadcast but requires both tensor to be // of Identical rank. // This function broadcasts the lower rank dimension across the higher rank @@ -459,11 +459,10 @@ Status GetTrtBroadcastShape( if ((operand_l_new_dims->d[i] != operand_r_new_dims->d[i]) && (operand_l_new_dims->d[i] != 1) && (operand_r_new_dims->d[i] != 1)) { return errors::InvalidArgument( - "Infeasible broadcast scheme (", - "batch_dim: ", operand_l_new_dims->d[0], ", ", - DebugString(*operand_l_new_dims), " vs ", - "batch_dim: ", operand_r_new_dims->d[0], ", ", - DebugString(*operand_r_new_dims), ")"); + "Infeasible broadcast scheme (batch_dim: ", operand_l_new_dims->d[0], + ", ", DebugString(*operand_l_new_dims), " vs batch_dim: ", + operand_r_new_dims->d[0], ", ", DebugString(*operand_r_new_dims), + ")"); } } return Status::OK(); @@ -3146,8 +3145,8 @@ Status ConvertBinary(OpConverterParams* params) { // Add ElementWise layer. nvinfer1::IElementWiseLayer* layer = - params->converter->network()->addElementWise( - *tensor_l, *tensor_r, op_pair->second); + params->converter->network()->addElementWise(*tensor_l, *tensor_r, + op_pair->second); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); // Pass the output @@ -3968,8 +3967,8 @@ Status ConvertMatMulHelper(OpConverterParams* params, params, input_a.tensor(), input_b.weights(), transpose_b, node_name); } - constexpr auto get_matrix_op = - [](nvinfer1::ITensor* in, bool transpose) -> nvinfer1::MatrixOperation { + constexpr auto get_matrix_op = []( + nvinfer1::ITensor* in, bool transpose) -> nvinfer1::MatrixOperation { return (in->getDimensions().nbDims < 2) ? nvinfer1::MatrixOperation::kVECTOR : (transpose) ? nvinfer1::MatrixOperation::kTRANSPOSE @@ -3979,9 +3978,8 @@ Status ConvertMatMulHelper(OpConverterParams* params, // If the MatMul operand is a constant, applies transposes at conversion-time // as necessary. If the operand is a tensor, does nothing. If required // transposes were applied, sets transpose to false. - const auto prepare_matmul_operand = - [¶ms](TRT_TensorOrWeights operand, - bool* transpose) -> nvinfer1::ITensor* { + const auto prepare_matmul_operand = [¶ms]( + TRT_TensorOrWeights operand, bool* transpose) -> nvinfer1::ITensor* { if (operand.is_tensor()) { return operand.tensor(); } else { @@ -4055,29 +4053,29 @@ Status ConvertBatchMatMul(OpConverterParams* params) { const bool transpose_b = attrs.get("adj_y"); // Removes the batch dimension from weights. - const auto remove_weights_batch_dim = - [¶ms](const TRT_TensorOrWeights& input, TRT_TensorOrWeights* tensor) { - auto dims = input.GetTrtDims(); - if (input.is_weights()) { - // The other operand must be a tensor, this is ensured by earlier - // checks. Checks that the batch dimension is not changed by - // broadcasting. - if (dims.d[0] != 1) { - return errors::InvalidArgument( - "Input weight attempts to broadcast across batch dimension for " - "BatchMatMul, at ", - params->node_def.name()); - } - // Remove the batch dimension from the weights. - TF_RETURN_IF_ERROR(RemoveBatchDimension(&dims)); - } - // Create tensor and reshape if necessary. - nvinfer1::ITensor* t{nullptr}; - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - input, dims, params->validation_only, &t)); - *tensor = TRT_TensorOrWeights{t}; - return Status::OK(); - }; + const auto remove_weights_batch_dim = [¶ms]( + const TRT_TensorOrWeights& input, TRT_TensorOrWeights* tensor) { + auto dims = input.GetTrtDims(); + if (input.is_weights()) { + // The other operand must be a tensor, this is ensured by earlier + // checks. Checks that the batch dimension is not changed by + // broadcasting. + if (dims.d[0] != 1) { + return errors::InvalidArgument( + "Input weight attempts to broadcast across batch dimension for " + "BatchMatMul, at ", + params->node_def.name()); + } + // Remove the batch dimension from the weights. + TF_RETURN_IF_ERROR(RemoveBatchDimension(&dims)); + } + // Create tensor and reshape if necessary. + nvinfer1::ITensor* t{nullptr}; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + input, dims, params->validation_only, &t)); + *tensor = TRT_TensorOrWeights{t}; + return Status::OK(); + }; TRT_TensorOrWeights tensor_l{nullptr}; TRT_TensorOrWeights tensor_r{nullptr}; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index e58f966e4bb..ba877dd3700 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -988,17 +988,17 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) { operand_2_shape, operand_2_is_tensor, operand_2_batch_size); // operand_1 broadcast operand_2 - ExpectStatus(GetTrtBroadcastShape( - operand_1, operand_2, &operand_1_new_dims, &operand_2_new_dims), - expected_code, expected_error_msg_substr); + ExpectStatus(GetTrtBroadcastShape(operand_1, operand_2, &operand_1_new_dims, + &operand_2_new_dims), + expected_code, expected_error_msg_substr); if (expected_code == error::OK) { ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims); ExpectTrtDimsEqualsArray(expected_operand_2_shape, operand_2_new_dims); } // operand_2 broadcast operand_1 - ExpectStatus(GetTrtBroadcastShape( - operand_2, operand_1, &operand_2_new_dims, &operand_1_new_dims), - expected_code, expected_error_msg_substr); + ExpectStatus(GetTrtBroadcastShape(operand_2, operand_1, &operand_2_new_dims, + &operand_1_new_dims), + expected_code, expected_error_msg_substr); if (expected_code == error::OK) { ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims); ExpectTrtDimsEqualsArray(expected_operand_2_shape, operand_2_new_dims); From 8e044458a6f7b607c2ff1a458ec3196c4d37a82e Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Fri, 17 May 2019 15:41:19 -0700 Subject: [PATCH 07/15] Clarify comment regarding broadcasting weights --- tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 582eb4e256d..331dd879a8f 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -396,11 +396,12 @@ Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l, (*operand_l_new_dims) = operand_l.GetTrtDims(); (*operand_r_new_dims) = operand_r.GetTrtDims(); - // Weights may include a batch dimension, so we need to remove it. - // We determine if that is the case by checking if the rank of the weights is - // larger than the rank of the tensor. Needed for cases such as: + // Weights may include a dimension which must be broadcasted against a + // tensor's batch dimension. This occurs when the rank of the weights is + // larger than the rank of the tensor. Example: // t: [1, 1] w/ implicit batch size of 1 // w: [1, 1, 1] + /// ^ this dimension in w needs to be broadcasted against t's batch dim. // where the output in TRT is expected to be 2D, not 3D. if (operand_l.is_weights() && operand_l_new_dims->nbDims > operand_r_new_dims->nbDims) { @@ -410,7 +411,6 @@ Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l, } TF_RETURN_IF_ERROR(RemoveBatchDimension(operand_l_new_dims)); } - if (operand_r.is_weights() && operand_r_new_dims->nbDims > operand_l_new_dims->nbDims) { if (operand_r_new_dims->d[0] != 1) { From 761f172793217c9d5a868f2eba85cfa3fdbbd6bb Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Wed, 22 May 2019 12:56:32 -0700 Subject: [PATCH 08/15] Use only a single TestBinaryOp function with args for input being tensors or weights --- .../tf2tensorrt/convert/convert_nodes_test.cc | 194 +++++++----------- 1 file changed, 77 insertions(+), 117 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index ba877dd3700..8888caa1cfe 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -1998,11 +1998,39 @@ void CheckAddedLayers(OpConverterTest* test, bool expect_scale_layer) { EXPECT_NE(expect_scale_layer, element_wise_layer_found); } -template -void CheckBinaryResults(OpConverterTest* test, const NodeDef& node_def, - const DataVec& input_data, DataVec& output_data) { - using CType = typename EnumToDataType::Type; +template +void TestBinaryOp(OpConverterTest* test, bool operand_1_is_tensor, + bool operand_2_is_tensor) { + typedef typename EnumToDataType::Type CType; + test->Reset(); + const NodeDef node_def = + GetBinaryOpNodeDef("input1", "input2", dtype); + if (operand_1_is_tensor) { + test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); + } else { + test->AddTestWeights("input1", /*dims=*/{1, 2}, + /*values=*/std::vector{CType(3), CType(6)}); + } + if (operand_2_is_tensor) { + test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); + } else { + test->AddTestWeights("input2", /*dims=*/{2, 1}, + /*values=*/std::vector{CType(2), CType(3)}); + } + test->RunValidationAndConversion(node_def); + DataVec input_data; + if (operand_1_is_tensor) { + input_data.emplace_back("input1", + test::AsTensor({CType(3), CType(6)})); + } + if (operand_2_is_tensor) { + input_data.emplace_back("input2", + test::AsTensor({CType(2), CType(3)})); + } + DataVec output_data{{"my_binary", ConstructTensor(4)}}; // Check output dims. TRT_TensorOrWeights output; TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); @@ -2010,9 +2038,9 @@ void CheckBinaryResults(OpConverterTest* test, const NodeDef& node_def, ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions()); // After broadcasting first input becomes {3, 6, 3, 6} and second input // becomes {2, 3, 2, 3}. - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + test->BuildAndRun(input_data, &output_data, dtype == DT_HALF + ? TrtPrecisionMode::FP16 + : TrtPrecisionMode::FP32); if (node_def.op() == "Add") { EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(5), CType(8), CType(6), CType(9))); @@ -2043,61 +2071,6 @@ void CheckBinaryResults(OpConverterTest* test, const NodeDef& node_def, } } -template -void TestBinaryTensorOpTensor(OpConverterTest* test) { - typedef typename EnumToDataType::Type CType; - test->Reset(); - const NodeDef node_def = - GetBinaryOpNodeDef("input1", "input2", dtype); - test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/1, - TfDataTypeToTrt(dtype)); - test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/1, - TfDataTypeToTrt(dtype)); - test->RunValidationAndConversion(node_def); - - const DataVec input_data{ - {"input1", test::AsTensor({CType(3), CType(6)})}, - {"input2", test::AsTensor({CType(2), CType(3)})}}; - DataVec output_data{{"my_binary", ConstructTensor(4)}}; - CheckBinaryResults(test, node_def, input_data, output_data); -} - -template -void TestBinaryTensorOpWeight(OpConverterTest* test) { - typedef typename EnumToDataType::Type CType; - test->Reset(); - const NodeDef node_def = - GetBinaryOpNodeDef("input1", "input2", dtype); - test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/1, - TfDataTypeToTrt(dtype)); - test->AddTestWeights("input2", /*dims=*/{2, 1}, - /*values=*/std::vector{CType(2), CType(3)}); - test->RunValidationAndConversion(node_def); - - const DataVec input_data{ - {"input1", test::AsTensor({CType(3), CType(6)})}}; - DataVec output_data{{"my_binary", ConstructTensor(4)}}; - CheckBinaryResults(test, node_def, input_data, output_data); -} - -template -void TestBinaryWeightOpTensor(OpConverterTest* test) { - typedef typename EnumToDataType::Type CType; - test->Reset(); - const NodeDef node_def = - GetBinaryOpNodeDef("input1", "input2", dtype); - test->AddTestWeights("input1", /*dims=*/{1, 2}, - /*values=*/std::vector{CType(3), CType(6)}); - test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/1, - TfDataTypeToTrt(dtype)); - test->RunValidationAndConversion(node_def); - - const DataVec input_data{ - {"input2", test::AsTensor({CType(2), CType(3)})}}; - DataVec output_data{{"my_binary", ConstructTensor(4)}}; - CheckBinaryResults(test, node_def, input_data, output_data); -} - TEST_F(OpConverterTest, ConvertBinary) { AttrValue dtype; dtype.set_type(DT_FLOAT); @@ -2123,62 +2096,49 @@ TEST_F(OpConverterTest, ConvertBinary) { "Constant folding is falled back to TensorFlow, binary op received " "both input as constant at: my_add"); } - // FP32 tests - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - // Test with operand R = Weights - TestBinaryTensorOpWeight(this); - TestBinaryTensorOpWeight(this); - TestBinaryTensorOpWeight(this); - TestBinaryTensorOpWeight(this); - TestBinaryTensorOpWeight(this); - TestBinaryTensorOpWeight(this); - TestBinaryTensorOpWeight(this); - TestBinaryTensorOpWeight(this); - // Test with operand L = Weights - TestBinaryWeightOpTensor(this); - TestBinaryWeightOpTensor(this); - TestBinaryWeightOpTensor(this); - TestBinaryWeightOpTensor(this); - TestBinaryWeightOpTensor(this); - TestBinaryWeightOpTensor(this); - TestBinaryWeightOpTensor(this); - TestBinaryWeightOpTensor(this); - // FP16 tests - // TODO(tmorris): Use templates to avoid duplication. - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - TestBinaryTensorOpTensor(this); - // Test with operand R = Weights - TestBinaryTensorOpWeight(this); - TestBinaryTensorOpWeight(this); - TestBinaryTensorOpWeight(this); - TestBinaryTensorOpWeight(this); - TestBinaryTensorOpWeight(this); - TestBinaryTensorOpWeight(this); - TestBinaryTensorOpWeight(this); - TestBinaryTensorOpWeight(this); - // Test with operand L = Weights - TestBinaryWeightOpTensor(this); - TestBinaryWeightOpTensor(this); - TestBinaryWeightOpTensor(this); - TestBinaryWeightOpTensor(this); - TestBinaryWeightOpTensor(this); - TestBinaryWeightOpTensor(this); - TestBinaryWeightOpTensor(this); - TestBinaryWeightOpTensor(this); + // Test combinations of tensor vs weight inputs (except when both inputs are + // weights). + for (const bool operand_1_is_tensor : {true, false}) { + for (const bool operand_2_is_tensor : {true, false}) { + if (!operand_1_is_tensor && !operand_2_is_tensor) continue; + // FP32 tests + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + // FP16 tests + // TODO(tmorris): Use templates to avoid duplication. + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + } + } } TEST_F(OpConverterTest, ConvertQuantize) { From bbb0a7c34c9e1dbcd72a2457307c59070e5a0ed0 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Wed, 22 May 2019 13:05:05 -0700 Subject: [PATCH 09/15] Undo some stuff that got moved around unecessarily --- tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc | 9 ++++----- .../compiler/tf2tensorrt/convert/convert_nodes_test.cc | 7 ++++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 331dd879a8f..fcf854393ff 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -3098,9 +3098,12 @@ Status ConvertBinary(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; if (inputs.size() != 2) { - return errors::InvalidArgument("Binary ops require two inputs, at ", + return errors::InvalidArgument(node_def.op(), " got ", inputs.size(), + " inputs but expected 2, at ", node_def.name()); } + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); // Constant folding should have been done by TensorFlow if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) { @@ -3109,8 +3112,6 @@ Status ConvertBinary(OpConverterParams* params) { "both input as constant at: ", node_def.name()); } - TF_RETURN_IF_ERROR( - AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); const TRT_TensorOrWeights& operand_l = inputs.at(0); const TRT_TensorOrWeights& operand_r = inputs.at(1); @@ -3148,8 +3149,6 @@ Status ConvertBinary(OpConverterParams* params) { params->converter->network()->addElementWise(*tensor_l, *tensor_r, op_pair->second); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - - // Pass the output params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); return Status::OK(); } diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 8888caa1cfe..b3494c681d6 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -2080,9 +2080,10 @@ TEST_F(OpConverterTest, ConvertBinary) { NodeDef node_def = MakeNodeDef("my_add", "Add", {num_inputs, "input"}, {{"T", dtype}}); AddTestTensor("input", {1}, /*batch_size=*/1, nvinfer1::DataType::kFLOAT); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - StrCat("Binary ops require two inputs, at my_add").c_str()); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + StrCat("Add got ", std::to_string(num_inputs), + " inputs but expected 2, at my_add") + .c_str());); } { // Both inputs are weights. From c91db055aab1fb7815d2eaf08d517ec49ee0cfc3 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Wed, 22 May 2019 15:03:02 -0700 Subject: [PATCH 10/15] Fix compile errors --- .../compiler/tf2tensorrt/convert/convert_nodes_test.cc | 10 +++++----- tensorflow/compiler/xla/service/gpu/BUILD | 1 + tensorflow/stream_executor/cuda/BUILD | 2 ++ 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index b3494c681d6..5f8873e3610 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -2023,12 +2023,12 @@ void TestBinaryOp(OpConverterTest* test, bool operand_1_is_tensor, DataVec input_data; if (operand_1_is_tensor) { - input_data.emplace_back("input1", - test::AsTensor({CType(3), CType(6)})); + input_data.push_back({"input1", + test::AsTensor({CType(3), CType(6)})}); } if (operand_2_is_tensor) { - input_data.emplace_back("input2", - test::AsTensor({CType(2), CType(3)})); + input_data.push_back({"input2", + test::AsTensor({CType(2), CType(3)})}); } DataVec output_data{{"my_binary", ConstructTensor(4)}}; // Check output dims. @@ -2083,7 +2083,7 @@ TEST_F(OpConverterTest, ConvertBinary) { RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, StrCat("Add got ", std::to_string(num_inputs), " inputs but expected 2, at my_add") - .c_str());); + .c_str()); } { // Both inputs are weights. diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 3ff115d0b50..76443486b81 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -637,6 +637,7 @@ cc_library( hdrs = ["cusolver_context.h"], deps = [ "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cublas_headers", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", diff --git a/tensorflow/stream_executor/cuda/BUILD b/tensorflow/stream_executor/cuda/BUILD index d21b7111642..277e11aa216 100644 --- a/tensorflow/stream_executor/cuda/BUILD +++ b/tensorflow/stream_executor/cuda/BUILD @@ -214,6 +214,7 @@ cc_library( textual_hdrs = glob(["cublas_*.inc"]), deps = if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cublas_headers", "//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/platform:dso_loader", ]), @@ -244,6 +245,7 @@ cc_library( "@com_google_absl//absl/strings", "//third_party/eigen3", "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cublas_headers", "//tensorflow/core:lib_internal", "//tensorflow/stream_executor", "//tensorflow/stream_executor:event", From 5cece961120d081563af06f9c82b941ccb1db3dc Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Wed, 22 May 2019 15:13:02 -0700 Subject: [PATCH 11/15] Remove special case for rank(lhs) == rank(rhs) since that is covered already and it was skipping the feasibility check --- tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index fcf854393ff..22607d7a93d 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -420,15 +420,6 @@ Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l, TF_RETURN_IF_ERROR(RemoveBatchDimension(operand_r_new_dims)); } - // If the rank of the tensors is already the same, we can't do anything - // further. - if (operand_l_new_dims->nbDims == operand_r_new_dims->nbDims) { - VLOG(2) << "Broadcasted operands to [L] " - << DebugString(*operand_l_new_dims) << " and [R] " - << DebugString(*operand_r_new_dims); - return Status::OK(); - } - const nvinfer1::Dims* higher_rank = (operand_l_new_dims->nbDims > operand_r_new_dims->nbDims) ? operand_l_new_dims From 2f41252aeec254836e277f6367366d4877de4f40 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Thu, 23 May 2019 10:34:47 -0700 Subject: [PATCH 12/15] Revert broadcasting changes --- .../tf2tensorrt/convert/convert_nodes.cc | 124 ++++++++++-------- .../tf2tensorrt/convert/convert_nodes_test.cc | 41 ++++-- 2 files changed, 95 insertions(+), 70 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 22607d7a93d..cfb0a6c3a3d 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -390,70 +390,78 @@ Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l, nvinfer1::Dims* operand_l_new_dims, nvinfer1::Dims* operand_r_new_dims) { // TensorRT Elementwise op supports broadcast but requires both tensor to be - // of Identical rank. - // This function broadcasts the lower rank dimension across the higher rank - // one. - (*operand_l_new_dims) = operand_l.GetTrtDims(); - (*operand_r_new_dims) = operand_r.GetTrtDims(); - - // Weights may include a dimension which must be broadcasted against a - // tensor's batch dimension. This occurs when the rank of the weights is - // larger than the rank of the tensor. Example: - // t: [1, 1] w/ implicit batch size of 1 - // w: [1, 1, 1] - /// ^ this dimension in w needs to be broadcasted against t's batch dim. - // where the output in TRT is expected to be 2D, not 3D. - if (operand_l.is_weights() && - operand_l_new_dims->nbDims > operand_r_new_dims->nbDims) { - if (operand_l_new_dims->d[0] != 1) { - return errors::InvalidArgument( - "Cannot broadcast weights with non-trivial batch dimension"); - } - TF_RETURN_IF_ERROR(RemoveBatchDimension(operand_l_new_dims)); - } - if (operand_r.is_weights() && - operand_r_new_dims->nbDims > operand_l_new_dims->nbDims) { - if (operand_r_new_dims->d[0] != 1) { - return errors::InvalidArgument( - "Cannot broadcast weights with non-trivial batch dimension"); - } - TF_RETURN_IF_ERROR(RemoveBatchDimension(operand_r_new_dims)); + // of Identical rank + // + // We consider case of: + // 1. operand_l to be a Tensor & operand_r to be a Const; + // 2. operand_l to be a Tensor & operand_r to be a Tensor; + // note: const op const (constant folding) should fallback to TensorFlow + // + // broadcast scheme: + // T: 1 3 5 (tensor would not have batch dimension) + // W: 1 1 3 1 (weight would have all explicit dimensions) + // i. fill in explicit dimensions + // -> T: -1 1 3 5 (we put a -1 for batch dimension) + // -> W: 1 1 3 1 + // ii. compare broadcast feasibility + // + // We cannot support the following since TensorRT does not allow manipulation + // on batch dimension, we cannot generate output with proper shape + // T: 3 5 1 + // W: 1 1 1 1 3 5 1 + // -> T: 1 1 1 -1 3 5 1 + // -> W: 1 1 1 1 3 5 1 + // *************************************************************************** + if (!operand_l.is_tensor() && !operand_r.is_tensor()) { + return errors::InvalidArgument( + "Broadcasting requires at least one of the operands be tensors"); } - const nvinfer1::Dims* higher_rank = - (operand_l_new_dims->nbDims > operand_r_new_dims->nbDims) - ? operand_l_new_dims - : operand_r_new_dims; - nvinfer1::Dims* lower_rank = - (operand_l_new_dims->nbDims <= operand_r_new_dims->nbDims) - ? operand_l_new_dims - : operand_r_new_dims; + const int max_nb_dims = nvinfer1::Dims::MAX_DIMS + 1; + auto compute_output_dims = [](const TRT_TensorOrWeights& input, + int broadcast_num_dims, int* output_dims_array, + nvinfer1::Dims* output_dims) { + const nvinfer1::Dims input_dims = input.GetTrtDims(); + std::fill(output_dims_array, output_dims_array + max_nb_dims, 1); + std::copy(input_dims.d, input_dims.d + input_dims.nbDims, + output_dims_array + broadcast_num_dims - input_dims.nbDims); + if (input.is_tensor()) { + const int true_input_dims = input_dims.nbDims + 1; + if (true_input_dims < broadcast_num_dims) { + return errors::InvalidArgument( + "Broadcasting beyond batch dimension is not supported ", + "(tensor #dims ", true_input_dims, " vs broadcast #dims ", + broadcast_num_dims, ")"); + } + // Set the batch dimension to -1, since batch size is not supposed to + // be broadcasted. + output_dims_array[0] = -1; + } + // Copy to output dimensions (stripping the batch dimension). + output_dims->nbDims = broadcast_num_dims - 1; + std::copy(output_dims_array + 1, output_dims_array + broadcast_num_dims, + output_dims->d); + return Status::OK(); + }; - // Broadcasts low_rank over high_rank in-place by inserting ones at the front - // of low_rank so the ranks match. - constexpr auto broadcast_dims = [](const nvinfer1::Dims& high_rank, - const nvinfer1::Dims& low_rank) { - nvinfer1::Dims ret{high_rank.nbDims}; - std::fill(ret.d, ret.d + ret.nbDims, 1); - int num_leading_ones = high_rank.nbDims - low_rank.nbDims; - std::copy(low_rank.d, low_rank.d + low_rank.nbDims, - ret.d + num_leading_ones); - return ret; - }; - - (*lower_rank) = broadcast_dims(*higher_rank, *lower_rank); - VLOG(2) << "Broadcasted operands to [L] " << DebugString(*operand_l_new_dims) - << " and [R] " << DebugString(*operand_r_new_dims); + // Compute the output dimensions. + const int broadcast_num_dims = + std::max(operand_l.GetTrtDims().nbDims + (operand_l.is_tensor() ? 1 : 0), + operand_r.GetTrtDims().nbDims + (operand_r.is_tensor() ? 1 : 0)); + int output_l[max_nb_dims], output_r[max_nb_dims]; + TF_RETURN_IF_ERROR(compute_output_dims(operand_l, broadcast_num_dims, + output_l, operand_l_new_dims)); + TF_RETURN_IF_ERROR(compute_output_dims(operand_r, broadcast_num_dims, + output_r, operand_r_new_dims)); // Compare broadcast feasibility - for (int i = 0; i < operand_r_new_dims->nbDims; ++i) { - if ((operand_l_new_dims->d[i] != operand_r_new_dims->d[i]) && - (operand_l_new_dims->d[i] != 1) && (operand_r_new_dims->d[i] != 1)) { + for (int i = 0; i < broadcast_num_dims; ++i) { + if ((output_l[i] != output_r[i]) && (output_l[i] != 1) && + (output_r[i] != 1)) { return errors::InvalidArgument( - "Infeasible broadcast scheme (batch_dim: ", operand_l_new_dims->d[0], - ", ", DebugString(*operand_l_new_dims), " vs batch_dim: ", - operand_r_new_dims->d[0], ", ", DebugString(*operand_r_new_dims), - ")"); + "Infeasible broadcast scheme (", "batch_dim: ", output_l[0], ", ", + DebugString(*operand_l_new_dims), " vs ", "batch_dim: ", output_r[0], + ", ", DebugString(*operand_r_new_dims), ")"); } } return Status::OK(); diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 5f8873e3610..b6082807ca1 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -1005,8 +1005,10 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) { } }; - // Both inputs are weights. This should be handled by constfold grappler. - symmetric_test({1}, {1}, kIsNotTensor, kIsNotTensor, {1}, {1}); + // Both inputs are weights. + symmetric_test( + {1}, {1}, kIsNotTensor, kIsNotTensor, {}, {}, error::INVALID_ARGUMENT, + "Broadcasting requires at least one of the operands be tensors"); // One tensor and one weights. symmetric_test({1, 1, 1}, {2}, kIsTensor, kIsNotTensor, {1, 1, 1}, {1, 1, 2}); @@ -1014,7 +1016,6 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) { symmetric_test({1, 3, 2}, {1}, kIsTensor, kIsNotTensor, {1, 3, 2}, {1, 1, 1}); symmetric_test({1, 1, 1}, {2, 3}, kIsTensor, kIsNotTensor, {1, 1, 1}, {1, 2, 3}); - symmetric_test({1, 1, 1}, {2, 3}, kIsTensor, kIsTensor, {1, 1, 1}, {1, 2, 3}); symmetric_test({1, 1, 1}, {2, 3, 4}, kIsTensor, kIsNotTensor, {1, 1, 1}, {2, 3, 4}); symmetric_test({1, 1, 1}, {1, 2, 3, 4}, kIsTensor, kIsNotTensor, {1, 1, 1}, @@ -1022,21 +1023,37 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) { symmetric_test({1, 3, 4}, {1, 2, 1, 4}, kIsTensor, kIsNotTensor, {1, 3, 4}, {2, 1, 4}); symmetric_test({1, 1, 1}, {2, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {}, - error::INVALID_ARGUMENT, - "Cannot broadcast weights with non-trivial batch dimension"); + error::INVALID_ARGUMENT, "Infeasible broadcast scheme"); symmetric_test({1, 1, 1}, {2, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {}, - error::INVALID_ARGUMENT, - "Cannot broadcast weights with non-trivial batch dimension", + error::INVALID_ARGUMENT, "Infeasible broadcast scheme", + /*operand_1_batch_size=*/2); + symmetric_test({1, 1, 1}, {1, 1, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {}, + error::INVALID_ARGUMENT, + "Broadcasting beyond batch dimension is not supported " + "(tensor #dims 4 vs broadcast #dims 5)"); + symmetric_test({3}, {1, 1, 3}, kIsTensor, kIsNotTensor, {}, {}, + error::INVALID_ARGUMENT, + "Broadcasting beyond batch dimension is not supported " + "(tensor #dims 2 vs broadcast #dims 3)", /*operand_1_batch_size=*/2); - symmetric_test({1, 1, 1}, {1, 1, 1, 1, 1}, kIsTensor, kIsNotTensor, - {1, 1, 1, 1}, {1, 1, 1, 1}); // Both inputs are tensors. - symmetric_test({1, 1, 1}, {1, 1}, kIsTensor, kIsTensor, {1, 1, 1}, {1, 1, 1}); + symmetric_test({1, 1, 1}, {1, 1}, kIsTensor, kIsTensor, {}, {}, + error::INVALID_ARGUMENT, + "Broadcasting beyond batch dimension is not supported " + "(tensor #dims 3 vs broadcast #dims 4)"); + symmetric_test({1, 3}, {3}, kIsTensor, kIsTensor, {}, {}, + error::INVALID_ARGUMENT, + "Broadcasting beyond batch dimension is not supported " + "(tensor #dims 2 vs broadcast #dims 3)"); symmetric_test({1, 3, 4}, {2, 1, 4}, kIsTensor, kIsTensor, {1, 3, 4}, {2, 1, 4}); - symmetric_test({1, 1, 1}, {1, 1, 1, 1}, kIsTensor, kIsTensor, {1, 1, 1, 1}, - {1, 1, 1, 1}); + symmetric_test({1, 1, 1}, {1, 1, 1, 1}, kIsTensor, kIsTensor, {}, {}, + error::INVALID_ARGUMENT, + "Broadcasting beyond batch dimension is not supported " + "(tensor #dims 4 vs broadcast #dims 5)"); + symmetric_test({2, 3}, {7, 5}, kIsTensor, kIsTensor, {}, {}, + error::INVALID_ARGUMENT, "Infeasible broadcast scheme"); } TEST_F(ConverterTest, CreateConstantLayer) { From a4fc90d6a2cddfb269e6d821cacf1fb838a38886 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Thu, 23 May 2019 10:38:07 -0700 Subject: [PATCH 13/15] Fix formatting --- .../tf2tensorrt/convert/convert_nodes.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index cfb0a6c3a3d..9b56fd56355 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -436,32 +436,32 @@ Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l, // Set the batch dimension to -1, since batch size is not supposed to // be broadcasted. output_dims_array[0] = -1; - } + } // Copy to output dimensions (stripping the batch dimension). output_dims->nbDims = broadcast_num_dims - 1; std::copy(output_dims_array + 1, output_dims_array + broadcast_num_dims, output_dims->d); return Status::OK(); - }; + }; // Compute the output dimensions. const int broadcast_num_dims = std::max(operand_l.GetTrtDims().nbDims + (operand_l.is_tensor() ? 1 : 0), - operand_r.GetTrtDims().nbDims + (operand_r.is_tensor() ? 1 : 0)); + operand_r.GetTrtDims().nbDims + (operand_r.is_tensor() ? 1 : 0)); int output_l[max_nb_dims], output_r[max_nb_dims]; TF_RETURN_IF_ERROR(compute_output_dims(operand_l, broadcast_num_dims, - output_l, operand_l_new_dims)); + output_l, operand_l_new_dims)); TF_RETURN_IF_ERROR(compute_output_dims(operand_r, broadcast_num_dims, - output_r, operand_r_new_dims)); + output_r, operand_r_new_dims)); // Compare broadcast feasibility for (int i = 0; i < broadcast_num_dims; ++i) { if ((output_l[i] != output_r[i]) && (output_l[i] != 1) && (output_r[i] != 1)) { return errors::InvalidArgument( - "Infeasible broadcast scheme (", "batch_dim: ", output_l[0], ", ", - DebugString(*operand_l_new_dims), " vs ", "batch_dim: ", output_r[0], - ", ", DebugString(*operand_r_new_dims), ")"); + "Infeasible broadcast scheme (", "batch_dim: ", output_l[0], ", ", + DebugString(*operand_l_new_dims), " vs ", "batch_dim: ", output_r[0], + ", ", DebugString(*operand_r_new_dims), ")"); } } return Status::OK(); From f5c20bd41b81a18446b970230c976aca4136f3e6 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Thu, 23 May 2019 10:40:07 -0700 Subject: [PATCH 14/15] Revert accidently changes --- tensorflow/compiler/xla/service/gpu/BUILD | 1 - tensorflow/stream_executor/cuda/BUILD | 2 -- 2 files changed, 3 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 76443486b81..3ff115d0b50 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -637,7 +637,6 @@ cc_library( hdrs = ["cusolver_context.h"], deps = [ "@local_config_cuda//cuda:cuda_headers", - "@local_config_cuda//cuda:cublas_headers", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", diff --git a/tensorflow/stream_executor/cuda/BUILD b/tensorflow/stream_executor/cuda/BUILD index 277e11aa216..d21b7111642 100644 --- a/tensorflow/stream_executor/cuda/BUILD +++ b/tensorflow/stream_executor/cuda/BUILD @@ -214,7 +214,6 @@ cc_library( textual_hdrs = glob(["cublas_*.inc"]), deps = if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", - "@local_config_cuda//cuda:cublas_headers", "//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/platform:dso_loader", ]), @@ -245,7 +244,6 @@ cc_library( "@com_google_absl//absl/strings", "//third_party/eigen3", "@local_config_cuda//cuda:cuda_headers", - "@local_config_cuda//cuda:cublas_headers", "//tensorflow/core:lib_internal", "//tensorflow/stream_executor", "//tensorflow/stream_executor:event", From dd8929321fc51567d5026fff49e635dcb2678334 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Thu, 23 May 2019 13:48:36 -0700 Subject: [PATCH 15/15] Use batch size 2 in binary op test for tensors --- .../tf2tensorrt/convert/convert_nodes_test.cc | 70 +++++++++++-------- 1 file changed, 40 insertions(+), 30 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index b6082807ca1..09b7a60c083 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -2023,14 +2023,14 @@ void TestBinaryOp(OpConverterTest* test, bool operand_1_is_tensor, const NodeDef node_def = GetBinaryOpNodeDef("input1", "input2", dtype); if (operand_1_is_tensor) { - test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/1, + test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/2, TfDataTypeToTrt(dtype)); } else { test->AddTestWeights("input1", /*dims=*/{1, 2}, /*values=*/std::vector{CType(3), CType(6)}); } if (operand_2_is_tensor) { - test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/1, + test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/2, TfDataTypeToTrt(dtype)); } else { test->AddTestWeights("input2", /*dims=*/{2, 1}, @@ -2040,14 +2040,16 @@ void TestBinaryOp(OpConverterTest* test, bool operand_1_is_tensor, DataVec input_data; if (operand_1_is_tensor) { - input_data.push_back({"input1", - test::AsTensor({CType(3), CType(6)})}); + input_data.push_back( + {"input1", + test::AsTensor({CType(3), CType(6), CType(3), CType(6)})}); } if (operand_2_is_tensor) { - input_data.push_back({"input2", - test::AsTensor({CType(2), CType(3)})}); + input_data.push_back( + {"input2", + test::AsTensor({CType(2), CType(3), CType(2), CType(3)})}); } - DataVec output_data{{"my_binary", ConstructTensor(4)}}; + DataVec output_data{{"my_binary", ConstructTensor(8)}}; // Check output dims. TRT_TensorOrWeights output; TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); @@ -2055,33 +2057,41 @@ void TestBinaryOp(OpConverterTest* test, bool operand_1_is_tensor, ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions()); // After broadcasting first input becomes {3, 6, 3, 6} and second input // becomes {2, 3, 2, 3}. - test->BuildAndRun(input_data, &output_data, dtype == DT_HALF - ? TrtPrecisionMode::FP16 - : TrtPrecisionMode::FP32); + test->BuildAndRun( + input_data, &output_data, + dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32, + /*batch_size=*/2); if (node_def.op() == "Add") { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(5), CType(8), CType(6), CType(9))); + EXPECT_THAT( + GetSpanForData(output_data[0]), + ElementsAreArray(CastTestVector({5, 8, 6, 9, 5, 8, 6, 9}))); } else if (node_def.op() == "Sub") { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(1), CType(4), CType(0), CType(3))); + EXPECT_THAT( + GetSpanForData(output_data[0]), + ElementsAreArray(CastTestVector({1, 4, 0, 3, 1, 4, 0, 3}))); } else if (node_def.op() == "Mul") { EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(6), CType(12), CType(9), CType(18))); + ElementsAreArray( + CastTestVector({6, 12, 9, 18, 6, 12, 9, 18}))); } else if (node_def.op() == "Div") { EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(1.5), CType(3), CType(1), CType(2))); + ElementsAreArray(CastTestVector( + {1.5, 3, 1, 2, 1.5, 3, 1, 2}))); } else if (node_def.op() == "RealDiv") { EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(1.5), CType(3), CType(1), CType(2))); + ElementsAreArray(CastTestVector( + {1.5, 3, 1, 2, 1.5, 3, 1, 2}))); } else if (node_def.op() == "Minimum") { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(2), CType(2), CType(3), CType(3))); + EXPECT_THAT( + GetSpanForData(output_data[0]), + ElementsAreArray(CastTestVector({2, 2, 3, 3, 2, 2, 3, 3}))); } else if (node_def.op() == "Maximum") { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(3), CType(6), CType(3), CType(6))); + EXPECT_THAT( + GetSpanForData(output_data[0]), + ElementsAreArray(CastTestVector({3, 6, 3, 6, 3, 6, 3, 6}))); } else if (node_def.op() == "Pow") { ExpectArrayNear( - std::vector{CType(9), CType(36), CType(27), CType(216)}, + CastTestVector({9, 36, 27, 216, 9, 36, 27, 216}), GetSpanForData(output_data[0])); } else { ASSERT_TRUE(false); @@ -2139,14 +2149,14 @@ TEST_F(OpConverterTest, ConvertBinary) { operand_2_is_tensor); // FP16 tests // TODO(tmorris): Use templates to avoid duplication. - TestBinaryOp(this, operand_1_is_tensor, - operand_2_is_tensor); - TestBinaryOp(this, operand_1_is_tensor, - operand_2_is_tensor); - TestBinaryOp(this, operand_1_is_tensor, - operand_2_is_tensor); - TestBinaryOp(this, operand_1_is_tensor, - operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); + TestBinaryOp(this, operand_1_is_tensor, + operand_2_is_tensor); TestBinaryOp(this, operand_1_is_tensor, operand_2_is_tensor); TestBinaryOp(this, operand_1_is_tensor,