Cleanup Selectv2 broadcasting
This commit is contained in:
parent
2b31ba7d0a
commit
9b3d8f2212
@ -149,21 +149,9 @@ class SelectV2Op : public OpKernel {
|
||||
|
||||
// The `cond`, `then`, and `else` are broadcastable (bcast.IsValid()),
|
||||
// This matches the behavior of numpy.
|
||||
// TODO (yongtang): Consolidate into n-ary broadcast, instead of multiple
|
||||
// 2-ary broadcast.
|
||||
|
||||
// Combine `then` and `else`.
|
||||
BCast then_else_bcast(BCast::FromShape(then->shape()),
|
||||
BCast::FromShape(else_->shape()), false);
|
||||
OP_REQUIRES(ctx, then_else_bcast.IsValid(),
|
||||
errors::InvalidArgument(
|
||||
"then ", then->shape().DebugString(), " and else ",
|
||||
else_->shape().DebugString(), " must be broadcastable"));
|
||||
// Combine `cond` with `then` and `else`.
|
||||
BCast bcast(
|
||||
BCast::FromShape(cond->shape()),
|
||||
BCast::FromShape(BCast::ToShape(then_else_bcast.output_shape())),
|
||||
false);
|
||||
BCastList<3> bcast({cond->shape().dim_sizes(), then->shape().dim_sizes(),
|
||||
else_->shape().dim_sizes()},
|
||||
false);
|
||||
OP_REQUIRES(ctx, bcast.IsValid(),
|
||||
errors::InvalidArgument(
|
||||
"condition ", cond->shape().DebugString(), ", then ",
|
||||
@ -172,12 +160,9 @@ class SelectV2Op : public OpKernel {
|
||||
|
||||
// Broadcast `cond`, `then` and `else` to combined shape,
|
||||
// in order to obtain the reshape.
|
||||
BCast cond_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())),
|
||||
BCast::FromShape(cond->shape()), false);
|
||||
BCast then_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())),
|
||||
BCast::FromShape(then->shape()), false);
|
||||
BCast else_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())),
|
||||
BCast::FromShape(else_->shape()), false);
|
||||
BCast cond_bcast(bcast.output_shape(), cond->shape().dim_sizes(), false);
|
||||
BCast then_bcast(bcast.output_shape(), then->shape().dim_sizes(), false);
|
||||
BCast else_bcast(bcast.output_shape(), else_->shape().dim_sizes(), false);
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
cond_bcast.IsValid() && then_bcast.IsValid() && else_bcast.IsValid(),
|
||||
|
Loading…
x
Reference in New Issue
Block a user