Automatic readability finding fixes.
PiperOrigin-RevId: 324682656 Change-Id: I009a949fbf08eeb5e45d19d8a1918a988960311b
This commit is contained in:
parent
b58a8717b1
commit
6dbc50195e
@ -40,8 +40,7 @@ absl::Status CreatePHWC4BufferFromTensorRef(const TensorRef<BHWC>& tensor_ref,
|
|||||||
|
|
||||||
absl::Status CopyFromPHWC4Buffer(const GlBuffer& buffer,
|
absl::Status CopyFromPHWC4Buffer(const GlBuffer& buffer,
|
||||||
TensorFloat32* tensor) {
|
TensorFloat32* tensor) {
|
||||||
return buffer.MappedRead<float>(
|
return buffer.MappedRead<float>([tensor](absl::Span<const float> data) {
|
||||||
[tensor, &buffer](absl::Span<const float> data) {
|
|
||||||
tensor->data.resize(tensor->shape.DimensionsProduct());
|
tensor->data.resize(tensor->shape.DimensionsProduct());
|
||||||
return ConvertFromPHWC4(absl::MakeConstSpan(data), tensor->shape,
|
return ConvertFromPHWC4(absl::MakeConstSpan(data), tensor->shape,
|
||||||
absl::MakeSpan(tensor->data));
|
absl::MakeSpan(tensor->data));
|
||||||
|
|||||||
@ -115,7 +115,7 @@ std::vector<ComputeTaskDescriptorPtr> ElementwiseWithTwoInputs(
|
|||||||
|
|
||||||
desc->uniform_buffers = {
|
desc->uniform_buffers = {
|
||||||
{"constant int2&",
|
{"constant int2&",
|
||||||
[input_ids, output_id](const std::map<ValueId, BHWC>& buffers) {
|
[input_ids](const std::map<ValueId, BHWC>& buffers) {
|
||||||
const auto& input_dim_1 = buffers.find(input_ids[1])->second;
|
const auto& input_dim_1 = buffers.find(input_ids[1])->second;
|
||||||
std::vector<int> uniform_params{
|
std::vector<int> uniform_params{
|
||||||
input_dim_1.w,
|
input_dim_1.w,
|
||||||
|
|||||||
@ -99,17 +99,15 @@ std::vector<ComputeTaskDescriptorPtr> MaxUnpooling(
|
|||||||
{input_indices_id, "device FLT4* const src_indices_buffer"},
|
{input_indices_id, "device FLT4* const src_indices_buffer"},
|
||||||
};
|
};
|
||||||
|
|
||||||
desc->output_buffer = {output_id, "device FLT4* output_buffer",
|
desc->output_buffer = {
|
||||||
[input_id, input_indices_id,
|
output_id, "device FLT4* output_buffer",
|
||||||
params](const std::map<ValueId, BHWC>& buffers) {
|
[input_id, params](const std::map<ValueId, BHWC>& buffers) {
|
||||||
return CalculateOutputShape(
|
return CalculateOutputShape(buffers.find(input_id)->second, params);
|
||||||
buffers.find(input_id)->second, params);
|
|
||||||
}};
|
}};
|
||||||
|
|
||||||
desc->uniform_buffers = {
|
desc->uniform_buffers = {
|
||||||
{"constant uniforms& params",
|
{"constant uniforms& params",
|
||||||
[input_id, input_indices_id, output_id,
|
[input_id, output_id, params](const std::map<ValueId, BHWC>& buffers) {
|
||||||
params](const std::map<ValueId, BHWC>& buffers) {
|
|
||||||
const auto& dimension = buffers.find(input_id)->second;
|
const auto& dimension = buffers.find(input_id)->second;
|
||||||
const auto& output_dimension = buffers.find(output_id)->second;
|
const auto& output_dimension = buffers.find(output_id)->second;
|
||||||
std::vector<int> uniform_params{
|
std::vector<int> uniform_params{
|
||||||
@ -126,7 +124,7 @@ std::vector<ComputeTaskDescriptorPtr> MaxUnpooling(
|
|||||||
}},
|
}},
|
||||||
};
|
};
|
||||||
|
|
||||||
desc->resize_function = [input_id, input_indices_id,
|
desc->resize_function = [input_id,
|
||||||
params](const std::map<ValueId, BHWC>& buffers) {
|
params](const std::map<ValueId, BHWC>& buffers) {
|
||||||
const auto& src_shape = buffers.find(input_id)->second;
|
const auto& src_shape = buffers.find(input_id)->second;
|
||||||
BHWC dst_shape = CalculateOutputShape(src_shape, params);
|
BHWC dst_shape = CalculateOutputShape(src_shape, params);
|
||||||
|
|||||||
@ -130,8 +130,7 @@ std::vector<ComputeTaskDescriptorPtr> Mean(int id, ValueId input_id,
|
|||||||
}};
|
}};
|
||||||
desc->uniform_buffers = {
|
desc->uniform_buffers = {
|
||||||
{"constant uniforms& params",
|
{"constant uniforms& params",
|
||||||
[input_id, output_id,
|
[input_id, work_group_size](const std::map<ValueId, BHWC>& buffers) {
|
||||||
work_group_size](const std::map<ValueId, BHWC>& buffers) {
|
|
||||||
const auto& src_shape = buffers.find(input_id)->second;
|
const auto& src_shape = buffers.find(input_id)->second;
|
||||||
const int src_slices = DivideRoundUp(src_shape.c, 4);
|
const int src_slices = DivideRoundUp(src_shape.c, 4);
|
||||||
struct uniforms {
|
struct uniforms {
|
||||||
|
|||||||
@ -613,8 +613,7 @@ std::vector<ComputeTaskDescriptorPtr> Winograd4x4To36TileX6(
|
|||||||
}},
|
}},
|
||||||
};
|
};
|
||||||
|
|
||||||
desc->resize_function = [output_id,
|
desc->resize_function = [output_id](const std::map<ValueId, BHWC>& buffers) {
|
||||||
attr](const std::map<ValueId, BHWC>& buffers) {
|
|
||||||
const uint3 groups_size{4, 6, 1};
|
const uint3 groups_size{4, 6, 1};
|
||||||
const auto& dst_shape = buffers.find(output_id)->second;
|
const auto& dst_shape = buffers.find(output_id)->second;
|
||||||
int grid_x = dst_shape.w;
|
int grid_x = dst_shape.w;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user