Rename elem_bcast to then_else_bcast, and remove duplicate template specification
based on review comment. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
d9b98c4d16
commit
e75409c2fe
@ -162,15 +162,15 @@ class SelectV2Op : public OpKernel {
|
|||||||
// 2-ary broadcast.
|
// 2-ary broadcast.
|
||||||
|
|
||||||
// Combine `then` and `else`.
|
// Combine `then` and `else`.
|
||||||
BCast elem_bcast(BCast::FromShape(then->shape()),
|
BCast then_else_bcast(BCast::FromShape(then->shape()),
|
||||||
BCast::FromShape(else_->shape()), false);
|
BCast::FromShape(else_->shape()), false);
|
||||||
OP_REQUIRES(ctx, elem_bcast.IsValid(),
|
OP_REQUIRES(ctx, then_else_bcast.IsValid(),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"then ", then->shape().DebugString(), " and else ",
|
"then ", then->shape().DebugString(), " and else ",
|
||||||
else_->shape().DebugString(), " must be broadcastable"));
|
else_->shape().DebugString(), " must be broadcastable"));
|
||||||
// Combine `cond` with `then` and `else`.
|
// Combine `cond` with `then` and `else`.
|
||||||
BCast bcast(BCast::FromShape(cond->shape()),
|
BCast bcast(BCast::FromShape(cond->shape()),
|
||||||
BCast::FromShape(BCast::ToShape(elem_bcast.output_shape())),
|
BCast::FromShape(BCast::ToShape(then_else_bcast.output_shape())),
|
||||||
false);
|
false);
|
||||||
OP_REQUIRES(ctx, bcast.IsValid(),
|
OP_REQUIRES(ctx, bcast.IsValid(),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
@ -443,46 +443,6 @@ struct BatchSelectFunctor<CPUDevice, T> {
|
|||||||
d.parallelFor(batch, cost, work);
|
d.parallelFor(batch, cost, work);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, int NDIMS>
|
|
||||||
struct BCastSelectFunctor<CPUDevice, T, NDIMS> {
|
|
||||||
void operator()(const CPUDevice& d,
|
|
||||||
typename TTypes<T, NDIMS>::Tensor output_tensor,
|
|
||||||
typename TTypes<bool, NDIMS>::ConstTensor cond_tensor,
|
|
||||||
typename TTypes<T, NDIMS>::ConstTensor then_tensor,
|
|
||||||
typename TTypes<T, NDIMS>::ConstTensor else_tensor,
|
|
||||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
|
|
||||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
|
|
||||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast) {
|
|
||||||
output_tensor.device(d) = cond_tensor.broadcast(cond_bcast)
|
|
||||||
.select(then_tensor.broadcast(then_bcast),
|
|
||||||
else_tensor.broadcast(else_bcast));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
|
||||||
template <typename T>
|
|
||||||
struct BatchSelectFunctor<SYCLDevice, T>
|
|
||||||
: BatchSelectFunctorBase<SYCLDevice, T> {};
|
|
||||||
|
|
||||||
template <typename T, int NDIMS>
|
|
||||||
struct BCastSelectFunctor<SYCLDevice, T, NDIMS> {
|
|
||||||
void operator()(const SYCLDevice& d,
|
|
||||||
typename TTypes<T, NDIMS>::Tensor output_tensor,
|
|
||||||
typename TTypes<bool, NDIMS>::ConstTensor cond_tensor,
|
|
||||||
typename TTypes<T, NDIMS>::ConstTensor then_tensor,
|
|
||||||
typename TTypes<T, NDIMS>::ConstTensor else_tensor,
|
|
||||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
|
|
||||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
|
|
||||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast) {
|
|
||||||
output_tensor.device(d) = cond_tensor.broadcast(cond_bcast)
|
|
||||||
.select(then_tensor.broadcast(then_bcast),
|
|
||||||
else_tensor.broadcast(else_bcast));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
|
||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -1217,7 +1217,11 @@ struct BCastSelectFunctor {
|
|||||||
typename TTypes<T, NDIMS>::ConstTensor else_tensor,
|
typename TTypes<T, NDIMS>::ConstTensor else_tensor,
|
||||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
|
typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
|
||||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
|
typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
|
||||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast);
|
typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast) {
|
||||||
|
output_tensor.device(d) = cond_tensor.broadcast(cond_bcast)
|
||||||
|
.select(then_tensor.broadcast(then_bcast),
|
||||||
|
else_tensor.broadcast(else_bcast));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace functor
|
} // end namespace functor
|
||||||
|
Loading…
Reference in New Issue
Block a user