diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/lstm.cc b/tensorflow/lite/delegates/gpu/cl/kernels/lstm.cc index 77eea07f278..4732d35e987 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/lstm.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/lstm.cc @@ -26,39 +26,34 @@ namespace gpu { namespace cl { namespace { -std::string GetLSTMCode(const OperationDef& op_def, const CLDevice& device) { - const WHSBPoint state_size{"1", "1", "state_size.z", "state_size.w"}; - const WHSBPoint src_size{"1", "1", "src_size.z", "src_size.w"}; - - TensorCodeGenerator intermediate("src_data", src_size, op_def.src_tensors[0]); - TensorCodeGenerator prev_state("prev_state", state_size, - op_def.src_tensors[1]); - - TensorCodeGenerator activation("dst_data", state_size, op_def.dst_tensors[0]); - TensorCodeGenerator new_state("new_state", state_size, op_def.dst_tensors[1]); +std::string GetLSTMCode(const OperationDef& op_def, const CLDevice& device, + Arguments* args) { + args->AddObjectRef( + "intermediate", AccessType::READ, + absl::make_unique(op_def.src_tensors[0])); + args->AddObjectRef( + "prev_state", AccessType::READ, + absl::make_unique(op_def.src_tensors[1])); + args->AddObjectRef( + "new_state", AccessType::WRITE, + absl::make_unique(op_def.dst_tensors[0])); + args->AddObjectRef( + "activation", AccessType::WRITE, + absl::make_unique(op_def.dst_tensors[1])); std::string c = GetCommonDefines(op_def.precision); - c += "__kernel void main_function(\n"; - c += intermediate.GetDeclaration(AccessType::READ) + ",\n"; - c += prev_state.GetDeclaration(AccessType::READ) + ",\n"; - c += new_state.GetDeclaration(AccessType::WRITE) + ",\n"; - c += activation.GetDeclaration(AccessType::WRITE) + ",\n"; - c += " int4 src_size, \n"; - c += " int4 state_size, \n"; - c += " int BATCH_SIZE \n"; - c += ") {\n"; + c += "$0) {\n"; c += " int B = get_global_id(0);\n"; c += " int Z = get_global_id(1);\n"; - c += " if (Z >= state_size.z || B >= state_size.w) return;\n"; - c += " FLT4 prev_st = " + prev_state.ReadWHSB("0", "0", "Z", "B") + ";\n"; - c += " FLT4 r0 = " + intermediate.ReadWHSB("0", "0", "Z", "B") + ";\n"; - c += " FLT4 r1 = " + - intermediate.ReadWHSB("0", "0", "Z + state_size.z", "B") + ";\n"; - c += " FLT4 r2 = " + - intermediate.ReadWHSB("0", "0", "Z + state_size.z * 2", "B") + ";\n"; - c += " FLT4 r3 = " + - intermediate.ReadWHSB("0", "0", "Z + state_size.z * 3", "B") + ";\n"; + c += " if (Z >= args.activation.Slices() || B >= args.activation.Batch()) " + "return;\n"; + c += " FLT4 prev_st = args.prev_state.Read(0, 0, Z, B);\n"; + c += " FLT4 r0 = args.intermediate.Read(0, 0, Z, B);\n"; + c += " int state_stride = args.activation.Slices();\n"; + c += " FLT4 r1 = args.intermediate.Read(0, 0, Z + state_stride, B);\n"; + c += " FLT4 r2 = args.intermediate.Read(0, 0, Z + state_stride * 2, B);\n"; + c += " FLT4 r3 = args.intermediate.Read(0, 0, Z + state_stride * 3, B);\n"; if (op_def.precision != CalculationsPrecision::F32 && device.IsAdreno()) { c += " FLT4 input_gate;\n"; c += " FLT4 new_input;\n"; @@ -97,9 +92,9 @@ std::string GetLSTMCode(const OperationDef& op_def, const CLDevice& device) { "* r3));\n"; } c += " FLT4 new_st = input_gate * new_input + forget_gate * prev_st;\n"; - c += " FLT4 activation = output_gate * tanh(new_st);\n"; - c += " " + activation.WriteWHSB("activation", "0", "0", "Z", "B"); - c += " " + new_state.WriteWHSB("new_st", "0", "0", "Z", "B"); + c += " FLT4 act_value = output_gate * tanh(new_st);\n"; + c += " args.activation.Write(act_value, 0, 0, Z, B);\n"; + c += " args.new_state.Write(new_st, 0, 0, Z, B);\n"; c += "}\n"; return c; } @@ -122,22 +117,20 @@ LSTM& LSTM::operator=(LSTM&& kernel) { } absl::Status LSTM::Compile(const CreationContext& creation_context) { - const auto code = GetLSTMCode(definition_, *creation_context.device); + std::string code = GetLSTMCode(definition_, *creation_context.device, &args_); + RETURN_IF_ERROR( + args_.TransformToCLCode(creation_context.device->GetInfo(), {}, &code)); return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, *creation_context.device, &kernel_); } absl::Status LSTM::BindArguments() { - kernel_.ResetBindingCounter(); - RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); - RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[1]->GetMemoryPtr())); - RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting())); - RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[1]->GetMemoryPtrForWriting())); - RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB())); - RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); - RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->Batch())); - return absl::OkStatus(); + RETURN_IF_ERROR(args_.SetObjectRef("intermediate", src_[0])); + RETURN_IF_ERROR(args_.SetObjectRef("prev_state", src_[1])); + RETURN_IF_ERROR(args_.SetObjectRef("new_state", dst_[0])); + RETURN_IF_ERROR(args_.SetObjectRef("activation", dst_[1])); + return args_.Bind(kernel_.kernel()); } int3 LSTM::GetGridSize() const {