Allow state tensors to use device memories in NNAPI delegate.
PiperOrigin-RevId: 347025011 Change-Id: I7441e254d020f89102b85023c5816894b3684b1f
This commit is contained in:
parent
01eab284ac
commit
37c9633c50
@ -449,6 +449,78 @@ ANeuralNetworksOperandType ConvertTensorTypeToNNType(
|
||||
return nn_operand_type;
|
||||
}
|
||||
|
||||
// Copy the CPU buffer of the input tensor to a shared memory address. Will
|
||||
// apply data type conversion if needed. The returned tensor_size is the size
|
||||
// after the potential data type conversion.
|
||||
TfLiteStatus CopyOrConvertInputData(TfLiteContext* context,
|
||||
TfLiteType ann_type_equivalent,
|
||||
bool use_int8_asymm_signed,
|
||||
TfLiteTensor* tensor, uint8_t* dst,
|
||||
int* tensor_size) {
|
||||
if (ann_type_equivalent != kTfLiteNoType) {
|
||||
const auto num_elements = NumElements(tensor);
|
||||
if (tensor->type == kTfLiteUInt8 && ann_type_equivalent == kTfLiteInt32) {
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
reinterpret_cast<int32_t*>(dst)[i] =
|
||||
static_cast<const int32_t>(tensor->data.uint8[i]);
|
||||
}
|
||||
} else if (tensor->type == kTfLiteInt8 &&
|
||||
ann_type_equivalent == kTfLiteUInt8) {
|
||||
// Explicitly convert int8 values to uint8 values.
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
dst[i] = static_cast<const uint8_t>(
|
||||
static_cast<int32_t>(tensor->data.int8[i]) + 128);
|
||||
}
|
||||
} else if (tensor->type == kTfLiteInt8 &&
|
||||
ann_type_equivalent == kTfLiteInt32) {
|
||||
if (use_int8_asymm_signed) {
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
reinterpret_cast<int32_t*>(dst)[i] =
|
||||
static_cast<const int32_t>(tensor->data.int8[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
reinterpret_cast<int32_t*>(dst)[i] =
|
||||
static_cast<const int32_t>(tensor->data.int8[i]) + 128;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context,
|
||||
"NN API Delegate: unsupported tensor types conversion: "
|
||||
"from type code %d to type code %d.\n",
|
||||
tensor->type, ann_type_equivalent);
|
||||
return kTfLiteError;
|
||||
}
|
||||
size_t type_size;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetSizeOfType(context, ann_type_equivalent, &type_size));
|
||||
*tensor_size = NumElements(tensor) * type_size;
|
||||
} else {
|
||||
// copy data to pre-allocated shared memory.
|
||||
memcpy(dst, tensor->data.raw, tensor->bytes);
|
||||
*tensor_size = tensor->bytes;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// Copy into the CPU buffer of the output tensor from a shared memory address.
|
||||
// Will apply data type conversion if needed.
|
||||
TfLiteStatus CopyOrConvertOutputData(TfLiteType ann_type_equivalent,
|
||||
const uint8_t* src, TfLiteTensor* tensor) {
|
||||
if (tensor->type == kTfLiteInt8 && ann_type_equivalent == kTfLiteUInt8) {
|
||||
// Explicitly convert uint8 values to int8 values.
|
||||
int8_t* output_ptr = tensor->data.int8;
|
||||
const auto num_elements = NumElements(tensor);
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
output_ptr[i] = static_cast<int8_t>(static_cast<int32_t>(src[i]) - 128);
|
||||
}
|
||||
} else {
|
||||
memcpy(tensor->data.raw, src, tensor->bytes);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
constexpr size_t kDefaultByteAlignmentForNNAPI = 16;
|
||||
|
||||
static size_t getNumPaddingBytes(size_t byte_size) {
|
||||
@ -3642,6 +3714,7 @@ TfLiteStatus NNAPIDelegateKernel::Prepare(TfLiteContext* context,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
const auto& delegate_data = StatefulNnApiDelegate::GetData(node->delegate);
|
||||
ANeuralNetworksCompilation* compilation = nullptr;
|
||||
if (!nnapi_devices_.empty()) {
|
||||
// Compile for the selected accelerator.
|
||||
@ -3709,6 +3782,67 @@ TfLiteStatus NNAPIDelegateKernel::Prepare(TfLiteContext* context,
|
||||
}
|
||||
RETURN_TFLITE_ERROR_IF_NN_ERROR(context, finish_result,
|
||||
"completing NNAPI compilation", nnapi_errno);
|
||||
|
||||
const bool use_device_memory_for_state_tensors =
|
||||
nnapi_->android_sdk_version >= kMinSdkVersionForNNAPI13 &&
|
||||
delegate_data.use_device_memory_for_state_tensors &&
|
||||
delegate_data.single_partition_delegated &&
|
||||
// State tensors with dynamic shapes are currently not supported.
|
||||
std::all_of(model_state_tfl_inputs_.begin(),
|
||||
model_state_tfl_inputs_.end(), [&context](int tfl_index) {
|
||||
TfLiteTensor* tensor = &context->tensors[tfl_index];
|
||||
return !IsDynamicTensor(tensor);
|
||||
});
|
||||
if (use_device_memory_for_state_tensors) {
|
||||
for (int tfl_index : model_state_tfl_inputs_) {
|
||||
auto& info = nn_state_tensor_info_map_.at(tfl_index);
|
||||
|
||||
// prepare device memory descriptor
|
||||
ANeuralNetworksMemoryDesc* desc = nullptr;
|
||||
RETURN_TFLITE_ERROR_IF_NN_ERROR(
|
||||
context, nnapi_->ANeuralNetworksMemoryDesc_create(&desc),
|
||||
"creating device memory descriptor", nnapi_errno);
|
||||
RETURN_TFLITE_ERROR_IF_NN_ERROR(
|
||||
context,
|
||||
nnapi_->ANeuralNetworksMemoryDesc_addInputRole(
|
||||
desc, compilation, info.nn_input_index, 1.0f),
|
||||
"adding input role to the device memory descriptor", nnapi_errno);
|
||||
RETURN_TFLITE_ERROR_IF_NN_ERROR(
|
||||
context,
|
||||
nnapi_->ANeuralNetworksMemoryDesc_addOutputRole(
|
||||
desc, compilation, info.nn_output_index, 1.0f),
|
||||
"adding output role to the device memory descriptor", nnapi_errno);
|
||||
RETURN_TFLITE_ERROR_IF_NN_ERROR(
|
||||
context, nnapi_->ANeuralNetworksMemoryDesc_finish(desc),
|
||||
"finishing device memory descriptor", nnapi_errno);
|
||||
|
||||
// allocate two device memories for each state tensor
|
||||
ANeuralNetworksMemory* state_input_memory = nullptr;
|
||||
RETURN_TFLITE_ERROR_IF_NN_ERROR(
|
||||
context,
|
||||
nnapi_->ANeuralNetworksMemory_createFromDesc(desc,
|
||||
&state_input_memory),
|
||||
"creating input device memory from the descriptor", nnapi_errno);
|
||||
info.nn_input_memory_handle.reset(state_input_memory);
|
||||
|
||||
ANeuralNetworksMemory* state_output_memory = nullptr;
|
||||
RETURN_TFLITE_ERROR_IF_NN_ERROR(
|
||||
context,
|
||||
nnapi_->ANeuralNetworksMemory_createFromDesc(desc,
|
||||
&state_output_memory),
|
||||
"creating output device memory from the descriptor", nnapi_errno);
|
||||
info.nn_output_memory_handle.reset(state_output_memory);
|
||||
nnapi_->ANeuralNetworksMemoryDesc_free(desc);
|
||||
|
||||
// we need a temporary buffer to sync states to raw pointers
|
||||
TfLiteTensor* tensor = &context->tensors[tfl_index];
|
||||
if (tensor->buffer_handle == kTfLiteNullBufferHandle) {
|
||||
info.nn_temp_buffer.reset(
|
||||
new NNMemory(nnapi_, "temp state tensor", info.tensor_size));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nn_compilation_.reset(compilation);
|
||||
|
||||
return kTfLiteOk;
|
||||
@ -3770,6 +3904,7 @@ TfLiteStatus NNAPIDelegateKernel::Invoke(TfLiteContext* context,
|
||||
// Set compilation timeout if applicable.
|
||||
const auto delegate_options =
|
||||
StatefulNnApiDelegate::GetOptions(node->delegate);
|
||||
const auto& delegate_data = StatefulNnApiDelegate::GetData(node->delegate);
|
||||
if (nnapi_->android_sdk_version >= kMinSdkVersionForNNAPI13) {
|
||||
if (delegate_options.max_execution_timeout_duration_ns > 0) {
|
||||
RETURN_TFLITE_ERROR_IF_NN_ERROR(
|
||||
@ -3835,14 +3970,24 @@ TfLiteStatus NNAPIDelegateKernel::Invoke(TfLiteContext* context,
|
||||
}
|
||||
}
|
||||
|
||||
const bool use_device_memory_for_state_tensors =
|
||||
nnapi_->android_sdk_version >= kMinSdkVersionForNNAPI13 &&
|
||||
delegate_data.use_device_memory_for_state_tensors &&
|
||||
// TODO(b/174612931): Even if the model is not fully supported, we can
|
||||
// still use device memories for state tensors if they are only used in
|
||||
// one single partition.
|
||||
delegate_data.single_partition_delegated &&
|
||||
std::all_of(model_state_tfl_inputs_.begin(),
|
||||
model_state_tfl_inputs_.end(), [&context](int tfl_index) {
|
||||
TfLiteTensor* tensor = &context->tensors[tfl_index];
|
||||
return !IsDynamicTensor(tensor);
|
||||
});
|
||||
|
||||
// Set the input tensor buffers. Note: we access tflite tensors using
|
||||
// absolute indices but NN api indices inputs by relative indices.
|
||||
int relative_input_index = 0;
|
||||
|
||||
const bool use_int8_asymm_signed =
|
||||
target_sdk_version_ >= kMinSdkVersionForNNAPI13;
|
||||
|
||||
size_t input_offset = 0;
|
||||
size_t input_offset_accumulator = 0;
|
||||
for (auto absolute_input_index : TfLiteIntArrayView(node->inputs)) {
|
||||
if (absolute_input_index == kTfLiteOptionalTensor) {
|
||||
continue;
|
||||
@ -3860,90 +4005,58 @@ TfLiteStatus NNAPIDelegateKernel::Invoke(TfLiteContext* context,
|
||||
input_nn_operand_type_ptr = &input_nn_operand_type;
|
||||
}
|
||||
if (tensor->allocation_type != kTfLiteMmapRo) {
|
||||
if (tensor->buffer_handle != kTfLiteNullBufferHandle &&
|
||||
tensor->buffer_handle < tensor_memory_map_->size()) {
|
||||
RETURN_TFLITE_ERROR_IF_NN_ERROR_FOR_TENSOR(
|
||||
context,
|
||||
nnapi_->ANeuralNetworksExecution_setInputFromMemory(
|
||||
execution, relative_input_index, input_nn_operand_type_ptr,
|
||||
tensor_memory_map_->at(tensor->buffer_handle).memory, 0,
|
||||
tensor->bytes),
|
||||
"associating NNAPI execution input with a memory object", tensor,
|
||||
nnapi_errno);
|
||||
relative_input_index++;
|
||||
continue;
|
||||
}
|
||||
int tensor_size = 0;
|
||||
if (ann_type_equivalent != kTfLiteNoType) {
|
||||
const auto num_elements = NumElements(tensor);
|
||||
uint8_t* input_ptr = nn_input_memory_->get_data_ptr() + input_offset;
|
||||
if (tensor->type == kTfLiteUInt8 &&
|
||||
ann_type_equivalent == kTfLiteInt32) {
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
reinterpret_cast<int32_t*>(input_ptr)[i] =
|
||||
static_cast<const int32_t>(tensor->data.uint8[i]);
|
||||
}
|
||||
} else if (tensor->type == kTfLiteInt8 &&
|
||||
ann_type_equivalent == kTfLiteUInt8) {
|
||||
// Explicitly convert int8 values to uint8 values.
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
input_ptr[i] = static_cast<const uint8_t>(
|
||||
static_cast<int32_t>(tensor->data.int8[i]) + 128);
|
||||
}
|
||||
} else if (tensor->type == kTfLiteInt8 &&
|
||||
ann_type_equivalent == kTfLiteInt32) {
|
||||
if (use_int8_asymm_signed) {
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
reinterpret_cast<int32_t*>(input_ptr)[i] =
|
||||
static_cast<const int32_t>(tensor->data.int8[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
reinterpret_cast<int32_t*>(input_ptr)[i] =
|
||||
static_cast<const int32_t>(tensor->data.int8[i]) + 128;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
context->ReportError(
|
||||
context,
|
||||
"NN API Delegate: unsupported tensor types conversion: "
|
||||
"from type code %d to type code %d.\n",
|
||||
tensor->type, ann_type_equivalent);
|
||||
return kTfLiteError;
|
||||
}
|
||||
size_t type_size;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetSizeOfType(context, ann_type_equivalent, &type_size));
|
||||
tensor_size = NumElements(tensor) * type_size;
|
||||
RETURN_TFLITE_ERROR_IF_NN_ERROR_FOR_TENSOR(
|
||||
context,
|
||||
nnapi_->ANeuralNetworksExecution_setInputFromMemory(
|
||||
execution, relative_input_index, input_nn_operand_type_ptr,
|
||||
nn_input_memory_->get_handle(), input_offset, tensor_size),
|
||||
"associating NNAPI execution input with a memory object", tensor,
|
||||
nnapi_errno);
|
||||
ANeuralNetworksMemory* input_memory_handle = nullptr;
|
||||
uint32_t input_offset = 0;
|
||||
uint32_t input_length = 0;
|
||||
const bool is_state_tensor =
|
||||
nn_state_tensor_info_map_.count(absolute_input_index) > 0;
|
||||
if (is_state_tensor && use_device_memory_for_state_tensors &&
|
||||
// If the client requests to sync states to device, we will use the
|
||||
// shared memory directly as input instead of explicitly copying into
|
||||
// the device memory.
|
||||
!delegate_data.sync_states_to_device) {
|
||||
const auto& state_tensor_info =
|
||||
nn_state_tensor_info_map_.at(absolute_input_index);
|
||||
input_memory_handle = state_tensor_info.nn_input_memory_handle.get();
|
||||
input_offset = 0;
|
||||
input_length = 0;
|
||||
} else if (tensor->buffer_handle != kTfLiteNullBufferHandle &&
|
||||
tensor->buffer_handle < tensor_memory_map_->size()) {
|
||||
input_memory_handle =
|
||||
tensor_memory_map_->at(tensor->buffer_handle).memory;
|
||||
input_offset = 0;
|
||||
input_length = tensor->bytes;
|
||||
} else {
|
||||
// copy data to pre-allocated shared memory.
|
||||
memcpy(nn_input_memory_->get_data_ptr() + input_offset,
|
||||
tensor->data.raw, tensor->bytes);
|
||||
RETURN_TFLITE_ERROR_IF_NN_ERROR_FOR_TENSOR(
|
||||
int tensor_size = 0;
|
||||
// copy or convert tensor data to pre-allocated shared memory.
|
||||
const bool use_int8_asymm_signed =
|
||||
target_sdk_version_ >= kMinSdkVersionForNNAPI13;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
nnapi_->ANeuralNetworksExecution_setInputFromMemory(
|
||||
execution, relative_input_index, input_nn_operand_type_ptr,
|
||||
nn_input_memory_->get_handle(), input_offset, tensor->bytes),
|
||||
"associating NNAPI execution input with a memory object", tensor,
|
||||
nnapi_errno);
|
||||
tensor_size = tensor->bytes;
|
||||
CopyOrConvertInputData(
|
||||
context, ann_type_equivalent, use_int8_asymm_signed, tensor,
|
||||
nn_input_memory_->get_data_ptr() + input_offset_accumulator,
|
||||
&tensor_size));
|
||||
input_memory_handle = nn_input_memory_->get_handle();
|
||||
input_offset = input_offset_accumulator;
|
||||
input_length = tensor_size;
|
||||
input_offset_accumulator += tensor_size;
|
||||
input_offset_accumulator += getNumPaddingBytes(tensor_size);
|
||||
}
|
||||
input_offset += tensor_size;
|
||||
input_offset += getNumPaddingBytes(tensor_size);
|
||||
RETURN_TFLITE_ERROR_IF_NN_ERROR_FOR_TENSOR(
|
||||
context,
|
||||
nnapi_->ANeuralNetworksExecution_setInputFromMemory(
|
||||
execution, relative_input_index, input_nn_operand_type_ptr,
|
||||
input_memory_handle, input_offset, input_length),
|
||||
"associating NNAPI execution input with a memory object", tensor,
|
||||
nnapi_errno);
|
||||
relative_input_index++;
|
||||
}
|
||||
}
|
||||
|
||||
// Set the output tensor buffers.
|
||||
int relative_output_index = 0;
|
||||
size_t output_offset = 0;
|
||||
size_t output_offset_accumulator = 0;
|
||||
for (auto output_index : TfLiteIntArrayView(node->outputs)) {
|
||||
// If the NNAPI implementation doesn't have some of the outputs
|
||||
// they are left unmapped and we should not try to read their value here
|
||||
@ -3977,11 +4090,12 @@ TfLiteStatus NNAPIDelegateKernel::Invoke(TfLiteContext* context,
|
||||
context,
|
||||
nnapi_->ANeuralNetworksExecution_setOutputFromMemory(
|
||||
execution, relative_output_index, output_nn_operand_type_ptr,
|
||||
nn_output_memory_->get_handle(), output_offset, tensor->bytes),
|
||||
nn_output_memory_->get_handle(), output_offset_accumulator,
|
||||
tensor->bytes),
|
||||
"associating NNAPI execution output to a memory object", tensor,
|
||||
nnapi_errno);
|
||||
output_offset += tensor->bytes;
|
||||
output_offset += getNumPaddingBytes(tensor->bytes);
|
||||
output_offset_accumulator += tensor->bytes;
|
||||
output_offset_accumulator += getNumPaddingBytes(tensor->bytes);
|
||||
}
|
||||
relative_output_index++;
|
||||
}
|
||||
@ -3990,16 +4104,27 @@ TfLiteStatus NNAPIDelegateKernel::Invoke(TfLiteContext* context,
|
||||
// current invocation.
|
||||
for (size_t i = 0; i < model_state_tfl_inputs_.size(); i++) {
|
||||
int state_tensor_idx = model_state_tfl_inputs_[i];
|
||||
TfLiteTensor* tensor = &context->tensors[state_tensor_idx];
|
||||
// Here we are using a deep copy for state_in tensors so that we are not
|
||||
// reading and writing into the same buffer during a invocation.
|
||||
// TODO(b/110369471): using double shared buffer to minimize the copies.
|
||||
RETURN_TFLITE_ERROR_IF_NN_ERROR(
|
||||
context,
|
||||
nnapi_->ANeuralNetworksExecution_setOutput(
|
||||
execution, relative_output_index, nullptr, tensor->data.raw,
|
||||
tensor->bytes),
|
||||
"associating NNAPI execution output to a buffer", nnapi_errno);
|
||||
if (use_device_memory_for_state_tensors) {
|
||||
auto* device_memory = nn_state_tensor_info_map_.at(state_tensor_idx)
|
||||
.nn_output_memory_handle.get();
|
||||
RETURN_TFLITE_ERROR_IF_NN_ERROR(
|
||||
context,
|
||||
nnapi_->ANeuralNetworksExecution_setOutputFromMemory(
|
||||
execution, relative_output_index, nullptr, device_memory, 0, 0),
|
||||
"associating NNAPI execution output with a device memory object",
|
||||
nnapi_errno);
|
||||
} else {
|
||||
TfLiteTensor* tensor = &context->tensors[state_tensor_idx];
|
||||
// Here we are using a deep copy for state_in tensors so that we are not
|
||||
// reading and writing into the same buffer during a invocation.
|
||||
// TODO(b/110369471): using double shared buffer to minimize the copies.
|
||||
RETURN_TFLITE_ERROR_IF_NN_ERROR(
|
||||
context,
|
||||
nnapi_->ANeuralNetworksExecution_setOutput(
|
||||
execution, relative_output_index, nullptr, tensor->data.raw,
|
||||
tensor->bytes),
|
||||
"associating NNAPI execution output to a buffer", nnapi_errno);
|
||||
}
|
||||
relative_output_index++;
|
||||
}
|
||||
// Invoke ANN in blocking fashion.
|
||||
@ -4022,39 +4147,70 @@ TfLiteStatus NNAPIDelegateKernel::Invoke(TfLiteContext* context,
|
||||
}
|
||||
|
||||
// copy results from shared memory to the destination.
|
||||
output_offset = 0;
|
||||
output_offset_accumulator = 0;
|
||||
for (auto output_index : TfLiteIntArrayView(node->outputs)) {
|
||||
TfLiteTensor* tensor = &context->tensors[output_index];
|
||||
if (tensor->buffer_handle != kTfLiteNullBufferHandle) {
|
||||
continue;
|
||||
}
|
||||
TfLiteType ann_type_equivalent =
|
||||
const TfLiteType ann_type_equivalent =
|
||||
operand_mapping_.lite_index_to_ann_type_conversion(output_index);
|
||||
if (tensor->type == kTfLiteInt8 && ann_type_equivalent == kTfLiteUInt8) {
|
||||
// Explicitly convert uint8 values to int8 values.
|
||||
uint8_t* output_ptr = reinterpret_cast<uint8_t*>(
|
||||
nn_output_memory_->get_data_ptr() + output_offset);
|
||||
const auto num_elements = NumElements(tensor);
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
output_ptr[i] =
|
||||
static_cast<uint8_t>(static_cast<int32_t>(output_ptr[i]) - 128);
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, CopyOrConvertOutputData(ann_type_equivalent,
|
||||
nn_output_memory_->get_data_ptr() +
|
||||
output_offset_accumulator,
|
||||
tensor));
|
||||
output_offset_accumulator += tensor->bytes;
|
||||
output_offset_accumulator += getNumPaddingBytes(tensor->bytes);
|
||||
}
|
||||
|
||||
// sync state tensors from device memories
|
||||
if (use_device_memory_for_state_tensors &&
|
||||
delegate_data.sync_states_from_device) {
|
||||
for (auto& [tfl_index, info] : nn_state_tensor_info_map_) {
|
||||
TfLiteTensor* tensor = &context->tensors[tfl_index];
|
||||
if (tensor->buffer_handle != kTfLiteNullBufferHandle &&
|
||||
tensor->buffer_handle < tensor_memory_map_->size()) {
|
||||
RETURN_TFLITE_ERROR_IF_NN_ERROR(
|
||||
context,
|
||||
nnapi_->ANeuralNetworksMemory_copy(
|
||||
info.nn_output_memory_handle.get(),
|
||||
tensor_memory_map_->at(tensor->buffer_handle).memory),
|
||||
"syncing device memory from device", nnapi_errno);
|
||||
} else {
|
||||
// For pointer tensor data, we need to copy twice:
|
||||
// 1. device memory -> shared memory
|
||||
// 2. shared memory -> raw pointer
|
||||
// The second copy may also need type conversion from uint8 -> int8.
|
||||
RETURN_TFLITE_ERROR_IF_NN_ERROR(context,
|
||||
nnapi_->ANeuralNetworksMemory_copy(
|
||||
info.nn_output_memory_handle.get(),
|
||||
info.nn_temp_buffer->get_handle()),
|
||||
"syncing device memory from device",
|
||||
nnapi_errno);
|
||||
const TfLiteType ann_type_equivalent =
|
||||
operand_mapping_.lite_index_to_ann_type_conversion(tfl_index);
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
CopyOrConvertOutputData(
|
||||
ann_type_equivalent,
|
||||
info.nn_temp_buffer->get_data_ptr(), tensor));
|
||||
}
|
||||
}
|
||||
memcpy(tensor->data.raw, nn_output_memory_->get_data_ptr() + output_offset,
|
||||
tensor->bytes);
|
||||
output_offset += tensor->bytes;
|
||||
output_offset += getNumPaddingBytes(tensor->bytes);
|
||||
}
|
||||
|
||||
// swap device memory handles so that the state output of the current
|
||||
// invocation will be used as the state input of the next invocation
|
||||
if (use_device_memory_for_state_tensors) {
|
||||
for (auto& [tfl_index, info] : nn_state_tensor_info_map_) {
|
||||
std::swap(info.nn_input_memory_handle, info.nn_output_memory_handle);
|
||||
}
|
||||
}
|
||||
|
||||
// copy output of all output tensors in feedback_loops_ into the
|
||||
// associated input
|
||||
for (auto feedback_loop : feedback_loops_) {
|
||||
int output_tensor_idx;
|
||||
int input_tensor_idx;
|
||||
std::tie(output_tensor_idx, input_tensor_idx) = feedback_loop;
|
||||
for (auto [output_tensor_idx, input_tensor_idx] : feedback_loops_) {
|
||||
TfLiteTensor& src = context->tensors[output_tensor_idx];
|
||||
TfLiteTensor& dest = context->tensors[input_tensor_idx];
|
||||
|
||||
memcpy(dest.data.raw, src.data.raw, src.bytes);
|
||||
}
|
||||
|
||||
@ -4622,6 +4778,17 @@ TfLiteStatus NNAPIDelegateKernel::BuildGraph(
|
||||
std::vector<uint32_t> outputs;
|
||||
outputs.reserve(output_tensors->size);
|
||||
|
||||
for (int tfl_index : model_state_tfl_inputs_) {
|
||||
NNStateTensorInfo info = {
|
||||
.nn_input_memory_handle =
|
||||
std::unique_ptr<ANeuralNetworksMemory, NNFreeMemory>(
|
||||
nullptr, NNFreeMemory(nnapi_)),
|
||||
.nn_output_memory_handle =
|
||||
std::unique_ptr<ANeuralNetworksMemory, NNFreeMemory>(
|
||||
nullptr, NNFreeMemory(nnapi_))};
|
||||
nn_state_tensor_info_map_.emplace(tfl_index, std::move(info));
|
||||
}
|
||||
|
||||
size_t total_input_byte_size = 0;
|
||||
// Make the TensorFlow Lite inputs and outputs to ann_indices.
|
||||
for (int i : TfLiteIntArrayView(input_tensors)) {
|
||||
@ -4631,10 +4798,6 @@ TfLiteStatus NNAPIDelegateKernel::BuildGraph(
|
||||
// The delegate might not have mapped this input (this can
|
||||
// happen if one tensor is split in several ones)
|
||||
operand_mapping_.lite_index_to_ann(i) != -1) {
|
||||
inputs.push_back(operand_mapping_.lite_index_to_ann(i));
|
||||
if (context->tensors[i].buffer_handle != kTfLiteNullBufferHandle) {
|
||||
continue;
|
||||
}
|
||||
const TfLiteType nn_type_conversion =
|
||||
operand_mapping_.lite_index_to_ann_type_conversion(i);
|
||||
int tensor_size = 0;
|
||||
@ -4646,6 +4809,15 @@ TfLiteStatus NNAPIDelegateKernel::BuildGraph(
|
||||
context, GetSizeOfType(context, nn_type_conversion, &type_size));
|
||||
tensor_size = NumElements(&context->tensors[i]) * type_size;
|
||||
}
|
||||
if (auto it = nn_state_tensor_info_map_.find(i);
|
||||
it != nn_state_tensor_info_map_.end()) {
|
||||
it->second.nn_input_index = inputs.size();
|
||||
it->second.tensor_size = tensor_size;
|
||||
}
|
||||
inputs.push_back(operand_mapping_.lite_index_to_ann(i));
|
||||
if (context->tensors[i].buffer_handle != kTfLiteNullBufferHandle) {
|
||||
continue;
|
||||
}
|
||||
total_input_byte_size += tensor_size;
|
||||
total_input_byte_size += getNumPaddingBytes(tensor_size);
|
||||
}
|
||||
@ -4666,8 +4838,11 @@ TfLiteStatus NNAPIDelegateKernel::BuildGraph(
|
||||
}
|
||||
|
||||
// Add state output tensors as model outputs.
|
||||
for (int i : model_state_outputs_) {
|
||||
outputs.push_back(i);
|
||||
for (int i = 0; i < model_state_outputs_.size(); i++) {
|
||||
const int tfl_index = model_state_tfl_inputs_[i];
|
||||
const int nn_model_index = model_state_outputs_[i];
|
||||
nn_state_tensor_info_map_.at(tfl_index).nn_output_index = outputs.size();
|
||||
outputs.push_back(nn_model_index);
|
||||
}
|
||||
|
||||
// Tell ANN to declare inputs/outputs
|
||||
@ -4772,6 +4947,8 @@ StatefulNnApiDelegate::StatefulNnApiDelegate(const NnApi* nnapi,
|
||||
if (nnapi->android_sdk_version >= kMinSdkVersionForNNAPI11) {
|
||||
delegate_data_.allow_dynamic_dimensions = options.allow_dynamic_dimensions;
|
||||
}
|
||||
delegate_data_.use_device_memory_for_state_tensors =
|
||||
options.use_device_memory_for_state_tensors;
|
||||
TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
|
||||
"Created TensorFlow Lite delegate for NNAPI.");
|
||||
Prepare = DoPrepare;
|
||||
@ -4814,9 +4991,17 @@ const StatefulNnApiDelegate::Options StatefulNnApiDelegate::GetOptions(
|
||||
options.max_execution_loop_timeout_duration_ns =
|
||||
delegate_data->max_execution_loop_timeout_duration_ns;
|
||||
options.allow_dynamic_dimensions = delegate_data->allow_dynamic_dimensions;
|
||||
options.use_device_memory_for_state_tensors =
|
||||
delegate_data->use_device_memory_for_state_tensors;
|
||||
return options;
|
||||
}
|
||||
|
||||
const StatefulNnApiDelegate::Data& StatefulNnApiDelegate::GetData(
|
||||
TfLiteDelegate* delegate) {
|
||||
auto* delegate_data = reinterpret_cast<Data*>(delegate->data_);
|
||||
return *delegate_data;
|
||||
}
|
||||
|
||||
const std::vector<StatefulNnApiDelegate::MemoryRegistration>&
|
||||
StatefulNnApiDelegate::GetTensorMemoryMap(TfLiteDelegate* delegate) {
|
||||
auto delegate_data = reinterpret_cast<Data*>(delegate->data_);
|
||||
@ -4877,6 +5062,24 @@ int StatefulNnApiDelegate::GetNnApiErrno() const {
|
||||
return delegate_data_.nnapi_errno;
|
||||
}
|
||||
|
||||
TfLiteStatus StatefulNnApiDelegate::SetSyncStatesToDevice(
|
||||
bool sync_states_to_device) {
|
||||
if (!delegate_data_.use_device_memory_for_state_tensors) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
delegate_data_.sync_states_to_device = sync_states_to_device;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus StatefulNnApiDelegate::SetSyncStatesFromDevice(
|
||||
bool sync_states_from_device) {
|
||||
if (!delegate_data_.use_device_memory_for_state_tensors) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
delegate_data_.sync_states_from_device = sync_states_from_device;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// static
|
||||
TfLiteStatus StatefulNnApiDelegate::GetNodesSupportedByAccelerator(
|
||||
TfLiteContext* context, TfLiteDelegate* delegate, const NnApi* nnapi,
|
||||
@ -4908,9 +5111,9 @@ TfLiteStatus StatefulNnApiDelegate::GetNodesSupportedByAccelerator(
|
||||
supported_partition_nodes.begin(),
|
||||
supported_partition_nodes.end());
|
||||
|
||||
bool model_fully_supported = (supported_partition_nodes.size() ==
|
||||
partition_params.nodes_to_replace->size);
|
||||
if (model_fully_supported) {
|
||||
bool single_partition_delegated = (supported_partition_nodes.size() ==
|
||||
partition_params.nodes_to_replace->size);
|
||||
if (single_partition_delegated) {
|
||||
delegate_data->CacheDelegateKernel(&partition_params,
|
||||
kernel_state.release());
|
||||
}
|
||||
@ -5125,6 +5328,10 @@ TfLiteStatus StatefulNnApiDelegate::DoPrepare(TfLiteContext* context,
|
||||
params_array, params_array + num_partitions),
|
||||
&nodes_to_delegate));
|
||||
|
||||
if (!nodes_to_delegate.empty() && num_partitions == 1) {
|
||||
delegate_data->single_partition_delegated = true;
|
||||
}
|
||||
|
||||
if (nodes_to_delegate.empty()) {
|
||||
return kTfLiteOk;
|
||||
} else {
|
||||
|
@ -125,6 +125,33 @@ class StatefulNnApiDelegate : public TfLiteDelegate {
|
||||
// accelerator. This should only be enabled if the target device supports
|
||||
// dynamic dimensions of the model.
|
||||
bool allow_dynamic_dimensions = false;
|
||||
|
||||
// When set to true, the delegate will allocate device memory for state
|
||||
// tensors to reduce data copying and transformation overhead. In such a
|
||||
// case, the user must explicitly specify whether they would like to sync
|
||||
// states between host and device before and after each invocation by
|
||||
// SetSyncStatesToDevice and SetSyncStatesFromDevice. The following code
|
||||
// snippet demonstrates the usage:
|
||||
//
|
||||
// StatefulNnapiDelegate::Options options;
|
||||
// options.use_device_memory_for_state_tensors = true;
|
||||
// ...
|
||||
//
|
||||
// for (int i = 0; i < sequence_size; i++) {
|
||||
// ...
|
||||
//
|
||||
// // Push initial states to the device before the first invocation.
|
||||
// delegate->SetSyncStatesToDevice(i == 0);
|
||||
//
|
||||
// // Get states data back to the host CPU buffer after the final
|
||||
// // invocation.
|
||||
// delegate->SetSyncStatesFromDevice(i == sequence_size - 1);
|
||||
//
|
||||
// interpreter->Invoke();
|
||||
// }
|
||||
//
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
bool use_device_memory_for_state_tensors = false;
|
||||
};
|
||||
|
||||
// Uses default options.
|
||||
@ -186,7 +213,23 @@ class StatefulNnApiDelegate : public TfLiteDelegate {
|
||||
// (i.e. when calling interpreter.ModifyGraphWithDelegate(delegate)).
|
||||
int GetNnApiErrno() const;
|
||||
|
||||
// Specifies whether the device memories should be initialized from the
|
||||
// content of CPU buffers of state tensors before the execution or not.
|
||||
// Will return an error if the delegate is not initialized with
|
||||
// use_device_memory_for_state_tensors set to true.
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
TfLiteStatus SetSyncStatesToDevice(bool sync_states_to_device);
|
||||
|
||||
// Specifies whether the device memories should be copied to the content of
|
||||
// CPU buffers of state tensors after the execution or not.
|
||||
// Will return an error if the delegate is not initialized with
|
||||
// use_device_memory_for_state_tensors set to true.
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
TfLiteStatus SetSyncStatesFromDevice(bool sync_states_from_device);
|
||||
|
||||
private:
|
||||
friend NNAPIDelegateKernel;
|
||||
|
||||
// Encapsulates all delegate data.
|
||||
struct Data {
|
||||
// Pointer to NNAPI implementation to be used by this delegate as
|
||||
@ -235,6 +278,17 @@ class StatefulNnApiDelegate : public TfLiteDelegate {
|
||||
uint64_t max_execution_loop_timeout_duration_ns = 0;
|
||||
// Whether to allow dynamic dimension sizes without re-compilation.
|
||||
bool allow_dynamic_dimensions = false;
|
||||
// When set to true, the delegate will allocate device memories for state
|
||||
// tensors to reduce data copying and transformation overhead.
|
||||
bool use_device_memory_for_state_tensors = false;
|
||||
// When set to true, the device memories will be initialized from the
|
||||
// content of CPU buffers of state tensors before the execution.
|
||||
bool sync_states_to_device = false;
|
||||
// When set to true, the device memories will be copied to the content of
|
||||
// CPU buffers of state tensors after the execution.
|
||||
bool sync_states_from_device = false;
|
||||
// Whether the model is fully supported by the delegate.
|
||||
bool single_partition_delegated = false;
|
||||
|
||||
explicit Data(const NnApi* nnapi);
|
||||
~Data();
|
||||
@ -248,6 +302,9 @@ class StatefulNnApiDelegate : public TfLiteDelegate {
|
||||
const TfLiteDelegateParams* delegate_params);
|
||||
};
|
||||
|
||||
// Returns the delegate data.
|
||||
static const Data& GetData(TfLiteDelegate* delegate);
|
||||
|
||||
// Implements TfLiteDelegate::Prepare. Please refer to TFLiteDelegate
|
||||
// documentation for more info.
|
||||
static TfLiteStatus DoPrepare(TfLiteContext* context,
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/allocation.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
||||
#include "tensorflow/lite/nnapi/NeuralNetworksTypes.h"
|
||||
#include "tensorflow/lite/nnapi/nnapi_implementation.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -154,6 +155,18 @@ class NNFreeExecution {
|
||||
// NnApi instance to use. Not owned by this object.
|
||||
const NnApi* nnapi_;
|
||||
};
|
||||
// RAII NN API Memory Destructor for use with std::unique_ptr
|
||||
class NNFreeMemory {
|
||||
public:
|
||||
explicit NNFreeMemory(const NnApi* nnapi) : nnapi_(nnapi) {}
|
||||
void operator()(ANeuralNetworksMemory* memory) {
|
||||
nnapi_->ANeuralNetworksMemory_free(memory);
|
||||
}
|
||||
|
||||
private:
|
||||
// NnApi instance to use. Not owned by this object.
|
||||
const NnApi* nnapi_;
|
||||
};
|
||||
|
||||
// Manage NNAPI shared memory handle
|
||||
class NNMemory {
|
||||
@ -175,6 +188,19 @@ class NNMemory {
|
||||
ANeuralNetworksMemory* nn_memory_handle_ = nullptr;
|
||||
};
|
||||
|
||||
// Basic info and NN device memory handles for state tensors.
|
||||
struct NNStateTensorInfo {
|
||||
uint32_t nn_input_index = 0;
|
||||
uint32_t nn_output_index = 0;
|
||||
// The size of the NN state tensor after applying any potential data type
|
||||
// conversion.
|
||||
int tensor_size = 0;
|
||||
std::unique_ptr<ANeuralNetworksMemory, NNFreeMemory> nn_input_memory_handle;
|
||||
std::unique_ptr<ANeuralNetworksMemory, NNFreeMemory> nn_output_memory_handle;
|
||||
// The shared memory used to sync the state from the device.
|
||||
std::unique_ptr<NNMemory> nn_temp_buffer;
|
||||
};
|
||||
|
||||
|
||||
enum class NNAPIValidationFailureType : int {
|
||||
// The operator is not supported by either NNAPI or the NNAPI Delegate.
|
||||
@ -340,6 +366,9 @@ class NNAPIDelegateKernel {
|
||||
// data available for TFLite model users
|
||||
std::vector<std::tuple<int, int>> feedback_loops_;
|
||||
|
||||
// TfLite index -> state tensor info.
|
||||
std::map<int, NNStateTensorInfo> nn_state_tensor_info_map_;
|
||||
|
||||
std::unique_ptr<NNMemory> nn_input_memory_;
|
||||
std::unique_ptr<NNMemory> nn_output_memory_;
|
||||
|
||||
|
@ -2718,24 +2718,15 @@ class RNNOpModel : public SingleOpModelWithNNAPI {
|
||||
public:
|
||||
RNNOpModel(int batches, int units, int size,
|
||||
const TensorType weights = TensorType_FLOAT32,
|
||||
const TensorType recurrent_weights = TensorType_FLOAT32) {
|
||||
Init(batches, units, size, weights, recurrent_weights);
|
||||
}
|
||||
|
||||
RNNOpModel(const StatefulNnApiDelegate::Options& options, int batches,
|
||||
int units, int size, const TensorType weights = TensorType_FLOAT32,
|
||||
const TensorType recurrent_weights = TensorType_FLOAT32)
|
||||
: batches_(batches), units_(units), input_size_(size) {
|
||||
input_ = AddInput(TensorType_FLOAT32);
|
||||
weights_ = AddInput(weights);
|
||||
recurrent_weights_ = AddInput(recurrent_weights);
|
||||
bias_ = AddInput(TensorType_FLOAT32);
|
||||
hidden_state_ = AddVariableInput(TensorType_FLOAT32);
|
||||
output_ = AddOutput(TensorType_FLOAT32);
|
||||
SetBuiltinOp(
|
||||
BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
|
||||
CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
|
||||
BuildInterpreterWithNNAPI({
|
||||
{batches_, input_size_}, // input tensor
|
||||
{units_, input_size_}, // weights tensor
|
||||
{units_, units_}, // recurrent weights tensor
|
||||
{units_}, // bias tensor
|
||||
{batches_, units_} // hidden state tensor
|
||||
});
|
||||
: SingleOpModelWithNNAPI(options) {
|
||||
Init(batches, units, size, weights, recurrent_weights);
|
||||
}
|
||||
|
||||
void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
|
||||
@ -2756,8 +2747,16 @@ class RNNOpModel : public SingleOpModelWithNNAPI {
|
||||
PopulateTensor(input_, offset, begin, end);
|
||||
}
|
||||
|
||||
void SetHiddenState(const std::vector<float>& data) {
|
||||
PopulateTensor(hidden_state_, data);
|
||||
}
|
||||
|
||||
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
|
||||
|
||||
std::vector<float> GetHiddenState() {
|
||||
return ExtractVector<float>(hidden_state_);
|
||||
}
|
||||
|
||||
int input_size() { return input_size_; }
|
||||
int num_units() { return units_; }
|
||||
int num_batches() { return batches_; }
|
||||
@ -2773,8 +2772,50 @@ class RNNOpModel : public SingleOpModelWithNNAPI {
|
||||
int batches_;
|
||||
int units_;
|
||||
int input_size_;
|
||||
|
||||
private:
|
||||
// Performs initialization logic shared across all constructors.
|
||||
void Init(int batches, int units, int size, const TensorType weights,
|
||||
const TensorType recurrent_weights) {
|
||||
batches_ = batches;
|
||||
units_ = units;
|
||||
input_size_ = size;
|
||||
input_ = AddInput(TensorType_FLOAT32);
|
||||
weights_ = AddInput(weights);
|
||||
recurrent_weights_ = AddInput(recurrent_weights);
|
||||
bias_ = AddInput(TensorType_FLOAT32);
|
||||
hidden_state_ = AddVariableInput(TensorType_FLOAT32);
|
||||
output_ = AddOutput(TensorType_FLOAT32);
|
||||
SetBuiltinOp(
|
||||
BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
|
||||
CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
|
||||
BuildInterpreterWithNNAPI({
|
||||
{batches_, input_size_}, // input tensor
|
||||
{units_, input_size_}, // weights tensor
|
||||
{units_, units_}, // recurrent weights tensor
|
||||
{units_}, // bias tensor
|
||||
{batches_, units_} // hidden state tensor
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
static void InvokeAndTestSingleRnnStep(int step_index, RNNOpModel* rnn) {
|
||||
float* batch_start = rnn_input + step_index * rnn->input_size();
|
||||
float* batch_end = batch_start + rnn->input_size();
|
||||
rnn->SetInput(0, batch_start, batch_end);
|
||||
rnn->SetInput(rnn->input_size(), batch_start, batch_end);
|
||||
|
||||
rnn->Invoke();
|
||||
|
||||
float* golden_start = rnn_golden_output + step_index * rnn->num_units();
|
||||
float* golden_end = golden_start + rnn->num_units();
|
||||
std::vector<float> expected;
|
||||
expected.insert(expected.end(), golden_start, golden_end);
|
||||
expected.insert(expected.end(), golden_start, golden_end);
|
||||
|
||||
EXPECT_THAT(rnn->GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
|
||||
}
|
||||
|
||||
TEST(NNAPIDelegate, RnnBlackBoxTest) {
|
||||
RNNOpModel rnn(2, 16, 8);
|
||||
rnn.SetWeights(rnn_weights);
|
||||
@ -2785,20 +2826,66 @@ TEST(NNAPIDelegate, RnnBlackBoxTest) {
|
||||
(rnn.input_size() * rnn.num_batches());
|
||||
|
||||
for (int i = 0; i < input_sequence_size; i++) {
|
||||
float* batch_start = rnn_input + i * rnn.input_size();
|
||||
float* batch_end = batch_start + rnn.input_size();
|
||||
rnn.SetInput(0, batch_start, batch_end);
|
||||
rnn.SetInput(rnn.input_size(), batch_start, batch_end);
|
||||
InvokeAndTestSingleRnnStep(i, &rnn);
|
||||
}
|
||||
}
|
||||
|
||||
rnn.Invoke();
|
||||
TEST(NNAPIDelegate, RnnDeviceMemoryBasicTest) {
|
||||
StatefulNnApiDelegate::Options options;
|
||||
options.use_device_memory_for_state_tensors = true;
|
||||
|
||||
float* golden_start = rnn_golden_output + i * rnn.num_units();
|
||||
float* golden_end = golden_start + rnn.num_units();
|
||||
std::vector<float> expected;
|
||||
expected.insert(expected.end(), golden_start, golden_end);
|
||||
expected.insert(expected.end(), golden_start, golden_end);
|
||||
RNNOpModel rnn(options, 2, 16, 8);
|
||||
rnn.SetWeights(rnn_weights);
|
||||
rnn.SetBias(rnn_bias);
|
||||
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
||||
|
||||
EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
|
||||
auto* delegate = rnn.GetDelegate();
|
||||
const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
|
||||
(rnn.input_size() * rnn.num_batches());
|
||||
|
||||
// Only sync the state to device in the first invocation, all subsequent
|
||||
// states are kept inside the driver.
|
||||
for (int i = 0; i < input_sequence_size; i++) {
|
||||
delegate->SetSyncStatesToDevice(i == 0);
|
||||
InvokeAndTestSingleRnnStep(i, &rnn);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NNAPIDelegate, RnnDeviceMemorySyncTest) {
|
||||
StatefulNnApiDelegate::Options options;
|
||||
options.use_device_memory_for_state_tensors = true;
|
||||
|
||||
RNNOpModel rnn(options, 2, 16, 8);
|
||||
rnn.SetWeights(rnn_weights);
|
||||
rnn.SetBias(rnn_bias);
|
||||
rnn.SetRecurrentWeights(rnn_recurrent_weights);
|
||||
|
||||
auto* delegate = rnn.GetDelegate();
|
||||
const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
|
||||
(rnn.input_size() * rnn.num_batches());
|
||||
const int sync_output_index = input_sequence_size / 2;
|
||||
|
||||
// The following steps test SetSyncStatesFromDevice and SetSyncStatesToDevice:
|
||||
// 1. Invoke RNN sequence until sync_output_index;
|
||||
// 2. Extract the hidden output state at sync_output_index by
|
||||
// SetSyncStatesFromDevice(true);
|
||||
// 3. Continue RNN sequence until the end;
|
||||
// 4. Reset the hidden state by SetSyncStatesToDevice(true), the state should
|
||||
// go back to sync_output_index;
|
||||
// 5. Continue RNN sequence from sync_output_index + 1 until the end.
|
||||
std::vector<float> hidden_state_data;
|
||||
for (int i = 0; i < input_sequence_size; i++) {
|
||||
delegate->SetSyncStatesToDevice(i == 0);
|
||||
delegate->SetSyncStatesFromDevice(i == sync_output_index);
|
||||
InvokeAndTestSingleRnnStep(i, &rnn);
|
||||
if (i == sync_output_index) {
|
||||
hidden_state_data = rnn.GetHiddenState();
|
||||
}
|
||||
}
|
||||
rnn.SetHiddenState(hidden_state_data);
|
||||
for (int i = sync_output_index + 1; i < input_sequence_size; i++) {
|
||||
delegate->SetSyncStatesToDevice(i == (sync_output_index + 1));
|
||||
InvokeAndTestSingleRnnStep(i, &rnn);
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user