Invoke APIs of SOC wrapper library to execute inference on HVX
Change: 141373481
This commit is contained in:
parent
260e64fe09
commit
539fafd32e
@ -66,7 +66,7 @@ fi
|
|||||||
if [[ "${USE_HEXAGON}" == "true" ]]; then
|
if [[ "${USE_HEXAGON}" == "true" ]]; then
|
||||||
HEXAGON_PARENT_DIR=$(cd ../hexagon && pwd)
|
HEXAGON_PARENT_DIR=$(cd ../hexagon && pwd)
|
||||||
HEXAGON_LIBS="${HEXAGON_PARENT_DIR}/libs"
|
HEXAGON_LIBS="${HEXAGON_PARENT_DIR}/libs"
|
||||||
HEXAGON_INCLUDE="${HEXAGON_PARENT_DIR}/include"
|
HEXAGON_INCLUDE=$(cd tensorflow/core/platform/hexagon && pwd)
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [[ -z "${BUILD_TARGET}" ]]; then
|
if [[ -z "${BUILD_TARGET}" ]]; then
|
||||||
|
@ -15,49 +15,77 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h"
|
#include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h"
|
||||||
|
|
||||||
|
#ifdef USE_HEXAGON_LIBS
|
||||||
|
#include "tensorflow/core/platform/hexagon/soc_interface.h"
|
||||||
|
#include "tensorflow/core/platform/profile_utils/cpu_utils.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
int HexagonControlWrapper::GetVersion() const {
|
#ifdef USE_HEXAGON_LIBS
|
||||||
// TODO: Implement
|
int HexagonControlWrapper::GetVersion() {
|
||||||
return 1;
|
return soc_interface_GetSocControllerVersion();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HexagonControlWrapper::Init() {
|
bool HexagonControlWrapper::Init() { return soc_interface_Init(); }
|
||||||
// TODO: Implement
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool HexagonControlWrapper::Finalize() {
|
bool HexagonControlWrapper::Finalize() { return soc_interface_Finalize(); }
|
||||||
// TODO: Implement
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
bool HexagonControlWrapper::SetupGraph(
|
bool HexagonControlWrapper::SetupGraph(
|
||||||
const GraphTransferer &graph_transferer) {
|
const GraphTransferer &graph_transferer) {
|
||||||
// TODO: Implement
|
return soc_interface_SetupGraphDummy(3 /* inception version */);
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HexagonControlWrapper::ExecuteGraph() {
|
bool HexagonControlWrapper::ExecuteGraph() {
|
||||||
// TODO: Implement
|
return soc_interface_ExecuteGraph();
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HexagonControlWrapper::TeardownGraph() {
|
bool HexagonControlWrapper::TeardownGraph() {
|
||||||
// TODO: Implement
|
return soc_interface_TeardownGraph();
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HexagonControlWrapper::FillInputNode(const string node_name,
|
bool HexagonControlWrapper::FillInputNode(const string node_name,
|
||||||
const ByteArray bytes) {
|
const ByteArray bytes) {
|
||||||
// TODO: Implement
|
// TODO(satok): Use arguments instead of dummy input
|
||||||
return false;
|
const int x = 1;
|
||||||
|
const int y = 299;
|
||||||
|
const int z = 299;
|
||||||
|
const int d = 3;
|
||||||
|
const int array_length = x * y * z * d;
|
||||||
|
const int byte_size = array_length * sizeof(float);
|
||||||
|
dummy_input_float_.resize(array_length);
|
||||||
|
return soc_interface_FillInputNodeFloat(
|
||||||
|
1, 299, 299, 3, reinterpret_cast<uint8 *>(dummy_input_float_.data()),
|
||||||
|
byte_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HexagonControlWrapper::ReadOutputNode(
|
bool HexagonControlWrapper::ReadOutputNode(
|
||||||
const string node_name, std::vector<ByteArray> *const outputs) const {
|
const string node_name, std::vector<ByteArray> *const outputs) {
|
||||||
CHECK(outputs != nullptr);
|
CHECK(outputs != nullptr);
|
||||||
// TODO: Implement
|
ByteArray output;
|
||||||
return false;
|
soc_interface_ReadOutputNodeFloat(node_name.c_str(), &std::get<0>(output),
|
||||||
|
&std::get<1>(output));
|
||||||
|
std::get<2>(output) = DT_FLOAT;
|
||||||
|
outputs->emplace_back(output);
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
int HexagonControlWrapper::GetVersion() { return -1; }
|
||||||
|
bool HexagonControlWrapper::Init() { return false; }
|
||||||
|
bool HexagonControlWrapper::Finalize() { return false; }
|
||||||
|
bool HexagonControlWrapper::SetupGraph(const GraphTransferer &) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool HexagonControlWrapper::ExecuteGraph() { return false; }
|
||||||
|
bool HexagonControlWrapper::TeardownGraph() { return false; }
|
||||||
|
bool HexagonControlWrapper::FillInputNode(const string, const ByteArray) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool HexagonControlWrapper::ReadOutputNode(const string,
|
||||||
|
std::vector<ByteArray> *const) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
|||||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_
|
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_
|
||||||
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_
|
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
|
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
|
||||||
#include "tensorflow/core/kernels/hexagon/i_soc_control_wrapper.h"
|
#include "tensorflow/core/kernels/hexagon/i_soc_control_wrapper.h"
|
||||||
@ -31,17 +33,20 @@ namespace tensorflow {
|
|||||||
class HexagonControlWrapper final : public ISocControlWrapper {
|
class HexagonControlWrapper final : public ISocControlWrapper {
|
||||||
public:
|
public:
|
||||||
HexagonControlWrapper() = default;
|
HexagonControlWrapper() = default;
|
||||||
int GetVersion() const final;
|
int GetVersion() final;
|
||||||
bool Init() final;
|
bool Init() final;
|
||||||
bool Finalize() final;
|
bool Finalize() final;
|
||||||
bool SetupGraph(const GraphTransferer &graph_transferer) final;
|
bool SetupGraph(const GraphTransferer &graph_transferer) final;
|
||||||
bool ExecuteGraph() final;
|
bool ExecuteGraph() final;
|
||||||
bool TeardownGraph() final;
|
bool TeardownGraph() final;
|
||||||
bool FillInputNode(string node_name, const ByteArray bytes) final;
|
bool FillInputNode(string node_name, const ByteArray bytes) final;
|
||||||
bool ReadOutputNode(string node_name,
|
bool ReadOutputNode(string node_name, std::vector<ByteArray> *outputs) final;
|
||||||
std::vector<ByteArray> *outputs) const final;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
// Dummy byte array for input node data.
|
||||||
|
// TODO(satok): Use actual data passed by FillInputNode and remove
|
||||||
|
std::vector<float> dummy_input_float_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(HexagonControlWrapper);
|
TF_DISALLOW_COPY_AND_ASSIGN(HexagonControlWrapper);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -51,7 +51,7 @@ TEST(GraphTransferer, RunInceptionV3OnHexagonExample) {
|
|||||||
Status status = gt.LoadGraphFromProtoFile(
|
Status status = gt.LoadGraphFromProtoFile(
|
||||||
*ops_definitions, filename, input_node_info_list, output_node_names,
|
*ops_definitions, filename, input_node_info_list, output_node_names,
|
||||||
is_text_proto, true /* dry_run_for_unknown_shape */, &output_tensor_info);
|
is_text_proto, true /* dry_run_for_unknown_shape */, &output_tensor_info);
|
||||||
EXPECT_TRUE(status.ok());
|
EXPECT_TRUE(status.ok()) << status;
|
||||||
|
|
||||||
HexagonControlWrapper hexagon_control_wrapper;
|
HexagonControlWrapper hexagon_control_wrapper;
|
||||||
const int version = hexagon_control_wrapper.GetVersion();
|
const int version = hexagon_control_wrapper.GetVersion();
|
||||||
|
@ -32,7 +32,7 @@ class ISocControlWrapper {
|
|||||||
|
|
||||||
// Return version of SOC controller library.
|
// Return version of SOC controller library.
|
||||||
// This function is mainly for a debug purpose to verify SOC controller.
|
// This function is mainly for a debug purpose to verify SOC controller.
|
||||||
virtual int GetVersion() const = 0;
|
virtual int GetVersion() = 0;
|
||||||
|
|
||||||
// Initialize SOC. This function should be called before
|
// Initialize SOC. This function should be called before
|
||||||
// starting graph transfer.
|
// starting graph transfer.
|
||||||
@ -56,7 +56,7 @@ class ISocControlWrapper {
|
|||||||
|
|
||||||
// Read output node's outputs on SOC
|
// Read output node's outputs on SOC
|
||||||
virtual bool ReadOutputNode(string node_name,
|
virtual bool ReadOutputNode(string node_name,
|
||||||
std::vector<ByteArray> *outputs) const = 0;
|
std::vector<ByteArray> *outputs) = 0;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(ISocControlWrapper);
|
TF_DISALLOW_COPY_AND_ASSIGN(ISocControlWrapper);
|
||||||
|
@ -46,7 +46,7 @@ class QuantizedMatMulOpForHexagonTest : public OpsTestBase {
|
|||||||
LOG(INFO) << "Hexagon libs are linked (wrapper version = "
|
LOG(INFO) << "Hexagon libs are linked (wrapper version = "
|
||||||
<< soc_interface_GetWrapperVersion()
|
<< soc_interface_GetWrapperVersion()
|
||||||
<< ", hexagon binary version = "
|
<< ", hexagon binary version = "
|
||||||
<< soc_interface_GetHexagonBinaryVersion() << ")";
|
<< soc_interface_GetSocControllerVersion() << ")";
|
||||||
LOG(INFO) << "Cpu frequency = "
|
LOG(INFO) << "Cpu frequency = "
|
||||||
<< profile_utils::CpuUtils::GetCycleCounterFrequency();
|
<< profile_utils::CpuUtils::GetCycleCounterFrequency();
|
||||||
#else
|
#else
|
||||||
@ -67,7 +67,7 @@ TEST_F(QuantizedMatMulOpForHexagonTest, EvaluateSharedLibOverhead) {
|
|||||||
(overhead_shared_lib_end - overhead_shared_lib_start);
|
(overhead_shared_lib_end - overhead_shared_lib_start);
|
||||||
const uint64 overhead_hexagon_rpc_start =
|
const uint64 overhead_hexagon_rpc_start =
|
||||||
profile_utils::CpuUtils::GetCurrentClockCycle();
|
profile_utils::CpuUtils::GetCurrentClockCycle();
|
||||||
const int hexagon_binary_version = soc_interface_GetHexagonBinaryVersion();
|
const int hexagon_binary_version = soc_interface_GetSocControllerVersion();
|
||||||
const uint64 overhead_hexagon_rpc_end =
|
const uint64 overhead_hexagon_rpc_end =
|
||||||
profile_utils::CpuUtils::GetCurrentClockCycle();
|
profile_utils::CpuUtils::GetCurrentClockCycle();
|
||||||
const uint64 overhead_hexagon_rpc_diff =
|
const uint64 overhead_hexagon_rpc_diff =
|
||||||
|
@ -30,8 +30,26 @@ int soc_interface_GetWrapperVersion();
|
|||||||
// Returns the version of hexagon binary.
|
// Returns the version of hexagon binary.
|
||||||
// You should assert that the version matches the expected version before
|
// You should assert that the version matches the expected version before
|
||||||
// calling APIs defined in this header.
|
// calling APIs defined in this header.
|
||||||
int soc_interface_GetHexagonBinaryVersion();
|
int soc_interface_GetSocControllerVersion();
|
||||||
// TODO(satok): Support gemm APIs via RPC
|
// Initialize SOC
|
||||||
|
bool soc_interface_Init();
|
||||||
|
// Finalize SOC
|
||||||
|
bool soc_interface_Finalize();
|
||||||
|
// Execute graph on SOC
|
||||||
|
bool soc_interface_ExecuteGraph();
|
||||||
|
// Teardown graph setup
|
||||||
|
bool soc_interface_TeardownGraph();
|
||||||
|
// Send input data to SOC
|
||||||
|
bool soc_interface_FillInputNodeFloat(int x, int y, int z, int d,
|
||||||
|
const uint8_t* const buf,
|
||||||
|
uint64_t buf_size);
|
||||||
|
// Load output data from SOC
|
||||||
|
bool soc_interface_ReadOutputNodeFloat(const char* const node_name,
|
||||||
|
uint8_t** buf, uint64_t* buf_size);
|
||||||
|
// Setup graph
|
||||||
|
// TODO(satok): Remove and use runtime version
|
||||||
|
bool soc_interface_SetupGraphDummy(int version);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif // __cplusplus
|
#endif // __cplusplus
|
||||||
|
Loading…
Reference in New Issue
Block a user