K-FAC: Support onehot categorical in kfac.loss_functions.
PiperOrigin-RevId: 180536416
This commit is contained in:
parent
12d82a1f53
commit
7b700c515b
@ -114,5 +114,76 @@ class CategoricalLogitsNegativeLogProbLossTest(test.TestCase):
|
|||||||
self.assertEqual(loss.num_registered_minibatches, num_towers)
|
self.assertEqual(loss.num_registered_minibatches, num_towers)
|
||||||
|
|
||||||
|
|
||||||
|
class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase):
|
||||||
|
|
||||||
|
def testSample(self):
|
||||||
|
"""Ensure samples can be drawn."""
|
||||||
|
with ops.Graph().as_default(), self.test_session() as sess:
|
||||||
|
logits = np.asarray([
|
||||||
|
[0., 0., 0.], #
|
||||||
|
[1., -1., 0.]
|
||||||
|
]).astype(np.float32)
|
||||||
|
loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
|
||||||
|
array_ops.constant(logits))
|
||||||
|
sample = loss.sample(42)
|
||||||
|
sample = sess.run(sample)
|
||||||
|
self.assertEqual(sample.shape, (2, 3))
|
||||||
|
|
||||||
|
def testEvaluateOnTargets(self):
|
||||||
|
"""Ensure log probability can be evaluated correctly."""
|
||||||
|
with ops.Graph().as_default(), self.test_session() as sess:
|
||||||
|
logits = np.asarray([
|
||||||
|
[0., 0., 0.], #
|
||||||
|
[1., -1., 0.]
|
||||||
|
]).astype(np.float32)
|
||||||
|
targets = np.asarray([2, 1]).astype(np.int32)
|
||||||
|
loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
|
||||||
|
array_ops.constant(logits), targets=array_ops.one_hot(targets, 3))
|
||||||
|
neg_log_prob = loss.evaluate()
|
||||||
|
neg_log_prob = sess.run(neg_log_prob)
|
||||||
|
|
||||||
|
# Calculate explicit log probability of targets.
|
||||||
|
probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
|
||||||
|
log_probs = np.log([
|
||||||
|
probs[0, targets[0]], #
|
||||||
|
probs[1, targets[1]]
|
||||||
|
])
|
||||||
|
expected_log_prob = np.sum(log_probs)
|
||||||
|
|
||||||
|
self.assertAllClose(neg_log_prob, -expected_log_prob)
|
||||||
|
|
||||||
|
def testEvaluateOnSample(self):
|
||||||
|
"""Ensure log probability of a sample can be drawn."""
|
||||||
|
with ops.Graph().as_default(), self.test_session() as sess:
|
||||||
|
logits = np.asarray([
|
||||||
|
[0., 0., 0.], #
|
||||||
|
[1., -1., 0.]
|
||||||
|
]).astype(np.float32)
|
||||||
|
loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
|
||||||
|
array_ops.constant(logits))
|
||||||
|
neg_log_prob = loss.evaluate_on_sample(42)
|
||||||
|
|
||||||
|
# Simply ensure this doesn't crash. As the output is random, it's
|
||||||
|
# difficult to say if the output is correct or not...
|
||||||
|
neg_log_prob = sess.run(neg_log_prob)
|
||||||
|
|
||||||
|
def testMultiMinibatchRegistration(self):
|
||||||
|
"""Ensure this loss function supports registering multiple minibatches."""
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
tower_logits = []
|
||||||
|
loss = None
|
||||||
|
num_towers = 5
|
||||||
|
for _ in range(num_towers):
|
||||||
|
logits = random_ops.random_uniform(shape=[2, 3])
|
||||||
|
tower_logits.append(logits)
|
||||||
|
if loss is None:
|
||||||
|
loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
|
||||||
|
logits)
|
||||||
|
else:
|
||||||
|
loss.register_additional_minibatch(logits)
|
||||||
|
self.assertListEqual(loss.input_minibatches, tower_logits)
|
||||||
|
self.assertEqual(loss.num_registered_minibatches, num_towers)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -65,6 +65,7 @@ py_library(
|
|||||||
srcs = ["loss_functions.py"],
|
srcs = ["loss_functions.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/contrib/distributions:distributions_py",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:tensor_shape",
|
"//tensorflow/python:tensor_shape",
|
||||||
|
@ -22,6 +22,7 @@ import abc
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from tensorflow.contrib.distributions.python.ops import onehot_categorical
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -785,3 +786,16 @@ def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position):
|
|||||||
after[dim] = dim_size - position - 1
|
after[dim] = dim_size - position - 1
|
||||||
|
|
||||||
return array_ops.pad(slice_to_insert, list(zip(before, after)))
|
return array_ops.pad(slice_to_insert, list(zip(before, after)))
|
||||||
|
|
||||||
|
|
||||||
|
class OnehotCategoricalLogitsNegativeLogProbLoss(
|
||||||
|
CategoricalLogitsNegativeLogProbLoss):
|
||||||
|
"""Neg log prob loss for a categorical distribution with onehot targets.
|
||||||
|
|
||||||
|
Identical to CategoricalLogitsNegativeLogProbLoss except that the underlying
|
||||||
|
distribution is OneHotCategorical as opposed to Categorical.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dist(self):
|
||||||
|
return onehot_categorical.OneHotCategorical(logits=self._logits)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user