[TF:XLA] Create helper for broadcasting ops to same dimensions.
PiperOrigin-RevId: 251603109
This commit is contained in:
parent
57c0440a73
commit
6db4b700fe
@ -19,11 +19,13 @@ cc_library(
|
||||
srcs = ["broadcast.cc"],
|
||||
hdrs = ["broadcast.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/core:framework",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
|
@ -19,9 +19,12 @@ limitations under the License.
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/util/bcast.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -88,4 +91,23 @@ xla::StatusOr<xla::XlaOp> BroadcastTo(xla::XlaOp input,
|
||||
return output;
|
||||
}
|
||||
|
||||
Status BroadcastOpsToSame(xla::XlaOp* lhs, xla::XlaOp* rhs) {
|
||||
TF_ASSIGN_OR_RETURN(auto lhs_xla_shape, lhs->builder()->GetShape(*lhs));
|
||||
TF_ASSIGN_OR_RETURN(auto rhs_xla_shape, rhs->builder()->GetShape(*rhs));
|
||||
TensorShape lhs_tf_shape;
|
||||
TensorShape rhs_tf_shape;
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(lhs_xla_shape, &lhs_tf_shape));
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(rhs_xla_shape, &rhs_tf_shape));
|
||||
if (!lhs_tf_shape.IsSameSize(rhs_tf_shape)) {
|
||||
BCast bcast(BCast::FromShape(lhs_tf_shape), BCast::FromShape(rhs_tf_shape));
|
||||
if (!bcast.IsValid()) {
|
||||
return errors::InvalidArgument(
|
||||
"Dimensions cannot be made to match through broadcasting");
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(*lhs, BroadcastTo(*lhs, bcast.output_shape()));
|
||||
TF_ASSIGN_OR_RETURN(*rhs, BroadcastTo(*rhs, bcast.output_shape()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -27,6 +27,11 @@ namespace tensorflow {
|
||||
xla::StatusOr<xla::XlaOp> BroadcastTo(xla::XlaOp input,
|
||||
absl::Span<int64 const> output_dims);
|
||||
|
||||
// Both ops are broadcasted to the same dimensions, so that each dimension is
|
||||
// the max of the two.
|
||||
// An InvalidArgument will be returned if the operations are of different rank
|
||||
// or they share a dimension where they are unequal and neither is 1.
|
||||
Status BroadcastOpsToSame(xla::XlaOp* lhs, xla::XlaOp* rhs);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_
|
||||
|
Loading…
Reference in New Issue
Block a user