Merge pull request #39153 from tfeher:trt_unary_dynamic_shape
PiperOrigin-RevId: 310410778 Change-Id: I716deaa6002bbd429ba4f797a5497e7e91e3973f
This commit is contained in:
commit
a65ece1e46
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
@ -66,6 +67,7 @@ namespace convert {
|
||||
using absl::StrCat;
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::ElementsAreArray;
|
||||
using ::testing::FloatNear;
|
||||
using ::testing::Matcher;
|
||||
using ::testing::NanSensitiveFloatNear;
|
||||
|
||||
@ -216,6 +218,21 @@ void ExpectTrtDimsEqualsArray(const std::vector<int>& lhs,
|
||||
<< " actual: " << DebugString(rhs);
|
||||
}
|
||||
|
||||
Matcher<std::vector<float>> ArrayFloatNear(const std::vector<float>& values,
|
||||
float max_abs_error = 1e-5,
|
||||
bool nan_sensitive = false) {
|
||||
std::vector<Matcher<float>> matchers;
|
||||
matchers.reserve(values.size());
|
||||
for (const float& v : values) {
|
||||
if (nan_sensitive) {
|
||||
matchers.emplace_back(NanSensitiveFloatNear(v, max_abs_error));
|
||||
} else {
|
||||
matchers.emplace_back(FloatNear(v, max_abs_error));
|
||||
}
|
||||
}
|
||||
return ElementsAreArray(matchers);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ExpectArrayNear(const std::vector<T>& lhs, absl::Span<const T> rhs) {
|
||||
ASSERT_EQ(lhs.size(), rhs.size());
|
||||
@ -5114,135 +5131,54 @@ TEST_F(OpConverterTest, ConvertGather) {
|
||||
TestConvertGather<DT_INT32>(this);
|
||||
}
|
||||
|
||||
TEST_F(OpConverterTest, ConvertUnary) {
|
||||
template <typename T>
|
||||
NodeDef CreateUnaryOp() {
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
|
||||
return T(s.WithOpName("my_unary"), input).operation.node()->def();
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedOpConverterTest, ConvertUnary) {
|
||||
const auto& spec = GetParam();
|
||||
const TrtTestMode trt_mode = std::get<0>(spec);
|
||||
const DataType tf_dtype = std::get<1>(spec);
|
||||
TrtPrecisionMode converter_precision = std::get<2>(spec);
|
||||
{
|
||||
// Input is weights, should fail.
|
||||
Reset();
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
|
||||
auto neg = ops::Neg(s.WithOpName("my_unary"), input);
|
||||
const NodeDef& node_def = neg.operation.node()->def();
|
||||
Reset(converter_precision, trt_mode);
|
||||
const NodeDef node_def = CreateUnaryOp<ops::Neg>();
|
||||
AddTestWeights<float>("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::UNIMPLEMENTED,
|
||||
"The input \"x\" for Neg must be a tensor, at my_unary");
|
||||
}
|
||||
|
||||
// Get nodedef for unary layer.
|
||||
auto get_unary_nodedef = [](string op_name) -> NodeDef {
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
|
||||
if (op_name == "Abs") {
|
||||
auto unary = ops::Abs(s.WithOpName("my_unary"), input);
|
||||
return unary.operation.node()->def();
|
||||
} else if (op_name == "Acos") {
|
||||
auto unary = ops::Acos(s.WithOpName("my_unary"), input);
|
||||
return unary.operation.node()->def();
|
||||
} else if (op_name == "Acosh") {
|
||||
auto unary = ops::Acosh(s.WithOpName("my_unary"), input);
|
||||
return unary.operation.node()->def();
|
||||
} else if (op_name == "Asin") {
|
||||
auto unary = ops::Asin(s.WithOpName("my_unary"), input);
|
||||
return unary.operation.node()->def();
|
||||
} else if (op_name == "Asinh") {
|
||||
auto unary = ops::Asinh(s.WithOpName("my_unary"), input);
|
||||
return unary.operation.node()->def();
|
||||
} else if (op_name == "Atan") {
|
||||
auto unary = ops::Atan(s.WithOpName("my_unary"), input);
|
||||
return unary.operation.node()->def();
|
||||
} else if (op_name == "Atanh") {
|
||||
auto unary = ops::Atanh(s.WithOpName("my_unary"), input);
|
||||
return unary.operation.node()->def();
|
||||
} else if (op_name == "Ceil") {
|
||||
auto unary = ops::Ceil(s.WithOpName("my_unary"), input);
|
||||
return unary.operation.node()->def();
|
||||
} else if (op_name == "Cos") {
|
||||
auto unary = ops::Cos(s.WithOpName("my_unary"), input);
|
||||
return unary.operation.node()->def();
|
||||
} else if (op_name == "Cosh") {
|
||||
auto unary = ops::Cosh(s.WithOpName("my_unary"), input);
|
||||
return unary.operation.node()->def();
|
||||
} else if (op_name == "Exp") {
|
||||
auto unary = ops::Exp(s.WithOpName("my_unary"), input);
|
||||
return unary.operation.node()->def();
|
||||
} else if (op_name == "Floor") {
|
||||
auto unary = ops::Floor(s.WithOpName("my_unary"), input);
|
||||
return unary.operation.node()->def();
|
||||
} else if (op_name == "Log") {
|
||||
auto unary = ops::Log(s.WithOpName("my_unary"), input);
|
||||
return unary.operation.node()->def();
|
||||
} else if (op_name == "Neg") {
|
||||
auto unary = ops::Neg(s.WithOpName("my_unary"), input);
|
||||
return unary.operation.node()->def();
|
||||
} else if (op_name == "Reciprocal") {
|
||||
auto unary = ops::Reciprocal(s.WithOpName("my_unary"), input);
|
||||
return unary.operation.node()->def();
|
||||
} else if (op_name == "Rsqrt") {
|
||||
auto unary = ops::Rsqrt(s.WithOpName("my_unary"), input);
|
||||
return unary.operation.node()->def();
|
||||
} else if (op_name == "Sin") {
|
||||
auto unary = ops::Sin(s.WithOpName("my_unary"), input);
|
||||
return unary.operation.node()->def();
|
||||
} else if (op_name == "Sinh") {
|
||||
auto unary = ops::Sinh(s.WithOpName("my_unary"), input);
|
||||
return unary.operation.node()->def();
|
||||
} else if (op_name == "Sqrt") {
|
||||
auto unary = ops::Sqrt(s.WithOpName("my_unary"), input);
|
||||
return unary.operation.node()->def();
|
||||
} else if (op_name == "Tan") {
|
||||
auto unary = ops::Tan(s.WithOpName("my_unary"), input);
|
||||
return unary.operation.node()->def();
|
||||
}
|
||||
EXPECT_TRUE(false);
|
||||
return NodeDef();
|
||||
};
|
||||
// Get expected output for unary layer.
|
||||
auto get_unary_output = [](string op_name, float input) -> float {
|
||||
if (op_name == "Abs") {
|
||||
return std::abs(input);
|
||||
} else if (op_name == "Acos") {
|
||||
return std::acos(input);
|
||||
} else if (op_name == "Acosh") {
|
||||
return std::acosh(input);
|
||||
} else if (op_name == "Asin") {
|
||||
return std::asin(input);
|
||||
} else if (op_name == "Asinh") {
|
||||
return std::asinh(input);
|
||||
} else if (op_name == "Atan") {
|
||||
return std::atan(input);
|
||||
} else if (op_name == "Atanh") {
|
||||
return std::atanh(input);
|
||||
} else if (op_name == "Ceil") {
|
||||
return std::ceil(input);
|
||||
} else if (op_name == "Cos") {
|
||||
return std::cos(input);
|
||||
} else if (op_name == "Cosh") {
|
||||
return std::cosh(input);
|
||||
} else if (op_name == "Exp") {
|
||||
return std::exp(input);
|
||||
} else if (op_name == "Floor") {
|
||||
return std::floor(input);
|
||||
} else if (op_name == "Log") {
|
||||
return std::log(input);
|
||||
} else if (op_name == "Neg") {
|
||||
return -input;
|
||||
} else if (op_name == "Reciprocal") {
|
||||
return 1.0 / input;
|
||||
} else if (op_name == "Rsqrt") {
|
||||
return 1.0 / std::sqrt(input);
|
||||
} else if (op_name == "Sin") {
|
||||
return std::sin(input);
|
||||
} else if (op_name == "Sinh") {
|
||||
return std::sinh(input);
|
||||
} else if (op_name == "Sqrt") {
|
||||
return std::sqrt(input);
|
||||
} else if (op_name == "Tan") {
|
||||
return std::tan(input);
|
||||
}
|
||||
EXPECT_TRUE(false);
|
||||
return 0;
|
||||
};
|
||||
|
||||
using OpFunc = std::function<NodeDef(void)>;
|
||||
using ValFunc = float (*)(float);
|
||||
std::map<std::string, std::pair<OpFunc, ValFunc>> op_map;
|
||||
#define ADD_OP(name, op, compute) \
|
||||
op_map[name] = \
|
||||
std::make_pair(CreateUnaryOp<op>, static_cast<ValFunc>(compute))
|
||||
ADD_OP("Abs", ops::Abs, std::abs);
|
||||
ADD_OP("Acos", ops::Acos, std::acos);
|
||||
ADD_OP("Acosh", ops::Acosh, std::acosh);
|
||||
ADD_OP("Asin", ops::Asin, std::asin);
|
||||
ADD_OP("Asinh", ops::Asinh, std::asinh);
|
||||
ADD_OP("Atan", ops::Atan, std::atan);
|
||||
ADD_OP("Atanh", ops::Atanh, std::atanh);
|
||||
ADD_OP("Ceil", ops::Ceil, std::ceil);
|
||||
ADD_OP("Cos", ops::Cos, std::cos);
|
||||
ADD_OP("Cosh", ops::Cosh, std::cosh);
|
||||
ADD_OP("Exp", ops::Exp, std::exp);
|
||||
ADD_OP("Floor", ops::Floor, std::floor);
|
||||
ADD_OP("Log", ops::Log, std::log);
|
||||
ADD_OP("Neg", ops::Neg, [](float x) { return -x; });
|
||||
ADD_OP("Reciprocal", ops::Reciprocal, [](float x) { return 1.0f / x; });
|
||||
ADD_OP("Rsqrt", ops::Rsqrt, [](float x) { return 1.0f / std::sqrt(x); });
|
||||
ADD_OP("Sin", ops::Sin, std::sin);
|
||||
ADD_OP("Sinh", ops::Sinh, std::sinh);
|
||||
ADD_OP("Sqrt", ops::Sqrt, std::sqrt);
|
||||
ADD_OP("Tan", ops::Tan, std::tan);
|
||||
#undef ADD_OP
|
||||
// Get list of ops to test.
|
||||
std::vector<string> ops_to_test;
|
||||
// Add all ops supported by ConvertUnary.
|
||||
@ -5253,26 +5189,30 @@ TEST_F(OpConverterTest, ConvertUnary) {
|
||||
}
|
||||
// Add other unary ops to test.
|
||||
ops_to_test.push_back("Rsqrt");
|
||||
// Ok.
|
||||
// Prepare test parameters
|
||||
auto p = TestParamBase{
|
||||
{1, 1, 2, 3}, // input dims
|
||||
{}, // input partial dims
|
||||
{1, 1, 2, 3}, // expected output dims
|
||||
};
|
||||
for (const string& op_name : ops_to_test) {
|
||||
Reset();
|
||||
NodeDef node_def = get_unary_nodedef(op_name);
|
||||
AddTestTensor("input", {1, 2, 3});
|
||||
RunValidationAndConversion(node_def);
|
||||
TRT_TensorOrWeights output;
|
||||
TF_EXPECT_OK(GetTensorOrWeights("my_unary", &output));
|
||||
ASSERT_TRUE(output.is_tensor());
|
||||
ExpectTrtDimsEqualsArray({1, 2, 3}, output.tensor()->getDimensions());
|
||||
|
||||
const std::vector<float> input = {-0.9f, 0.6f, 0.0f, -3.5f, 100.0f, 2.9f};
|
||||
const DataVec input_data{{"input", AsTensor<float>(input)}};
|
||||
DataVec output_data{{"my_unary", ConstructTensor<float>(6)}};
|
||||
BuildAndRun(input_data, &output_data);
|
||||
for (int i = 0; i < input.size(); ++i) {
|
||||
const float expected_output = get_unary_output(op_name, input[i]);
|
||||
EXPECT_THAT(GetSpanForData<float>(output_data[0])[i],
|
||||
NanSensitiveFloatNear(expected_output, 0.0001));
|
||||
SCOPED_TRACE(op_name);
|
||||
Reset(converter_precision, trt_mode);
|
||||
if (!op_map.count(op_name)) {
|
||||
FAIL() << "Unary op test map does not contain op " << op_name;
|
||||
}
|
||||
NodeDef node_def = op_map[op_name].first();
|
||||
|
||||
AddTestTensor("input", p.input_dims, TfDataTypeToTrt(tf_dtype), trt_mode);
|
||||
RunValidationAndConversion(node_def, Status::OK(), "my_unary",
|
||||
p.expected_output_dims);
|
||||
|
||||
std::vector<float> input_values{-0.9f, 0.6f, 0.0f, -3.5f, 100.0f, 2.9f};
|
||||
std::vector<float> output;
|
||||
std::transform(input_values.begin(), input_values.end(),
|
||||
std::back_inserter(output), op_map[op_name].second);
|
||||
InstantiateBuildAndRun(tf_dtype, "my_unary", this, p, input_values,
|
||||
ArrayFloatNear(output, 0.0001, true));
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user