Merge pull request #42141 from kttian:erase_grad
PiperOrigin-RevId: 326348715 Change-Id: I68e1ff52418ab030d434e91162ebf4bb82c2649e
This commit is contained in:
commit
4f8bc7c25f
@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/kernels/map_kernels.h"
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -38,4 +41,16 @@ REGISTER_KERNEL_BUILDER(Name("TensorMapErase").Device(DEVICE_CPU),
|
||||
REGISTER_KERNEL_BUILDER(Name("TensorMapHasKey").Device(DEVICE_CPU),
|
||||
TensorMapHasKey);
|
||||
|
||||
#undef REGISTER_TENSOR_MAP_OPS_CPU
|
||||
|
||||
#define REGISTER_TENSOR_MAP_OPS_CPU(T)
|
||||
|
||||
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
|
||||
TensorMap,
|
||||
TensorMapBinaryAdd<CPUDevice>);
|
||||
|
||||
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
|
||||
DEVICE_CPU, TensorMap,
|
||||
TensorMapZerosLike<CPUDevice>);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||
#include "tensorflow/core/kernels/tensor_map.h"
|
||||
#include "tensorflow/core/util/tensor_ops_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -175,6 +176,34 @@ class TensorMapHasKey : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device>
|
||||
Status TensorMapBinaryAdd(OpKernelContext* c, const TensorMap& a,
|
||||
const TensorMap& b, TensorMap* out) {
|
||||
// Binary add returns a map containing the union of keys.
|
||||
// Values with keys in the intersection are added.
|
||||
out->tensors() = a.tensors();
|
||||
for (const std::pair<TensorKey, Tensor>& p : b.tensors()) {
|
||||
absl::flat_hash_map<TensorKey, Tensor>::iterator it =
|
||||
out->tensors().find(p.first);
|
||||
if (it != out->tensors().end()) {
|
||||
Tensor out_tensor;
|
||||
TF_RETURN_IF_ERROR(
|
||||
BinaryAddTensors<Device>(c, p.second, it->second, &out_tensor));
|
||||
it->second = out_tensor;
|
||||
} else {
|
||||
out->tensors().emplace(p.first, p.second);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename Device>
|
||||
Status TensorMapZerosLike(OpKernelContext* c, const TensorMap& x,
|
||||
TensorMap* y) {
|
||||
// Zeros like returns an empty map.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_MAP_KERNELS_H_
|
||||
|
@ -41,22 +41,11 @@ void TensorMap::Encode(VariantTensorData* data) const {
|
||||
*data->add_tensors() = v;
|
||||
map_it++;
|
||||
}
|
||||
string metadata;
|
||||
// TODO(b/118838800): Add a proto for storing the metadata.
|
||||
// Metadata format:
|
||||
// <element_dtype><element_shape_proto>
|
||||
core::PutVarint64(&metadata, static_cast<uint64>(element_dtype));
|
||||
TensorShapeProto element_shape_proto;
|
||||
element_shape.AsProto(&element_shape_proto);
|
||||
element_shape_proto.AppendToString(&metadata);
|
||||
data->set_metadata(metadata);
|
||||
}
|
||||
|
||||
static Status TensorMapDeviceCopy(
|
||||
const TensorMap& from, TensorMap* to,
|
||||
const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
|
||||
to->element_shape = from.element_shape;
|
||||
to->element_dtype = from.element_dtype;
|
||||
for (const std::pair<TensorKey, Tensor>& p : from.tensors()) {
|
||||
TensorKey to_key(p.first.dtype());
|
||||
Tensor to_val(p.second.dtype());
|
||||
@ -81,11 +70,6 @@ bool TensorMap::Decode(const VariantTensorData& data) {
|
||||
// TODO(srbs): Change the signature to Decode(VariantTensorData data) so
|
||||
// that we do not have to copy each tensor individually below. This would
|
||||
// require changing VariantTensorData::tensors() as well.
|
||||
string metadata;
|
||||
data.get_metadata(&metadata);
|
||||
uint64 scratch;
|
||||
StringPiece iter(metadata);
|
||||
|
||||
std::vector<Tensor>::const_iterator tensors_it = data.tensors().begin();
|
||||
|
||||
while (tensors_it != data.tensors().end()) {
|
||||
@ -95,13 +79,6 @@ bool TensorMap::Decode(const VariantTensorData& data) {
|
||||
tensors().emplace(tensors_it[0], tensors_it[1]);
|
||||
tensors_it += 2;
|
||||
}
|
||||
|
||||
core::GetVarint64(&iter, &scratch);
|
||||
element_dtype = static_cast<DataType>(scratch);
|
||||
core::GetVarint64(&iter, &scratch);
|
||||
TensorShapeProto element_shape_proto;
|
||||
element_shape_proto.ParseFromString(string(iter.data(), iter.size()));
|
||||
element_shape = PartialTensorShape(element_shape_proto);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -69,24 +69,16 @@ class TensorMap {
|
||||
TensorMap() : tensors_(new Tensors) {}
|
||||
~TensorMap();
|
||||
|
||||
TensorMap(const TensorMap& other)
|
||||
: element_shape(other.element_shape),
|
||||
element_dtype(other.element_dtype),
|
||||
tensors_(other.tensors_) {
|
||||
TensorMap(const TensorMap& other) : tensors_(other.tensors_) {
|
||||
tensors_->Ref();
|
||||
}
|
||||
|
||||
TensorMap(TensorMap&& rhs)
|
||||
: element_shape(std::move(rhs.element_shape)),
|
||||
element_dtype(rhs.element_dtype),
|
||||
tensors_(rhs.tensors_) {
|
||||
TensorMap(TensorMap&& rhs) : tensors_(rhs.tensors_) {
|
||||
rhs.tensors_ = nullptr;
|
||||
}
|
||||
|
||||
TensorMap& operator=(const TensorMap& rhs) {
|
||||
if (this == &rhs) return *this;
|
||||
element_shape = rhs.element_shape;
|
||||
element_dtype = rhs.element_dtype;
|
||||
tensors_->Unref();
|
||||
tensors_ = rhs.tensors_;
|
||||
tensors_->Ref();
|
||||
@ -95,8 +87,6 @@ class TensorMap {
|
||||
|
||||
TensorMap& operator=(TensorMap&& rhs) {
|
||||
if (this == &rhs) return *this;
|
||||
element_shape = rhs.element_shape;
|
||||
element_dtype = rhs.element_dtype;
|
||||
std::swap(tensors_, rhs.tensors_);
|
||||
return *this;
|
||||
}
|
||||
@ -112,27 +102,18 @@ class TensorMap {
|
||||
// TODO(apassos) fill this out
|
||||
string DebugString() const { return "TensorMap"; }
|
||||
|
||||
PartialTensorShape element_shape;
|
||||
|
||||
DataType element_dtype;
|
||||
|
||||
// Access to the underlying tensor container.
|
||||
absl::flat_hash_map<TensorKey, Tensor>& tensors() {
|
||||
return tensors_->values_;
|
||||
}
|
||||
|
||||
const absl::flat_hash_map<TensorKey, Tensor>& tensors() const {
|
||||
return tensors_->values_;
|
||||
}
|
||||
|
||||
// Access to shape and element dtype
|
||||
PartialTensorShape& shape() { return element_shape; }
|
||||
DataType dtype() { return element_dtype; }
|
||||
|
||||
// Get a new TensorMap containing a copy of the underlying tensor container.
|
||||
TensorMap Copy() const {
|
||||
TensorMap out;
|
||||
out.element_shape = element_shape;
|
||||
out.element_dtype = element_dtype;
|
||||
// This performs a copy of the absl::hashmap.
|
||||
out.tensors_->values_ = tensors_->values_;
|
||||
return out;
|
||||
|
@ -114,7 +114,6 @@ TEST(TensorMapTest, Copy) {
|
||||
Tensor v = Tensor(22);
|
||||
tm.insert(k, v);
|
||||
TensorMap tmc = tm.Copy();
|
||||
EXPECT_EQ(tm.dtype(), tmc.dtype());
|
||||
EXPECT_EQ(tm.size(), tmc.size());
|
||||
EXPECT_NE(tm.find(k), tm.tensors().end());
|
||||
EXPECT_NE(tmc.find(k), tmc.tensors().end());
|
||||
@ -131,8 +130,6 @@ TEST(TensorMapTest, EncodeDecode) {
|
||||
tm.Encode(&data);
|
||||
TensorMap tmc;
|
||||
tmc.Decode(data);
|
||||
|
||||
EXPECT_EQ(tm.dtype(), tmc.dtype());
|
||||
EXPECT_EQ(tm.size(), tmc.size());
|
||||
EXPECT_NE(tm.find(k), tm.tensors().end());
|
||||
EXPECT_NE(tmc.find(k), tmc.tensors().end());
|
||||
|
@ -53,7 +53,7 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
l = map_ops.tensor_map_lookup(m, k, dtypes.float32)
|
||||
self.assertAllClose(l, v)
|
||||
|
||||
def testTensorMapLookupMissingKeyFails(self):
|
||||
def testTensorMapLookupFromEmptyMapFails(self):
|
||||
m = map_ops.empty_tensor_map()
|
||||
k = constant_op.constant(1.0)
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
@ -61,6 +61,17 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
l = map_ops.tensor_map_lookup(m, k, dtypes.float32)
|
||||
self.evaluate(l)
|
||||
|
||||
def testTensorMapLookupMissingKeyFails(self):
|
||||
m = map_ops.empty_tensor_map()
|
||||
k = constant_op.constant(1.0)
|
||||
k2 = constant_op.constant(2.0)
|
||||
v = constant_op.constant(11.0)
|
||||
m = map_ops.tensor_map_insert(m, k, v)
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"Trying to lookup non-existent key."):
|
||||
l = map_ops.tensor_map_lookup(m, k2, dtypes.float32)
|
||||
self.evaluate(l)
|
||||
|
||||
def testTensorMapErase(self):
|
||||
m = map_ops.empty_tensor_map()
|
||||
k = constant_op.constant(1.0)
|
||||
@ -105,85 +116,82 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(b, True)
|
||||
self.assertAllEqual(b2, False)
|
||||
|
||||
def testHasKeyLookup(self):
|
||||
with self.test_session():
|
||||
m = map_ops.empty_tensor_map()
|
||||
k = constant_op.constant(1.0)
|
||||
k2 = constant_op.constant(2.0)
|
||||
v = constant_op.constant(2.0)
|
||||
m = map_ops.tensor_map_insert(m, k, v)
|
||||
def testIfHasKeyLookup(self):
|
||||
m = map_ops.empty_tensor_map()
|
||||
k = constant_op.constant(1.0)
|
||||
k2 = constant_op.constant(2.0)
|
||||
v = constant_op.constant(2.0)
|
||||
m = map_ops.tensor_map_insert(m, k, v)
|
||||
|
||||
default_value = array_ops.zeros_like(v)
|
||||
l = control_flow_ops.cond(
|
||||
map_ops.tensor_map_has_key(m, k),
|
||||
lambda: map_ops.tensor_map_lookup(m, k, dtypes.float32),
|
||||
lambda: default_value)
|
||||
l2 = control_flow_ops.cond(
|
||||
map_ops.tensor_map_has_key(m, k2),
|
||||
lambda: map_ops.tensor_map_lookup(m, k, dtypes.float32),
|
||||
lambda: default_value)
|
||||
self.assertAllClose(l, v)
|
||||
self.assertAllClose(l2, default_value)
|
||||
default_value = array_ops.zeros_like(v)
|
||||
l = control_flow_ops.cond(
|
||||
map_ops.tensor_map_has_key(m, k),
|
||||
lambda: map_ops.tensor_map_lookup(m, k, dtypes.float32),
|
||||
lambda: default_value)
|
||||
l2 = control_flow_ops.cond(
|
||||
map_ops.tensor_map_has_key(m, k2),
|
||||
lambda: map_ops.tensor_map_lookup(m, k, dtypes.float32),
|
||||
lambda: default_value)
|
||||
self.assertAllClose(l, v)
|
||||
self.assertAllClose(l2, default_value)
|
||||
|
||||
def testInsertLookupGrad(self):
|
||||
with backprop.GradientTape() as tape:
|
||||
m = map_ops.empty_tensor_map()
|
||||
k = constant_op.constant(1.0)
|
||||
v = constant_op.constant(2.0)
|
||||
v = constant_op.constant(11.0)
|
||||
tape.watch(v)
|
||||
m = map_ops.tensor_map_insert(m, k, v)
|
||||
l = map_ops.tensor_map_lookup(m, k, dtypes.float32)
|
||||
l *= 5
|
||||
g = tape.gradient(l, v)
|
||||
self.assertAllClose(g, 5)
|
||||
self.assertAllEqual(g, 5)
|
||||
|
||||
def testMultipleInsertLookupGrad(self):
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
m = map_ops.empty_tensor_map()
|
||||
k = constant_op.constant(1.0)
|
||||
v = constant_op.constant(2.0)
|
||||
k2 = constant_op.constant(12.0)
|
||||
v2 = constant_op.constant(22.0)
|
||||
k3 = constant_op.constant(13.0)
|
||||
v3 = constant_op.constant(23.0)
|
||||
k2 = constant_op.constant(2.0)
|
||||
k3 = constant_op.constant(3.0)
|
||||
v = constant_op.constant(11.0)
|
||||
v2 = constant_op.constant(12.0)
|
||||
v3 = constant_op.constant(13.0)
|
||||
tape.watch(v)
|
||||
tape.watch(v2)
|
||||
tape.watch(v3)
|
||||
m = map_ops.tensor_map_insert(m, k, v)
|
||||
m = map_ops.tensor_map_insert(m, k2, v2)
|
||||
m = map_ops.tensor_map_insert(m, k3, v3)
|
||||
|
||||
l = map_ops.tensor_map_lookup(m, k, v.dtype)
|
||||
l2 = map_ops.tensor_map_lookup(m, k2, v2.dtype)
|
||||
l3 = map_ops.tensor_map_lookup(m, k3, v3.dtype)
|
||||
g = tape.gradient(l * 5, v)
|
||||
g2 = tape.gradient(l2 * 6, v2)
|
||||
g3 = tape.gradient(l3 * 7, v3)
|
||||
self.assertAllClose(g, 5)
|
||||
self.assertAllClose(g2, 6)
|
||||
self.assertAllClose(g3, 7)
|
||||
self.assertAllEqual(g, 5)
|
||||
self.assertAllEqual(g2, 6)
|
||||
self.assertAllEqual(g3, 7)
|
||||
del tape
|
||||
|
||||
def testSameKeyInsertLookupGrad(self):
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
def testInsertLookupComposeGrad(self):
|
||||
with backprop.GradientTape() as tape:
|
||||
m = map_ops.empty_tensor_map()
|
||||
k = constant_op.constant(1.0)
|
||||
v = constant_op.constant(2.0)
|
||||
v2 = constant_op.constant(22.0)
|
||||
k2 = constant_op.constant(2.0)
|
||||
v = constant_op.constant(11.0)
|
||||
tape.watch(v)
|
||||
tape.watch(v2)
|
||||
m = map_ops.tensor_map_insert(m, k, v)
|
||||
m = map_ops.tensor_map_insert(m, k, v2)
|
||||
l = map_ops.tensor_map_lookup(m, k, v.dtype)
|
||||
g = tape.gradient(l * 5, v)
|
||||
g2 = tape.gradient(l * 5, v2)
|
||||
self.assertAllClose(g, array_ops.zeros_like(v))
|
||||
self.assertAllClose(g2, 5)
|
||||
m = map_ops.tensor_map_insert(m, k2, l)
|
||||
l2 = map_ops.tensor_map_lookup(m, k2, l.dtype)
|
||||
g = tape.gradient(l2 * 5, v)
|
||||
self.assertAllEqual(g, 5)
|
||||
|
||||
def testSameKeyAlternatingInsertLookupGrad(self):
|
||||
def testReplaceLookupGrad(self):
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
m = map_ops.empty_tensor_map()
|
||||
k = constant_op.constant(1.0)
|
||||
v = constant_op.constant(2.0)
|
||||
v = constant_op.constant(11.0)
|
||||
v2 = constant_op.constant(22.0)
|
||||
tape.watch(v)
|
||||
tape.watch(v2)
|
||||
@ -191,7 +199,7 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
l = map_ops.tensor_map_lookup(m, k, v.dtype)
|
||||
self.assertAllClose(l, v)
|
||||
g = tape.gradient(l * 5, v)
|
||||
self.assertAllClose(g, 5)
|
||||
self.assertAllEqual(g, 5)
|
||||
m = map_ops.tensor_map_insert(m, k, v2)
|
||||
l2 = map_ops.tensor_map_lookup(m, k, v2.dtype)
|
||||
self.assertAllClose(l2, v2)
|
||||
@ -199,6 +207,195 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
g3 = tape.gradient(l2 * 7, v2)
|
||||
self.assertAllClose(g2, array_ops.zeros_like(v))
|
||||
self.assertAllClose(g3, 7)
|
||||
del tape
|
||||
|
||||
def testDiffKeySameValueGrad(self):
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
m = map_ops.empty_tensor_map()
|
||||
k = constant_op.constant(1.0)
|
||||
k2 = constant_op.constant(11.0)
|
||||
v = constant_op.constant(2.0)
|
||||
v2 = constant_op.constant(2.0)
|
||||
tape.watch(v)
|
||||
tape.watch(v2)
|
||||
m = map_ops.tensor_map_insert(m, k, v)
|
||||
m = map_ops.tensor_map_insert(m, k2, v)
|
||||
l = map_ops.tensor_map_lookup(m, k, v.dtype)
|
||||
l2 = map_ops.tensor_map_lookup(m, k2, v.dtype)
|
||||
g = tape.gradient(l + l2, v)
|
||||
self.assertAllEqual(g, 2)
|
||||
m = map_ops.tensor_map_insert(m, k2, v2)
|
||||
l2 = map_ops.tensor_map_lookup(m, k2, v2.dtype)
|
||||
g2 = tape.gradient(l + l2, v2)
|
||||
self.assertAllEqual(g2, 1)
|
||||
del tape
|
||||
|
||||
def testLookupAddGrad(self):
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
k = constant_op.constant(1.0)
|
||||
k2 = constant_op.constant(2.0)
|
||||
v = constant_op.constant(11.0)
|
||||
v2 = constant_op.constant(22.0)
|
||||
tape.watch(v)
|
||||
tape.watch(v2)
|
||||
m = map_ops.empty_tensor_map()
|
||||
m = map_ops.tensor_map_insert(m, k, v)
|
||||
m = map_ops.tensor_map_insert(m, k2, v2)
|
||||
l1 = map_ops.tensor_map_lookup(m, k, v.dtype)
|
||||
l2 = map_ops.tensor_map_lookup(m, k2, v2.dtype)
|
||||
g = tape.gradient(l1 + l2, [l1, l2])
|
||||
self.assertAllClose(g, [1, 1])
|
||||
g2 = tape.gradient(l1 + l2, [v, v2])
|
||||
self.assertAllClose(g2, [1, 1])
|
||||
g3 = tape.gradient(l1 + l2 * 4, v2)
|
||||
self.assertAllEqual(g3, 4)
|
||||
del tape
|
||||
|
||||
def testLookupMultiplyGrad(self):
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
k = constant_op.constant(1.0)
|
||||
k2 = constant_op.constant(2.0)
|
||||
v = constant_op.constant(11.0)
|
||||
v2 = constant_op.constant(22.0)
|
||||
tape.watch(v)
|
||||
tape.watch(v2)
|
||||
m = map_ops.empty_tensor_map()
|
||||
m = map_ops.tensor_map_insert(m, k, v)
|
||||
m = map_ops.tensor_map_insert(m, k2, v2)
|
||||
l1 = map_ops.tensor_map_lookup(m, k, v.dtype)
|
||||
l2 = map_ops.tensor_map_lookup(m, k2, v2.dtype)
|
||||
g = tape.gradient(l1 * l2, [v, v2])
|
||||
self.assertAllClose(g, [v2, v])
|
||||
g2 = tape.gradient(l1 * l1, v)
|
||||
self.assertAllClose(g2, 2 * v)
|
||||
del tape
|
||||
|
||||
def testEraseSecondGrad(self):
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
m = map_ops.empty_tensor_map()
|
||||
k = constant_op.constant(1.0)
|
||||
k2 = constant_op.constant(2.0)
|
||||
v = constant_op.constant(11.0)
|
||||
v2 = constant_op.constant(22.0)
|
||||
tape.watch(v)
|
||||
tape.watch(v2)
|
||||
m = map_ops.tensor_map_insert(m, k, v)
|
||||
m = map_ops.tensor_map_insert(m, k2, v2)
|
||||
m, e = map_ops.tensor_map_erase(m, k2, v2.dtype)
|
||||
l = map_ops.tensor_map_lookup(m, k, v.dtype)
|
||||
self.assertAllClose(l, v)
|
||||
self.assertAllClose(e, v2)
|
||||
g = tape.gradient(l * 5, v)
|
||||
self.assertAllEqual(g, 5)
|
||||
g2 = tape.gradient(e * 6, v2)
|
||||
self.assertAllEqual(g2, 6)
|
||||
del tape
|
||||
|
||||
def testEraseFirstGrad(self):
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
m = map_ops.empty_tensor_map()
|
||||
k = constant_op.constant(1.0)
|
||||
k2 = constant_op.constant(2.0)
|
||||
v = constant_op.constant(11.0)
|
||||
v2 = constant_op.constant(22.0)
|
||||
tape.watch(v)
|
||||
tape.watch(v2)
|
||||
m = map_ops.tensor_map_insert(m, k, v)
|
||||
l = map_ops.tensor_map_lookup(m, k, v.dtype)
|
||||
m = map_ops.tensor_map_insert(m, k2, v2)
|
||||
m, e = map_ops.tensor_map_erase(m, k, v.dtype)
|
||||
l2 = map_ops.tensor_map_lookup(m, k2, v2.dtype)
|
||||
self.assertAllClose(l2, v2)
|
||||
self.assertAllClose(e, v)
|
||||
g = tape.gradient(l * 5, v)
|
||||
self.assertAllEqual(g, 5)
|
||||
g2 = tape.gradient(l2 * 6, v2)
|
||||
self.assertAllEqual(g2, 6)
|
||||
g3 = tape.gradient(e * 7, v)
|
||||
self.assertAllEqual(g3, 7)
|
||||
m, e2 = map_ops.tensor_map_erase(m, k2, v2.dtype)
|
||||
g4 = tape.gradient(e2 * 8, v2)
|
||||
self.assertAllEqual(g4, 8)
|
||||
del tape
|
||||
|
||||
def testEraseInsertComposedGrad(self):
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
m = map_ops.empty_tensor_map()
|
||||
k = constant_op.constant(1.0)
|
||||
k2 = constant_op.constant(2.0)
|
||||
v = constant_op.constant(11.0)
|
||||
v2 = constant_op.constant(22.0)
|
||||
tape.watch(v)
|
||||
tape.watch(v2)
|
||||
m = map_ops.tensor_map_insert(m, k, v)
|
||||
m, e = map_ops.tensor_map_erase(m, k, v.dtype)
|
||||
m = map_ops.tensor_map_insert(m, k2, e)
|
||||
l = map_ops.tensor_map_lookup(m, k2, e.dtype)
|
||||
self.assertAllClose(e, v)
|
||||
self.assertAllClose(l, e)
|
||||
g = tape.gradient(l * 5, v)
|
||||
self.assertAllEqual(g, 5)
|
||||
g2 = tape.gradient(e * 6, v)
|
||||
self.assertAllEqual(g2, 6)
|
||||
del tape
|
||||
|
||||
def testStringKeyGrad(self):
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
m = map_ops.empty_tensor_map()
|
||||
k = constant_op.constant("key")
|
||||
k2 = constant_op.constant("key2")
|
||||
v = constant_op.constant(2.0)
|
||||
v2 = constant_op.constant(22.0)
|
||||
tape.watch(v2)
|
||||
m = map_ops.tensor_map_insert(m, k2, v2)
|
||||
m = map_ops.tensor_map_insert(m, k, v)
|
||||
s = map_ops.tensor_map_size(m)
|
||||
self.assertAllEqual(s, 2)
|
||||
l = map_ops.tensor_map_lookup(m, k, v.dtype)
|
||||
self.assertAllClose(l, v)
|
||||
m = map_ops.tensor_map_insert(m, k, v2)
|
||||
l2 = map_ops.tensor_map_lookup(m, k, v2.dtype)
|
||||
self.assertAllClose(l2, v2)
|
||||
g = tape.gradient(l2 * 5, v2)
|
||||
self.assertAllEqual(g, 5)
|
||||
|
||||
m, e = map_ops.tensor_map_erase(m, k, v2.dtype)
|
||||
s = map_ops.tensor_map_size(m)
|
||||
self.assertAllEqual(s, 1)
|
||||
self.assertAllClose(e, v2)
|
||||
g2 = tape.gradient(e * 6, v2)
|
||||
self.assertAllEqual(g2, 6)
|
||||
del tape
|
||||
|
||||
def testStringValue(self):
|
||||
m = map_ops.empty_tensor_map()
|
||||
k = constant_op.constant("key")
|
||||
v = constant_op.constant("value")
|
||||
k2 = constant_op.constant(1.0)
|
||||
v2 = constant_op.constant(2.0)
|
||||
m = map_ops.tensor_map_insert(m, k, v)
|
||||
m = map_ops.tensor_map_insert(m, k2, v2)
|
||||
l = map_ops.tensor_map_lookup(m, k, v.dtype)
|
||||
self.assertAllEqual(l, v)
|
||||
l2 = map_ops.tensor_map_lookup(m, k2, v2.dtype)
|
||||
self.assertAllClose(l2, v2)
|
||||
m, e = map_ops.tensor_map_erase(m, k, v.dtype)
|
||||
self.assertAllEqual(e, v)
|
||||
|
||||
def testVectorValue(self):
|
||||
m = map_ops.empty_tensor_map()
|
||||
k = constant_op.constant([1.0, 2.0])
|
||||
v = constant_op.constant([11.0, 22.0])
|
||||
m = map_ops.tensor_map_insert(m, k, v)
|
||||
s = map_ops.tensor_map_size(m)
|
||||
self.assertAllEqual(s, 1)
|
||||
l = map_ops.tensor_map_lookup(m, k, v.dtype)
|
||||
self.assertAllEqual(l, v)
|
||||
|
||||
m, e = map_ops.tensor_map_erase(m, k, v.dtype)
|
||||
s = map_ops.tensor_map_size(m)
|
||||
self.assertAllEqual(s, 0)
|
||||
self.assertAllClose(e, v)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -66,10 +66,16 @@ def LookupGrad(op, dval):
|
||||
def InsertGrad(op, dmap):
|
||||
_, k, v = op.inputs
|
||||
key_grad = None
|
||||
value_grad = control_flow_ops.cond(
|
||||
tensor_map_has_key(dmap, k), lambda: tensor_map_lookup(dmap, k, v.dtype),
|
||||
lambda: array_ops.zeros_like(v))
|
||||
map_grad = control_flow_ops.cond(
|
||||
tensor_map_has_key(dmap, k),
|
||||
lambda: tensor_map_erase(dmap, k, v.dtype)[0], lambda: dmap)
|
||||
(value_grad, map_grad) = control_flow_ops.cond(
|
||||
tensor_map_has_key(dmap, k), lambda: (tensor_map_lookup(
|
||||
dmap, k, v.dtype), tensor_map_erase(dmap, k, v.dtype)[0]), lambda:
|
||||
(array_ops.zeros_like(v), dmap))
|
||||
return map_grad, key_grad, value_grad
|
||||
|
||||
|
||||
@ops.RegisterGradient("TensorMapErase")
|
||||
def EraseGrad(op, dmap, dval):
|
||||
_, k = op.inputs
|
||||
key_grad = None
|
||||
map_grad = tensor_map_insert(dmap, k, dval)
|
||||
return map_grad, key_grad
|
||||
|
Loading…
Reference in New Issue
Block a user