Invoke APIs of SOC wrapper library to execute inference on HVX

Change: 141373481
This commit is contained in:
A. Unique TensorFlower 2016-12-07 16:28:28 -08:00 committed by TensorFlower Gardener
parent 260e64fe09
commit 539fafd32e
7 changed files with 84 additions and 33 deletions

View File

@ -66,7 +66,7 @@ fi
if [[ "${USE_HEXAGON}" == "true" ]]; then
HEXAGON_PARENT_DIR=$(cd ../hexagon && pwd)
HEXAGON_LIBS="${HEXAGON_PARENT_DIR}/libs"
HEXAGON_INCLUDE="${HEXAGON_PARENT_DIR}/include"
HEXAGON_INCLUDE=$(cd tensorflow/core/platform/hexagon && pwd)
fi
if [[ -z "${BUILD_TARGET}" ]]; then

View File

@ -15,49 +15,77 @@ limitations under the License.
#include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h"
#ifdef USE_HEXAGON_LIBS
#include "tensorflow/core/platform/hexagon/soc_interface.h"
#include "tensorflow/core/platform/profile_utils/cpu_utils.h"
#include "tensorflow/core/platform/types.h"
#endif
namespace tensorflow {
int HexagonControlWrapper::GetVersion() const {
// TODO: Implement
return 1;
#ifdef USE_HEXAGON_LIBS
int HexagonControlWrapper::GetVersion() {
return soc_interface_GetSocControllerVersion();
}
bool HexagonControlWrapper::Init() {
// TODO: Implement
return false;
}
bool HexagonControlWrapper::Init() { return soc_interface_Init(); }
bool HexagonControlWrapper::Finalize() {
// TODO: Implement
return false;
}
bool HexagonControlWrapper::Finalize() { return soc_interface_Finalize(); }
bool HexagonControlWrapper::SetupGraph(
const GraphTransferer &graph_transferer) {
// TODO: Implement
return false;
return soc_interface_SetupGraphDummy(3 /* inception version */);
}
bool HexagonControlWrapper::ExecuteGraph() {
// TODO: Implement
return false;
return soc_interface_ExecuteGraph();
}
bool HexagonControlWrapper::TeardownGraph() {
// TODO: Implement
return false;
return soc_interface_TeardownGraph();
}
bool HexagonControlWrapper::FillInputNode(const string node_name,
const ByteArray bytes) {
// TODO: Implement
return false;
// TODO(satok): Use arguments instead of dummy input
const int x = 1;
const int y = 299;
const int z = 299;
const int d = 3;
const int array_length = x * y * z * d;
const int byte_size = array_length * sizeof(float);
dummy_input_float_.resize(array_length);
return soc_interface_FillInputNodeFloat(
1, 299, 299, 3, reinterpret_cast<uint8 *>(dummy_input_float_.data()),
byte_size);
}
bool HexagonControlWrapper::ReadOutputNode(
const string node_name, std::vector<ByteArray> *const outputs) const {
const string node_name, std::vector<ByteArray> *const outputs) {
CHECK(outputs != nullptr);
// TODO: Implement
return false;
ByteArray output;
soc_interface_ReadOutputNodeFloat(node_name.c_str(), &std::get<0>(output),
&std::get<1>(output));
std::get<2>(output) = DT_FLOAT;
outputs->emplace_back(output);
return true;
}
#else
int HexagonControlWrapper::GetVersion() { return -1; }
bool HexagonControlWrapper::Init() { return false; }
bool HexagonControlWrapper::Finalize() { return false; }
bool HexagonControlWrapper::SetupGraph(const GraphTransferer &) {
return false;
}
bool HexagonControlWrapper::ExecuteGraph() { return false; }
bool HexagonControlWrapper::TeardownGraph() { return false; }
bool HexagonControlWrapper::FillInputNode(const string, const ByteArray) {
return false;
}
bool HexagonControlWrapper::ReadOutputNode(const string,
std::vector<ByteArray> *const) {
return false;
}
#endif
} // namespace tensorflow

View File

@ -16,6 +16,8 @@ limitations under the License.
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_
#include <vector>
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
#include "tensorflow/core/kernels/hexagon/i_soc_control_wrapper.h"
@ -31,17 +33,20 @@ namespace tensorflow {
class HexagonControlWrapper final : public ISocControlWrapper {
public:
HexagonControlWrapper() = default;
int GetVersion() const final;
int GetVersion() final;
bool Init() final;
bool Finalize() final;
bool SetupGraph(const GraphTransferer &graph_transferer) final;
bool ExecuteGraph() final;
bool TeardownGraph() final;
bool FillInputNode(string node_name, const ByteArray bytes) final;
bool ReadOutputNode(string node_name,
std::vector<ByteArray> *outputs) const final;
bool ReadOutputNode(string node_name, std::vector<ByteArray> *outputs) final;
private:
// Dummy byte array for input node data.
// TODO(satok): Use actual data passed by FillInputNode and remove
std::vector<float> dummy_input_float_;
TF_DISALLOW_COPY_AND_ASSIGN(HexagonControlWrapper);
};

View File

@ -51,7 +51,7 @@ TEST(GraphTransferer, RunInceptionV3OnHexagonExample) {
Status status = gt.LoadGraphFromProtoFile(
*ops_definitions, filename, input_node_info_list, output_node_names,
is_text_proto, true /* dry_run_for_unknown_shape */, &output_tensor_info);
EXPECT_TRUE(status.ok());
EXPECT_TRUE(status.ok()) << status;
HexagonControlWrapper hexagon_control_wrapper;
const int version = hexagon_control_wrapper.GetVersion();

View File

@ -32,7 +32,7 @@ class ISocControlWrapper {
// Return version of SOC controller library.
// This function is mainly for a debug purpose to verify SOC controller.
virtual int GetVersion() const = 0;
virtual int GetVersion() = 0;
// Initialize SOC. This function should be called before
// starting graph transfer.
@ -56,7 +56,7 @@ class ISocControlWrapper {
// Read output node's outputs on SOC
virtual bool ReadOutputNode(string node_name,
std::vector<ByteArray> *outputs) const = 0;
std::vector<ByteArray> *outputs) = 0;
private:
TF_DISALLOW_COPY_AND_ASSIGN(ISocControlWrapper);

View File

@ -46,7 +46,7 @@ class QuantizedMatMulOpForHexagonTest : public OpsTestBase {
LOG(INFO) << "Hexagon libs are linked (wrapper version = "
<< soc_interface_GetWrapperVersion()
<< ", hexagon binary version = "
<< soc_interface_GetHexagonBinaryVersion() << ")";
<< soc_interface_GetSocControllerVersion() << ")";
LOG(INFO) << "Cpu frequency = "
<< profile_utils::CpuUtils::GetCycleCounterFrequency();
#else
@ -67,7 +67,7 @@ TEST_F(QuantizedMatMulOpForHexagonTest, EvaluateSharedLibOverhead) {
(overhead_shared_lib_end - overhead_shared_lib_start);
const uint64 overhead_hexagon_rpc_start =
profile_utils::CpuUtils::GetCurrentClockCycle();
const int hexagon_binary_version = soc_interface_GetHexagonBinaryVersion();
const int hexagon_binary_version = soc_interface_GetSocControllerVersion();
const uint64 overhead_hexagon_rpc_end =
profile_utils::CpuUtils::GetCurrentClockCycle();
const uint64 overhead_hexagon_rpc_diff =

View File

@ -30,8 +30,26 @@ int soc_interface_GetWrapperVersion();
// Returns the version of hexagon binary.
// You should assert that the version matches the expected version before
// calling APIs defined in this header.
int soc_interface_GetHexagonBinaryVersion();
// TODO(satok): Support gemm APIs via RPC
int soc_interface_GetSocControllerVersion();
// Initialize SOC
bool soc_interface_Init();
// Finalize SOC
bool soc_interface_Finalize();
// Execute graph on SOC
bool soc_interface_ExecuteGraph();
// Teardown graph setup
bool soc_interface_TeardownGraph();
// Send input data to SOC
bool soc_interface_FillInputNodeFloat(int x, int y, int z, int d,
const uint8_t* const buf,
uint64_t buf_size);
// Load output data from SOC
bool soc_interface_ReadOutputNodeFloat(const char* const node_name,
uint8_t** buf, uint64_t* buf_size);
// Setup graph
// TODO(satok): Remove and use runtime version
bool soc_interface_SetupGraphDummy(int version);
#ifdef __cplusplus
}
#endif // __cplusplus