Merge pull request #40483 from MattConley:convert_square_update
PiperOrigin-RevId: 317112570 Change-Id: Ief45a08e3a7064244c373ff385d055f10fd1b6f5
This commit is contained in:
commit
f07f057816
@ -4424,8 +4424,13 @@ Status ConvertSquare(OpConverterParams* params) {
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}}));
|
||||
#if IS_TRT_VERSION_GE(6, 0, 1, 0)
|
||||
TF_RETURN_IF_ERROR(AllowDataTypes(
|
||||
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
|
||||
#else
|
||||
TF_RETURN_IF_ERROR(
|
||||
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
|
||||
#endif
|
||||
if (params->validation_only) return Status::OK();
|
||||
|
||||
// Constant 2 with same rank as input
|
||||
|
@ -2756,58 +2756,40 @@ TEST_F(OpConverterTest, ConvertQuantize) {
|
||||
}
|
||||
}
|
||||
|
||||
template <DataType dtype>
|
||||
void TestConvertSquare(OpConverterTest* test) {
|
||||
test->Reset();
|
||||
typedef typename EnumToDataType<dtype>::Type CType;
|
||||
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto input = ops::Placeholder(s.WithOpName("input"), dtype);
|
||||
auto square = ops::Square(s.WithOpName("my_square"), input);
|
||||
NodeDef node_def = square.operation.node()->def();
|
||||
|
||||
test->AddTestTensor("input", {1, 20}, /*batch_size=*/1,
|
||||
TfDataTypeToTrt(dtype));
|
||||
test->RunValidationAndConversion(node_def);
|
||||
TRT_TensorOrWeights output;
|
||||
TF_EXPECT_OK(test->GetTensorOrWeights("my_square", &output));
|
||||
ASSERT_TRUE(output.is_tensor());
|
||||
ExpectTrtDimsEqualsArray({1, 20}, output.tensor()->getDimensions());
|
||||
|
||||
const int num_inputs = 20;
|
||||
std::vector<CType> inputs(num_inputs);
|
||||
std::vector<CType> expected_outputs(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
const CType value = CType(i - 9);
|
||||
inputs[i] = value;
|
||||
expected_outputs[i] = value * value;
|
||||
}
|
||||
const DataVec input_data{{"input", test->AsTensor<CType>(inputs)}};
|
||||
// Engine outputs are converted to FP16 automatically if we set FP16 mode in
|
||||
// the builder.
|
||||
DataVec output_data{{"my_square", test->ConstructTensor<CType>(num_inputs)}};
|
||||
TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data));
|
||||
ExpectArrayNear(expected_outputs, GetSpanForData<CType>(output_data[0]));
|
||||
}
|
||||
|
||||
TEST_F(OpConverterTest, ConvertSquare) {
|
||||
TEST_P(OpConverterTest2, ConvertSquare) {
|
||||
{
|
||||
// Input is weights, should fail.
|
||||
Reset();
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
|
||||
auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
|
||||
auto square = ops::Square(s.WithOpName("my_square"), input);
|
||||
NodeDef node_def = square.operation.node()->def();
|
||||
AddTestWeights<float>("input", {1, 2, 3}, {1, 2, 3, 4, -5, 6});
|
||||
AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, -5, 6}, tf_type);
|
||||
RunValidationAndConversion(
|
||||
node_def, error::UNIMPLEMENTED,
|
||||
"The input \"x\" for Square must be a tensor, at my_square");
|
||||
}
|
||||
|
||||
// OK. Note that kINT32 is not supported by IElementWiseLayer, so we don't
|
||||
// test DT_INT32 type here.
|
||||
TestConvertSquare<DT_FLOAT>(this);
|
||||
TestConvertSquare<DT_HALF>(this);
|
||||
Reset();
|
||||
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
|
||||
auto square = ops::Square(s.WithOpName("my_square"), input);
|
||||
NodeDef node_def = square.operation.node()->def();
|
||||
|
||||
const int num_inputs = 20;
|
||||
std::vector<float> inputs(num_inputs);
|
||||
std::vector<float> expected_outputs(num_inputs);
|
||||
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
const float value = (i - 9);
|
||||
inputs[i] = value;
|
||||
expected_outputs[i] = value * value;
|
||||
}
|
||||
AddTestTensor("input", {1, 1, 20}, tf_type, inputs);
|
||||
|
||||
TestOpConverter("my_square", node_def, {1, 1, 20}, Status::OK(), Status::OK(),
|
||||
ArrayFloatNear(expected_outputs, 0));
|
||||
}
|
||||
|
||||
#if IS_TRT_VERSION_GE(5, 1, 0, 0)
|
||||
|
Loading…
Reference in New Issue
Block a user