LSTM converted to generic GPUOperation.
PiperOrigin-RevId: 328189510 Change-Id: Ic2791b36123d374fa1d3521e66b18dd7b82e5c4a
This commit is contained in:
parent
84258bceed
commit
ae517dba72
@ -24,33 +24,14 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
|
||||
LSTM::LSTM(const OperationDef& definition, const DeviceInfo& device_info)
|
||||
: GPUOperation(definition) {
|
||||
code_ = GetLSTMCode(definition_, device_info);
|
||||
}
|
||||
|
||||
LSTM::LSTM(LSTM&& kernel) : GPUOperation(std::move(kernel)) {}
|
||||
|
||||
LSTM& LSTM::operator=(LSTM&& kernel) {
|
||||
if (this != &kernel) {
|
||||
GPUOperation::operator=(std::move(kernel));
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::string LSTM::GetLSTMCode(const OperationDef& op_def,
|
||||
const DeviceInfo& device_info) {
|
||||
AddSrcTensor("intermediate", op_def.src_tensors[0]);
|
||||
AddSrcTensor("prev_state", op_def.src_tensors[1]);
|
||||
AddDstTensor("new_state", op_def.dst_tensors[0]);
|
||||
AddDstTensor("activation", op_def.dst_tensors[1]);
|
||||
|
||||
namespace {
|
||||
std::string GetLSTMCode(const OperationDef& op_def,
|
||||
const DeviceInfo& device_info) {
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
c += "__kernel void main_function(\n";
|
||||
c += "$0) {\n";
|
||||
c += " int B = get_global_id(0);\n";
|
||||
c += " int Z = get_global_id(1);\n";
|
||||
c += " int Z = get_global_id(2);\n";
|
||||
c += " if (Z >= args.activation.Slices() || B >= args.activation.Batch()) "
|
||||
"return;\n";
|
||||
c += " FLT4 prev_st = args.prev_state.Read(0, 0, Z, B);\n";
|
||||
@ -105,15 +86,18 @@ std::string LSTM::GetLSTMCode(const OperationDef& op_def,
|
||||
return c;
|
||||
}
|
||||
|
||||
int3 LSTM::GetGridSize() const {
|
||||
const int grid_x = dst_[0]->Batch();
|
||||
const int grid_y = dst_[0]->Slices();
|
||||
const int grid_z = 1;
|
||||
return int3(grid_x, grid_y, grid_z);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
LSTM CreateLSTM(const OperationDef& definition, const DeviceInfo& device_info) {
|
||||
return LSTM(definition, device_info);
|
||||
GPUOperation CreateLSTM(const OperationDef& definition,
|
||||
const DeviceInfo& device_info) {
|
||||
GPUOperation op(definition);
|
||||
op.AddSrcTensor("intermediate", definition.src_tensors[0]);
|
||||
op.AddSrcTensor("prev_state", definition.src_tensors[1]);
|
||||
op.AddDstTensor("new_state", definition.dst_tensors[0]);
|
||||
op.AddDstTensor("activation", definition.dst_tensors[1]);
|
||||
op.code_ = GetLSTMCode(definition, device_info);
|
||||
op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
|
||||
return op;
|
||||
}
|
||||
|
||||
} // namespace cl
|
||||
|
@ -25,23 +25,8 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
|
||||
class LSTM : public GPUOperation {
|
||||
public:
|
||||
LSTM(const OperationDef& definition, const DeviceInfo& device_info);
|
||||
int3 GetGridSize() const override;
|
||||
|
||||
// Move only
|
||||
LSTM(LSTM&& kernel);
|
||||
LSTM& operator=(LSTM&& kernel);
|
||||
LSTM(const LSTM&) = delete;
|
||||
LSTM& operator=(const LSTM&) = delete;
|
||||
|
||||
private:
|
||||
std::string GetLSTMCode(const OperationDef& op_def,
|
||||
const DeviceInfo& device_info);
|
||||
};
|
||||
|
||||
LSTM CreateLSTM(const OperationDef& definition, const DeviceInfo& device_info);
|
||||
GPUOperation CreateLSTM(const OperationDef& definition,
|
||||
const DeviceInfo& device_info);
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
|
@ -67,7 +67,7 @@ TEST_F(OpenCLOperationTest, LSTM) {
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::BHWC});
|
||||
TensorFloat32 new_state;
|
||||
TensorFloat32 new_activ;
|
||||
LSTM operation = CreateLSTM(op_def, env_.GetDevicePtr()->info_);
|
||||
GPUOperation operation = CreateLSTM(op_def, env_.GetDevicePtr()->info_);
|
||||
ASSERT_OK(ExecuteGPUOperation(
|
||||
{src_tensor, prev_state}, creation_context_, &operation,
|
||||
{BHWC(1, 1, 1, 4), BHWC(1, 1, 1, 4)}, {&new_state, &new_activ}));
|
||||
|
@ -246,7 +246,7 @@ absl::Status GPUOperationFromNode(const DeviceInfo& device_info,
|
||||
return absl::OkStatus();
|
||||
}
|
||||
case OperationType::LSTM: {
|
||||
SelectLSTM(op_def, device_info, gpu_op);
|
||||
*gpu_op = SelectLSTM(op_def, device_info);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
case OperationType::MAX_UNPOOLING_2D: {
|
||||
|
@ -45,10 +45,9 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
|
||||
void SelectLSTM(const OperationDef& op_def, const DeviceInfo& device_info,
|
||||
std::unique_ptr<GPUOperation>* ptr) {
|
||||
LSTM operation = CreateLSTM(op_def, device_info);
|
||||
*ptr = absl::make_unique<LSTM>(std::move(operation));
|
||||
std::unique_ptr<GPUOperation> SelectLSTM(const OperationDef& op_def,
|
||||
const DeviceInfo& device_info) {
|
||||
return absl::make_unique<GPUOperation>(CreateLSTM(op_def, device_info));
|
||||
}
|
||||
|
||||
std::unique_ptr<GPUOperation> SelectReLU(const ReLUAttributes& attr,
|
||||
|
@ -28,8 +28,8 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
|
||||
void SelectLSTM(const OperationDef& op_def, const DeviceInfo& device_info,
|
||||
std::unique_ptr<GPUOperation>* ptr);
|
||||
std::unique_ptr<GPUOperation> SelectLSTM(const OperationDef& op_def,
|
||||
const DeviceInfo& device_info);
|
||||
|
||||
std::unique_ptr<GPUOperation> SelectReLU(const ReLUAttributes& attr,
|
||||
const OperationDef& op_def);
|
||||
|
Loading…
x
Reference in New Issue
Block a user