Fixed bugs introduced by recent CL.
PiperOrigin-RevId: 322619249 Change-Id: Id16e10d8f5fb8d77d7a213a69b3d3409b4e609aa
This commit is contained in:
parent
a9b7e06aa8
commit
232a4118c8
@ -491,14 +491,14 @@ class InferenceRunnerImpl : public InferenceRunner {
|
|||||||
|
|
||||||
absl::Status SetInputObject(int index, TensorObject object) override {
|
absl::Status SetInputObject(int index, TensorObject object) override {
|
||||||
if (index < 0 || index >= inputs_.size()) {
|
if (index < 0 || index >= inputs_.size()) {
|
||||||
return absl::OutOfRangeError("Index is out of range");
|
return absl::OutOfRangeError("Input index is out of range");
|
||||||
}
|
}
|
||||||
return inputs_[index]->SetExternalObject(object);
|
return inputs_[index]->SetExternalObject(object);
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status SetOutputObject(int index, TensorObject object) override {
|
absl::Status SetOutputObject(int index, TensorObject object) override {
|
||||||
if (index < 0 || index >= outputs_.size()) {
|
if (index < 0 || index >= outputs_.size()) {
|
||||||
return absl::OutOfRangeError("Index is out of range");
|
return absl::OutOfRangeError("Output index is out of range");
|
||||||
}
|
}
|
||||||
return outputs_[index]->SetExternalObject(object);
|
return outputs_[index]->SetExternalObject(object);
|
||||||
}
|
}
|
||||||
@ -623,13 +623,13 @@ class InferenceBuilderImpl : public InferenceBuilder {
|
|||||||
|
|
||||||
absl::Status SetInputObjectDef(int index, ObjectDef new_def) override {
|
absl::Status SetInputObjectDef(int index, ObjectDef new_def) override {
|
||||||
if (index < 0 || index >= inputs_.size()) {
|
if (index < 0 || index >= inputs_.size()) {
|
||||||
return absl::OutOfRangeError("Index is out of range");
|
return absl::OutOfRangeError("Input index is out of range");
|
||||||
}
|
}
|
||||||
auto def = inputs_[index];
|
auto def = inputs_[index];
|
||||||
def.external_def.object_def = new_def;
|
def.external_def.object_def = new_def;
|
||||||
if (!tie_factory_->IsSupported(def)) {
|
if (!tie_factory_->IsSupported(def)) {
|
||||||
return absl::InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
"New object definition is not supported.");
|
"New input object definition is not supported.");
|
||||||
}
|
}
|
||||||
inputs_[index] = def;
|
inputs_[index] = def;
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
@ -637,13 +637,13 @@ class InferenceBuilderImpl : public InferenceBuilder {
|
|||||||
|
|
||||||
absl::Status SetOutputObjectDef(int index, ObjectDef new_def) override {
|
absl::Status SetOutputObjectDef(int index, ObjectDef new_def) override {
|
||||||
if (index < 0 || index >= outputs_.size()) {
|
if (index < 0 || index >= outputs_.size()) {
|
||||||
return absl::OutOfRangeError("Index is out of range");
|
return absl::OutOfRangeError("Output index is out of range");
|
||||||
}
|
}
|
||||||
auto def = outputs_[index];
|
auto def = outputs_[index];
|
||||||
def.external_def.object_def = new_def;
|
def.external_def.object_def = new_def;
|
||||||
if (!tie_factory_->IsSupported(def)) {
|
if (!tie_factory_->IsSupported(def)) {
|
||||||
return absl::InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
"New object definition is not supported.");
|
"New output object definition is not supported.");
|
||||||
}
|
}
|
||||||
outputs_[index] = def;
|
outputs_[index] = def;
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -57,6 +57,8 @@ bool IsSuitableForWinograd4x4To6x6(const Convolution2DAttributes& attr,
|
|||||||
}
|
}
|
||||||
|
|
||||||
absl::Status WinogradFromNode(const CreationContext& creation_context,
|
absl::Status WinogradFromNode(const CreationContext& creation_context,
|
||||||
|
const std::vector<Value*>& inputs,
|
||||||
|
const std::vector<Value*>& outputs,
|
||||||
const OperationDef& op_def, ModelHints hints,
|
const OperationDef& op_def, ModelHints hints,
|
||||||
const BHWC& input_shape, const BHWC& output_shape,
|
const BHWC& input_shape, const BHWC& output_shape,
|
||||||
const Convolution2DAttributes& attr,
|
const Convolution2DAttributes& attr,
|
||||||
@ -95,7 +97,7 @@ absl::Status WinogradFromNode(const CreationContext& creation_context,
|
|||||||
auto& winograd_up = gpu_subgraph->operations[0];
|
auto& winograd_up = gpu_subgraph->operations[0];
|
||||||
RETURN_IF_ERROR(SelectWinograd4x4To36(
|
RETURN_IF_ERROR(SelectWinograd4x4To36(
|
||||||
creation_context, attr.padding, winograd_up_def, &winograd_up.operation));
|
creation_context, attr.padding, winograd_up_def, &winograd_up.operation));
|
||||||
winograd_up.input_ids = {0};
|
winograd_up.input_ids = {static_cast<int>(inputs[0]->id)};
|
||||||
winograd_up.output_ids = {-1};
|
winograd_up.output_ids = {-1};
|
||||||
|
|
||||||
OperationDef conv_def;
|
OperationDef conv_def;
|
||||||
@ -114,7 +116,7 @@ absl::Status WinogradFromNode(const CreationContext& creation_context,
|
|||||||
winograd_down_def.dst_tensors.push_back(op_def.dst_tensors[0]);
|
winograd_down_def.dst_tensors.push_back(op_def.dst_tensors[0]);
|
||||||
auto& winograd_down = gpu_subgraph->operations[2];
|
auto& winograd_down = gpu_subgraph->operations[2];
|
||||||
winograd_down.input_ids = {-2};
|
winograd_down.input_ids = {-2};
|
||||||
winograd_down.output_ids = {0};
|
winograd_down.output_ids = {static_cast<int>(outputs[0]->id)};
|
||||||
auto bias_copy = attr.bias;
|
auto bias_copy = attr.bias;
|
||||||
if (bias_copy.shape.v < attr.weights.shape.o) {
|
if (bias_copy.shape.v < attr.weights.shape.o) {
|
||||||
bias_copy.shape = Linear(attr.weights.shape.o);
|
bias_copy.shape = Linear(attr.weights.shape.o);
|
||||||
@ -202,8 +204,8 @@ absl::Status GPUOperationFromNode(const CreationContext& creation_context,
|
|||||||
auto input_shape = inputs[0]->tensor.shape;
|
auto input_shape = inputs[0]->tensor.shape;
|
||||||
auto output_shape = outputs[0]->tensor.shape;
|
auto output_shape = outputs[0]->tensor.shape;
|
||||||
if (inputs.size() == 1) {
|
if (inputs.size() == 1) {
|
||||||
if (WinogradFromNode(creation_context, op_def, hints, input_shape,
|
if (WinogradFromNode(creation_context, inputs, outputs, op_def, hints,
|
||||||
output_shape, attr, gpu_subgraph)
|
input_shape, output_shape, attr, gpu_subgraph)
|
||||||
.ok()) {
|
.ok()) {
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
} else {
|
} else {
|
||||||
@ -215,13 +217,13 @@ absl::Status GPUOperationFromNode(const CreationContext& creation_context,
|
|||||||
auto weights_shape = inputs[1]->tensor.shape;
|
auto weights_shape = inputs[1]->tensor.shape;
|
||||||
TensorDescriptor weights_desc = {op_def.src_tensors[1].data_type,
|
TensorDescriptor weights_desc = {op_def.src_tensors[1].data_type,
|
||||||
TensorStorageType::BUFFER,
|
TensorStorageType::BUFFER,
|
||||||
Layout::UNKNOWN};
|
Layout::BHWC};
|
||||||
gpu_subgraph->operations.clear();
|
gpu_subgraph->operations.clear();
|
||||||
gpu_subgraph->operations.resize(2);
|
gpu_subgraph->operations.resize(2);
|
||||||
auto& converter_op = gpu_subgraph->operations[0];
|
auto& converter_op = gpu_subgraph->operations[0];
|
||||||
auto& conv_op = gpu_subgraph->operations[1];
|
auto& conv_op = gpu_subgraph->operations[1];
|
||||||
conv_op.input_ids = {0, -1};
|
conv_op.input_ids = {static_cast<int>(inputs[0]->id), -1};
|
||||||
conv_op.output_ids = {0};
|
conv_op.output_ids = {static_cast<int>(outputs[0]->id)};
|
||||||
OperationDef conv_def = op_def;
|
OperationDef conv_def = op_def;
|
||||||
conv_def.src_tensors[1] = weights_desc;
|
conv_def.src_tensors[1] = weights_desc;
|
||||||
ConvWeightsDescription conv_weights_desc;
|
ConvWeightsDescription conv_weights_desc;
|
||||||
@ -242,7 +244,7 @@ absl::Status GPUOperationFromNode(const CreationContext& creation_context,
|
|||||||
converter_def.src_tensors.push_back(op_def.src_tensors[1]);
|
converter_def.src_tensors.push_back(op_def.src_tensors[1]);
|
||||||
converter_def.dst_tensors.push_back(weights_desc);
|
converter_def.dst_tensors.push_back(weights_desc);
|
||||||
|
|
||||||
converter_op.input_ids = {1};
|
converter_op.input_ids = {static_cast<int>(inputs[1]->id)};
|
||||||
converter_op.output_ids = {-1};
|
converter_op.output_ids = {-1};
|
||||||
return SelectConverterToConvWeights(conv_weights_desc, creation_context,
|
return SelectConverterToConvWeights(conv_weights_desc, creation_context,
|
||||||
converter_def, hints,
|
converter_def, hints,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user