Add a way to fuse a graph by remote graph executor so that users don't need to be aware of supported op types, node names, subgraph stracture etc.

PiperOrigin-RevId: 162411763
This commit is contained in:
A. Unique TensorFlower 2017-07-18 15:20:03 -07:00 committed by TensorFlower Gardener
parent d8672f1839
commit d5f4d9bbac
31 changed files with 474 additions and 299 deletions

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h" #include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h" #include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h" #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/str_util.h"
@ -110,7 +111,7 @@ static void DumpRemoteFusedGraph(const NodeDef& node_def) {
static void CheckOpsSupport(const GraphDef& graph_def, static void CheckOpsSupport(const GraphDef& graph_def,
const bool dump_all_nodes, const bool dump_all_nodes,
const bool dump_shape_and_type) { const bool dump_shape_and_type) {
const IGraphTransferOpsDefinitions& ops_definition = const IRemoteFusedGraphOpsDefinitions& ops_definition =
HexagonOpsDefinitions::getInstance(); HexagonOpsDefinitions::getInstance();
LOG(INFO) << "Checking " << graph_def.node_size() << " nodes"; LOG(INFO) << "Checking " << graph_def.node_size() << " nodes";
LOG(INFO) << "dump_all_nodes = " << dump_all_nodes LOG(INFO) << "dump_all_nodes = " << dump_all_nodes
@ -127,7 +128,7 @@ static void CheckOpsSupport(const GraphDef& graph_def,
} }
// TODO(satok): Set correct data type if it's given. // TODO(satok): Set correct data type if it's given.
const int op_id = ops_definition.GetOpIdFor(node.op(), {}); const int op_id = ops_definition.GetOpIdFor(node.op(), {});
if (op_id == IGraphTransferOpsDefinitions::INVALID_OP_ID) { if (op_id == IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) {
all_supported = false; all_supported = false;
LOG(ERROR) << "OP type: " << node.op() << " is not supported on hvx. " LOG(ERROR) << "OP type: " << node.op() << " is not supported on hvx. "
<< "Name = " << node.name(); << "Name = " << node.name();

View File

@ -82,7 +82,7 @@ fi
if [[ "${USE_HEXAGON}" == "true" ]]; then if [[ "${USE_HEXAGON}" == "true" ]]; then
HEXAGON_PARENT_DIR=$(cd "${HEXAGON_DOWNLOAD_PATH}" >/dev/null && pwd) HEXAGON_PARENT_DIR=$(cd "${HEXAGON_DOWNLOAD_PATH}" >/dev/null && pwd)
HEXAGON_LIBS="${HEXAGON_PARENT_DIR}/libs" HEXAGON_LIBS="${HEXAGON_PARENT_DIR}/libs"
HEXAGON_INCLUDE=$(cd "tensorflow/core/platform/hexagon" >/dev/null && pwd) HEXAGON_INCLUDE=$(cd "tensorflow/core/kernels/hexagon" >/dev/null && pwd)
fi fi
if [[ "${ENABLE_EXPERIMENTAL_HEXNN_OPS}" == "true" ]]; then if [[ "${ENABLE_EXPERIMENTAL_HEXNN_OPS}" == "true" ]]; then

View File

@ -56,7 +56,6 @@ tensorflow/core/platform/posix/test.cc
QUANTIZATION_TEST_SRCS := \ QUANTIZATION_TEST_SRCS := \
$(GRAPH_TRANSFER_SRCS) \ $(GRAPH_TRANSFER_SRCS) \
tensorflow/core/kernels/hexagon/quantized_matmul_op_for_hexagon_test.cc \
tensorflow/core/kernels/hexagon/graph_transferer_test.cc \ tensorflow/core/kernels/hexagon/graph_transferer_test.cc \
tensorflow/contrib/makefile/test/test_main.cc tensorflow/contrib/makefile/test/test_main.cc

View File

@ -5057,9 +5057,13 @@ tf_kernel_library(
cc_library( cc_library(
name = "remote_fused_graph_execute_utils", name = "remote_fused_graph_execute_utils",
srcs = ["remote_fused_graph_execute_utils.cc"], srcs = [
"i_remote_fused_graph_ops_definitions.cc",
"remote_fused_graph_execute_utils.cc",
],
hdrs = [ hdrs = [
"i_remote_fused_graph_executor.h", "i_remote_fused_graph_executor.h",
"i_remote_fused_graph_ops_definitions.h",
"remote_fused_graph_execute_utils.h", "remote_fused_graph_execute_utils.h",
], ],
deps = [ deps = [
@ -5078,6 +5082,7 @@ cc_library(
srcs = ["remote_fused_graph_execute_op_test_utils.cc"], srcs = ["remote_fused_graph_execute_op_test_utils.cc"],
hdrs = ["remote_fused_graph_execute_op_test_utils.h"], hdrs = ["remote_fused_graph_execute_op_test_utils.h"],
deps = [ deps = [
":remote_fused_graph_execute_utils",
"//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops",
"//tensorflow/cc:ops", "//tensorflow/cc:ops",
"//tensorflow/cc:scope", "//tensorflow/cc:scope",

View File

@ -26,23 +26,6 @@ filegroup(
visibility = ["//tensorflow:__subpackages__"], visibility = ["//tensorflow:__subpackages__"],
) )
tf_cc_test(
name = "quantized_matmul_op_for_hexagon_test",
size = "small",
srcs = ["quantized_matmul_op_for_hexagon_test.cc"],
tags = ["nomsan"], # http://b/32242946
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:ops_testutil",
"//tensorflow/core/kernels:ops_util",
"//tensorflow/core/kernels:quantized_ops",
],
)
tf_cc_test( tf_cc_test(
name = "graph_transferer_test", name = "graph_transferer_test",
size = "small", size = "small",
@ -79,14 +62,14 @@ tf_kernel_library(
"graph_transferer.cc", "graph_transferer.cc",
"hexagon_control_wrapper.cc", "hexagon_control_wrapper.cc",
"hexagon_ops_definitions.cc", "hexagon_ops_definitions.cc",
"i_graph_transfer_ops_definitions.cc", "soc_interface.cc",
], ],
hdrs = [ hdrs = [
"graph_transfer_utils.h", "graph_transfer_utils.h",
"graph_transferer.h", "graph_transferer.h",
"hexagon_control_wrapper.h", "hexagon_control_wrapper.h",
"hexagon_ops_definitions.h", "hexagon_ops_definitions.h",
"i_graph_transfer_ops_definitions.h", "soc_interface.h",
], ],
deps = [ deps = [
"//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops",
@ -111,6 +94,7 @@ cc_library(
":graph_transferer", ":graph_transferer",
"//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops",
"//tensorflow/cc:remote_fused_graph_ops", "//tensorflow/cc:remote_fused_graph_ops",
"//tensorflow/core/kernels:remote_fused_graph_execute_utils",
"//tensorflow/tools/graph_transforms:transform_utils", "//tensorflow/tools/graph_transforms:transform_utils",
], ],
alwayslink = 1, alwayslink = 1,
@ -121,6 +105,7 @@ tf_cc_test(
size = "small", size = "small",
srcs = ["hexagon_rewriter_transform_test.cc"], srcs = ["hexagon_rewriter_transform_test.cc"],
deps = [ deps = [
":graph_transferer",
":hexagon_rewriter_transform", ":hexagon_rewriter_transform",
"//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
@ -129,6 +114,7 @@ tf_cc_test(
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core:testlib", "//tensorflow/core:testlib",
"//tensorflow/core/kernels:remote_fused_graph_execute_utils",
"//tensorflow/tools/graph_transforms:transform_utils", "//tensorflow/tools/graph_transforms:transform_utils",
], ],
) )

View File

@ -96,7 +96,7 @@ GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo(
} }
/* static */ GraphDef GraphTransferUtils::BuildFusedGraphDef( /* static */ GraphDef GraphTransferUtils::BuildFusedGraphDef(
const IGraphTransferOpsDefinitions& ops_definitions, const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const string& remote_graph_execute_name, const string& remote_graph_execute_name,
const std::vector<std::pair<string, Tensor>>& inputs, const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& outputs, GraphDef* original_def) { const std::vector<string>& outputs, GraphDef* original_def) {

View File

@ -39,7 +39,7 @@ class GraphTransferUtils {
const int element_count, const int top_n); const int element_count, const int top_n);
static GraphDef BuildFusedGraphDef( static GraphDef BuildFusedGraphDef(
const IGraphTransferOpsDefinitions& ops_definitions, const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const string& remote_graph_execute_name, const string& remote_graph_execute_name,
const std::vector<std::pair<string, Tensor>>& inputs, const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& outputs, GraphDef* original_def); const std::vector<string>& outputs, GraphDef* original_def);

View File

@ -81,7 +81,7 @@ static Node* FindMutableNodeByName(const string& name, Graph* graph) {
* of node to transfer the graph to SOC. * of node to transfer the graph to SOC.
*/ */
Status GraphTransferer::LoadGraphFromProto( Status GraphTransferer::LoadGraphFromProto(
const IGraphTransferOpsDefinitions& ops_definitions, const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const GraphDef& graph_def, const GraphDef& graph_def,
const std::vector<std::pair<string, Tensor>>& input_node_info_list, const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const std::vector<string>& output_node_names, const std::vector<string>& output_node_names,
@ -177,9 +177,6 @@ Status GraphTransferer::LoadGraphFromProto(
} }
} }
graph_transfer_info_.set_destination(
ops_definitions.GetTransferDestination());
ClearCache(); ClearCache();
if (DBG_DUMP_PARAMS) { if (DBG_DUMP_PARAMS) {
DumpNodeTransferParams(); DumpNodeTransferParams();
@ -191,7 +188,7 @@ Status GraphTransferer::LoadGraphFromProto(
} }
Status GraphTransferer::LoadGraphFromProtoFile( Status GraphTransferer::LoadGraphFromProtoFile(
const IGraphTransferOpsDefinitions& ops_definitions, const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const string& graph_def_path, const string& graph_def_path,
const std::vector<std::pair<string, Tensor>>& input_node_info_list, const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const std::vector<string>& output_node_names, const bool is_text_proto, const std::vector<string>& output_node_names, const bool is_text_proto,
@ -415,7 +412,7 @@ Status GraphTransferer::TransformGraphToAddAggregatedInputNode(
} }
Status GraphTransferer::RegisterNode( Status GraphTransferer::RegisterNode(
const IGraphTransferOpsDefinitions& ops_definitions, const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node, const ShapeRefiner& shape_refiner, const Node& node,
const std::vector<std::pair<string, Tensor>>& input_node_info_list, const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const std::vector<string>& output_node_names) { const std::vector<string>& output_node_names) {
@ -438,7 +435,7 @@ Status GraphTransferer::RegisterNode(
} else if (IsNodeFlattenReshape(node, shape_refiner)) { } else if (IsNodeFlattenReshape(node, shape_refiner)) {
RegisterFlattenNode(ops_definitions, shape_refiner, node); RegisterFlattenNode(ops_definitions, shape_refiner, node);
} else if (ops_definitions.GetOpIdFor(node.type_string(), {}) != } else if (ops_definitions.GetOpIdFor(node.type_string(), {}) !=
IGraphTransferOpsDefinitions::INVALID_OP_ID) { IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) {
// TODO(satok): Set correct data type if it's given. // TODO(satok): Set correct data type if it's given.
RegisterGenericNode(ops_definitions, shape_refiner, node); RegisterGenericNode(ops_definitions, shape_refiner, node);
} else { } else {
@ -637,7 +634,7 @@ bool GraphTransferer::IsNodeFlattenReshape(const Node& node,
} }
void GraphTransferer::RegisterNodeWithPaddingAndStrides( void GraphTransferer::RegisterNodeWithPaddingAndStrides(
const IGraphTransferOpsDefinitions& ops_definitions, const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node) { const ShapeRefiner& shape_refiner, const Node& node) {
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1); CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
const int id = node_name_to_id_cache_map_[node.name()]; const int id = node_name_to_id_cache_map_[node.name()];
@ -671,7 +668,7 @@ void GraphTransferer::RegisterNodeWithPaddingAndStrides(
} }
void GraphTransferer::RegisterNodeWithRank( void GraphTransferer::RegisterNodeWithRank(
const IGraphTransferOpsDefinitions& ops_definitions, const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node) { const ShapeRefiner& shape_refiner, const Node& node) {
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1); CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
const int id = node_name_to_id_cache_map_[node.name()]; const int id = node_name_to_id_cache_map_[node.name()];
@ -704,7 +701,7 @@ void GraphTransferer::RegisterNodeWithRank(
} }
void GraphTransferer::RegisterPadNode( void GraphTransferer::RegisterPadNode(
const IGraphTransferOpsDefinitions& ops_definitions, const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node) { const ShapeRefiner& shape_refiner, const Node& node) {
static constexpr int PAD_WIDTH = 4; static constexpr int PAD_WIDTH = 4;
static constexpr int PAD_HEIGHT = 2; static constexpr int PAD_HEIGHT = 2;
@ -779,7 +776,7 @@ void GraphTransferer::RegisterPadNode(
} }
void GraphTransferer::RegisterInputNode( void GraphTransferer::RegisterInputNode(
const IGraphTransferOpsDefinitions& ops_definitions, const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node) { const ShapeRefiner& shape_refiner, const Node& node) {
const string op_type = node.type_string(); const string op_type = node.type_string();
VLOG(1) << "Register input node: " << node.name() << ", " << op_type; VLOG(1) << "Register input node: " << node.name() << ", " << op_type;
@ -797,12 +794,13 @@ void GraphTransferer::RegisterInputNode(
} }
void GraphTransferer::RegisterFlattenNode( void GraphTransferer::RegisterFlattenNode(
const IGraphTransferOpsDefinitions& ops_definitions, const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node) { const ShapeRefiner& shape_refiner, const Node& node) {
VLOG(1) << "Register flatten node: " << node.name(); VLOG(1) << "Register flatten node: " << node.name();
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1); CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
const int id = node_name_to_id_cache_map_[node.name()]; const int id = node_name_to_id_cache_map_[node.name()];
const string op_type = IGraphTransferOpsDefinitions::FLATTEN_OP_NAME; // TODO(satok): Remove dependency to specific type
const string op_type = "FLATTEN";
// TODO(satok): Set correct data type if it's given. // TODO(satok): Set correct data type if it's given.
const int op_type_id = ops_definitions.GetOpIdFor(op_type, {}); const int op_type_id = ops_definitions.GetOpIdFor(op_type, {});
CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount()); CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount());
@ -814,7 +812,7 @@ void GraphTransferer::RegisterFlattenNode(
} }
void GraphTransferer::RegisterGenericNode( void GraphTransferer::RegisterGenericNode(
const IGraphTransferOpsDefinitions& ops_definitions, const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node) { const ShapeRefiner& shape_refiner, const Node& node) {
VLOG(1) << "Register generic node: " << node.name(); VLOG(1) << "Register generic node: " << node.name();
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1); CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
@ -832,7 +830,7 @@ void GraphTransferer::RegisterGenericNode(
// TODO(satok): Remove this function. // TODO(satok): Remove this function.
// TODO(satok): Remove only_register_const_node. // TODO(satok): Remove only_register_const_node.
Status GraphTransferer::RegisterNodeIfAllInputsAreCached( Status GraphTransferer::RegisterNodeIfAllInputsAreCached(
const IGraphTransferOpsDefinitions& ops_definitions, const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node, const ShapeRefiner& shape_refiner, const Node& node,
const bool only_register_const_node, const bool only_register_const_node,
const std::vector<std::pair<string, Tensor>>& input_node_info_list, const std::vector<std::pair<string, Tensor>>& input_node_info_list,

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph_transfer_info.pb.h" #include "tensorflow/core/framework/graph_transfer_info.pb.h"
#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.h" #include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h" #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/protobuf.h"
@ -53,7 +53,7 @@ class GraphTransferer {
// TODO(satok): Pass a pair of TensorShape and DataType instead of // TODO(satok): Pass a pair of TensorShape and DataType instead of
// Tensor as input_node_info_list. // Tensor as input_node_info_list.
Status LoadGraphFromProto( Status LoadGraphFromProto(
const IGraphTransferOpsDefinitions& ops_definitions, const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const GraphDef& graph_def, const GraphDef& graph_def,
const std::vector<std::pair<string, Tensor>>& input_node_info_list, const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const std::vector<string>& output_node_names, const std::vector<string>& output_node_names,
@ -63,7 +63,7 @@ class GraphTransferer {
// TODO(satok): Pass a pair of TensorShape and DataType instead of // TODO(satok): Pass a pair of TensorShape and DataType instead of
// Tensor as input_node_info_list. // Tensor as input_node_info_list.
Status LoadGraphFromProtoFile( Status LoadGraphFromProtoFile(
const IGraphTransferOpsDefinitions& ops_definitions, const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const string& graph_def_path, const string& graph_def_path,
const std::vector<std::pair<string, Tensor>>& input_node_info_list, const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const std::vector<string>& output_node_names, const bool is_text_proto, const std::vector<string>& output_node_names, const bool is_text_proto,
@ -112,7 +112,7 @@ class GraphTransferer {
Graph* graph, ShapeRefiner* shape_refiner); Graph* graph, ShapeRefiner* shape_refiner);
Status RegisterNode( Status RegisterNode(
const IGraphTransferOpsDefinitions& ops_definitions, const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node, const ShapeRefiner& shape_refiner, const Node& node,
const std::vector<std::pair<string, Tensor>>& input_node_info_list, const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const std::vector<string>& output_node_names); const std::vector<string>& output_node_names);
@ -140,30 +140,29 @@ class GraphTransferer {
const ShapeRefiner& shape_refiner); const ShapeRefiner& shape_refiner);
void RegisterNodeWithPaddingAndStrides( void RegisterNodeWithPaddingAndStrides(
const IGraphTransferOpsDefinitions& ops_definitions, const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node); const ShapeRefiner& shape_refiner, const Node& node);
void RegisterNodeWithRank(const IGraphTransferOpsDefinitions& ops_definitions, void RegisterNodeWithRank(
const ShapeRefiner& shape_refiner, const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const Node& node); const ShapeRefiner& shape_refiner, const Node& node);
void RegisterPadNode(const IGraphTransferOpsDefinitions& ops_definitions, void RegisterPadNode(const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node); const ShapeRefiner& shape_refiner, const Node& node);
void RegisterInputNode(const IGraphTransferOpsDefinitions& ops_definitions, void RegisterInputNode(const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const ShapeRefiner& shape_refiner, const Node& node);
const Node& node);
void RegisterFlattenNode(const IGraphTransferOpsDefinitions& ops_definitions, void RegisterFlattenNode(
const ShapeRefiner& shape_refiner, const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const Node& node); const ShapeRefiner& shape_refiner, const Node& node);
void RegisterGenericNode(const IGraphTransferOpsDefinitions& ops_definitions, void RegisterGenericNode(
const ShapeRefiner& shape_refiner, const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const Node& node); const ShapeRefiner& shape_refiner, const Node& node);
Status RegisterNodeIfAllInputsAreCached( Status RegisterNodeIfAllInputsAreCached(
const IGraphTransferOpsDefinitions& ops_definitions, const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node, const ShapeRefiner& shape_refiner, const Node& node,
const bool only_register_const_node, const bool only_register_const_node,
const std::vector<std::pair<string, Tensor>>& input_node_info_list, const std::vector<std::pair<string, Tensor>>& input_node_info_list,

View File

@ -22,8 +22,8 @@ limitations under the License.
#include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h" #include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h"
#include "tensorflow/core/kernels/hexagon/graph_transferer.h" #include "tensorflow/core/kernels/hexagon/graph_transferer.h"
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h" #include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
#include "tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h" #include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/path.h"
@ -50,7 +50,7 @@ class GraphTransfererTest : public ::testing::Test {
const RemoteFusedGraphExecuteUtils::TensorShapeMap EMPTY_OUTPUT_TENSOR_MAP; const RemoteFusedGraphExecuteUtils::TensorShapeMap EMPTY_OUTPUT_TENSOR_MAP;
class TestGraphTransferOpsDefinitions : public IGraphTransferOpsDefinitions { class TestGraphTransferOpsDefinitions : public IRemoteFusedGraphOpsDefinitions {
public: public:
int GetTotalOpsCount() const final { return op_types_.size(); } int GetTotalOpsCount() const final { return op_types_.size(); }
@ -63,10 +63,6 @@ class TestGraphTransferOpsDefinitions : public IGraphTransferOpsDefinitions {
return -1; return -1;
} }
GraphTransferInfo::Destination GetTransferDestination() const final {
return GraphTransferInfo::NOP;
}
private: private:
const std::vector<string> op_types_{"INPUT", "OUTPUT", "Conv2D", const std::vector<string> op_types_{"INPUT", "OUTPUT", "Conv2D",
"MaxPool", "NoOp", "Add", "MaxPool", "NoOp", "Add",
@ -371,14 +367,14 @@ TEST_F(GraphTransfererTest, LoadMaxPoolGraph) {
} }
TEST(HexagonOpsDefinitions, CheckOpsDefinitions) { TEST(HexagonOpsDefinitions, CheckOpsDefinitions) {
const IGraphTransferOpsDefinitions& ops_definitions = const IRemoteFusedGraphOpsDefinitions& ops_definitions =
HexagonOpsDefinitions::getInstance(); HexagonOpsDefinitions::getInstance();
const int total_ops_count = ops_definitions.GetTotalOpsCount(); const int total_ops_count = ops_definitions.GetTotalOpsCount();
EXPECT_GT(total_ops_count, 0); EXPECT_GT(total_ops_count, 0);
} }
TEST(GraphTransferer, LoadGraphFromProtoFile) { TEST(GraphTransferer, LoadGraphFromProtoFile) {
const IGraphTransferOpsDefinitions* ops_definitions = const IRemoteFusedGraphOpsDefinitions* ops_definitions =
&TEST_GRAPH_TRANSFER_OPS_DEFINITIONS; &TEST_GRAPH_TRANSFER_OPS_DEFINITIONS;
string filename = string filename =
io::JoinPath(testing::TensorFlowSrcRoot(), io::JoinPath(testing::TensorFlowSrcRoot(),
@ -441,7 +437,7 @@ void CompareGraphTransferInfo(const GraphTransferInfo& a,
} // anonymous namespace } // anonymous namespace
TEST(GraphTransferer, LoadGraphFromProtoFileShapeInferenceSimple) { TEST(GraphTransferer, LoadGraphFromProtoFileShapeInferenceSimple) {
const IGraphTransferOpsDefinitions* ops_definitions = const IRemoteFusedGraphOpsDefinitions* ops_definitions =
&TEST_GRAPH_TRANSFER_OPS_DEFINITIONS; &TEST_GRAPH_TRANSFER_OPS_DEFINITIONS;
string filename = string filename =
io::JoinPath(testing::TensorFlowSrcRoot(), io::JoinPath(testing::TensorFlowSrcRoot(),

View File

@ -16,18 +16,19 @@ limitations under the License.
#include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h" #include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h"
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h" #include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
#include "tensorflow/core/kernels/hexagon/soc_interface.h"
#ifdef USE_HEXAGON_LIBS
#include "tensorflow/core/platform/hexagon/soc_interface.h"
#include "tensorflow/core/platform/profile_utils/cpu_utils.h" #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
#endif
namespace tensorflow { namespace tensorflow {
constexpr const char* const INPUT_OP_NAME = "INPUT";
constexpr const char* const OUTPUT_OP_NAME = "OUTPUT"; constexpr const char* const OUTPUT_OP_NAME = "OUTPUT";
constexpr const char* const REMOTE_FUSED_GRAPH_NODE_NAME_PREFIX =
"hexagon_remote_fused_graph";
/* static */ constexpr const char* const
HexagonControlWrapper::REMOTE_FUSED_GRAPH_EXECUTOR_NAME;
constexpr int ALIGNMENT_BYTES = 16; constexpr int ALIGNMENT_BYTES = 16;
constexpr int MAX_IN_OUT_COUNT = 128;
const bool DBG_DUMP_VERIFICATION_STRING = false; const bool DBG_DUMP_VERIFICATION_STRING = false;
const int DBG_LEVEL = 0; // -2: verbose, -1: debug, 0: info const int DBG_LEVEL = 0; // -2: verbose, -1: debug, 0: info
@ -63,7 +64,6 @@ static uint8* FindAlignedPointer(uint8* ptr) {
return nullptr; return nullptr;
} }
#ifdef USE_HEXAGON_LIBS
int HexagonControlWrapper::GetVersion() { int HexagonControlWrapper::GetVersion() {
return soc_interface_GetSocControllerVersion(); return soc_interface_GetSocControllerVersion();
} }
@ -95,7 +95,6 @@ bool HexagonControlWrapper::Init(const RemoteFusedGraphExecuteInfo& info) {
LOG(ERROR) << "Hexagon initialization was failed. See log output."; LOG(ERROR) << "Hexagon initialization was failed. See log output.";
return false; return false;
} }
const GraphTransferInfo& gt_info = graph_transferer_.GetGraphTransferInfo();
std::vector<int> input_sizes; std::vector<int> input_sizes;
std::vector<int> output_sizes; std::vector<int> output_sizes;
CHECK_NOTNULL(execute_info_); CHECK_NOTNULL(execute_info_);
@ -207,8 +206,9 @@ bool HexagonControlWrapper::SetupGraph() {
for (const GraphTransferInfo::NodeInputInfo& input_params : for (const GraphTransferInfo::NodeInputInfo& input_params :
graph_transfer_info.node_input_info()) { graph_transfer_info.node_input_info()) {
const int count = input_params.node_input_size(); const int count = input_params.node_input_size();
int node_ids[count]; CHECK(count <= MAX_IN_OUT_COUNT);
int ports[count]; int node_ids[MAX_IN_OUT_COUNT];
int ports[MAX_IN_OUT_COUNT];
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
const GraphTransferInfo::NodeInput& node_input = const GraphTransferInfo::NodeInput& node_input =
input_params.node_input(i); input_params.node_input(i);
@ -226,7 +226,8 @@ bool HexagonControlWrapper::SetupGraph() {
for (const GraphTransferInfo::NodeOutputInfo& output_params : for (const GraphTransferInfo::NodeOutputInfo& output_params :
graph_transfer_info.node_output_info()) { graph_transfer_info.node_output_info()) {
const int count = output_params.max_byte_size_size(); const int count = output_params.max_byte_size_size();
int sizes[count]; CHECK(count <= MAX_IN_OUT_COUNT);
int sizes[MAX_IN_OUT_COUNT];
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
const int size = output_params.max_byte_size(i); const int size = output_params.max_byte_size(i);
sizes[i] = size; sizes[i] = size;
@ -373,6 +374,7 @@ bool HexagonControlWrapper::ReadOutputNode(
<< output_tensor->TotalBytes() << ", " << std::get<1>(output); << output_tensor->TotalBytes() << ", " << std::get<1>(output);
TF_CHECK_OK(RemoteFusedGraphExecuteUtils::CopyByteArrayToTensor( TF_CHECK_OK(RemoteFusedGraphExecuteUtils::CopyByteArrayToTensor(
std::get<0>(output), std::get<1>(output), output_tensor)); std::get<0>(output), std::get<1>(output), output_tensor));
return true;
} }
bool HexagonControlWrapper::ReadOutputNode( bool HexagonControlWrapper::ReadOutputNode(
@ -382,14 +384,30 @@ bool HexagonControlWrapper::ReadOutputNode(
const string tensor_name = AddPort(node_name); const string tensor_name = AddPort(node_name);
CHECK(output_port_map_.count(tensor_name) > 0); CHECK(output_port_map_.count(tensor_name) > 0);
const int port = output_port_map_.at(tensor_name); const int port = output_port_map_.at(tensor_name);
soc_interface_ReadOutputNodeWithPort(port, &std::get<0>(output), soc_interface_ReadOutputNodeWithPort(
&std::get<1>(output)); port, &std::get<0>(output),
reinterpret_cast<uint64_t*>(&std::get<1>(output)));
// TODO: Accept all results // TODO: Accept all results
// std::get<2>(output) = DT_FLOAT; // std::get<2>(output) = DT_FLOAT;
outputs->emplace_back(output); outputs->emplace_back(output);
return true; return true;
} }
Status HexagonControlWrapper::FuseRemoteGraph(
const GraphDef& original_graph_def, const std::vector<string>& inputs,
const std::vector<string>& outputs, GraphDef* fused_graph_def) {
const std::unordered_set<string> fused_node_names =
RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions(
original_graph_def, HexagonOpsDefinitions::getInstance());
// TODO(satok): We may want to place shape and type inside this function
// if they are not placed in the given graph.
TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::FuseRemoteGraphByNodeNames(
original_graph_def, inputs, outputs, REMOTE_FUSED_GRAPH_NODE_NAME_PREFIX,
fused_node_names, REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
/*require_shape_type=*/true, fused_graph_def));
return Status::OK();
}
bool HexagonControlWrapper::FillInputNode(const string& node_name, bool HexagonControlWrapper::FillInputNode(const string& node_name,
const Tensor& tensor) { const Tensor& tensor) {
StringPiece tensor_data = tensor.tensor_data(); StringPiece tensor_data = tensor.tensor_data();
@ -415,31 +433,5 @@ bool HexagonControlWrapper::FillInputNode(const string& node_name,
return true; return true;
} }
#else bool HexagonControlWrapper::IsEnabled() const { return true; };
int HexagonControlWrapper::GetVersion() { return -1; }
bool HexagonControlWrapper::Init(const RemoteFusedGraphExecuteInfo&) {
return false;
}
bool HexagonControlWrapper::Finalize() { return false; }
bool HexagonControlWrapper::SetupGraph() { return false; }
bool HexagonControlWrapper::ExecuteGraph() { return false; }
bool HexagonControlWrapper::TeardownGraph() { return false; }
bool HexagonControlWrapper::FillInputNode(
const string&, const std::array<int64, GraphTransferer::SHAPE_ARRAY_SIZE>&,
const ConstByteArray) {
return false;
}
bool HexagonControlWrapper::FillInputNode(const string&, const Tensor&) {
return false;
}
bool HexagonControlWrapper::ReadOutputNode(
const string& node_name, TensorAllocatorFunc tensor_allocator) {
return false;
}
bool HexagonControlWrapper::ReadOutputNode(const string&,
std::vector<ByteArray>* const) {
return false;
}
#endif
} // namespace tensorflow } // namespace tensorflow

View File

@ -35,6 +35,8 @@ class HexagonControlWrapper final : public IRemoteFusedGraphExecutor {
public: public:
using ByteArray = using ByteArray =
std::tuple<uint8* /* data */, uint64 /* size */, DataType /* type */>; std::tuple<uint8* /* data */, uint64 /* size */, DataType /* type */>;
static constexpr const char* const REMOTE_FUSED_GRAPH_EXECUTOR_NAME =
"build_hexagon_remote_fused_graph_executor";
HexagonControlWrapper() = default; HexagonControlWrapper() = default;
int GetVersion() final; int GetVersion() final;
@ -46,6 +48,11 @@ class HexagonControlWrapper final : public IRemoteFusedGraphExecutor {
bool FillInputNode(const string& node_name, const Tensor& tensor) final; bool FillInputNode(const string& node_name, const Tensor& tensor) final;
bool ReadOutputNode(const string& node_name, bool ReadOutputNode(const string& node_name,
TensorAllocatorFunc tensor_allocator) final; TensorAllocatorFunc tensor_allocator) final;
Status FuseRemoteGraph(const GraphDef& original_graph_def,
const std::vector<string>& inputs,
const std::vector<string>& outputs,
GraphDef* fused_graph_def) final;
bool IsEnabled() const final;
bool ReadOutputNode(const string& node_name, std::vector<ByteArray>* outputs); bool ReadOutputNode(const string& node_name, std::vector<ByteArray>* outputs);
private: private:

View File

@ -35,8 +35,8 @@ adb push /tmp/imagenet_comp_graph_label_strings.txt /data/local/tmp
#include "tensorflow/core/kernels/hexagon/graph_transferer.h" #include "tensorflow/core/kernels/hexagon/graph_transferer.h"
#include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h" #include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h"
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h" #include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
#include "tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h" #include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
#include "tensorflow/core/kernels/quantization_utils.h" #include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
@ -389,9 +389,6 @@ static void CompareGraphTransferInfo(const GraphTransferInfo& gfi0,
EXPECT_EQ(goni0.ByteSizeLong(), goni1.ByteSizeLong()); EXPECT_EQ(goni0.ByteSizeLong(), goni1.ByteSizeLong());
EXPECT_EQ(goni0.DebugString(), goni1.DebugString()); EXPECT_EQ(goni0.DebugString(), goni1.DebugString());
} }
// 7. check destination
EXPECT_EQ(gfi0.destination(), gfi1.destination());
} }
// CAVEAT: This test only runs when you specify hexagon library using // CAVEAT: This test only runs when you specify hexagon library using
@ -407,7 +404,7 @@ TEST(GraphTransferer,
LOG(INFO) << "Run inception v3 on hexagon with hexagon controller"; LOG(INFO) << "Run inception v3 on hexagon with hexagon controller";
CheckHexagonControllerVersion(); CheckHexagonControllerVersion();
const IGraphTransferOpsDefinitions* ops_definitions = const IRemoteFusedGraphOpsDefinitions* ops_definitions =
&HexagonOpsDefinitions::getInstance(); &HexagonOpsDefinitions::getInstance();
std::vector<std::pair<string, Tensor>> inputs; std::vector<std::pair<string, Tensor>> inputs;
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH})); inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
@ -474,7 +471,7 @@ TEST(GraphTransferer,
LOG(INFO) << "Run inception v3 on hexagon with hexagon controller"; LOG(INFO) << "Run inception v3 on hexagon with hexagon controller";
CheckHexagonControllerVersion(); CheckHexagonControllerVersion();
const IGraphTransferOpsDefinitions* ops_definitions = const IRemoteFusedGraphOpsDefinitions* ops_definitions =
&HexagonOpsDefinitions::getInstance(); &HexagonOpsDefinitions::getInstance();
std::vector<std::pair<string, Tensor>> inputs; std::vector<std::pair<string, Tensor>> inputs;
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH})); inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
@ -505,7 +502,7 @@ TEST(GraphTransferer, RunInceptionV3OnHexagonExampleWithTfRuntime) {
LOG(INFO) << "Fuse and run inception v3 on hexagon with tf runtime"; LOG(INFO) << "Fuse and run inception v3 on hexagon with tf runtime";
CheckHexagonControllerVersion(); CheckHexagonControllerVersion();
const IGraphTransferOpsDefinitions* ops_definitions = const IRemoteFusedGraphOpsDefinitions* ops_definitions =
&HexagonOpsDefinitions::getInstance(); &HexagonOpsDefinitions::getInstance();
std::vector<std::pair<string, Tensor>> inputs; std::vector<std::pair<string, Tensor>> inputs;
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH})); inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
@ -543,7 +540,7 @@ TEST(GraphTransferer, DISABLED_CheckShapeInferencePerformance) {
CheckHexagonControllerVersion(); CheckHexagonControllerVersion();
profile_utils::CpuUtils::EnableClockCycleProfiling(true); profile_utils::CpuUtils::EnableClockCycleProfiling(true);
const IGraphTransferOpsDefinitions* ops_definitions = const IRemoteFusedGraphOpsDefinitions* ops_definitions =
&HexagonOpsDefinitions::getInstance(); &HexagonOpsDefinitions::getInstance();
std::vector<std::pair<string, Tensor>> inputs; std::vector<std::pair<string, Tensor>> inputs;
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH})); inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));

View File

@ -304,8 +304,8 @@ HexagonOpsDefinitions::BuildOpNameToSocOpTypeMap() {
EmplaceOpType("INPUT", {}, SupportedOpType::INPUT, &op_map); EmplaceOpType("INPUT", {}, SupportedOpType::INPUT, &op_map);
EmplaceOpType("OUTPUT", {}, SupportedOpType::OUTPUT, &op_map); EmplaceOpType("OUTPUT", {}, SupportedOpType::OUTPUT, &op_map);
EmplaceOpType("NoOp", {}, SupportedOpType::NOP, &op_map); EmplaceOpType("NoOp", {}, SupportedOpType::NOP, &op_map);
EmplaceOpType(IGraphTransferOpsDefinitions::FLATTEN_OP_NAME, {}, // Special op type for hexagon
SupportedOpType::FLATTEN, &op_map); EmplaceOpType("FLATTEN", {}, SupportedOpType::FLATTEN, &op_map);
// Tensorflow op name // Tensorflow op name
// CAVEAT: Keep order of SupportedOpType // CAVEAT: Keep order of SupportedOpType
EmplaceOpType("Identity", {}, SupportedOpType::NOP, &op_map); EmplaceOpType("Identity", {}, SupportedOpType::NOP, &op_map);
@ -373,7 +373,7 @@ HexagonOpsDefinitions::BuildOpNameToSocOpTypeMap() {
HexagonOpsDefinitions::HexagonOpsDefinitions() HexagonOpsDefinitions::HexagonOpsDefinitions()
: op_name_to_soc_op_type_map_(BuildOpNameToSocOpTypeMap()) {} : op_name_to_soc_op_type_map_(BuildOpNameToSocOpTypeMap()) {}
/* static */ const IGraphTransferOpsDefinitions& /* static */ const IRemoteFusedGraphOpsDefinitions&
HexagonOpsDefinitions::getInstance() { HexagonOpsDefinitions::getInstance() {
const static HexagonOpsDefinitions instance{}; const static HexagonOpsDefinitions instance{};
return instance; return instance;
@ -393,17 +393,17 @@ int HexagonOpsDefinitions::GetOpIdFor(const string& op_type,
if (dt_vec.empty()) { if (dt_vec.empty()) {
return static_cast<int>(std::get<1>(dt_to_op_vec.front())); return static_cast<int>(std::get<1>(dt_to_op_vec.front()));
} }
// If there is only one op_id registered for empty op_vec, we assume
// that the op supports any data types.
if (dt_to_op_vec.size() == 1 && std::get<0>(dt_to_op_vec.front()).empty()) {
return static_cast<int>(std::get<1>(dt_to_op_vec.front()));
}
for (const DataTypeToOp& data_type_to_op : dt_to_op_vec) { for (const DataTypeToOp& data_type_to_op : dt_to_op_vec) {
if (std::get<0>(data_type_to_op) == dt_vec) { if (std::get<0>(data_type_to_op) == dt_vec) {
return static_cast<int>(std::get<1>(data_type_to_op)); return static_cast<int>(std::get<1>(data_type_to_op));
} }
} }
} }
return IGraphTransferOpsDefinitions::INVALID_OP_ID; return IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID;
}
GraphTransferInfo::Destination HexagonOpsDefinitions::GetTransferDestination()
const {
return GraphTransferInfo::HEXAGON;
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -18,21 +18,20 @@ limitations under the License.
#include <unordered_map> #include <unordered_map>
#include "i_graph_transfer_ops_definitions.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
namespace tensorflow { namespace tensorflow {
// HexagonOpsDefinitions provides ops definitions supported in hexagon library // HexagonOpsDefinitions provides ops definitions supported in hexagon library
// TODO(satok): add a functionality to call functions in hexagon library // TODO(satok): add a functionality to call functions in hexagon library
class HexagonOpsDefinitions final : public IGraphTransferOpsDefinitions { class HexagonOpsDefinitions final : public IRemoteFusedGraphOpsDefinitions {
public: public:
static const IGraphTransferOpsDefinitions& getInstance(); static const IRemoteFusedGraphOpsDefinitions& getInstance();
int GetTotalOpsCount() const final; int GetTotalOpsCount() const final;
int GetOpIdFor(const string& op_type, const DataTypeVector& dt) const final; int GetOpIdFor(const string& op_type, const DataTypeVector& dt) const final;
GraphTransferInfo::Destination GetTransferDestination() const final;
private: private:
enum class SupportedOpType; enum class SupportedOpType;

View File

@ -27,7 +27,7 @@ Status BuildRemoteFusedGraphExecutor(
static RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar static RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
k_hexagon_remote_fused_graph_executor_build( k_hexagon_remote_fused_graph_executor_build(
"build_hexagon_remote_fused_graph_executor", HexagonControlWrapper::REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
BuildRemoteFusedGraphExecutor); BuildRemoteFusedGraphExecutor);
} // namespace hexagon_remote_fused_graph_executor_build } // namespace hexagon_remote_fused_graph_executor_build

View File

@ -1,136 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Tests in this file are designed to evaluate hexagon DSP operations.
#define EIGEN_USE_THREADS
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#ifdef USE_HEXAGON_LIBS
#include "tensorflow/core/platform/hexagon/soc_interface.h"
#include "tensorflow/core/platform/profile_utils/cpu_utils.h"
#endif
namespace tensorflow {
class QuantizedMatMulOpForHexagonTest : public OpsTestBase {
protected:
void SetUp() final {
#ifdef USE_HEXAGON_LIBS
profile_utils::CpuUtils::EnableClockCycleProfiling(true);
LOG(INFO) << "Hexagon libs are linked (wrapper version = "
<< soc_interface_GetWrapperVersion()
<< ", hexagon binary version = "
<< soc_interface_GetSocControllerVersion() << ")";
LOG(INFO) << "Cpu frequency = "
<< profile_utils::CpuUtils::GetCycleCounterFrequency();
#else
LOG(WARNING) << "Hexagon libs are not linked.";
#endif
}
};
// Shows some statistics of hexagon dsp using hexagon specific APIs
#ifdef USE_HEXAGON_LIBS
TEST_F(QuantizedMatMulOpForHexagonTest, EvaluateSharedLibOverhead) {
const uint64 overhead_shared_lib_start =
profile_utils::CpuUtils::GetCurrentClockCycle();
const int wrapper_version = soc_interface_GetWrapperVersion();
const uint64 overhead_shared_lib_end =
profile_utils::CpuUtils::GetCurrentClockCycle();
const uint64 overhead_shared_lib_diff =
(overhead_shared_lib_end - overhead_shared_lib_start);
const uint64 overhead_hexagon_rpc_start =
profile_utils::CpuUtils::GetCurrentClockCycle();
const int hexagon_binary_version = soc_interface_GetSocControllerVersion();
const uint64 overhead_hexagon_rpc_end =
profile_utils::CpuUtils::GetCurrentClockCycle();
const uint64 overhead_hexagon_rpc_diff =
(overhead_hexagon_rpc_end - overhead_hexagon_rpc_start);
LOG(INFO) << "Shared lib (ver = " << wrapper_version << ") overhead is "
<< overhead_shared_lib_diff << " cycles, time = "
<< std::chrono::duration_cast<std::chrono::microseconds>(
profile_utils::CpuUtils::ConvertClockCycleToTime(
overhead_shared_lib_diff))
.count()
<< " usec";
LOG(INFO) << "hexagon rpc (ver = " << hexagon_binary_version
<< ") overhead is " << overhead_hexagon_rpc_diff
<< " cycles, time = "
<< std::chrono::duration_cast<std::chrono::microseconds>(
profile_utils::CpuUtils::ConvertClockCycleToTime(
overhead_hexagon_rpc_diff))
.count()
<< " usec";
}
#endif
// Runs two small matrices through the operator, and leaves all the parameters
// at their default values.
// This test is a sample to execute matmul on hexagon.
TEST_F(QuantizedMatMulOpForHexagonTest, Small_NoParams) {
TF_ASSERT_OK(NodeDefBuilder("quantized_mat_mul_op", "QuantizedMatMul")
.Input(FakeInput(DT_QUINT8))
.Input(FakeInput(DT_QUINT8))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Attr("Toutput", DataTypeToEnum<qint32>::v())
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
// A matrix is:
// | 1 | 2 | 3 |
// | 4 | 5 | 6 |
AddInputFromArray<quint8>(TensorShape({2, 3}), {1, 2, 3, 4, 5, 6});
// B matrix is:
// | 7 | 8 | 9 | 10 |
// | 11 | 12 | 13 | 14 |
// | 15 | 16 | 17 | 18 |
AddInputFromArray<quint8>(TensorShape({3, 4}),
{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
AddInputFromArray<float>(TensorShape({1}), {0});
AddInputFromArray<float>(TensorShape({1}), {255.0f});
AddInputFromArray<float>(TensorShape({1}), {0});
AddInputFromArray<float>(TensorShape({1}), {255.0f});
TF_ASSERT_OK(RunOpKernel());
// Here are the results we expect, from hand calculations:
// (1 * 7) + (2 * 11) + (3 * 15) = 74
// (1 * 8) + (2 * 12) + (3 * 16) = 80
// (1 * 9) + (2 * 13) + (3 * 17) = 86
// (1 * 10) + (2 * 14) + (3 * 18) = 92
// (4 * 7) + (5 * 11) + (6 * 15) = 173
// (4 * 8) + (5 * 12) + (6 * 16) = 188
// (4 * 9) + (5 * 13) + (6 * 17) = 203
// (4 * 10) + (5 * 14) + (6 * 18) = 218
Tensor expected(allocator(), DT_QINT32, TensorShape({2, 4}));
test::FillValues<qint32>(&expected, {74, 80, 86, 92, 173, 188, 203, 218});
test::ExpectTensorEqual<qint32>(expected, *GetOutput(0));
}
} // namespace tensorflow

View File

@ -0,0 +1,83 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
vcyou may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/hexagon/soc_interface.h"
// Dummy implementation of soc_interface.
int soc_interface_GetWrapperVersion() { return -1; }
int soc_interface_GetSocControllerVersion() { return -1; }
bool soc_interface_Init() { return false; }
bool soc_interface_Finalize() { return false; }
bool soc_interface_ExecuteGraph() { return false; }
bool soc_interface_TeardownGraph() { return false; }
bool soc_interface_AllocateInOutNodeBuffers(int /*input_count*/,
int* /*input_sizes*/,
int /*output_count*/,
int* /*output_sizes*/) {
return false;
}
bool soc_interface_FillInputNodeWithPort(int /*port*/, int /*x*/, int /*y*/,
int /*z*/, int /*d*/,
const uint8_t* const /*buf*/,
uint64_t /*buf_byte_size*/) {
return false;
}
bool soc_interface_FillInputNodeFloat(int /*x*/, int /*y*/, int /*z*/,
int /*d*/, const uint8_t* const /*buf*/,
uint64_t /*buf_byte_size*/) {
return false;
}
bool soc_interface_ReadOutputNodeWithPort(int /*port*/, uint8_t** /*buf*/,
uint64_t* /*buf_byte_size*/) {
return false;
}
bool soc_interface_ReadOutputNodeFloat(const char* const /*node_name*/,
uint8_t** /*buf*/,
uint64_t* /*buf_byte_size*/) {
return false;
}
bool soc_interface_setupDummyGraph(int /*version*/) { return false; }
bool soc_interface_AllocateNodeInputAndNodeOutputArray(
int /*total_input_count*/, int /*total_output_count*/) {
return false;
}
bool soc_interface_ReleaseNodeInputAndNodeOutputArray() { return false; }
void* soc_interface_SetOneNodeInputs(int /*input_count*/,
const int* const /*node_id*/,
const int* const /*port*/) {
return nullptr;
}
void* soc_interface_SetOneNodeOutputs(int /*output_count*/, int* /*max_size*/) {
return nullptr;
}
bool soc_interface_AppendConstNode(const char* const /*name*/, int /*node_id*/,
int /*batch*/, int /*height*/, int /*width*/,
int /*depth*/, const uint8_t* const /*data*/,
int /*data_length*/) {
return false;
}
bool soc_interface_AppendNode(const char* const /*name*/, int /*node_id*/,
int /*op_id*/, int /*padding_id*/,
const void* const /*inputs*/,
int /*inputs_count*/,
const void* const /*outputs*/,
int /*outputs_count*/) {
return false;
}
bool soc_interface_InstantiateGraph() { return false; }
bool soc_interface_ConstructGraph() { return false; }
void soc_interface_SetLogLevel(int /*log_level*/) {}
void soc_interface_SetDebugFlag(uint64_t /*flag*/) {}

View File

@ -21,6 +21,7 @@ limitations under the License.
// All functions defined here must have prefix "soc_interface" to avoid // All functions defined here must have prefix "soc_interface" to avoid
// naming conflicts. // naming conflicts.
#ifdef __cplusplus #ifdef __cplusplus
#include <cstdint>
extern "C" { extern "C" {
#else #else
#include <stdbool.h> #include <stdbool.h>

View File

@ -59,6 +59,13 @@ class IRemoteFusedGraphExecutor {
virtual bool ReadOutputNode(const string& node_name, virtual bool ReadOutputNode(const string& node_name,
TensorAllocatorFunc tensor_allocator) = 0; TensorAllocatorFunc tensor_allocator) = 0;
virtual Status FuseRemoteGraph(const GraphDef& original_graph_def,
const std::vector<string>& inputs,
const std::vector<string>& outputs,
GraphDef* fused_graph_def) = 0;
virtual bool IsEnabled() const = 0;
private: private:
TF_DISALLOW_COPY_AND_ASSIGN(IRemoteFusedGraphExecutor); TF_DISALLOW_COPY_AND_ASSIGN(IRemoteFusedGraphExecutor);
}; };

View File

@ -13,11 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "i_graph_transfer_ops_definitions.h" #include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
namespace tensorflow { namespace tensorflow {
/* static */ constexpr int IGraphTransferOpsDefinitions::INVALID_OP_ID; /* static */ constexpr int IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID;
// TODO(satok): Remove
/* static */ constexpr const char* const
IGraphTransferOpsDefinitions::FLATTEN_OP_NAME;
} }

View File

@ -13,39 +13,34 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_I_GRAPH_TRANSFER_OPS_DEFINITIONS_H_ #ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_OPS_DEFINITIONS_H_
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_I_GRAPH_TRANSFER_OPS_DEFINITIONS_H_ #define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_OPS_DEFINITIONS_H_
#include "tensorflow/core/framework/graph_transfer_info.pb.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
namespace tensorflow { namespace tensorflow {
// IGraphTransferOpsDefinitions is an interface class which provides interfaces // IRemoteFusedGraphOpsDefinitions is an interface class which provides
// about ops supported by SOC. // APIs to provide information about op types supported by SOC.
// TODO(satok): Provide ways to transfer graph definitions into SOC // TODO(satok): Provide ways to transfer graph definitions into SOC
class IGraphTransferOpsDefinitions { class IRemoteFusedGraphOpsDefinitions {
public: public:
// op id which is not supported by SOC // op id which is not supported by SOC
static constexpr int INVALID_OP_ID = -1; static constexpr int INVALID_OP_ID = -1;
// Custom op name for flatten node
static constexpr const char* const FLATTEN_OP_NAME = "FLATTEN";
IGraphTransferOpsDefinitions() = default; IRemoteFusedGraphOpsDefinitions() = default;
virtual ~IGraphTransferOpsDefinitions() = default; virtual ~IRemoteFusedGraphOpsDefinitions() = default;
// Return total ops count supported by SOC // Return total ops count supported by SOC
virtual int GetTotalOpsCount() const = 0; virtual int GetTotalOpsCount() const = 0;
// Return op id for given string op name // Return op id for given string op name
virtual int GetOpIdFor(const string& op_name, virtual int GetOpIdFor(const string& op_type,
const DataTypeVector& dt) const = 0; const DataTypeVector& dt) const = 0;
// Return destination of transfer
virtual GraphTransferInfo::Destination GetTransferDestination() const = 0;
private: private:
TF_DISALLOW_COPY_AND_ASSIGN(IGraphTransferOpsDefinitions); TF_DISALLOW_COPY_AND_ASSIGN(IRemoteFusedGraphOpsDefinitions);
}; };
} // namespace tensorflow } // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_I_GRAPH_TRANSFER_OPS_DEFINITIONS_H_ #endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_OPS_DEFINITIONS_H_

View File

@ -41,7 +41,8 @@ class RemoteFusedGraphExecuteOp : public OpKernel {
RemoteFusedGraphExecuteUtils::GetExecutorBuildFunc( RemoteFusedGraphExecuteUtils::GetExecutorBuildFunc(
execute_info_.executor_name()); execute_info_.executor_name());
if (build_func != nullptr) { if (build_func != nullptr) {
Status status = (*build_func)(&remote_fused_graph_executor_); TF_CHECK_OK((*build_func)(&remote_fused_graph_executor_));
CHECK(remote_fused_graph_executor_->IsEnabled());
} else { } else {
LOG(ERROR) << "Executor not found for " LOG(ERROR) << "Executor not found for "
<< execute_info_.executor_name(); << execute_info_.executor_name();

View File

@ -159,8 +159,8 @@ static RemoteFusedGraphExecuteInfo BuildRemoteFusedGraphExecuteInfo(
return execute_info; return execute_info;
} }
// 1. Create TestRemoteFusedGraphExecutor to execute your fused graph // 1. Create SampleRemoteFusedGraphExecutor to execute your fused graph
class TestRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor { class SampleRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor {
public: public:
int GetVersion() final { return 1; } int GetVersion() final { return 1; }
bool Init(const RemoteFusedGraphExecuteInfo& info) final { bool Init(const RemoteFusedGraphExecuteInfo& info) final {
@ -214,6 +214,16 @@ class TestRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor {
return true; return true;
} }
Status FuseRemoteGraph(const GraphDef& original_graph_def,
const std::vector<string>& /*inputs*/,
const std::vector<string>& /*outputs*/,
GraphDef* fused_graph_def) final {
*fused_graph_def = original_graph_def;
return Status::OK();
}
bool IsEnabled() const final { return true; }
private: private:
const RemoteFusedGraphExecuteInfo* info_; const RemoteFusedGraphExecuteInfo* info_;
std::unordered_map<string, Tensor> input_tensor_cache_; std::unordered_map<string, Tensor> input_tensor_cache_;
@ -225,7 +235,7 @@ class TestRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor {
namespace remote_fused_graph_execute_op { namespace remote_fused_graph_execute_op {
Status BuildRemoteFusedGraphExecutor( Status BuildRemoteFusedGraphExecutor(
std::unique_ptr<IRemoteFusedGraphExecutor>* executor) { std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
executor->reset(new TestRemoteFusedGraphExecutor()); executor->reset(new SampleRemoteFusedGraphExecutor());
return Status::OK(); return Status::OK();
} }

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/math_ops.h" #include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
@ -92,4 +93,36 @@ namespace tensorflow {
return Status::OK(); return Status::OK();
} }
TestRemoteFusedGraphExecutor::TestRemoteFusedGraphExecutor(
const std::unordered_set<string>& fused_op_types,
const string& executor_name)
: fused_op_types_(fused_op_types), executor_name_(executor_name) {}
int TestRemoteFusedGraphExecutor::GetVersion() { return 0; }
bool TestRemoteFusedGraphExecutor::Init(const RemoteFusedGraphExecuteInfo&) {
return true;
}
bool TestRemoteFusedGraphExecutor::Finalize() { return true; }
bool TestRemoteFusedGraphExecutor::SetupGraph() { return true; }
bool TestRemoteFusedGraphExecutor::ExecuteGraph() { return true; }
bool TestRemoteFusedGraphExecutor::TeardownGraph() { return true; }
bool TestRemoteFusedGraphExecutor::FillInputNode(const string&, const Tensor&) {
return true;
}
bool TestRemoteFusedGraphExecutor::ReadOutputNode(const string&,
TensorAllocatorFunc) {
return true;
}
Status TestRemoteFusedGraphExecutor::FuseRemoteGraph(
const GraphDef& original_graph_def, const std::vector<string>& inputs,
const std::vector<string>& outputs, GraphDef* fused_graph_def) {
return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByOpTypes(
original_graph_def, inputs, outputs, "remote_fused_graph_node_names",
fused_op_types_, executor_name_,
/*require_shape_type=*/false, fused_graph_def);
return Status::OK();
}
bool TestRemoteFusedGraphExecutor::IsEnabled() const { return true; }
} // namespace tensorflow } // namespace tensorflow

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/framework/scope.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
namespace tensorflow { namespace tensorflow {
@ -59,6 +60,30 @@ class RemoteFusedGraphExecuteOpTestUtils {
TF_DISALLOW_COPY_AND_ASSIGN(RemoteFusedGraphExecuteOpTestUtils); TF_DISALLOW_COPY_AND_ASSIGN(RemoteFusedGraphExecuteOpTestUtils);
}; };
class TestRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor {
public:
TestRemoteFusedGraphExecutor(const std::unordered_set<string>& fused_op_types,
const string& executor_name);
int GetVersion() final;
bool Init(const RemoteFusedGraphExecuteInfo&) final;
bool Finalize() final;
bool SetupGraph() final;
bool ExecuteGraph() final;
bool TeardownGraph() final;
bool FillInputNode(const string&, const Tensor&) final;
bool ReadOutputNode(const string&, TensorAllocatorFunc) final;
Status FuseRemoteGraph(const GraphDef& original_graph_def,
const std::vector<string>& inputs,
const std::vector<string>& outputs,
GraphDef* fused_graph_def) final;
bool IsEnabled() const final;
private:
const std::unordered_set<string> fused_op_types_;
const string executor_name_;
};
} // namespace tensorflow } // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_OP_TEST_UTILS_H_ #endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_OP_TEST_UTILS_H_

View File

@ -161,6 +161,8 @@ string DumpCluster(const RemoteFusedGraphExecuteUtils::ClusterInfo& cluster) {
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS; RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS;
/* static */ constexpr const char* const /* static */ constexpr const char* const
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES; RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES;
/* static */ constexpr const char* const
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR;
/* static */ constexpr const char* const /* static */ constexpr const char* const
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES; RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES;
/* static */ constexpr const char* const /* static */ constexpr const char* const
@ -1084,6 +1086,26 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
require_shape_type, output_graph_def); require_shape_type, output_graph_def);
} }
/* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
const GraphDef& input_graph_def, const std::vector<string>& inputs,
const std::vector<string>& outputs, const string& executor_name,
GraphDef* output_graph_def) {
const ExecutorBuildFunc* build_func = GetExecutorBuildFunc(executor_name);
if (build_func == nullptr) {
return errors::InvalidArgument("Unknown executor name: " + executor_name);
}
std::unique_ptr<IRemoteFusedGraphExecutor> executor;
TF_RETURN_IF_ERROR((*build_func)(&executor));
CHECK_NOTNULL(executor.get());
if (!executor->IsEnabled()) {
// As this executor is not enabled, just return original graph as is.
*output_graph_def = input_graph_def;
return Status::OK();
}
return executor->FuseRemoteGraph(input_graph_def, inputs, outputs,
output_graph_def);
}
/* static */ Status RemoteFusedGraphExecuteUtils::PlaceRemoteGraphArguments( /* static */ Status RemoteFusedGraphExecuteUtils::PlaceRemoteGraphArguments(
const std::vector<string>& inputs, const std::vector<string>& outputs, const std::vector<string>& inputs, const std::vector<string>& outputs,
const std::unordered_set<string>& fused_node_names, const std::unordered_set<string>& fused_node_names,
@ -1387,6 +1409,28 @@ RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpTypes(
return retval; return retval;
} }
/* static */ std::unordered_set<string>
RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions(
const GraphDef& graph_def,
const IRemoteFusedGraphOpsDefinitions& ops_definitions) {
std::unordered_set<string> retval;
for (const NodeDef& node_def : graph_def.node()) {
std::vector<DataType> dt_vec;
std::vector<TensorShape> shape_vec;
const Status status =
GetOutputTensorShapeType(node_def, &dt_vec, &shape_vec);
if (!status.ok()) {
shape_vec.clear();
}
if (ops_definitions.GetOpIdFor(
node_def.op(), DataTypeVector(dt_vec.begin(), dt_vec.end())) !=
IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) {
retval.emplace(node_def.name());
}
}
return retval;
}
/* static */ Status RemoteFusedGraphExecuteUtils::ReplaceInputNodeByPlaceHolder( /* static */ Status RemoteFusedGraphExecuteUtils::ReplaceInputNodeByPlaceHolder(
const string& input, const DataType type, const TensorShape& shape, const string& input, const DataType type, const TensorShape& shape,
GraphDef* graph_def) { GraphDef* graph_def) {

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h" #include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
@ -59,6 +60,8 @@ class RemoteFusedGraphExecuteUtils {
"border_outputs"; "border_outputs";
static constexpr const char* const TRANSFORM_ARG_FUSED_OP_TYPES = static constexpr const char* const TRANSFORM_ARG_FUSED_OP_TYPES =
"fused_op_types"; "fused_op_types";
static constexpr const char* const TRANSFORM_ARG_FUSE_BY_EXECUTOR =
"fuse_by_executor";
static constexpr const char* const TRANSFORM_ARG_INPUT_TYPES = "input_types"; static constexpr const char* const TRANSFORM_ARG_INPUT_TYPES = "input_types";
static constexpr const char* const TRANSFORM_ARG_INPUT_SHAPES = static constexpr const char* const TRANSFORM_ARG_INPUT_SHAPES =
"input_shapes"; "input_shapes";
@ -257,6 +260,12 @@ class RemoteFusedGraphExecuteUtils {
const std::vector<std::pair<string, Tensor>>& input_tensors, const std::vector<std::pair<string, Tensor>>& input_tensors,
GraphDef* output_graph_def); GraphDef* output_graph_def);
static Status FuseRemoteGraphByExecutor(const GraphDef& input_graph_def,
const std::vector<string>& inputs,
const std::vector<string>& outputs,
const string& executor_name,
GraphDef* output_graph_def);
static bool IsFuseReady( static bool IsFuseReady(
const GraphDef& input_graph_def, const GraphDef& input_graph_def,
const std::vector<std::pair<string, Tensor>>& input_tensors); const std::vector<std::pair<string, Tensor>>& input_tensors);
@ -273,6 +282,10 @@ class RemoteFusedGraphExecuteUtils {
static std::unordered_set<string> BuildNodeMapFromOpTypes( static std::unordered_set<string> BuildNodeMapFromOpTypes(
const GraphDef& graph_def, const std::unordered_set<string>& op_types); const GraphDef& graph_def, const std::unordered_set<string>& op_types);
static std::unordered_set<string> BuildNodeMapFromOpsDefinitions(
const GraphDef& graph_def,
const IRemoteFusedGraphOpsDefinitions& ops_definitions);
private: private:
static void EmplaceTensorShapeType(const string& name, const Tensor& tensor, static void EmplaceTensorShapeType(const string& name, const Tensor& tensor,
TensorShapeMap* tensor_shape_map); TensorShapeMap* tensor_shape_map);

View File

@ -33,6 +33,10 @@ constexpr const char* const NAME_A_PLUS_B = "A_PLUS_B";
constexpr float NODE_A_VAL = 2.0f; constexpr float NODE_A_VAL = 2.0f;
constexpr float NODE_B_VAL = 3.0f; constexpr float NODE_B_VAL = 3.0f;
constexpr float VALUE_TOLERANCE_FLOAT = 1e-8f; constexpr float VALUE_TOLERANCE_FLOAT = 1e-8f;
constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME0 =
"fuse_test_remote_fused_graph_executor0";
constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME1 =
"fuse_test_remote_fused_graph_executor1";
static NodeDef* GetNodeDef(const string& name, GraphDef* def) { static NodeDef* GetNodeDef(const string& name, GraphDef* def) {
CHECK_NE(def, nullptr); CHECK_NE(def, nullptr);
@ -44,17 +48,38 @@ static NodeDef* GetNodeDef(const string& name, GraphDef* def) {
return nullptr; return nullptr;
} }
Status BuildRemoteFusedGraphExecutor0(
std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
executor->reset(
new TestRemoteFusedGraphExecutor({"Mul"}, REMOTE_FUSED_EXECUTOR_NAME0));
return Status::OK();
}
Status BuildRemoteFusedGraphExecutor1(
std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
executor->reset(new TestRemoteFusedGraphExecutor(
{"Const", "Mul"}, REMOTE_FUSED_EXECUTOR_NAME1));
return Status::OK();
}
class FuseRemoteGraphMultipleAddOpsTest : public ::testing::Test { class FuseRemoteGraphMultipleAddOpsTest : public ::testing::Test {
protected: protected:
void SetUp() final { void SetUp() final {
TF_ASSERT_OK( TF_ASSERT_OK(
RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(&graph_def_)); RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(&graph_def_));
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
k_hexagon_remote_fused_graph_executor_build( hexagon_remote_fused_graph_executor_build(
"remote_graph_executor_name", "remote_graph_executor_name",
[](std::unique_ptr<IRemoteFusedGraphExecutor>* executor) -> Status { [](std::unique_ptr<IRemoteFusedGraphExecutor>* executor) -> Status {
return Status::OK(); return Status::OK();
}); });
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
test_remote_fused_graph_executor_build0(REMOTE_FUSED_EXECUTOR_NAME0,
BuildRemoteFusedGraphExecutor0);
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
test_remote_fused_graph_executor_build1(REMOTE_FUSED_EXECUTOR_NAME1,
BuildRemoteFusedGraphExecutor1);
} }
void TearDown() final {} void TearDown() final {}
@ -87,6 +112,18 @@ class FuseRemoteGraphMultipleAddOpsTest : public ::testing::Test {
/*require_shape_type=*/false, &result_graph_def_); /*require_shape_type=*/false, &result_graph_def_);
} }
Status FuseByExecutor0() {
return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
graph_def_, inputs_, outputs_, REMOTE_FUSED_EXECUTOR_NAME0,
&result_graph_def_);
}
Status FuseByExecutor1() {
return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
graph_def_, inputs_, outputs_, REMOTE_FUSED_EXECUTOR_NAME1,
&result_graph_def_);
}
Status BuildAndAddTensorShape() { Status BuildAndAddTensorShape() {
return RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes( return RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes(
input_tensors_, /*dry_run_inference=*/true, &graph_def_); input_tensors_, /*dry_run_inference=*/true, &graph_def_);
@ -694,6 +731,30 @@ TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByOpTypes_FGHIJ) {
<< SummarizeGraphDef(result_graph_def_); << SummarizeGraphDef(result_graph_def_);
} }
TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByExecutor_HIJ) {
ReplaceOpType({"H", "I", "J"}, "Mul");
TF_ASSERT_OK(FuseByExecutor0());
EXPECT_EQ(11, graph_def_.node_size());
EXPECT_EQ(9, result_graph_def_.node_size())
<< "=== Before: \n"
<< SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
<< SummarizeGraphDef(result_graph_def_);
}
TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByExecutor_FGHIJ) {
ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul");
TF_ASSERT_OK(FuseByExecutor1());
EXPECT_EQ(11, graph_def_.node_size());
EXPECT_EQ(3, result_graph_def_.node_size())
<< "=== Before: \n"
<< SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
<< SummarizeGraphDef(result_graph_def_);
}
TEST_F(FuseRemoteGraphMultipleAddOpsTest, PlaceAndFuse_H) { TEST_F(FuseRemoteGraphMultipleAddOpsTest, PlaceAndFuse_H) {
subgraph_node_names_ = {"H"}; subgraph_node_names_ = {"H"};

View File

@ -66,7 +66,7 @@ static Status ParseArguments(const TransformFuncContext& context,
string* input_types_str, string* input_shapes_str, string* input_types_str, string* input_shapes_str,
string* fused_nodes_str, string* border_inputs_str, string* fused_nodes_str, string* border_inputs_str,
string* border_outputs_str, string* border_outputs_str,
string* fused_op_types_str, string* fused_op_types_str, bool* fuse_by_executor,
string* remote_fused_graph_node_name, string* remote_fused_graph_node_name,
string* remote_graph_executor_name) { string* remote_graph_executor_name) {
TF_RETURN_IF_ERROR(context.GetOneStringParameter( TF_RETURN_IF_ERROR(context.GetOneStringParameter(
@ -87,6 +87,9 @@ static Status ParseArguments(const TransformFuncContext& context,
TF_RETURN_IF_ERROR(context.GetOneStringParameter( TF_RETURN_IF_ERROR(context.GetOneStringParameter(
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES, "", RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES, "",
fused_op_types_str)); fused_op_types_str));
TF_RETURN_IF_ERROR(context.GetOneBoolParameter(
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR, false,
fuse_by_executor));
TF_RETURN_IF_ERROR(context.GetOneStringParameter( TF_RETURN_IF_ERROR(context.GetOneStringParameter(
RemoteFusedGraphExecuteUtils:: RemoteFusedGraphExecuteUtils::
TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME, TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
@ -140,12 +143,14 @@ Status FuseRemoteGraph(const GraphDef& input_graph_def,
string border_inputs_str; string border_inputs_str;
string border_outputs_str; string border_outputs_str;
string fused_op_types_str; string fused_op_types_str;
bool fuse_by_executor = false;
string remote_fused_graph_node_name; string remote_fused_graph_node_name;
string remote_graph_executor_name; string remote_graph_executor_name;
TF_RETURN_IF_ERROR(ParseArguments( TF_RETURN_IF_ERROR(ParseArguments(
context, &input_types_str, &input_shapes_str, &fused_nodes_str, context, &input_types_str, &input_shapes_str, &fused_nodes_str,
&border_inputs_str, &border_outputs_str, &fused_op_types_str, &border_inputs_str, &border_outputs_str, &fused_op_types_str,
&remote_fused_graph_node_name, &remote_graph_executor_name)); &fuse_by_executor, &remote_fused_graph_node_name,
&remote_graph_executor_name));
if (!input_types_str.empty()) { if (!input_types_str.empty()) {
TF_RETURN_IF_ERROR(PlaceShapeType(inputs, outputs, input_types_str, TF_RETURN_IF_ERROR(PlaceShapeType(inputs, outputs, input_types_str,
@ -187,6 +192,10 @@ Status FuseRemoteGraph(const GraphDef& input_graph_def,
mutable_input_graph_def, inputs, outputs, remote_fused_graph_node_name, mutable_input_graph_def, inputs, outputs, remote_fused_graph_node_name,
fused_op_types, remote_graph_executor_name, require_shape_type, fused_op_types, remote_graph_executor_name, require_shape_type,
output_graph_def)); output_graph_def));
} else if (fuse_by_executor) {
TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
mutable_input_graph_def, inputs, outputs, remote_graph_executor_name,
output_graph_def));
} else { } else {
CHECK(false) << "Fuse targets are not specified."; CHECK(false) << "Fuse targets are not specified.";
} }
@ -205,15 +214,17 @@ Status PlaceRemoteGraphArguments(const GraphDef& input_graph_def,
string input_types_str; string input_types_str;
string input_shapes_str; string input_shapes_str;
string fused_nodes_str; string fused_nodes_str;
string fused_op_types_str;
string border_inputs_str; string border_inputs_str;
string border_outputs_str; string border_outputs_str;
string fused_op_types_str;
bool fuse_by_executor = false;
string remote_fused_graph_node_name; string remote_fused_graph_node_name;
string remote_graph_executor_name; string remote_graph_executor_name;
TF_RETURN_IF_ERROR(ParseArguments( TF_RETURN_IF_ERROR(ParseArguments(
context, &input_types_str, &input_shapes_str, &fused_nodes_str, context, &input_types_str, &input_shapes_str, &fused_nodes_str,
&border_inputs_str, &border_outputs_str, &fused_op_types_str, &border_inputs_str, &border_outputs_str, &fused_op_types_str,
&remote_fused_graph_node_name, &remote_graph_executor_name)); &fuse_by_executor, &remote_fused_graph_node_name,
&remote_graph_executor_name));
if (!input_types_str.empty()) { if (!input_types_str.empty()) {
TF_RETURN_IF_ERROR(PlaceShapeType(inputs, outputs, input_types_str, TF_RETURN_IF_ERROR(PlaceShapeType(inputs, outputs, input_types_str,

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/graph/default_device.h" #include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/testlib.h" #include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h" #include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h" #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
@ -43,11 +44,28 @@ Status PlaceRemoteGraphArguments(const GraphDef& input_graph_def,
GraphDef* output_graph_def); GraphDef* output_graph_def);
namespace { namespace {
constexpr const char* const REMOTE_FUSED_GRAPH_EXECUTOR_NAME = constexpr const char* const REMOTE_FUSED_GRAPH_EXECUTOR_NAME =
"remote_fused_graph_executor_name"; "remote_fused_graph_executor_name";
constexpr const char* const REMOTE_FUSED_GRAPH_NODE_NAME = constexpr const char* const REMOTE_FUSED_GRAPH_NODE_NAME =
"remote_fused_graph_node_name"; "remote_fused_graph_node_name";
constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME0 =
"fuse_test_remote_fused_graph_executor0";
constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME1 =
"fuse_test_remote_fused_graph_executor1";
Status BuildRemoteFusedGraphExecutor0(
std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
executor->reset(
new TestRemoteFusedGraphExecutor({"Mul"}, REMOTE_FUSED_EXECUTOR_NAME0));
return Status::OK();
}
Status BuildRemoteFusedGraphExecutor1(
std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
executor->reset(new TestRemoteFusedGraphExecutor(
{"Const", "Mul"}, REMOTE_FUSED_EXECUTOR_NAME1));
return Status::OK();
}
class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test { class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
protected: protected:
@ -55,11 +73,18 @@ class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph( TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(
&input_graph_def_)); &input_graph_def_));
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
k_hexagon_remote_fused_graph_executor_build( hexagon_remote_fused_graph_executor_build(
REMOTE_FUSED_GRAPH_EXECUTOR_NAME, REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
[](std::unique_ptr<IRemoteFusedGraphExecutor>* executor) -> Status { [](std::unique_ptr<IRemoteFusedGraphExecutor>* executor) -> Status {
return Status::OK(); return Status::OK();
}); });
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
test_remote_fused_graph_executor_build0(REMOTE_FUSED_EXECUTOR_NAME0,
BuildRemoteFusedGraphExecutor0);
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
test_remote_fused_graph_executor_build1(REMOTE_FUSED_EXECUTOR_NAME1,
BuildRemoteFusedGraphExecutor1);
} }
void TearDown() final {} void TearDown() final {}
@ -113,10 +138,16 @@ class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
{fused_op_types_str_}})); {fused_op_types_str_}}));
} }
if (fuse_by_executor_) {
context.params.insert(std::pair<string, std::vector<string>>(
{RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR,
{"true"}}));
}
context.params.insert(std::pair<string, std::vector<string>>( context.params.insert(std::pair<string, std::vector<string>>(
{RemoteFusedGraphExecuteUtils:: {RemoteFusedGraphExecuteUtils::
TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME, TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
{REMOTE_FUSED_GRAPH_EXECUTOR_NAME}})); {remote_fused_graph_executor_name_}}));
context.params.insert(std::pair<string, std::vector<string>>( context.params.insert(std::pair<string, std::vector<string>>(
{RemoteFusedGraphExecuteUtils:: {RemoteFusedGraphExecuteUtils::
TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME, TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME,
@ -160,7 +191,7 @@ class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO, ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO,
&serialized_proto)); &serialized_proto));
info.ParseFromString(serialized_proto); info.ParseFromString(serialized_proto);
CHECK_EQ(REMOTE_FUSED_GRAPH_EXECUTOR_NAME, info.executor_name()); CHECK_EQ(remote_fused_graph_executor_name_, info.executor_name());
} }
} }
EXPECT_EQ(expected_cluster_count, cluster_count); EXPECT_EQ(expected_cluster_count, cluster_count);
@ -178,6 +209,8 @@ class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
string border_inputs_str_; string border_inputs_str_;
string border_outputs_str_; string border_outputs_str_;
string fused_op_types_str_; string fused_op_types_str_;
string remote_fused_graph_executor_name_{REMOTE_FUSED_GRAPH_EXECUTOR_NAME};
bool fuse_by_executor_{false};
}; };
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
@ -260,6 +293,24 @@ TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
CheckGraph(3, 1); CheckGraph(3, 1);
} }
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
FuseRemoteGraphByExecutor_HIJ) {
ReplaceOpType({"H", "I", "J"}, "Mul");
remote_fused_graph_executor_name_ = REMOTE_FUSED_EXECUTOR_NAME0;
fuse_by_executor_ = true;
TF_ASSERT_OK(Fuse());
CheckGraph(9, 1);
}
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
FuseRemoteGraphByExecutor_FGHIJ) {
ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul");
remote_fused_graph_executor_name_ = REMOTE_FUSED_EXECUTOR_NAME1;
fuse_by_executor_ = true;
TF_ASSERT_OK(Fuse());
CheckGraph(3, 1);
}
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_HIJ) { TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_HIJ) {
fused_node_names_str_ = "H,I,J"; fused_node_names_str_ = "H,I,J";
TF_ASSERT_OK(PlaceFuseArgs()); TF_ASSERT_OK(PlaceFuseArgs());