From 6db4b700fe1a286cc4dda030c1044dcc6c9919b2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 Jun 2019 02:12:56 -0700 Subject: [PATCH] [TF:XLA] Create helper for broadcasting ops to same dimensions. PiperOrigin-RevId: 251603109 --- tensorflow/compiler/tf2xla/lib/BUILD | 2 ++ tensorflow/compiler/tf2xla/lib/broadcast.cc | 22 +++++++++++++++++++++ tensorflow/compiler/tf2xla/lib/broadcast.h | 5 +++++ 3 files changed, 29 insertions(+) diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 5b1f92b24c8..8fde48b391c 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -19,11 +19,13 @@ cc_library( srcs = ["broadcast.cc"], hdrs = ["broadcast.h"], deps = [ + "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:framework", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.cc b/tensorflow/compiler/tf2xla/lib/broadcast.cc index be31f116686..a0789f982c3 100644 --- a/tensorflow/compiler/tf2xla/lib/broadcast.cc +++ b/tensorflow/compiler/tf2xla/lib/broadcast.cc @@ -19,9 +19,12 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/util/bcast.h" namespace tensorflow { @@ -88,4 +91,23 @@ xla::StatusOr BroadcastTo(xla::XlaOp input, return output; } +Status BroadcastOpsToSame(xla::XlaOp* lhs, xla::XlaOp* rhs) { + TF_ASSIGN_OR_RETURN(auto lhs_xla_shape, lhs->builder()->GetShape(*lhs)); + TF_ASSIGN_OR_RETURN(auto rhs_xla_shape, rhs->builder()->GetShape(*rhs)); + TensorShape lhs_tf_shape; + TensorShape rhs_tf_shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(lhs_xla_shape, &lhs_tf_shape)); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(rhs_xla_shape, &rhs_tf_shape)); + if (!lhs_tf_shape.IsSameSize(rhs_tf_shape)) { + BCast bcast(BCast::FromShape(lhs_tf_shape), BCast::FromShape(rhs_tf_shape)); + if (!bcast.IsValid()) { + return errors::InvalidArgument( + "Dimensions cannot be made to match through broadcasting"); + } + TF_ASSIGN_OR_RETURN(*lhs, BroadcastTo(*lhs, bcast.output_shape())); + TF_ASSIGN_OR_RETURN(*rhs, BroadcastTo(*rhs, bcast.output_shape())); + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.h b/tensorflow/compiler/tf2xla/lib/broadcast.h index 591e696f06b..766c99546e8 100644 --- a/tensorflow/compiler/tf2xla/lib/broadcast.h +++ b/tensorflow/compiler/tf2xla/lib/broadcast.h @@ -27,6 +27,11 @@ namespace tensorflow { xla::StatusOr BroadcastTo(xla::XlaOp input, absl::Span output_dims); +// Both ops are broadcasted to the same dimensions, so that each dimension is +// the max of the two. +// An InvalidArgument will be returned if the operations are of different rank +// or they share a dimension where they are unequal and neither is 1. +Status BroadcastOpsToSame(xla::XlaOp* lhs, xla::XlaOp* rhs); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_