Merge pull request #22874 from trevor-m:tmorris_tftrt_transpose_reshape

PiperOrigin-RevId: 217444920
This commit is contained in:
TensorFlower Gardener 2018-10-16 23:07:37 -07:00
commit 24294c6288
10 changed files with 1599 additions and 729 deletions

View File

@ -326,6 +326,34 @@ tf_cuda_cc_test(
]),
)
tf_cuda_cc_test(
name = "convert_nodes_test",
size = "medium",
srcs = ["convert/convert_nodes_test.cc"],
tags = [
"no_cuda_on_cpu_tap",
"no_windows",
"nomac",
],
deps = [
":trt_logging",
":trt_conversion",
":trt_plugins",
"@com_google_googletest//:gtest",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
] + if_tensorrt([
"@local_config_cuda//cuda:cuda_headers",
"@local_config_tensorrt//:nv_infer",
]),
)
# Library for the segmenting portion of TensorRT operation creation
cc_library(
name = "segment",
@ -455,6 +483,7 @@ cuda_py_tests(
"test/multi_connection_neighbor_engine_test.py",
"test/neighboring_engine_test.py",
"test/rank_two_test.py",
"test/reshape_transpose_test.py",
"test/vgg_block_nchw_test.py",
"test/vgg_block_test.py",
],

View File

@ -115,6 +115,8 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) {
"Sqrt",
"Abs",
"Neg",
"Transpose",
"Reshape",
#if NV_TENSORRT_MAJOR > 3
"MatMul",
"BatchMatMul",

File diff suppressed because it is too large Load Diff

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
#include <list>
#include <set>
#include <string>
#include <unordered_map>
@ -26,6 +27,7 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h"
#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
@ -33,6 +35,7 @@ limitations under the License.
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "tensorrt/include/NvInfer.h"
namespace tensorflow {
namespace tensorrt {
@ -170,6 +173,162 @@ class OutputEdgeValidator {
bool operator()(const tensorflow::Edge* out_edge) const;
};
////////////////////////////////////////////////////////////////////////////////
// Classes/functions below are exposed for testing purposes only.
////////////////////////////////////////////////////////////////////////////////
string DebugString(const nvinfer1::Dims& dims);
string DebugString(const nvinfer1::ITensor& tensor);
int64_t TrtDimsNumElements(const nvinfer1::Dims& dims);
// Class to convert TF weight to TRT weight.
class TRT_ShapedWeights {
public:
TRT_ShapedWeights(tensorflow::DataType type, const void* values,
nvinfer1::Dims shape);
explicit TRT_ShapedWeights(tensorflow::DataType type);
// TODO(aaroey): use rvalue reference.
TRT_ShapedWeights(const TRT_ShapedWeights& rhs);
nvinfer1::Weights GetWeightsForTRT() const;
const void* GetValues() const { return values_; }
int64_t count() const;
size_t size_bytes() const;
// Default converter
operator nvinfer1::Weights() const { return GetWeightsForTRT(); }
string DebugString() const;
// TODO(aaroey): make these private.
nvinfer1::Dims shape_; // Note: shape.type[] is not used.
tensorflow::DataType type_;
private:
// TODO(aaroey): this should not be const as it's always from TRTWeightStore.
const void* values_;
friend bool operator==(const TRT_ShapedWeights& lhs,
const TRT_ShapedWeights& rhs);
};
class TRT_TensorOrWeights {
public:
explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor);
explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights);
// TODO(aaroey): use rvalue reference.
TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs);
bool is_tensor() const { return is_tensor_; }
bool is_weights() const { return !is_tensor_; }
nvinfer1::ITensor* tensor() {
CHECK(is_tensor());
return tensor_;
}
const nvinfer1::ITensor* tensor() const {
CHECK(is_tensor());
return tensor_;
}
TRT_ShapedWeights& weights() {
CHECK(is_weights());
return weights_;
}
const TRT_ShapedWeights& weights() const {
CHECK(is_weights());
return weights_;
}
// TODO(aaroey): rename to dims() to be consistent.
nvinfer1::Dims shape() const;
string DebugString() const;
private:
nvinfer1::ITensor* tensor_;
TRT_ShapedWeights weights_;
const bool is_tensor_;
};
// Class to convert TF nodes to TRT network.
class Converter {
public:
Converter(nvinfer1::INetworkDefinition* trt_network, bool fp16,
int max_batch_size);
virtual ~Converter() {}
nvinfer1::INetworkDefinition* network() { return trt_network_; }
TRTWeightStore* weight_store() { return &weight_store_; }
bool IsFP16() const { return fp16_; }
int GetMaxBatchSize() const { return max_batch_size_; }
TRT_ShapedWeights GetTempWeights(tensorflow::DataType type,
const nvinfer1::Dims& dims);
TRT_ShapedWeights GetTempWeightsLike(const TRT_ShapedWeights& weights) {
return GetTempWeights(weights.type_, weights.shape_);
}
Status ConvertNode(const tensorflow::NodeDef& node_def);
TRT_TensorOrWeights GetTensorOrWeights(const string& name);
Status AddInputTensor(const string& name, nvinfer1::ITensor* tensor);
Status TransposeTensor(nvinfer1::ITensor* input_tensor,
const std::vector<int>& order_with_batch_dim,
const nvinfer1::ITensor** output_tensor);
// Converts input into tensor with shape specified by dims.
Status PrepareTensorForShape(const TRT_TensorOrWeights& input,
const nvinfer1::Dims& dims,
const nvinfer1::ITensor** tensor);
// Expose for testing purposes.
Status GetInputs(const tensorflow::NodeDef& node_def,
std::vector<TRT_TensorOrWeights>* inputs) const;
private:
using OpConverter =
std::function<tensorflow::Status(Converter&, const tensorflow::NodeDef&,
const std::vector<TRT_TensorOrWeights>&,
std::vector<TRT_TensorOrWeights>*)>;
void RegisterOpConverters();
std::unordered_map<string, OpConverter> op_registry_;
std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_;
OpConverter plugin_converter_;
nvinfer1::INetworkDefinition* trt_network_;
// TODO(aaroey): inline the definition of TRTWeightStore here, and add APIs to
// operate the stored weights instead of operating it directly.
TRTWeightStore weight_store_;
bool fp16_;
int max_batch_size_;
friend class ConverterForTest;
};
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow

View File

@ -0,0 +1,646 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
#include <memory>
#include <unordered_map>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/core/framework/node_def.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "cuda/include/cuda.h"
#include "cuda/include/cuda_runtime_api.h"
#include "tensorrt/include/NvInfer.h"
namespace tensorflow {
namespace tensorrt {
namespace convert {
using ::testing::ElementsAre;
void ExpectStatus(Status status, error::Code code, const char* substr) {
EXPECT_EQ(code, status.code()) << status;
EXPECT_THAT(status.error_message(), ::testing::HasSubstr(substr)) << status;
}
nvinfer1::Dims GetTestDims(const std::vector<int>& d) {
nvinfer1::Dims dims;
dims.nbDims = d.size();
for (int i = 0; i < d.size(); ++i) {
dims.d[i] = d[i];
}
return dims;
}
// Fake ITensor implementation for testing purposes.
class FakeITensor : public nvinfer1::ITensor {
public:
FakeITensor() {}
FakeITensor(const nvinfer1::Dims& dims, const string& name = "")
: name_(name), dims_(dims) {}
FakeITensor(const string& name, const std::vector<int>& dims)
: name_(name), dims_(GetTestDims(dims)) {}
void setName(const char* name) override { name_ = name; }
const char* getName() const override { return name_.c_str(); }
void setDimensions(nvinfer1::Dims dimensions) override { dims_ = dimensions; }
nvinfer1::Dims getDimensions() const override { return dims_; }
void setType(nvinfer1::DataType type) override { type_ = type; }
nvinfer1::DataType getType() const override { return type_; }
bool isNetworkInput() const override { return false; }
bool isNetworkOutput() const override { return false; }
void setBroadcastAcrossBatch(bool broadcastAcrossBatch) override {}
bool getBroadcastAcrossBatch() const override { return false; }
nvinfer1::TensorLocation getLocation() const override { return location_; }
void setLocation(nvinfer1::TensorLocation location) override {
location_ = location;
}
#if NV_TENSORRT_MAJOR >= 5
bool setDynamicRange(float min, float max) override {}
#endif
private:
string name_;
nvinfer1::Dims dims_;
nvinfer1::DataType type_;
nvinfer1::TensorLocation location_;
};
bool Equals(const nvinfer1::Dims& lhs, const nvinfer1::Dims& rhs) {
if (lhs.nbDims != rhs.nbDims) return false;
for (int i = 0; i < lhs.nbDims; ++i) {
if (lhs.d[i] != rhs.d[i]) return false;
// We don't check the types in the tests.
}
return true;
}
bool operator==(const TRT_ShapedWeights& lhs, const TRT_ShapedWeights& rhs) {
return Equals(lhs.shape_, rhs.shape_) && lhs.type_ == rhs.type_ &&
lhs.values_ == rhs.values_;
}
TEST(TRT_ShapedWeights_Test, Basic) {
{
float raw_weights[10];
TRT_ShapedWeights weights(DT_FLOAT, raw_weights, GetTestDims({2, 5}));
nvinfer1::Weights trt_weights = weights.GetWeightsForTRT();
EXPECT_EQ(nvinfer1::DataType::kFLOAT, trt_weights.type);
EXPECT_EQ(static_cast<void*>(raw_weights), trt_weights.values);
EXPECT_EQ(10, trt_weights.count);
EXPECT_EQ(static_cast<void*>(raw_weights), weights.GetValues());
EXPECT_EQ(10, weights.count());
EXPECT_EQ(40, weights.size_bytes());
}
{
int32 raw_weights = 0;
TRT_ShapedWeights weights(DT_INT32, &raw_weights, GetTestDims({1, 1, 1}));
nvinfer1::Weights trt_weights = weights.GetWeightsForTRT();
EXPECT_EQ(nvinfer1::DataType::kINT32, trt_weights.type);
EXPECT_EQ(static_cast<void*>(&raw_weights), trt_weights.values);
EXPECT_EQ(1, trt_weights.count);
EXPECT_EQ(static_cast<void*>(&raw_weights), weights.GetValues());
EXPECT_EQ(1, weights.count());
EXPECT_EQ(4, weights.size_bytes());
}
{
TRT_ShapedWeights weights(DT_FLOAT);
nvinfer1::Weights trt_weights = weights.GetWeightsForTRT();
EXPECT_EQ(nvinfer1::DataType::kFLOAT, trt_weights.type);
EXPECT_EQ(nullptr, trt_weights.values);
EXPECT_EQ(0, trt_weights.count);
EXPECT_EQ(nullptr, weights.GetValues());
EXPECT_EQ(0, weights.count());
EXPECT_EQ(0, weights.size_bytes());
}
}
TEST(TRT_TensorOrWeights_Test, Basic) {
{
nvinfer1::Dims dims;
dims.nbDims = 1;
dims.d[0] = 1;
FakeITensor itensor(dims);
TRT_TensorOrWeights tw(&itensor);
EXPECT_EQ(true, tw.is_tensor());
EXPECT_EQ(false, tw.is_weights());
EXPECT_EQ(&itensor, tw.tensor());
EXPECT_TRUE(Equals(dims, tw.shape()))
<< "- expected: " << DebugString(dims)
<< "\n vs\n- actual: " << DebugString(tw.shape());
}
{
TRT_ShapedWeights weights(DT_FLOAT);
TRT_TensorOrWeights tw(weights);
EXPECT_EQ(false, tw.is_tensor());
EXPECT_EQ(true, tw.is_weights());
EXPECT_EQ(weights, tw.weights());
nvinfer1::Dims dims;
dims.nbDims = 0;
EXPECT_TRUE(Equals(dims, tw.shape()))
<< "- expected: " << DebugString(dims)
<< "\n vs\n- actual: " << DebugString(tw.shape());
}
}
class ConverterForTest : public Converter {
public:
ConverterForTest()
: Converter(nullptr, /*fp16=*/false, /*max_batch_size=*/1) {
QCHECK_EQ(0, cudaStreamCreate(&stream_));
Reset();
}
~ConverterForTest() override { QCHECK_EQ(0, cudaStreamDestroy(stream_)); }
// Helper methods for testing purposes.
void AddOpConverter(const string& op_name, OpConverter op_converter) {
op_registry_[op_name] = op_converter;
}
void AddTensorOrWeights(const string& name, TRT_TensorOrWeights tw) {
ASSERT_TRUE(trt_tensors_.insert({name, tw}).second);
}
void Reset() {
// Clear the tensor map.
trt_tensors_.clear();
// Reset the INetworkDefinition.
engine_.reset(nullptr);
network_.reset(nullptr);
builder_.reset(nullptr);
builder_.reset(nvinfer1::createInferBuilder(logger_));
network_.reset(builder_->createNetwork());
trt_network_ = network_.get();
}
void BuildAndRun(const char* input_name, const std::vector<float>& input_data,
const char* output_name, std::vector<float>* output_data) {
// Mark the output tensor as TRT engine output.
TRT_TensorOrWeights tensor = GetTensorOrWeights(output_name);
tensor.tensor()->setName(output_name);
network()->markOutput(*tensor.tensor());
// Build the TRT engine.
QCHECK_EQ(nullptr, engine_.get());
engine_.reset(builder_->buildCudaEngine(*network()));
CHECK_NOTNULL(engine_.get());
// Execute the TRT engine.
const int input_size = input_data.size() * sizeof(float);
const int output_size = output_data->size() * sizeof(float);
const int input_index = engine_->getBindingIndex(input_name);
const int output_index = engine_->getBindingIndex(output_name);
ASSERT_EQ(engine_->getNbBindings(), 2);
void* buffers[2];
ASSERT_EQ(0, cudaMalloc(&buffers[input_index], input_size));
ASSERT_EQ(0, cudaMalloc(&buffers[output_index], output_size));
ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input_data.data(),
input_size, cudaMemcpyHostToDevice, stream_));
TrtUniquePtrType<nvinfer1::IExecutionContext> execution_context(
engine_->createExecutionContext());
execution_context->enqueue(1, buffers, stream_, nullptr);
ASSERT_EQ(0, cudaMemcpyAsync(output_data->data(), buffers[output_index],
output_size, cudaMemcpyDeviceToHost, stream_));
cudaStreamSynchronize(stream_);
ASSERT_EQ(0, cudaFree(buffers[input_index]));
ASSERT_EQ(0, cudaFree(buffers[output_index]));
}
private:
Logger logger_;
TrtUniquePtrType<nvinfer1::IBuilder> builder_;
TrtUniquePtrType<nvinfer1::INetworkDefinition> network_;
TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
cudaStream_t stream_;
};
class ConverterTest : public ::testing::Test {
protected:
nvinfer1::ITensor* AddTestTensor(const char* name,
const std::vector<int>& dims) {
nvinfer1::ITensor* tensor = converter_.network()->addInput(
name, nvinfer1::DataType::kFLOAT, GetTestDims(dims));
converter_.AddTensorOrWeights(name, TRT_TensorOrWeights{tensor});
return tensor;
}
template <typename CType>
TRT_ShapedWeights AddTestWeights(const char* name, const DataType dtype,
const std::vector<int>& dims,
const std::vector<CType>& values) {
const nvinfer1::Dims trt_dims = GetTestDims(dims);
const int64_t num_elements = TrtDimsNumElements(trt_dims);
QCHECK_EQ(num_elements, values.size())
<< num_elements << " vs " << values.size();
TRT_ShapedWeights weights(dtype);
if (num_elements) {
const int64_t size_bytes = DataTypeSize(dtype) * num_elements;
QCHECK_EQ(size_bytes, sizeof(CType) * values.size())
<< size_bytes << " vs " << sizeof(CType) * values.size();
converter_.weight_store()->store_.push_back(
std::vector<uint8_t>(size_bytes));
void* dst =
static_cast<void*>(converter_.weight_store()->store_.back().data());
memcpy(dst, values.data(), size_bytes);
weights = TRT_ShapedWeights(dtype, dst, trt_dims);
}
converter_.AddTensorOrWeights(name, TRT_TensorOrWeights{weights});
return weights;
}
NodeDef MakeNodeDef(const string& name, const string& op,
const std::vector<string>& inputs) {
NodeDef node_def;
node_def.set_name(name);
node_def.set_op(op);
for (const string& input : inputs) {
node_def.add_input(input);
}
return node_def;
}
ConverterForTest converter_;
};
TEST_F(ConverterTest, GetTempWeights) {
TRT_ShapedWeights weights =
converter_.GetTempWeights(DT_FLOAT, GetTestDims({2, 3}));
nvinfer1::Weights trt_weights = weights.GetWeightsForTRT();
EXPECT_EQ(nvinfer1::DataType::kFLOAT, trt_weights.type);
EXPECT_NE(nullptr, trt_weights.values);
EXPECT_EQ(6, trt_weights.count);
EXPECT_NE(nullptr, weights.GetValues());
EXPECT_EQ(6, weights.count());
EXPECT_EQ(24, weights.size_bytes());
// TODO(aaroey): test the case where shape element count is 0.
}
TEST_F(ConverterTest, GetInputs) {
NodeDef node_def;
node_def.add_input("^control_input");
node_def.add_input("input");
node_def.add_input("input:0");
node_def.add_input("input:1");
node_def.add_input("weird_input:2:3:4:0");
FakeITensor input, input_1, input_2;
TF_EXPECT_OK(converter_.AddInputTensor("input", &input));
TF_EXPECT_OK(converter_.AddInputTensor("input:1", &input_1));
TF_EXPECT_OK(converter_.AddInputTensor("weird_input:2:3:4", &input_2));
std::vector<TRT_TensorOrWeights> inputs;
TF_EXPECT_OK(converter_.GetInputs(node_def, &inputs));
EXPECT_EQ(4, inputs.size());
EXPECT_EQ(&input, inputs[0].tensor());
EXPECT_EQ(&input, inputs[1].tensor());
EXPECT_EQ(&input_1, inputs[2].tensor());
EXPECT_EQ(&input_2, inputs[3].tensor());
}
TEST_F(ConverterTest, ConvertNode) {
FakeITensor output_tensors[2];
auto op_converter = [&output_tensors](
Converter& ctx, const NodeDef& node_def,
const std::vector<TRT_TensorOrWeights>& inputs,
std::vector<TRT_TensorOrWeights>* outputs) -> Status {
nvinfer1::Dims dims = inputs[0].tensor()->getDimensions();
for (int i = 0; i < 2; ++i) {
dims.d[0] += 1;
output_tensors[i].setDimensions(dims);
outputs->push_back(TRT_TensorOrWeights(&output_tensors[i]));
}
return Status::OK();
};
converter_.AddOpConverter("MyOp", op_converter);
FakeITensor input_tensor("my_input", {12345});
TF_EXPECT_OK(converter_.AddInputTensor("my_input", &input_tensor));
NodeDef node_def = MakeNodeDef("my_op", "MyOp", {"my_input"});
TF_EXPECT_OK(converter_.ConvertNode(node_def));
TRT_TensorOrWeights actual_output_1 = converter_.GetTensorOrWeights("my_op");
EXPECT_EQ(&output_tensors[0], actual_output_1.tensor());
EXPECT_EQ(12346, actual_output_1.tensor()->getDimensions().d[0]);
TRT_TensorOrWeights actual_output_2 =
converter_.GetTensorOrWeights("my_op:1");
EXPECT_EQ(&output_tensors[1], actual_output_2.tensor());
EXPECT_EQ(12347, actual_output_2.tensor()->getDimensions().d[0]);
}
TEST_F(ConverterTest, TransposeTensor) {
nvinfer1::ITensor* input_tensor = AddTestTensor("", {2, 3, 5});
const nvinfer1::ITensor* output_tensor = nullptr;
// Rank doesn't match.
ExpectStatus(
converter_.TransposeTensor(input_tensor, {0, 1}, &output_tensor),
error::INVALID_ARGUMENT,
"Rank of perm for transpose does not match with that of the input");
// Transpose at batch dimension.
ExpectStatus(
converter_.TransposeTensor(input_tensor, {1, 0, 2, 3}, &output_tensor),
error::UNIMPLEMENTED, "Transpose at batch dimension is not supported.");
// OK.
TF_EXPECT_OK(
converter_.TransposeTensor(input_tensor, {0, 3, 1, 2}, &output_tensor));
EXPECT_TRUE(Equals(GetTestDims({5, 2, 3}), output_tensor->getDimensions()))
<< DebugString(*output_tensor);
}
TEST_F(ConverterTest, PrepareTensorForShape_Tensor) {
nvinfer1::ITensor* input_tensor = AddTestTensor("", {2, 3, 5});
TRT_TensorOrWeights tw(input_tensor);
const nvinfer1::ITensor* output_tensor = nullptr;
// Shape size doesn't match.
ExpectStatus(converter_.PrepareTensorForShape(tw, GetTestDims({2, 3, 6}),
&output_tensor),
error::INVALID_ARGUMENT, "Reshape shapes are not compatible.");
// TODO(aaroey): we should check the case where uninferred dimensions are not
// an exact divisor of input dim ensions, e.g. for dims {-1, 7}.
// Infer shape, ok.
TF_EXPECT_OK(converter_.PrepareTensorForShape(tw, GetTestDims({-1, 2}),
&output_tensor));
EXPECT_TRUE(Equals(GetTestDims({15, 2}), output_tensor->getDimensions()))
<< DebugString(*output_tensor);
// Regular shape.
TF_EXPECT_OK(converter_.PrepareTensorForShape(tw, GetTestDims({10, 3}),
&output_tensor));
EXPECT_TRUE(Equals(GetTestDims({10, 3}), output_tensor->getDimensions()))
<< DebugString(*output_tensor);
}
#if NV_TENSORRT_MAJOR > 3
TEST_F(ConverterTest, PrepareTensorForShape_Weights) {
TRT_ShapedWeights weights =
converter_.GetTempWeights(DT_FLOAT, GetTestDims({2, 3, 5}));
TRT_TensorOrWeights tw(weights);
const nvinfer1::ITensor* output_tensor = nullptr;
TF_EXPECT_OK(converter_.PrepareTensorForShape(tw, GetTestDims({10, 3}),
&output_tensor));
EXPECT_TRUE(Equals(GetTestDims({10, 3}), output_tensor->getDimensions()))
<< DebugString(*output_tensor);
}
#endif
template <DataType dtype, typename InputCType, typename OutputCType>
void TestConvertConst(ConverterForTest* converter) {
NodeDef node_def;
node_def.set_name("my_const");
node_def.set_op("Const");
auto reset_and_test = [&node_def, converter](
const Tensor& tensor, const bool as_tensor_content,
const std::vector<int>& expected_dims,
const std::vector<OutputCType>& expected_value) {
converter->Reset();
auto& attr = *node_def.mutable_attr();
if (as_tensor_content) {
tensor.AsProtoTensorContent(attr["value"].mutable_tensor());
} else {
tensor.AsProtoField(attr["value"].mutable_tensor());
}
TF_EXPECT_OK(converter->ConvertNode(node_def));
TRT_TensorOrWeights output = converter->GetTensorOrWeights("my_const");
EXPECT_TRUE(Equals(GetTestDims(expected_dims), output.weights().shape_))
<< output.DebugString();
ASSERT_EQ(expected_value.size(), output.weights().count())
<< output.DebugString();
const OutputCType* actual_values =
static_cast<const OutputCType*>(output.weights().GetValues());
for (int i = 0; i < expected_value.size(); ++i) {
EXPECT_EQ(expected_value[i], actual_values[i]);
}
};
auto& attr = *node_def.mutable_attr();
attr["dtype"].set_type(dtype);
{
// By default empty tensor will pick DT_FLOAT as data type and we fix it
// here.
attr["value"].mutable_tensor()->set_dtype(dtype);
Tensor t; // Empty tensor.
reset_and_test(t, false, {}, {});
}
{
Tensor t = ::tensorflow::test::AsScalar<InputCType>(12);
reset_and_test(t, false, {1}, {12});
reset_and_test(t, true, {1}, {12});
}
{
Tensor t = ::tensorflow::test::AsTensor<InputCType>({1, 2});
reset_and_test(t, false, {2}, {1, 2});
reset_and_test(t, true, {2}, {1, 2});
}
{
Tensor t = ::tensorflow::test::AsTensor<InputCType>({1, 2, 3, 4, 5, 6},
TensorShape({2, 3}));
reset_and_test(t, false, {2, 3}, {1, 2, 3, 4, 5, 6});
reset_and_test(t, true, {2, 3}, {1, 2, 3, 4, 5, 6});
}
}
TEST_F(ConverterTest, ConvertConst) {
{
converter_.Reset();
NodeDef node_def = MakeNodeDef("my_const", "Const", {"input"});
AddTestTensor("input", {1});
ExpectStatus(
converter_.ConvertNode(node_def), error::INVALID_ARGUMENT,
"Constant node is expected to have empty input list: my_const");
}
{
converter_.Reset();
NodeDef node_def = MakeNodeDef("my_const", "Const", {});
(*node_def.mutable_attr())["dtype"].set_type(DT_DOUBLE);
ExpectStatus(converter_.ConvertNode(node_def), error::INVALID_ARGUMENT,
"Unsupported data type");
}
TestConvertConst<DT_FLOAT, float, float>(&converter_);
TestConvertConst<DT_INT8, int8, int32>(&converter_);
#if NV_TENSORRT_MAJOR > 3
TestConvertConst<DT_INT32, int32, int32>(&converter_);
#endif
}
TEST_F(ConverterTest, ConvertTranspose) {
{
// Input list is empty, should fail.
NodeDef node_def = MakeNodeDef("my_transpose", "Transpose", {});
ExpectStatus(converter_.ConvertNode(node_def), error::INVALID_ARGUMENT,
"Input expects tensor and weights, at my_transpose");
}
NodeDef node_def =
MakeNodeDef("my_transpose", "Transpose", {"input", "weights"});
{
// Permutation is a tensor, should fail.
converter_.Reset();
AddTestTensor("input", {1, 2, 3});
AddTestTensor("weights", {3});
ExpectStatus(converter_.ConvertNode(node_def), error::INVALID_ARGUMENT,
"Input expects tensor and weights, at my_transpose");
}
{
// Transpose at batch dimension, should fail.
converter_.Reset();
AddTestTensor("input", {1, 2, 3});
AddTestWeights<int32>("weights", DT_INT32, {4}, {1, 0, 2, 3});
ExpectStatus(converter_.ConvertNode(node_def), error::UNIMPLEMENTED,
"Transpose at batch dimension is not supported");
}
{
// Permutation rank doesn't match, should fail.
converter_.Reset();
AddTestTensor("input", {1, 2, 3});
AddTestWeights<int32>("weights", DT_INT32, {3}, {0, 1, 2});
ExpectStatus(
converter_.ConvertNode(node_def), error::INVALID_ARGUMENT,
"Rank of perm for transpose does not match with that of the input.");
}
{
// Ok.
converter_.Reset();
AddTestTensor("input", {1, 2, 3});
AddTestWeights<int32>("weights", DT_INT32, {4}, {0, 3, 1, 2});
TF_EXPECT_OK(converter_.ConvertNode(node_def));
TRT_TensorOrWeights output = converter_.GetTensorOrWeights("my_transpose");
EXPECT_TRUE(output.is_tensor());
EXPECT_TRUE(
Equals(GetTestDims({3, 1, 2}), output.tensor()->getDimensions()))
<< output.DebugString();
std::vector<float> output_data(6);
converter_.BuildAndRun("input", {1, 2, 3, 4, 5, 6}, "my_transpose",
&output_data);
EXPECT_THAT(output_data, ElementsAre(1, 4, 2, 5, 3, 6));
}
}
TEST_F(ConverterTest, ConvertReshape) {
{
// Input list is empty, should fail.
NodeDef node_def = MakeNodeDef("my_reshape", "Reshape", {});
ExpectStatus(converter_.ConvertNode(node_def), error::INVALID_ARGUMENT,
"Input expects weights for shape, at my_reshape");
}
NodeDef node_def = MakeNodeDef("my_reshape", "Reshape", {"input", "weights"});
{
// Shape is a tensor, should fail.
converter_.Reset();
AddTestTensor("input", {1, 2, 3});
AddTestTensor("weights", {3});
ExpectStatus(converter_.ConvertNode(node_def), error::INVALID_ARGUMENT,
"Input expects weights for shape, at my_reshape");
}
{
// Reshape to scalar, should fail.
converter_.Reset();
AddTestTensor("input", {1, 2, 3});
AddTestWeights<int32>("weights", DT_INT32, {}, {});
ExpectStatus(converter_.ConvertNode(node_def), error::UNIMPLEMENTED,
"Reshape to shape=[] is not supported, at my_reshape");
}
{
// Reshape at batch dimension, should fail.
converter_.Reset();
AddTestTensor("input", {1, 2, 3});
AddTestWeights<int32>("weights", DT_INT32, {4}, {-1, 1, 1, 2});
ExpectStatus(converter_.ConvertNode(node_def), error::UNIMPLEMENTED,
"Reshape on batch dimension is not supported, at my_reshape");
}
{
// Reshape at batch dimension, should fail.
converter_.Reset();
AddTestTensor("input", {1, 2, 3});
AddTestWeights<int32>("weights", DT_INT32, {4}, {3, 1, 1, 2});
ExpectStatus(converter_.ConvertNode(node_def), error::UNIMPLEMENTED,
"Reshape on batch dimension is not supported, at my_reshape");
}
// Reshape on non batch dimensions, ok.
for (int batch_dim : {-1, 1}) {
converter_.Reset();
AddTestTensor("input", {1, 2, 3});
AddTestWeights<int32>("weights", DT_INT32, {4}, {batch_dim, 1, 3, 2});
TF_EXPECT_OK(converter_.ConvertNode(node_def));
TRT_TensorOrWeights output = converter_.GetTensorOrWeights("my_reshape");
EXPECT_TRUE(output.is_tensor());
EXPECT_TRUE(
Equals(GetTestDims({1, 3, 2}), output.tensor()->getDimensions()))
<< output.DebugString();
std::vector<float> output_data(6);
converter_.BuildAndRun("input", {1, 2, 3, 4, 5, 6}, "my_reshape",
&output_data);
EXPECT_THAT(output_data, ElementsAre(1, 2, 3, 4, 5, 6));
}
}
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow
#endif // GOOGLE_TENSORRT
#endif // GOOGLE_CUDA

View File

@ -136,6 +136,16 @@ class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase):
# - my_trt_op_1 should have ["weights","conv", "div"]
return ["my_trt_op_0", "my_trt_op_1"]
def ShouldRunTest(self, run_params):
# TODO(aaroey): LayoutOptimizer adds Transpose(Const, Const) to the graph
# which breaks the conversion. We should fix it as:
# - Detect the invalid NodeDef earlier before adding them to segment
# - Let it able to change the RewriterConfig when calling
# create_inference_graph().
# It will be good to add debugging feature for Grappler to print the graph
# after running each optimizer.
return False
class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase):

View File

@ -50,17 +50,22 @@ class BatchMatMulTest(trt_test.TfTrtIntegrationTestBase):
w2 = array_ops.placeholder(dtype=dtype, shape=w2_dims, name=w2_name)
with g.device("/GPU:0"):
b = constant_op.constant(np.random.randn(12, 5, 12, 7), dtype=dtype)
c = constant_op.constant(np.random.randn(5, 1, 1), dtype=dtype)
d = constant_op.constant(np.random.randn(5, 1, 1), dtype=dtype)
x1 = math_ops.matmul(inp, b)
c = constant_op.constant(np.random.randn(5, 1, 1), dtype=dtype)
x1 = x1 + c
x2 = math_ops.matmul(inp, w1)
d = constant_op.constant(np.random.randn(5, 1, 1), dtype=dtype)
x2 = x2 * d
e = gen_array_ops.reshape(inp, [12, 40, 12])
e = self.trt_incompatible_op(inp)
e = gen_array_ops.reshape(e, [12, 40, 12])
x3 = math_ops.matmul(e, w2)
f = constant_op.constant(np.random.randn(40, 1), dtype=dtype)
x3 = x3 + f
x3 = gen_array_ops.reshape(x3, [12, 5, 8, 7])
x3 = self.trt_incompatible_op(x3)
out = x1 + x2 + x3
array_ops.squeeze(out, name=output_name)
return trt_test.TfTrtIntegrationTestParams(

View File

@ -33,95 +33,100 @@ from tensorflow.python.platform import test
class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase):
def _ConstOp(self, shape):
return constant_op.constant(np.random.randn(*shape), dtype=dtypes.float32)
def GetParams(self):
"""Testing conversion of BiasAdd MatMul in TF-TRT conversion."""
dtype = dtypes.float32
input_name = "input"
input_dims = [48, 12]
input_matrix_rows = 4
input_matrix_columns = 144
input_dims = [input_matrix_rows, input_matrix_columns]
output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
x = array_ops.placeholder(
dtype=dtypes.float32, shape=input_dims, name=input_name)
b = constant_op.constant(np.random.randn(12, 4), dtype=dtype)
b = self._ConstOp((input_matrix_columns, 4))
x1 = math_ops.matmul(x, b)
b = constant_op.constant(np.random.randn(1, 4), dtype=dtype)
b = self._ConstOp((1, 4))
x1 = x1 + b
b = constant_op.constant(np.random.randn(48, 4), dtype=dtype)
x2 = math_ops.matmul(x, b, transpose_a=True)
x2 = gen_array_ops.reshape(x2, [48, 1])
b = self._ConstOp((input_matrix_rows, 144))
x2 = self.trt_incompatible_op(x)
x2 = math_ops.matmul(x2, b, transpose_a=True)
x2 = gen_array_ops.reshape(x2, [4, -1])
x2 = self.trt_incompatible_op(x2)
b = constant_op.constant(np.random.randn(4, 12), dtype=dtype)
b = self._ConstOp((4, input_matrix_columns))
x3 = math_ops.matmul(x, b, transpose_b=True)
b = constant_op.constant(np.random.randn(16, 48), dtype=dtype)
x4 = math_ops.matmul(x, b, transpose_b=True, transpose_a=True)
x4 = gen_array_ops.reshape(x4, [48, 4])
b = self._ConstOp((16, input_matrix_rows))
x4 = self.trt_incompatible_op(x)
x4 = math_ops.matmul(x4, b, transpose_b=True, transpose_a=True)
x4 = gen_array_ops.reshape(x4, [4, -1])
x4 = self.trt_incompatible_op(x4)
x5 = gen_array_ops.reshape(x, [4, 144])
b = constant_op.constant(np.random.randn(144, 48), dtype=dtype)
x5 = math_ops.matmul(x5, b)
b = constant_op.constant(np.random.randn(48), dtype=dtype)
b = self._ConstOp((input_matrix_columns, 48))
x5 = math_ops.matmul(x, b)
b = self._ConstOp((48,))
x5 = nn.bias_add(x5, b)
x5 = gen_array_ops.reshape(x5, [48, 4])
x5 = gen_array_ops.reshape(x5, [4, -1])
x6 = gen_array_ops.reshape(x, [4, 12, 12])
b = constant_op.constant(np.random.randn(12), dtype=dtype)
b = self._ConstOp((12,))
x6 = nn.bias_add(x6, b, data_format="NHWC")
x6 = gen_array_ops.reshape(x6, [48, -1])
x6 = gen_array_ops.reshape(x6, [4, -1])
x7 = gen_array_ops.reshape(x, [4, 12, 3, 4])
b = constant_op.constant(np.random.randn(4), dtype=dtype)
b = self._ConstOp((4,))
x7 = nn.bias_add(x7, b, data_format="NHWC")
x7 = gen_array_ops.reshape(x7, [48, -1])
x7 = gen_array_ops.reshape(x7, [4, -1])
x8 = gen_array_ops.reshape(x, [4, 12, 3, 2, 2])
b = constant_op.constant(np.random.randn(2), dtype=dtype)
b = self._ConstOp((2,))
x8 = nn.bias_add(x8, b, data_format="NHWC")
x8 = gen_array_ops.reshape(x8, [48, -1])
x8 = gen_array_ops.reshape(x8, [4, -1])
x9 = gen_array_ops.reshape(x, [4, 12, 3, 2, 2])
b = constant_op.constant(np.random.randn(12), dtype=dtype)
b = self._ConstOp((12,))
x9 = nn.bias_add(x9, b, data_format="NCHW")
x9 = gen_array_ops.reshape(x9, [48, -1])
x9 = gen_array_ops.reshape(x9, [4, -1])
x10 = gen_array_ops.reshape(x, [4, 12, 3, 4])
b = constant_op.constant(np.random.randn(12), dtype=dtype)
b = self._ConstOp((12,))
x10 = nn.bias_add(x10, b, data_format="NCHW")
x10 = gen_array_ops.reshape(x10, [48, -1])
x10 = gen_array_ops.reshape(x10, [4, -1])
x11 = gen_array_ops.reshape(x, [4, 12, 12])
b = constant_op.constant(np.random.randn(12), dtype=dtype)
b = self._ConstOp((12,))
x11 = nn.bias_add(x11, b, data_format="NCHW")
x11 = gen_array_ops.reshape(x11, [48, -1])
x11 = gen_array_ops.reshape(x11, [4, -1])
out = array_ops.concat(
[x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11], axis=-1)
out = array_ops.concat([x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11],
axis=-1)
out = array_ops.squeeze(out, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
output_names=[output_name],
expected_output_dims=[(48, 89)])
expected_output_dims=[(4, 6680)])
def GetConversionParams(self, run_params):
"""Return a ConversionParams for test."""
return super(BiasaddMatMulTest,
self).GetConversionParams(run_params)._replace(
max_batch_size=48, maximum_cached_engines=2)
max_batch_size=4, maximum_cached_engines=2)
def _ValidEngines(self):
"""Engines expected to build and run."""
return [
"my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3",
"my_trt_op_6", "my_trt_op_7", "my_trt_op_8", "my_trt_op_9"
]
return ["my_trt_op_0"]
def _InvalidEngines(self):
"""Engines that will cause conversion error at building time."""
return ["my_trt_op_4", "my_trt_op_5"]
return ["my_trt_op_1", "my_trt_op_2"]
def ExpectedEnginesToBuild(self, run_params):
"""Return the expected engines to build."""

View File

@ -32,79 +32,34 @@ from tensorflow.python.platform import test
class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase):
def _ConstOp(self, shape):
return constant_op.constant(np.random.randn(*shape), dtype=dtypes.float32)
def GetParams(self):
"""Tests for scale & elementwise layers in TF-TRT."""
dtype = dtypes.float32
input_name = "input"
input_dims = [10, 24, 24, 20]
output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
# scale
a = constant_op.constant(np.random.randn(1), dtype=dtype)
f = x + a
x = math_ops.sigmoid(f)
# scale
a = constant_op.constant(np.random.randn(1), dtype=dtype)
f = a + x
x = math_ops.sigmoid(f)
# scale
a = constant_op.constant(np.random.randn(24, 1, 1), dtype=dtype)
f = x + a
x = math_ops.sigmoid(f)
# scale
a = constant_op.constant(np.random.randn(24, 1, 1), dtype=dtype)
f = a + x
x = math_ops.sigmoid(f)
# scale
a = constant_op.constant(np.random.randn(24, 24, 20), dtype=dtype)
f = a + x
x = math_ops.sigmoid(f)
# scale
a = constant_op.constant(np.random.randn(24, 24, 20), dtype=dtype)
f = x + a
x = math_ops.sigmoid(f)
# elementwise
a = constant_op.constant(np.random.randn(20), dtype=dtype)
f = x + a
x = math_ops.sigmoid(f)
# elementwise
a = constant_op.constant(np.random.randn(20), dtype=dtype)
f = a + x
x = math_ops.sigmoid(f)
# elementwise
a = constant_op.constant(np.random.randn(1, 24, 1, 1), dtype=dtype)
f = a + x
x = math_ops.sigmoid(f)
# elementwise
a = constant_op.constant(np.random.randn(1, 24, 1, 1), dtype=dtype)
f = x + a
x = math_ops.sigmoid(f)
# elementwise
a = constant_op.constant(np.random.randn(1, 24, 24, 1), dtype=dtype)
f = a + x
x = math_ops.sigmoid(f)
# elementwise
a = constant_op.constant(np.random.randn(1, 24, 24, 1), dtype=dtype)
f = x + a
x = math_ops.sigmoid(f)
# elementwise
a = constant_op.constant(np.random.randn(1, 24, 24, 20), dtype=dtype)
f = a + x
x = math_ops.sigmoid(f)
# elementwise
a = constant_op.constant(np.random.randn(1, 24, 24, 20), dtype=dtype)
f = x + a
x = math_ops.sigmoid(f)
# elementwise
a = constant_op.constant(np.random.randn(24, 20), dtype=dtype)
f = a + x
x = math_ops.sigmoid(f)
# elementwise
a = constant_op.constant(np.random.randn(24, 20), dtype=dtype)
f = x + a
x = math_ops.sigmoid(f)
x = array_ops.placeholder(
dtype=dtypes.float32, shape=input_dims, name=input_name)
for weights_shape in [
(1,), # scale
(24, 1, 1), # scale
(24, 24, 20), # scale
(20,), # elementwise
(1, 24, 1, 1), # elementwise
(1, 24, 24, 1), # elementwise
(1, 24, 24, 20), # elementwise
(24, 20), # elementwise
]:
a = self._ConstOp(weights_shape)
f = x + a
x = math_ops.sigmoid(f)
a = self._ConstOp(weights_shape)
f = a + x
x = math_ops.sigmoid(f)
gen_array_ops.reshape(x, [5, -1], name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
@ -115,24 +70,7 @@ class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase):
def ExpectedEnginesToBuild(self, run_params):
"""Return the expected engines to build."""
return [
"my_trt_op_0",
"my_trt_op_1",
"my_trt_op_2",
"my_trt_op_3",
"my_trt_op_4",
"my_trt_op_5",
"my_trt_op_6",
"my_trt_op_7",
"my_trt_op_8",
"my_trt_op_9",
"my_trt_op_10",
"my_trt_op_11",
"my_trt_op_12",
"my_trt_op_13",
"my_trt_op_14",
"my_trt_op_15",
]
return ["my_trt_op_%d" % i for i in range(16)]
if __name__ == "__main__":

View File

@ -0,0 +1,152 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Basic tests for TF-TensorRT integration."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class ReshapeTest(trt_test.TfTrtIntegrationTestBase):
def GetParams(self):
dtype = dtypes.float32
input_name = "input"
input_dims = [100, 24, 24, 2]
output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
dtype=dtype, shape=[None] + input_dims[1:], name=input_name)
outputs = []
# Here we test two types of reshapes, one changes the batch dimension and
# the other does not. Note that we're not able to test reshaping to
# scalar, since TRT requires input tensor to be of rank at least 2, so a
# reshape with scalar input will be filtered out of the segment before
# conversion.
with g.device("/GPU:0"):
# These reshapes happen at batch dimension, thus conversion should fail.
for shape in [[2, 50, 24, 24, 2], [-1, 50, 24, 24, 2],
[2, 50, -1, 24, 2]]:
incompatible_reshape = array_ops.reshape(inp, shape)
reshape_back = array_ops.reshape(incompatible_reshape,
[-1, 24, 24, 2])
outputs.append(self.trt_incompatible_op(reshape_back))
# Add another block with many reshapes that don't change the batch
# dimension.
compatible_reshape = array_ops.reshape(
inp, [-1, 24 * 24, 2], name="reshape-0")
compatible_reshape = array_ops.reshape(
compatible_reshape, [100, 24, -1], name="reshape-1")
compatible_reshape = array_ops.reshape(
compatible_reshape, [100, 24 * 2, 24], name="reshape-2")
compatible_reshape = array_ops.reshape(
compatible_reshape, [-1, 24, 24 * 2], name="reshape-3")
compatible_reshape = array_ops.reshape(
compatible_reshape, [-1, 6, 4, 24, 2], name="reshape-4")
compatible_reshape = array_ops.reshape(
compatible_reshape, [-1, 6, 4, 6, 4, 2, 1], name="reshape-5")
compatible_reshape = array_ops.reshape(
compatible_reshape, [-1, 24, 24, 2], name="reshape-6")
outputs.append(self.trt_incompatible_op(compatible_reshape))
math_ops.add_n(outputs, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
output_names=[output_name],
expected_output_dims=[tuple(input_dims)])
def ExpectedEnginesToBuild(self, run_params):
"""Return the expected engines to build."""
return {
"my_trt_op_3": ["reshape-%d" % i for i in range(7)] +
["reshape-%d/shape" % i for i in range(7)]
}
def ShouldRunTest(self, run_params):
"""Whether to run the test."""
return (not trt_test.IsQuantizationMode(run_params.precision_mode) and
not run_params.dynamic_engine)
class TransposeTest(trt_test.TfTrtIntegrationTestBase):
def GetParams(self):
"""Create a graph containing single segment."""
dtype = dtypes.float32
input_name = "input"
input_dims = [100, 24, 24, 2]
output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
dtype=dtype, shape=[None] + input_dims[1:], name=input_name)
with g.device("/GPU:0"):
# Add a block with compatible transposes.
compatible_transpose = array_ops.transpose(
inp, [0, 3, 1, 2], name="transpose-1")
compatible_transpose = array_ops.transpose(
compatible_transpose, [0, 2, 3, 1], name="transposeback")
# Add an incompatible op so the first block will not be in the same
# subgraph where the following block belongs.
bridge = self.trt_incompatible_op(compatible_transpose)
# Add a block with incompatible transposes.
#
# Note: by default Grappler will run the TRT optimizer twice. At the
# first time it will group the two transpose ops below to same segment
# then fail the conversion due to the expected batch dimension problem.
# At the second time, since the input of bridge op is my_trt_op_0, it
# will fail to do shape inference which then cause conversion to fail.
# TODO(laigd): support shape inference, make TRT optimizer run only
# once, and fix this.
incompatible_transpose = array_ops.transpose(
bridge, [2, 1, 0, 3], name="transpose-2")
excluded_transpose = array_ops.transpose(
incompatible_transpose, [0, 2, 3, 1], name="transpose-3")
array_ops.identity(excluded_transpose, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
output_names=[output_name],
expected_output_dims=[(24, 100, 2, 24)])
def ExpectedEnginesToBuild(self, run_params):
"""Return the expected engines to build."""
return {
"my_trt_op_0": [
"transpose-1", "transpose-1/perm", "transposeback",
"transposeback/perm"
]
}
def ShouldRunTest(self, run_params):
"""Whether to run the test."""
return (not trt_test.IsQuantizationMode(run_params.precision_mode) and
not run_params.dynamic_engine)
if __name__ == "__main__":
test.main()