Merge pull request #42141 from kttian:erase_grad

PiperOrigin-RevId: 326348715
Change-Id: I68e1ff52418ab030d434e91162ebf4bb82c2649e
This commit is contained in:
TensorFlower Gardener 2020-08-12 17:42:08 -07:00
commit 4f8bc7c25f
7 changed files with 299 additions and 97 deletions

View File

@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/map_kernels.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -38,4 +41,16 @@ REGISTER_KERNEL_BUILDER(Name("TensorMapErase").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("TensorMapHasKey").Device(DEVICE_CPU),
TensorMapHasKey);
#undef REGISTER_TENSOR_MAP_OPS_CPU
#define REGISTER_TENSOR_MAP_OPS_CPU(T)
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
TensorMap,
TensorMapBinaryAdd<CPUDevice>);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_CPU, TensorMap,
TensorMapZerosLike<CPUDevice>);
} // namespace tensorflow

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/kernels/tensor_map.h"
#include "tensorflow/core/util/tensor_ops_util.h"
namespace tensorflow {
@ -175,6 +176,34 @@ class TensorMapHasKey : public OpKernel {
}
};
template <typename Device>
Status TensorMapBinaryAdd(OpKernelContext* c, const TensorMap& a,
const TensorMap& b, TensorMap* out) {
// Binary add returns a map containing the union of keys.
// Values with keys in the intersection are added.
out->tensors() = a.tensors();
for (const std::pair<TensorKey, Tensor>& p : b.tensors()) {
absl::flat_hash_map<TensorKey, Tensor>::iterator it =
out->tensors().find(p.first);
if (it != out->tensors().end()) {
Tensor out_tensor;
TF_RETURN_IF_ERROR(
BinaryAddTensors<Device>(c, p.second, it->second, &out_tensor));
it->second = out_tensor;
} else {
out->tensors().emplace(p.first, p.second);
}
}
return Status::OK();
}
template <typename Device>
Status TensorMapZerosLike(OpKernelContext* c, const TensorMap& x,
TensorMap* y) {
// Zeros like returns an empty map.
return Status::OK();
}
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_MAP_KERNELS_H_

View File

@ -41,22 +41,11 @@ void TensorMap::Encode(VariantTensorData* data) const {
*data->add_tensors() = v;
map_it++;
}
string metadata;
// TODO(b/118838800): Add a proto for storing the metadata.
// Metadata format:
// <element_dtype><element_shape_proto>
core::PutVarint64(&metadata, static_cast<uint64>(element_dtype));
TensorShapeProto element_shape_proto;
element_shape.AsProto(&element_shape_proto);
element_shape_proto.AppendToString(&metadata);
data->set_metadata(metadata);
}
static Status TensorMapDeviceCopy(
const TensorMap& from, TensorMap* to,
const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
to->element_shape = from.element_shape;
to->element_dtype = from.element_dtype;
for (const std::pair<TensorKey, Tensor>& p : from.tensors()) {
TensorKey to_key(p.first.dtype());
Tensor to_val(p.second.dtype());
@ -81,11 +70,6 @@ bool TensorMap::Decode(const VariantTensorData& data) {
// TODO(srbs): Change the signature to Decode(VariantTensorData data) so
// that we do not have to copy each tensor individually below. This would
// require changing VariantTensorData::tensors() as well.
string metadata;
data.get_metadata(&metadata);
uint64 scratch;
StringPiece iter(metadata);
std::vector<Tensor>::const_iterator tensors_it = data.tensors().begin();
while (tensors_it != data.tensors().end()) {
@ -95,13 +79,6 @@ bool TensorMap::Decode(const VariantTensorData& data) {
tensors().emplace(tensors_it[0], tensors_it[1]);
tensors_it += 2;
}
core::GetVarint64(&iter, &scratch);
element_dtype = static_cast<DataType>(scratch);
core::GetVarint64(&iter, &scratch);
TensorShapeProto element_shape_proto;
element_shape_proto.ParseFromString(string(iter.data(), iter.size()));
element_shape = PartialTensorShape(element_shape_proto);
return true;
}

View File

@ -69,24 +69,16 @@ class TensorMap {
TensorMap() : tensors_(new Tensors) {}
~TensorMap();
TensorMap(const TensorMap& other)
: element_shape(other.element_shape),
element_dtype(other.element_dtype),
tensors_(other.tensors_) {
TensorMap(const TensorMap& other) : tensors_(other.tensors_) {
tensors_->Ref();
}
TensorMap(TensorMap&& rhs)
: element_shape(std::move(rhs.element_shape)),
element_dtype(rhs.element_dtype),
tensors_(rhs.tensors_) {
TensorMap(TensorMap&& rhs) : tensors_(rhs.tensors_) {
rhs.tensors_ = nullptr;
}
TensorMap& operator=(const TensorMap& rhs) {
if (this == &rhs) return *this;
element_shape = rhs.element_shape;
element_dtype = rhs.element_dtype;
tensors_->Unref();
tensors_ = rhs.tensors_;
tensors_->Ref();
@ -95,8 +87,6 @@ class TensorMap {
TensorMap& operator=(TensorMap&& rhs) {
if (this == &rhs) return *this;
element_shape = rhs.element_shape;
element_dtype = rhs.element_dtype;
std::swap(tensors_, rhs.tensors_);
return *this;
}
@ -112,27 +102,18 @@ class TensorMap {
// TODO(apassos) fill this out
string DebugString() const { return "TensorMap"; }
PartialTensorShape element_shape;
DataType element_dtype;
// Access to the underlying tensor container.
absl::flat_hash_map<TensorKey, Tensor>& tensors() {
return tensors_->values_;
}
const absl::flat_hash_map<TensorKey, Tensor>& tensors() const {
return tensors_->values_;
}
// Access to shape and element dtype
PartialTensorShape& shape() { return element_shape; }
DataType dtype() { return element_dtype; }
// Get a new TensorMap containing a copy of the underlying tensor container.
TensorMap Copy() const {
TensorMap out;
out.element_shape = element_shape;
out.element_dtype = element_dtype;
// This performs a copy of the absl::hashmap.
out.tensors_->values_ = tensors_->values_;
return out;

View File

@ -114,7 +114,6 @@ TEST(TensorMapTest, Copy) {
Tensor v = Tensor(22);
tm.insert(k, v);
TensorMap tmc = tm.Copy();
EXPECT_EQ(tm.dtype(), tmc.dtype());
EXPECT_EQ(tm.size(), tmc.size());
EXPECT_NE(tm.find(k), tm.tensors().end());
EXPECT_NE(tmc.find(k), tmc.tensors().end());
@ -131,8 +130,6 @@ TEST(TensorMapTest, EncodeDecode) {
tm.Encode(&data);
TensorMap tmc;
tmc.Decode(data);
EXPECT_EQ(tm.dtype(), tmc.dtype());
EXPECT_EQ(tm.size(), tmc.size());
EXPECT_NE(tm.find(k), tm.tensors().end());
EXPECT_NE(tmc.find(k), tmc.tensors().end());

View File

@ -53,7 +53,7 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
l = map_ops.tensor_map_lookup(m, k, dtypes.float32)
self.assertAllClose(l, v)
def testTensorMapLookupMissingKeyFails(self):
def testTensorMapLookupFromEmptyMapFails(self):
m = map_ops.empty_tensor_map()
k = constant_op.constant(1.0)
with self.assertRaisesRegex(errors.InvalidArgumentError,
@ -61,6 +61,17 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
l = map_ops.tensor_map_lookup(m, k, dtypes.float32)
self.evaluate(l)
def testTensorMapLookupMissingKeyFails(self):
m = map_ops.empty_tensor_map()
k = constant_op.constant(1.0)
k2 = constant_op.constant(2.0)
v = constant_op.constant(11.0)
m = map_ops.tensor_map_insert(m, k, v)
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Trying to lookup non-existent key."):
l = map_ops.tensor_map_lookup(m, k2, dtypes.float32)
self.evaluate(l)
def testTensorMapErase(self):
m = map_ops.empty_tensor_map()
k = constant_op.constant(1.0)
@ -105,85 +116,82 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
self.assertAllEqual(b, True)
self.assertAllEqual(b2, False)
def testHasKeyLookup(self):
with self.test_session():
m = map_ops.empty_tensor_map()
k = constant_op.constant(1.0)
k2 = constant_op.constant(2.0)
v = constant_op.constant(2.0)
m = map_ops.tensor_map_insert(m, k, v)
def testIfHasKeyLookup(self):
m = map_ops.empty_tensor_map()
k = constant_op.constant(1.0)
k2 = constant_op.constant(2.0)
v = constant_op.constant(2.0)
m = map_ops.tensor_map_insert(m, k, v)
default_value = array_ops.zeros_like(v)
l = control_flow_ops.cond(
map_ops.tensor_map_has_key(m, k),
lambda: map_ops.tensor_map_lookup(m, k, dtypes.float32),
lambda: default_value)
l2 = control_flow_ops.cond(
map_ops.tensor_map_has_key(m, k2),
lambda: map_ops.tensor_map_lookup(m, k, dtypes.float32),
lambda: default_value)
self.assertAllClose(l, v)
self.assertAllClose(l2, default_value)
default_value = array_ops.zeros_like(v)
l = control_flow_ops.cond(
map_ops.tensor_map_has_key(m, k),
lambda: map_ops.tensor_map_lookup(m, k, dtypes.float32),
lambda: default_value)
l2 = control_flow_ops.cond(
map_ops.tensor_map_has_key(m, k2),
lambda: map_ops.tensor_map_lookup(m, k, dtypes.float32),
lambda: default_value)
self.assertAllClose(l, v)
self.assertAllClose(l2, default_value)
def testInsertLookupGrad(self):
with backprop.GradientTape() as tape:
m = map_ops.empty_tensor_map()
k = constant_op.constant(1.0)
v = constant_op.constant(2.0)
v = constant_op.constant(11.0)
tape.watch(v)
m = map_ops.tensor_map_insert(m, k, v)
l = map_ops.tensor_map_lookup(m, k, dtypes.float32)
l *= 5
g = tape.gradient(l, v)
self.assertAllClose(g, 5)
self.assertAllEqual(g, 5)
def testMultipleInsertLookupGrad(self):
with backprop.GradientTape(persistent=True) as tape:
m = map_ops.empty_tensor_map()
k = constant_op.constant(1.0)
v = constant_op.constant(2.0)
k2 = constant_op.constant(12.0)
v2 = constant_op.constant(22.0)
k3 = constant_op.constant(13.0)
v3 = constant_op.constant(23.0)
k2 = constant_op.constant(2.0)
k3 = constant_op.constant(3.0)
v = constant_op.constant(11.0)
v2 = constant_op.constant(12.0)
v3 = constant_op.constant(13.0)
tape.watch(v)
tape.watch(v2)
tape.watch(v3)
m = map_ops.tensor_map_insert(m, k, v)
m = map_ops.tensor_map_insert(m, k2, v2)
m = map_ops.tensor_map_insert(m, k3, v3)
l = map_ops.tensor_map_lookup(m, k, v.dtype)
l2 = map_ops.tensor_map_lookup(m, k2, v2.dtype)
l3 = map_ops.tensor_map_lookup(m, k3, v3.dtype)
g = tape.gradient(l * 5, v)
g2 = tape.gradient(l2 * 6, v2)
g3 = tape.gradient(l3 * 7, v3)
self.assertAllClose(g, 5)
self.assertAllClose(g2, 6)
self.assertAllClose(g3, 7)
self.assertAllEqual(g, 5)
self.assertAllEqual(g2, 6)
self.assertAllEqual(g3, 7)
del tape
def testSameKeyInsertLookupGrad(self):
with backprop.GradientTape(persistent=True) as tape:
def testInsertLookupComposeGrad(self):
with backprop.GradientTape() as tape:
m = map_ops.empty_tensor_map()
k = constant_op.constant(1.0)
v = constant_op.constant(2.0)
v2 = constant_op.constant(22.0)
k2 = constant_op.constant(2.0)
v = constant_op.constant(11.0)
tape.watch(v)
tape.watch(v2)
m = map_ops.tensor_map_insert(m, k, v)
m = map_ops.tensor_map_insert(m, k, v2)
l = map_ops.tensor_map_lookup(m, k, v.dtype)
g = tape.gradient(l * 5, v)
g2 = tape.gradient(l * 5, v2)
self.assertAllClose(g, array_ops.zeros_like(v))
self.assertAllClose(g2, 5)
m = map_ops.tensor_map_insert(m, k2, l)
l2 = map_ops.tensor_map_lookup(m, k2, l.dtype)
g = tape.gradient(l2 * 5, v)
self.assertAllEqual(g, 5)
def testSameKeyAlternatingInsertLookupGrad(self):
def testReplaceLookupGrad(self):
with backprop.GradientTape(persistent=True) as tape:
m = map_ops.empty_tensor_map()
k = constant_op.constant(1.0)
v = constant_op.constant(2.0)
v = constant_op.constant(11.0)
v2 = constant_op.constant(22.0)
tape.watch(v)
tape.watch(v2)
@ -191,7 +199,7 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
l = map_ops.tensor_map_lookup(m, k, v.dtype)
self.assertAllClose(l, v)
g = tape.gradient(l * 5, v)
self.assertAllClose(g, 5)
self.assertAllEqual(g, 5)
m = map_ops.tensor_map_insert(m, k, v2)
l2 = map_ops.tensor_map_lookup(m, k, v2.dtype)
self.assertAllClose(l2, v2)
@ -199,6 +207,195 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
g3 = tape.gradient(l2 * 7, v2)
self.assertAllClose(g2, array_ops.zeros_like(v))
self.assertAllClose(g3, 7)
del tape
def testDiffKeySameValueGrad(self):
with backprop.GradientTape(persistent=True) as tape:
m = map_ops.empty_tensor_map()
k = constant_op.constant(1.0)
k2 = constant_op.constant(11.0)
v = constant_op.constant(2.0)
v2 = constant_op.constant(2.0)
tape.watch(v)
tape.watch(v2)
m = map_ops.tensor_map_insert(m, k, v)
m = map_ops.tensor_map_insert(m, k2, v)
l = map_ops.tensor_map_lookup(m, k, v.dtype)
l2 = map_ops.tensor_map_lookup(m, k2, v.dtype)
g = tape.gradient(l + l2, v)
self.assertAllEqual(g, 2)
m = map_ops.tensor_map_insert(m, k2, v2)
l2 = map_ops.tensor_map_lookup(m, k2, v2.dtype)
g2 = tape.gradient(l + l2, v2)
self.assertAllEqual(g2, 1)
del tape
def testLookupAddGrad(self):
with backprop.GradientTape(persistent=True) as tape:
k = constant_op.constant(1.0)
k2 = constant_op.constant(2.0)
v = constant_op.constant(11.0)
v2 = constant_op.constant(22.0)
tape.watch(v)
tape.watch(v2)
m = map_ops.empty_tensor_map()
m = map_ops.tensor_map_insert(m, k, v)
m = map_ops.tensor_map_insert(m, k2, v2)
l1 = map_ops.tensor_map_lookup(m, k, v.dtype)
l2 = map_ops.tensor_map_lookup(m, k2, v2.dtype)
g = tape.gradient(l1 + l2, [l1, l2])
self.assertAllClose(g, [1, 1])
g2 = tape.gradient(l1 + l2, [v, v2])
self.assertAllClose(g2, [1, 1])
g3 = tape.gradient(l1 + l2 * 4, v2)
self.assertAllEqual(g3, 4)
del tape
def testLookupMultiplyGrad(self):
with backprop.GradientTape(persistent=True) as tape:
k = constant_op.constant(1.0)
k2 = constant_op.constant(2.0)
v = constant_op.constant(11.0)
v2 = constant_op.constant(22.0)
tape.watch(v)
tape.watch(v2)
m = map_ops.empty_tensor_map()
m = map_ops.tensor_map_insert(m, k, v)
m = map_ops.tensor_map_insert(m, k2, v2)
l1 = map_ops.tensor_map_lookup(m, k, v.dtype)
l2 = map_ops.tensor_map_lookup(m, k2, v2.dtype)
g = tape.gradient(l1 * l2, [v, v2])
self.assertAllClose(g, [v2, v])
g2 = tape.gradient(l1 * l1, v)
self.assertAllClose(g2, 2 * v)
del tape
def testEraseSecondGrad(self):
with backprop.GradientTape(persistent=True) as tape:
m = map_ops.empty_tensor_map()
k = constant_op.constant(1.0)
k2 = constant_op.constant(2.0)
v = constant_op.constant(11.0)
v2 = constant_op.constant(22.0)
tape.watch(v)
tape.watch(v2)
m = map_ops.tensor_map_insert(m, k, v)
m = map_ops.tensor_map_insert(m, k2, v2)
m, e = map_ops.tensor_map_erase(m, k2, v2.dtype)
l = map_ops.tensor_map_lookup(m, k, v.dtype)
self.assertAllClose(l, v)
self.assertAllClose(e, v2)
g = tape.gradient(l * 5, v)
self.assertAllEqual(g, 5)
g2 = tape.gradient(e * 6, v2)
self.assertAllEqual(g2, 6)
del tape
def testEraseFirstGrad(self):
with backprop.GradientTape(persistent=True) as tape:
m = map_ops.empty_tensor_map()
k = constant_op.constant(1.0)
k2 = constant_op.constant(2.0)
v = constant_op.constant(11.0)
v2 = constant_op.constant(22.0)
tape.watch(v)
tape.watch(v2)
m = map_ops.tensor_map_insert(m, k, v)
l = map_ops.tensor_map_lookup(m, k, v.dtype)
m = map_ops.tensor_map_insert(m, k2, v2)
m, e = map_ops.tensor_map_erase(m, k, v.dtype)
l2 = map_ops.tensor_map_lookup(m, k2, v2.dtype)
self.assertAllClose(l2, v2)
self.assertAllClose(e, v)
g = tape.gradient(l * 5, v)
self.assertAllEqual(g, 5)
g2 = tape.gradient(l2 * 6, v2)
self.assertAllEqual(g2, 6)
g3 = tape.gradient(e * 7, v)
self.assertAllEqual(g3, 7)
m, e2 = map_ops.tensor_map_erase(m, k2, v2.dtype)
g4 = tape.gradient(e2 * 8, v2)
self.assertAllEqual(g4, 8)
del tape
def testEraseInsertComposedGrad(self):
with backprop.GradientTape(persistent=True) as tape:
m = map_ops.empty_tensor_map()
k = constant_op.constant(1.0)
k2 = constant_op.constant(2.0)
v = constant_op.constant(11.0)
v2 = constant_op.constant(22.0)
tape.watch(v)
tape.watch(v2)
m = map_ops.tensor_map_insert(m, k, v)
m, e = map_ops.tensor_map_erase(m, k, v.dtype)
m = map_ops.tensor_map_insert(m, k2, e)
l = map_ops.tensor_map_lookup(m, k2, e.dtype)
self.assertAllClose(e, v)
self.assertAllClose(l, e)
g = tape.gradient(l * 5, v)
self.assertAllEqual(g, 5)
g2 = tape.gradient(e * 6, v)
self.assertAllEqual(g2, 6)
del tape
def testStringKeyGrad(self):
with backprop.GradientTape(persistent=True) as tape:
m = map_ops.empty_tensor_map()
k = constant_op.constant("key")
k2 = constant_op.constant("key2")
v = constant_op.constant(2.0)
v2 = constant_op.constant(22.0)
tape.watch(v2)
m = map_ops.tensor_map_insert(m, k2, v2)
m = map_ops.tensor_map_insert(m, k, v)
s = map_ops.tensor_map_size(m)
self.assertAllEqual(s, 2)
l = map_ops.tensor_map_lookup(m, k, v.dtype)
self.assertAllClose(l, v)
m = map_ops.tensor_map_insert(m, k, v2)
l2 = map_ops.tensor_map_lookup(m, k, v2.dtype)
self.assertAllClose(l2, v2)
g = tape.gradient(l2 * 5, v2)
self.assertAllEqual(g, 5)
m, e = map_ops.tensor_map_erase(m, k, v2.dtype)
s = map_ops.tensor_map_size(m)
self.assertAllEqual(s, 1)
self.assertAllClose(e, v2)
g2 = tape.gradient(e * 6, v2)
self.assertAllEqual(g2, 6)
del tape
def testStringValue(self):
m = map_ops.empty_tensor_map()
k = constant_op.constant("key")
v = constant_op.constant("value")
k2 = constant_op.constant(1.0)
v2 = constant_op.constant(2.0)
m = map_ops.tensor_map_insert(m, k, v)
m = map_ops.tensor_map_insert(m, k2, v2)
l = map_ops.tensor_map_lookup(m, k, v.dtype)
self.assertAllEqual(l, v)
l2 = map_ops.tensor_map_lookup(m, k2, v2.dtype)
self.assertAllClose(l2, v2)
m, e = map_ops.tensor_map_erase(m, k, v.dtype)
self.assertAllEqual(e, v)
def testVectorValue(self):
m = map_ops.empty_tensor_map()
k = constant_op.constant([1.0, 2.0])
v = constant_op.constant([11.0, 22.0])
m = map_ops.tensor_map_insert(m, k, v)
s = map_ops.tensor_map_size(m)
self.assertAllEqual(s, 1)
l = map_ops.tensor_map_lookup(m, k, v.dtype)
self.assertAllEqual(l, v)
m, e = map_ops.tensor_map_erase(m, k, v.dtype)
s = map_ops.tensor_map_size(m)
self.assertAllEqual(s, 0)
self.assertAllClose(e, v)
if __name__ == "__main__":

View File

@ -66,10 +66,16 @@ def LookupGrad(op, dval):
def InsertGrad(op, dmap):
_, k, v = op.inputs
key_grad = None
value_grad = control_flow_ops.cond(
tensor_map_has_key(dmap, k), lambda: tensor_map_lookup(dmap, k, v.dtype),
lambda: array_ops.zeros_like(v))
map_grad = control_flow_ops.cond(
tensor_map_has_key(dmap, k),
lambda: tensor_map_erase(dmap, k, v.dtype)[0], lambda: dmap)
(value_grad, map_grad) = control_flow_ops.cond(
tensor_map_has_key(dmap, k), lambda: (tensor_map_lookup(
dmap, k, v.dtype), tensor_map_erase(dmap, k, v.dtype)[0]), lambda:
(array_ops.zeros_like(v), dmap))
return map_grad, key_grad, value_grad
@ops.RegisterGradient("TensorMapErase")
def EraseGrad(op, dmap, dval):
_, k = op.inputs
key_grad = None
map_grad = tensor_map_insert(dmap, k, dval)
return map_grad, key_grad