K-FAC: Support onehot categorical in kfac.loss_functions.
PiperOrigin-RevId: 180536416
This commit is contained in:
parent
12d82a1f53
commit
7b700c515b
@ -114,5 +114,76 @@ class CategoricalLogitsNegativeLogProbLossTest(test.TestCase):
|
||||
self.assertEqual(loss.num_registered_minibatches, num_towers)
|
||||
|
||||
|
||||
class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase):
|
||||
|
||||
def testSample(self):
|
||||
"""Ensure samples can be drawn."""
|
||||
with ops.Graph().as_default(), self.test_session() as sess:
|
||||
logits = np.asarray([
|
||||
[0., 0., 0.], #
|
||||
[1., -1., 0.]
|
||||
]).astype(np.float32)
|
||||
loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
|
||||
array_ops.constant(logits))
|
||||
sample = loss.sample(42)
|
||||
sample = sess.run(sample)
|
||||
self.assertEqual(sample.shape, (2, 3))
|
||||
|
||||
def testEvaluateOnTargets(self):
|
||||
"""Ensure log probability can be evaluated correctly."""
|
||||
with ops.Graph().as_default(), self.test_session() as sess:
|
||||
logits = np.asarray([
|
||||
[0., 0., 0.], #
|
||||
[1., -1., 0.]
|
||||
]).astype(np.float32)
|
||||
targets = np.asarray([2, 1]).astype(np.int32)
|
||||
loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
|
||||
array_ops.constant(logits), targets=array_ops.one_hot(targets, 3))
|
||||
neg_log_prob = loss.evaluate()
|
||||
neg_log_prob = sess.run(neg_log_prob)
|
||||
|
||||
# Calculate explicit log probability of targets.
|
||||
probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
|
||||
log_probs = np.log([
|
||||
probs[0, targets[0]], #
|
||||
probs[1, targets[1]]
|
||||
])
|
||||
expected_log_prob = np.sum(log_probs)
|
||||
|
||||
self.assertAllClose(neg_log_prob, -expected_log_prob)
|
||||
|
||||
def testEvaluateOnSample(self):
|
||||
"""Ensure log probability of a sample can be drawn."""
|
||||
with ops.Graph().as_default(), self.test_session() as sess:
|
||||
logits = np.asarray([
|
||||
[0., 0., 0.], #
|
||||
[1., -1., 0.]
|
||||
]).astype(np.float32)
|
||||
loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
|
||||
array_ops.constant(logits))
|
||||
neg_log_prob = loss.evaluate_on_sample(42)
|
||||
|
||||
# Simply ensure this doesn't crash. As the output is random, it's
|
||||
# difficult to say if the output is correct or not...
|
||||
neg_log_prob = sess.run(neg_log_prob)
|
||||
|
||||
def testMultiMinibatchRegistration(self):
|
||||
"""Ensure this loss function supports registering multiple minibatches."""
|
||||
with ops.Graph().as_default():
|
||||
tower_logits = []
|
||||
loss = None
|
||||
num_towers = 5
|
||||
for _ in range(num_towers):
|
||||
logits = random_ops.random_uniform(shape=[2, 3])
|
||||
tower_logits.append(logits)
|
||||
if loss is None:
|
||||
loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
|
||||
logits)
|
||||
else:
|
||||
loss.register_additional_minibatch(logits)
|
||||
self.assertListEqual(loss.input_minibatches, tower_logits)
|
||||
self.assertEqual(loss.num_registered_minibatches, num_towers)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -65,6 +65,7 @@ py_library(
|
||||
srcs = ["loss_functions.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/distributions:distributions_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
|
@ -22,6 +22,7 @@ import abc
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import onehot_categorical
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -785,3 +786,16 @@ def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position):
|
||||
after[dim] = dim_size - position - 1
|
||||
|
||||
return array_ops.pad(slice_to_insert, list(zip(before, after)))
|
||||
|
||||
|
||||
class OnehotCategoricalLogitsNegativeLogProbLoss(
|
||||
CategoricalLogitsNegativeLogProbLoss):
|
||||
"""Neg log prob loss for a categorical distribution with onehot targets.
|
||||
|
||||
Identical to CategoricalLogitsNegativeLogProbLoss except that the underlying
|
||||
distribution is OneHotCategorical as opposed to Categorical.
|
||||
"""
|
||||
|
||||
@property
|
||||
def dist(self):
|
||||
return onehot_categorical.OneHotCategorical(logits=self._logits)
|
||||
|
Loading…
x
Reference in New Issue
Block a user