Fixed bugs introduced by recent CL.

PiperOrigin-RevId: 322619249
Change-Id: Id16e10d8f5fb8d77d7a213a69b3d3409b4e609aa
This commit is contained in:
Raman Sarokin 2020-07-22 11:39:09 -07:00 committed by TensorFlower Gardener
parent a9b7e06aa8
commit 232a4118c8
2 changed files with 16 additions and 14 deletions

View File

@ -491,14 +491,14 @@ class InferenceRunnerImpl : public InferenceRunner {
absl::Status SetInputObject(int index, TensorObject object) override { absl::Status SetInputObject(int index, TensorObject object) override {
if (index < 0 || index >= inputs_.size()) { if (index < 0 || index >= inputs_.size()) {
return absl::OutOfRangeError("Index is out of range"); return absl::OutOfRangeError("Input index is out of range");
} }
return inputs_[index]->SetExternalObject(object); return inputs_[index]->SetExternalObject(object);
} }
absl::Status SetOutputObject(int index, TensorObject object) override { absl::Status SetOutputObject(int index, TensorObject object) override {
if (index < 0 || index >= outputs_.size()) { if (index < 0 || index >= outputs_.size()) {
return absl::OutOfRangeError("Index is out of range"); return absl::OutOfRangeError("Output index is out of range");
} }
return outputs_[index]->SetExternalObject(object); return outputs_[index]->SetExternalObject(object);
} }
@ -623,13 +623,13 @@ class InferenceBuilderImpl : public InferenceBuilder {
absl::Status SetInputObjectDef(int index, ObjectDef new_def) override { absl::Status SetInputObjectDef(int index, ObjectDef new_def) override {
if (index < 0 || index >= inputs_.size()) { if (index < 0 || index >= inputs_.size()) {
return absl::OutOfRangeError("Index is out of range"); return absl::OutOfRangeError("Input index is out of range");
} }
auto def = inputs_[index]; auto def = inputs_[index];
def.external_def.object_def = new_def; def.external_def.object_def = new_def;
if (!tie_factory_->IsSupported(def)) { if (!tie_factory_->IsSupported(def)) {
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
"New object definition is not supported."); "New input object definition is not supported.");
} }
inputs_[index] = def; inputs_[index] = def;
return absl::OkStatus(); return absl::OkStatus();
@ -637,13 +637,13 @@ class InferenceBuilderImpl : public InferenceBuilder {
absl::Status SetOutputObjectDef(int index, ObjectDef new_def) override { absl::Status SetOutputObjectDef(int index, ObjectDef new_def) override {
if (index < 0 || index >= outputs_.size()) { if (index < 0 || index >= outputs_.size()) {
return absl::OutOfRangeError("Index is out of range"); return absl::OutOfRangeError("Output index is out of range");
} }
auto def = outputs_[index]; auto def = outputs_[index];
def.external_def.object_def = new_def; def.external_def.object_def = new_def;
if (!tie_factory_->IsSupported(def)) { if (!tie_factory_->IsSupported(def)) {
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
"New object definition is not supported."); "New output object definition is not supported.");
} }
outputs_[index] = def; outputs_[index] = def;
return absl::OkStatus(); return absl::OkStatus();

View File

@ -57,6 +57,8 @@ bool IsSuitableForWinograd4x4To6x6(const Convolution2DAttributes& attr,
} }
absl::Status WinogradFromNode(const CreationContext& creation_context, absl::Status WinogradFromNode(const CreationContext& creation_context,
const std::vector<Value*>& inputs,
const std::vector<Value*>& outputs,
const OperationDef& op_def, ModelHints hints, const OperationDef& op_def, ModelHints hints,
const BHWC& input_shape, const BHWC& output_shape, const BHWC& input_shape, const BHWC& output_shape,
const Convolution2DAttributes& attr, const Convolution2DAttributes& attr,
@ -95,7 +97,7 @@ absl::Status WinogradFromNode(const CreationContext& creation_context,
auto& winograd_up = gpu_subgraph->operations[0]; auto& winograd_up = gpu_subgraph->operations[0];
RETURN_IF_ERROR(SelectWinograd4x4To36( RETURN_IF_ERROR(SelectWinograd4x4To36(
creation_context, attr.padding, winograd_up_def, &winograd_up.operation)); creation_context, attr.padding, winograd_up_def, &winograd_up.operation));
winograd_up.input_ids = {0}; winograd_up.input_ids = {static_cast<int>(inputs[0]->id)};
winograd_up.output_ids = {-1}; winograd_up.output_ids = {-1};
OperationDef conv_def; OperationDef conv_def;
@ -114,7 +116,7 @@ absl::Status WinogradFromNode(const CreationContext& creation_context,
winograd_down_def.dst_tensors.push_back(op_def.dst_tensors[0]); winograd_down_def.dst_tensors.push_back(op_def.dst_tensors[0]);
auto& winograd_down = gpu_subgraph->operations[2]; auto& winograd_down = gpu_subgraph->operations[2];
winograd_down.input_ids = {-2}; winograd_down.input_ids = {-2};
winograd_down.output_ids = {0}; winograd_down.output_ids = {static_cast<int>(outputs[0]->id)};
auto bias_copy = attr.bias; auto bias_copy = attr.bias;
if (bias_copy.shape.v < attr.weights.shape.o) { if (bias_copy.shape.v < attr.weights.shape.o) {
bias_copy.shape = Linear(attr.weights.shape.o); bias_copy.shape = Linear(attr.weights.shape.o);
@ -202,8 +204,8 @@ absl::Status GPUOperationFromNode(const CreationContext& creation_context,
auto input_shape = inputs[0]->tensor.shape; auto input_shape = inputs[0]->tensor.shape;
auto output_shape = outputs[0]->tensor.shape; auto output_shape = outputs[0]->tensor.shape;
if (inputs.size() == 1) { if (inputs.size() == 1) {
if (WinogradFromNode(creation_context, op_def, hints, input_shape, if (WinogradFromNode(creation_context, inputs, outputs, op_def, hints,
output_shape, attr, gpu_subgraph) input_shape, output_shape, attr, gpu_subgraph)
.ok()) { .ok()) {
return absl::OkStatus(); return absl::OkStatus();
} else { } else {
@ -215,13 +217,13 @@ absl::Status GPUOperationFromNode(const CreationContext& creation_context,
auto weights_shape = inputs[1]->tensor.shape; auto weights_shape = inputs[1]->tensor.shape;
TensorDescriptor weights_desc = {op_def.src_tensors[1].data_type, TensorDescriptor weights_desc = {op_def.src_tensors[1].data_type,
TensorStorageType::BUFFER, TensorStorageType::BUFFER,
Layout::UNKNOWN}; Layout::BHWC};
gpu_subgraph->operations.clear(); gpu_subgraph->operations.clear();
gpu_subgraph->operations.resize(2); gpu_subgraph->operations.resize(2);
auto& converter_op = gpu_subgraph->operations[0]; auto& converter_op = gpu_subgraph->operations[0];
auto& conv_op = gpu_subgraph->operations[1]; auto& conv_op = gpu_subgraph->operations[1];
conv_op.input_ids = {0, -1}; conv_op.input_ids = {static_cast<int>(inputs[0]->id), -1};
conv_op.output_ids = {0}; conv_op.output_ids = {static_cast<int>(outputs[0]->id)};
OperationDef conv_def = op_def; OperationDef conv_def = op_def;
conv_def.src_tensors[1] = weights_desc; conv_def.src_tensors[1] = weights_desc;
ConvWeightsDescription conv_weights_desc; ConvWeightsDescription conv_weights_desc;
@ -242,7 +244,7 @@ absl::Status GPUOperationFromNode(const CreationContext& creation_context,
converter_def.src_tensors.push_back(op_def.src_tensors[1]); converter_def.src_tensors.push_back(op_def.src_tensors[1]);
converter_def.dst_tensors.push_back(weights_desc); converter_def.dst_tensors.push_back(weights_desc);
converter_op.input_ids = {1}; converter_op.input_ids = {static_cast<int>(inputs[1]->id)};
converter_op.output_ids = {-1}; converter_op.output_ids = {-1};
return SelectConverterToConvWeights(conv_weights_desc, creation_context, return SelectConverterToConvWeights(conv_weights_desc, creation_context,
converter_def, hints, converter_def, hints,