Reshape&Reshapex4 converted to new style.
PiperOrigin-RevId: 316904479 Change-Id: I7c1fb0ca5a31fc1f82545d70cfcdcfb7d63bcd6a
This commit is contained in:
parent
b780ee931b
commit
d8e0beacd9
@ -25,92 +25,24 @@ namespace gpu {
|
||||
namespace cl {
|
||||
namespace {
|
||||
|
||||
std::string GetReshapeBatchedCode(
|
||||
const OperationDef& op_def,
|
||||
const std::vector<ElementwiseOperation*>& linked_operations) {
|
||||
TensorCodeGenerator src_tensor(
|
||||
"src_data",
|
||||
WHSBPoint{"src_size.x", "src_size.y", "src_size.z", "src_size.w"},
|
||||
op_def.src_tensors[0]);
|
||||
TensorCodeGenerator dst_tensor(
|
||||
"dst_data",
|
||||
WHSBPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"},
|
||||
op_def.dst_tensors[0]);
|
||||
std::string GetReshapeBatchedCode(const OperationDef& op_def, Arguments* args) {
|
||||
args->AddObjectRef(
|
||||
"src_tensor", AccessType::READ,
|
||||
absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]));
|
||||
args->AddObjectRef(
|
||||
"dst_tensor", AccessType::WRITE,
|
||||
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
|
||||
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
c += "__kernel void main_function(\n";
|
||||
c += src_tensor.GetDeclaration(AccessType::READ);
|
||||
c += GetArgsDeclaration(linked_operations);
|
||||
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
|
||||
c += " int4 src_size, \n";
|
||||
c += " int4 dst_size, \n";
|
||||
c += " int src_channels, \n";
|
||||
c += " int dst_channels \n";
|
||||
c += ") {\n";
|
||||
c += "$0) {\n";
|
||||
c += " int linear_id = get_global_id(0);\n";
|
||||
c += " int X = linear_id / dst_size.w;\n";
|
||||
c += " int B = linear_id % dst_size.w;\n";
|
||||
c += " int X = linear_id / args.dst_tensor.Batch();\n";
|
||||
c += " int B = linear_id % args.dst_tensor.Batch();\n";
|
||||
c += " int Y = get_global_id(1);\n";
|
||||
c += " int Z = get_global_id(2);\n";
|
||||
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z || B >= "
|
||||
"dst_size.w) return;\n";
|
||||
c += " FLT temps[4];\n";
|
||||
c += " temps[0] = (FLT)(0.0f);\n";
|
||||
c += " temps[1] = (FLT)(0.0f);\n";
|
||||
c += " temps[2] = (FLT)(0.0f);\n";
|
||||
c += " temps[3] = (FLT)(0.0f);\n";
|
||||
c += " int base = ((B * dst_size.y + Y)* dst_size.x + X)* dst_channels + Z "
|
||||
"* 4;\n";
|
||||
c += " for (int i = 0; i < 4; ++i) {\n";
|
||||
c += " int dst_channel = Z * 4 + i;\n";
|
||||
c += " if (dst_channel < dst_channels) {;\n";
|
||||
c += " int p = base + i;\n";
|
||||
c += " int src_c = p % src_channels;\n";
|
||||
c += " p = p / src_channels;\n";
|
||||
c += " int src_x = p % src_size.x;\n";
|
||||
c += " p = p / src_size.x;\n";
|
||||
c += " int src_y = p % src_size.y;\n";
|
||||
c += " int src_b = p / src_size.y;\n";
|
||||
c += " int src_z = src_c / 4;\n";
|
||||
c += " int src_sub_ch = src_c % 4;\n";
|
||||
c += " FLT4 t =" +
|
||||
src_tensor.ReadWHSB("src_x", "src_y", "src_z", "src_b") + ";\n";
|
||||
c += " FLT t_ar[4] = {t.x, t.y, t.z, t.w};\n";
|
||||
c += " temps[i] = t_ar[src_sub_ch];\n";
|
||||
c += " }\n";
|
||||
c += " }\n";
|
||||
c += " FLT4 result = (FLT4)(temps[0], temps[1], temps[2], temps[3]);\n";
|
||||
const LinkingContext context{"result", "X * dst_size.w + B", "Y", "Z"};
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " + dst_tensor.WriteWHSB("result", "X", "Y", "Z", "B");
|
||||
c += "}\n";
|
||||
return c;
|
||||
}
|
||||
|
||||
std::string GetReshapeCode(
|
||||
const OperationDef& op_def,
|
||||
const std::vector<ElementwiseOperation*>& linked_operations) {
|
||||
TensorCodeGenerator src_tensor(
|
||||
"src_data", WHSPoint{"src_size.x", "src_size.y", "src_size.z"},
|
||||
op_def.src_tensors[0]);
|
||||
TensorCodeGenerator dst_tensor(
|
||||
"dst_data", WHSPoint{"dst_size.x", "dst_size.y", "dst_size.z"},
|
||||
op_def.dst_tensors[0]);
|
||||
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
c += "__kernel void main_function(\n";
|
||||
c += src_tensor.GetDeclaration(AccessType::READ);
|
||||
c += GetArgsDeclaration(linked_operations);
|
||||
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
|
||||
c += " int4 src_size, \n";
|
||||
c += " int4 dst_size, \n";
|
||||
c += " int src_channels, \n";
|
||||
c += " int dst_channels \n";
|
||||
c += ") {\n";
|
||||
c += " int X = get_global_id(0);\n";
|
||||
c += " int Y = get_global_id(1);\n";
|
||||
c += " int Z = get_global_id(2);\n";
|
||||
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) { \n";
|
||||
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
|
||||
"Z >= args.dst_tensor.Slices()) { \n";
|
||||
c += " return; \n";
|
||||
c += " } \n";
|
||||
c += " FLT temps[4];\n";
|
||||
@ -118,25 +50,73 @@ std::string GetReshapeCode(
|
||||
c += " temps[1] = (FLT)(0.0f);\n";
|
||||
c += " temps[2] = (FLT)(0.0f);\n";
|
||||
c += " temps[3] = (FLT)(0.0f);\n";
|
||||
c += " int base = ((B * args.dst_tensor.Height() + Y) * "
|
||||
"args.dst_tensor.Width() + X) * args.dst_tensor.Channels() + Z * 4;\n";
|
||||
c += " for (int i = 0; i < 4; ++i) {\n";
|
||||
c += " int dst_channel = Z * 4 + i;\n";
|
||||
c += " if (dst_channel < dst_channels) {;\n";
|
||||
c += " int p = dst_channel + dst_channels * (X + dst_size.x * Y);\n";
|
||||
c += " int src_c = p % src_channels;\n";
|
||||
c += " p = p / src_channels;\n";
|
||||
c += " int src_x = p % src_size.x;\n";
|
||||
c += " int src_y = p / src_size.x;\n";
|
||||
c += " if (dst_channel < args.dst_tensor.Channels()) {;\n";
|
||||
c += " int p = base + i;\n";
|
||||
c += " int src_c = p % args.src_tensor.Channels();\n";
|
||||
c += " p = p / args.src_tensor.Channels();\n";
|
||||
c += " int src_x = p % args.src_tensor.Width();\n";
|
||||
c += " p = p / args.src_tensor.Width();\n";
|
||||
c += " int src_y = p % args.src_tensor.Height();\n";
|
||||
c += " int src_b = p / args.src_tensor.Height();\n";
|
||||
c += " int src_z = src_c / 4;\n";
|
||||
c += " int src_sub_ch = src_c % 4;\n";
|
||||
c += " FLT4 t =" + src_tensor.ReadWHS("src_x", "src_y", "src_z") + ";\n";
|
||||
c += " FLT4 t = args.src_tensor.Read(src_x, src_y, src_z, src_b);\n";
|
||||
c += " FLT t_ar[4] = {t.x, t.y, t.z, t.w};\n";
|
||||
c += " temps[i] = t_ar[src_sub_ch];\n";
|
||||
c += " }\n";
|
||||
c += " }\n";
|
||||
c += " FLT4 result = (FLT4)(temps[0], temps[1], temps[2], temps[3]);\n";
|
||||
const LinkingContext context{"result", "X", "Y", "Z"};
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " + dst_tensor.WriteWHS("result", "X", "Y", "Z");
|
||||
c += " args.dst_tensor.Write(result, X, Y, Z, B);\n";
|
||||
c += "}\n";
|
||||
return c;
|
||||
}
|
||||
|
||||
std::string GetReshapeCode(const OperationDef& op_def, Arguments* args) {
|
||||
args->AddObjectRef(
|
||||
"src_tensor", AccessType::READ,
|
||||
absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]));
|
||||
args->AddObjectRef(
|
||||
"dst_tensor", AccessType::WRITE,
|
||||
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
|
||||
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
c += "__kernel void main_function(\n";
|
||||
c += "$0) {\n";
|
||||
c += " int X = get_global_id(0);\n";
|
||||
c += " int Y = get_global_id(1);\n";
|
||||
c += " int Z = get_global_id(2);\n";
|
||||
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
|
||||
"Z >= args.dst_tensor.Slices()) { \n";
|
||||
c += " return; \n";
|
||||
c += " } \n";
|
||||
c += " FLT temps[4];\n";
|
||||
c += " temps[0] = (FLT)(0.0f);\n";
|
||||
c += " temps[1] = (FLT)(0.0f);\n";
|
||||
c += " temps[2] = (FLT)(0.0f);\n";
|
||||
c += " temps[3] = (FLT)(0.0f);\n";
|
||||
c += " int base = (Y * args.dst_tensor.Width() + X) * "
|
||||
"args.dst_tensor.Channels() + Z * 4;\n";
|
||||
c += " for (int i = 0; i < 4; ++i) {\n";
|
||||
c += " int dst_channel = Z * 4 + i;\n";
|
||||
c += " if (dst_channel < args.dst_tensor.Channels()) {;\n";
|
||||
c += " int p = base + i;\n";
|
||||
c += " int src_c = p % args.src_tensor.Channels();\n";
|
||||
c += " p = p / args.src_tensor.Channels();\n";
|
||||
c += " int src_x = p % args.src_tensor.Width();\n";
|
||||
c += " int src_y = p / args.src_tensor.Width();\n";
|
||||
c += " int src_z = src_c / 4;\n";
|
||||
c += " int src_sub_ch = src_c % 4;\n";
|
||||
c += " FLT4 t = args.src_tensor.Read(src_x, src_y, src_z);\n";
|
||||
c += " FLT t_ar[4] = {t.x, t.y, t.z, t.w};\n";
|
||||
c += " temps[i] = t_ar[src_sub_ch];\n";
|
||||
c += " }\n";
|
||||
c += " }\n";
|
||||
c += " FLT4 result = (FLT4)(temps[0], temps[1], temps[2], temps[3]);\n";
|
||||
c += " args.dst_tensor.Write(result, X, Y, Z);\n";
|
||||
c += "}\n";
|
||||
return c;
|
||||
}
|
||||
@ -157,24 +137,25 @@ Reshape& Reshape::operator=(Reshape&& operation) {
|
||||
}
|
||||
|
||||
absl::Status Reshape::Compile(const CreationContext& creation_context) {
|
||||
const auto code = definition_.IsBatchSupported()
|
||||
? GetReshapeBatchedCode(definition_, linked_operations_)
|
||||
: GetReshapeCode(definition_, linked_operations_);
|
||||
std::string code = definition_.IsBatchSupported()
|
||||
? GetReshapeBatchedCode(definition_, &args_)
|
||||
: GetReshapeCode(definition_, &args_);
|
||||
std::string element_wise_code;
|
||||
RETURN_IF_ERROR(
|
||||
MergeOperations(linked_operations_, &args_, &element_wise_code));
|
||||
RETURN_IF_ERROR(args_.TransformToCLCode(creation_context.device->GetInfo(),
|
||||
{{"dst_tensor", element_wise_code}},
|
||||
&code));
|
||||
return creation_context.cache->GetOrCreateCLKernel(
|
||||
code, "main_function", *creation_context.context,
|
||||
*creation_context.device, &kernel_);
|
||||
}
|
||||
|
||||
absl::Status Reshape::BindArguments() {
|
||||
kernel_.ResetBindingCounter();
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->Channels()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->Channels()));
|
||||
return absl::OkStatus();
|
||||
RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0]));
|
||||
RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
|
||||
RETURN_IF_ERROR(SetArguments(linked_operations_, &args_));
|
||||
return args_.Bind(kernel_.kernel());
|
||||
}
|
||||
|
||||
int3 Reshape::GetGridSize() const {
|
||||
|
@ -25,82 +25,66 @@ namespace gpu {
|
||||
namespace cl {
|
||||
namespace {
|
||||
|
||||
std::string GetReshapeBatchedCode(
|
||||
const OperationDef& op_def,
|
||||
const std::vector<ElementwiseOperation*>& linked_operations) {
|
||||
TensorCodeGenerator src_tensor(
|
||||
"src_data",
|
||||
WHSBPoint{"src_size.x", "src_size.y", "src_size.z", "src_size.w"},
|
||||
op_def.src_tensors[0]);
|
||||
TensorCodeGenerator dst_tensor(
|
||||
"dst_data",
|
||||
WHSBPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"},
|
||||
op_def.dst_tensors[0]);
|
||||
std::string GetReshapeBatchedCode(const OperationDef& op_def, Arguments* args) {
|
||||
args->AddObjectRef(
|
||||
"src_tensor", AccessType::READ,
|
||||
absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]));
|
||||
args->AddObjectRef(
|
||||
"dst_tensor", AccessType::WRITE,
|
||||
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
|
||||
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
c += "__kernel void main_function(\n";
|
||||
c += src_tensor.GetDeclaration(AccessType::READ);
|
||||
c += GetArgsDeclaration(linked_operations);
|
||||
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
|
||||
c += " int4 src_size, \n";
|
||||
c += " int4 dst_size \n";
|
||||
c += ") {\n";
|
||||
c += "$0) {\n";
|
||||
c += " int linear_id = get_global_id(0);\n";
|
||||
c += " int X = linear_id / dst_size.w;\n";
|
||||
c += " int B = linear_id % dst_size.w;\n";
|
||||
c += " int X = linear_id / args.dst_tensor.Batch();\n";
|
||||
c += " int B = linear_id % args.dst_tensor.Batch();\n";
|
||||
c += " int Y = get_global_id(1);\n";
|
||||
c += " int Z = get_global_id(2);\n";
|
||||
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z || B >= "
|
||||
"dst_size.w) return;\n";
|
||||
c += " int dst_bhwc4 = ((B * dst_size.y + Y) * dst_size.x + X) * dst_size.z "
|
||||
"+ Z;\n";
|
||||
c += " int src_z = dst_bhwc4 % src_size.z;\n";
|
||||
c += " dst_bhwc4 = dst_bhwc4 / src_size.z;\n";
|
||||
c += " int src_x = dst_bhwc4 % src_size.x;\n";
|
||||
c += " dst_bhwc4 = dst_bhwc4 / src_size.x;\n";
|
||||
c += " int src_y = dst_bhwc4 % src_size.y;\n";
|
||||
c += " int src_b = dst_bhwc4 / src_size.y;\n";
|
||||
c += " FLT4 result =" +
|
||||
src_tensor.ReadWHSB("src_x", "src_y", "src_z", "src_b") + ";\n";
|
||||
const LinkingContext context{"result", "X * dst_size.w + B", "Y", "Z"};
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " + dst_tensor.WriteWHSB("result", "X", "Y", "Z", "B");
|
||||
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
|
||||
"Z >= args.dst_tensor.Slices()) { \n";
|
||||
c += " return; \n";
|
||||
c += " } \n";
|
||||
c += " int dst_bhwc4 = ((B * args.dst_tensor.Height() + Y) * "
|
||||
"args.dst_tensor.Width() + X) * args.dst_tensor.Slices() + Z;\n";
|
||||
c += " int src_z = dst_bhwc4 % args.src_tensor.Slices();\n";
|
||||
c += " dst_bhwc4 = dst_bhwc4 / args.src_tensor.Slices();\n";
|
||||
c += " int src_x = dst_bhwc4 % args.src_tensor.Width();\n";
|
||||
c += " dst_bhwc4 = dst_bhwc4 / args.src_tensor.Width();\n";
|
||||
c += " int src_y = dst_bhwc4 % args.src_tensor.Height();\n";
|
||||
c += " int src_b = dst_bhwc4 / args.src_tensor.Height();\n";
|
||||
c += " FLT4 result = args.src_tensor.Read(src_x, src_y, src_z, src_b);\n";
|
||||
c += " args.dst_tensor.Write(result, X, Y, Z, B);\n";
|
||||
c += "}\n";
|
||||
return c;
|
||||
}
|
||||
|
||||
std::string GetReshapeCode(
|
||||
const OperationDef& op_def,
|
||||
const std::vector<ElementwiseOperation*>& linked_operations) {
|
||||
TensorCodeGenerator src_tensor(
|
||||
"src_data", WHSPoint{"src_size.x", "src_size.y", "src_size.z"},
|
||||
op_def.src_tensors[0]);
|
||||
TensorCodeGenerator dst_tensor(
|
||||
"dst_data", WHSPoint{"dst_size.x", "dst_size.y", "dst_size.z"},
|
||||
op_def.dst_tensors[0]);
|
||||
std::string GetReshapeCode(const OperationDef& op_def, Arguments* args) {
|
||||
args->AddObjectRef(
|
||||
"src_tensor", AccessType::READ,
|
||||
absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]));
|
||||
args->AddObjectRef(
|
||||
"dst_tensor", AccessType::WRITE,
|
||||
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
|
||||
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
c += "__kernel void main_function(\n";
|
||||
c += src_tensor.GetDeclaration(AccessType::READ);
|
||||
c += GetArgsDeclaration(linked_operations);
|
||||
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
|
||||
c += " int4 src_size, \n";
|
||||
c += " int4 dst_size \n";
|
||||
c += ") {\n";
|
||||
c += "$0) {\n";
|
||||
c += " int X = get_global_id(0);\n";
|
||||
c += " int Y = get_global_id(1);\n";
|
||||
c += " int Z = get_global_id(2);\n";
|
||||
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;\n";
|
||||
c += " int dst_hwc4 = (Y * dst_size.x + X) * dst_size.z + Z;\n";
|
||||
c += " int src_z = dst_hwc4 % src_size.z;\n";
|
||||
c += " dst_hwc4 = dst_hwc4 / src_size.z;\n";
|
||||
c += " int src_x = dst_hwc4 % src_size.x;\n";
|
||||
c += " int src_y = dst_hwc4 / src_size.x;\n";
|
||||
c +=
|
||||
" FLT4 result =" + src_tensor.ReadWHS("src_x", "src_y", "src_z") + ";\n";
|
||||
const LinkingContext context{"result", "X", "Y", "Z"};
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " + dst_tensor.WriteWHS("result", "X", "Y", "Z");
|
||||
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
|
||||
"Z >= args.dst_tensor.Slices()) { \n";
|
||||
c += " return; \n";
|
||||
c += " } \n";
|
||||
c += " int dst_hwc4 = (Y * args.dst_tensor.Width() + X) * "
|
||||
"args.dst_tensor.Slices() + Z;\n";
|
||||
c += " int src_z = dst_hwc4 % args.src_tensor.Slices();\n";
|
||||
c += " dst_hwc4 = dst_hwc4 / args.src_tensor.Slices();\n";
|
||||
c += " int src_x = dst_hwc4 % args.src_tensor.Width();\n";
|
||||
c += " int src_y = dst_hwc4 / args.src_tensor.Width();\n";
|
||||
c += " FLT4 result = args.src_tensor.Read(src_x, src_y, src_z);\n";
|
||||
c += " args.dst_tensor.Write(result, X, Y, Z);\n";
|
||||
c += "}\n";
|
||||
return c;
|
||||
}
|
||||
@ -121,22 +105,25 @@ Reshapex4& Reshapex4::operator=(Reshapex4&& operation) {
|
||||
}
|
||||
|
||||
absl::Status Reshapex4::Compile(const CreationContext& creation_context) {
|
||||
const auto code = definition_.IsBatchSupported()
|
||||
? GetReshapeBatchedCode(definition_, linked_operations_)
|
||||
: GetReshapeCode(definition_, linked_operations_);
|
||||
std::string code = definition_.IsBatchSupported()
|
||||
? GetReshapeBatchedCode(definition_, &args_)
|
||||
: GetReshapeCode(definition_, &args_);
|
||||
std::string element_wise_code;
|
||||
RETURN_IF_ERROR(
|
||||
MergeOperations(linked_operations_, &args_, &element_wise_code));
|
||||
RETURN_IF_ERROR(args_.TransformToCLCode(creation_context.device->GetInfo(),
|
||||
{{"dst_tensor", element_wise_code}},
|
||||
&code));
|
||||
return creation_context.cache->GetOrCreateCLKernel(
|
||||
code, "main_function", *creation_context.context,
|
||||
*creation_context.device, &kernel_);
|
||||
}
|
||||
|
||||
absl::Status Reshapex4::BindArguments() {
|
||||
kernel_.ResetBindingCounter();
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
||||
return absl::OkStatus();
|
||||
RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0]));
|
||||
RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
|
||||
RETURN_IF_ERROR(SetArguments(linked_operations_, &args_));
|
||||
return args_.Bind(kernel_.kernel());
|
||||
}
|
||||
|
||||
int3 Reshapex4::GetGridSize() const {
|
||||
|
Loading…
x
Reference in New Issue
Block a user