Add a graph transform utility for hexagon
Change: 147385554
This commit is contained in:
parent
bad13d7a31
commit
223ae54143
tensorflow
contrib/makefile
core/kernels/hexagon
BUILDgraph_transfer_utils.cchexagon_graph_execution_test.cchexagon_rewriter_transform.cchexagon_rewriter_transform_test.cc
tools/graph_transforms
@ -90,7 +90,7 @@ if [[ "${DOWNLOAD_AND_USE_HEXAGON}" == "true" ]]; then
|
||||
fi
|
||||
|
||||
if [[ ! -z "${HEXAGON_LIB_PATH}" ]]; then
|
||||
echo "Copy hexagon libraries"
|
||||
echo "Copy hexagon libraries from ${HEXAGON_LIB_PATH}"
|
||||
|
||||
mkdir -p "${HEXAGON_DOWNLOAD_PATH}/libs"
|
||||
cp -fv "${HEXAGON_LIB_PATH}/libhexagon_controller.so" \
|
||||
|
@ -50,6 +50,7 @@ tf_cc_test(
|
||||
"graph_transferer_test.cc",
|
||||
"hexagon_graph_execution_test.cc",
|
||||
],
|
||||
data = ["//tensorflow/core:example_parser_configuration_testdata"],
|
||||
deps = [
|
||||
":graph_transferer",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
@ -87,7 +88,6 @@ tf_kernel_library(
|
||||
"i_graph_transfer_ops_definitions.h",
|
||||
"i_soc_control_wrapper.h",
|
||||
],
|
||||
data = ["//tensorflow/core:example_parser_configuration_testdata"],
|
||||
deps = [
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:remote_fused_graph_ops",
|
||||
@ -99,3 +99,40 @@ tf_kernel_library(
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hexagon_rewriter_transform",
|
||||
srcs = [
|
||||
"hexagon_rewriter_transform.cc",
|
||||
],
|
||||
deps = [
|
||||
":graph_transferer",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:remote_fused_graph_ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/core",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/tools/graph_transforms:transform_utils",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "hexagon_rewriter_transform_test",
|
||||
size = "small",
|
||||
srcs = ["hexagon_rewriter_transform_test.cc"],
|
||||
deps = [
|
||||
":hexagon_rewriter_transform",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/tools/graph_transforms:transform_utils",
|
||||
],
|
||||
)
|
||||
|
@ -58,7 +58,7 @@ GraphTransferUtils::GetTopNFloatResults(const float* const data,
|
||||
CHECK(gt != nullptr);
|
||||
GraphTransferer::OutputTensorInfo output_tensor_info;
|
||||
Status status = gt->DryRunInferenceForAllNode(
|
||||
def, inputs, false /* initialize_by_zero */, &output_tensor_info);
|
||||
def, inputs, true /* initialize_by_zero */, &output_tensor_info);
|
||||
CHECK(status.ok());
|
||||
status = gt->LoadGraphFromProto(ops_definitions, def, inputs, outputs, false,
|
||||
output_tensor_info.output_tensor_map);
|
||||
|
@ -12,15 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// Before calling this test program, download a model as follows.
|
||||
// $ curl https://storage.googleapis.com/download.tensorflow.org/models/tensorflow_inception_v3_stripped_optimized_quantized.pb \
|
||||
// -o /tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb
|
||||
// adb push /tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb \
|
||||
// /data/local/tmp
|
||||
// $ curl
|
||||
// https://storage.googleapis.com/download.tensorflow.org/models/imagenet_comp_graph_label_strings.txt
|
||||
// -o /tmp/imagenet_comp_graph_label_strings.txt
|
||||
// adb push /tmp/imagenet_comp_graph_label_strings.txt /data/local/tmp
|
||||
/* Before calling this test program, download a model as follows.
|
||||
$ curl
|
||||
https://storage.googleapis.com/download.tensorflow.org/models/tensorflow_inception_v3_stripped_optimized_quantized.pb
|
||||
\ -o /tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb
|
||||
$ adb push /tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb \
|
||||
/data/local/tmp
|
||||
$ curl
|
||||
https://storage.googleapis.com/download.tensorflow.org/models/imagenet_comp_graph_label_strings.txt
|
||||
-o /tmp/imagenet_comp_graph_label_strings.txt
|
||||
adb push /tmp/imagenet_comp_graph_label_strings.txt /data/local/tmp
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
|
||||
@ -49,8 +51,12 @@ using ConstByteArray = ISocControlWrapper::ConstByteArray;
|
||||
constexpr const char* const IMAGE_FILENAME = "/data/local/tmp/img_299x299.bmp";
|
||||
constexpr const char* const MODEL_FILENAME =
|
||||
"/data/local/tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb";
|
||||
constexpr const char* const FUSED_MODEL_FILENAME =
|
||||
"/data/local/tmp/"
|
||||
"tensorflow_inception_v3_stripped_optimized_quantized_fused_hexagon.pb";
|
||||
constexpr const char* const REMOTE_FUSED_GRAPH_EXECUTE_NODE_NAME =
|
||||
"remote_fused_graph_execute_node";
|
||||
|
||||
const bool USE_TF_RUNTIME = true;
|
||||
const bool DBG_DUMP_FLOAT_DATA = false;
|
||||
const int WIDTH = 299;
|
||||
const int HEIGHT = 299;
|
||||
@ -58,6 +64,13 @@ const int DEPTH = 3;
|
||||
const int EXPECTED_FIRST_RESULT_ID = 59;
|
||||
const int EXECUTION_REPEAT_COUNT = 3;
|
||||
|
||||
static void CheckHexagonControllerVersion() {
|
||||
HexagonControlWrapper hexagon_control_wrapper;
|
||||
const int version = hexagon_control_wrapper.GetVersion();
|
||||
ASSERT_GE(version, 1);
|
||||
LOG(INFO) << "Hexagon controller version is " << version;
|
||||
}
|
||||
|
||||
static void DumpTop10Results(const int byte_size,
|
||||
const float* const float_array) {
|
||||
const int element_count = byte_size / sizeof(float);
|
||||
@ -159,9 +172,6 @@ static void RunInferenceByHexagonControlWrapper(
|
||||
img_floats.size() * sizeof(float), DT_FLOAT);
|
||||
|
||||
HexagonControlWrapper hexagon_control_wrapper;
|
||||
const int version = hexagon_control_wrapper.GetVersion();
|
||||
ASSERT_GE(version, 1);
|
||||
LOG(INFO) << "Hexagon controller version is " << version;
|
||||
// 1. Initialize hexagon
|
||||
hexagon_control_wrapper.Init();
|
||||
|
||||
@ -196,13 +206,61 @@ static void RunInferenceByHexagonControlWrapper(
|
||||
hexagon_control_wrapper.Finalize();
|
||||
}
|
||||
|
||||
static void RunFusedGraph(const GraphDef& fused_graph_def) {
|
||||
// Setup input tensor
|
||||
std::vector<float> img_floats;
|
||||
LoadImage(&img_floats);
|
||||
|
||||
LOG(INFO) << "Ioading image finished.";
|
||||
Tensor img_tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH});
|
||||
ASSERT_EQ(WIDTH * HEIGHT * DEPTH, img_floats.size());
|
||||
ASSERT_EQ(img_tensor.TotalBytes(), img_floats.size() * sizeof(float));
|
||||
|
||||
LOG(INFO) << "Copy data to tensor.";
|
||||
std::memcpy(img_tensor.flat<float>().data(), img_floats.data(),
|
||||
img_tensor.TotalBytes());
|
||||
|
||||
// Setup session
|
||||
std::vector<Tensor> output_tensors;
|
||||
SessionOptions session_options;
|
||||
session_options.env = Env::Default();
|
||||
std::unique_ptr<Session> session =
|
||||
std::unique_ptr<Session>(NewSession(session_options));
|
||||
Status status = session->Create(fused_graph_def);
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
// Setup session arguments
|
||||
RunOptions run_options;
|
||||
run_options.set_trace_level(RunOptions::FULL_TRACE);
|
||||
RunMetadata run_metadata;
|
||||
|
||||
std::vector<std::pair<string, tensorflow::Tensor>> input_tensors;
|
||||
input_tensors.emplace_back("Mul", img_tensor);
|
||||
std::vector<string> output_node_names;
|
||||
output_node_names.emplace_back(REMOTE_FUSED_GRAPH_EXECUTE_NODE_NAME);
|
||||
|
||||
LOG(INFO) << "Run graph";
|
||||
// Run inference with all node as output
|
||||
status = session->Run(run_options, input_tensors, output_node_names, {},
|
||||
&output_tensors, &run_metadata);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(1, output_tensors.size());
|
||||
const Tensor& output_tensor = output_tensors.at(0);
|
||||
LOG(INFO) << "Output byte size = " << output_tensor.TotalBytes();
|
||||
LOG(INFO) << "Output shape = " << output_tensor.shape().DebugString();
|
||||
DumpTop10Results(output_tensor.TotalBytes(),
|
||||
output_tensor.flat<float>().data());
|
||||
}
|
||||
|
||||
// CAVEAT: This test only runs when you specify hexagon library using
|
||||
// makefile.
|
||||
// TODO(satok): Make this generic so that this can run without any
|
||||
// additional steps.
|
||||
#ifdef USE_HEXAGON_LIBS
|
||||
TEST(GraphTransferer, RunInceptionV3OnHexagonExample) {
|
||||
if (USE_TF_RUNTIME) return;
|
||||
LOG(INFO) << "Run inception v3 on hexagon with hexagon controller";
|
||||
CheckHexagonControllerVersion();
|
||||
|
||||
const IGraphTransferOpsDefinitions* ops_definitions =
|
||||
&HexagonOpsDefinitions::getInstance();
|
||||
std::vector<GraphTransferer::InputNodeInfo> input_node_info_list = {
|
||||
@ -226,31 +284,22 @@ TEST(GraphTransferer, RunInceptionV3OnHexagonExample) {
|
||||
}
|
||||
|
||||
TEST(GraphTransferer, RunInceptionV3OnHexagonExampleWithTfRuntime) {
|
||||
if (!USE_TF_RUNTIME) return;
|
||||
LOG(INFO) << "Fuse and run inception v3 on hexagon with tf runtime";
|
||||
CheckHexagonControllerVersion();
|
||||
|
||||
const IGraphTransferOpsDefinitions* ops_definitions =
|
||||
&HexagonOpsDefinitions::getInstance();
|
||||
std::vector<GraphTransferer::InputNodeInfo> inputs = {
|
||||
GraphTransferer::InputNodeInfo{
|
||||
"Mul", Tensor{DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}}}};
|
||||
std::vector<string> outputs = {"softmax"};
|
||||
const bool is_text_proto = false;
|
||||
|
||||
std::vector<float> img_floats;
|
||||
LoadImage(&img_floats);
|
||||
|
||||
LOG(INFO) << "Ioading image finished.";
|
||||
|
||||
Tensor img_tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH});
|
||||
ASSERT_EQ(WIDTH * HEIGHT * DEPTH, img_floats.size());
|
||||
ASSERT_EQ(img_tensor.TotalBytes(), img_floats.size() * sizeof(float));
|
||||
|
||||
LOG(INFO) << "Copy data to tensor.";
|
||||
|
||||
std::memcpy(img_tensor.flat<float>().data(), img_floats.data(),
|
||||
img_tensor.TotalBytes());
|
||||
|
||||
GraphDef graph_def;
|
||||
|
||||
Status status = ReadBinaryProto(Env::Default(), MODEL_FILENAME, &graph_def);
|
||||
|
||||
ASSERT_TRUE(status.ok());
|
||||
@ -259,40 +308,22 @@ TEST(GraphTransferer, RunInceptionV3OnHexagonExampleWithTfRuntime) {
|
||||
GraphTransferer gt;
|
||||
gt.EnableStrictCheckMode(false);
|
||||
GraphDef fused_graph_def = GraphTransferUtils::BuildFusedGraphDef(
|
||||
HexagonOpsDefinitions::getInstance(), "remote_fused_graph_execute_node",
|
||||
inputs, outputs, graph_def, >);
|
||||
HexagonOpsDefinitions::getInstance(),
|
||||
REMOTE_FUSED_GRAPH_EXECUTE_NODE_NAME, inputs, outputs, graph_def, >);
|
||||
|
||||
// Setup session
|
||||
std::vector<Tensor> output_tensors;
|
||||
SessionOptions session_options;
|
||||
session_options.env = Env::Default();
|
||||
std::unique_ptr<Session> session =
|
||||
std::unique_ptr<Session>(NewSession(session_options));
|
||||
status = session->Create(fused_graph_def);
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
// Setup session arguments
|
||||
RunOptions run_options;
|
||||
run_options.set_trace_level(RunOptions::FULL_TRACE);
|
||||
RunMetadata run_metadata;
|
||||
|
||||
std::vector<std::pair<string, tensorflow::Tensor>> input_tensors;
|
||||
input_tensors.emplace_back("Mul", img_tensor);
|
||||
std::vector<string> output_node_names;
|
||||
output_node_names.emplace_back("remote_fused_graph_execute_node");
|
||||
|
||||
LOG(INFO) << "Run graph";
|
||||
// Run inference with all node as output
|
||||
status = session->Run(run_options, input_tensors, output_node_names, {},
|
||||
&output_tensors, &run_metadata);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(1, output_tensors.size());
|
||||
const Tensor& output_tensor = output_tensors.at(0);
|
||||
LOG(INFO) << "Output byte size = " << output_tensor.TotalBytes();
|
||||
LOG(INFO) << "Output shape = " << output_tensor.shape().DebugString();
|
||||
DumpTop10Results(output_tensor.TotalBytes(),
|
||||
output_tensor.flat<float>().data());
|
||||
RunFusedGraph(fused_graph_def);
|
||||
}
|
||||
|
||||
TEST(GraphTransferer, RunInceptionV3OnHexagonExampleWithFusedGraph) {
|
||||
LOG(INFO) << "Run inception v3 with fused graph";
|
||||
CheckHexagonControllerVersion();
|
||||
|
||||
GraphDef fused_graph_def;
|
||||
Status status =
|
||||
ReadBinaryProto(Env::Default(), FUSED_MODEL_FILENAME, &fused_graph_def);
|
||||
RunFusedGraph(fused_graph_def);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -0,0 +1,91 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Wraps the hexagon rewriter in a transform so it can be used as part of the
|
||||
// graph transform tool.
|
||||
// A usage example, based on the Image Understanding pipeline:
|
||||
/*
|
||||
bazel build tensorflow/tools/graph_transforms:transform_graph
|
||||
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
|
||||
--in_graph=/tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb \
|
||||
--out_graph=\
|
||||
/tmp/tensorflow_inception_v3_stripped_optimized_quantized_fused_hexagon.pb \
|
||||
--inputs='Mul' \
|
||||
--outputs='softmax' \
|
||||
--transforms='\
|
||||
rewrite_quantized_stripped_model_for_hexagon(
|
||||
input_shape0="1,299,299,3" \
|
||||
input_type0="float" \
|
||||
)'
|
||||
*/
|
||||
|
||||
#include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h"
|
||||
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
constexpr const char* const INPUT_SHAPE_PREFIX = "input_shape";
|
||||
constexpr const char* const INPUT_TYPE_PREFIX = "input_type";
|
||||
|
||||
Status RewriteQuantizedStrippedModelForHexagon(
|
||||
const GraphDef& input_graph_def, const TransformFuncContext& context,
|
||||
GraphDef* output_graph_def) {
|
||||
LOG(INFO) << "Transforming quantized stripped model to a remote fused "
|
||||
"graph execute op...";
|
||||
std::vector<GraphTransferer::InputNodeInfo> inputs;
|
||||
std::vector<string> outputs;
|
||||
for (int i = 0; i < context.input_names.size(); ++i) {
|
||||
const string& input_name = context.input_names.at(i);
|
||||
|
||||
// Get input shape
|
||||
string shape_string;
|
||||
TF_RETURN_IF_ERROR(context.GetOneStringParameter(
|
||||
INPUT_SHAPE_PREFIX + std::to_string(i), "", &shape_string));
|
||||
std::vector<int64> dims;
|
||||
CHECK(str_util::SplitAndParseAsInts(shape_string, ',', &dims));
|
||||
|
||||
// Get input data type
|
||||
string data_type_string;
|
||||
TF_RETURN_IF_ERROR(context.GetOneStringParameter(
|
||||
INPUT_TYPE_PREFIX + std::to_string(i), "", &data_type_string));
|
||||
DataType data_type;
|
||||
CHECK(DataTypeFromString(data_type_string, &data_type))
|
||||
<< "\"" << data_type_string << "\" was an invalid type";
|
||||
|
||||
LOG(INFO) << "Input(" << i << "): name = " << input_name
|
||||
<< ", shape = " << shape_string
|
||||
<< ", type = " << data_type_string;
|
||||
|
||||
inputs.emplace_back(GraphTransferer::InputNodeInfo{
|
||||
input_name, {data_type, TensorShape(dims)}});
|
||||
}
|
||||
|
||||
for (const string& output_name : context.output_names) {
|
||||
outputs.emplace_back(output_name);
|
||||
}
|
||||
GraphTransferer gt;
|
||||
gt.EnableStrictCheckMode(false);
|
||||
*output_graph_def = GraphTransferUtils::BuildFusedGraphDef(
|
||||
HexagonOpsDefinitions::getInstance(), "remote_fused_graph_execute_node",
|
||||
inputs, outputs, input_graph_def, >);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_GRAPH_TRANSFORM("rewrite_quantized_stripped_model_for_hexagon",
|
||||
RewriteQuantizedStrippedModelForHexagon);
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
@ -0,0 +1,81 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/image_ops.h"
|
||||
#include "tensorflow/cc/ops/nn_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/default_device.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/graph/testlib.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
|
||||
// Declared here so we don't have to put it in a public header.
|
||||
Status RewriteQuantizedStrippedModelForHexagon(
|
||||
const GraphDef& input_graph_def, const TransformFuncContext& context,
|
||||
GraphDef* output_graph_def);
|
||||
|
||||
namespace {
|
||||
|
||||
TEST(HexagonRewriteTransformTest, BasicRun) {
|
||||
Scope root = tensorflow::Scope::NewRootScope();
|
||||
|
||||
// Create a simple graph that calculates (a + b) * placeholder.
|
||||
Tensor a_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
|
||||
test::FillIota<float>(&a_data, 1.0f);
|
||||
Output a_const = ops::Const(root.WithOpName("a"), Input::Initializer(a_data));
|
||||
|
||||
Tensor b_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
|
||||
test::FillIota<float>(&b_data, 1.0f);
|
||||
Output b_const = ops::Const(root.WithOpName("b"), Input::Initializer(b_data));
|
||||
|
||||
Output add = ops::Add(root.WithOpName("add"), a_const, b_const);
|
||||
|
||||
Output placeholder =
|
||||
ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
|
||||
|
||||
Output mul = ops::Mul(root.WithOpName("output"), add, placeholder);
|
||||
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
|
||||
GraphDef result;
|
||||
TransformFuncContext context;
|
||||
context.input_names = {"placeholder"};
|
||||
context.output_names = {"output"};
|
||||
context.params.insert(std::pair<string, std::vector<string>>(
|
||||
{"input_shape0", {string("1,1,1,1")}}));
|
||||
context.params.insert(std::pair<string, std::vector<string>>(
|
||||
{"input_type0", {string("float")}}));
|
||||
TF_ASSERT_OK(
|
||||
RewriteQuantizedStrippedModelForHexagon(graph_def, context, &result));
|
||||
|
||||
// Node in the input graph is fused to
|
||||
// 1 input placeholder node + 1 fused output node
|
||||
EXPECT_EQ(2, result.node_size());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
@ -95,6 +95,7 @@ cc_library(
|
||||
"//tensorflow/core:tensorflow",
|
||||
] + if_not_windows([
|
||||
"//tensorflow/core/kernels:quantized_ops",
|
||||
"//tensorflow/core/kernels/hexagon:hexagon_rewriter_transform",
|
||||
]),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user