[TF:XLA] Create helper for broadcasting ops to same dimensions.

PiperOrigin-RevId: 251603109
This commit is contained in:
A. Unique TensorFlower 2019-06-05 02:12:56 -07:00 committed by TensorFlower Gardener
parent 57c0440a73
commit 6db4b700fe
3 changed files with 29 additions and 0 deletions

View File

@ -19,11 +19,13 @@ cc_library(
srcs = ["broadcast.cc"],
hdrs = ["broadcast.h"],
deps = [
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/core:framework",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",

View File

@ -19,9 +19,12 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
@ -88,4 +91,23 @@ xla::StatusOr<xla::XlaOp> BroadcastTo(xla::XlaOp input,
return output;
}
Status BroadcastOpsToSame(xla::XlaOp* lhs, xla::XlaOp* rhs) {
TF_ASSIGN_OR_RETURN(auto lhs_xla_shape, lhs->builder()->GetShape(*lhs));
TF_ASSIGN_OR_RETURN(auto rhs_xla_shape, rhs->builder()->GetShape(*rhs));
TensorShape lhs_tf_shape;
TensorShape rhs_tf_shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(lhs_xla_shape, &lhs_tf_shape));
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(rhs_xla_shape, &rhs_tf_shape));
if (!lhs_tf_shape.IsSameSize(rhs_tf_shape)) {
BCast bcast(BCast::FromShape(lhs_tf_shape), BCast::FromShape(rhs_tf_shape));
if (!bcast.IsValid()) {
return errors::InvalidArgument(
"Dimensions cannot be made to match through broadcasting");
}
TF_ASSIGN_OR_RETURN(*lhs, BroadcastTo(*lhs, bcast.output_shape()));
TF_ASSIGN_OR_RETURN(*rhs, BroadcastTo(*rhs, bcast.output_shape()));
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -27,6 +27,11 @@ namespace tensorflow {
xla::StatusOr<xla::XlaOp> BroadcastTo(xla::XlaOp input,
absl::Span<int64 const> output_dims);
// Both ops are broadcasted to the same dimensions, so that each dimension is
// the max of the two.
// An InvalidArgument will be returned if the operations are of different rank
// or they share a dimension where they are unequal and neither is 1.
Status BroadcastOpsToSame(xla::XlaOp* lhs, xla::XlaOp* rhs);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_