Add a way to fuse a graph by remote graph executor so that users don't need to be aware of supported op types, node names, subgraph stracture etc.
PiperOrigin-RevId: 162411763
This commit is contained in:
parent
d8672f1839
commit
d5f4d9bbac
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
|
||||
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
|
||||
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
|
||||
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
@ -110,7 +111,7 @@ static void DumpRemoteFusedGraph(const NodeDef& node_def) {
|
||||
static void CheckOpsSupport(const GraphDef& graph_def,
|
||||
const bool dump_all_nodes,
|
||||
const bool dump_shape_and_type) {
|
||||
const IGraphTransferOpsDefinitions& ops_definition =
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definition =
|
||||
HexagonOpsDefinitions::getInstance();
|
||||
LOG(INFO) << "Checking " << graph_def.node_size() << " nodes";
|
||||
LOG(INFO) << "dump_all_nodes = " << dump_all_nodes
|
||||
@ -127,7 +128,7 @@ static void CheckOpsSupport(const GraphDef& graph_def,
|
||||
}
|
||||
// TODO(satok): Set correct data type if it's given.
|
||||
const int op_id = ops_definition.GetOpIdFor(node.op(), {});
|
||||
if (op_id == IGraphTransferOpsDefinitions::INVALID_OP_ID) {
|
||||
if (op_id == IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) {
|
||||
all_supported = false;
|
||||
LOG(ERROR) << "OP type: " << node.op() << " is not supported on hvx. "
|
||||
<< "Name = " << node.name();
|
||||
|
@ -82,7 +82,7 @@ fi
|
||||
if [[ "${USE_HEXAGON}" == "true" ]]; then
|
||||
HEXAGON_PARENT_DIR=$(cd "${HEXAGON_DOWNLOAD_PATH}" >/dev/null && pwd)
|
||||
HEXAGON_LIBS="${HEXAGON_PARENT_DIR}/libs"
|
||||
HEXAGON_INCLUDE=$(cd "tensorflow/core/platform/hexagon" >/dev/null && pwd)
|
||||
HEXAGON_INCLUDE=$(cd "tensorflow/core/kernels/hexagon" >/dev/null && pwd)
|
||||
fi
|
||||
|
||||
if [[ "${ENABLE_EXPERIMENTAL_HEXNN_OPS}" == "true" ]]; then
|
||||
|
@ -56,7 +56,6 @@ tensorflow/core/platform/posix/test.cc
|
||||
|
||||
QUANTIZATION_TEST_SRCS := \
|
||||
$(GRAPH_TRANSFER_SRCS) \
|
||||
tensorflow/core/kernels/hexagon/quantized_matmul_op_for_hexagon_test.cc \
|
||||
tensorflow/core/kernels/hexagon/graph_transferer_test.cc \
|
||||
tensorflow/contrib/makefile/test/test_main.cc
|
||||
|
||||
|
@ -5057,9 +5057,13 @@ tf_kernel_library(
|
||||
|
||||
cc_library(
|
||||
name = "remote_fused_graph_execute_utils",
|
||||
srcs = ["remote_fused_graph_execute_utils.cc"],
|
||||
srcs = [
|
||||
"i_remote_fused_graph_ops_definitions.cc",
|
||||
"remote_fused_graph_execute_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"i_remote_fused_graph_executor.h",
|
||||
"i_remote_fused_graph_ops_definitions.h",
|
||||
"remote_fused_graph_execute_utils.h",
|
||||
],
|
||||
deps = [
|
||||
@ -5078,6 +5082,7 @@ cc_library(
|
||||
srcs = ["remote_fused_graph_execute_op_test_utils.cc"],
|
||||
hdrs = ["remote_fused_graph_execute_op_test_utils.h"],
|
||||
deps = [
|
||||
":remote_fused_graph_execute_utils",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:scope",
|
||||
|
@ -26,23 +26,6 @@ filegroup(
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "quantized_matmul_op_for_hexagon_test",
|
||||
size = "small",
|
||||
srcs = ["quantized_matmul_op_for_hexagon_test.cc"],
|
||||
tags = ["nomsan"], # http://b/32242946
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels:ops_testutil",
|
||||
"//tensorflow/core/kernels:ops_util",
|
||||
"//tensorflow/core/kernels:quantized_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "graph_transferer_test",
|
||||
size = "small",
|
||||
@ -79,14 +62,14 @@ tf_kernel_library(
|
||||
"graph_transferer.cc",
|
||||
"hexagon_control_wrapper.cc",
|
||||
"hexagon_ops_definitions.cc",
|
||||
"i_graph_transfer_ops_definitions.cc",
|
||||
"soc_interface.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"graph_transfer_utils.h",
|
||||
"graph_transferer.h",
|
||||
"hexagon_control_wrapper.h",
|
||||
"hexagon_ops_definitions.h",
|
||||
"i_graph_transfer_ops_definitions.h",
|
||||
"soc_interface.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/cc:cc_ops",
|
||||
@ -111,6 +94,7 @@ cc_library(
|
||||
":graph_transferer",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:remote_fused_graph_ops",
|
||||
"//tensorflow/core/kernels:remote_fused_graph_execute_utils",
|
||||
"//tensorflow/tools/graph_transforms:transform_utils",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -121,6 +105,7 @@ tf_cc_test(
|
||||
size = "small",
|
||||
srcs = ["hexagon_rewriter_transform_test.cc"],
|
||||
deps = [
|
||||
":graph_transferer",
|
||||
":hexagon_rewriter_transform",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
@ -129,6 +114,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels:remote_fused_graph_execute_utils",
|
||||
"//tensorflow/tools/graph_transforms:transform_utils",
|
||||
],
|
||||
)
|
||||
|
@ -96,7 +96,7 @@ GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo(
|
||||
}
|
||||
|
||||
/* static */ GraphDef GraphTransferUtils::BuildFusedGraphDef(
|
||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const string& remote_graph_execute_name,
|
||||
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||
const std::vector<string>& outputs, GraphDef* original_def) {
|
||||
|
@ -39,7 +39,7 @@ class GraphTransferUtils {
|
||||
const int element_count, const int top_n);
|
||||
|
||||
static GraphDef BuildFusedGraphDef(
|
||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const string& remote_graph_execute_name,
|
||||
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||
const std::vector<string>& outputs, GraphDef* original_def);
|
||||
|
@ -81,7 +81,7 @@ static Node* FindMutableNodeByName(const string& name, Graph* graph) {
|
||||
* of node to transfer the graph to SOC.
|
||||
*/
|
||||
Status GraphTransferer::LoadGraphFromProto(
|
||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const GraphDef& graph_def,
|
||||
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
||||
const std::vector<string>& output_node_names,
|
||||
@ -177,9 +177,6 @@ Status GraphTransferer::LoadGraphFromProto(
|
||||
}
|
||||
}
|
||||
|
||||
graph_transfer_info_.set_destination(
|
||||
ops_definitions.GetTransferDestination());
|
||||
|
||||
ClearCache();
|
||||
if (DBG_DUMP_PARAMS) {
|
||||
DumpNodeTransferParams();
|
||||
@ -191,7 +188,7 @@ Status GraphTransferer::LoadGraphFromProto(
|
||||
}
|
||||
|
||||
Status GraphTransferer::LoadGraphFromProtoFile(
|
||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const string& graph_def_path,
|
||||
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
||||
const std::vector<string>& output_node_names, const bool is_text_proto,
|
||||
@ -415,7 +412,7 @@ Status GraphTransferer::TransformGraphToAddAggregatedInputNode(
|
||||
}
|
||||
|
||||
Status GraphTransferer::RegisterNode(
|
||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner, const Node& node,
|
||||
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
||||
const std::vector<string>& output_node_names) {
|
||||
@ -438,7 +435,7 @@ Status GraphTransferer::RegisterNode(
|
||||
} else if (IsNodeFlattenReshape(node, shape_refiner)) {
|
||||
RegisterFlattenNode(ops_definitions, shape_refiner, node);
|
||||
} else if (ops_definitions.GetOpIdFor(node.type_string(), {}) !=
|
||||
IGraphTransferOpsDefinitions::INVALID_OP_ID) {
|
||||
IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) {
|
||||
// TODO(satok): Set correct data type if it's given.
|
||||
RegisterGenericNode(ops_definitions, shape_refiner, node);
|
||||
} else {
|
||||
@ -637,7 +634,7 @@ bool GraphTransferer::IsNodeFlattenReshape(const Node& node,
|
||||
}
|
||||
|
||||
void GraphTransferer::RegisterNodeWithPaddingAndStrides(
|
||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner, const Node& node) {
|
||||
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
|
||||
const int id = node_name_to_id_cache_map_[node.name()];
|
||||
@ -671,7 +668,7 @@ void GraphTransferer::RegisterNodeWithPaddingAndStrides(
|
||||
}
|
||||
|
||||
void GraphTransferer::RegisterNodeWithRank(
|
||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner, const Node& node) {
|
||||
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
|
||||
const int id = node_name_to_id_cache_map_[node.name()];
|
||||
@ -704,7 +701,7 @@ void GraphTransferer::RegisterNodeWithRank(
|
||||
}
|
||||
|
||||
void GraphTransferer::RegisterPadNode(
|
||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner, const Node& node) {
|
||||
static constexpr int PAD_WIDTH = 4;
|
||||
static constexpr int PAD_HEIGHT = 2;
|
||||
@ -779,7 +776,7 @@ void GraphTransferer::RegisterPadNode(
|
||||
}
|
||||
|
||||
void GraphTransferer::RegisterInputNode(
|
||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner, const Node& node) {
|
||||
const string op_type = node.type_string();
|
||||
VLOG(1) << "Register input node: " << node.name() << ", " << op_type;
|
||||
@ -797,12 +794,13 @@ void GraphTransferer::RegisterInputNode(
|
||||
}
|
||||
|
||||
void GraphTransferer::RegisterFlattenNode(
|
||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner, const Node& node) {
|
||||
VLOG(1) << "Register flatten node: " << node.name();
|
||||
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
|
||||
const int id = node_name_to_id_cache_map_[node.name()];
|
||||
const string op_type = IGraphTransferOpsDefinitions::FLATTEN_OP_NAME;
|
||||
// TODO(satok): Remove dependency to specific type
|
||||
const string op_type = "FLATTEN";
|
||||
// TODO(satok): Set correct data type if it's given.
|
||||
const int op_type_id = ops_definitions.GetOpIdFor(op_type, {});
|
||||
CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount());
|
||||
@ -814,7 +812,7 @@ void GraphTransferer::RegisterFlattenNode(
|
||||
}
|
||||
|
||||
void GraphTransferer::RegisterGenericNode(
|
||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner, const Node& node) {
|
||||
VLOG(1) << "Register generic node: " << node.name();
|
||||
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
|
||||
@ -832,7 +830,7 @@ void GraphTransferer::RegisterGenericNode(
|
||||
// TODO(satok): Remove this function.
|
||||
// TODO(satok): Remove only_register_const_node.
|
||||
Status GraphTransferer::RegisterNodeIfAllInputsAreCached(
|
||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner, const Node& node,
|
||||
const bool only_register_const_node,
|
||||
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/graph_transfer_info.pb.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.h"
|
||||
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
|
||||
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
@ -53,7 +53,7 @@ class GraphTransferer {
|
||||
// TODO(satok): Pass a pair of TensorShape and DataType instead of
|
||||
// Tensor as input_node_info_list.
|
||||
Status LoadGraphFromProto(
|
||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const GraphDef& graph_def,
|
||||
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
||||
const std::vector<string>& output_node_names,
|
||||
@ -63,7 +63,7 @@ class GraphTransferer {
|
||||
// TODO(satok): Pass a pair of TensorShape and DataType instead of
|
||||
// Tensor as input_node_info_list.
|
||||
Status LoadGraphFromProtoFile(
|
||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const string& graph_def_path,
|
||||
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
||||
const std::vector<string>& output_node_names, const bool is_text_proto,
|
||||
@ -112,7 +112,7 @@ class GraphTransferer {
|
||||
Graph* graph, ShapeRefiner* shape_refiner);
|
||||
|
||||
Status RegisterNode(
|
||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner, const Node& node,
|
||||
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
||||
const std::vector<string>& output_node_names);
|
||||
@ -140,30 +140,29 @@ class GraphTransferer {
|
||||
const ShapeRefiner& shape_refiner);
|
||||
|
||||
void RegisterNodeWithPaddingAndStrides(
|
||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner, const Node& node);
|
||||
|
||||
void RegisterNodeWithRank(const IGraphTransferOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner,
|
||||
const Node& node);
|
||||
void RegisterNodeWithRank(
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner, const Node& node);
|
||||
|
||||
void RegisterPadNode(const IGraphTransferOpsDefinitions& ops_definitions,
|
||||
void RegisterPadNode(const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner, const Node& node);
|
||||
|
||||
void RegisterInputNode(const IGraphTransferOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner,
|
||||
const Node& node);
|
||||
void RegisterInputNode(const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner, const Node& node);
|
||||
|
||||
void RegisterFlattenNode(const IGraphTransferOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner,
|
||||
const Node& node);
|
||||
void RegisterFlattenNode(
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner, const Node& node);
|
||||
|
||||
void RegisterGenericNode(const IGraphTransferOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner,
|
||||
const Node& node);
|
||||
void RegisterGenericNode(
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner, const Node& node);
|
||||
|
||||
Status RegisterNodeIfAllInputsAreCached(
|
||||
const IGraphTransferOpsDefinitions& ops_definitions,
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner, const Node& node,
|
||||
const bool only_register_const_node,
|
||||
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
||||
|
@ -22,8 +22,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h"
|
||||
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
|
||||
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
|
||||
#include "tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.h"
|
||||
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
|
||||
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
@ -50,7 +50,7 @@ class GraphTransfererTest : public ::testing::Test {
|
||||
|
||||
const RemoteFusedGraphExecuteUtils::TensorShapeMap EMPTY_OUTPUT_TENSOR_MAP;
|
||||
|
||||
class TestGraphTransferOpsDefinitions : public IGraphTransferOpsDefinitions {
|
||||
class TestGraphTransferOpsDefinitions : public IRemoteFusedGraphOpsDefinitions {
|
||||
public:
|
||||
int GetTotalOpsCount() const final { return op_types_.size(); }
|
||||
|
||||
@ -63,10 +63,6 @@ class TestGraphTransferOpsDefinitions : public IGraphTransferOpsDefinitions {
|
||||
return -1;
|
||||
}
|
||||
|
||||
GraphTransferInfo::Destination GetTransferDestination() const final {
|
||||
return GraphTransferInfo::NOP;
|
||||
}
|
||||
|
||||
private:
|
||||
const std::vector<string> op_types_{"INPUT", "OUTPUT", "Conv2D",
|
||||
"MaxPool", "NoOp", "Add",
|
||||
@ -371,14 +367,14 @@ TEST_F(GraphTransfererTest, LoadMaxPoolGraph) {
|
||||
}
|
||||
|
||||
TEST(HexagonOpsDefinitions, CheckOpsDefinitions) {
|
||||
const IGraphTransferOpsDefinitions& ops_definitions =
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions =
|
||||
HexagonOpsDefinitions::getInstance();
|
||||
const int total_ops_count = ops_definitions.GetTotalOpsCount();
|
||||
EXPECT_GT(total_ops_count, 0);
|
||||
}
|
||||
|
||||
TEST(GraphTransferer, LoadGraphFromProtoFile) {
|
||||
const IGraphTransferOpsDefinitions* ops_definitions =
|
||||
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
|
||||
&TEST_GRAPH_TRANSFER_OPS_DEFINITIONS;
|
||||
string filename =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(),
|
||||
@ -441,7 +437,7 @@ void CompareGraphTransferInfo(const GraphTransferInfo& a,
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(GraphTransferer, LoadGraphFromProtoFileShapeInferenceSimple) {
|
||||
const IGraphTransferOpsDefinitions* ops_definitions =
|
||||
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
|
||||
&TEST_GRAPH_TRANSFER_OPS_DEFINITIONS;
|
||||
string filename =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(),
|
||||
|
@ -16,18 +16,19 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h"
|
||||
|
||||
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
|
||||
|
||||
#ifdef USE_HEXAGON_LIBS
|
||||
#include "tensorflow/core/platform/hexagon/soc_interface.h"
|
||||
#include "tensorflow/core/kernels/hexagon/soc_interface.h"
|
||||
#include "tensorflow/core/platform/profile_utils/cpu_utils.h"
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
constexpr const char* const INPUT_OP_NAME = "INPUT";
|
||||
constexpr const char* const OUTPUT_OP_NAME = "OUTPUT";
|
||||
constexpr const char* const REMOTE_FUSED_GRAPH_NODE_NAME_PREFIX =
|
||||
"hexagon_remote_fused_graph";
|
||||
/* static */ constexpr const char* const
|
||||
HexagonControlWrapper::REMOTE_FUSED_GRAPH_EXECUTOR_NAME;
|
||||
|
||||
constexpr int ALIGNMENT_BYTES = 16;
|
||||
constexpr int MAX_IN_OUT_COUNT = 128;
|
||||
|
||||
const bool DBG_DUMP_VERIFICATION_STRING = false;
|
||||
const int DBG_LEVEL = 0; // -2: verbose, -1: debug, 0: info
|
||||
@ -63,7 +64,6 @@ static uint8* FindAlignedPointer(uint8* ptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#ifdef USE_HEXAGON_LIBS
|
||||
int HexagonControlWrapper::GetVersion() {
|
||||
return soc_interface_GetSocControllerVersion();
|
||||
}
|
||||
@ -95,7 +95,6 @@ bool HexagonControlWrapper::Init(const RemoteFusedGraphExecuteInfo& info) {
|
||||
LOG(ERROR) << "Hexagon initialization was failed. See log output.";
|
||||
return false;
|
||||
}
|
||||
const GraphTransferInfo& gt_info = graph_transferer_.GetGraphTransferInfo();
|
||||
std::vector<int> input_sizes;
|
||||
std::vector<int> output_sizes;
|
||||
CHECK_NOTNULL(execute_info_);
|
||||
@ -207,8 +206,9 @@ bool HexagonControlWrapper::SetupGraph() {
|
||||
for (const GraphTransferInfo::NodeInputInfo& input_params :
|
||||
graph_transfer_info.node_input_info()) {
|
||||
const int count = input_params.node_input_size();
|
||||
int node_ids[count];
|
||||
int ports[count];
|
||||
CHECK(count <= MAX_IN_OUT_COUNT);
|
||||
int node_ids[MAX_IN_OUT_COUNT];
|
||||
int ports[MAX_IN_OUT_COUNT];
|
||||
for (int i = 0; i < count; ++i) {
|
||||
const GraphTransferInfo::NodeInput& node_input =
|
||||
input_params.node_input(i);
|
||||
@ -226,7 +226,8 @@ bool HexagonControlWrapper::SetupGraph() {
|
||||
for (const GraphTransferInfo::NodeOutputInfo& output_params :
|
||||
graph_transfer_info.node_output_info()) {
|
||||
const int count = output_params.max_byte_size_size();
|
||||
int sizes[count];
|
||||
CHECK(count <= MAX_IN_OUT_COUNT);
|
||||
int sizes[MAX_IN_OUT_COUNT];
|
||||
for (int i = 0; i < count; ++i) {
|
||||
const int size = output_params.max_byte_size(i);
|
||||
sizes[i] = size;
|
||||
@ -373,6 +374,7 @@ bool HexagonControlWrapper::ReadOutputNode(
|
||||
<< output_tensor->TotalBytes() << ", " << std::get<1>(output);
|
||||
TF_CHECK_OK(RemoteFusedGraphExecuteUtils::CopyByteArrayToTensor(
|
||||
std::get<0>(output), std::get<1>(output), output_tensor));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HexagonControlWrapper::ReadOutputNode(
|
||||
@ -382,14 +384,30 @@ bool HexagonControlWrapper::ReadOutputNode(
|
||||
const string tensor_name = AddPort(node_name);
|
||||
CHECK(output_port_map_.count(tensor_name) > 0);
|
||||
const int port = output_port_map_.at(tensor_name);
|
||||
soc_interface_ReadOutputNodeWithPort(port, &std::get<0>(output),
|
||||
&std::get<1>(output));
|
||||
soc_interface_ReadOutputNodeWithPort(
|
||||
port, &std::get<0>(output),
|
||||
reinterpret_cast<uint64_t*>(&std::get<1>(output)));
|
||||
// TODO: Accept all results
|
||||
// std::get<2>(output) = DT_FLOAT;
|
||||
outputs->emplace_back(output);
|
||||
return true;
|
||||
}
|
||||
|
||||
Status HexagonControlWrapper::FuseRemoteGraph(
|
||||
const GraphDef& original_graph_def, const std::vector<string>& inputs,
|
||||
const std::vector<string>& outputs, GraphDef* fused_graph_def) {
|
||||
const std::unordered_set<string> fused_node_names =
|
||||
RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions(
|
||||
original_graph_def, HexagonOpsDefinitions::getInstance());
|
||||
// TODO(satok): We may want to place shape and type inside this function
|
||||
// if they are not placed in the given graph.
|
||||
TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::FuseRemoteGraphByNodeNames(
|
||||
original_graph_def, inputs, outputs, REMOTE_FUSED_GRAPH_NODE_NAME_PREFIX,
|
||||
fused_node_names, REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
|
||||
/*require_shape_type=*/true, fused_graph_def));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool HexagonControlWrapper::FillInputNode(const string& node_name,
|
||||
const Tensor& tensor) {
|
||||
StringPiece tensor_data = tensor.tensor_data();
|
||||
@ -415,31 +433,5 @@ bool HexagonControlWrapper::FillInputNode(const string& node_name,
|
||||
return true;
|
||||
}
|
||||
|
||||
#else
|
||||
int HexagonControlWrapper::GetVersion() { return -1; }
|
||||
bool HexagonControlWrapper::Init(const RemoteFusedGraphExecuteInfo&) {
|
||||
return false;
|
||||
}
|
||||
bool HexagonControlWrapper::Finalize() { return false; }
|
||||
bool HexagonControlWrapper::SetupGraph() { return false; }
|
||||
bool HexagonControlWrapper::ExecuteGraph() { return false; }
|
||||
bool HexagonControlWrapper::TeardownGraph() { return false; }
|
||||
bool HexagonControlWrapper::FillInputNode(
|
||||
const string&, const std::array<int64, GraphTransferer::SHAPE_ARRAY_SIZE>&,
|
||||
const ConstByteArray) {
|
||||
return false;
|
||||
}
|
||||
bool HexagonControlWrapper::FillInputNode(const string&, const Tensor&) {
|
||||
return false;
|
||||
}
|
||||
bool HexagonControlWrapper::ReadOutputNode(
|
||||
const string& node_name, TensorAllocatorFunc tensor_allocator) {
|
||||
return false;
|
||||
}
|
||||
bool HexagonControlWrapper::ReadOutputNode(const string&,
|
||||
std::vector<ByteArray>* const) {
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
bool HexagonControlWrapper::IsEnabled() const { return true; };
|
||||
} // namespace tensorflow
|
||||
|
@ -35,6 +35,8 @@ class HexagonControlWrapper final : public IRemoteFusedGraphExecutor {
|
||||
public:
|
||||
using ByteArray =
|
||||
std::tuple<uint8* /* data */, uint64 /* size */, DataType /* type */>;
|
||||
static constexpr const char* const REMOTE_FUSED_GRAPH_EXECUTOR_NAME =
|
||||
"build_hexagon_remote_fused_graph_executor";
|
||||
|
||||
HexagonControlWrapper() = default;
|
||||
int GetVersion() final;
|
||||
@ -46,6 +48,11 @@ class HexagonControlWrapper final : public IRemoteFusedGraphExecutor {
|
||||
bool FillInputNode(const string& node_name, const Tensor& tensor) final;
|
||||
bool ReadOutputNode(const string& node_name,
|
||||
TensorAllocatorFunc tensor_allocator) final;
|
||||
Status FuseRemoteGraph(const GraphDef& original_graph_def,
|
||||
const std::vector<string>& inputs,
|
||||
const std::vector<string>& outputs,
|
||||
GraphDef* fused_graph_def) final;
|
||||
bool IsEnabled() const final;
|
||||
bool ReadOutputNode(const string& node_name, std::vector<ByteArray>* outputs);
|
||||
|
||||
private:
|
||||
|
@ -35,8 +35,8 @@ adb push /tmp/imagenet_comp_graph_label_strings.txt /data/local/tmp
|
||||
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
|
||||
#include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h"
|
||||
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
|
||||
#include "tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.h"
|
||||
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
|
||||
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
|
||||
#include "tensorflow/core/kernels/quantization_utils.h"
|
||||
#include "tensorflow/core/lib/core/casts.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -389,9 +389,6 @@ static void CompareGraphTransferInfo(const GraphTransferInfo& gfi0,
|
||||
EXPECT_EQ(goni0.ByteSizeLong(), goni1.ByteSizeLong());
|
||||
EXPECT_EQ(goni0.DebugString(), goni1.DebugString());
|
||||
}
|
||||
|
||||
// 7. check destination
|
||||
EXPECT_EQ(gfi0.destination(), gfi1.destination());
|
||||
}
|
||||
|
||||
// CAVEAT: This test only runs when you specify hexagon library using
|
||||
@ -407,7 +404,7 @@ TEST(GraphTransferer,
|
||||
LOG(INFO) << "Run inception v3 on hexagon with hexagon controller";
|
||||
CheckHexagonControllerVersion();
|
||||
|
||||
const IGraphTransferOpsDefinitions* ops_definitions =
|
||||
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
|
||||
&HexagonOpsDefinitions::getInstance();
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
|
||||
@ -474,7 +471,7 @@ TEST(GraphTransferer,
|
||||
LOG(INFO) << "Run inception v3 on hexagon with hexagon controller";
|
||||
CheckHexagonControllerVersion();
|
||||
|
||||
const IGraphTransferOpsDefinitions* ops_definitions =
|
||||
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
|
||||
&HexagonOpsDefinitions::getInstance();
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
|
||||
@ -505,7 +502,7 @@ TEST(GraphTransferer, RunInceptionV3OnHexagonExampleWithTfRuntime) {
|
||||
LOG(INFO) << "Fuse and run inception v3 on hexagon with tf runtime";
|
||||
CheckHexagonControllerVersion();
|
||||
|
||||
const IGraphTransferOpsDefinitions* ops_definitions =
|
||||
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
|
||||
&HexagonOpsDefinitions::getInstance();
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
|
||||
@ -543,7 +540,7 @@ TEST(GraphTransferer, DISABLED_CheckShapeInferencePerformance) {
|
||||
CheckHexagonControllerVersion();
|
||||
profile_utils::CpuUtils::EnableClockCycleProfiling(true);
|
||||
|
||||
const IGraphTransferOpsDefinitions* ops_definitions =
|
||||
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
|
||||
&HexagonOpsDefinitions::getInstance();
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
|
||||
|
@ -304,8 +304,8 @@ HexagonOpsDefinitions::BuildOpNameToSocOpTypeMap() {
|
||||
EmplaceOpType("INPUT", {}, SupportedOpType::INPUT, &op_map);
|
||||
EmplaceOpType("OUTPUT", {}, SupportedOpType::OUTPUT, &op_map);
|
||||
EmplaceOpType("NoOp", {}, SupportedOpType::NOP, &op_map);
|
||||
EmplaceOpType(IGraphTransferOpsDefinitions::FLATTEN_OP_NAME, {},
|
||||
SupportedOpType::FLATTEN, &op_map);
|
||||
// Special op type for hexagon
|
||||
EmplaceOpType("FLATTEN", {}, SupportedOpType::FLATTEN, &op_map);
|
||||
// Tensorflow op name
|
||||
// CAVEAT: Keep order of SupportedOpType
|
||||
EmplaceOpType("Identity", {}, SupportedOpType::NOP, &op_map);
|
||||
@ -373,7 +373,7 @@ HexagonOpsDefinitions::BuildOpNameToSocOpTypeMap() {
|
||||
HexagonOpsDefinitions::HexagonOpsDefinitions()
|
||||
: op_name_to_soc_op_type_map_(BuildOpNameToSocOpTypeMap()) {}
|
||||
|
||||
/* static */ const IGraphTransferOpsDefinitions&
|
||||
/* static */ const IRemoteFusedGraphOpsDefinitions&
|
||||
HexagonOpsDefinitions::getInstance() {
|
||||
const static HexagonOpsDefinitions instance{};
|
||||
return instance;
|
||||
@ -393,17 +393,17 @@ int HexagonOpsDefinitions::GetOpIdFor(const string& op_type,
|
||||
if (dt_vec.empty()) {
|
||||
return static_cast<int>(std::get<1>(dt_to_op_vec.front()));
|
||||
}
|
||||
// If there is only one op_id registered for empty op_vec, we assume
|
||||
// that the op supports any data types.
|
||||
if (dt_to_op_vec.size() == 1 && std::get<0>(dt_to_op_vec.front()).empty()) {
|
||||
return static_cast<int>(std::get<1>(dt_to_op_vec.front()));
|
||||
}
|
||||
for (const DataTypeToOp& data_type_to_op : dt_to_op_vec) {
|
||||
if (std::get<0>(data_type_to_op) == dt_vec) {
|
||||
return static_cast<int>(std::get<1>(data_type_to_op));
|
||||
}
|
||||
}
|
||||
}
|
||||
return IGraphTransferOpsDefinitions::INVALID_OP_ID;
|
||||
}
|
||||
|
||||
GraphTransferInfo::Destination HexagonOpsDefinitions::GetTransferDestination()
|
||||
const {
|
||||
return GraphTransferInfo::HEXAGON;
|
||||
return IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID;
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -18,21 +18,20 @@ limitations under the License.
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "i_graph_transfer_ops_definitions.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// HexagonOpsDefinitions provides ops definitions supported in hexagon library
|
||||
// TODO(satok): add a functionality to call functions in hexagon library
|
||||
class HexagonOpsDefinitions final : public IGraphTransferOpsDefinitions {
|
||||
class HexagonOpsDefinitions final : public IRemoteFusedGraphOpsDefinitions {
|
||||
public:
|
||||
static const IGraphTransferOpsDefinitions& getInstance();
|
||||
static const IRemoteFusedGraphOpsDefinitions& getInstance();
|
||||
|
||||
int GetTotalOpsCount() const final;
|
||||
int GetOpIdFor(const string& op_type, const DataTypeVector& dt) const final;
|
||||
GraphTransferInfo::Destination GetTransferDestination() const final;
|
||||
|
||||
private:
|
||||
enum class SupportedOpType;
|
||||
|
@ -27,7 +27,7 @@ Status BuildRemoteFusedGraphExecutor(
|
||||
|
||||
static RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
|
||||
k_hexagon_remote_fused_graph_executor_build(
|
||||
"build_hexagon_remote_fused_graph_executor",
|
||||
HexagonControlWrapper::REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
|
||||
BuildRemoteFusedGraphExecutor);
|
||||
|
||||
} // namespace hexagon_remote_fused_graph_executor_build
|
||||
|
@ -1,136 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// Tests in this file are designed to evaluate hexagon DSP operations.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/kernels/quantization_utils.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
#ifdef USE_HEXAGON_LIBS
|
||||
#include "tensorflow/core/platform/hexagon/soc_interface.h"
|
||||
#include "tensorflow/core/platform/profile_utils/cpu_utils.h"
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class QuantizedMatMulOpForHexagonTest : public OpsTestBase {
|
||||
protected:
|
||||
void SetUp() final {
|
||||
#ifdef USE_HEXAGON_LIBS
|
||||
profile_utils::CpuUtils::EnableClockCycleProfiling(true);
|
||||
LOG(INFO) << "Hexagon libs are linked (wrapper version = "
|
||||
<< soc_interface_GetWrapperVersion()
|
||||
<< ", hexagon binary version = "
|
||||
<< soc_interface_GetSocControllerVersion() << ")";
|
||||
LOG(INFO) << "Cpu frequency = "
|
||||
<< profile_utils::CpuUtils::GetCycleCounterFrequency();
|
||||
#else
|
||||
LOG(WARNING) << "Hexagon libs are not linked.";
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// Shows some statistics of hexagon dsp using hexagon specific APIs
|
||||
#ifdef USE_HEXAGON_LIBS
|
||||
TEST_F(QuantizedMatMulOpForHexagonTest, EvaluateSharedLibOverhead) {
|
||||
const uint64 overhead_shared_lib_start =
|
||||
profile_utils::CpuUtils::GetCurrentClockCycle();
|
||||
const int wrapper_version = soc_interface_GetWrapperVersion();
|
||||
const uint64 overhead_shared_lib_end =
|
||||
profile_utils::CpuUtils::GetCurrentClockCycle();
|
||||
const uint64 overhead_shared_lib_diff =
|
||||
(overhead_shared_lib_end - overhead_shared_lib_start);
|
||||
const uint64 overhead_hexagon_rpc_start =
|
||||
profile_utils::CpuUtils::GetCurrentClockCycle();
|
||||
const int hexagon_binary_version = soc_interface_GetSocControllerVersion();
|
||||
const uint64 overhead_hexagon_rpc_end =
|
||||
profile_utils::CpuUtils::GetCurrentClockCycle();
|
||||
const uint64 overhead_hexagon_rpc_diff =
|
||||
(overhead_hexagon_rpc_end - overhead_hexagon_rpc_start);
|
||||
LOG(INFO) << "Shared lib (ver = " << wrapper_version << ") overhead is "
|
||||
<< overhead_shared_lib_diff << " cycles, time = "
|
||||
<< std::chrono::duration_cast<std::chrono::microseconds>(
|
||||
profile_utils::CpuUtils::ConvertClockCycleToTime(
|
||||
overhead_shared_lib_diff))
|
||||
.count()
|
||||
<< " usec";
|
||||
LOG(INFO) << "hexagon rpc (ver = " << hexagon_binary_version
|
||||
<< ") overhead is " << overhead_hexagon_rpc_diff
|
||||
<< " cycles, time = "
|
||||
<< std::chrono::duration_cast<std::chrono::microseconds>(
|
||||
profile_utils::CpuUtils::ConvertClockCycleToTime(
|
||||
overhead_hexagon_rpc_diff))
|
||||
.count()
|
||||
<< " usec";
|
||||
}
|
||||
#endif
|
||||
|
||||
// Runs two small matrices through the operator, and leaves all the parameters
|
||||
// at their default values.
|
||||
// This test is a sample to execute matmul on hexagon.
|
||||
TEST_F(QuantizedMatMulOpForHexagonTest, Small_NoParams) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("quantized_mat_mul_op", "QuantizedMatMul")
|
||||
.Input(FakeInput(DT_QUINT8))
|
||||
.Input(FakeInput(DT_QUINT8))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Attr("Toutput", DataTypeToEnum<qint32>::v())
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
// A matrix is:
|
||||
// | 1 | 2 | 3 |
|
||||
// | 4 | 5 | 6 |
|
||||
AddInputFromArray<quint8>(TensorShape({2, 3}), {1, 2, 3, 4, 5, 6});
|
||||
// B matrix is:
|
||||
// | 7 | 8 | 9 | 10 |
|
||||
// | 11 | 12 | 13 | 14 |
|
||||
// | 15 | 16 | 17 | 18 |
|
||||
AddInputFromArray<quint8>(TensorShape({3, 4}),
|
||||
{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
|
||||
AddInputFromArray<float>(TensorShape({1}), {0});
|
||||
AddInputFromArray<float>(TensorShape({1}), {255.0f});
|
||||
AddInputFromArray<float>(TensorShape({1}), {0});
|
||||
AddInputFromArray<float>(TensorShape({1}), {255.0f});
|
||||
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
// Here are the results we expect, from hand calculations:
|
||||
// (1 * 7) + (2 * 11) + (3 * 15) = 74
|
||||
// (1 * 8) + (2 * 12) + (3 * 16) = 80
|
||||
// (1 * 9) + (2 * 13) + (3 * 17) = 86
|
||||
// (1 * 10) + (2 * 14) + (3 * 18) = 92
|
||||
// (4 * 7) + (5 * 11) + (6 * 15) = 173
|
||||
// (4 * 8) + (5 * 12) + (6 * 16) = 188
|
||||
// (4 * 9) + (5 * 13) + (6 * 17) = 203
|
||||
// (4 * 10) + (5 * 14) + (6 * 18) = 218
|
||||
Tensor expected(allocator(), DT_QINT32, TensorShape({2, 4}));
|
||||
test::FillValues<qint32>(&expected, {74, 80, 86, 92, 173, 188, 203, 218});
|
||||
test::ExpectTensorEqual<qint32>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
83
tensorflow/core/kernels/hexagon/soc_interface.cc
Normal file
83
tensorflow/core/kernels/hexagon/soc_interface.cc
Normal file
@ -0,0 +1,83 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
vcyou may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/hexagon/soc_interface.h"
|
||||
|
||||
// Dummy implementation of soc_interface.
|
||||
|
||||
int soc_interface_GetWrapperVersion() { return -1; }
|
||||
int soc_interface_GetSocControllerVersion() { return -1; }
|
||||
bool soc_interface_Init() { return false; }
|
||||
bool soc_interface_Finalize() { return false; }
|
||||
bool soc_interface_ExecuteGraph() { return false; }
|
||||
bool soc_interface_TeardownGraph() { return false; }
|
||||
bool soc_interface_AllocateInOutNodeBuffers(int /*input_count*/,
|
||||
int* /*input_sizes*/,
|
||||
int /*output_count*/,
|
||||
int* /*output_sizes*/) {
|
||||
return false;
|
||||
}
|
||||
bool soc_interface_FillInputNodeWithPort(int /*port*/, int /*x*/, int /*y*/,
|
||||
int /*z*/, int /*d*/,
|
||||
const uint8_t* const /*buf*/,
|
||||
uint64_t /*buf_byte_size*/) {
|
||||
return false;
|
||||
}
|
||||
bool soc_interface_FillInputNodeFloat(int /*x*/, int /*y*/, int /*z*/,
|
||||
int /*d*/, const uint8_t* const /*buf*/,
|
||||
uint64_t /*buf_byte_size*/) {
|
||||
return false;
|
||||
}
|
||||
bool soc_interface_ReadOutputNodeWithPort(int /*port*/, uint8_t** /*buf*/,
|
||||
uint64_t* /*buf_byte_size*/) {
|
||||
return false;
|
||||
}
|
||||
bool soc_interface_ReadOutputNodeFloat(const char* const /*node_name*/,
|
||||
uint8_t** /*buf*/,
|
||||
uint64_t* /*buf_byte_size*/) {
|
||||
return false;
|
||||
}
|
||||
bool soc_interface_setupDummyGraph(int /*version*/) { return false; }
|
||||
bool soc_interface_AllocateNodeInputAndNodeOutputArray(
|
||||
int /*total_input_count*/, int /*total_output_count*/) {
|
||||
return false;
|
||||
}
|
||||
bool soc_interface_ReleaseNodeInputAndNodeOutputArray() { return false; }
|
||||
void* soc_interface_SetOneNodeInputs(int /*input_count*/,
|
||||
const int* const /*node_id*/,
|
||||
const int* const /*port*/) {
|
||||
return nullptr;
|
||||
}
|
||||
void* soc_interface_SetOneNodeOutputs(int /*output_count*/, int* /*max_size*/) {
|
||||
return nullptr;
|
||||
}
|
||||
bool soc_interface_AppendConstNode(const char* const /*name*/, int /*node_id*/,
|
||||
int /*batch*/, int /*height*/, int /*width*/,
|
||||
int /*depth*/, const uint8_t* const /*data*/,
|
||||
int /*data_length*/) {
|
||||
return false;
|
||||
}
|
||||
bool soc_interface_AppendNode(const char* const /*name*/, int /*node_id*/,
|
||||
int /*op_id*/, int /*padding_id*/,
|
||||
const void* const /*inputs*/,
|
||||
int /*inputs_count*/,
|
||||
const void* const /*outputs*/,
|
||||
int /*outputs_count*/) {
|
||||
return false;
|
||||
}
|
||||
bool soc_interface_InstantiateGraph() { return false; }
|
||||
bool soc_interface_ConstructGraph() { return false; }
|
||||
void soc_interface_SetLogLevel(int /*log_level*/) {}
|
||||
void soc_interface_SetDebugFlag(uint64_t /*flag*/) {}
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
// All functions defined here must have prefix "soc_interface" to avoid
|
||||
// naming conflicts.
|
||||
#ifdef __cplusplus
|
||||
#include <cstdint>
|
||||
extern "C" {
|
||||
#else
|
||||
#include <stdbool.h>
|
@ -59,6 +59,13 @@ class IRemoteFusedGraphExecutor {
|
||||
virtual bool ReadOutputNode(const string& node_name,
|
||||
TensorAllocatorFunc tensor_allocator) = 0;
|
||||
|
||||
virtual Status FuseRemoteGraph(const GraphDef& original_graph_def,
|
||||
const std::vector<string>& inputs,
|
||||
const std::vector<string>& outputs,
|
||||
GraphDef* fused_graph_def) = 0;
|
||||
|
||||
virtual bool IsEnabled() const = 0;
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(IRemoteFusedGraphExecutor);
|
||||
};
|
||||
|
@ -13,11 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "i_graph_transfer_ops_definitions.h"
|
||||
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
|
||||
|
||||
namespace tensorflow {
|
||||
/* static */ constexpr int IGraphTransferOpsDefinitions::INVALID_OP_ID;
|
||||
// TODO(satok): Remove
|
||||
/* static */ constexpr const char* const
|
||||
IGraphTransferOpsDefinitions::FLATTEN_OP_NAME;
|
||||
/* static */ constexpr int IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID;
|
||||
}
|
@ -13,39 +13,34 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_I_GRAPH_TRANSFER_OPS_DEFINITIONS_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_I_GRAPH_TRANSFER_OPS_DEFINITIONS_H_
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_OPS_DEFINITIONS_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_OPS_DEFINITIONS_H_
|
||||
|
||||
#include "tensorflow/core/framework/graph_transfer_info.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// IGraphTransferOpsDefinitions is an interface class which provides interfaces
|
||||
// about ops supported by SOC.
|
||||
// IRemoteFusedGraphOpsDefinitions is an interface class which provides
|
||||
// APIs to provide information about op types supported by SOC.
|
||||
// TODO(satok): Provide ways to transfer graph definitions into SOC
|
||||
class IGraphTransferOpsDefinitions {
|
||||
class IRemoteFusedGraphOpsDefinitions {
|
||||
public:
|
||||
// op id which is not supported by SOC
|
||||
static constexpr int INVALID_OP_ID = -1;
|
||||
// Custom op name for flatten node
|
||||
static constexpr const char* const FLATTEN_OP_NAME = "FLATTEN";
|
||||
|
||||
IGraphTransferOpsDefinitions() = default;
|
||||
virtual ~IGraphTransferOpsDefinitions() = default;
|
||||
IRemoteFusedGraphOpsDefinitions() = default;
|
||||
virtual ~IRemoteFusedGraphOpsDefinitions() = default;
|
||||
// Return total ops count supported by SOC
|
||||
virtual int GetTotalOpsCount() const = 0;
|
||||
// Return op id for given string op name
|
||||
virtual int GetOpIdFor(const string& op_name,
|
||||
virtual int GetOpIdFor(const string& op_type,
|
||||
const DataTypeVector& dt) const = 0;
|
||||
// Return destination of transfer
|
||||
virtual GraphTransferInfo::Destination GetTransferDestination() const = 0;
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(IGraphTransferOpsDefinitions);
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(IRemoteFusedGraphOpsDefinitions);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_I_GRAPH_TRANSFER_OPS_DEFINITIONS_H_
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_OPS_DEFINITIONS_H_
|
@ -41,7 +41,8 @@ class RemoteFusedGraphExecuteOp : public OpKernel {
|
||||
RemoteFusedGraphExecuteUtils::GetExecutorBuildFunc(
|
||||
execute_info_.executor_name());
|
||||
if (build_func != nullptr) {
|
||||
Status status = (*build_func)(&remote_fused_graph_executor_);
|
||||
TF_CHECK_OK((*build_func)(&remote_fused_graph_executor_));
|
||||
CHECK(remote_fused_graph_executor_->IsEnabled());
|
||||
} else {
|
||||
LOG(ERROR) << "Executor not found for "
|
||||
<< execute_info_.executor_name();
|
||||
|
@ -159,8 +159,8 @@ static RemoteFusedGraphExecuteInfo BuildRemoteFusedGraphExecuteInfo(
|
||||
return execute_info;
|
||||
}
|
||||
|
||||
// 1. Create TestRemoteFusedGraphExecutor to execute your fused graph
|
||||
class TestRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor {
|
||||
// 1. Create SampleRemoteFusedGraphExecutor to execute your fused graph
|
||||
class SampleRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor {
|
||||
public:
|
||||
int GetVersion() final { return 1; }
|
||||
bool Init(const RemoteFusedGraphExecuteInfo& info) final {
|
||||
@ -214,6 +214,16 @@ class TestRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor {
|
||||
return true;
|
||||
}
|
||||
|
||||
Status FuseRemoteGraph(const GraphDef& original_graph_def,
|
||||
const std::vector<string>& /*inputs*/,
|
||||
const std::vector<string>& /*outputs*/,
|
||||
GraphDef* fused_graph_def) final {
|
||||
*fused_graph_def = original_graph_def;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool IsEnabled() const final { return true; }
|
||||
|
||||
private:
|
||||
const RemoteFusedGraphExecuteInfo* info_;
|
||||
std::unordered_map<string, Tensor> input_tensor_cache_;
|
||||
@ -225,7 +235,7 @@ class TestRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor {
|
||||
namespace remote_fused_graph_execute_op {
|
||||
Status BuildRemoteFusedGraphExecutor(
|
||||
std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
|
||||
executor->reset(new TestRemoteFusedGraphExecutor());
|
||||
executor->reset(new SampleRemoteFusedGraphExecutor());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/math_ops.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
@ -92,4 +93,36 @@ namespace tensorflow {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TestRemoteFusedGraphExecutor::TestRemoteFusedGraphExecutor(
|
||||
const std::unordered_set<string>& fused_op_types,
|
||||
const string& executor_name)
|
||||
: fused_op_types_(fused_op_types), executor_name_(executor_name) {}
|
||||
|
||||
int TestRemoteFusedGraphExecutor::GetVersion() { return 0; }
|
||||
bool TestRemoteFusedGraphExecutor::Init(const RemoteFusedGraphExecuteInfo&) {
|
||||
return true;
|
||||
}
|
||||
bool TestRemoteFusedGraphExecutor::Finalize() { return true; }
|
||||
bool TestRemoteFusedGraphExecutor::SetupGraph() { return true; }
|
||||
bool TestRemoteFusedGraphExecutor::ExecuteGraph() { return true; }
|
||||
bool TestRemoteFusedGraphExecutor::TeardownGraph() { return true; }
|
||||
bool TestRemoteFusedGraphExecutor::FillInputNode(const string&, const Tensor&) {
|
||||
return true;
|
||||
}
|
||||
bool TestRemoteFusedGraphExecutor::ReadOutputNode(const string&,
|
||||
TensorAllocatorFunc) {
|
||||
return true;
|
||||
}
|
||||
Status TestRemoteFusedGraphExecutor::FuseRemoteGraph(
|
||||
const GraphDef& original_graph_def, const std::vector<string>& inputs,
|
||||
const std::vector<string>& outputs, GraphDef* fused_graph_def) {
|
||||
return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByOpTypes(
|
||||
original_graph_def, inputs, outputs, "remote_fused_graph_node_names",
|
||||
fused_op_types_, executor_name_,
|
||||
/*require_shape_type=*/false, fused_graph_def);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool TestRemoteFusedGraphExecutor::IsEnabled() const { return true; }
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -59,6 +60,30 @@ class RemoteFusedGraphExecuteOpTestUtils {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RemoteFusedGraphExecuteOpTestUtils);
|
||||
};
|
||||
|
||||
class TestRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor {
|
||||
public:
|
||||
TestRemoteFusedGraphExecutor(const std::unordered_set<string>& fused_op_types,
|
||||
const string& executor_name);
|
||||
|
||||
int GetVersion() final;
|
||||
bool Init(const RemoteFusedGraphExecuteInfo&) final;
|
||||
bool Finalize() final;
|
||||
bool SetupGraph() final;
|
||||
bool ExecuteGraph() final;
|
||||
bool TeardownGraph() final;
|
||||
bool FillInputNode(const string&, const Tensor&) final;
|
||||
bool ReadOutputNode(const string&, TensorAllocatorFunc) final;
|
||||
Status FuseRemoteGraph(const GraphDef& original_graph_def,
|
||||
const std::vector<string>& inputs,
|
||||
const std::vector<string>& outputs,
|
||||
GraphDef* fused_graph_def) final;
|
||||
bool IsEnabled() const final;
|
||||
|
||||
private:
|
||||
const std::unordered_set<string> fused_op_types_;
|
||||
const string executor_name_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_OP_TEST_UTILS_H_
|
||||
|
@ -161,6 +161,8 @@ string DumpCluster(const RemoteFusedGraphExecuteUtils::ClusterInfo& cluster) {
|
||||
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS;
|
||||
/* static */ constexpr const char* const
|
||||
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES;
|
||||
/* static */ constexpr const char* const
|
||||
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR;
|
||||
/* static */ constexpr const char* const
|
||||
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES;
|
||||
/* static */ constexpr const char* const
|
||||
@ -1084,6 +1086,26 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
|
||||
require_shape_type, output_graph_def);
|
||||
}
|
||||
|
||||
/* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
|
||||
const GraphDef& input_graph_def, const std::vector<string>& inputs,
|
||||
const std::vector<string>& outputs, const string& executor_name,
|
||||
GraphDef* output_graph_def) {
|
||||
const ExecutorBuildFunc* build_func = GetExecutorBuildFunc(executor_name);
|
||||
if (build_func == nullptr) {
|
||||
return errors::InvalidArgument("Unknown executor name: " + executor_name);
|
||||
}
|
||||
std::unique_ptr<IRemoteFusedGraphExecutor> executor;
|
||||
TF_RETURN_IF_ERROR((*build_func)(&executor));
|
||||
CHECK_NOTNULL(executor.get());
|
||||
if (!executor->IsEnabled()) {
|
||||
// As this executor is not enabled, just return original graph as is.
|
||||
*output_graph_def = input_graph_def;
|
||||
return Status::OK();
|
||||
}
|
||||
return executor->FuseRemoteGraph(input_graph_def, inputs, outputs,
|
||||
output_graph_def);
|
||||
}
|
||||
|
||||
/* static */ Status RemoteFusedGraphExecuteUtils::PlaceRemoteGraphArguments(
|
||||
const std::vector<string>& inputs, const std::vector<string>& outputs,
|
||||
const std::unordered_set<string>& fused_node_names,
|
||||
@ -1387,6 +1409,28 @@ RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpTypes(
|
||||
return retval;
|
||||
}
|
||||
|
||||
/* static */ std::unordered_set<string>
|
||||
RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions(
|
||||
const GraphDef& graph_def,
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions) {
|
||||
std::unordered_set<string> retval;
|
||||
for (const NodeDef& node_def : graph_def.node()) {
|
||||
std::vector<DataType> dt_vec;
|
||||
std::vector<TensorShape> shape_vec;
|
||||
const Status status =
|
||||
GetOutputTensorShapeType(node_def, &dt_vec, &shape_vec);
|
||||
if (!status.ok()) {
|
||||
shape_vec.clear();
|
||||
}
|
||||
if (ops_definitions.GetOpIdFor(
|
||||
node_def.op(), DataTypeVector(dt_vec.begin(), dt_vec.end())) !=
|
||||
IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) {
|
||||
retval.emplace(node_def.name());
|
||||
}
|
||||
}
|
||||
return retval;
|
||||
}
|
||||
|
||||
/* static */ Status RemoteFusedGraphExecuteUtils::ReplaceInputNodeByPlaceHolder(
|
||||
const string& input, const DataType type, const TensorShape& shape,
|
||||
GraphDef* graph_def) {
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
|
||||
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
@ -59,6 +60,8 @@ class RemoteFusedGraphExecuteUtils {
|
||||
"border_outputs";
|
||||
static constexpr const char* const TRANSFORM_ARG_FUSED_OP_TYPES =
|
||||
"fused_op_types";
|
||||
static constexpr const char* const TRANSFORM_ARG_FUSE_BY_EXECUTOR =
|
||||
"fuse_by_executor";
|
||||
static constexpr const char* const TRANSFORM_ARG_INPUT_TYPES = "input_types";
|
||||
static constexpr const char* const TRANSFORM_ARG_INPUT_SHAPES =
|
||||
"input_shapes";
|
||||
@ -257,6 +260,12 @@ class RemoteFusedGraphExecuteUtils {
|
||||
const std::vector<std::pair<string, Tensor>>& input_tensors,
|
||||
GraphDef* output_graph_def);
|
||||
|
||||
static Status FuseRemoteGraphByExecutor(const GraphDef& input_graph_def,
|
||||
const std::vector<string>& inputs,
|
||||
const std::vector<string>& outputs,
|
||||
const string& executor_name,
|
||||
GraphDef* output_graph_def);
|
||||
|
||||
static bool IsFuseReady(
|
||||
const GraphDef& input_graph_def,
|
||||
const std::vector<std::pair<string, Tensor>>& input_tensors);
|
||||
@ -273,6 +282,10 @@ class RemoteFusedGraphExecuteUtils {
|
||||
static std::unordered_set<string> BuildNodeMapFromOpTypes(
|
||||
const GraphDef& graph_def, const std::unordered_set<string>& op_types);
|
||||
|
||||
static std::unordered_set<string> BuildNodeMapFromOpsDefinitions(
|
||||
const GraphDef& graph_def,
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions);
|
||||
|
||||
private:
|
||||
static void EmplaceTensorShapeType(const string& name, const Tensor& tensor,
|
||||
TensorShapeMap* tensor_shape_map);
|
||||
|
@ -33,6 +33,10 @@ constexpr const char* const NAME_A_PLUS_B = "A_PLUS_B";
|
||||
constexpr float NODE_A_VAL = 2.0f;
|
||||
constexpr float NODE_B_VAL = 3.0f;
|
||||
constexpr float VALUE_TOLERANCE_FLOAT = 1e-8f;
|
||||
constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME0 =
|
||||
"fuse_test_remote_fused_graph_executor0";
|
||||
constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME1 =
|
||||
"fuse_test_remote_fused_graph_executor1";
|
||||
|
||||
static NodeDef* GetNodeDef(const string& name, GraphDef* def) {
|
||||
CHECK_NE(def, nullptr);
|
||||
@ -44,17 +48,38 @@ static NodeDef* GetNodeDef(const string& name, GraphDef* def) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Status BuildRemoteFusedGraphExecutor0(
|
||||
std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
|
||||
executor->reset(
|
||||
new TestRemoteFusedGraphExecutor({"Mul"}, REMOTE_FUSED_EXECUTOR_NAME0));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BuildRemoteFusedGraphExecutor1(
|
||||
std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
|
||||
executor->reset(new TestRemoteFusedGraphExecutor(
|
||||
{"Const", "Mul"}, REMOTE_FUSED_EXECUTOR_NAME1));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
class FuseRemoteGraphMultipleAddOpsTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() final {
|
||||
TF_ASSERT_OK(
|
||||
RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(&graph_def_));
|
||||
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
|
||||
k_hexagon_remote_fused_graph_executor_build(
|
||||
hexagon_remote_fused_graph_executor_build(
|
||||
"remote_graph_executor_name",
|
||||
[](std::unique_ptr<IRemoteFusedGraphExecutor>* executor) -> Status {
|
||||
return Status::OK();
|
||||
});
|
||||
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
|
||||
test_remote_fused_graph_executor_build0(REMOTE_FUSED_EXECUTOR_NAME0,
|
||||
BuildRemoteFusedGraphExecutor0);
|
||||
|
||||
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
|
||||
test_remote_fused_graph_executor_build1(REMOTE_FUSED_EXECUTOR_NAME1,
|
||||
BuildRemoteFusedGraphExecutor1);
|
||||
}
|
||||
|
||||
void TearDown() final {}
|
||||
@ -87,6 +112,18 @@ class FuseRemoteGraphMultipleAddOpsTest : public ::testing::Test {
|
||||
/*require_shape_type=*/false, &result_graph_def_);
|
||||
}
|
||||
|
||||
Status FuseByExecutor0() {
|
||||
return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
|
||||
graph_def_, inputs_, outputs_, REMOTE_FUSED_EXECUTOR_NAME0,
|
||||
&result_graph_def_);
|
||||
}
|
||||
|
||||
Status FuseByExecutor1() {
|
||||
return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
|
||||
graph_def_, inputs_, outputs_, REMOTE_FUSED_EXECUTOR_NAME1,
|
||||
&result_graph_def_);
|
||||
}
|
||||
|
||||
Status BuildAndAddTensorShape() {
|
||||
return RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes(
|
||||
input_tensors_, /*dry_run_inference=*/true, &graph_def_);
|
||||
@ -694,6 +731,30 @@ TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByOpTypes_FGHIJ) {
|
||||
<< SummarizeGraphDef(result_graph_def_);
|
||||
}
|
||||
|
||||
TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByExecutor_HIJ) {
|
||||
ReplaceOpType({"H", "I", "J"}, "Mul");
|
||||
|
||||
TF_ASSERT_OK(FuseByExecutor0());
|
||||
|
||||
EXPECT_EQ(11, graph_def_.node_size());
|
||||
EXPECT_EQ(9, result_graph_def_.node_size())
|
||||
<< "=== Before: \n"
|
||||
<< SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
|
||||
<< SummarizeGraphDef(result_graph_def_);
|
||||
}
|
||||
|
||||
TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByExecutor_FGHIJ) {
|
||||
ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul");
|
||||
|
||||
TF_ASSERT_OK(FuseByExecutor1());
|
||||
|
||||
EXPECT_EQ(11, graph_def_.node_size());
|
||||
EXPECT_EQ(3, result_graph_def_.node_size())
|
||||
<< "=== Before: \n"
|
||||
<< SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
|
||||
<< SummarizeGraphDef(result_graph_def_);
|
||||
}
|
||||
|
||||
TEST_F(FuseRemoteGraphMultipleAddOpsTest, PlaceAndFuse_H) {
|
||||
subgraph_node_names_ = {"H"};
|
||||
|
||||
|
@ -66,7 +66,7 @@ static Status ParseArguments(const TransformFuncContext& context,
|
||||
string* input_types_str, string* input_shapes_str,
|
||||
string* fused_nodes_str, string* border_inputs_str,
|
||||
string* border_outputs_str,
|
||||
string* fused_op_types_str,
|
||||
string* fused_op_types_str, bool* fuse_by_executor,
|
||||
string* remote_fused_graph_node_name,
|
||||
string* remote_graph_executor_name) {
|
||||
TF_RETURN_IF_ERROR(context.GetOneStringParameter(
|
||||
@ -87,6 +87,9 @@ static Status ParseArguments(const TransformFuncContext& context,
|
||||
TF_RETURN_IF_ERROR(context.GetOneStringParameter(
|
||||
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES, "",
|
||||
fused_op_types_str));
|
||||
TF_RETURN_IF_ERROR(context.GetOneBoolParameter(
|
||||
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR, false,
|
||||
fuse_by_executor));
|
||||
TF_RETURN_IF_ERROR(context.GetOneStringParameter(
|
||||
RemoteFusedGraphExecuteUtils::
|
||||
TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
|
||||
@ -140,12 +143,14 @@ Status FuseRemoteGraph(const GraphDef& input_graph_def,
|
||||
string border_inputs_str;
|
||||
string border_outputs_str;
|
||||
string fused_op_types_str;
|
||||
bool fuse_by_executor = false;
|
||||
string remote_fused_graph_node_name;
|
||||
string remote_graph_executor_name;
|
||||
TF_RETURN_IF_ERROR(ParseArguments(
|
||||
context, &input_types_str, &input_shapes_str, &fused_nodes_str,
|
||||
&border_inputs_str, &border_outputs_str, &fused_op_types_str,
|
||||
&remote_fused_graph_node_name, &remote_graph_executor_name));
|
||||
&fuse_by_executor, &remote_fused_graph_node_name,
|
||||
&remote_graph_executor_name));
|
||||
|
||||
if (!input_types_str.empty()) {
|
||||
TF_RETURN_IF_ERROR(PlaceShapeType(inputs, outputs, input_types_str,
|
||||
@ -187,6 +192,10 @@ Status FuseRemoteGraph(const GraphDef& input_graph_def,
|
||||
mutable_input_graph_def, inputs, outputs, remote_fused_graph_node_name,
|
||||
fused_op_types, remote_graph_executor_name, require_shape_type,
|
||||
output_graph_def));
|
||||
} else if (fuse_by_executor) {
|
||||
TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
|
||||
mutable_input_graph_def, inputs, outputs, remote_graph_executor_name,
|
||||
output_graph_def));
|
||||
} else {
|
||||
CHECK(false) << "Fuse targets are not specified.";
|
||||
}
|
||||
@ -205,15 +214,17 @@ Status PlaceRemoteGraphArguments(const GraphDef& input_graph_def,
|
||||
string input_types_str;
|
||||
string input_shapes_str;
|
||||
string fused_nodes_str;
|
||||
string fused_op_types_str;
|
||||
string border_inputs_str;
|
||||
string border_outputs_str;
|
||||
string fused_op_types_str;
|
||||
bool fuse_by_executor = false;
|
||||
string remote_fused_graph_node_name;
|
||||
string remote_graph_executor_name;
|
||||
TF_RETURN_IF_ERROR(ParseArguments(
|
||||
context, &input_types_str, &input_shapes_str, &fused_nodes_str,
|
||||
&border_inputs_str, &border_outputs_str, &fused_op_types_str,
|
||||
&remote_fused_graph_node_name, &remote_graph_executor_name));
|
||||
&fuse_by_executor, &remote_fused_graph_node_name,
|
||||
&remote_graph_executor_name));
|
||||
|
||||
if (!input_types_str.empty()) {
|
||||
TF_RETURN_IF_ERROR(PlaceShapeType(inputs, outputs, input_types_str,
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/default_device.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/graph/testlib.h"
|
||||
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
|
||||
#include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h"
|
||||
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
@ -43,11 +44,28 @@ Status PlaceRemoteGraphArguments(const GraphDef& input_graph_def,
|
||||
GraphDef* output_graph_def);
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr const char* const REMOTE_FUSED_GRAPH_EXECUTOR_NAME =
|
||||
"remote_fused_graph_executor_name";
|
||||
constexpr const char* const REMOTE_FUSED_GRAPH_NODE_NAME =
|
||||
"remote_fused_graph_node_name";
|
||||
constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME0 =
|
||||
"fuse_test_remote_fused_graph_executor0";
|
||||
constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME1 =
|
||||
"fuse_test_remote_fused_graph_executor1";
|
||||
|
||||
Status BuildRemoteFusedGraphExecutor0(
|
||||
std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
|
||||
executor->reset(
|
||||
new TestRemoteFusedGraphExecutor({"Mul"}, REMOTE_FUSED_EXECUTOR_NAME0));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BuildRemoteFusedGraphExecutor1(
|
||||
std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
|
||||
executor->reset(new TestRemoteFusedGraphExecutor(
|
||||
{"Const", "Mul"}, REMOTE_FUSED_EXECUTOR_NAME1));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
|
||||
protected:
|
||||
@ -55,11 +73,18 @@ class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
|
||||
TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(
|
||||
&input_graph_def_));
|
||||
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
|
||||
k_hexagon_remote_fused_graph_executor_build(
|
||||
hexagon_remote_fused_graph_executor_build(
|
||||
REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
|
||||
[](std::unique_ptr<IRemoteFusedGraphExecutor>* executor) -> Status {
|
||||
return Status::OK();
|
||||
});
|
||||
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
|
||||
test_remote_fused_graph_executor_build0(REMOTE_FUSED_EXECUTOR_NAME0,
|
||||
BuildRemoteFusedGraphExecutor0);
|
||||
|
||||
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
|
||||
test_remote_fused_graph_executor_build1(REMOTE_FUSED_EXECUTOR_NAME1,
|
||||
BuildRemoteFusedGraphExecutor1);
|
||||
}
|
||||
|
||||
void TearDown() final {}
|
||||
@ -113,10 +138,16 @@ class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
|
||||
{fused_op_types_str_}}));
|
||||
}
|
||||
|
||||
if (fuse_by_executor_) {
|
||||
context.params.insert(std::pair<string, std::vector<string>>(
|
||||
{RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR,
|
||||
{"true"}}));
|
||||
}
|
||||
|
||||
context.params.insert(std::pair<string, std::vector<string>>(
|
||||
{RemoteFusedGraphExecuteUtils::
|
||||
TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
|
||||
{REMOTE_FUSED_GRAPH_EXECUTOR_NAME}}));
|
||||
{remote_fused_graph_executor_name_}}));
|
||||
context.params.insert(std::pair<string, std::vector<string>>(
|
||||
{RemoteFusedGraphExecuteUtils::
|
||||
TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME,
|
||||
@ -160,7 +191,7 @@ class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
|
||||
ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO,
|
||||
&serialized_proto));
|
||||
info.ParseFromString(serialized_proto);
|
||||
CHECK_EQ(REMOTE_FUSED_GRAPH_EXECUTOR_NAME, info.executor_name());
|
||||
CHECK_EQ(remote_fused_graph_executor_name_, info.executor_name());
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(expected_cluster_count, cluster_count);
|
||||
@ -178,6 +209,8 @@ class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
|
||||
string border_inputs_str_;
|
||||
string border_outputs_str_;
|
||||
string fused_op_types_str_;
|
||||
string remote_fused_graph_executor_name_{REMOTE_FUSED_GRAPH_EXECUTOR_NAME};
|
||||
bool fuse_by_executor_{false};
|
||||
};
|
||||
|
||||
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
|
||||
@ -260,6 +293,24 @@ TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
|
||||
CheckGraph(3, 1);
|
||||
}
|
||||
|
||||
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
|
||||
FuseRemoteGraphByExecutor_HIJ) {
|
||||
ReplaceOpType({"H", "I", "J"}, "Mul");
|
||||
remote_fused_graph_executor_name_ = REMOTE_FUSED_EXECUTOR_NAME0;
|
||||
fuse_by_executor_ = true;
|
||||
TF_ASSERT_OK(Fuse());
|
||||
CheckGraph(9, 1);
|
||||
}
|
||||
|
||||
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
|
||||
FuseRemoteGraphByExecutor_FGHIJ) {
|
||||
ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul");
|
||||
remote_fused_graph_executor_name_ = REMOTE_FUSED_EXECUTOR_NAME1;
|
||||
fuse_by_executor_ = true;
|
||||
TF_ASSERT_OK(Fuse());
|
||||
CheckGraph(3, 1);
|
||||
}
|
||||
|
||||
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_HIJ) {
|
||||
fused_node_names_str_ = "H,I,J";
|
||||
TF_ASSERT_OK(PlaceFuseArgs());
|
||||
|
Loading…
Reference in New Issue
Block a user