Fixed bugs introduced by recent CL.

PiperOrigin-RevId: 322619249
Change-Id: Id16e10d8f5fb8d77d7a213a69b3d3409b4e609aa
This commit is contained in:
Raman Sarokin 2020-07-22 11:39:09 -07:00 committed by TensorFlower Gardener
parent a9b7e06aa8
commit 232a4118c8
2 changed files with 16 additions and 14 deletions

View File

@ -491,14 +491,14 @@ class InferenceRunnerImpl : public InferenceRunner {
absl::Status SetInputObject(int index, TensorObject object) override {
if (index < 0 || index >= inputs_.size()) {
return absl::OutOfRangeError("Index is out of range");
return absl::OutOfRangeError("Input index is out of range");
}
return inputs_[index]->SetExternalObject(object);
}
absl::Status SetOutputObject(int index, TensorObject object) override {
if (index < 0 || index >= outputs_.size()) {
return absl::OutOfRangeError("Index is out of range");
return absl::OutOfRangeError("Output index is out of range");
}
return outputs_[index]->SetExternalObject(object);
}
@ -623,13 +623,13 @@ class InferenceBuilderImpl : public InferenceBuilder {
absl::Status SetInputObjectDef(int index, ObjectDef new_def) override {
if (index < 0 || index >= inputs_.size()) {
return absl::OutOfRangeError("Index is out of range");
return absl::OutOfRangeError("Input index is out of range");
}
auto def = inputs_[index];
def.external_def.object_def = new_def;
if (!tie_factory_->IsSupported(def)) {
return absl::InvalidArgumentError(
"New object definition is not supported.");
"New input object definition is not supported.");
}
inputs_[index] = def;
return absl::OkStatus();
@ -637,13 +637,13 @@ class InferenceBuilderImpl : public InferenceBuilder {
absl::Status SetOutputObjectDef(int index, ObjectDef new_def) override {
if (index < 0 || index >= outputs_.size()) {
return absl::OutOfRangeError("Index is out of range");
return absl::OutOfRangeError("Output index is out of range");
}
auto def = outputs_[index];
def.external_def.object_def = new_def;
if (!tie_factory_->IsSupported(def)) {
return absl::InvalidArgumentError(
"New object definition is not supported.");
"New output object definition is not supported.");
}
outputs_[index] = def;
return absl::OkStatus();

View File

@ -57,6 +57,8 @@ bool IsSuitableForWinograd4x4To6x6(const Convolution2DAttributes& attr,
}
absl::Status WinogradFromNode(const CreationContext& creation_context,
const std::vector<Value*>& inputs,
const std::vector<Value*>& outputs,
const OperationDef& op_def, ModelHints hints,
const BHWC& input_shape, const BHWC& output_shape,
const Convolution2DAttributes& attr,
@ -95,7 +97,7 @@ absl::Status WinogradFromNode(const CreationContext& creation_context,
auto& winograd_up = gpu_subgraph->operations[0];
RETURN_IF_ERROR(SelectWinograd4x4To36(
creation_context, attr.padding, winograd_up_def, &winograd_up.operation));
winograd_up.input_ids = {0};
winograd_up.input_ids = {static_cast<int>(inputs[0]->id)};
winograd_up.output_ids = {-1};
OperationDef conv_def;
@ -114,7 +116,7 @@ absl::Status WinogradFromNode(const CreationContext& creation_context,
winograd_down_def.dst_tensors.push_back(op_def.dst_tensors[0]);
auto& winograd_down = gpu_subgraph->operations[2];
winograd_down.input_ids = {-2};
winograd_down.output_ids = {0};
winograd_down.output_ids = {static_cast<int>(outputs[0]->id)};
auto bias_copy = attr.bias;
if (bias_copy.shape.v < attr.weights.shape.o) {
bias_copy.shape = Linear(attr.weights.shape.o);
@ -202,8 +204,8 @@ absl::Status GPUOperationFromNode(const CreationContext& creation_context,
auto input_shape = inputs[0]->tensor.shape;
auto output_shape = outputs[0]->tensor.shape;
if (inputs.size() == 1) {
if (WinogradFromNode(creation_context, op_def, hints, input_shape,
output_shape, attr, gpu_subgraph)
if (WinogradFromNode(creation_context, inputs, outputs, op_def, hints,
input_shape, output_shape, attr, gpu_subgraph)
.ok()) {
return absl::OkStatus();
} else {
@ -215,13 +217,13 @@ absl::Status GPUOperationFromNode(const CreationContext& creation_context,
auto weights_shape = inputs[1]->tensor.shape;
TensorDescriptor weights_desc = {op_def.src_tensors[1].data_type,
TensorStorageType::BUFFER,
Layout::UNKNOWN};
Layout::BHWC};
gpu_subgraph->operations.clear();
gpu_subgraph->operations.resize(2);
auto& converter_op = gpu_subgraph->operations[0];
auto& conv_op = gpu_subgraph->operations[1];
conv_op.input_ids = {0, -1};
conv_op.output_ids = {0};
conv_op.input_ids = {static_cast<int>(inputs[0]->id), -1};
conv_op.output_ids = {static_cast<int>(outputs[0]->id)};
OperationDef conv_def = op_def;
conv_def.src_tensors[1] = weights_desc;
ConvWeightsDescription conv_weights_desc;
@ -242,7 +244,7 @@ absl::Status GPUOperationFromNode(const CreationContext& creation_context,
converter_def.src_tensors.push_back(op_def.src_tensors[1]);
converter_def.dst_tensors.push_back(weights_desc);
converter_op.input_ids = {1};
converter_op.input_ids = {static_cast<int>(inputs[1]->id)};
converter_op.output_ids = {-1};
return SelectConverterToConvWeights(conv_weights_desc, creation_context,
converter_def, hints,