In all cases that would be triggered by the special case, it now will build a BroadcastInDim() such that broadcast_shape == output_dims and the xla::Reshape will not be triggered. Thus, the special case is not needed. PiperOrigin-RevId: 273015433
114 lines
4.3 KiB
C++
114 lines
4.3 KiB
C++
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
|
|
|
|
#include <vector>
|
|
|
|
#include "absl/algorithm/container.h"
|
|
#include "absl/strings/str_join.h"
|
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
|
#include "tensorflow/compiler/xla/shape_util.h"
|
|
#include "tensorflow/compiler/xla/status_macros.h"
|
|
#include "tensorflow/compiler/xla/util.h"
|
|
#include "tensorflow/core/framework/tensor_shape.h"
|
|
#include "tensorflow/core/util/bcast.h"
|
|
|
|
namespace tensorflow {
|
|
|
|
xla::StatusOr<xla::XlaOp> BroadcastTo(xla::XlaOp input,
|
|
absl::Span<int64 const> output_dims) {
|
|
xla::XlaBuilder* builder = input.builder();
|
|
TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input));
|
|
absl::Span<int64 const> input_dims =
|
|
xla::AsInt64Slice(input_shape.dimensions());
|
|
|
|
if (input_dims == output_dims) {
|
|
return input;
|
|
}
|
|
|
|
if (input_dims.size() > output_dims.size()) {
|
|
return errors::InvalidArgument(
|
|
"Input shape (", xla::ShapeUtil::HumanString(input_shape),
|
|
") must have rank less than or equal to the output shape [",
|
|
absl::StrJoin(output_dims, ","), "]");
|
|
}
|
|
|
|
std::vector<int64> broadcast_dims;
|
|
std::vector<int64> broadcast_shape;
|
|
auto input_it = input_dims.rbegin();
|
|
for (auto output_it = output_dims.rbegin(); output_it != output_dims.rend();
|
|
++output_it) {
|
|
if (input_it != input_dims.rend()) {
|
|
if (!(*output_it == 0 && *input_it == 0) &&
|
|
!(*input_it != 0 && *output_it % *input_it == 0)) {
|
|
return errors::InvalidArgument("Invalid shape broadcast from ",
|
|
xla::ShapeUtil::HumanString(input_shape),
|
|
" to [", absl::StrJoin(output_dims, ","),
|
|
"]");
|
|
}
|
|
|
|
broadcast_dims.push_back(broadcast_shape.size());
|
|
if (*output_it == *input_it || *input_it == 1) {
|
|
broadcast_shape.push_back(*output_it);
|
|
} else if (*output_it != *input_it) {
|
|
// Add dimensions [I, O/I], which we will later flatten to just
|
|
// [O]. We must do this in two phases since XLA broadcasting does not
|
|
// support tiling.
|
|
broadcast_shape.push_back(*input_it);
|
|
broadcast_shape.push_back(*output_it / *input_it);
|
|
}
|
|
++input_it;
|
|
} else {
|
|
broadcast_shape.push_back(*output_it);
|
|
}
|
|
}
|
|
TF_RET_CHECK(input_it == input_dims.rend());
|
|
|
|
absl::c_reverse(broadcast_dims);
|
|
int broadcast_shape_size = broadcast_shape.size();
|
|
for (int64& broadcast_dim : broadcast_dims) {
|
|
broadcast_dim = broadcast_shape_size - broadcast_dim - 1;
|
|
}
|
|
absl::c_reverse(broadcast_shape);
|
|
xla::XlaOp output =
|
|
xla::BroadcastInDim(input, broadcast_shape, broadcast_dims);
|
|
if (broadcast_shape != output_dims) {
|
|
output = xla::Reshape(output, output_dims);
|
|
}
|
|
return output;
|
|
}
|
|
|
|
Status BroadcastOpsToSame(xla::XlaOp* lhs, xla::XlaOp* rhs) {
|
|
TF_ASSIGN_OR_RETURN(auto lhs_xla_shape, lhs->builder()->GetShape(*lhs));
|
|
TF_ASSIGN_OR_RETURN(auto rhs_xla_shape, rhs->builder()->GetShape(*rhs));
|
|
TensorShape lhs_tf_shape;
|
|
TensorShape rhs_tf_shape;
|
|
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(lhs_xla_shape, &lhs_tf_shape));
|
|
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(rhs_xla_shape, &rhs_tf_shape));
|
|
if (!lhs_tf_shape.IsSameSize(rhs_tf_shape)) {
|
|
BCast bcast(BCast::FromShape(lhs_tf_shape), BCast::FromShape(rhs_tf_shape));
|
|
if (!bcast.IsValid()) {
|
|
return errors::InvalidArgument(
|
|
"Dimensions cannot be made to match through broadcasting");
|
|
}
|
|
TF_ASSIGN_OR_RETURN(*lhs, BroadcastTo(*lhs, bcast.output_shape()));
|
|
TF_ASSIGN_OR_RETURN(*rhs, BroadcastTo(*rhs, bcast.output_shape()));
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
} // namespace tensorflow
|