binary add and zeros like update

This commit is contained in:
Katherine Tian 2020-08-11 22:27:04 +00:00
parent a1e5b65f08
commit 61d63b3764

View File

@ -178,7 +178,8 @@ class TensorMapHasKey : public OpKernel {
template <typename Device>
Status TensorMapBinaryAdd(OpKernelContext* c, const TensorMap& a,
const TensorMap& b, TensorMap* out) {
const TensorMap& b, TensorMap* out) {
// binary add returns a union of keys, values with keys in the intersection are added
out->tensors() = a.tensors();
for (const std::pair<TensorKey,Tensor>& p : b.tensors()) {
absl::flat_hash_map<TensorKey,Tensor>::iterator it = out->tensors().find(p.first);
@ -196,11 +197,7 @@ Status TensorMapBinaryAdd(OpKernelContext* c, const TensorMap& a,
template <typename Device>
Status TensorMapZerosLike(OpKernelContext* c, const TensorMap& x, TensorMap* y) {
for (const std::pair<TensorKey,Tensor>& p : x.tensors()) {
Tensor val;
TF_RETURN_IF_ERROR(ZerosLikeTensor<Device>(c, p.second, &val));
y->tensors().emplace(p.first, val);
}
// zeros like returns an empty map
return Status::OK();
}