Add a way to fuse a graph by remote graph executor so that users don't need to be aware of supported op types, node names, subgraph stracture etc.

PiperOrigin-RevId: 162411763
This commit is contained in:
A. Unique TensorFlower 2017-07-18 15:20:03 -07:00 committed by TensorFlower Gardener
parent d8672f1839
commit d5f4d9bbac
31 changed files with 474 additions and 299 deletions

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
@ -110,7 +111,7 @@ static void DumpRemoteFusedGraph(const NodeDef& node_def) {
static void CheckOpsSupport(const GraphDef& graph_def,
const bool dump_all_nodes,
const bool dump_shape_and_type) {
const IGraphTransferOpsDefinitions& ops_definition =
const IRemoteFusedGraphOpsDefinitions& ops_definition =
HexagonOpsDefinitions::getInstance();
LOG(INFO) << "Checking " << graph_def.node_size() << " nodes";
LOG(INFO) << "dump_all_nodes = " << dump_all_nodes
@ -127,7 +128,7 @@ static void CheckOpsSupport(const GraphDef& graph_def,
}
// TODO(satok): Set correct data type if it's given.
const int op_id = ops_definition.GetOpIdFor(node.op(), {});
if (op_id == IGraphTransferOpsDefinitions::INVALID_OP_ID) {
if (op_id == IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) {
all_supported = false;
LOG(ERROR) << "OP type: " << node.op() << " is not supported on hvx. "
<< "Name = " << node.name();

View File

@ -82,7 +82,7 @@ fi
if [[ "${USE_HEXAGON}" == "true" ]]; then
HEXAGON_PARENT_DIR=$(cd "${HEXAGON_DOWNLOAD_PATH}" >/dev/null && pwd)
HEXAGON_LIBS="${HEXAGON_PARENT_DIR}/libs"
HEXAGON_INCLUDE=$(cd "tensorflow/core/platform/hexagon" >/dev/null && pwd)
HEXAGON_INCLUDE=$(cd "tensorflow/core/kernels/hexagon" >/dev/null && pwd)
fi
if [[ "${ENABLE_EXPERIMENTAL_HEXNN_OPS}" == "true" ]]; then

View File

@ -56,7 +56,6 @@ tensorflow/core/platform/posix/test.cc
QUANTIZATION_TEST_SRCS := \
$(GRAPH_TRANSFER_SRCS) \
tensorflow/core/kernels/hexagon/quantized_matmul_op_for_hexagon_test.cc \
tensorflow/core/kernels/hexagon/graph_transferer_test.cc \
tensorflow/contrib/makefile/test/test_main.cc

View File

@ -5057,9 +5057,13 @@ tf_kernel_library(
cc_library(
name = "remote_fused_graph_execute_utils",
srcs = ["remote_fused_graph_execute_utils.cc"],
srcs = [
"i_remote_fused_graph_ops_definitions.cc",
"remote_fused_graph_execute_utils.cc",
],
hdrs = [
"i_remote_fused_graph_executor.h",
"i_remote_fused_graph_ops_definitions.h",
"remote_fused_graph_execute_utils.h",
],
deps = [
@ -5078,6 +5082,7 @@ cc_library(
srcs = ["remote_fused_graph_execute_op_test_utils.cc"],
hdrs = ["remote_fused_graph_execute_op_test_utils.h"],
deps = [
":remote_fused_graph_execute_utils",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:scope",

View File

@ -26,23 +26,6 @@ filegroup(
visibility = ["//tensorflow:__subpackages__"],
)
tf_cc_test(
name = "quantized_matmul_op_for_hexagon_test",
size = "small",
srcs = ["quantized_matmul_op_for_hexagon_test.cc"],
tags = ["nomsan"], # http://b/32242946
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:ops_testutil",
"//tensorflow/core/kernels:ops_util",
"//tensorflow/core/kernels:quantized_ops",
],
)
tf_cc_test(
name = "graph_transferer_test",
size = "small",
@ -79,14 +62,14 @@ tf_kernel_library(
"graph_transferer.cc",
"hexagon_control_wrapper.cc",
"hexagon_ops_definitions.cc",
"i_graph_transfer_ops_definitions.cc",
"soc_interface.cc",
],
hdrs = [
"graph_transfer_utils.h",
"graph_transferer.h",
"hexagon_control_wrapper.h",
"hexagon_ops_definitions.h",
"i_graph_transfer_ops_definitions.h",
"soc_interface.h",
],
deps = [
"//tensorflow/cc:cc_ops",
@ -111,6 +94,7 @@ cc_library(
":graph_transferer",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:remote_fused_graph_ops",
"//tensorflow/core/kernels:remote_fused_graph_execute_utils",
"//tensorflow/tools/graph_transforms:transform_utils",
],
alwayslink = 1,
@ -121,6 +105,7 @@ tf_cc_test(
size = "small",
srcs = ["hexagon_rewriter_transform_test.cc"],
deps = [
":graph_transferer",
":hexagon_rewriter_transform",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:core_cpu",
@ -129,6 +114,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:remote_fused_graph_execute_utils",
"//tensorflow/tools/graph_transforms:transform_utils",
],
)

View File

@ -96,7 +96,7 @@ GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo(
}
/* static */ GraphDef GraphTransferUtils::BuildFusedGraphDef(
const IGraphTransferOpsDefinitions& ops_definitions,
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const string& remote_graph_execute_name,
const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& outputs, GraphDef* original_def) {

View File

@ -39,7 +39,7 @@ class GraphTransferUtils {
const int element_count, const int top_n);
static GraphDef BuildFusedGraphDef(
const IGraphTransferOpsDefinitions& ops_definitions,
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const string& remote_graph_execute_name,
const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& outputs, GraphDef* original_def);

View File

@ -81,7 +81,7 @@ static Node* FindMutableNodeByName(const string& name, Graph* graph) {
* of node to transfer the graph to SOC.
*/
Status GraphTransferer::LoadGraphFromProto(
const IGraphTransferOpsDefinitions& ops_definitions,
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const GraphDef& graph_def,
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const std::vector<string>& output_node_names,
@ -177,9 +177,6 @@ Status GraphTransferer::LoadGraphFromProto(
}
}
graph_transfer_info_.set_destination(
ops_definitions.GetTransferDestination());
ClearCache();
if (DBG_DUMP_PARAMS) {
DumpNodeTransferParams();
@ -191,7 +188,7 @@ Status GraphTransferer::LoadGraphFromProto(
}
Status GraphTransferer::LoadGraphFromProtoFile(
const IGraphTransferOpsDefinitions& ops_definitions,
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const string& graph_def_path,
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const std::vector<string>& output_node_names, const bool is_text_proto,
@ -415,7 +412,7 @@ Status GraphTransferer::TransformGraphToAddAggregatedInputNode(
}
Status GraphTransferer::RegisterNode(
const IGraphTransferOpsDefinitions& ops_definitions,
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node,
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const std::vector<string>& output_node_names) {
@ -438,7 +435,7 @@ Status GraphTransferer::RegisterNode(
} else if (IsNodeFlattenReshape(node, shape_refiner)) {
RegisterFlattenNode(ops_definitions, shape_refiner, node);
} else if (ops_definitions.GetOpIdFor(node.type_string(), {}) !=
IGraphTransferOpsDefinitions::INVALID_OP_ID) {
IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) {
// TODO(satok): Set correct data type if it's given.
RegisterGenericNode(ops_definitions, shape_refiner, node);
} else {
@ -637,7 +634,7 @@ bool GraphTransferer::IsNodeFlattenReshape(const Node& node,
}
void GraphTransferer::RegisterNodeWithPaddingAndStrides(
const IGraphTransferOpsDefinitions& ops_definitions,
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node) {
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
const int id = node_name_to_id_cache_map_[node.name()];
@ -671,7 +668,7 @@ void GraphTransferer::RegisterNodeWithPaddingAndStrides(
}
void GraphTransferer::RegisterNodeWithRank(
const IGraphTransferOpsDefinitions& ops_definitions,
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node) {
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
const int id = node_name_to_id_cache_map_[node.name()];
@ -704,7 +701,7 @@ void GraphTransferer::RegisterNodeWithRank(
}
void GraphTransferer::RegisterPadNode(
const IGraphTransferOpsDefinitions& ops_definitions,
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node) {
static constexpr int PAD_WIDTH = 4;
static constexpr int PAD_HEIGHT = 2;
@ -779,7 +776,7 @@ void GraphTransferer::RegisterPadNode(
}
void GraphTransferer::RegisterInputNode(
const IGraphTransferOpsDefinitions& ops_definitions,
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node) {
const string op_type = node.type_string();
VLOG(1) << "Register input node: " << node.name() << ", " << op_type;
@ -797,12 +794,13 @@ void GraphTransferer::RegisterInputNode(
}
void GraphTransferer::RegisterFlattenNode(
const IGraphTransferOpsDefinitions& ops_definitions,
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node) {
VLOG(1) << "Register flatten node: " << node.name();
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
const int id = node_name_to_id_cache_map_[node.name()];
const string op_type = IGraphTransferOpsDefinitions::FLATTEN_OP_NAME;
// TODO(satok): Remove dependency to specific type
const string op_type = "FLATTEN";
// TODO(satok): Set correct data type if it's given.
const int op_type_id = ops_definitions.GetOpIdFor(op_type, {});
CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount());
@ -814,7 +812,7 @@ void GraphTransferer::RegisterFlattenNode(
}
void GraphTransferer::RegisterGenericNode(
const IGraphTransferOpsDefinitions& ops_definitions,
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node) {
VLOG(1) << "Register generic node: " << node.name();
CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
@ -832,7 +830,7 @@ void GraphTransferer::RegisterGenericNode(
// TODO(satok): Remove this function.
// TODO(satok): Remove only_register_const_node.
Status GraphTransferer::RegisterNodeIfAllInputsAreCached(
const IGraphTransferOpsDefinitions& ops_definitions,
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node,
const bool only_register_const_node,
const std::vector<std::pair<string, Tensor>>& input_node_info_list,

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph_transfer_info.pb.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
@ -53,7 +53,7 @@ class GraphTransferer {
// TODO(satok): Pass a pair of TensorShape and DataType instead of
// Tensor as input_node_info_list.
Status LoadGraphFromProto(
const IGraphTransferOpsDefinitions& ops_definitions,
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const GraphDef& graph_def,
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const std::vector<string>& output_node_names,
@ -63,7 +63,7 @@ class GraphTransferer {
// TODO(satok): Pass a pair of TensorShape and DataType instead of
// Tensor as input_node_info_list.
Status LoadGraphFromProtoFile(
const IGraphTransferOpsDefinitions& ops_definitions,
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const string& graph_def_path,
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const std::vector<string>& output_node_names, const bool is_text_proto,
@ -112,7 +112,7 @@ class GraphTransferer {
Graph* graph, ShapeRefiner* shape_refiner);
Status RegisterNode(
const IGraphTransferOpsDefinitions& ops_definitions,
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node,
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const std::vector<string>& output_node_names);
@ -140,30 +140,29 @@ class GraphTransferer {
const ShapeRefiner& shape_refiner);
void RegisterNodeWithPaddingAndStrides(
const IGraphTransferOpsDefinitions& ops_definitions,
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node);
void RegisterNodeWithRank(const IGraphTransferOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner,
const Node& node);
void RegisterNodeWithRank(
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node);
void RegisterPadNode(const IGraphTransferOpsDefinitions& ops_definitions,
void RegisterPadNode(const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node);
void RegisterInputNode(const IGraphTransferOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner,
const Node& node);
void RegisterInputNode(const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node);
void RegisterFlattenNode(const IGraphTransferOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner,
const Node& node);
void RegisterFlattenNode(
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node);
void RegisterGenericNode(const IGraphTransferOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner,
const Node& node);
void RegisterGenericNode(
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node);
Status RegisterNodeIfAllInputsAreCached(
const IGraphTransferOpsDefinitions& ops_definitions,
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node,
const bool only_register_const_node,
const std::vector<std::pair<string, Tensor>>& input_node_info_list,

View File

@ -22,8 +22,8 @@ limitations under the License.
#include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h"
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
#include "tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
@ -50,7 +50,7 @@ class GraphTransfererTest : public ::testing::Test {
const RemoteFusedGraphExecuteUtils::TensorShapeMap EMPTY_OUTPUT_TENSOR_MAP;
class TestGraphTransferOpsDefinitions : public IGraphTransferOpsDefinitions {
class TestGraphTransferOpsDefinitions : public IRemoteFusedGraphOpsDefinitions {
public:
int GetTotalOpsCount() const final { return op_types_.size(); }
@ -63,10 +63,6 @@ class TestGraphTransferOpsDefinitions : public IGraphTransferOpsDefinitions {
return -1;
}
GraphTransferInfo::Destination GetTransferDestination() const final {
return GraphTransferInfo::NOP;
}
private:
const std::vector<string> op_types_{"INPUT", "OUTPUT", "Conv2D",
"MaxPool", "NoOp", "Add",
@ -371,14 +367,14 @@ TEST_F(GraphTransfererTest, LoadMaxPoolGraph) {
}
TEST(HexagonOpsDefinitions, CheckOpsDefinitions) {
const IGraphTransferOpsDefinitions& ops_definitions =
const IRemoteFusedGraphOpsDefinitions& ops_definitions =
HexagonOpsDefinitions::getInstance();
const int total_ops_count = ops_definitions.GetTotalOpsCount();
EXPECT_GT(total_ops_count, 0);
}
TEST(GraphTransferer, LoadGraphFromProtoFile) {
const IGraphTransferOpsDefinitions* ops_definitions =
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
&TEST_GRAPH_TRANSFER_OPS_DEFINITIONS;
string filename =
io::JoinPath(testing::TensorFlowSrcRoot(),
@ -441,7 +437,7 @@ void CompareGraphTransferInfo(const GraphTransferInfo& a,
} // anonymous namespace
TEST(GraphTransferer, LoadGraphFromProtoFileShapeInferenceSimple) {
const IGraphTransferOpsDefinitions* ops_definitions =
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
&TEST_GRAPH_TRANSFER_OPS_DEFINITIONS;
string filename =
io::JoinPath(testing::TensorFlowSrcRoot(),

View File

@ -16,18 +16,19 @@ limitations under the License.
#include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h"
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
#ifdef USE_HEXAGON_LIBS
#include "tensorflow/core/platform/hexagon/soc_interface.h"
#include "tensorflow/core/kernels/hexagon/soc_interface.h"
#include "tensorflow/core/platform/profile_utils/cpu_utils.h"
#endif
namespace tensorflow {
constexpr const char* const INPUT_OP_NAME = "INPUT";
constexpr const char* const OUTPUT_OP_NAME = "OUTPUT";
constexpr const char* const REMOTE_FUSED_GRAPH_NODE_NAME_PREFIX =
"hexagon_remote_fused_graph";
/* static */ constexpr const char* const
HexagonControlWrapper::REMOTE_FUSED_GRAPH_EXECUTOR_NAME;
constexpr int ALIGNMENT_BYTES = 16;
constexpr int MAX_IN_OUT_COUNT = 128;
const bool DBG_DUMP_VERIFICATION_STRING = false;
const int DBG_LEVEL = 0; // -2: verbose, -1: debug, 0: info
@ -63,7 +64,6 @@ static uint8* FindAlignedPointer(uint8* ptr) {
return nullptr;
}
#ifdef USE_HEXAGON_LIBS
int HexagonControlWrapper::GetVersion() {
return soc_interface_GetSocControllerVersion();
}
@ -95,7 +95,6 @@ bool HexagonControlWrapper::Init(const RemoteFusedGraphExecuteInfo& info) {
LOG(ERROR) << "Hexagon initialization was failed. See log output.";
return false;
}
const GraphTransferInfo& gt_info = graph_transferer_.GetGraphTransferInfo();
std::vector<int> input_sizes;
std::vector<int> output_sizes;
CHECK_NOTNULL(execute_info_);
@ -207,8 +206,9 @@ bool HexagonControlWrapper::SetupGraph() {
for (const GraphTransferInfo::NodeInputInfo& input_params :
graph_transfer_info.node_input_info()) {
const int count = input_params.node_input_size();
int node_ids[count];
int ports[count];
CHECK(count <= MAX_IN_OUT_COUNT);
int node_ids[MAX_IN_OUT_COUNT];
int ports[MAX_IN_OUT_COUNT];
for (int i = 0; i < count; ++i) {
const GraphTransferInfo::NodeInput& node_input =
input_params.node_input(i);
@ -226,7 +226,8 @@ bool HexagonControlWrapper::SetupGraph() {
for (const GraphTransferInfo::NodeOutputInfo& output_params :
graph_transfer_info.node_output_info()) {
const int count = output_params.max_byte_size_size();
int sizes[count];
CHECK(count <= MAX_IN_OUT_COUNT);
int sizes[MAX_IN_OUT_COUNT];
for (int i = 0; i < count; ++i) {
const int size = output_params.max_byte_size(i);
sizes[i] = size;
@ -373,6 +374,7 @@ bool HexagonControlWrapper::ReadOutputNode(
<< output_tensor->TotalBytes() << ", " << std::get<1>(output);
TF_CHECK_OK(RemoteFusedGraphExecuteUtils::CopyByteArrayToTensor(
std::get<0>(output), std::get<1>(output), output_tensor));
return true;
}
bool HexagonControlWrapper::ReadOutputNode(
@ -382,14 +384,30 @@ bool HexagonControlWrapper::ReadOutputNode(
const string tensor_name = AddPort(node_name);
CHECK(output_port_map_.count(tensor_name) > 0);
const int port = output_port_map_.at(tensor_name);
soc_interface_ReadOutputNodeWithPort(port, &std::get<0>(output),
&std::get<1>(output));
soc_interface_ReadOutputNodeWithPort(
port, &std::get<0>(output),
reinterpret_cast<uint64_t*>(&std::get<1>(output)));
// TODO: Accept all results
// std::get<2>(output) = DT_FLOAT;
outputs->emplace_back(output);
return true;
}
Status HexagonControlWrapper::FuseRemoteGraph(
const GraphDef& original_graph_def, const std::vector<string>& inputs,
const std::vector<string>& outputs, GraphDef* fused_graph_def) {
const std::unordered_set<string> fused_node_names =
RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions(
original_graph_def, HexagonOpsDefinitions::getInstance());
// TODO(satok): We may want to place shape and type inside this function
// if they are not placed in the given graph.
TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::FuseRemoteGraphByNodeNames(
original_graph_def, inputs, outputs, REMOTE_FUSED_GRAPH_NODE_NAME_PREFIX,
fused_node_names, REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
/*require_shape_type=*/true, fused_graph_def));
return Status::OK();
}
bool HexagonControlWrapper::FillInputNode(const string& node_name,
const Tensor& tensor) {
StringPiece tensor_data = tensor.tensor_data();
@ -415,31 +433,5 @@ bool HexagonControlWrapper::FillInputNode(const string& node_name,
return true;
}
#else
int HexagonControlWrapper::GetVersion() { return -1; }
bool HexagonControlWrapper::Init(const RemoteFusedGraphExecuteInfo&) {
return false;
}
bool HexagonControlWrapper::Finalize() { return false; }
bool HexagonControlWrapper::SetupGraph() { return false; }
bool HexagonControlWrapper::ExecuteGraph() { return false; }
bool HexagonControlWrapper::TeardownGraph() { return false; }
bool HexagonControlWrapper::FillInputNode(
const string&, const std::array<int64, GraphTransferer::SHAPE_ARRAY_SIZE>&,
const ConstByteArray) {
return false;
}
bool HexagonControlWrapper::FillInputNode(const string&, const Tensor&) {
return false;
}
bool HexagonControlWrapper::ReadOutputNode(
const string& node_name, TensorAllocatorFunc tensor_allocator) {
return false;
}
bool HexagonControlWrapper::ReadOutputNode(const string&,
std::vector<ByteArray>* const) {
return false;
}
#endif
bool HexagonControlWrapper::IsEnabled() const { return true; };
} // namespace tensorflow

View File

@ -35,6 +35,8 @@ class HexagonControlWrapper final : public IRemoteFusedGraphExecutor {
public:
using ByteArray =
std::tuple<uint8* /* data */, uint64 /* size */, DataType /* type */>;
static constexpr const char* const REMOTE_FUSED_GRAPH_EXECUTOR_NAME =
"build_hexagon_remote_fused_graph_executor";
HexagonControlWrapper() = default;
int GetVersion() final;
@ -46,6 +48,11 @@ class HexagonControlWrapper final : public IRemoteFusedGraphExecutor {
bool FillInputNode(const string& node_name, const Tensor& tensor) final;
bool ReadOutputNode(const string& node_name,
TensorAllocatorFunc tensor_allocator) final;
Status FuseRemoteGraph(const GraphDef& original_graph_def,
const std::vector<string>& inputs,
const std::vector<string>& outputs,
GraphDef* fused_graph_def) final;
bool IsEnabled() const final;
bool ReadOutputNode(const string& node_name, std::vector<ByteArray>* outputs);
private:

View File

@ -35,8 +35,8 @@ adb push /tmp/imagenet_comp_graph_label_strings.txt /data/local/tmp
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
#include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h"
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
#include "tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/core/status.h"
@ -389,9 +389,6 @@ static void CompareGraphTransferInfo(const GraphTransferInfo& gfi0,
EXPECT_EQ(goni0.ByteSizeLong(), goni1.ByteSizeLong());
EXPECT_EQ(goni0.DebugString(), goni1.DebugString());
}
// 7. check destination
EXPECT_EQ(gfi0.destination(), gfi1.destination());
}
// CAVEAT: This test only runs when you specify hexagon library using
@ -407,7 +404,7 @@ TEST(GraphTransferer,
LOG(INFO) << "Run inception v3 on hexagon with hexagon controller";
CheckHexagonControllerVersion();
const IGraphTransferOpsDefinitions* ops_definitions =
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
&HexagonOpsDefinitions::getInstance();
std::vector<std::pair<string, Tensor>> inputs;
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
@ -474,7 +471,7 @@ TEST(GraphTransferer,
LOG(INFO) << "Run inception v3 on hexagon with hexagon controller";
CheckHexagonControllerVersion();
const IGraphTransferOpsDefinitions* ops_definitions =
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
&HexagonOpsDefinitions::getInstance();
std::vector<std::pair<string, Tensor>> inputs;
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
@ -505,7 +502,7 @@ TEST(GraphTransferer, RunInceptionV3OnHexagonExampleWithTfRuntime) {
LOG(INFO) << "Fuse and run inception v3 on hexagon with tf runtime";
CheckHexagonControllerVersion();
const IGraphTransferOpsDefinitions* ops_definitions =
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
&HexagonOpsDefinitions::getInstance();
std::vector<std::pair<string, Tensor>> inputs;
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
@ -543,7 +540,7 @@ TEST(GraphTransferer, DISABLED_CheckShapeInferencePerformance) {
CheckHexagonControllerVersion();
profile_utils::CpuUtils::EnableClockCycleProfiling(true);
const IGraphTransferOpsDefinitions* ops_definitions =
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
&HexagonOpsDefinitions::getInstance();
std::vector<std::pair<string, Tensor>> inputs;
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));

View File

@ -304,8 +304,8 @@ HexagonOpsDefinitions::BuildOpNameToSocOpTypeMap() {
EmplaceOpType("INPUT", {}, SupportedOpType::INPUT, &op_map);
EmplaceOpType("OUTPUT", {}, SupportedOpType::OUTPUT, &op_map);
EmplaceOpType("NoOp", {}, SupportedOpType::NOP, &op_map);
EmplaceOpType(IGraphTransferOpsDefinitions::FLATTEN_OP_NAME, {},
SupportedOpType::FLATTEN, &op_map);
// Special op type for hexagon
EmplaceOpType("FLATTEN", {}, SupportedOpType::FLATTEN, &op_map);
// Tensorflow op name
// CAVEAT: Keep order of SupportedOpType
EmplaceOpType("Identity", {}, SupportedOpType::NOP, &op_map);
@ -373,7 +373,7 @@ HexagonOpsDefinitions::BuildOpNameToSocOpTypeMap() {
HexagonOpsDefinitions::HexagonOpsDefinitions()
: op_name_to_soc_op_type_map_(BuildOpNameToSocOpTypeMap()) {}
/* static */ const IGraphTransferOpsDefinitions&
/* static */ const IRemoteFusedGraphOpsDefinitions&
HexagonOpsDefinitions::getInstance() {
const static HexagonOpsDefinitions instance{};
return instance;
@ -393,17 +393,17 @@ int HexagonOpsDefinitions::GetOpIdFor(const string& op_type,
if (dt_vec.empty()) {
return static_cast<int>(std::get<1>(dt_to_op_vec.front()));
}
// If there is only one op_id registered for empty op_vec, we assume
// that the op supports any data types.
if (dt_to_op_vec.size() == 1 && std::get<0>(dt_to_op_vec.front()).empty()) {
return static_cast<int>(std::get<1>(dt_to_op_vec.front()));
}
for (const DataTypeToOp& data_type_to_op : dt_to_op_vec) {
if (std::get<0>(data_type_to_op) == dt_vec) {
return static_cast<int>(std::get<1>(data_type_to_op));
}
}
}
return IGraphTransferOpsDefinitions::INVALID_OP_ID;
}
GraphTransferInfo::Destination HexagonOpsDefinitions::GetTransferDestination()
const {
return GraphTransferInfo::HEXAGON;
return IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID;
}
} // namespace tensorflow

View File

@ -18,21 +18,20 @@ limitations under the License.
#include <unordered_map>
#include "i_graph_transfer_ops_definitions.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
// HexagonOpsDefinitions provides ops definitions supported in hexagon library
// TODO(satok): add a functionality to call functions in hexagon library
class HexagonOpsDefinitions final : public IGraphTransferOpsDefinitions {
class HexagonOpsDefinitions final : public IRemoteFusedGraphOpsDefinitions {
public:
static const IGraphTransferOpsDefinitions& getInstance();
static const IRemoteFusedGraphOpsDefinitions& getInstance();
int GetTotalOpsCount() const final;
int GetOpIdFor(const string& op_type, const DataTypeVector& dt) const final;
GraphTransferInfo::Destination GetTransferDestination() const final;
private:
enum class SupportedOpType;

View File

@ -27,7 +27,7 @@ Status BuildRemoteFusedGraphExecutor(
static RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
k_hexagon_remote_fused_graph_executor_build(
"build_hexagon_remote_fused_graph_executor",
HexagonControlWrapper::REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
BuildRemoteFusedGraphExecutor);
} // namespace hexagon_remote_fused_graph_executor_build

View File

@ -1,136 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Tests in this file are designed to evaluate hexagon DSP operations.
#define EIGEN_USE_THREADS
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#ifdef USE_HEXAGON_LIBS
#include "tensorflow/core/platform/hexagon/soc_interface.h"
#include "tensorflow/core/platform/profile_utils/cpu_utils.h"
#endif
namespace tensorflow {
class QuantizedMatMulOpForHexagonTest : public OpsTestBase {
protected:
void SetUp() final {
#ifdef USE_HEXAGON_LIBS
profile_utils::CpuUtils::EnableClockCycleProfiling(true);
LOG(INFO) << "Hexagon libs are linked (wrapper version = "
<< soc_interface_GetWrapperVersion()
<< ", hexagon binary version = "
<< soc_interface_GetSocControllerVersion() << ")";
LOG(INFO) << "Cpu frequency = "
<< profile_utils::CpuUtils::GetCycleCounterFrequency();
#else
LOG(WARNING) << "Hexagon libs are not linked.";
#endif
}
};
// Shows some statistics of hexagon dsp using hexagon specific APIs
#ifdef USE_HEXAGON_LIBS
TEST_F(QuantizedMatMulOpForHexagonTest, EvaluateSharedLibOverhead) {
const uint64 overhead_shared_lib_start =
profile_utils::CpuUtils::GetCurrentClockCycle();
const int wrapper_version = soc_interface_GetWrapperVersion();
const uint64 overhead_shared_lib_end =
profile_utils::CpuUtils::GetCurrentClockCycle();
const uint64 overhead_shared_lib_diff =
(overhead_shared_lib_end - overhead_shared_lib_start);
const uint64 overhead_hexagon_rpc_start =
profile_utils::CpuUtils::GetCurrentClockCycle();
const int hexagon_binary_version = soc_interface_GetSocControllerVersion();
const uint64 overhead_hexagon_rpc_end =
profile_utils::CpuUtils::GetCurrentClockCycle();
const uint64 overhead_hexagon_rpc_diff =
(overhead_hexagon_rpc_end - overhead_hexagon_rpc_start);
LOG(INFO) << "Shared lib (ver = " << wrapper_version << ") overhead is "
<< overhead_shared_lib_diff << " cycles, time = "
<< std::chrono::duration_cast<std::chrono::microseconds>(
profile_utils::CpuUtils::ConvertClockCycleToTime(
overhead_shared_lib_diff))
.count()
<< " usec";
LOG(INFO) << "hexagon rpc (ver = " << hexagon_binary_version
<< ") overhead is " << overhead_hexagon_rpc_diff
<< " cycles, time = "
<< std::chrono::duration_cast<std::chrono::microseconds>(
profile_utils::CpuUtils::ConvertClockCycleToTime(
overhead_hexagon_rpc_diff))
.count()
<< " usec";
}
#endif
// Runs two small matrices through the operator, and leaves all the parameters
// at their default values.
// This test is a sample to execute matmul on hexagon.
TEST_F(QuantizedMatMulOpForHexagonTest, Small_NoParams) {
TF_ASSERT_OK(NodeDefBuilder("quantized_mat_mul_op", "QuantizedMatMul")
.Input(FakeInput(DT_QUINT8))
.Input(FakeInput(DT_QUINT8))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Attr("Toutput", DataTypeToEnum<qint32>::v())
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
// A matrix is:
// | 1 | 2 | 3 |
// | 4 | 5 | 6 |
AddInputFromArray<quint8>(TensorShape({2, 3}), {1, 2, 3, 4, 5, 6});
// B matrix is:
// | 7 | 8 | 9 | 10 |
// | 11 | 12 | 13 | 14 |
// | 15 | 16 | 17 | 18 |
AddInputFromArray<quint8>(TensorShape({3, 4}),
{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
AddInputFromArray<float>(TensorShape({1}), {0});
AddInputFromArray<float>(TensorShape({1}), {255.0f});
AddInputFromArray<float>(TensorShape({1}), {0});
AddInputFromArray<float>(TensorShape({1}), {255.0f});
TF_ASSERT_OK(RunOpKernel());
// Here are the results we expect, from hand calculations:
// (1 * 7) + (2 * 11) + (3 * 15) = 74
// (1 * 8) + (2 * 12) + (3 * 16) = 80
// (1 * 9) + (2 * 13) + (3 * 17) = 86
// (1 * 10) + (2 * 14) + (3 * 18) = 92
// (4 * 7) + (5 * 11) + (6 * 15) = 173
// (4 * 8) + (5 * 12) + (6 * 16) = 188
// (4 * 9) + (5 * 13) + (6 * 17) = 203
// (4 * 10) + (5 * 14) + (6 * 18) = 218
Tensor expected(allocator(), DT_QINT32, TensorShape({2, 4}));
test::FillValues<qint32>(&expected, {74, 80, 86, 92, 173, 188, 203, 218});
test::ExpectTensorEqual<qint32>(expected, *GetOutput(0));
}
} // namespace tensorflow

View File

@ -0,0 +1,83 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
vcyou may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/hexagon/soc_interface.h"
// Dummy implementation of soc_interface.
int soc_interface_GetWrapperVersion() { return -1; }
int soc_interface_GetSocControllerVersion() { return -1; }
bool soc_interface_Init() { return false; }
bool soc_interface_Finalize() { return false; }
bool soc_interface_ExecuteGraph() { return false; }
bool soc_interface_TeardownGraph() { return false; }
bool soc_interface_AllocateInOutNodeBuffers(int /*input_count*/,
int* /*input_sizes*/,
int /*output_count*/,
int* /*output_sizes*/) {
return false;
}
bool soc_interface_FillInputNodeWithPort(int /*port*/, int /*x*/, int /*y*/,
int /*z*/, int /*d*/,
const uint8_t* const /*buf*/,
uint64_t /*buf_byte_size*/) {
return false;
}
bool soc_interface_FillInputNodeFloat(int /*x*/, int /*y*/, int /*z*/,
int /*d*/, const uint8_t* const /*buf*/,
uint64_t /*buf_byte_size*/) {
return false;
}
bool soc_interface_ReadOutputNodeWithPort(int /*port*/, uint8_t** /*buf*/,
uint64_t* /*buf_byte_size*/) {
return false;
}
bool soc_interface_ReadOutputNodeFloat(const char* const /*node_name*/,
uint8_t** /*buf*/,
uint64_t* /*buf_byte_size*/) {
return false;
}
bool soc_interface_setupDummyGraph(int /*version*/) { return false; }
bool soc_interface_AllocateNodeInputAndNodeOutputArray(
int /*total_input_count*/, int /*total_output_count*/) {
return false;
}
bool soc_interface_ReleaseNodeInputAndNodeOutputArray() { return false; }
void* soc_interface_SetOneNodeInputs(int /*input_count*/,
const int* const /*node_id*/,
const int* const /*port*/) {
return nullptr;
}
void* soc_interface_SetOneNodeOutputs(int /*output_count*/, int* /*max_size*/) {
return nullptr;
}
bool soc_interface_AppendConstNode(const char* const /*name*/, int /*node_id*/,
int /*batch*/, int /*height*/, int /*width*/,
int /*depth*/, const uint8_t* const /*data*/,
int /*data_length*/) {
return false;
}
bool soc_interface_AppendNode(const char* const /*name*/, int /*node_id*/,
int /*op_id*/, int /*padding_id*/,
const void* const /*inputs*/,
int /*inputs_count*/,
const void* const /*outputs*/,
int /*outputs_count*/) {
return false;
}
bool soc_interface_InstantiateGraph() { return false; }
bool soc_interface_ConstructGraph() { return false; }
void soc_interface_SetLogLevel(int /*log_level*/) {}
void soc_interface_SetDebugFlag(uint64_t /*flag*/) {}

View File

@ -21,6 +21,7 @@ limitations under the License.
// All functions defined here must have prefix "soc_interface" to avoid
// naming conflicts.
#ifdef __cplusplus
#include <cstdint>
extern "C" {
#else
#include <stdbool.h>

View File

@ -59,6 +59,13 @@ class IRemoteFusedGraphExecutor {
virtual bool ReadOutputNode(const string& node_name,
TensorAllocatorFunc tensor_allocator) = 0;
virtual Status FuseRemoteGraph(const GraphDef& original_graph_def,
const std::vector<string>& inputs,
const std::vector<string>& outputs,
GraphDef* fused_graph_def) = 0;
virtual bool IsEnabled() const = 0;
private:
TF_DISALLOW_COPY_AND_ASSIGN(IRemoteFusedGraphExecutor);
};

View File

@ -13,11 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "i_graph_transfer_ops_definitions.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
namespace tensorflow {
/* static */ constexpr int IGraphTransferOpsDefinitions::INVALID_OP_ID;
// TODO(satok): Remove
/* static */ constexpr const char* const
IGraphTransferOpsDefinitions::FLATTEN_OP_NAME;
/* static */ constexpr int IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID;
}

View File

@ -13,39 +13,34 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_I_GRAPH_TRANSFER_OPS_DEFINITIONS_H_
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_I_GRAPH_TRANSFER_OPS_DEFINITIONS_H_
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_OPS_DEFINITIONS_H_
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_OPS_DEFINITIONS_H_
#include "tensorflow/core/framework/graph_transfer_info.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
// IGraphTransferOpsDefinitions is an interface class which provides interfaces
// about ops supported by SOC.
// IRemoteFusedGraphOpsDefinitions is an interface class which provides
// APIs to provide information about op types supported by SOC.
// TODO(satok): Provide ways to transfer graph definitions into SOC
class IGraphTransferOpsDefinitions {
class IRemoteFusedGraphOpsDefinitions {
public:
// op id which is not supported by SOC
static constexpr int INVALID_OP_ID = -1;
// Custom op name for flatten node
static constexpr const char* const FLATTEN_OP_NAME = "FLATTEN";
IGraphTransferOpsDefinitions() = default;
virtual ~IGraphTransferOpsDefinitions() = default;
IRemoteFusedGraphOpsDefinitions() = default;
virtual ~IRemoteFusedGraphOpsDefinitions() = default;
// Return total ops count supported by SOC
virtual int GetTotalOpsCount() const = 0;
// Return op id for given string op name
virtual int GetOpIdFor(const string& op_name,
virtual int GetOpIdFor(const string& op_type,
const DataTypeVector& dt) const = 0;
// Return destination of transfer
virtual GraphTransferInfo::Destination GetTransferDestination() const = 0;
private:
TF_DISALLOW_COPY_AND_ASSIGN(IGraphTransferOpsDefinitions);
TF_DISALLOW_COPY_AND_ASSIGN(IRemoteFusedGraphOpsDefinitions);
};
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_I_GRAPH_TRANSFER_OPS_DEFINITIONS_H_
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_OPS_DEFINITIONS_H_

View File

@ -41,7 +41,8 @@ class RemoteFusedGraphExecuteOp : public OpKernel {
RemoteFusedGraphExecuteUtils::GetExecutorBuildFunc(
execute_info_.executor_name());
if (build_func != nullptr) {
Status status = (*build_func)(&remote_fused_graph_executor_);
TF_CHECK_OK((*build_func)(&remote_fused_graph_executor_));
CHECK(remote_fused_graph_executor_->IsEnabled());
} else {
LOG(ERROR) << "Executor not found for "
<< execute_info_.executor_name();

View File

@ -159,8 +159,8 @@ static RemoteFusedGraphExecuteInfo BuildRemoteFusedGraphExecuteInfo(
return execute_info;
}
// 1. Create TestRemoteFusedGraphExecutor to execute your fused graph
class TestRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor {
// 1. Create SampleRemoteFusedGraphExecutor to execute your fused graph
class SampleRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor {
public:
int GetVersion() final { return 1; }
bool Init(const RemoteFusedGraphExecuteInfo& info) final {
@ -214,6 +214,16 @@ class TestRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor {
return true;
}
Status FuseRemoteGraph(const GraphDef& original_graph_def,
const std::vector<string>& /*inputs*/,
const std::vector<string>& /*outputs*/,
GraphDef* fused_graph_def) final {
*fused_graph_def = original_graph_def;
return Status::OK();
}
bool IsEnabled() const final { return true; }
private:
const RemoteFusedGraphExecuteInfo* info_;
std::unordered_map<string, Tensor> input_tensor_cache_;
@ -225,7 +235,7 @@ class TestRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor {
namespace remote_fused_graph_execute_op {
Status BuildRemoteFusedGraphExecutor(
std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
executor->reset(new TestRemoteFusedGraphExecutor());
executor->reset(new SampleRemoteFusedGraphExecutor());
return Status::OK();
}

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
@ -92,4 +93,36 @@ namespace tensorflow {
return Status::OK();
}
TestRemoteFusedGraphExecutor::TestRemoteFusedGraphExecutor(
const std::unordered_set<string>& fused_op_types,
const string& executor_name)
: fused_op_types_(fused_op_types), executor_name_(executor_name) {}
int TestRemoteFusedGraphExecutor::GetVersion() { return 0; }
bool TestRemoteFusedGraphExecutor::Init(const RemoteFusedGraphExecuteInfo&) {
return true;
}
bool TestRemoteFusedGraphExecutor::Finalize() { return true; }
bool TestRemoteFusedGraphExecutor::SetupGraph() { return true; }
bool TestRemoteFusedGraphExecutor::ExecuteGraph() { return true; }
bool TestRemoteFusedGraphExecutor::TeardownGraph() { return true; }
bool TestRemoteFusedGraphExecutor::FillInputNode(const string&, const Tensor&) {
return true;
}
bool TestRemoteFusedGraphExecutor::ReadOutputNode(const string&,
TensorAllocatorFunc) {
return true;
}
Status TestRemoteFusedGraphExecutor::FuseRemoteGraph(
const GraphDef& original_graph_def, const std::vector<string>& inputs,
const std::vector<string>& outputs, GraphDef* fused_graph_def) {
return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByOpTypes(
original_graph_def, inputs, outputs, "remote_fused_graph_node_names",
fused_op_types_, executor_name_,
/*require_shape_type=*/false, fused_graph_def);
return Status::OK();
}
bool TestRemoteFusedGraphExecutor::IsEnabled() const { return true; }
} // namespace tensorflow

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
@ -59,6 +60,30 @@ class RemoteFusedGraphExecuteOpTestUtils {
TF_DISALLOW_COPY_AND_ASSIGN(RemoteFusedGraphExecuteOpTestUtils);
};
class TestRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor {
public:
TestRemoteFusedGraphExecutor(const std::unordered_set<string>& fused_op_types,
const string& executor_name);
int GetVersion() final;
bool Init(const RemoteFusedGraphExecuteInfo&) final;
bool Finalize() final;
bool SetupGraph() final;
bool ExecuteGraph() final;
bool TeardownGraph() final;
bool FillInputNode(const string&, const Tensor&) final;
bool ReadOutputNode(const string&, TensorAllocatorFunc) final;
Status FuseRemoteGraph(const GraphDef& original_graph_def,
const std::vector<string>& inputs,
const std::vector<string>& outputs,
GraphDef* fused_graph_def) final;
bool IsEnabled() const final;
private:
const std::unordered_set<string> fused_op_types_;
const string executor_name_;
};
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_OP_TEST_UTILS_H_

View File

@ -161,6 +161,8 @@ string DumpCluster(const RemoteFusedGraphExecuteUtils::ClusterInfo& cluster) {
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS;
/* static */ constexpr const char* const
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES;
/* static */ constexpr const char* const
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR;
/* static */ constexpr const char* const
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES;
/* static */ constexpr const char* const
@ -1084,6 +1086,26 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
require_shape_type, output_graph_def);
}
/* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
const GraphDef& input_graph_def, const std::vector<string>& inputs,
const std::vector<string>& outputs, const string& executor_name,
GraphDef* output_graph_def) {
const ExecutorBuildFunc* build_func = GetExecutorBuildFunc(executor_name);
if (build_func == nullptr) {
return errors::InvalidArgument("Unknown executor name: " + executor_name);
}
std::unique_ptr<IRemoteFusedGraphExecutor> executor;
TF_RETURN_IF_ERROR((*build_func)(&executor));
CHECK_NOTNULL(executor.get());
if (!executor->IsEnabled()) {
// As this executor is not enabled, just return original graph as is.
*output_graph_def = input_graph_def;
return Status::OK();
}
return executor->FuseRemoteGraph(input_graph_def, inputs, outputs,
output_graph_def);
}
/* static */ Status RemoteFusedGraphExecuteUtils::PlaceRemoteGraphArguments(
const std::vector<string>& inputs, const std::vector<string>& outputs,
const std::unordered_set<string>& fused_node_names,
@ -1387,6 +1409,28 @@ RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpTypes(
return retval;
}
/* static */ std::unordered_set<string>
RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions(
const GraphDef& graph_def,
const IRemoteFusedGraphOpsDefinitions& ops_definitions) {
std::unordered_set<string> retval;
for (const NodeDef& node_def : graph_def.node()) {
std::vector<DataType> dt_vec;
std::vector<TensorShape> shape_vec;
const Status status =
GetOutputTensorShapeType(node_def, &dt_vec, &shape_vec);
if (!status.ok()) {
shape_vec.clear();
}
if (ops_definitions.GetOpIdFor(
node_def.op(), DataTypeVector(dt_vec.begin(), dt_vec.end())) !=
IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) {
retval.emplace(node_def.name());
}
}
return retval;
}
/* static */ Status RemoteFusedGraphExecuteUtils::ReplaceInputNodeByPlaceHolder(
const string& input, const DataType type, const TensorShape& shape,
GraphDef* graph_def) {

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
@ -59,6 +60,8 @@ class RemoteFusedGraphExecuteUtils {
"border_outputs";
static constexpr const char* const TRANSFORM_ARG_FUSED_OP_TYPES =
"fused_op_types";
static constexpr const char* const TRANSFORM_ARG_FUSE_BY_EXECUTOR =
"fuse_by_executor";
static constexpr const char* const TRANSFORM_ARG_INPUT_TYPES = "input_types";
static constexpr const char* const TRANSFORM_ARG_INPUT_SHAPES =
"input_shapes";
@ -257,6 +260,12 @@ class RemoteFusedGraphExecuteUtils {
const std::vector<std::pair<string, Tensor>>& input_tensors,
GraphDef* output_graph_def);
static Status FuseRemoteGraphByExecutor(const GraphDef& input_graph_def,
const std::vector<string>& inputs,
const std::vector<string>& outputs,
const string& executor_name,
GraphDef* output_graph_def);
static bool IsFuseReady(
const GraphDef& input_graph_def,
const std::vector<std::pair<string, Tensor>>& input_tensors);
@ -273,6 +282,10 @@ class RemoteFusedGraphExecuteUtils {
static std::unordered_set<string> BuildNodeMapFromOpTypes(
const GraphDef& graph_def, const std::unordered_set<string>& op_types);
static std::unordered_set<string> BuildNodeMapFromOpsDefinitions(
const GraphDef& graph_def,
const IRemoteFusedGraphOpsDefinitions& ops_definitions);
private:
static void EmplaceTensorShapeType(const string& name, const Tensor& tensor,
TensorShapeMap* tensor_shape_map);

View File

@ -33,6 +33,10 @@ constexpr const char* const NAME_A_PLUS_B = "A_PLUS_B";
constexpr float NODE_A_VAL = 2.0f;
constexpr float NODE_B_VAL = 3.0f;
constexpr float VALUE_TOLERANCE_FLOAT = 1e-8f;
constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME0 =
"fuse_test_remote_fused_graph_executor0";
constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME1 =
"fuse_test_remote_fused_graph_executor1";
static NodeDef* GetNodeDef(const string& name, GraphDef* def) {
CHECK_NE(def, nullptr);
@ -44,17 +48,38 @@ static NodeDef* GetNodeDef(const string& name, GraphDef* def) {
return nullptr;
}
Status BuildRemoteFusedGraphExecutor0(
std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
executor->reset(
new TestRemoteFusedGraphExecutor({"Mul"}, REMOTE_FUSED_EXECUTOR_NAME0));
return Status::OK();
}
Status BuildRemoteFusedGraphExecutor1(
std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
executor->reset(new TestRemoteFusedGraphExecutor(
{"Const", "Mul"}, REMOTE_FUSED_EXECUTOR_NAME1));
return Status::OK();
}
class FuseRemoteGraphMultipleAddOpsTest : public ::testing::Test {
protected:
void SetUp() final {
TF_ASSERT_OK(
RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(&graph_def_));
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
k_hexagon_remote_fused_graph_executor_build(
hexagon_remote_fused_graph_executor_build(
"remote_graph_executor_name",
[](std::unique_ptr<IRemoteFusedGraphExecutor>* executor) -> Status {
return Status::OK();
});
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
test_remote_fused_graph_executor_build0(REMOTE_FUSED_EXECUTOR_NAME0,
BuildRemoteFusedGraphExecutor0);
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
test_remote_fused_graph_executor_build1(REMOTE_FUSED_EXECUTOR_NAME1,
BuildRemoteFusedGraphExecutor1);
}
void TearDown() final {}
@ -87,6 +112,18 @@ class FuseRemoteGraphMultipleAddOpsTest : public ::testing::Test {
/*require_shape_type=*/false, &result_graph_def_);
}
Status FuseByExecutor0() {
return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
graph_def_, inputs_, outputs_, REMOTE_FUSED_EXECUTOR_NAME0,
&result_graph_def_);
}
Status FuseByExecutor1() {
return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
graph_def_, inputs_, outputs_, REMOTE_FUSED_EXECUTOR_NAME1,
&result_graph_def_);
}
Status BuildAndAddTensorShape() {
return RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes(
input_tensors_, /*dry_run_inference=*/true, &graph_def_);
@ -694,6 +731,30 @@ TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByOpTypes_FGHIJ) {
<< SummarizeGraphDef(result_graph_def_);
}
TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByExecutor_HIJ) {
ReplaceOpType({"H", "I", "J"}, "Mul");
TF_ASSERT_OK(FuseByExecutor0());
EXPECT_EQ(11, graph_def_.node_size());
EXPECT_EQ(9, result_graph_def_.node_size())
<< "=== Before: \n"
<< SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
<< SummarizeGraphDef(result_graph_def_);
}
TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByExecutor_FGHIJ) {
ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul");
TF_ASSERT_OK(FuseByExecutor1());
EXPECT_EQ(11, graph_def_.node_size());
EXPECT_EQ(3, result_graph_def_.node_size())
<< "=== Before: \n"
<< SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
<< SummarizeGraphDef(result_graph_def_);
}
TEST_F(FuseRemoteGraphMultipleAddOpsTest, PlaceAndFuse_H) {
subgraph_node_names_ = {"H"};

View File

@ -66,7 +66,7 @@ static Status ParseArguments(const TransformFuncContext& context,
string* input_types_str, string* input_shapes_str,
string* fused_nodes_str, string* border_inputs_str,
string* border_outputs_str,
string* fused_op_types_str,
string* fused_op_types_str, bool* fuse_by_executor,
string* remote_fused_graph_node_name,
string* remote_graph_executor_name) {
TF_RETURN_IF_ERROR(context.GetOneStringParameter(
@ -87,6 +87,9 @@ static Status ParseArguments(const TransformFuncContext& context,
TF_RETURN_IF_ERROR(context.GetOneStringParameter(
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES, "",
fused_op_types_str));
TF_RETURN_IF_ERROR(context.GetOneBoolParameter(
RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR, false,
fuse_by_executor));
TF_RETURN_IF_ERROR(context.GetOneStringParameter(
RemoteFusedGraphExecuteUtils::
TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
@ -140,12 +143,14 @@ Status FuseRemoteGraph(const GraphDef& input_graph_def,
string border_inputs_str;
string border_outputs_str;
string fused_op_types_str;
bool fuse_by_executor = false;
string remote_fused_graph_node_name;
string remote_graph_executor_name;
TF_RETURN_IF_ERROR(ParseArguments(
context, &input_types_str, &input_shapes_str, &fused_nodes_str,
&border_inputs_str, &border_outputs_str, &fused_op_types_str,
&remote_fused_graph_node_name, &remote_graph_executor_name));
&fuse_by_executor, &remote_fused_graph_node_name,
&remote_graph_executor_name));
if (!input_types_str.empty()) {
TF_RETURN_IF_ERROR(PlaceShapeType(inputs, outputs, input_types_str,
@ -187,6 +192,10 @@ Status FuseRemoteGraph(const GraphDef& input_graph_def,
mutable_input_graph_def, inputs, outputs, remote_fused_graph_node_name,
fused_op_types, remote_graph_executor_name, require_shape_type,
output_graph_def));
} else if (fuse_by_executor) {
TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
mutable_input_graph_def, inputs, outputs, remote_graph_executor_name,
output_graph_def));
} else {
CHECK(false) << "Fuse targets are not specified.";
}
@ -205,15 +214,17 @@ Status PlaceRemoteGraphArguments(const GraphDef& input_graph_def,
string input_types_str;
string input_shapes_str;
string fused_nodes_str;
string fused_op_types_str;
string border_inputs_str;
string border_outputs_str;
string fused_op_types_str;
bool fuse_by_executor = false;
string remote_fused_graph_node_name;
string remote_graph_executor_name;
TF_RETURN_IF_ERROR(ParseArguments(
context, &input_types_str, &input_shapes_str, &fused_nodes_str,
&border_inputs_str, &border_outputs_str, &fused_op_types_str,
&remote_fused_graph_node_name, &remote_graph_executor_name));
&fuse_by_executor, &remote_fused_graph_node_name,
&remote_graph_executor_name));
if (!input_types_str.empty()) {
TF_RETURN_IF_ERROR(PlaceShapeType(inputs, outputs, input_types_str,

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@ -43,11 +44,28 @@ Status PlaceRemoteGraphArguments(const GraphDef& input_graph_def,
GraphDef* output_graph_def);
namespace {
constexpr const char* const REMOTE_FUSED_GRAPH_EXECUTOR_NAME =
"remote_fused_graph_executor_name";
constexpr const char* const REMOTE_FUSED_GRAPH_NODE_NAME =
"remote_fused_graph_node_name";
constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME0 =
"fuse_test_remote_fused_graph_executor0";
constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME1 =
"fuse_test_remote_fused_graph_executor1";
Status BuildRemoteFusedGraphExecutor0(
std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
executor->reset(
new TestRemoteFusedGraphExecutor({"Mul"}, REMOTE_FUSED_EXECUTOR_NAME0));
return Status::OK();
}
Status BuildRemoteFusedGraphExecutor1(
std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
executor->reset(new TestRemoteFusedGraphExecutor(
{"Const", "Mul"}, REMOTE_FUSED_EXECUTOR_NAME1));
return Status::OK();
}
class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
protected:
@ -55,11 +73,18 @@ class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(
&input_graph_def_));
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
k_hexagon_remote_fused_graph_executor_build(
hexagon_remote_fused_graph_executor_build(
REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
[](std::unique_ptr<IRemoteFusedGraphExecutor>* executor) -> Status {
return Status::OK();
});
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
test_remote_fused_graph_executor_build0(REMOTE_FUSED_EXECUTOR_NAME0,
BuildRemoteFusedGraphExecutor0);
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
test_remote_fused_graph_executor_build1(REMOTE_FUSED_EXECUTOR_NAME1,
BuildRemoteFusedGraphExecutor1);
}
void TearDown() final {}
@ -113,10 +138,16 @@ class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
{fused_op_types_str_}}));
}
if (fuse_by_executor_) {
context.params.insert(std::pair<string, std::vector<string>>(
{RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR,
{"true"}}));
}
context.params.insert(std::pair<string, std::vector<string>>(
{RemoteFusedGraphExecuteUtils::
TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
{REMOTE_FUSED_GRAPH_EXECUTOR_NAME}}));
{remote_fused_graph_executor_name_}}));
context.params.insert(std::pair<string, std::vector<string>>(
{RemoteFusedGraphExecuteUtils::
TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME,
@ -160,7 +191,7 @@ class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO,
&serialized_proto));
info.ParseFromString(serialized_proto);
CHECK_EQ(REMOTE_FUSED_GRAPH_EXECUTOR_NAME, info.executor_name());
CHECK_EQ(remote_fused_graph_executor_name_, info.executor_name());
}
}
EXPECT_EQ(expected_cluster_count, cluster_count);
@ -178,6 +209,8 @@ class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
string border_inputs_str_;
string border_outputs_str_;
string fused_op_types_str_;
string remote_fused_graph_executor_name_{REMOTE_FUSED_GRAPH_EXECUTOR_NAME};
bool fuse_by_executor_{false};
};
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
@ -260,6 +293,24 @@ TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
CheckGraph(3, 1);
}
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
FuseRemoteGraphByExecutor_HIJ) {
ReplaceOpType({"H", "I", "J"}, "Mul");
remote_fused_graph_executor_name_ = REMOTE_FUSED_EXECUTOR_NAME0;
fuse_by_executor_ = true;
TF_ASSERT_OK(Fuse());
CheckGraph(9, 1);
}
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
FuseRemoteGraphByExecutor_FGHIJ) {
ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul");
remote_fused_graph_executor_name_ = REMOTE_FUSED_EXECUTOR_NAME1;
fuse_by_executor_ = true;
TF_ASSERT_OK(Fuse());
CheckGraph(3, 1);
}
TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_HIJ) {
fused_node_names_str_ = "H,I,J";
TF_ASSERT_OK(PlaceFuseArgs());