[TF:TRT] Avoid generating segments that won't be accepted for TRTEngineOp.
When implicit batch and static engine are used, it is required that input dimensions rather than the batch dimensions are fully specified when contructing the TRTEngineOp. Previously, we may generate segments with inputs that do not meet this requirement and abandon the segments when constructing TRTEngineOp. This change avoids generating segments with dynamic non-batch dimensions. Add two test cases. Modify a test that is affected by this change due to a bug in GraphProperties::InferStatically. PiperOrigin-RevId: 308626836 Change-Id: Ie6fc04fef43ab7a575f6a6e3761150b915dc3524
This commit is contained in:
parent
020835b576
commit
ebe525f27c
@ -491,7 +491,9 @@ cc_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_protobuf//:protobuf_headers",
|
||||
],
|
||||
)
|
||||
|
@ -77,6 +77,19 @@ Status BuildNodeMap(const Graph& graph,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
EngineInfo::EngineType GetEngineType(const ConversionParams& params) {
|
||||
return (params.is_dyn_op || params.use_calibration)
|
||||
? EngineInfo::EngineType::TRTDynamic
|
||||
: EngineInfo::EngineType::TRTStatic;
|
||||
}
|
||||
|
||||
// Returns true when use_implicit_batch is false or when we are building dynamic
|
||||
// engine, to allow unknown size for dimensions rather than dimension 0.
|
||||
bool AllowDynamicNonBatchDimension(const ConversionParams& params) {
|
||||
return !params.use_implicit_batch ||
|
||||
GetEngineType(params) == EngineInfo::EngineType::TRTDynamic;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
struct EdgePtrCompare {
|
||||
@ -393,9 +406,8 @@ Status CreateTRTNode(const ConversionParams& params,
|
||||
for (int i = 1; i < conn.outside_shape.dims(); i++) {
|
||||
if (conn.outside_shape.dim_size(i) <= 0) {
|
||||
return errors::Internal(
|
||||
"Input shapes must be fully defined when in static mode. "
|
||||
"Please try is_dynamic_op=True (shape was ",
|
||||
conn.outside_shape.DebugString(), ")");
|
||||
"Not fully defined input shape when in static mode which "
|
||||
"should have been excluded by the segmenter. ");
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -645,11 +657,15 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
||||
segment_options.exclude_node_list.insert(node);
|
||||
}
|
||||
segment_options.minimum_segment_size = params.minimum_segment_size;
|
||||
segment_options.use_implicit_batch = params.use_implicit_batch;
|
||||
segment_options.allow_dynamic_non_batch_dim =
|
||||
AllowDynamicNonBatchDimension(params);
|
||||
|
||||
segment::SegmentNodesVector initial_segments;
|
||||
TrtNodeValidator validator(*params.graph_properties, params.precision_mode,
|
||||
params.use_calibration, params.use_implicit_batch);
|
||||
TF_RETURN_IF_ERROR(segment::SegmentGraph(
|
||||
&graph,
|
||||
&graph, params.graph_properties,
|
||||
std::bind(&TrtNodeValidator::IsTensorRTCandidate, &validator,
|
||||
std::placeholders::_1),
|
||||
// Input validation is already done by TrtNodeValidator, so we don't
|
||||
@ -686,9 +702,7 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
||||
continue;
|
||||
}
|
||||
curr_engine.precision_mode = params.precision_mode;
|
||||
curr_engine.engine_type = ((params.is_dyn_op || params.use_calibration)
|
||||
? EngineInfo::EngineType::TRTDynamic
|
||||
: EngineInfo::EngineType::TRTStatic);
|
||||
curr_engine.engine_type = GetEngineType(params);
|
||||
curr_engine.use_calibration = params.use_calibration;
|
||||
curr_engine.maximum_cached_engines = params.max_cached_engines;
|
||||
curr_engine.allow_build_at_runtime = params.allow_build_at_runtime;
|
||||
@ -764,6 +778,7 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
||||
} else {
|
||||
// Graph is not modified.
|
||||
LOG(WARNING) << "Cannot replace " << msg
|
||||
<< " reason: " << status.error_message()
|
||||
<< " (keeping original segment).";
|
||||
}
|
||||
if (VLOG_IS_ON(1)) {
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/segment/union_find.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
@ -39,8 +40,11 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
namespace segment {
|
||||
namespace {
|
||||
using absl::StrAppend;
|
||||
using absl::StrAppendFormat;
|
||||
using absl::StrCat;
|
||||
using absl::StrJoin;
|
||||
|
||||
// A simple graph representation to mirror Graph. This structure
|
||||
// helps saving memory since segmenter modifies the graph in place, preventing
|
||||
@ -243,8 +247,6 @@ struct NodePtrCompare {
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
// Copied from TF ReverseDFS, which only works for Graph.
|
||||
void StableDFS(const SimpleGraph& g, bool reverse,
|
||||
const std::vector<const SimpleNode*>& start,
|
||||
@ -344,7 +346,68 @@ bool CanContractEdge(const SimpleEdge* edge,
|
||||
});
|
||||
return !has_cycle;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// TODO(bixia): put this to a common utility file.
|
||||
string TensorPropertiesToString(const OpInfo::TensorProperties& prop) {
|
||||
string s = StrCat(DataTypeString(prop.dtype()), ": ");
|
||||
StrAppend(&s, "[");
|
||||
if (prop.shape().unknown_rank()) {
|
||||
StrAppend(&s, "?");
|
||||
} else {
|
||||
StrAppend(&s, StrJoin(prop.shape().dim(), ",",
|
||||
[](string* out, const TensorShapeProto_Dim& d) {
|
||||
StrAppendFormat(out, "%d", d.size());
|
||||
}));
|
||||
}
|
||||
StrAppend(&s, "]");
|
||||
return s;
|
||||
}
|
||||
|
||||
string TensorPropertiesToString(
|
||||
const std::vector<OpInfo::TensorProperties>& properties) {
|
||||
return StrJoin(properties, "; ",
|
||||
[](string* out, const OpInfo::TensorProperties& prop) {
|
||||
StrAppend(out, TensorPropertiesToString(prop));
|
||||
});
|
||||
}
|
||||
|
||||
// Returns true if we can't be sure that the operand with the given properties
|
||||
// won't have negative values for non-batch dimensions.
|
||||
//
|
||||
bool HasDynamicNonBatchDimension(const OpInfo::TensorProperties& prop) {
|
||||
const TensorShapeProto& shape = prop.shape();
|
||||
if (shape.unknown_rank()) return true;
|
||||
|
||||
// Scalar is a well specified shape, and TRT supports implicit broadcasting
|
||||
// from scalar to other shapes.
|
||||
if (shape.dim_size() == 0) return false;
|
||||
for (int i = 1; i < shape.dim_size(); ++i) {
|
||||
// The value of a dynamic dimension can be other negative values besides
|
||||
// -1, representing the symbolic group of the dimension.
|
||||
if (shape.dim(i).size() <= -1) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Returns true if we can't be sure that the operation won't have dynamic
|
||||
// non-batch dimension involved. We only check the shape of the first output
|
||||
// assuming shape inference already propagates the shapes.
|
||||
bool OperationHasDynamicNonBatchDimension(
|
||||
const grappler::GraphProperties* graph_properties, const Node* node) {
|
||||
VLOG(3) << "process node " << node->name();
|
||||
// If the node doesn't have any input or output, not computation is involved.
|
||||
if (node->num_inputs() == 0 || node->num_outputs() == 0) return false;
|
||||
|
||||
// If the node doesn't have output properties, return true to be conservative.
|
||||
if (!graph_properties->HasOutputProperties(node->name())) return true;
|
||||
VLOG(3) << "output shapes "
|
||||
<< TensorPropertiesToString(
|
||||
graph_properties->GetOutputProperties(node->name()));
|
||||
return HasDynamicNonBatchDimension(
|
||||
graph_properties->GetOutputProperties(node->name()).at(0));
|
||||
}
|
||||
|
||||
void ContractEdge(SimpleEdge* edge, SimpleGraph* graph,
|
||||
std::vector<const SimpleEdge*>* remove_edges) {
|
||||
@ -404,12 +467,25 @@ void ContractEdge(SimpleEdge* edge, SimpleGraph* graph,
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status SegmentGraph(const Graph* tf_graph,
|
||||
const grappler::GraphProperties* graph_properties,
|
||||
const std::function<Status(const Node*)>& candidate_fn,
|
||||
const std::function<bool(const Edge*)>& input_candidate_fn,
|
||||
const std::function<bool(const Edge*)>& output_candidate_fn,
|
||||
const SegmentOptions& options,
|
||||
SegmentNodesVector* segments) {
|
||||
if (!options.use_implicit_batch && !options.allow_dynamic_non_batch_dim) {
|
||||
return errors::Internal(
|
||||
"Explicit batch mode should allow dynamic non-batch dimensions");
|
||||
}
|
||||
|
||||
if (!options.allow_dynamic_non_batch_dim && !graph_properties) {
|
||||
return errors::Internal(
|
||||
"Need graph propertities to disallow dynamic non-batch dimensions");
|
||||
}
|
||||
|
||||
// Steps:
|
||||
// 1. run the segmentation algorithm to find all the segments, which uses
|
||||
// candidate_fn to determine the candidates segment nodes;
|
||||
@ -441,34 +517,31 @@ Status SegmentGraph(const Graph* tf_graph,
|
||||
std::vector<UnionFind<SimpleNode*>> node_segments;
|
||||
for (int i = 0; i < graph->num_node_ids(); ++i) {
|
||||
SimpleNode* node = graph->FindNodeId(i);
|
||||
if (options.exclude_node_list.count(node->name()) != 0) {
|
||||
auto exclude_node = [&](absl::string_view reason) {
|
||||
VLOG(1) << "Not a TF-TRT candidate, "
|
||||
<< "(Op type: " << node->tf_node()->type_string() << "), "
|
||||
<< "(Op name: " << node->name() << "), "
|
||||
<< "(Reason: excluded by segmenter option)";
|
||||
<< "(Reason: " << reason << ")";
|
||||
unsupported_ops.emplace(node->tf_node()->type_string());
|
||||
num_unsupported_ops++;
|
||||
node = nullptr;
|
||||
};
|
||||
if (options.exclude_node_list.count(node->name()) != 0) {
|
||||
exclude_node("excluded by segmenter option");
|
||||
} else if (!options.allow_dynamic_non_batch_dim &&
|
||||
OperationHasDynamicNonBatchDimension(graph_properties,
|
||||
node->tf_node())) {
|
||||
exclude_node("dynamic non-batch dimensions not allowed");
|
||||
} else {
|
||||
const Status status = candidate_fn(node->tf_node());
|
||||
if (!status.ok()) {
|
||||
VLOG(1) << "Not a TF-TRT candidate, "
|
||||
<< "(Op type: " << node->tf_node()->type_string() << "), "
|
||||
<< "(Op name: " << node->name() << "), "
|
||||
<< "(Reason: " << status << ")";
|
||||
unsupported_ops.emplace(node->tf_node()->type_string());
|
||||
num_unsupported_ops++;
|
||||
node = nullptr;
|
||||
exclude_node(status.error_message());
|
||||
} else if (tftrt_op_blacklist.count(node->tf_node()->type_string())) {
|
||||
// WARNING verbosity since the user explicitly requests this behavior.
|
||||
LOG(WARNING)
|
||||
<< "Blacklisted as TF-TRT candidate, "
|
||||
<< "(Op type: " << node->tf_node()->type_string() << "), "
|
||||
<< "(Op name: " << node->name() << "), "
|
||||
<< "(Reason: Blacklisted with the env var TF_TRT_OP_BLACKLIST)";
|
||||
unsupported_ops.emplace(node->tf_node()->type_string());
|
||||
num_unsupported_ops++;
|
||||
node = nullptr;
|
||||
LOG(WARNING) << "Blacklisted as TF-TRT candidate, "
|
||||
<< "(Op type: " << node->tf_node()->type_string() << "), "
|
||||
<< "(Op name: " << node->name() << ")";
|
||||
exclude_node("Blacklisted with the env var TF_TRT_OP_BLACKLIST");
|
||||
} else {
|
||||
VLOG(2) << "Accepted as a TF-TRT candidate, "
|
||||
<< "(Op type: " << node->tf_node()->type_string() << "), "
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -37,12 +38,17 @@ using SegmentNodesVector = std::vector<std::set<const Node*>>;
|
||||
struct SegmentOptions {
|
||||
// Segment must contain at least this many nodes.
|
||||
int minimum_segment_size = 2;
|
||||
bool use_implicit_batch = true;
|
||||
// When use_implicit_batch is false or when we are building dynamic engines,
|
||||
// we allow dynamic non-batch dimensions.
|
||||
bool allow_dynamic_non_batch_dim = false;
|
||||
std::set<string> exclude_node_list;
|
||||
};
|
||||
|
||||
// Get the subgraphs of a graph that can be handled by TensorRT.
|
||||
//
|
||||
// @param graph Graph of the network
|
||||
// @param tf_graph Graph of the network.
|
||||
// @graph_properties is the static graph properties.
|
||||
// @param candidate_fn A function that returns OK for a Node* if
|
||||
// that node can be handled by TensorRT.
|
||||
// @param segments Returns the TensorRT segments/subgraphs. Each entry
|
||||
@ -50,6 +56,7 @@ struct SegmentOptions {
|
||||
// all the NodeDefs in that subgraph.
|
||||
// @return the status.
|
||||
Status SegmentGraph(const Graph* tf_graph,
|
||||
const grappler::GraphProperties* graph_properties,
|
||||
const std::function<Status(const Node*)>& candidate_fn,
|
||||
const std::function<bool(const Edge*)>& input_candidate_fn,
|
||||
const std::function<bool(const Edge*)>& output_candidate_fn,
|
||||
|
@ -42,7 +42,7 @@ class SegmentTest : public ::testing::Test {
|
||||
if (node_names.find(node->name()) != node_names.end()) {
|
||||
return Status::OK();
|
||||
}
|
||||
return errors::NotFound("");
|
||||
return errors::NotFound("Not a user specified candidate");
|
||||
};
|
||||
}
|
||||
|
||||
@ -60,18 +60,29 @@ class SegmentTest : public ::testing::Test {
|
||||
};
|
||||
}
|
||||
|
||||
void RunTest(const Graph* graph, const std::set<string>& candidates,
|
||||
void RunTest(const Graph* graph,
|
||||
const grappler::GraphProperties* graph_properties,
|
||||
const std::set<string>& candidates,
|
||||
const std::set<string>& input_candidates,
|
||||
const std::set<string>& output_candidates,
|
||||
const std::vector<std::set<string>>& expected_segments) {
|
||||
SegmentNodesVector segments;
|
||||
TF_EXPECT_OK(SegmentGraph(graph, MakeCandidateFn(candidates),
|
||||
TF_EXPECT_OK(SegmentGraph(graph, graph_properties,
|
||||
MakeCandidateFn(candidates),
|
||||
MakeInputEdgeCandidateFn(input_candidates),
|
||||
MakeOutputEdgeCandidateFn(output_candidates),
|
||||
default_options_, &segments));
|
||||
segment_options_, &segments));
|
||||
ValidateSegment(segments, expected_segments);
|
||||
}
|
||||
|
||||
void RunTest(const Graph* graph, const std::set<string>& candidates,
|
||||
const std::set<string>& input_candidates,
|
||||
const std::set<string>& output_candidates,
|
||||
const std::vector<std::set<string>>& expected_segments) {
|
||||
RunTest(graph, nullptr, candidates, input_candidates, output_candidates,
|
||||
expected_segments);
|
||||
}
|
||||
|
||||
void ValidateSegment(const SegmentNodesVector& segments,
|
||||
const std::vector<std::set<string>>& expected_segments) {
|
||||
EXPECT_EQ(expected_segments.size(), segments.size());
|
||||
@ -93,7 +104,17 @@ class SegmentTest : public ::testing::Test {
|
||||
}
|
||||
}
|
||||
|
||||
SegmentOptions default_options_;
|
||||
void DisableImplicitBatchMode() {
|
||||
segment_options_.use_implicit_batch = false;
|
||||
segment_options_.allow_dynamic_non_batch_dim = true;
|
||||
}
|
||||
|
||||
void EnableImplicitBatchModeForStaticEngine() {
|
||||
segment_options_.use_implicit_batch = true;
|
||||
segment_options_.allow_dynamic_non_batch_dim = false;
|
||||
}
|
||||
|
||||
SegmentOptions segment_options_;
|
||||
};
|
||||
|
||||
std::set<string> operator-(const std::set<string>& lhs, const string& rhs) {
|
||||
@ -107,6 +128,7 @@ TEST_F(SegmentTest, Empty) {
|
||||
Graph g(OpRegistry::Global());
|
||||
TF_EXPECT_OK(s.ToGraph(&g));
|
||||
// Expect no segments/subgraphs.
|
||||
DisableImplicitBatchMode();
|
||||
RunTest(&g, {}, {}, {}, {});
|
||||
}
|
||||
|
||||
@ -133,6 +155,7 @@ TEST_F(SegmentTest, Simple) {
|
||||
// All Add operations are candidates, and we expect all of them to be
|
||||
// collapsed into a single segment
|
||||
const std::set<string> all_adds = {"add0", "add1", "add2", "add3", "add4"};
|
||||
DisableImplicitBatchMode();
|
||||
RunTest(&g, all_adds, all_adds, all_adds, {all_adds});
|
||||
|
||||
// Make add1 not a candidate, and we expect all other Add operations to be
|
||||
@ -179,6 +202,7 @@ TEST_F(SegmentTest, AvoidCycle) {
|
||||
|
||||
// add2 is not a TRT candidate so there should be no segments generated.
|
||||
const std::set<string> without_add2 = {"add0", "add1", "add3", "add4"};
|
||||
DisableImplicitBatchMode();
|
||||
RunTest(&g, without_add2, without_add2, without_add2, {});
|
||||
}
|
||||
|
||||
@ -212,6 +236,7 @@ TEST_F(SegmentTest, Multiple) {
|
||||
"add5", "add6", "add7", "add8"};
|
||||
// Make add5 not a TRT candidate, and we expect two segments.
|
||||
auto without_add5 = all_adds - "add5";
|
||||
DisableImplicitBatchMode();
|
||||
RunTest(&g, without_add5, without_add5, without_add5,
|
||||
{{"add0", "add1", "add2", "add3"}, {"add6", "add8"}});
|
||||
|
||||
@ -258,6 +283,7 @@ TEST_F(SegmentTest, BigIfElse) {
|
||||
// Make add2 not a TRT candidate, and we expect 2 segments.
|
||||
const std::set<string> all_adds = {"add0", "add1", "add2", "add3",
|
||||
"add4", "add5", "add6", "add7"};
|
||||
DisableImplicitBatchMode();
|
||||
RunTest(&g, all_adds - "add2", all_adds, all_adds,
|
||||
{{"add0", "add1"}, {"add3", "add4", "add5", "add6", "add7"}});
|
||||
}
|
||||
@ -276,9 +302,73 @@ TEST_F(SegmentTest, IdentityOps) {
|
||||
"identity2", "identity3"};
|
||||
// Identity ops are not counted as effective ops in the segment, so no segment
|
||||
// will be formed in this case.
|
||||
DisableImplicitBatchMode();
|
||||
RunTest(&g, all_identities, all_identities, all_identities, {});
|
||||
}
|
||||
|
||||
// Testing implicit batch mode segmentation: it excludes the add-2 operation
|
||||
// with a dynamic non-batch dimension.
|
||||
TEST_F(SegmentTest, ExcludeAddWithDynamicNonBatchDimension) {
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto feed_0_shape = ops::Placeholder::Shape(PartialTensorShape({-1, 2, 3}));
|
||||
auto feed_1_shape = ops::Placeholder::Shape(PartialTensorShape({-1, -1, 3}));
|
||||
auto const_val = ops::Const<float>(s, {1.0}, {});
|
||||
auto feed_0 =
|
||||
ops::Placeholder(s.WithOpName("feed-1"), DT_FLOAT, feed_0_shape);
|
||||
auto feed_1 =
|
||||
ops::Placeholder(s.WithOpName("feed-2"), DT_FLOAT, feed_1_shape);
|
||||
auto add_0 = ops::Add(s.WithOpName("add-0"), feed_0, const_val);
|
||||
auto add_1 = ops::Add(s.WithOpName("add-1"), add_0, feed_0);
|
||||
auto add_2 = ops::Add(s.WithOpName("add-2"), const_val, feed_1);
|
||||
|
||||
grappler::GrapplerItem item;
|
||||
item.fetch.push_back("add-2");
|
||||
TF_EXPECT_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
grappler::GraphProperties static_graph_properties(item);
|
||||
TF_EXPECT_OK(static_graph_properties.InferStatically(true));
|
||||
|
||||
Graph g(OpRegistry::Global());
|
||||
TF_CHECK_OK(
|
||||
ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g));
|
||||
|
||||
const std::set<string> all_nodes = {"add-0", "add-1", "add-2"};
|
||||
EnableImplicitBatchModeForStaticEngine();
|
||||
RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes,
|
||||
{all_nodes - "add-2"});
|
||||
}
|
||||
|
||||
// Testing implicit batch mode segmentation: It excludes the reshape operation
|
||||
// with a dynamic non-batch output dimension.
|
||||
// TODO(bixia): hoist the check for reshape should not change batch size from
|
||||
// the converter to the segmenter and add another test case for excluding
|
||||
// a reshape without dynamic dimensions involved.
|
||||
TEST_F(SegmentTest, ExcludeReshapeWithDynamicNonBatchDimensionInOutput) {
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto feed_0_shape = ops::Placeholder::Shape(PartialTensorShape({-1, 2, 3}));
|
||||
auto const_val = ops::Const<float>(s, {1.0}, {});
|
||||
auto feed_0 =
|
||||
ops::Placeholder(s.WithOpName("feed-1"), DT_FLOAT, feed_0_shape);
|
||||
auto add_0 = ops::Add(s.WithOpName("add-0"), feed_0, const_val);
|
||||
auto reshape = ops::Reshape(s.WithOpName("reshape"), add_0, Input({6, -1}));
|
||||
auto add_1 = ops::Add(s.WithOpName("add-1"), reshape, const_val);
|
||||
|
||||
grappler::GrapplerItem item;
|
||||
item.fetch.push_back("add-1");
|
||||
TF_EXPECT_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
grappler::GraphProperties static_graph_properties(item);
|
||||
TF_EXPECT_OK(static_graph_properties.InferStatically(true));
|
||||
|
||||
Graph g(OpRegistry::Global());
|
||||
TF_CHECK_OK(
|
||||
ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g));
|
||||
|
||||
const std::set<string> all_nodes = {"add-0", "reshape", "add-1"};
|
||||
EnableImplicitBatchModeForStaticEngine();
|
||||
RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, {});
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace segment
|
||||
} // namespace tensorrt
|
||||
|
@ -65,7 +65,8 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase):
|
||||
x5 = math_ops.matmul(x, b)
|
||||
b = self._ConstOp((48,))
|
||||
x5 = nn.bias_add(x5, b)
|
||||
x5 = gen_array_ops.reshape(x5, [4, -1])
|
||||
# TODO(b/154672994): Put the reshape back when the bug is fixed.
|
||||
# x5 = gen_array_ops.reshape(x5, [4, -1])
|
||||
|
||||
x6 = gen_array_ops.reshape(x, [4, 24, 6])
|
||||
b = self._ConstOp((6,))
|
||||
|
Loading…
x
Reference in New Issue
Block a user