Add a way to fuse a graph by remote graph executor so that users don't need to be aware of supported op types, node names, subgraph stracture etc.
PiperOrigin-RevId: 162411763
This commit is contained in:
parent
d8672f1839
commit
d5f4d9bbac
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
|
#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
|
||||||
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
|
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
|
||||||
|
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
|
||||||
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
|
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
@ -110,7 +111,7 @@ static void DumpRemoteFusedGraph(const NodeDef& node_def) {
|
|||||||
static void CheckOpsSupport(const GraphDef& graph_def,
|
static void CheckOpsSupport(const GraphDef& graph_def,
|
||||||
const bool dump_all_nodes,
|
const bool dump_all_nodes,
|
||||||
const bool dump_shape_and_type) {
|
const bool dump_shape_and_type) {
|
||||||
const IGraphTransferOpsDefinitions& ops_definition =
|
const IRemoteFusedGraphOpsDefinitions& ops_definition =
|
||||||
HexagonOpsDefinitions::getInstance();
|
HexagonOpsDefinitions::getInstance();
|
||||||
LOG(INFO) << "Checking " << graph_def.node_size() << " nodes";
|
LOG(INFO) << "Checking " << graph_def.node_size() << " nodes";
|
||||||
LOG(INFO) << "dump_all_nodes = " << dump_all_nodes
|
LOG(INFO) << "dump_all_nodes = " << dump_all_nodes
|
||||||
@ -127,7 +128,7 @@ static void CheckOpsSupport(const GraphDef& graph_def,
|
|||||||
}
|
}
|
||||||
// TODO(satok): Set correct data type if it's given.
|
// TODO(satok): Set correct data type if it's given.
|
||||||
const int op_id = ops_definition.GetOpIdFor(node.op(), {});
|
const int op_id = ops_definition.GetOpIdFor(node.op(), {});
|
||||||
if (op_id == IGraphTransferOpsDefinitions::INVALID_OP_ID) {
|
if (op_id == IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) {
|
||||||
all_supported = false;
|
all_supported = false;
|
||||||
LOG(ERROR) << "OP type: " << node.op() << " is not supported on hvx. "
|
LOG(ERROR) << "OP type: " << node.op() << " is not supported on hvx. "
|
||||||
<< "Name = " << node.name();
|
<< "Name = " << node.name();
|
||||||
|
@ -82,7 +82,7 @@ fi
|
|||||||
if [[ "${USE_HEXAGON}" == "true" ]]; then
|
if [[ "${USE_HEXAGON}" == "true" ]]; then
|
||||||
HEXAGON_PARENT_DIR=$(cd "${HEXAGON_DOWNLOAD_PATH}" >/dev/null && pwd)
|
HEXAGON_PARENT_DIR=$(cd "${HEXAGON_DOWNLOAD_PATH}" >/dev/null && pwd)
|
||||||
HEXAGON_LIBS="${HEXAGON_PARENT_DIR}/libs"
|
HEXAGON_LIBS="${HEXAGON_PARENT_DIR}/libs"
|
||||||
HEXAGON_INCLUDE=$(cd "tensorflow/core/platform/hexagon" >/dev/null && pwd)
|
HEXAGON_INCLUDE=$(cd "tensorflow/core/kernels/hexagon" >/dev/null && pwd)
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [[ "${ENABLE_EXPERIMENTAL_HEXNN_OPS}" == "true" ]]; then
|
if [[ "${ENABLE_EXPERIMENTAL_HEXNN_OPS}" == "true" ]]; then
|
||||||
|
@ -56,7 +56,6 @@ tensorflow/core/platform/posix/test.cc
|
|||||||
|
|
||||||
QUANTIZATION_TEST_SRCS := \
|
QUANTIZATION_TEST_SRCS := \
|
||||||
$(GRAPH_TRANSFER_SRCS) \
|
$(GRAPH_TRANSFER_SRCS) \
|
||||||
tensorflow/core/kernels/hexagon/quantized_matmul_op_for_hexagon_test.cc \
|
|
||||||
tensorflow/core/kernels/hexagon/graph_transferer_test.cc \
|
tensorflow/core/kernels/hexagon/graph_transferer_test.cc \
|
||||||
tensorflow/contrib/makefile/test/test_main.cc
|
tensorflow/contrib/makefile/test/test_main.cc
|
||||||
|
|
||||||
|
@ -5057,9 +5057,13 @@ tf_kernel_library(
|
|||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "remote_fused_graph_execute_utils",
|
name = "remote_fused_graph_execute_utils",
|
||||||
srcs = ["remote_fused_graph_execute_utils.cc"],
|
srcs = [
|
||||||
|
"i_remote_fused_graph_ops_definitions.cc",
|
||||||
|
"remote_fused_graph_execute_utils.cc",
|
||||||
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"i_remote_fused_graph_executor.h",
|
"i_remote_fused_graph_executor.h",
|
||||||
|
"i_remote_fused_graph_ops_definitions.h",
|
||||||
"remote_fused_graph_execute_utils.h",
|
"remote_fused_graph_execute_utils.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
@ -5078,6 +5082,7 @@ cc_library(
|
|||||||
srcs = ["remote_fused_graph_execute_op_test_utils.cc"],
|
srcs = ["remote_fused_graph_execute_op_test_utils.cc"],
|
||||||
hdrs = ["remote_fused_graph_execute_op_test_utils.h"],
|
hdrs = ["remote_fused_graph_execute_op_test_utils.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":remote_fused_graph_execute_utils",
|
||||||
"//tensorflow/cc:cc_ops",
|
"//tensorflow/cc:cc_ops",
|
||||||
"//tensorflow/cc:ops",
|
"//tensorflow/cc:ops",
|
||||||
"//tensorflow/cc:scope",
|
"//tensorflow/cc:scope",
|
||||||
|
@ -26,23 +26,6 @@ filegroup(
|
|||||||
visibility = ["//tensorflow:__subpackages__"],
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
|
||||||
name = "quantized_matmul_op_for_hexagon_test",
|
|
||||||
size = "small",
|
|
||||||
srcs = ["quantized_matmul_op_for_hexagon_test.cc"],
|
|
||||||
tags = ["nomsan"], # http://b/32242946
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:protos_all_cc",
|
|
||||||
"//tensorflow/core:test",
|
|
||||||
"//tensorflow/core:test_main",
|
|
||||||
"//tensorflow/core:testlib",
|
|
||||||
"//tensorflow/core/kernels:ops_testutil",
|
|
||||||
"//tensorflow/core/kernels:ops_util",
|
|
||||||
"//tensorflow/core/kernels:quantized_ops",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "graph_transferer_test",
|
name = "graph_transferer_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
@ -79,14 +62,14 @@ tf_kernel_library(
|
|||||||
"graph_transferer.cc",
|
"graph_transferer.cc",
|
||||||
"hexagon_control_wrapper.cc",
|
"hexagon_control_wrapper.cc",
|
||||||
"hexagon_ops_definitions.cc",
|
"hexagon_ops_definitions.cc",
|
||||||
"i_graph_transfer_ops_definitions.cc",
|
"soc_interface.cc",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"graph_transfer_utils.h",
|
"graph_transfer_utils.h",
|
||||||
"graph_transferer.h",
|
"graph_transferer.h",
|
||||||
"hexagon_control_wrapper.h",
|
"hexagon_control_wrapper.h",
|
||||||
"hexagon_ops_definitions.h",
|
"hexagon_ops_definitions.h",
|
||||||
"i_graph_transfer_ops_definitions.h",
|
"soc_interface.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/cc:cc_ops",
|
"//tensorflow/cc:cc_ops",
|
||||||
@ -111,6 +94,7 @@ cc_library(
|
|||||||
":graph_transferer",
|
":graph_transferer",
|
||||||
"//tensorflow/cc:cc_ops",
|
"//tensorflow/cc:cc_ops",
|
||||||
"//tensorflow/cc:remote_fused_graph_ops",
|
"//tensorflow/cc:remote_fused_graph_ops",
|
||||||
|
"//tensorflow/core/kernels:remote_fused_graph_execute_utils",
|
||||||
"//tensorflow/tools/graph_transforms:transform_utils",
|
"//tensorflow/tools/graph_transforms:transform_utils",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
@ -121,6 +105,7 @@ tf_cc_test(
|
|||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["hexagon_rewriter_transform_test.cc"],
|
srcs = ["hexagon_rewriter_transform_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":graph_transferer",
|
||||||
":hexagon_rewriter_transform",
|
":hexagon_rewriter_transform",
|
||||||
"//tensorflow/cc:cc_ops",
|
"//tensorflow/cc:cc_ops",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
@ -129,6 +114,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core:testlib",
|
"//tensorflow/core:testlib",
|
||||||
|
"//tensorflow/core/kernels:remote_fused_graph_execute_utils",
|
||||||
"//tensorflow/tools/graph_transforms:transform_utils",
|
"//tensorflow/tools/graph_transforms:transform_utils",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -96,7 +96,7 @@ GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* static */ GraphDef GraphTransferUtils::BuildFusedGraphDef(
|
/* static */ GraphDef GraphTransferUtils::BuildFusedGraphDef(
|
||||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||||
const string& remote_graph_execute_name,
|
const string& remote_graph_execute_name,
|
||||||
const std::vector<std::pair<string, Tensor>>& inputs,
|
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||||
const std::vector<string>& outputs, GraphDef* original_def) {
|
const std::vector<string>& outputs, GraphDef* original_def) {
|
||||||
|
@ -39,7 +39,7 @@ class GraphTransferUtils {
|
|||||||
const int element_count, const int top_n);
|
const int element_count, const int top_n);
|
||||||
|
|
||||||
static GraphDef BuildFusedGraphDef(
|
static GraphDef BuildFusedGraphDef(
|
||||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||||
const string& remote_graph_execute_name,
|
const string& remote_graph_execute_name,
|
||||||
const std::vector<std::pair<string, Tensor>>& inputs,
|
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||||
const std::vector<string>& outputs, GraphDef* original_def);
|
const std::vector<string>& outputs, GraphDef* original_def);
|
||||||
|
@ -81,7 +81,7 @@ static Node* FindMutableNodeByName(const string& name, Graph* graph) {
|
|||||||
* of node to transfer the graph to SOC.
|
* of node to transfer the graph to SOC.
|
||||||
*/
|
*/
|
||||||
Status GraphTransferer::LoadGraphFromProto(
|
Status GraphTransferer::LoadGraphFromProto(
|
||||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||||
const GraphDef& graph_def,
|
const GraphDef& graph_def,
|
||||||
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
||||||
const std::vector<string>& output_node_names,
|
const std::vector<string>& output_node_names,
|
||||||
@ -177,9 +177,6 @@ Status GraphTransferer::LoadGraphFromProto(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
graph_transfer_info_.set_destination(
|
|
||||||
ops_definitions.GetTransferDestination());
|
|
||||||
|
|
||||||
ClearCache();
|
ClearCache();
|
||||||
if (DBG_DUMP_PARAMS) {
|
if (DBG_DUMP_PARAMS) {
|
||||||
DumpNodeTransferParams();
|
DumpNodeTransferParams();
|
||||||
@ -191,7 +188,7 @@ Status GraphTransferer::LoadGraphFromProto(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status GraphTransferer::LoadGraphFromProtoFile(
|
Status GraphTransferer::LoadGraphFromProtoFile(
|
||||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||||
const string& graph_def_path,
|
const string& graph_def_path,
|
||||||
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
||||||
const std::vector<string>& output_node_names, const bool is_text_proto,
|
const std::vector<string>& output_node_names, const bool is_text_proto,
|
||||||
@ -415,7 +412,7 @@ Status GraphTransferer::TransformGraphToAddAggregatedInputNode(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status GraphTransferer::RegisterNode(
|
Status GraphTransferer::RegisterNode(
|
||||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||||
const ShapeRefiner& shape_refiner, const Node& node,
|
const ShapeRefiner& shape_refiner, const Node& node,
|
||||||
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
||||||
const std::vector<string>& output_node_names) {
|
const std::vector<string>& output_node_names) {
|
||||||
@ -438,7 +435,7 @@ Status GraphTransferer::RegisterNode(
|
|||||||
} else if (IsNodeFlattenReshape(node, shape_refiner)) {
|
} else if (IsNodeFlattenReshape(node, shape_refiner)) {
|
||||||
RegisterFlattenNode(ops_definitions, shape_refiner, node);
|
RegisterFlattenNode(ops_definitions, shape_refiner, node);
|
||||||
} else if (ops_definitions.GetOpIdFor(node.type_string(), {}) !=
|
} else if (ops_definitions.GetOpIdFor(node.type_string(), {}) !=
|
||||||
IGraphTransferOpsDefinitions::INVALID_OP_ID) {
|
IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) {
|
||||||
// TODO(satok): Set correct data type if it's given.
|
// TODO(satok): Set correct data type if it's given.
|
||||||
RegisterGenericNode(ops_definitions, shape_refiner, node);
|
RegisterGenericNode(ops_definitions, shape_refiner, node);
|
||||||
} else {
|
} else {
|
||||||
@ -637,7 +634,7 @@ bool GraphTransferer::IsNodeFlattenReshape(const Node& node,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void GraphTransferer::RegisterNodeWithPaddingAndStrides(
|
void GraphTransferer::RegisterNodeWithPaddingAndStrides(
|
||||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||||
const ShapeRefiner& shape_refiner, const Node& node) {
|
const ShapeRefiner& shape_refiner, const Node& node) {
|
||||||
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
|
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
|
||||||
const int id = node_name_to_id_cache_map_[node.name()];
|
const int id = node_name_to_id_cache_map_[node.name()];
|
||||||
@ -671,7 +668,7 @@ void GraphTransferer::RegisterNodeWithPaddingAndStrides(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void GraphTransferer::RegisterNodeWithRank(
|
void GraphTransferer::RegisterNodeWithRank(
|
||||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||||
const ShapeRefiner& shape_refiner, const Node& node) {
|
const ShapeRefiner& shape_refiner, const Node& node) {
|
||||||
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
|
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
|
||||||
const int id = node_name_to_id_cache_map_[node.name()];
|
const int id = node_name_to_id_cache_map_[node.name()];
|
||||||
@ -704,7 +701,7 @@ void GraphTransferer::RegisterNodeWithRank(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void GraphTransferer::RegisterPadNode(
|
void GraphTransferer::RegisterPadNode(
|
||||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||||
const ShapeRefiner& shape_refiner, const Node& node) {
|
const ShapeRefiner& shape_refiner, const Node& node) {
|
||||||
static constexpr int PAD_WIDTH = 4;
|
static constexpr int PAD_WIDTH = 4;
|
||||||
static constexpr int PAD_HEIGHT = 2;
|
static constexpr int PAD_HEIGHT = 2;
|
||||||
@ -779,7 +776,7 @@ void GraphTransferer::RegisterPadNode(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void GraphTransferer::RegisterInputNode(
|
void GraphTransferer::RegisterInputNode(
|
||||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||||
const ShapeRefiner& shape_refiner, const Node& node) {
|
const ShapeRefiner& shape_refiner, const Node& node) {
|
||||||
const string op_type = node.type_string();
|
const string op_type = node.type_string();
|
||||||
VLOG(1) << "Register input node: " << node.name() << ", " << op_type;
|
VLOG(1) << "Register input node: " << node.name() << ", " << op_type;
|
||||||
@ -797,12 +794,13 @@ void GraphTransferer::RegisterInputNode(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void GraphTransferer::RegisterFlattenNode(
|
void GraphTransferer::RegisterFlattenNode(
|
||||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||||
const ShapeRefiner& shape_refiner, const Node& node) {
|
const ShapeRefiner& shape_refiner, const Node& node) {
|
||||||
VLOG(1) << "Register flatten node: " << node.name();
|
VLOG(1) << "Register flatten node: " << node.name();
|
||||||
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
|
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
|
||||||
const int id = node_name_to_id_cache_map_[node.name()];
|
const int id = node_name_to_id_cache_map_[node.name()];
|
||||||
const string op_type = IGraphTransferOpsDefinitions::FLATTEN_OP_NAME;
|
// TODO(satok): Remove dependency to specific type
|
||||||
|
const string op_type = "FLATTEN";
|
||||||
// TODO(satok): Set correct data type if it's given.
|
// TODO(satok): Set correct data type if it's given.
|
||||||
const int op_type_id = ops_definitions.GetOpIdFor(op_type, {});
|
const int op_type_id = ops_definitions.GetOpIdFor(op_type, {});
|
||||||
CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount());
|
CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount());
|
||||||
@ -814,7 +812,7 @@ void GraphTransferer::RegisterFlattenNode(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void GraphTransferer::RegisterGenericNode(
|
void GraphTransferer::RegisterGenericNode(
|
||||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||||
const ShapeRefiner& shape_refiner, const Node& node) {
|
const ShapeRefiner& shape_refiner, const Node& node) {
|
||||||
VLOG(1) << "Register generic node: " << node.name();
|
VLOG(1) << "Register generic node: " << node.name();
|
||||||
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
|
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
|
||||||
@ -832,7 +830,7 @@ void GraphTransferer::RegisterGenericNode(
|
|||||||
// TODO(satok): Remove this function.
|
// TODO(satok): Remove this function.
|
||||||
// TODO(satok): Remove only_register_const_node.
|
// TODO(satok): Remove only_register_const_node.
|
||||||
Status GraphTransferer::RegisterNodeIfAllInputsAreCached(
|
Status GraphTransferer::RegisterNodeIfAllInputsAreCached(
|
||||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||||
const ShapeRefiner& shape_refiner, const Node& node,
|
const ShapeRefiner& shape_refiner, const Node& node,
|
||||||
const bool only_register_const_node,
|
const bool only_register_const_node,
|
||||||
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
||||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/graph_transfer_info.pb.h"
|
#include "tensorflow/core/framework/graph_transfer_info.pb.h"
|
||||||
#include "tensorflow/core/framework/shape_inference.h"
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.h"
|
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
|
||||||
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
|
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
@ -53,7 +53,7 @@ class GraphTransferer {
|
|||||||
// TODO(satok): Pass a pair of TensorShape and DataType instead of
|
// TODO(satok): Pass a pair of TensorShape and DataType instead of
|
||||||
// Tensor as input_node_info_list.
|
// Tensor as input_node_info_list.
|
||||||
Status LoadGraphFromProto(
|
Status LoadGraphFromProto(
|
||||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||||
const GraphDef& graph_def,
|
const GraphDef& graph_def,
|
||||||
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
||||||
const std::vector<string>& output_node_names,
|
const std::vector<string>& output_node_names,
|
||||||
@ -63,7 +63,7 @@ class GraphTransferer {
|
|||||||
// TODO(satok): Pass a pair of TensorShape and DataType instead of
|
// TODO(satok): Pass a pair of TensorShape and DataType instead of
|
||||||
// Tensor as input_node_info_list.
|
// Tensor as input_node_info_list.
|
||||||
Status LoadGraphFromProtoFile(
|
Status LoadGraphFromProtoFile(
|
||||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||||
const string& graph_def_path,
|
const string& graph_def_path,
|
||||||
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
||||||
const std::vector<string>& output_node_names, const bool is_text_proto,
|
const std::vector<string>& output_node_names, const bool is_text_proto,
|
||||||
@ -112,7 +112,7 @@ class GraphTransferer {
|
|||||||
Graph* graph, ShapeRefiner* shape_refiner);
|
Graph* graph, ShapeRefiner* shape_refiner);
|
||||||
|
|
||||||
Status RegisterNode(
|
Status RegisterNode(
|
||||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||||
const ShapeRefiner& shape_refiner, const Node& node,
|
const ShapeRefiner& shape_refiner, const Node& node,
|
||||||
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
||||||
const std::vector<string>& output_node_names);
|
const std::vector<string>& output_node_names);
|
||||||
@ -140,30 +140,29 @@ class GraphTransferer {
|
|||||||
const ShapeRefiner& shape_refiner);
|
const ShapeRefiner& shape_refiner);
|
||||||
|
|
||||||
void RegisterNodeWithPaddingAndStrides(
|
void RegisterNodeWithPaddingAndStrides(
|
||||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||||
const ShapeRefiner& shape_refiner, const Node& node);
|
const ShapeRefiner& shape_refiner, const Node& node);
|
||||||
|
|
||||||
void RegisterNodeWithRank(const IGraphTransferOpsDefinitions& ops_definitions,
|
void RegisterNodeWithRank(
|
||||||
const ShapeRefiner& shape_refiner,
|
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||||
const Node& node);
|
|
||||||
|
|
||||||
void RegisterPadNode(const IGraphTransferOpsDefinitions& ops_definitions,
|
|
||||||
const ShapeRefiner& shape_refiner, const Node& node);
|
const ShapeRefiner& shape_refiner, const Node& node);
|
||||||
|
|
||||||
void RegisterInputNode(const IGraphTransferOpsDefinitions& ops_definitions,
|
void RegisterPadNode(const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||||
const ShapeRefiner& shape_refiner,
|
const ShapeRefiner& shape_refiner, const Node& node);
|
||||||
const Node& node);
|
|
||||||
|
|
||||||
void RegisterFlattenNode(const IGraphTransferOpsDefinitions& ops_definitions,
|
void RegisterInputNode(const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||||
const ShapeRefiner& shape_refiner,
|
const ShapeRefiner& shape_refiner, const Node& node);
|
||||||
const Node& node);
|
|
||||||
|
|
||||||
void RegisterGenericNode(const IGraphTransferOpsDefinitions& ops_definitions,
|
void RegisterFlattenNode(
|
||||||
const ShapeRefiner& shape_refiner,
|
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||||
const Node& node);
|
const ShapeRefiner& shape_refiner, const Node& node);
|
||||||
|
|
||||||
|
void RegisterGenericNode(
|
||||||
|
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||||
|
const ShapeRefiner& shape_refiner, const Node& node);
|
||||||
|
|
||||||
Status RegisterNodeIfAllInputsAreCached(
|
Status RegisterNodeIfAllInputsAreCached(
|
||||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||||
const ShapeRefiner& shape_refiner, const Node& node,
|
const ShapeRefiner& shape_refiner, const Node& node,
|
||||||
const bool only_register_const_node,
|
const bool only_register_const_node,
|
||||||
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
||||||
|
@ -22,8 +22,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h"
|
#include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h"
|
||||||
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
|
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
|
||||||
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
|
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
|
||||||
#include "tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.h"
|
|
||||||
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
|
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
|
||||||
|
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
@ -50,7 +50,7 @@ class GraphTransfererTest : public ::testing::Test {
|
|||||||
|
|
||||||
const RemoteFusedGraphExecuteUtils::TensorShapeMap EMPTY_OUTPUT_TENSOR_MAP;
|
const RemoteFusedGraphExecuteUtils::TensorShapeMap EMPTY_OUTPUT_TENSOR_MAP;
|
||||||
|
|
||||||
class TestGraphTransferOpsDefinitions : public IGraphTransferOpsDefinitions {
|
class TestGraphTransferOpsDefinitions : public IRemoteFusedGraphOpsDefinitions {
|
||||||
public:
|
public:
|
||||||
int GetTotalOpsCount() const final { return op_types_.size(); }
|
int GetTotalOpsCount() const final { return op_types_.size(); }
|
||||||
|
|
||||||
@ -63,10 +63,6 @@ class TestGraphTransferOpsDefinitions : public IGraphTransferOpsDefinitions {
|
|||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
GraphTransferInfo::Destination GetTransferDestination() const final {
|
|
||||||
return GraphTransferInfo::NOP;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const std::vector<string> op_types_{"INPUT", "OUTPUT", "Conv2D",
|
const std::vector<string> op_types_{"INPUT", "OUTPUT", "Conv2D",
|
||||||
"MaxPool", "NoOp", "Add",
|
"MaxPool", "NoOp", "Add",
|
||||||
@ -371,14 +367,14 @@ TEST_F(GraphTransfererTest, LoadMaxPoolGraph) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(HexagonOpsDefinitions, CheckOpsDefinitions) {
|
TEST(HexagonOpsDefinitions, CheckOpsDefinitions) {
|
||||||
const IGraphTransferOpsDefinitions& ops_definitions =
|
const IRemoteFusedGraphOpsDefinitions& ops_definitions =
|
||||||
HexagonOpsDefinitions::getInstance();
|
HexagonOpsDefinitions::getInstance();
|
||||||
const int total_ops_count = ops_definitions.GetTotalOpsCount();
|
const int total_ops_count = ops_definitions.GetTotalOpsCount();
|
||||||
EXPECT_GT(total_ops_count, 0);
|
EXPECT_GT(total_ops_count, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GraphTransferer, LoadGraphFromProtoFile) {
|
TEST(GraphTransferer, LoadGraphFromProtoFile) {
|
||||||
const IGraphTransferOpsDefinitions* ops_definitions =
|
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
|
||||||
&TEST_GRAPH_TRANSFER_OPS_DEFINITIONS;
|
&TEST_GRAPH_TRANSFER_OPS_DEFINITIONS;
|
||||||
string filename =
|
string filename =
|
||||||
io::JoinPath(testing::TensorFlowSrcRoot(),
|
io::JoinPath(testing::TensorFlowSrcRoot(),
|
||||||
@ -441,7 +437,7 @@ void CompareGraphTransferInfo(const GraphTransferInfo& a,
|
|||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
TEST(GraphTransferer, LoadGraphFromProtoFileShapeInferenceSimple) {
|
TEST(GraphTransferer, LoadGraphFromProtoFileShapeInferenceSimple) {
|
||||||
const IGraphTransferOpsDefinitions* ops_definitions =
|
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
|
||||||
&TEST_GRAPH_TRANSFER_OPS_DEFINITIONS;
|
&TEST_GRAPH_TRANSFER_OPS_DEFINITIONS;
|
||||||
string filename =
|
string filename =
|
||||||
io::JoinPath(testing::TensorFlowSrcRoot(),
|
io::JoinPath(testing::TensorFlowSrcRoot(),
|
||||||
|
@ -16,18 +16,19 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h"
|
#include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h"
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
|
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
|
||||||
|
#include "tensorflow/core/kernels/hexagon/soc_interface.h"
|
||||||
#ifdef USE_HEXAGON_LIBS
|
|
||||||
#include "tensorflow/core/platform/hexagon/soc_interface.h"
|
|
||||||
#include "tensorflow/core/platform/profile_utils/cpu_utils.h"
|
#include "tensorflow/core/platform/profile_utils/cpu_utils.h"
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
constexpr const char* const INPUT_OP_NAME = "INPUT";
|
|
||||||
constexpr const char* const OUTPUT_OP_NAME = "OUTPUT";
|
constexpr const char* const OUTPUT_OP_NAME = "OUTPUT";
|
||||||
|
constexpr const char* const REMOTE_FUSED_GRAPH_NODE_NAME_PREFIX =
|
||||||
|
"hexagon_remote_fused_graph";
|
||||||
|
/* static */ constexpr const char* const
|
||||||
|
HexagonControlWrapper::REMOTE_FUSED_GRAPH_EXECUTOR_NAME;
|
||||||
|
|
||||||
constexpr int ALIGNMENT_BYTES = 16;
|
constexpr int ALIGNMENT_BYTES = 16;
|
||||||
|
constexpr int MAX_IN_OUT_COUNT = 128;
|
||||||
|
|
||||||
const bool DBG_DUMP_VERIFICATION_STRING = false;
|
const bool DBG_DUMP_VERIFICATION_STRING = false;
|
||||||
const int DBG_LEVEL = 0; // -2: verbose, -1: debug, 0: info
|
const int DBG_LEVEL = 0; // -2: verbose, -1: debug, 0: info
|
||||||
@ -63,7 +64,6 @@ static uint8* FindAlignedPointer(uint8* ptr) {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef USE_HEXAGON_LIBS
|
|
||||||
int HexagonControlWrapper::GetVersion() {
|
int HexagonControlWrapper::GetVersion() {
|
||||||
return soc_interface_GetSocControllerVersion();
|
return soc_interface_GetSocControllerVersion();
|
||||||
}
|
}
|
||||||
@ -95,7 +95,6 @@ bool HexagonControlWrapper::Init(const RemoteFusedGraphExecuteInfo& info) {
|
|||||||
LOG(ERROR) << "Hexagon initialization was failed. See log output.";
|
LOG(ERROR) << "Hexagon initialization was failed. See log output.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const GraphTransferInfo& gt_info = graph_transferer_.GetGraphTransferInfo();
|
|
||||||
std::vector<int> input_sizes;
|
std::vector<int> input_sizes;
|
||||||
std::vector<int> output_sizes;
|
std::vector<int> output_sizes;
|
||||||
CHECK_NOTNULL(execute_info_);
|
CHECK_NOTNULL(execute_info_);
|
||||||
@ -207,8 +206,9 @@ bool HexagonControlWrapper::SetupGraph() {
|
|||||||
for (const GraphTransferInfo::NodeInputInfo& input_params :
|
for (const GraphTransferInfo::NodeInputInfo& input_params :
|
||||||
graph_transfer_info.node_input_info()) {
|
graph_transfer_info.node_input_info()) {
|
||||||
const int count = input_params.node_input_size();
|
const int count = input_params.node_input_size();
|
||||||
int node_ids[count];
|
CHECK(count <= MAX_IN_OUT_COUNT);
|
||||||
int ports[count];
|
int node_ids[MAX_IN_OUT_COUNT];
|
||||||
|
int ports[MAX_IN_OUT_COUNT];
|
||||||
for (int i = 0; i < count; ++i) {
|
for (int i = 0; i < count; ++i) {
|
||||||
const GraphTransferInfo::NodeInput& node_input =
|
const GraphTransferInfo::NodeInput& node_input =
|
||||||
input_params.node_input(i);
|
input_params.node_input(i);
|
||||||
@ -226,7 +226,8 @@ bool HexagonControlWrapper::SetupGraph() {
|
|||||||
for (const GraphTransferInfo::NodeOutputInfo& output_params :
|
for (const GraphTransferInfo::NodeOutputInfo& output_params :
|
||||||
graph_transfer_info.node_output_info()) {
|
graph_transfer_info.node_output_info()) {
|
||||||
const int count = output_params.max_byte_size_size();
|
const int count = output_params.max_byte_size_size();
|
||||||
int sizes[count];
|
CHECK(count <= MAX_IN_OUT_COUNT);
|
||||||
|
int sizes[MAX_IN_OUT_COUNT];
|
||||||
for (int i = 0; i < count; ++i) {
|
for (int i = 0; i < count; ++i) {
|
||||||
const int size = output_params.max_byte_size(i);
|
const int size = output_params.max_byte_size(i);
|
||||||
sizes[i] = size;
|
sizes[i] = size;
|
||||||
@ -373,6 +374,7 @@ bool HexagonControlWrapper::ReadOutputNode(
|
|||||||
<< output_tensor->TotalBytes() << ", " << std::get<1>(output);
|
<< output_tensor->TotalBytes() << ", " << std::get<1>(output);
|
||||||
TF_CHECK_OK(RemoteFusedGraphExecuteUtils::CopyByteArrayToTensor(
|
TF_CHECK_OK(RemoteFusedGraphExecuteUtils::CopyByteArrayToTensor(
|
||||||
std::get<0>(output), std::get<1>(output), output_tensor));
|
std::get<0>(output), std::get<1>(output), output_tensor));
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HexagonControlWrapper::ReadOutputNode(
|
bool HexagonControlWrapper::ReadOutputNode(
|
||||||
@ -382,14 +384,30 @@ bool HexagonControlWrapper::ReadOutputNode(
|
|||||||
const string tensor_name = AddPort(node_name);
|
const string tensor_name = AddPort(node_name);
|
||||||
CHECK(output_port_map_.count(tensor_name) > 0);
|
CHECK(output_port_map_.count(tensor_name) > 0);
|
||||||
const int port = output_port_map_.at(tensor_name);
|
const int port = output_port_map_.at(tensor_name);
|
||||||
soc_interface_ReadOutputNodeWithPort(port, &std::get<0>(output),
|
soc_interface_ReadOutputNodeWithPort(
|
||||||
&std::get<1>(output));
|
port, &std::get<0>(output),
|
||||||
|
reinterpret_cast<uint64_t*>(&std::get<1>(output)));
|
||||||
// TODO: Accept all results
|
// TODO: Accept all results
|
||||||
// std::get<2>(output) = DT_FLOAT;
|
// std::get<2>(output) = DT_FLOAT;
|
||||||
outputs->emplace_back(output);
|
outputs->emplace_back(output);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status HexagonControlWrapper::FuseRemoteGraph(
|
||||||
|
const GraphDef& original_graph_def, const std::vector<string>& inputs,
|
||||||
|
const std::vector<string>& outputs, GraphDef* fused_graph_def) {
|
||||||
|
const std::unordered_set<string> fused_node_names =
|
||||||
|
RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions(
|
||||||
|
original_graph_def, HexagonOpsDefinitions::getInstance());
|
||||||
|
// TODO(satok): We may want to place shape and type inside this function
|
||||||
|
// if they are not placed in the given graph.
|
||||||
|
TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::FuseRemoteGraphByNodeNames(
|
||||||
|
original_graph_def, inputs, outputs, REMOTE_FUSED_GRAPH_NODE_NAME_PREFIX,
|
||||||
|
fused_node_names, REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
|
||||||
|
/*require_shape_type=*/true, fused_graph_def));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
bool HexagonControlWrapper::FillInputNode(const string& node_name,
|
bool HexagonControlWrapper::FillInputNode(const string& node_name,
|
||||||
const Tensor& tensor) {
|
const Tensor& tensor) {
|
||||||
StringPiece tensor_data = tensor.tensor_data();
|
StringPiece tensor_data = tensor.tensor_data();
|
||||||
@ -415,31 +433,5 @@ bool HexagonControlWrapper::FillInputNode(const string& node_name,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
#else
|
bool HexagonControlWrapper::IsEnabled() const { return true; };
|
||||||
int HexagonControlWrapper::GetVersion() { return -1; }
|
|
||||||
bool HexagonControlWrapper::Init(const RemoteFusedGraphExecuteInfo&) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
bool HexagonControlWrapper::Finalize() { return false; }
|
|
||||||
bool HexagonControlWrapper::SetupGraph() { return false; }
|
|
||||||
bool HexagonControlWrapper::ExecuteGraph() { return false; }
|
|
||||||
bool HexagonControlWrapper::TeardownGraph() { return false; }
|
|
||||||
bool HexagonControlWrapper::FillInputNode(
|
|
||||||
const string&, const std::array<int64, GraphTransferer::SHAPE_ARRAY_SIZE>&,
|
|
||||||
const ConstByteArray) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
bool HexagonControlWrapper::FillInputNode(const string&, const Tensor&) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
bool HexagonControlWrapper::ReadOutputNode(
|
|
||||||
const string& node_name, TensorAllocatorFunc tensor_allocator) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
bool HexagonControlWrapper::ReadOutputNode(const string&,
|
|
||||||
std::vector<ByteArray>* const) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -35,6 +35,8 @@ class HexagonControlWrapper final : public IRemoteFusedGraphExecutor {
|
|||||||
public:
|
public:
|
||||||
using ByteArray =
|
using ByteArray =
|
||||||
std::tuple<uint8* /* data */, uint64 /* size */, DataType /* type */>;
|
std::tuple<uint8* /* data */, uint64 /* size */, DataType /* type */>;
|
||||||
|
static constexpr const char* const REMOTE_FUSED_GRAPH_EXECUTOR_NAME =
|
||||||
|
"build_hexagon_remote_fused_graph_executor";
|
||||||
|
|
||||||
HexagonControlWrapper() = default;
|
HexagonControlWrapper() = default;
|
||||||
int GetVersion() final;
|
int GetVersion() final;
|
||||||
@ -46,6 +48,11 @@ class HexagonControlWrapper final : public IRemoteFusedGraphExecutor {
|
|||||||
bool FillInputNode(const string& node_name, const Tensor& tensor) final;
|
bool FillInputNode(const string& node_name, const Tensor& tensor) final;
|
||||||
bool ReadOutputNode(const string& node_name,
|
bool ReadOutputNode(const string& node_name,
|
||||||
TensorAllocatorFunc tensor_allocator) final;
|
TensorAllocatorFunc tensor_allocator) final;
|
||||||
|
Status FuseRemoteGraph(const GraphDef& original_graph_def,
|
||||||
|
const std::vector<string>& inputs,
|
||||||
|
const std::vector<string>& outputs,
|
||||||
|
GraphDef* fused_graph_def) final;
|
||||||
|
bool IsEnabled() const final;
|
||||||
bool ReadOutputNode(const string& node_name, std::vector<ByteArray>* outputs);
|
bool ReadOutputNode(const string& node_name, std::vector<ByteArray>* outputs);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -35,8 +35,8 @@ adb push /tmp/imagenet_comp_graph_label_strings.txt /data/local/tmp
|
|||||||
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
|
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
|
||||||
#include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h"
|
#include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h"
|
||||||
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
|
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
|
||||||
#include "tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.h"
|
|
||||||
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
|
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
|
||||||
|
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
|
||||||
#include "tensorflow/core/kernels/quantization_utils.h"
|
#include "tensorflow/core/kernels/quantization_utils.h"
|
||||||
#include "tensorflow/core/lib/core/casts.h"
|
#include "tensorflow/core/lib/core/casts.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
@ -389,9 +389,6 @@ static void CompareGraphTransferInfo(const GraphTransferInfo& gfi0,
|
|||||||
EXPECT_EQ(goni0.ByteSizeLong(), goni1.ByteSizeLong());
|
EXPECT_EQ(goni0.ByteSizeLong(), goni1.ByteSizeLong());
|
||||||
EXPECT_EQ(goni0.DebugString(), goni1.DebugString());
|
EXPECT_EQ(goni0.DebugString(), goni1.DebugString());
|
||||||
}
|
}
|
||||||
|
|
||||||
// 7. check destination
|
|
||||||
EXPECT_EQ(gfi0.destination(), gfi1.destination());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CAVEAT: This test only runs when you specify hexagon library using
|
// CAVEAT: This test only runs when you specify hexagon library using
|
||||||
@ -407,7 +404,7 @@ TEST(GraphTransferer,
|
|||||||
LOG(INFO) << "Run inception v3 on hexagon with hexagon controller";
|
LOG(INFO) << "Run inception v3 on hexagon with hexagon controller";
|
||||||
CheckHexagonControllerVersion();
|
CheckHexagonControllerVersion();
|
||||||
|
|
||||||
const IGraphTransferOpsDefinitions* ops_definitions =
|
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
|
||||||
&HexagonOpsDefinitions::getInstance();
|
&HexagonOpsDefinitions::getInstance();
|
||||||
std::vector<std::pair<string, Tensor>> inputs;
|
std::vector<std::pair<string, Tensor>> inputs;
|
||||||
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
|
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
|
||||||
@ -474,7 +471,7 @@ TEST(GraphTransferer,
|
|||||||
LOG(INFO) << "Run inception v3 on hexagon with hexagon controller";
|
LOG(INFO) << "Run inception v3 on hexagon with hexagon controller";
|
||||||
CheckHexagonControllerVersion();
|
CheckHexagonControllerVersion();
|
||||||
|
|
||||||
const IGraphTransferOpsDefinitions* ops_definitions =
|
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
|
||||||
&HexagonOpsDefinitions::getInstance();
|
&HexagonOpsDefinitions::getInstance();
|
||||||
std::vector<std::pair<string, Tensor>> inputs;
|
std::vector<std::pair<string, Tensor>> inputs;
|
||||||
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
|
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
|
||||||
@ -505,7 +502,7 @@ TEST(GraphTransferer, RunInceptionV3OnHexagonExampleWithTfRuntime) {
|
|||||||
LOG(INFO) << "Fuse and run inception v3 on hexagon with tf runtime";
|
LOG(INFO) << "Fuse and run inception v3 on hexagon with tf runtime";
|
||||||
CheckHexagonControllerVersion();
|
CheckHexagonControllerVersion();
|
||||||
|
|
||||||
const IGraphTransferOpsDefinitions* ops_definitions =
|
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
|
||||||
&HexagonOpsDefinitions::getInstance();
|
&HexagonOpsDefinitions::getInstance();
|
||||||
std::vector<std::pair<string, Tensor>> inputs;
|
std::vector<std::pair<string, Tensor>> inputs;
|
||||||
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
|
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
|
||||||
@ -543,7 +540,7 @@ TEST(GraphTransferer, DISABLED_CheckShapeInferencePerformance) {
|
|||||||
CheckHexagonControllerVersion();
|
CheckHexagonControllerVersion();
|
||||||
profile_utils::CpuUtils::EnableClockCycleProfiling(true);
|
profile_utils::CpuUtils::EnableClockCycleProfiling(true);
|
||||||
|
|
||||||
const IGraphTransferOpsDefinitions* ops_definitions =
|
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
|
||||||
&HexagonOpsDefinitions::getInstance();
|
&HexagonOpsDefinitions::getInstance();
|
||||||
std::vector<std::pair<string, Tensor>> inputs;
|
std::vector<std::pair<string, Tensor>> inputs;
|
||||||
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
|
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
|
||||||
|
@ -304,8 +304,8 @@ HexagonOpsDefinitions::BuildOpNameToSocOpTypeMap() {
|
|||||||
EmplaceOpType("INPUT", {}, SupportedOpType::INPUT, &op_map);
|
EmplaceOpType("INPUT", {}, SupportedOpType::INPUT, &op_map);
|
||||||
EmplaceOpType("OUTPUT", {}, SupportedOpType::OUTPUT, &op_map);
|
EmplaceOpType("OUTPUT", {}, SupportedOpType::OUTPUT, &op_map);
|
||||||
EmplaceOpType("NoOp", {}, SupportedOpType::NOP, &op_map);
|
EmplaceOpType("NoOp", {}, SupportedOpType::NOP, &op_map);
|
||||||
EmplaceOpType(IGraphTransferOpsDefinitions::FLATTEN_OP_NAME, {},
|
// Special op type for hexagon
|
||||||
SupportedOpType::FLATTEN, &op_map);
|
EmplaceOpType("FLATTEN", {}, SupportedOpType::FLATTEN, &op_map);
|
||||||
// Tensorflow op name
|
// Tensorflow op name
|
||||||
// CAVEAT: Keep order of SupportedOpType
|
// CAVEAT: Keep order of SupportedOpType
|
||||||
EmplaceOpType("Identity", {}, SupportedOpType::NOP, &op_map);
|
EmplaceOpType("Identity", {}, SupportedOpType::NOP, &op_map);
|
||||||
@ -373,7 +373,7 @@ HexagonOpsDefinitions::BuildOpNameToSocOpTypeMap() {
|
|||||||
HexagonOpsDefinitions::HexagonOpsDefinitions()
|
HexagonOpsDefinitions::HexagonOpsDefinitions()
|
||||||
: op_name_to_soc_op_type_map_(BuildOpNameToSocOpTypeMap()) {}
|
: op_name_to_soc_op_type_map_(BuildOpNameToSocOpTypeMap()) {}
|
||||||
|
|
||||||
/* static */ const IGraphTransferOpsDefinitions&
|
/* static */ const IRemoteFusedGraphOpsDefinitions&
|
||||||
HexagonOpsDefinitions::getInstance() {
|
HexagonOpsDefinitions::getInstance() {
|
||||||
const static HexagonOpsDefinitions instance{};
|
const static HexagonOpsDefinitions instance{};
|
||||||
return instance;
|
return instance;
|
||||||
@ -393,17 +393,17 @@ int HexagonOpsDefinitions::GetOpIdFor(const string& op_type,
|
|||||||
if (dt_vec.empty()) {
|
if (dt_vec.empty()) {
|
||||||
return static_cast<int>(std::get<1>(dt_to_op_vec.front()));
|
return static_cast<int>(std::get<1>(dt_to_op_vec.front()));
|
||||||
}
|
}
|
||||||
|
// If there is only one op_id registered for empty op_vec, we assume
|
||||||
|
// that the op supports any data types.
|
||||||
|
if (dt_to_op_vec.size() == 1 && std::get<0>(dt_to_op_vec.front()).empty()) {
|
||||||
|
return static_cast<int>(std::get<1>(dt_to_op_vec.front()));
|
||||||
|
}
|
||||||
for (const DataTypeToOp& data_type_to_op : dt_to_op_vec) {
|
for (const DataTypeToOp& data_type_to_op : dt_to_op_vec) {
|
||||||
if (std::get<0>(data_type_to_op) == dt_vec) {
|
if (std::get<0>(data_type_to_op) == dt_vec) {
|
||||||
return static_cast<int>(std::get<1>(data_type_to_op));
|
return static_cast<int>(std::get<1>(data_type_to_op));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return IGraphTransferOpsDefinitions::INVALID_OP_ID;
|
return IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID;
|
||||||
}
|
|
||||||
|
|
||||||
GraphTransferInfo::Destination HexagonOpsDefinitions::GetTransferDestination()
|
|
||||||
const {
|
|
||||||
return GraphTransferInfo::HEXAGON;
|
|
||||||
}
|
}
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -18,21 +18,20 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
#include "i_graph_transfer_ops_definitions.h"
|
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// HexagonOpsDefinitions provides ops definitions supported in hexagon library
|
// HexagonOpsDefinitions provides ops definitions supported in hexagon library
|
||||||
// TODO(satok): add a functionality to call functions in hexagon library
|
// TODO(satok): add a functionality to call functions in hexagon library
|
||||||
class HexagonOpsDefinitions final : public IGraphTransferOpsDefinitions {
|
class HexagonOpsDefinitions final : public IRemoteFusedGraphOpsDefinitions {
|
||||||
public:
|
public:
|
||||||
static const IGraphTransferOpsDefinitions& getInstance();
|
static const IRemoteFusedGraphOpsDefinitions& getInstance();
|
||||||
|
|
||||||
int GetTotalOpsCount() const final;
|
int GetTotalOpsCount() const final;
|
||||||
int GetOpIdFor(const string& op_type, const DataTypeVector& dt) const final;
|
int GetOpIdFor(const string& op_type, const DataTypeVector& dt) const final;
|
||||||
GraphTransferInfo::Destination GetTransferDestination() const final;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
enum class SupportedOpType;
|
enum class SupportedOpType;
|
||||||
|
@ -27,7 +27,7 @@ Status BuildRemoteFusedGraphExecutor(
|
|||||||
|
|
||||||
static RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
|
static RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
|
||||||
k_hexagon_remote_fused_graph_executor_build(
|
k_hexagon_remote_fused_graph_executor_build(
|
||||||
"build_hexagon_remote_fused_graph_executor",
|
HexagonControlWrapper::REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
|
||||||
BuildRemoteFusedGraphExecutor);
|
BuildRemoteFusedGraphExecutor);
|
||||||
|
|
||||||
} // namespace hexagon_remote_fused_graph_executor_build
|
} // namespace hexagon_remote_fused_graph_executor_build
|
||||||
|
@ -1,136 +0,0 @@
|
|||||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
// Tests in this file are designed to evaluate hexagon DSP operations.
|
|
||||||
|
|
||||||
#define EIGEN_USE_THREADS
|
|
||||||
|
|
||||||
#include "tensorflow/core/framework/allocator.h"
|
|
||||||
#include "tensorflow/core/framework/fake_input.h"
|
|
||||||
#include "tensorflow/core/framework/node_def_builder.h"
|
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
|
||||||
#include "tensorflow/core/framework/types.h"
|
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
|
||||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
|
||||||
#include "tensorflow/core/kernels/ops_util.h"
|
|
||||||
#include "tensorflow/core/kernels/quantization_utils.h"
|
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
|
||||||
#include "tensorflow/core/platform/test.h"
|
|
||||||
|
|
||||||
#ifdef USE_HEXAGON_LIBS
|
|
||||||
#include "tensorflow/core/platform/hexagon/soc_interface.h"
|
|
||||||
#include "tensorflow/core/platform/profile_utils/cpu_utils.h"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
|
|
||||||
class QuantizedMatMulOpForHexagonTest : public OpsTestBase {
|
|
||||||
protected:
|
|
||||||
void SetUp() final {
|
|
||||||
#ifdef USE_HEXAGON_LIBS
|
|
||||||
profile_utils::CpuUtils::EnableClockCycleProfiling(true);
|
|
||||||
LOG(INFO) << "Hexagon libs are linked (wrapper version = "
|
|
||||||
<< soc_interface_GetWrapperVersion()
|
|
||||||
<< ", hexagon binary version = "
|
|
||||||
<< soc_interface_GetSocControllerVersion() << ")";
|
|
||||||
LOG(INFO) << "Cpu frequency = "
|
|
||||||
<< profile_utils::CpuUtils::GetCycleCounterFrequency();
|
|
||||||
#else
|
|
||||||
LOG(WARNING) << "Hexagon libs are not linked.";
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Shows some statistics of hexagon dsp using hexagon specific APIs
|
|
||||||
#ifdef USE_HEXAGON_LIBS
|
|
||||||
TEST_F(QuantizedMatMulOpForHexagonTest, EvaluateSharedLibOverhead) {
|
|
||||||
const uint64 overhead_shared_lib_start =
|
|
||||||
profile_utils::CpuUtils::GetCurrentClockCycle();
|
|
||||||
const int wrapper_version = soc_interface_GetWrapperVersion();
|
|
||||||
const uint64 overhead_shared_lib_end =
|
|
||||||
profile_utils::CpuUtils::GetCurrentClockCycle();
|
|
||||||
const uint64 overhead_shared_lib_diff =
|
|
||||||
(overhead_shared_lib_end - overhead_shared_lib_start);
|
|
||||||
const uint64 overhead_hexagon_rpc_start =
|
|
||||||
profile_utils::CpuUtils::GetCurrentClockCycle();
|
|
||||||
const int hexagon_binary_version = soc_interface_GetSocControllerVersion();
|
|
||||||
const uint64 overhead_hexagon_rpc_end =
|
|
||||||
profile_utils::CpuUtils::GetCurrentClockCycle();
|
|
||||||
const uint64 overhead_hexagon_rpc_diff =
|
|
||||||
(overhead_hexagon_rpc_end - overhead_hexagon_rpc_start);
|
|
||||||
LOG(INFO) << "Shared lib (ver = " << wrapper_version << ") overhead is "
|
|
||||||
<< overhead_shared_lib_diff << " cycles, time = "
|
|
||||||
<< std::chrono::duration_cast<std::chrono::microseconds>(
|
|
||||||
profile_utils::CpuUtils::ConvertClockCycleToTime(
|
|
||||||
overhead_shared_lib_diff))
|
|
||||||
.count()
|
|
||||||
<< " usec";
|
|
||||||
LOG(INFO) << "hexagon rpc (ver = " << hexagon_binary_version
|
|
||||||
<< ") overhead is " << overhead_hexagon_rpc_diff
|
|
||||||
<< " cycles, time = "
|
|
||||||
<< std::chrono::duration_cast<std::chrono::microseconds>(
|
|
||||||
profile_utils::CpuUtils::ConvertClockCycleToTime(
|
|
||||||
overhead_hexagon_rpc_diff))
|
|
||||||
.count()
|
|
||||||
<< " usec";
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Runs two small matrices through the operator, and leaves all the parameters
|
|
||||||
// at their default values.
|
|
||||||
// This test is a sample to execute matmul on hexagon.
|
|
||||||
TEST_F(QuantizedMatMulOpForHexagonTest, Small_NoParams) {
|
|
||||||
TF_ASSERT_OK(NodeDefBuilder("quantized_mat_mul_op", "QuantizedMatMul")
|
|
||||||
.Input(FakeInput(DT_QUINT8))
|
|
||||||
.Input(FakeInput(DT_QUINT8))
|
|
||||||
.Input(FakeInput(DT_FLOAT))
|
|
||||||
.Input(FakeInput(DT_FLOAT))
|
|
||||||
.Input(FakeInput(DT_FLOAT))
|
|
||||||
.Input(FakeInput(DT_FLOAT))
|
|
||||||
.Attr("Toutput", DataTypeToEnum<qint32>::v())
|
|
||||||
.Finalize(node_def()));
|
|
||||||
TF_ASSERT_OK(InitOp());
|
|
||||||
// A matrix is:
|
|
||||||
// | 1 | 2 | 3 |
|
|
||||||
// | 4 | 5 | 6 |
|
|
||||||
AddInputFromArray<quint8>(TensorShape({2, 3}), {1, 2, 3, 4, 5, 6});
|
|
||||||
// B matrix is:
|
|
||||||
// | 7 | 8 | 9 | 10 |
|
|
||||||
// | 11 | 12 | 13 | 14 |
|
|
||||||
// | 15 | 16 | 17 | 18 |
|
|
||||||
AddInputFromArray<quint8>(TensorShape({3, 4}),
|
|
||||||
{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
|
|
||||||
AddInputFromArray<float>(TensorShape({1}), {0});
|
|
||||||
AddInputFromArray<float>(TensorShape({1}), {255.0f});
|
|
||||||
AddInputFromArray<float>(TensorShape({1}), {0});
|
|
||||||
AddInputFromArray<float>(TensorShape({1}), {255.0f});
|
|
||||||
|
|
||||||
TF_ASSERT_OK(RunOpKernel());
|
|
||||||
// Here are the results we expect, from hand calculations:
|
|
||||||
// (1 * 7) + (2 * 11) + (3 * 15) = 74
|
|
||||||
// (1 * 8) + (2 * 12) + (3 * 16) = 80
|
|
||||||
// (1 * 9) + (2 * 13) + (3 * 17) = 86
|
|
||||||
// (1 * 10) + (2 * 14) + (3 * 18) = 92
|
|
||||||
// (4 * 7) + (5 * 11) + (6 * 15) = 173
|
|
||||||
// (4 * 8) + (5 * 12) + (6 * 16) = 188
|
|
||||||
// (4 * 9) + (5 * 13) + (6 * 17) = 203
|
|
||||||
// (4 * 10) + (5 * 14) + (6 * 18) = 218
|
|
||||||
Tensor expected(allocator(), DT_QINT32, TensorShape({2, 4}));
|
|
||||||
test::FillValues<qint32>(&expected, {74, 80, 86, 92, 173, 188, 203, 218});
|
|
||||||
test::ExpectTensorEqual<qint32>(expected, *GetOutput(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
|
83
tensorflow/core/kernels/hexagon/soc_interface.cc
Normal file
83
tensorflow/core/kernels/hexagon/soc_interface.cc
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
vcyou may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/hexagon/soc_interface.h"
|
||||||
|
|
||||||
|
// Dummy implementation of soc_interface.
|
||||||
|
|
||||||
|
int soc_interface_GetWrapperVersion() { return -1; }
|
||||||
|
int soc_interface_GetSocControllerVersion() { return -1; }
|
||||||
|
bool soc_interface_Init() { return false; }
|
||||||
|
bool soc_interface_Finalize() { return false; }
|
||||||
|
bool soc_interface_ExecuteGraph() { return false; }
|
||||||
|
bool soc_interface_TeardownGraph() { return false; }
|
||||||
|
bool soc_interface_AllocateInOutNodeBuffers(int /*input_count*/,
|
||||||
|
int* /*input_sizes*/,
|
||||||
|
int /*output_count*/,
|
||||||
|
int* /*output_sizes*/) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool soc_interface_FillInputNodeWithPort(int /*port*/, int /*x*/, int /*y*/,
|
||||||
|
int /*z*/, int /*d*/,
|
||||||
|
const uint8_t* const /*buf*/,
|
||||||
|
uint64_t /*buf_byte_size*/) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool soc_interface_FillInputNodeFloat(int /*x*/, int /*y*/, int /*z*/,
|
||||||
|
int /*d*/, const uint8_t* const /*buf*/,
|
||||||
|
uint64_t /*buf_byte_size*/) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool soc_interface_ReadOutputNodeWithPort(int /*port*/, uint8_t** /*buf*/,
|
||||||
|
uint64_t* /*buf_byte_size*/) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool soc_interface_ReadOutputNodeFloat(const char* const /*node_name*/,
|
||||||
|
uint8_t** /*buf*/,
|
||||||
|
uint64_t* /*buf_byte_size*/) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool soc_interface_setupDummyGraph(int /*version*/) { return false; }
|
||||||
|
bool soc_interface_AllocateNodeInputAndNodeOutputArray(
|
||||||
|
int /*total_input_count*/, int /*total_output_count*/) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool soc_interface_ReleaseNodeInputAndNodeOutputArray() { return false; }
|
||||||
|
void* soc_interface_SetOneNodeInputs(int /*input_count*/,
|
||||||
|
const int* const /*node_id*/,
|
||||||
|
const int* const /*port*/) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
void* soc_interface_SetOneNodeOutputs(int /*output_count*/, int* /*max_size*/) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
bool soc_interface_AppendConstNode(const char* const /*name*/, int /*node_id*/,
|
||||||
|
int /*batch*/, int /*height*/, int /*width*/,
|
||||||
|
int /*depth*/, const uint8_t* const /*data*/,
|
||||||
|
int /*data_length*/) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool soc_interface_AppendNode(const char* const /*name*/, int /*node_id*/,
|
||||||
|
int /*op_id*/, int /*padding_id*/,
|
||||||
|
const void* const /*inputs*/,
|
||||||
|
int /*inputs_count*/,
|
||||||
|
const void* const /*outputs*/,
|
||||||
|
int /*outputs_count*/) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool soc_interface_InstantiateGraph() { return false; }
|
||||||
|
bool soc_interface_ConstructGraph() { return false; }
|
||||||
|
void soc_interface_SetLogLevel(int /*log_level*/) {}
|
||||||
|
void soc_interface_SetDebugFlag(uint64_t /*flag*/) {}
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
// All functions defined here must have prefix "soc_interface" to avoid
|
// All functions defined here must have prefix "soc_interface" to avoid
|
||||||
// naming conflicts.
|
// naming conflicts.
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
#include <cstdint>
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#else
|
#else
|
||||||
#include <stdbool.h>
|
#include <stdbool.h>
|
@ -59,6 +59,13 @@ class IRemoteFusedGraphExecutor {
|
|||||||
virtual bool ReadOutputNode(const string& node_name,
|
virtual bool ReadOutputNode(const string& node_name,
|
||||||
TensorAllocatorFunc tensor_allocator) = 0;
|
TensorAllocatorFunc tensor_allocator) = 0;
|
||||||
|
|
||||||
|
virtual Status FuseRemoteGraph(const GraphDef& original_graph_def,
|
||||||
|
const std::vector<string>& inputs,
|
||||||
|
const std::vector<string>& outputs,
|
||||||
|
GraphDef* fused_graph_def) = 0;
|
||||||
|
|
||||||
|
virtual bool IsEnabled() const = 0;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(IRemoteFusedGraphExecutor);
|
TF_DISALLOW_COPY_AND_ASSIGN(IRemoteFusedGraphExecutor);
|
||||||
};
|
};
|
||||||
|
@ -13,11 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "i_graph_transfer_ops_definitions.h"
|
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
/* static */ constexpr int IGraphTransferOpsDefinitions::INVALID_OP_ID;
|
/* static */ constexpr int IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID;
|
||||||
// TODO(satok): Remove
|
|
||||||
/* static */ constexpr const char* const
|
|
||||||
IGraphTransferOpsDefinitions::FLATTEN_OP_NAME;
|
|
||||||
}
|
}
|
@ -13,39 +13,34 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_I_GRAPH_TRANSFER_OPS_DEFINITIONS_H_
|
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_OPS_DEFINITIONS_H_
|
||||||
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_I_GRAPH_TRANSFER_OPS_DEFINITIONS_H_
|
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_OPS_DEFINITIONS_H_
|
||||||
|
|
||||||
#include "tensorflow/core/framework/graph_transfer_info.pb.h"
|
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// IGraphTransferOpsDefinitions is an interface class which provides interfaces
|
// IRemoteFusedGraphOpsDefinitions is an interface class which provides
|
||||||
// about ops supported by SOC.
|
// APIs to provide information about op types supported by SOC.
|
||||||
// TODO(satok): Provide ways to transfer graph definitions into SOC
|
// TODO(satok): Provide ways to transfer graph definitions into SOC
|
||||||
class IGraphTransferOpsDefinitions {
|
class IRemoteFusedGraphOpsDefinitions {
|
||||||
public:
|
public:
|
||||||
// op id which is not supported by SOC
|
// op id which is not supported by SOC
|
||||||
static constexpr int INVALID_OP_ID = -1;
|
static constexpr int INVALID_OP_ID = -1;
|
||||||
// Custom op name for flatten node
|
|
||||||
static constexpr const char* const FLATTEN_OP_NAME = "FLATTEN";
|
|
||||||
|
|
||||||
IGraphTransferOpsDefinitions() = default;
|
IRemoteFusedGraphOpsDefinitions() = default;
|
||||||
virtual ~IGraphTransferOpsDefinitions() = default;
|
virtual ~IRemoteFusedGraphOpsDefinitions() = default;
|
||||||
// Return total ops count supported by SOC
|
// Return total ops count supported by SOC
|
||||||
virtual int GetTotalOpsCount() const = 0;
|
virtual int GetTotalOpsCount() const = 0;
|
||||||
// Return op id for given string op name
|
// Return op id for given string op name
|
||||||
virtual int GetOpIdFor(const string& op_name,
|
virtual int GetOpIdFor(const string& op_type,
|
||||||
const DataTypeVector& dt) const = 0;
|
const DataTypeVector& dt) const = 0;
|
||||||
// Return destination of transfer
|
|
||||||
virtual GraphTransferInfo::Destination GetTransferDestination() const = 0;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(IGraphTransferOpsDefinitions);
|
TF_DISALLOW_COPY_AND_ASSIGN(IRemoteFusedGraphOpsDefinitions);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_I_GRAPH_TRANSFER_OPS_DEFINITIONS_H_
|
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_OPS_DEFINITIONS_H_
|
@ -41,7 +41,8 @@ class RemoteFusedGraphExecuteOp : public OpKernel {
|
|||||||
RemoteFusedGraphExecuteUtils::GetExecutorBuildFunc(
|
RemoteFusedGraphExecuteUtils::GetExecutorBuildFunc(
|
||||||
execute_info_.executor_name());
|
execute_info_.executor_name());
|
||||||
if (build_func != nullptr) {
|
if (build_func != nullptr) {
|
||||||
Status status = (*build_func)(&remote_fused_graph_executor_);
|
TF_CHECK_OK((*build_func)(&remote_fused_graph_executor_));
|
||||||
|
CHECK(remote_fused_graph_executor_->IsEnabled());
|
||||||
} else {
|
} else {
|
||||||
LOG(ERROR) << "Executor not found for "
|
LOG(ERROR) << "Executor not found for "
|
||||||
<< execute_info_.executor_name();
|
<< execute_info_.executor_name();
|
||||||
|
@ -159,8 +159,8 @@ static RemoteFusedGraphExecuteInfo BuildRemoteFusedGraphExecuteInfo(
|
|||||||
return execute_info;
|
return execute_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1. Create TestRemoteFusedGraphExecutor to execute your fused graph
|
// 1. Create SampleRemoteFusedGraphExecutor to execute your fused graph
|
||||||
class TestRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor {
|
class SampleRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor {
|
||||||
public:
|
public:
|
||||||
int GetVersion() final { return 1; }
|
int GetVersion() final { return 1; }
|
||||||
bool Init(const RemoteFusedGraphExecuteInfo& info) final {
|
bool Init(const RemoteFusedGraphExecuteInfo& info) final {
|
||||||
@ -214,6 +214,16 @@ class TestRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status FuseRemoteGraph(const GraphDef& original_graph_def,
|
||||||
|
const std::vector<string>& /*inputs*/,
|
||||||
|
const std::vector<string>& /*outputs*/,
|
||||||
|
GraphDef* fused_graph_def) final {
|
||||||
|
*fused_graph_def = original_graph_def;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsEnabled() const final { return true; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const RemoteFusedGraphExecuteInfo* info_;
|
const RemoteFusedGraphExecuteInfo* info_;
|
||||||
std::unordered_map<string, Tensor> input_tensor_cache_;
|
std::unordered_map<string, Tensor> input_tensor_cache_;
|
||||||
@ -225,7 +235,7 @@ class TestRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor {
|
|||||||
namespace remote_fused_graph_execute_op {
|
namespace remote_fused_graph_execute_op {
|
||||||
Status BuildRemoteFusedGraphExecutor(
|
Status BuildRemoteFusedGraphExecutor(
|
||||||
std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
|
std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
|
||||||
executor->reset(new TestRemoteFusedGraphExecutor());
|
executor->reset(new SampleRemoteFusedGraphExecutor());
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/ops/const_op.h"
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
#include "tensorflow/cc/ops/math_ops.h"
|
#include "tensorflow/cc/ops/math_ops.h"
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
@ -92,4 +93,36 @@ namespace tensorflow {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TestRemoteFusedGraphExecutor::TestRemoteFusedGraphExecutor(
|
||||||
|
const std::unordered_set<string>& fused_op_types,
|
||||||
|
const string& executor_name)
|
||||||
|
: fused_op_types_(fused_op_types), executor_name_(executor_name) {}
|
||||||
|
|
||||||
|
int TestRemoteFusedGraphExecutor::GetVersion() { return 0; }
|
||||||
|
bool TestRemoteFusedGraphExecutor::Init(const RemoteFusedGraphExecuteInfo&) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool TestRemoteFusedGraphExecutor::Finalize() { return true; }
|
||||||
|
bool TestRemoteFusedGraphExecutor::SetupGraph() { return true; }
|
||||||
|
bool TestRemoteFusedGraphExecutor::ExecuteGraph() { return true; }
|
||||||
|
bool TestRemoteFusedGraphExecutor::TeardownGraph() { return true; }
|
||||||
|
bool TestRemoteFusedGraphExecutor::FillInputNode(const string&, const Tensor&) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool TestRemoteFusedGraphExecutor::ReadOutputNode(const string&,
|
||||||
|
TensorAllocatorFunc) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
Status TestRemoteFusedGraphExecutor::FuseRemoteGraph(
|
||||||
|
const GraphDef& original_graph_def, const std::vector<string>& inputs,
|
||||||
|
const std::vector<string>& outputs, GraphDef* fused_graph_def) {
|
||||||
|
return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByOpTypes(
|
||||||
|
original_graph_def, inputs, outputs, "remote_fused_graph_node_names",
|
||||||
|
fused_op_types_, executor_name_,
|
||||||
|
/*require_shape_type=*/false, fused_graph_def);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TestRemoteFusedGraphExecutor::IsEnabled() const { return true; }
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/framework/ops.h"
|
#include "tensorflow/cc/framework/ops.h"
|
||||||
#include "tensorflow/cc/framework/scope.h"
|
#include "tensorflow/cc/framework/scope.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -59,6 +60,30 @@ class RemoteFusedGraphExecuteOpTestUtils {
|
|||||||
TF_DISALLOW_COPY_AND_ASSIGN(RemoteFusedGraphExecuteOpTestUtils);
|
TF_DISALLOW_COPY_AND_ASSIGN(RemoteFusedGraphExecuteOpTestUtils);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class TestRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor {
|
||||||
|
public:
|
||||||
|
TestRemoteFusedGraphExecutor(const std::unordered_set<string>& fused_op_types,
|
||||||
|
const string& executor_name);
|
||||||
|
|
||||||
|
int GetVersion() final;
|
||||||
|
bool Init(const RemoteFusedGraphExecuteInfo&) final;
|
||||||
|
bool Finalize() final;
|
||||||
|
bool SetupGraph() final;
|
||||||
|
bool ExecuteGraph() final;
|
||||||
|
bool TeardownGraph() final;
|
||||||
|
bool FillInputNode(const string&, const Tensor&) final;
|
||||||
|
bool ReadOutputNode(const string&, TensorAllocatorFunc) final;
|
||||||
|
Status FuseRemoteGraph(const GraphDef& original_graph_def,
|
||||||
|
const std::vector<string>& inputs,
|
||||||
|
const std::vector<string>& outputs,
|
||||||
|
GraphDef* fused_graph_def) final;
|
||||||
|
bool IsEnabled() const final;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const std::unordered_set<string> fused_op_types_;
|
||||||
|
const string executor_name_;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_OP_TEST_UTILS_H_
|
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_OP_TEST_UTILS_H_
|
||||||
|
@ -161,6 +161,8 @@ string DumpCluster(const RemoteFusedGraphExecuteUtils::ClusterInfo& cluster) {
|
|||||||
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS;
|
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS;
|
||||||
/* static */ constexpr const char* const
|
/* static */ constexpr const char* const
|
||||||
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES;
|
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES;
|
||||||
|
/* static */ constexpr const char* const
|
||||||
|
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR;
|
||||||
/* static */ constexpr const char* const
|
/* static */ constexpr const char* const
|
||||||
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES;
|
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES;
|
||||||
/* static */ constexpr const char* const
|
/* static */ constexpr const char* const
|
||||||
@ -1084,6 +1086,26 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
|
|||||||
require_shape_type, output_graph_def);
|
require_shape_type, output_graph_def);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
|
||||||
|
const GraphDef& input_graph_def, const std::vector<string>& inputs,
|
||||||
|
const std::vector<string>& outputs, const string& executor_name,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
const ExecutorBuildFunc* build_func = GetExecutorBuildFunc(executor_name);
|
||||||
|
if (build_func == nullptr) {
|
||||||
|
return errors::InvalidArgument("Unknown executor name: " + executor_name);
|
||||||
|
}
|
||||||
|
std::unique_ptr<IRemoteFusedGraphExecutor> executor;
|
||||||
|
TF_RETURN_IF_ERROR((*build_func)(&executor));
|
||||||
|
CHECK_NOTNULL(executor.get());
|
||||||
|
if (!executor->IsEnabled()) {
|
||||||
|
// As this executor is not enabled, just return original graph as is.
|
||||||
|
*output_graph_def = input_graph_def;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
return executor->FuseRemoteGraph(input_graph_def, inputs, outputs,
|
||||||
|
output_graph_def);
|
||||||
|
}
|
||||||
|
|
||||||
/* static */ Status RemoteFusedGraphExecuteUtils::PlaceRemoteGraphArguments(
|
/* static */ Status RemoteFusedGraphExecuteUtils::PlaceRemoteGraphArguments(
|
||||||
const std::vector<string>& inputs, const std::vector<string>& outputs,
|
const std::vector<string>& inputs, const std::vector<string>& outputs,
|
||||||
const std::unordered_set<string>& fused_node_names,
|
const std::unordered_set<string>& fused_node_names,
|
||||||
@ -1387,6 +1409,28 @@ RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpTypes(
|
|||||||
return retval;
|
return retval;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */ std::unordered_set<string>
|
||||||
|
RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions(
|
||||||
|
const GraphDef& graph_def,
|
||||||
|
const IRemoteFusedGraphOpsDefinitions& ops_definitions) {
|
||||||
|
std::unordered_set<string> retval;
|
||||||
|
for (const NodeDef& node_def : graph_def.node()) {
|
||||||
|
std::vector<DataType> dt_vec;
|
||||||
|
std::vector<TensorShape> shape_vec;
|
||||||
|
const Status status =
|
||||||
|
GetOutputTensorShapeType(node_def, &dt_vec, &shape_vec);
|
||||||
|
if (!status.ok()) {
|
||||||
|
shape_vec.clear();
|
||||||
|
}
|
||||||
|
if (ops_definitions.GetOpIdFor(
|
||||||
|
node_def.op(), DataTypeVector(dt_vec.begin(), dt_vec.end())) !=
|
||||||
|
IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) {
|
||||||
|
retval.emplace(node_def.name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return retval;
|
||||||
|
}
|
||||||
|
|
||||||
/* static */ Status RemoteFusedGraphExecuteUtils::ReplaceInputNodeByPlaceHolder(
|
/* static */ Status RemoteFusedGraphExecuteUtils::ReplaceInputNodeByPlaceHolder(
|
||||||
const string& input, const DataType type, const TensorShape& shape,
|
const string& input, const DataType type, const TensorShape& shape,
|
||||||
GraphDef* graph_def) {
|
GraphDef* graph_def) {
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
|
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
|
||||||
|
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
|
||||||
@ -59,6 +60,8 @@ class RemoteFusedGraphExecuteUtils {
|
|||||||
"border_outputs";
|
"border_outputs";
|
||||||
static constexpr const char* const TRANSFORM_ARG_FUSED_OP_TYPES =
|
static constexpr const char* const TRANSFORM_ARG_FUSED_OP_TYPES =
|
||||||
"fused_op_types";
|
"fused_op_types";
|
||||||
|
static constexpr const char* const TRANSFORM_ARG_FUSE_BY_EXECUTOR =
|
||||||
|
"fuse_by_executor";
|
||||||
static constexpr const char* const TRANSFORM_ARG_INPUT_TYPES = "input_types";
|
static constexpr const char* const TRANSFORM_ARG_INPUT_TYPES = "input_types";
|
||||||
static constexpr const char* const TRANSFORM_ARG_INPUT_SHAPES =
|
static constexpr const char* const TRANSFORM_ARG_INPUT_SHAPES =
|
||||||
"input_shapes";
|
"input_shapes";
|
||||||
@ -257,6 +260,12 @@ class RemoteFusedGraphExecuteUtils {
|
|||||||
const std::vector<std::pair<string, Tensor>>& input_tensors,
|
const std::vector<std::pair<string, Tensor>>& input_tensors,
|
||||||
GraphDef* output_graph_def);
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
|
static Status FuseRemoteGraphByExecutor(const GraphDef& input_graph_def,
|
||||||
|
const std::vector<string>& inputs,
|
||||||
|
const std::vector<string>& outputs,
|
||||||
|
const string& executor_name,
|
||||||
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
static bool IsFuseReady(
|
static bool IsFuseReady(
|
||||||
const GraphDef& input_graph_def,
|
const GraphDef& input_graph_def,
|
||||||
const std::vector<std::pair<string, Tensor>>& input_tensors);
|
const std::vector<std::pair<string, Tensor>>& input_tensors);
|
||||||
@ -273,6 +282,10 @@ class RemoteFusedGraphExecuteUtils {
|
|||||||
static std::unordered_set<string> BuildNodeMapFromOpTypes(
|
static std::unordered_set<string> BuildNodeMapFromOpTypes(
|
||||||
const GraphDef& graph_def, const std::unordered_set<string>& op_types);
|
const GraphDef& graph_def, const std::unordered_set<string>& op_types);
|
||||||
|
|
||||||
|
static std::unordered_set<string> BuildNodeMapFromOpsDefinitions(
|
||||||
|
const GraphDef& graph_def,
|
||||||
|
const IRemoteFusedGraphOpsDefinitions& ops_definitions);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static void EmplaceTensorShapeType(const string& name, const Tensor& tensor,
|
static void EmplaceTensorShapeType(const string& name, const Tensor& tensor,
|
||||||
TensorShapeMap* tensor_shape_map);
|
TensorShapeMap* tensor_shape_map);
|
||||||
|
@ -33,6 +33,10 @@ constexpr const char* const NAME_A_PLUS_B = "A_PLUS_B";
|
|||||||
constexpr float NODE_A_VAL = 2.0f;
|
constexpr float NODE_A_VAL = 2.0f;
|
||||||
constexpr float NODE_B_VAL = 3.0f;
|
constexpr float NODE_B_VAL = 3.0f;
|
||||||
constexpr float VALUE_TOLERANCE_FLOAT = 1e-8f;
|
constexpr float VALUE_TOLERANCE_FLOAT = 1e-8f;
|
||||||
|
constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME0 =
|
||||||
|
"fuse_test_remote_fused_graph_executor0";
|
||||||
|
constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME1 =
|
||||||
|
"fuse_test_remote_fused_graph_executor1";
|
||||||
|
|
||||||
static NodeDef* GetNodeDef(const string& name, GraphDef* def) {
|
static NodeDef* GetNodeDef(const string& name, GraphDef* def) {
|
||||||
CHECK_NE(def, nullptr);
|
CHECK_NE(def, nullptr);
|
||||||
@ -44,17 +48,38 @@ static NodeDef* GetNodeDef(const string& name, GraphDef* def) {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status BuildRemoteFusedGraphExecutor0(
|
||||||
|
std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
|
||||||
|
executor->reset(
|
||||||
|
new TestRemoteFusedGraphExecutor({"Mul"}, REMOTE_FUSED_EXECUTOR_NAME0));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status BuildRemoteFusedGraphExecutor1(
|
||||||
|
std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
|
||||||
|
executor->reset(new TestRemoteFusedGraphExecutor(
|
||||||
|
{"Const", "Mul"}, REMOTE_FUSED_EXECUTOR_NAME1));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
class FuseRemoteGraphMultipleAddOpsTest : public ::testing::Test {
|
class FuseRemoteGraphMultipleAddOpsTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
void SetUp() final {
|
void SetUp() final {
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(&graph_def_));
|
RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(&graph_def_));
|
||||||
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
|
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
|
||||||
k_hexagon_remote_fused_graph_executor_build(
|
hexagon_remote_fused_graph_executor_build(
|
||||||
"remote_graph_executor_name",
|
"remote_graph_executor_name",
|
||||||
[](std::unique_ptr<IRemoteFusedGraphExecutor>* executor) -> Status {
|
[](std::unique_ptr<IRemoteFusedGraphExecutor>* executor) -> Status {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
|
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
|
||||||
|
test_remote_fused_graph_executor_build0(REMOTE_FUSED_EXECUTOR_NAME0,
|
||||||
|
BuildRemoteFusedGraphExecutor0);
|
||||||
|
|
||||||
|
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
|
||||||
|
test_remote_fused_graph_executor_build1(REMOTE_FUSED_EXECUTOR_NAME1,
|
||||||
|
BuildRemoteFusedGraphExecutor1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TearDown() final {}
|
void TearDown() final {}
|
||||||
@ -87,6 +112,18 @@ class FuseRemoteGraphMultipleAddOpsTest : public ::testing::Test {
|
|||||||
/*require_shape_type=*/false, &result_graph_def_);
|
/*require_shape_type=*/false, &result_graph_def_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status FuseByExecutor0() {
|
||||||
|
return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
|
||||||
|
graph_def_, inputs_, outputs_, REMOTE_FUSED_EXECUTOR_NAME0,
|
||||||
|
&result_graph_def_);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status FuseByExecutor1() {
|
||||||
|
return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
|
||||||
|
graph_def_, inputs_, outputs_, REMOTE_FUSED_EXECUTOR_NAME1,
|
||||||
|
&result_graph_def_);
|
||||||
|
}
|
||||||
|
|
||||||
Status BuildAndAddTensorShape() {
|
Status BuildAndAddTensorShape() {
|
||||||
return RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes(
|
return RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes(
|
||||||
input_tensors_, /*dry_run_inference=*/true, &graph_def_);
|
input_tensors_, /*dry_run_inference=*/true, &graph_def_);
|
||||||
@ -694,6 +731,30 @@ TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByOpTypes_FGHIJ) {
|
|||||||
<< SummarizeGraphDef(result_graph_def_);
|
<< SummarizeGraphDef(result_graph_def_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByExecutor_HIJ) {
|
||||||
|
ReplaceOpType({"H", "I", "J"}, "Mul");
|
||||||
|
|
||||||
|
TF_ASSERT_OK(FuseByExecutor0());
|
||||||
|
|
||||||
|
EXPECT_EQ(11, graph_def_.node_size());
|
||||||
|
EXPECT_EQ(9, result_graph_def_.node_size())
|
||||||
|
<< "=== Before: \n"
|
||||||
|
<< SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
|
||||||
|
<< SummarizeGraphDef(result_graph_def_);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByExecutor_FGHIJ) {
|
||||||
|
ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul");
|
||||||
|
|
||||||
|
TF_ASSERT_OK(FuseByExecutor1());
|
||||||
|
|
||||||
|
EXPECT_EQ(11, graph_def_.node_size());
|
||||||
|
EXPECT_EQ(3, result_graph_def_.node_size())
|
||||||
|
<< "=== Before: \n"
|
||||||
|
<< SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
|
||||||
|
<< SummarizeGraphDef(result_graph_def_);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(FuseRemoteGraphMultipleAddOpsTest, PlaceAndFuse_H) {
|
TEST_F(FuseRemoteGraphMultipleAddOpsTest, PlaceAndFuse_H) {
|
||||||
subgraph_node_names_ = {"H"};
|
subgraph_node_names_ = {"H"};
|
||||||
|
|
||||||
|
@ -66,7 +66,7 @@ static Status ParseArguments(const TransformFuncContext& context,
|
|||||||
string* input_types_str, string* input_shapes_str,
|
string* input_types_str, string* input_shapes_str,
|
||||||
string* fused_nodes_str, string* border_inputs_str,
|
string* fused_nodes_str, string* border_inputs_str,
|
||||||
string* border_outputs_str,
|
string* border_outputs_str,
|
||||||
string* fused_op_types_str,
|
string* fused_op_types_str, bool* fuse_by_executor,
|
||||||
string* remote_fused_graph_node_name,
|
string* remote_fused_graph_node_name,
|
||||||
string* remote_graph_executor_name) {
|
string* remote_graph_executor_name) {
|
||||||
TF_RETURN_IF_ERROR(context.GetOneStringParameter(
|
TF_RETURN_IF_ERROR(context.GetOneStringParameter(
|
||||||
@ -87,6 +87,9 @@ static Status ParseArguments(const TransformFuncContext& context,
|
|||||||
TF_RETURN_IF_ERROR(context.GetOneStringParameter(
|
TF_RETURN_IF_ERROR(context.GetOneStringParameter(
|
||||||
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES, "",
|
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES, "",
|
||||||
fused_op_types_str));
|
fused_op_types_str));
|
||||||
|
TF_RETURN_IF_ERROR(context.GetOneBoolParameter(
|
||||||
|
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR, false,
|
||||||
|
fuse_by_executor));
|
||||||
TF_RETURN_IF_ERROR(context.GetOneStringParameter(
|
TF_RETURN_IF_ERROR(context.GetOneStringParameter(
|
||||||
RemoteFusedGraphExecuteUtils::
|
RemoteFusedGraphExecuteUtils::
|
||||||
TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
|
TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
|
||||||
@ -140,12 +143,14 @@ Status FuseRemoteGraph(const GraphDef& input_graph_def,
|
|||||||
string border_inputs_str;
|
string border_inputs_str;
|
||||||
string border_outputs_str;
|
string border_outputs_str;
|
||||||
string fused_op_types_str;
|
string fused_op_types_str;
|
||||||
|
bool fuse_by_executor = false;
|
||||||
string remote_fused_graph_node_name;
|
string remote_fused_graph_node_name;
|
||||||
string remote_graph_executor_name;
|
string remote_graph_executor_name;
|
||||||
TF_RETURN_IF_ERROR(ParseArguments(
|
TF_RETURN_IF_ERROR(ParseArguments(
|
||||||
context, &input_types_str, &input_shapes_str, &fused_nodes_str,
|
context, &input_types_str, &input_shapes_str, &fused_nodes_str,
|
||||||
&border_inputs_str, &border_outputs_str, &fused_op_types_str,
|
&border_inputs_str, &border_outputs_str, &fused_op_types_str,
|
||||||
&remote_fused_graph_node_name, &remote_graph_executor_name));
|
&fuse_by_executor, &remote_fused_graph_node_name,
|
||||||
|
&remote_graph_executor_name));
|
||||||
|
|
||||||
if (!input_types_str.empty()) {
|
if (!input_types_str.empty()) {
|
||||||
TF_RETURN_IF_ERROR(PlaceShapeType(inputs, outputs, input_types_str,
|
TF_RETURN_IF_ERROR(PlaceShapeType(inputs, outputs, input_types_str,
|
||||||
@ -187,6 +192,10 @@ Status FuseRemoteGraph(const GraphDef& input_graph_def,
|
|||||||
mutable_input_graph_def, inputs, outputs, remote_fused_graph_node_name,
|
mutable_input_graph_def, inputs, outputs, remote_fused_graph_node_name,
|
||||||
fused_op_types, remote_graph_executor_name, require_shape_type,
|
fused_op_types, remote_graph_executor_name, require_shape_type,
|
||||||
output_graph_def));
|
output_graph_def));
|
||||||
|
} else if (fuse_by_executor) {
|
||||||
|
TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
|
||||||
|
mutable_input_graph_def, inputs, outputs, remote_graph_executor_name,
|
||||||
|
output_graph_def));
|
||||||
} else {
|
} else {
|
||||||
CHECK(false) << "Fuse targets are not specified.";
|
CHECK(false) << "Fuse targets are not specified.";
|
||||||
}
|
}
|
||||||
@ -205,15 +214,17 @@ Status PlaceRemoteGraphArguments(const GraphDef& input_graph_def,
|
|||||||
string input_types_str;
|
string input_types_str;
|
||||||
string input_shapes_str;
|
string input_shapes_str;
|
||||||
string fused_nodes_str;
|
string fused_nodes_str;
|
||||||
string fused_op_types_str;
|
|
||||||
string border_inputs_str;
|
string border_inputs_str;
|
||||||
string border_outputs_str;
|
string border_outputs_str;
|
||||||
|
string fused_op_types_str;
|
||||||
|
bool fuse_by_executor = false;
|
||||||
string remote_fused_graph_node_name;
|
string remote_fused_graph_node_name;
|
||||||
string remote_graph_executor_name;
|
string remote_graph_executor_name;
|
||||||
TF_RETURN_IF_ERROR(ParseArguments(
|
TF_RETURN_IF_ERROR(ParseArguments(
|
||||||
context, &input_types_str, &input_shapes_str, &fused_nodes_str,
|
context, &input_types_str, &input_shapes_str, &fused_nodes_str,
|
||||||
&border_inputs_str, &border_outputs_str, &fused_op_types_str,
|
&border_inputs_str, &border_outputs_str, &fused_op_types_str,
|
||||||
&remote_fused_graph_node_name, &remote_graph_executor_name));
|
&fuse_by_executor, &remote_fused_graph_node_name,
|
||||||
|
&remote_graph_executor_name));
|
||||||
|
|
||||||
if (!input_types_str.empty()) {
|
if (!input_types_str.empty()) {
|
||||||
TF_RETURN_IF_ERROR(PlaceShapeType(inputs, outputs, input_types_str,
|
TF_RETURN_IF_ERROR(PlaceShapeType(inputs, outputs, input_types_str,
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/graph/default_device.h"
|
#include "tensorflow/core/graph/default_device.h"
|
||||||
#include "tensorflow/core/graph/node_builder.h"
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
#include "tensorflow/core/graph/testlib.h"
|
#include "tensorflow/core/graph/testlib.h"
|
||||||
|
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
|
||||||
#include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h"
|
#include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h"
|
||||||
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
|
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
@ -43,11 +44,28 @@ Status PlaceRemoteGraphArguments(const GraphDef& input_graph_def,
|
|||||||
GraphDef* output_graph_def);
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr const char* const REMOTE_FUSED_GRAPH_EXECUTOR_NAME =
|
constexpr const char* const REMOTE_FUSED_GRAPH_EXECUTOR_NAME =
|
||||||
"remote_fused_graph_executor_name";
|
"remote_fused_graph_executor_name";
|
||||||
constexpr const char* const REMOTE_FUSED_GRAPH_NODE_NAME =
|
constexpr const char* const REMOTE_FUSED_GRAPH_NODE_NAME =
|
||||||
"remote_fused_graph_node_name";
|
"remote_fused_graph_node_name";
|
||||||
|
constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME0 =
|
||||||
|
"fuse_test_remote_fused_graph_executor0";
|
||||||
|
constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME1 =
|
||||||
|
"fuse_test_remote_fused_graph_executor1";
|
||||||
|
|
||||||
|
Status BuildRemoteFusedGraphExecutor0(
|
||||||
|
std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
|
||||||
|
executor->reset(
|
||||||
|
new TestRemoteFusedGraphExecutor({"Mul"}, REMOTE_FUSED_EXECUTOR_NAME0));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status BuildRemoteFusedGraphExecutor1(
|
||||||
|
std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
|
||||||
|
executor->reset(new TestRemoteFusedGraphExecutor(
|
||||||
|
{"Const", "Mul"}, REMOTE_FUSED_EXECUTOR_NAME1));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
|
class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
@ -55,11 +73,18 @@ class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
|
|||||||
TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(
|
TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(
|
||||||
&input_graph_def_));
|
&input_graph_def_));
|
||||||
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
|
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
|
||||||
k_hexagon_remote_fused_graph_executor_build(
|
hexagon_remote_fused_graph_executor_build(
|
||||||
REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
|
REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
|
||||||
[](std::unique_ptr<IRemoteFusedGraphExecutor>* executor) -> Status {
|
[](std::unique_ptr<IRemoteFusedGraphExecutor>* executor) -> Status {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
|
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
|
||||||
|
test_remote_fused_graph_executor_build0(REMOTE_FUSED_EXECUTOR_NAME0,
|
||||||
|
BuildRemoteFusedGraphExecutor0);
|
||||||
|
|
||||||
|
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
|
||||||
|
test_remote_fused_graph_executor_build1(REMOTE_FUSED_EXECUTOR_NAME1,
|
||||||
|
BuildRemoteFusedGraphExecutor1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TearDown() final {}
|
void TearDown() final {}
|
||||||
@ -113,10 +138,16 @@ class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
|
|||||||
{fused_op_types_str_}}));
|
{fused_op_types_str_}}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (fuse_by_executor_) {
|
||||||
|
context.params.insert(std::pair<string, std::vector<string>>(
|
||||||
|
{RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR,
|
||||||
|
{"true"}}));
|
||||||
|
}
|
||||||
|
|
||||||
context.params.insert(std::pair<string, std::vector<string>>(
|
context.params.insert(std::pair<string, std::vector<string>>(
|
||||||
{RemoteFusedGraphExecuteUtils::
|
{RemoteFusedGraphExecuteUtils::
|
||||||
TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
|
TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
|
||||||
{REMOTE_FUSED_GRAPH_EXECUTOR_NAME}}));
|
{remote_fused_graph_executor_name_}}));
|
||||||
context.params.insert(std::pair<string, std::vector<string>>(
|
context.params.insert(std::pair<string, std::vector<string>>(
|
||||||
{RemoteFusedGraphExecuteUtils::
|
{RemoteFusedGraphExecuteUtils::
|
||||||
TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME,
|
TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME,
|
||||||
@ -160,7 +191,7 @@ class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
|
|||||||
ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO,
|
ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO,
|
||||||
&serialized_proto));
|
&serialized_proto));
|
||||||
info.ParseFromString(serialized_proto);
|
info.ParseFromString(serialized_proto);
|
||||||
CHECK_EQ(REMOTE_FUSED_GRAPH_EXECUTOR_NAME, info.executor_name());
|
CHECK_EQ(remote_fused_graph_executor_name_, info.executor_name());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EXPECT_EQ(expected_cluster_count, cluster_count);
|
EXPECT_EQ(expected_cluster_count, cluster_count);
|
||||||
@ -178,6 +209,8 @@ class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
|
|||||||
string border_inputs_str_;
|
string border_inputs_str_;
|
||||||
string border_outputs_str_;
|
string border_outputs_str_;
|
||||||
string fused_op_types_str_;
|
string fused_op_types_str_;
|
||||||
|
string remote_fused_graph_executor_name_{REMOTE_FUSED_GRAPH_EXECUTOR_NAME};
|
||||||
|
bool fuse_by_executor_{false};
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
|
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
|
||||||
@ -260,6 +293,24 @@ TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
|
|||||||
CheckGraph(3, 1);
|
CheckGraph(3, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
|
||||||
|
FuseRemoteGraphByExecutor_HIJ) {
|
||||||
|
ReplaceOpType({"H", "I", "J"}, "Mul");
|
||||||
|
remote_fused_graph_executor_name_ = REMOTE_FUSED_EXECUTOR_NAME0;
|
||||||
|
fuse_by_executor_ = true;
|
||||||
|
TF_ASSERT_OK(Fuse());
|
||||||
|
CheckGraph(9, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
|
||||||
|
FuseRemoteGraphByExecutor_FGHIJ) {
|
||||||
|
ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul");
|
||||||
|
remote_fused_graph_executor_name_ = REMOTE_FUSED_EXECUTOR_NAME1;
|
||||||
|
fuse_by_executor_ = true;
|
||||||
|
TF_ASSERT_OK(Fuse());
|
||||||
|
CheckGraph(3, 1);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_HIJ) {
|
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_HIJ) {
|
||||||
fused_node_names_str_ = "H,I,J";
|
fused_node_names_str_ = "H,I,J";
|
||||||
TF_ASSERT_OK(PlaceFuseArgs());
|
TF_ASSERT_OK(PlaceFuseArgs());
|
||||||
|
Loading…
x
Reference in New Issue
Block a user