Support fp16 conversion for conversion of Square, Relu and LeakyRelu.
This also fix a bug in GatherV2 converter where it uses the wrong data type attribute name. Will enable it for more converters later. PiperOrigin-RevId: 238485945
This commit is contained in:
parent
7239d65880
commit
762093140b
@ -326,6 +326,7 @@ tf_cuda_cc_test(
|
|||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
"//tensorflow/cc:cc_ops",
|
"//tensorflow/cc:cc_ops",
|
||||||
|
"//tensorflow/cc:cc_ops_internal",
|
||||||
"//tensorflow/cc:ops",
|
"//tensorflow/cc:ops",
|
||||||
"//tensorflow/cc:scope",
|
"//tensorflow/cc:scope",
|
||||||
"//tensorflow/core/grappler/costs:graph_properties",
|
"//tensorflow/core/grappler/costs:graph_properties",
|
||||||
|
@ -116,6 +116,88 @@ inline Status ConvertDType(DataType tf_dtype, nvinfer1::DataType* trt_dtype) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class TFAttrs {
|
||||||
|
public:
|
||||||
|
explicit TFAttrs(const NodeDef& tf_node) {
|
||||||
|
for (const auto& attr : tf_node.attr()) {
|
||||||
|
attrs_.insert({attr.first, &attr.second});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool count(const string& key) const { return attrs_.count(key); }
|
||||||
|
|
||||||
|
AttrValue const* at(const string& key) const {
|
||||||
|
if (!attrs_.count(key)) {
|
||||||
|
LOG(FATAL) << "Attribute not found: " << key;
|
||||||
|
}
|
||||||
|
return attrs_.at(key);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T get(const string& key) const;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T get(const string& key, const T& default_value) const {
|
||||||
|
return attrs_.count(key) ? this->get<T>(key) : default_value;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<string> GetAllAttrKeys() const {
|
||||||
|
std::vector<string> attr_list;
|
||||||
|
for (const auto& attr_item : attrs_) {
|
||||||
|
attr_list.emplace_back(attr_item.first);
|
||||||
|
}
|
||||||
|
return attr_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
typedef std::map<string, AttrValue const*> AttrMap;
|
||||||
|
AttrMap attrs_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
string TFAttrs::get<string>(const string& key) const {
|
||||||
|
return this->at(key)->s();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
std::vector<int64> TFAttrs::get<std::vector<int64>>(const string& key) const {
|
||||||
|
auto attr = this->at(key)->list().i();
|
||||||
|
return std::vector<int64>(attr.begin(), attr.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
std::vector<float> TFAttrs::get<std::vector<float>>(const string& key) const {
|
||||||
|
auto attr = this->at(key)->list().f();
|
||||||
|
return std::vector<float>(attr.begin(), attr.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(const string& key) const {
|
||||||
|
nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT);
|
||||||
|
TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype));
|
||||||
|
return trt_dtype;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
DataType TFAttrs::get<DataType>(const string& key) const {
|
||||||
|
return this->at(key)->type();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
float TFAttrs::get<float>(const string& key) const {
|
||||||
|
return this->at(key)->f();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
bool TFAttrs::get<bool>(const string& key) const {
|
||||||
|
return this->at(key)->b();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
int64 TFAttrs::get<int64>(const string& key) const {
|
||||||
|
return this->at(key)->i();
|
||||||
|
}
|
||||||
|
|
||||||
template <typename TensorShapeType>
|
template <typename TensorShapeType>
|
||||||
inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape,
|
inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape,
|
||||||
bool ignore_first_dim) {
|
bool ignore_first_dim) {
|
||||||
@ -379,17 +461,35 @@ nvinfer1::ITensor* Converter::CreateConstantLayer(
|
|||||||
|
|
||||||
Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value,
|
Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value,
|
||||||
const nvinfer1::Dims& dims,
|
const nvinfer1::Dims& dims,
|
||||||
const nvinfer1::ITensor** tensor) {
|
const nvinfer1::ITensor** tensor,
|
||||||
|
const char* dtype_attr_name = "T") {
|
||||||
|
TFAttrs attrs(params->node_def);
|
||||||
|
DataType dtype;
|
||||||
|
if (attrs.count(dtype_attr_name)) {
|
||||||
|
dtype = attrs.get<DataType>(dtype_attr_name);
|
||||||
|
} else {
|
||||||
|
dtype = DT_FLOAT; // Default to FP32.
|
||||||
|
}
|
||||||
|
|
||||||
// In order to be broadcastable, the number of dims has to match.
|
// In order to be broadcastable, the number of dims has to match.
|
||||||
nvinfer1::Dims broadcastable_dims(dims);
|
nvinfer1::Dims broadcastable_dims(dims);
|
||||||
for (int i = 0; i < broadcastable_dims.nbDims; i++) {
|
for (int i = 0; i < broadcastable_dims.nbDims; i++) {
|
||||||
broadcastable_dims.d[i] = 1;
|
broadcastable_dims.d[i] = 1;
|
||||||
}
|
}
|
||||||
TRT_ShapedWeights weights = params->weight_store->GetTempWeights(
|
TRT_ShapedWeights weights =
|
||||||
DataType::DT_FLOAT, broadcastable_dims);
|
params->weight_store->GetTempWeights(dtype, broadcastable_dims);
|
||||||
auto weights_ptr =
|
void* raw_ptr = const_cast<void*>(weights.GetValues());
|
||||||
static_cast<float*>(const_cast<void*>(weights.GetValues()));
|
switch (dtype) {
|
||||||
weights_ptr[0] = value;
|
case DataType::DT_FLOAT:
|
||||||
|
static_cast<float*>(raw_ptr)[0] = value;
|
||||||
|
break;
|
||||||
|
case DataType::DT_HALF:
|
||||||
|
static_cast<Eigen::half*>(raw_ptr)[0] = Eigen::half(value);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return errors::InvalidArgument("Unsupported data type ",
|
||||||
|
DataTypeString(dtype));
|
||||||
|
}
|
||||||
*tensor = params->converter->CreateConstantLayer(weights, broadcastable_dims);
|
*tensor = params->converter->CreateConstantLayer(weights, broadcastable_dims);
|
||||||
TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, params->node_def.name());
|
TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, params->node_def.name());
|
||||||
params->converter->ProvideQuantizationRange(
|
params->converter->ProvideQuantizationRange(
|
||||||
@ -662,88 +762,6 @@ string TRT_TensorOrWeights::DebugString() const {
|
|||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
class TFAttrs {
|
|
||||||
public:
|
|
||||||
explicit TFAttrs(const NodeDef& tf_node) {
|
|
||||||
for (const auto& attr : tf_node.attr()) {
|
|
||||||
attrs_.insert({attr.first, &attr.second});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool count(const string& key) const { return attrs_.count(key); }
|
|
||||||
|
|
||||||
AttrValue const* at(const string& key) const {
|
|
||||||
if (!attrs_.count(key)) {
|
|
||||||
LOG(FATAL) << "Attribute not found: " << key;
|
|
||||||
}
|
|
||||||
return attrs_.at(key);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
T get(const string& key) const;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
T get(const string& key, const T& default_value) const {
|
|
||||||
return attrs_.count(key) ? this->get<T>(key) : default_value;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<string> GetAllAttrKeys() const {
|
|
||||||
std::vector<string> attr_list;
|
|
||||||
for (const auto& attr_item : attrs_) {
|
|
||||||
attr_list.emplace_back(attr_item.first);
|
|
||||||
}
|
|
||||||
return attr_list;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
typedef std::map<string, AttrValue const*> AttrMap;
|
|
||||||
AttrMap attrs_;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
string TFAttrs::get<string>(const string& key) const {
|
|
||||||
return this->at(key)->s();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
std::vector<int64> TFAttrs::get<std::vector<int64>>(const string& key) const {
|
|
||||||
auto attr = this->at(key)->list().i();
|
|
||||||
return std::vector<int64>(attr.begin(), attr.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
std::vector<float> TFAttrs::get<std::vector<float>>(const string& key) const {
|
|
||||||
auto attr = this->at(key)->list().f();
|
|
||||||
return std::vector<float>(attr.begin(), attr.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(const string& key) const {
|
|
||||||
nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT);
|
|
||||||
TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype));
|
|
||||||
return trt_dtype;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
DataType TFAttrs::get<DataType>(const string& key) const {
|
|
||||||
return this->at(key)->type();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
float TFAttrs::get<float>(const string& key) const {
|
|
||||||
return this->at(key)->f();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
bool TFAttrs::get<bool>(const string& key) const {
|
|
||||||
return this->at(key)->b();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
int64 TFAttrs::get<int64>(const string& key) const {
|
|
||||||
return this->at(key)->i();
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(jie): reorder4 & reorder2 should be merged?
|
// TODO(jie): reorder4 & reorder2 should be merged?
|
||||||
// TODO(aaroey): fix the order of parameters.
|
// TODO(aaroey): fix the order of parameters.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -1435,11 +1453,15 @@ Status CheckInputsWeights(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status AllowDataTypes(const OpConverterParams& params,
|
Status AllowDataTypes(const OpConverterParams& params,
|
||||||
const std::set<DataType>& allowed_dtypes) {
|
const std::set<DataType>& allowed_dtypes,
|
||||||
|
const char* dtype_attr_name = "T") {
|
||||||
const auto& node_def = params.node_def;
|
const auto& node_def = params.node_def;
|
||||||
TFAttrs attrs(params.node_def);
|
TFAttrs attrs(node_def);
|
||||||
if (attrs.count("T")) {
|
if (!attrs.count(dtype_attr_name)) {
|
||||||
const auto op_dtype = attrs.get<DataType>("T");
|
return errors::InvalidArgument("Attribute with name ", dtype_attr_name,
|
||||||
|
" not found.");
|
||||||
|
}
|
||||||
|
const auto op_dtype = attrs.get<DataType>(dtype_attr_name);
|
||||||
if (!allowed_dtypes.count(op_dtype)) {
|
if (!allowed_dtypes.count(op_dtype)) {
|
||||||
// Build string list of allowed types.
|
// Build string list of allowed types.
|
||||||
std::ostringstream ss;
|
std::ostringstream ss;
|
||||||
@ -1452,9 +1474,6 @@ Status AllowDataTypes(const OpConverterParams& params,
|
|||||||
", must be one of [", ss.str(), "], at ",
|
", must be one of [", ss.str(), "], at ",
|
||||||
node_def.name());
|
node_def.name());
|
||||||
}
|
}
|
||||||
}
|
|
||||||
// If there is no T attribute, we can't determine the type of the op. We will
|
|
||||||
// allow it to convert for now.
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3696,7 +3715,8 @@ Status ConvertGather(OpConverterParams* params) {
|
|||||||
TF_RETURN_IF_ERROR(CheckInputsWeights(
|
TF_RETURN_IF_ERROR(CheckInputsWeights(
|
||||||
*params, {{"params", false}, {"indices", false}, {"axis", true}}));
|
*params, {{"params", false}, {"indices", false}, {"axis", true}}));
|
||||||
TF_RETURN_IF_ERROR(AllowDataTypes(
|
TF_RETURN_IF_ERROR(AllowDataTypes(
|
||||||
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
|
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32},
|
||||||
|
/*dtype_attr_name=*/"Tparams"));
|
||||||
absl::Span<const int> axis = inputs.at(2).weights().GetSpan<int>();
|
absl::Span<const int> axis = inputs.at(2).weights().GetSpan<int>();
|
||||||
if (axis.size() != 1) {
|
if (axis.size() != 1) {
|
||||||
return errors::InvalidArgument("Axis for GatherV2 must be a scalar, at ",
|
return errors::InvalidArgument("Axis for GatherV2 must be a scalar, at ",
|
||||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/cc/framework/ops.h"
|
#include "tensorflow/cc/framework/ops.h"
|
||||||
#include "tensorflow/cc/framework/scope.h"
|
#include "tensorflow/cc/framework/scope.h"
|
||||||
|
#include "tensorflow/cc/ops/nn_ops_internal.h"
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
|
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
|
||||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
||||||
@ -109,13 +110,17 @@ DataType TrtDataTypeToTf(nvinfer1::DataType trt_dtype) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
NodeDef MakeNodeDef(const string& name, const string& op,
|
NodeDef MakeNodeDef(const string& name, const string& op,
|
||||||
const std::vector<string>& inputs) {
|
const std::vector<string>& inputs,
|
||||||
|
const std::map<string, AttrValue> attrs = {}) {
|
||||||
NodeDef node_def;
|
NodeDef node_def;
|
||||||
node_def.set_name(name);
|
node_def.set_name(name);
|
||||||
node_def.set_op(op);
|
node_def.set_op(op);
|
||||||
for (const string& input : inputs) {
|
for (const string& input : inputs) {
|
||||||
node_def.add_input(input);
|
node_def.add_input(input);
|
||||||
}
|
}
|
||||||
|
for (const auto& attr : attrs) {
|
||||||
|
(*node_def.mutable_attr())[attr.first] = attr.second;
|
||||||
|
}
|
||||||
return node_def;
|
return node_def;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1094,8 +1099,22 @@ class OpConverterTest : public ::testing::Test {
|
|||||||
validator_inputs_.clear();
|
validator_inputs_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(laigd): test fp16 and int8 support.
|
void CheckDataTypeMatches(const DataVec& datas) {
|
||||||
void BuildAndRun(const DataVec& input_data, DataVec* output_data) {
|
for (const auto& data : datas) {
|
||||||
|
const int input_index = engine_->getBindingIndex(data.name);
|
||||||
|
ASSERT_NE(-1, input_index);
|
||||||
|
const nvinfer1::DataType trt_dtype =
|
||||||
|
engine_->getBindingDataType(input_index);
|
||||||
|
const DataType tf_dtype = TrtDataTypeToTf(trt_dtype);
|
||||||
|
ASSERT_EQ(data.tensor.dtype(), tf_dtype)
|
||||||
|
<< DataTypeString(data.tensor.dtype()) << " vs. "
|
||||||
|
<< DataTypeString(tf_dtype);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(laigd): test fp16 and int8 support for more converters.
|
||||||
|
void BuildAndRun(const DataVec& input_data, DataVec* output_data,
|
||||||
|
TrtPrecisionMode precision_mode = TrtPrecisionMode::FP32) {
|
||||||
// Mark the output tensor as TRT engine output.
|
// Mark the output tensor as TRT engine output.
|
||||||
std::vector<Converter::EngineOutputInfo> output_info;
|
std::vector<Converter::EngineOutputInfo> output_info;
|
||||||
for (const auto& data : *output_data) {
|
for (const auto& data : *output_data) {
|
||||||
@ -1105,9 +1124,20 @@ class OpConverterTest : public ::testing::Test {
|
|||||||
TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(output_info));
|
TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(output_info));
|
||||||
|
|
||||||
// Build the TRT engine.
|
// Build the TRT engine.
|
||||||
|
if (precision_mode == TrtPrecisionMode::FP16) {
|
||||||
|
builder_->setFp16Mode(true);
|
||||||
|
} else if (precision_mode == TrtPrecisionMode::INT8) {
|
||||||
|
// Setting FP16 mode as well allows TRT to also consider FP16 kernels and
|
||||||
|
// use them in situations where they are faster than INT8 or where INT8 is
|
||||||
|
// not supported for a given layer.
|
||||||
|
builder_->setFp16Mode(true);
|
||||||
|
builder_->setInt8Mode(true);
|
||||||
|
}
|
||||||
ASSERT_EQ(nullptr, engine_.get());
|
ASSERT_EQ(nullptr, engine_.get());
|
||||||
engine_.reset(builder_->buildCudaEngine(*converter_->network()));
|
engine_.reset(builder_->buildCudaEngine(*converter_->network()));
|
||||||
CHECK_NOTNULL(engine_.get());
|
CHECK_NOTNULL(engine_.get());
|
||||||
|
CheckDataTypeMatches(input_data);
|
||||||
|
CheckDataTypeMatches(*output_data);
|
||||||
|
|
||||||
// Execute the TRT engine.
|
// Execute the TRT engine.
|
||||||
const int num_bindings = input_data.size() + output_data->size();
|
const int num_bindings = input_data.size() + output_data->size();
|
||||||
@ -1761,7 +1791,9 @@ void TestBinaryTensorOpWeightNoBroadcast(OpConverterTest* test) {
|
|||||||
const DataVec input_data{
|
const DataVec input_data{
|
||||||
{"input", test::AsTensor<CType>(swap_inputs ? operand2 : operand1)}};
|
{"input", test::AsTensor<CType>(swap_inputs ? operand2 : operand1)}};
|
||||||
DataVec output_data{{"my_binary", ConstructTensor<CType>(2)}};
|
DataVec output_data{{"my_binary", ConstructTensor<CType>(2)}};
|
||||||
test->BuildAndRun(input_data, &output_data);
|
test->BuildAndRun(
|
||||||
|
input_data, &output_data,
|
||||||
|
dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
|
||||||
if (node_def.op() == "Add") {
|
if (node_def.op() == "Add") {
|
||||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||||
ElementsAre(CType(5), CType(10.5)));
|
ElementsAre(CType(5), CType(10.5)));
|
||||||
@ -1942,7 +1974,9 @@ void TestBinaryTensorOpTensor(OpConverterTest* test) {
|
|||||||
DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}};
|
DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}};
|
||||||
// After broadcasting first input becomes {3, 6, 3, 6} and second input
|
// After broadcasting first input becomes {3, 6, 3, 6} and second input
|
||||||
// becomes {2, 3, 2, 3}.
|
// becomes {2, 3, 2, 3}.
|
||||||
test->BuildAndRun(input_data, &output_data);
|
test->BuildAndRun(
|
||||||
|
input_data, &output_data,
|
||||||
|
dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
|
||||||
if (node_def.op() == "Add") {
|
if (node_def.op() == "Add") {
|
||||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||||
ElementsAre(CType(5), CType(8), CType(6), CType(9)));
|
ElementsAre(CType(5), CType(8), CType(6), CType(9)));
|
||||||
@ -1974,10 +2008,13 @@ void TestBinaryTensorOpTensor(OpConverterTest* test) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OpConverterTest, ConvertBinary) {
|
TEST_F(OpConverterTest, ConvertBinary) {
|
||||||
|
AttrValue dtype;
|
||||||
|
dtype.set_type(DT_FLOAT);
|
||||||
// Input size doesn't match, should fail.
|
// Input size doesn't match, should fail.
|
||||||
for (size_t num_inputs = 0; num_inputs < 2; ++num_inputs) {
|
for (size_t num_inputs = 0; num_inputs < 2; ++num_inputs) {
|
||||||
Reset();
|
Reset();
|
||||||
NodeDef node_def = MakeNodeDef("my_add", "Add", {num_inputs, "input"});
|
NodeDef node_def =
|
||||||
|
MakeNodeDef("my_add", "Add", {num_inputs, "input"}, {{"T", dtype}});
|
||||||
AddTestTensor("input", {1}, /*batch_size=*/1, nvinfer1::DataType::kFLOAT);
|
AddTestTensor("input", {1}, /*batch_size=*/1, nvinfer1::DataType::kFLOAT);
|
||||||
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
|
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
|
||||||
StrCat("Add got ", std::to_string(num_inputs),
|
StrCat("Add got ", std::to_string(num_inputs),
|
||||||
@ -1987,7 +2024,8 @@ TEST_F(OpConverterTest, ConvertBinary) {
|
|||||||
{
|
{
|
||||||
// Both inputs are weights.
|
// Both inputs are weights.
|
||||||
Reset();
|
Reset();
|
||||||
NodeDef node_def = MakeNodeDef("my_add", "Add", {"weights1", "weights2"});
|
NodeDef node_def =
|
||||||
|
MakeNodeDef("my_add", "Add", {"weights1", "weights2"}, {{"T", dtype}});
|
||||||
AddTestWeights<float>("weights1", {1}, {1});
|
AddTestWeights<float>("weights1", {1}, {1});
|
||||||
AddTestWeights<float>("weights2", {1}, {1});
|
AddTestWeights<float>("weights2", {1}, {1});
|
||||||
RunValidationAndConversion(
|
RunValidationAndConversion(
|
||||||
@ -2002,15 +2040,12 @@ TEST_F(OpConverterTest, ConvertBinary) {
|
|||||||
TestBinaryTensorOpWeightNoBroadcast<ops::Mul, DT_FLOAT>(this);
|
TestBinaryTensorOpWeightNoBroadcast<ops::Mul, DT_FLOAT>(this);
|
||||||
TestBinaryTensorOpWeightNoBroadcast<ops::Div, DT_FLOAT>(this);
|
TestBinaryTensorOpWeightNoBroadcast<ops::Div, DT_FLOAT>(this);
|
||||||
TestBinaryTensorOpWeightNoBroadcast<ops::RealDiv, DT_FLOAT>(this);
|
TestBinaryTensorOpWeightNoBroadcast<ops::RealDiv, DT_FLOAT>(this);
|
||||||
#if 0
|
|
||||||
// TODO(b/119560144): it doesn't support FP16 constants and the following test
|
|
||||||
// will fail.
|
|
||||||
TestBinaryTensorOpWeightNoBroadcast<ops::Add, DT_HALF>(this);
|
TestBinaryTensorOpWeightNoBroadcast<ops::Add, DT_HALF>(this);
|
||||||
TestBinaryTensorOpWeightNoBroadcast<ops::Sub, DT_HALF>(this);
|
TestBinaryTensorOpWeightNoBroadcast<ops::Sub, DT_HALF>(this);
|
||||||
TestBinaryTensorOpWeightNoBroadcast<ops::Mul, DT_HALF>(this);
|
TestBinaryTensorOpWeightNoBroadcast<ops::Mul, DT_HALF>(this);
|
||||||
TestBinaryTensorOpWeightNoBroadcast<ops::Div, DT_HALF>(this);
|
TestBinaryTensorOpWeightNoBroadcast<ops::Div, DT_HALF>(this);
|
||||||
TestBinaryTensorOpWeightNoBroadcast<ops::RealDiv, DT_HALF>(this);
|
TestBinaryTensorOpWeightNoBroadcast<ops::RealDiv, DT_HALF>(this);
|
||||||
#endif
|
|
||||||
|
|
||||||
// Test BinaryTensorOpWeight() with channel-wise broadcasting.
|
// Test BinaryTensorOpWeight() with channel-wise broadcasting.
|
||||||
TestBinaryTensorOpWeightWithChannelWiseBroadcast<DT_FLOAT>(this);
|
TestBinaryTensorOpWeightWithChannelWiseBroadcast<DT_FLOAT>(this);
|
||||||
@ -2192,7 +2227,8 @@ void TestConvertSquare(OpConverterTest* test) {
|
|||||||
auto square = ops::Square(s.WithOpName("my_square"), input);
|
auto square = ops::Square(s.WithOpName("my_square"), input);
|
||||||
NodeDef node_def = square.operation.node()->def();
|
NodeDef node_def = square.operation.node()->def();
|
||||||
|
|
||||||
test->AddTestTensor("input", {1, 20});
|
test->AddTestTensor("input", {1, 20}, /*batch_size=*/1,
|
||||||
|
TfDataTypeToTrt(dtype));
|
||||||
test->RunValidationAndConversion(node_def);
|
test->RunValidationAndConversion(node_def);
|
||||||
TRT_TensorOrWeights output;
|
TRT_TensorOrWeights output;
|
||||||
TF_EXPECT_OK(test->GetTensorOrWeights("my_square", &output));
|
TF_EXPECT_OK(test->GetTensorOrWeights("my_square", &output));
|
||||||
@ -2202,14 +2238,18 @@ void TestConvertSquare(OpConverterTest* test) {
|
|||||||
const int num_inputs = 20;
|
const int num_inputs = 20;
|
||||||
std::vector<CType> inputs(num_inputs);
|
std::vector<CType> inputs(num_inputs);
|
||||||
std::vector<CType> expected_outputs(num_inputs);
|
std::vector<CType> expected_outputs(num_inputs);
|
||||||
for (int i = 0; i < 20; i++) {
|
for (int i = 0; i < num_inputs; ++i) {
|
||||||
const CType value = CType(i - 9);
|
const CType value = CType(i - 9);
|
||||||
inputs[i] = value;
|
inputs[i] = value;
|
||||||
expected_outputs[i] = value * value;
|
expected_outputs[i] = value * value;
|
||||||
}
|
}
|
||||||
const DataVec input_data{{"input", test::AsTensor<CType>(inputs)}};
|
const DataVec input_data{{"input", test::AsTensor<CType>(inputs)}};
|
||||||
|
// Engine outputs are converted to FP16 automatically if we set FP16 mode in
|
||||||
|
// the builder.
|
||||||
DataVec output_data{{"my_square", ConstructTensor<CType>(num_inputs)}};
|
DataVec output_data{{"my_square", ConstructTensor<CType>(num_inputs)}};
|
||||||
test->BuildAndRun(input_data, &output_data);
|
test->BuildAndRun(
|
||||||
|
input_data, &output_data,
|
||||||
|
dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
|
||||||
ExpectArrayNear(expected_outputs, GetSpanForData<CType>(output_data[0]));
|
ExpectArrayNear(expected_outputs, GetSpanForData<CType>(output_data[0]));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2237,9 +2277,7 @@ TEST_F(OpConverterTest, ConvertSquare) {
|
|||||||
// OK. Note that kINT32 is not supported by IElementWiseLayer, so we don't
|
// OK. Note that kINT32 is not supported by IElementWiseLayer, so we don't
|
||||||
// test DT_INT32 type here.
|
// test DT_INT32 type here.
|
||||||
TestConvertSquare<DT_FLOAT>(this);
|
TestConvertSquare<DT_FLOAT>(this);
|
||||||
// TODO(tmorris): Looks like there may be a bug with this layer for FP16
|
TestConvertSquare<DT_HALF>(this);
|
||||||
// inputs. Disabling for now.
|
|
||||||
// TestConvertSquare<DT_HALF>(this);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OpConverterTest, ConvertActivation) {
|
TEST_F(OpConverterTest, ConvertActivation) {
|
||||||
@ -2269,10 +2307,10 @@ TEST_F(OpConverterTest, ConvertActivation) {
|
|||||||
Scope s = Scope::NewRootScope();
|
Scope s = Scope::NewRootScope();
|
||||||
auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
|
auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
|
||||||
if (op_name == "LeakyRelu") {
|
if (op_name == "LeakyRelu") {
|
||||||
// LeakyRelu does not have a C++ API
|
auto act =
|
||||||
NodeDef node_def = MakeNodeDef("my_act", "LeakyRelu", {"input"});
|
ops::internal::LeakyRelu(s.WithOpName("my_act"), input,
|
||||||
(*node_def.mutable_attr())["alpha"].set_f(kAlpha);
|
ops::internal::LeakyRelu::Alpha(kAlpha));
|
||||||
return node_def;
|
return act.operation.node()->def();
|
||||||
} else if (op_name == "Relu") {
|
} else if (op_name == "Relu") {
|
||||||
auto act = ops::Relu(s.WithOpName("my_act"), input);
|
auto act = ops::Relu(s.WithOpName("my_act"), input);
|
||||||
return act.operation.node()->def();
|
return act.operation.node()->def();
|
||||||
|
Loading…
Reference in New Issue
Block a user