Merge pull request #3656 from zheng-xq/branch_129393964

Branch 129393964
This commit is contained in:
zheng-xq 2016-08-04 22:18:56 -07:00 committed by GitHub
commit c99b9bef4c
409 changed files with 16177 additions and 5195 deletions

View File

@ -37,7 +37,10 @@ config_setting(
package_group( package_group(
name = "internal", name = "internal",
packages = ["//tensorflow/..."], packages = [
"//learning/vis/...",
"//tensorflow/...",
],
) )
sh_binary( sh_binary(

View File

@ -482,7 +482,6 @@ static void TF_Run_Helper(
result = session->PRun(handle, input_pairs, output_tensor_names, &outputs); result = session->PRun(handle, input_pairs, output_tensor_names, &outputs);
} }
if (!result.ok()) { if (!result.ok()) {
LOG(ERROR) << result.error_message();
status->status = result; status->status = result;
return; return;
} }

View File

@ -73,6 +73,46 @@ tf_cc_test(
], ],
) )
cc_library(
name = "grad_op_registry",
srcs = ["framework/grad_op_registry.cc"],
hdrs = ["framework/grad_op_registry.h"],
deps = [
":ops",
":scope",
],
)
cc_library(
name = "math_grad",
srcs = ["gradients/math_grad.cc"],
deps = [
":cc_ops",
":grad_op_registry",
":ops",
":scope",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
],
)
tf_cc_test(
name = "gradients/math_grad_test",
deps = [
":cc_ops",
":grad_op_registry",
":math_grad",
"//tensorflow/core:all_kernels",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib_internal",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_gen_op_wrappers_cc( tf_gen_op_wrappers_cc(
name = "cc_ops", name = "cc_ops",
op_lib_names = [ op_lib_names = [

View File

@ -0,0 +1,42 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/framework/grad_op_registry.h"
namespace tensorflow {
namespace ops {
// static
GradOpRegistry* GradOpRegistry::Global() {
static GradOpRegistry* grad_op_registry = new GradOpRegistry;
return grad_op_registry;
}
bool GradOpRegistry::Register(const string& op, GradFunc func) {
CHECK(registry_.insert({op, func}).second) << "Existing gradient for " << op;
return true;
}
Status GradOpRegistry::Lookup(const string& op, GradFunc* func) {
auto iter = registry_.find(op);
if (iter == registry_.end()) {
return errors::NotFound("No gradient defined for op: ", op);
}
*func = iter->second;
return Status::OK();
}
} // end namespace ops
} // namespace tensorflow

View File

@ -0,0 +1,75 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
#include <unordered_map>
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
namespace tensorflow {
namespace ops {
// GradFunc is the signature for all gradient functions in GradOpRegistry.
// Implementations should add operations to compute the gradient outputs of 'op'
// (returned in 'grad_outputs') using 'scope' and 'grad_inputs'.
typedef Status (*GradFunc)(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs);
// GradOpRegistry maintains a static registry of gradient functions.
// Gradient functions are indexed in the registry by the forward op name (i.e.
// "MatMul" -> MatMulGrad func).
class GradOpRegistry {
public:
// Registers 'func' as the the gradient function for 'op'.
// Returns true if registration was succesful, check fails otherwise.
bool Register(const string& op, GradFunc func);
// Sets 'func' to the gradient function for 'op' and returns Status OK if
// the gradient function for 'op' exists in the registry.
// Note that 'func' can be null for ops that have registered no-gradient with
// the registry.
// Returns error status otherwise.
Status Lookup(const string& op, GradFunc* func);
// Returns a pointer to the global gradient function registry.
static GradOpRegistry* Global();
private:
std::unordered_map<string, GradFunc> registry_;
};
} // namespace ops
// Macros used to define gradient functions for ops.
#define REGISTER_GRADIENT_OP(name, fn) \
REGISTER_GRADIENT_OP_UNIQ_HELPER(__COUNTER__, name, fn)
#define REGISTER_NO_GRADIENT_OP(name) \
REGISTER_GRADIENT_OP_UNIQ_HELPER(__COUNTER__, name, nullptr)
#define REGISTER_GRADIENT_OP_UNIQ_HELPER(ctr, name, fn) \
REGISTER_GRADIENT_OP_UNIQ(ctr, name, fn)
#define REGISTER_GRADIENT_OP_UNIQ(ctr, name, fn) \
static bool unused_ret_val_##ctr = \
::tensorflow::ops::GradOpRegistry::Global()->Register(name, fn)
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_

View File

@ -18,6 +18,44 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace ops { namespace ops {
Operation::Operation(Node* n) : inputs_(GetInputs(n)), node_(n) {}
Output Operation::input(int i) const {
CHECK_NOTNULL(node_);
CHECK_GE(i, 0);
CHECK_LT(i, node_->num_inputs());
// Handle the case where the input was unknown at the time this
// Operation was constructed.
if (inputs_[i].first == nullptr && inputs_[i].second == -1) {
for (const Edge* e : node_->in_edges()) {
if (e->IsControlEdge()) continue;
if (e->dst_input() == i) {
return Output(e->src(), e->src_output());
}
}
}
return Output(inputs_[i].first, inputs_[i].second);
}
Output Operation::output(int i) const {
CHECK_NOTNULL(node_);
CHECK_GE(i, 0);
CHECK_LT(i, node_->num_outputs());
return Output(node_, i);
}
Operation::Inputs Operation::GetInputs(Node* node) {
Operation::Inputs inputs;
if (node != nullptr) {
inputs.resize(node->num_inputs(), {nullptr, -1});
for (const Edge* e : node->in_edges()) {
if (e->IsControlEdge()) continue;
inputs[e->dst_input()] = std::make_pair(e->src(), e->src_output());
}
}
return inputs;
}
Input::Initializer::Initializer( Input::Initializer::Initializer(
const std::initializer_list<Input::Initializer>& v) { const std::initializer_list<Input::Initializer>& v) {
if (v.size() < 1) { if (v.size() < 1) {

View File

@ -27,17 +27,29 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace ops { namespace ops {
class Output;
// Represents a node in the computation graph. // Represents a node in the computation graph.
class Operation { class Operation {
public: public:
Operation() : node_(nullptr) {} Operation() : node_(nullptr) {}
explicit Operation(Node* n) : node_(n) {} explicit Operation(Node* n);
int num_inputs() const { return node_->num_inputs(); }
DataType input_type(int o) const { return node_->input_type(o); }
Output input(int i) const;
int num_outputs() const { return node_->num_outputs(); } int num_outputs() const { return node_->num_outputs(); }
DataType output_type(int o) const { return node_->output_type(o); } DataType output_type(int o) const { return node_->output_type(o); }
Output output(int i) const;
Node* node() const { return node_; } Node* node() const { return node_; }
private: private:
typedef std::vector<std::pair<Node*, int64>> Inputs;
static Inputs GetInputs(Node* node);
Inputs inputs_;
Node* node_; Node* node_;
}; };
@ -81,7 +93,7 @@ class Input {
tensor = t; tensor = t;
} }
explicit Initializer(const Tensor& t) : tensor(t) {} Initializer(const Tensor& t) : tensor(t) {} // NOLINT(runtime/explicit)
// Construct from a scalar value and an explicit shape // Construct from a scalar value and an explicit shape
template <typename T, typename = typename std::enable_if< template <typename T, typename = typename std::enable_if<

View File

@ -0,0 +1,91 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/cc/framework/grad_op_registry.h"
namespace tensorflow {
namespace ops {
namespace {
// TODO(andydavis) Move this to a more appropriate file.
REGISTER_NO_GRADIENT_OP("Const");
// MatMulGrad helper function used to compute two MatMul operations
// based on input matrix transposition combinations.
Status MatMulGradHelper(const Scope& scope, const Output& x0, const bool adj_x0,
const Output& x1, const bool adj_x1, const Output& y0,
const bool adj_y0, const Output& y1, const bool adj_y1,
std::vector<Output>* grad_outputs) {
auto dx =
MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1));
grad_outputs->push_back(dx);
auto dy =
MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1));
grad_outputs->push_back(dy);
return Status::OK();
}
// MatMulGrad common used to read and check node attr state, and determine
// proper MatMul products for gradients based on input matrix transposition
// combinations.
// TODO(andydavis) Re-use this function for BatchMatMulGrad.
Status MatMulGradCommon(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
const string& attr_adj_x, const string& attr_adj_y,
std::vector<Output>* grad_outputs) {
DataType dtype;
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), "T", &dtype));
if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
return errors::Unimplemented(
"MatMul gradient for complex data type is not supported yet.");
}
bool ta;
bool tb;
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_x, &ta));
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_y, &tb));
if (!ta && !tb) {
return MatMulGradHelper(scope, grad_inputs[0], false, op.input(1), true,
op.input(0), true, grad_inputs[0], false,
grad_outputs);
} else if (!ta && tb) {
return MatMulGradHelper(scope, grad_inputs[0], false, op.input(1), false,
grad_inputs[0], true, op.input(0), false,
grad_outputs);
} else if (ta && !tb) {
return MatMulGradHelper(scope, op.input(1), false, grad_inputs[0], true,
op.input(0), false, grad_inputs[0], false,
grad_outputs);
}
return MatMulGradHelper(scope, op.input(1), true, grad_inputs[0], true,
grad_inputs[0], true, op.input(0), true,
grad_outputs);
}
Status MatMulGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
return MatMulGradCommon(scope, op, grad_inputs, "transpose_a", "transpose_b",
grad_outputs);
}
REGISTER_GRADIENT_OP("MatMul", MatMulGrad);
} // anonymous namespace
} // namespace ops
} // namespace tensorflow

View File

@ -0,0 +1,183 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/framework/grad_op_registry.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow {
using namespace ops; // NOLINT(build/namespaces)
namespace {
// TODO(andydavis) Test gradient function against numeric gradients output.
// TODO(andydavis) As more gradients are added move common test functions
// to a testutil library.
class MathGradTest : public ::testing::Test {
protected:
MathGradTest() : root_(Scope::NewRootScope()) {}
void ComputeMatMulGrad(const Output& x, const bool t_x, const Output& y,
const bool t_y, const Output& dz,
std::vector<Tensor>* out) {
// Compute forward MatMul: z = MatMul(x, y).
auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
TF_EXPECT_OK(root_.status());
CHECK_NOTNULL(z.node());
std::vector<Output> grad_outputs;
// Call MatMulGrad which populates 'grad_outputs'.
CallGradFunction(Operation(z.node()), {dz}, &grad_outputs);
EXPECT_EQ(2, grad_outputs.size());
// Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'.
GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out);
}
void CallGradFunction(const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
GradFunc grad_fn;
TF_EXPECT_OK(GradOpRegistry::Global()->Lookup(op.node()->name(), &grad_fn));
TF_EXPECT_OK(grad_fn(root_, op, grad_inputs, grad_outputs));
TF_EXPECT_OK(root_.status());
}
Tensor ComputeMatMul(const Output& x, const bool t_x, const Output& y,
const bool t_y) {
auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
TF_EXPECT_OK(root_.status());
Tensor out;
GetTensor(root_, z, &out);
return out;
}
void RandMatMulGradData(const bool tx, const bool ty,
std::vector<Tensor>* data) {
// z = MatMul(x, y)
const int m = Rand();
const int k = Rand();
const int n = Rand();
// x.shape = [m, k]
const TensorShape x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
data->emplace_back(DT_FLOAT, x_shape);
RandTensor(&data->back());
// y.shape = [k, n]
const TensorShape y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
data->emplace_back(DT_FLOAT, y_shape);
RandTensor(&data->back());
// z.shape = [m, n]
data->emplace_back(DT_FLOAT, TensorShape({m, n}));
RandTensor(&data->back());
}
void RandTensor(Tensor* t) {
test::FillFn<float>(
t, [this](const int i) { return static_cast<float>(Rand()); });
}
int Rand() { return 1 + (random::New64() % 10); }
// TODO(andydavis) Move 'GetTensors/GetTensor' to some testutil class.
// Note: they should be moved to a general/non-grad specific testutil class.
void GetTensors(const Scope& scope, OutputList tensors,
std::vector<Tensor>* out) {
SessionOptions options;
std::unique_ptr<Session> session(NewSession(options));
GraphDef def;
scope.graph()->ToGraphDef(&def);
graph::SetDefaultDevice("/cpu:0", &def);
TF_CHECK_OK(session->Create(def));
std::vector<string> names;
for (const auto& t : tensors) {
names.push_back(strings::StrCat(t.node()->name(), ":", t.index()));
}
TF_CHECK_OK(session->Run({}, names, {}, out));
TF_CHECK_OK(session->Close());
}
void GetTensor(const Scope& scope, Output tensor, Tensor* out) {
std::vector<Tensor> outputs;
GetTensors(scope, {tensor}, &outputs);
*out = outputs[0];
}
Scope root_;
};
TEST_F(MathGradTest, MatMulGrad_NoTranspose) {
std::vector<Tensor> data;
RandMatMulGradData(false, false, &data);
auto x = Const(root_, data[0]);
auto y = Const(root_, data[1]);
auto dz = Const(root_, data[2]);
std::vector<Tensor> grad_outputs;
ComputeMatMulGrad(x, false, y, false, dz, &grad_outputs);
test::ExpectClose(grad_outputs[0], ComputeMatMul(dz, false, y, true));
test::ExpectClose(grad_outputs[1], ComputeMatMul(x, true, dz, false));
}
TEST_F(MathGradTest, MatMulGrad_TransposeX) {
std::vector<Tensor> data;
RandMatMulGradData(true, false, &data);
auto x = Const(root_, data[0]);
auto y = Const(root_, data[1]);
auto dz = Const(root_, data[2]);
std::vector<Tensor> grad_outputs;
ComputeMatMulGrad(x, true, y, false, dz, &grad_outputs);
test::ExpectClose(grad_outputs[0], ComputeMatMul(y, false, dz, true));
test::ExpectClose(grad_outputs[1], ComputeMatMul(x, false, dz, false));
}
TEST_F(MathGradTest, MatMulGrad_TransposeY) {
std::vector<Tensor> data;
RandMatMulGradData(false, true, &data);
auto x = Const(root_, data[0]);
auto y = Const(root_, data[1]);
auto dz = Const(root_, data[2]);
std::vector<Tensor> grad_outputs;
ComputeMatMulGrad(x, false, y, true, dz, &grad_outputs);
test::ExpectClose(grad_outputs[0], ComputeMatMul(dz, false, y, false));
test::ExpectClose(grad_outputs[1], ComputeMatMul(dz, true, x, false));
}
TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
std::vector<Tensor> data;
RandMatMulGradData(true, true, &data);
auto x = Const(root_, data[0]);
auto y = Const(root_, data[1]);
auto dz = Const(root_, data[2]);
std::vector<Tensor> grad_outputs;
ComputeMatMulGrad(x, true, y, true, dz, &grad_outputs);
test::ExpectClose(grad_outputs[0], ComputeMatMul(y, true, dz, true));
test::ExpectClose(grad_outputs[1], ComputeMatMul(dz, true, x, true));
}
} // namespace
} // namespace tensorflow

View File

@ -99,7 +99,16 @@ cuda_py_tests(
srcs = ["python/kernel_tests/beta_test.py"], srcs = ["python/kernel_tests/beta_test.py"],
additional_deps = [ additional_deps = [
":distributions_py", ":distributions_py",
"//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test",
],
)
cuda_py_tests(
name = "binomial_test",
size = "small",
srcs = ["python/kernel_tests/binomial_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
tags = ["notsan"], tags = ["notsan"],
@ -179,9 +188,8 @@ cuda_py_tests(
) )
cuda_py_tests( cuda_py_tests(
name = "kullback_leibler_test", name = "laplace_test",
size = "small", srcs = ["python/kernel_tests/laplace_test.py"],
srcs = ["python/kernel_tests/kullback_leibler_test.py"],
additional_deps = [ additional_deps = [
":distributions_py", ":distributions_py",
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
@ -190,13 +198,14 @@ cuda_py_tests(
) )
cuda_py_tests( cuda_py_tests(
name = "laplace_test", name = "multinomial_test",
srcs = ["python/kernel_tests/laplace_test.py"], srcs = ["python/kernel_tests/multinomial_test.py"],
additional_deps = [ additional_deps = [
":distributions_py", ":distributions_py",
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
tags = ["notsan"],
) )
cuda_py_tests( cuda_py_tests(
@ -239,6 +248,15 @@ cuda_py_tests(
srcs = ["python/kernel_tests/uniform_test.py"], srcs = ["python/kernel_tests/uniform_test.py"],
additional_deps = [ additional_deps = [
":distributions_py", ":distributions_py",
"//tensorflow/python:framework_test_lib",
],
)
cuda_py_tests(
name = "kullback_leibler_test",
size = "small",
srcs = ["python/kernel_tests/kullback_leibler_test.py"],
additional_deps = [
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
) )

View File

@ -25,6 +25,7 @@ initialized with parameters that define the distributions.
### Univariate (scalar) distributions ### Univariate (scalar) distributions
@@Binomial
@@Bernoulli @@Bernoulli
@@Beta @@Beta
@@Categorical @@Categorical
@ -50,6 +51,7 @@ initialized with parameters that define the distributions.
@@Dirichlet @@Dirichlet
@@DirichletMultinomial @@DirichletMultinomial
@@Multinomial
### Transformed distributions ### Transformed distributions
@ -79,6 +81,7 @@ from __future__ import print_function
from tensorflow.contrib.distributions.python.ops.bernoulli import * from tensorflow.contrib.distributions.python.ops.bernoulli import *
from tensorflow.contrib.distributions.python.ops.beta import * from tensorflow.contrib.distributions.python.ops.beta import *
from tensorflow.contrib.distributions.python.ops.binomial import *
from tensorflow.contrib.distributions.python.ops.categorical import * from tensorflow.contrib.distributions.python.ops.categorical import *
from tensorflow.contrib.distributions.python.ops.chi2 import * from tensorflow.contrib.distributions.python.ops.chi2 import *
from tensorflow.contrib.distributions.python.ops.dirichlet import * from tensorflow.contrib.distributions.python.ops.dirichlet import *
@ -89,6 +92,7 @@ from tensorflow.contrib.distributions.python.ops.gamma import *
from tensorflow.contrib.distributions.python.ops.inverse_gamma import * from tensorflow.contrib.distributions.python.ops.inverse_gamma import *
from tensorflow.contrib.distributions.python.ops.kullback_leibler import * from tensorflow.contrib.distributions.python.ops.kullback_leibler import *
from tensorflow.contrib.distributions.python.ops.laplace import * from tensorflow.contrib.distributions.python.ops.laplace import *
from tensorflow.contrib.distributions.python.ops.multinomial import *
from tensorflow.contrib.distributions.python.ops.mvn import * from tensorflow.contrib.distributions.python.ops.mvn import *
from tensorflow.contrib.distributions.python.ops.normal import * from tensorflow.contrib.distributions.python.ops.normal import *
from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import * from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import *

View File

@ -57,10 +57,17 @@ class BernoulliTest(tf.test.TestCase):
self.assertAllClose(scipy.special.logit(p), dist.logits.eval()) self.assertAllClose(scipy.special.logit(p), dist.logits.eval())
def testInvalidP(self): def testInvalidP(self):
invalid_ps = [1.01, -0.01, 2., -3.] invalid_ps = [1.01, 2.]
for p in invalid_ps: for p in invalid_ps:
with self.test_session(): with self.test_session():
with self.assertRaisesOpError("x <= y"): with self.assertRaisesOpError("p has components greater than 1"):
dist = tf.contrib.distributions.Bernoulli(p=p)
dist.p.eval()
invalid_ps = [-0.01, -3.]
for p in invalid_ps:
with self.test_session():
with self.assertRaisesOpError("Condition x >= 0"):
dist = tf.contrib.distributions.Bernoulli(p=p) dist = tf.contrib.distributions.Bernoulli(p=p)
dist.p.eval() dist.p.eval()

View File

@ -0,0 +1,173 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from scipy import stats
import tensorflow as tf
class BinomialTest(tf.test.TestCase):
def testSimpleShapes(self):
with self.test_session():
p = np.float32(np.random.beta(1, 1))
binom = tf.contrib.distributions.Binomial(n=1., p=p)
self.assertAllEqual([], binom.event_shape().eval())
self.assertAllEqual([], binom.batch_shape().eval())
self.assertEqual(tf.TensorShape([]), binom.get_event_shape())
self.assertEqual(tf.TensorShape([]), binom.get_batch_shape())
def testComplexShapes(self):
with self.test_session():
p = np.random.beta(1, 1, size=(3, 2)).astype(np.float32)
n = [[3., 2], [4, 5], [6, 7]]
binom = tf.contrib.distributions.Binomial(n=n, p=p)
self.assertAllEqual([], binom.event_shape().eval())
self.assertAllEqual([3, 2], binom.batch_shape().eval())
self.assertEqual(tf.TensorShape([]), binom.get_event_shape())
self.assertEqual(tf.TensorShape([3, 2]), binom.get_batch_shape())
def testNProperty(self):
p = [[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]]
n = [[3.], [4]]
with self.test_session():
binom = tf.contrib.distributions.Binomial(n=n, p=p)
self.assertEqual((2, 1), binom.n.get_shape())
self.assertAllClose(n, binom.n.eval())
def testPProperty(self):
p = [[0.1, 0.2, 0.7]]
with self.test_session():
binom = tf.contrib.distributions.Binomial(n=3., p=p)
self.assertEqual((1, 3), binom.p.get_shape())
self.assertEqual((1, 3), binom.logits.get_shape())
self.assertAllClose(p, binom.p.eval())
def testLogitsProperty(self):
logits = [[0., 9., -0.5]]
with self.test_session():
binom = tf.contrib.distributions.Binomial(n=3., logits=logits)
self.assertEqual((1, 3), binom.p.get_shape())
self.assertEqual((1, 3), binom.logits.get_shape())
self.assertAllClose(logits, binom.logits.eval())
def testPmfNandCountsAgree(self):
p = [[0.1, 0.2, 0.7]]
n = [[5.]]
with self.test_session():
binom = tf.contrib.distributions.Binomial(n=n, p=p)
binom.pmf([2., 3, 2]).eval()
binom.pmf([3., 1, 2]).eval()
with self.assertRaisesOpError('Condition x >= 0.*'):
binom.pmf([-1., 4, 2]).eval()
with self.assertRaisesOpError('Condition x <= y.*'):
binom.pmf([7., 3, 0]).eval()
def testPmf_non_integer_counts(self):
p = [[0.1, 0.2, 0.7]]
n = [[5.]]
with self.test_session():
# No errors with integer n.
binom = tf.contrib.distributions.Binomial(n=n, p=p)
binom.pmf([2., 3, 2]).eval()
binom.pmf([3., 1, 2]).eval()
# Both equality and integer checking fail.
with self.assertRaisesOpError('Condition x == y.*'):
binom.pmf([1.0, 2.5, 1.5]).eval()
binom = tf.contrib.distributions.Binomial(n=n, p=p, validate_args=False)
binom.pmf([1., 2., 3.]).eval()
# Non-integer arguments work.
binom.pmf([1.0, 2.5, 1.5]).eval()
def testPmfBothZeroBatches(self):
with self.test_session():
# Both zero-batches. No broadcast
p = 0.5
counts = 1.
pmf = tf.contrib.distributions.Binomial(n=1., p=p).pmf(counts)
self.assertAllClose(0.5, pmf.eval())
self.assertEqual((), pmf.get_shape())
def testPmfBothZeroBatchesNontrivialN(self):
with self.test_session():
# Both zero-batches. No broadcast
p = 0.1
counts = 3.
binom = tf.contrib.distributions.Binomial(n=5., p=p)
pmf = binom.pmf(counts)
self.assertAllClose(stats.binom.pmf(counts, n=5., p=p), pmf.eval())
self.assertEqual((), pmf.get_shape())
def testPmfPStretchedInBroadcastWhenSameRank(self):
with self.test_session():
p = [[0.1, 0.9]]
counts = [[1., 2.]]
pmf = tf.contrib.distributions.Binomial(n=3., p=p).pmf(counts)
self.assertAllClose(stats.binom.pmf(counts, n=3., p=p), pmf.eval())
self.assertEqual((1, 2), pmf.get_shape())
def testPmfPStretchedInBroadcastWhenLowerRank(self):
with self.test_session():
p = [0.1, 0.4]
counts = [[1.], [0.]]
pmf = tf.contrib.distributions.Binomial(n=1., p=p).pmf(counts)
self.assertAllClose([[0.1, 0.4], [0.9, 0.6]], pmf.eval())
self.assertEqual((2, 2), pmf.get_shape())
def testBinomialMean(self):
with self.test_session():
n = 5.
p = [0.1, 0.2, 0.7]
binom = tf.contrib.distributions.Binomial(n=n, p=p)
expected_means = stats.binom.mean(n, p)
self.assertEqual((3,), binom.mean().get_shape())
self.assertAllClose(expected_means, binom.mean().eval())
def testBinomialVariance(self):
with self.test_session():
n = 5.
p = [0.1, 0.2, 0.7]
binom = tf.contrib.distributions.Binomial(n=n, p=p)
expected_variances = stats.binom.var(n, p)
self.assertEqual((3,), binom.variance().get_shape())
self.assertAllClose(expected_variances, binom.variance().eval())
def testBinomialMode(self):
with self.test_session():
n = 5.
p = [0.1, 0.2, 0.7]
binom = tf.contrib.distributions.Binomial(n=n, p=p)
expected_modes = [0., 1, 4]
self.assertEqual((3,), binom.mode().get_shape())
self.assertAllClose(expected_modes, binom.mode().eval())
def testBinomialMultipleMode(self):
with self.test_session():
n = 9.
p = [0.1, 0.2, 0.7]
binom = tf.contrib.distributions.Binomial(n=n, p=p)
# For the case where (n + 1) * p is an integer, the modes are:
# (n + 1) * p and (n + 1) * p - 1. In this case, we get back
# the larger of the two modes.
expected_modes = [1., 2, 7]
self.assertEqual((3,), binom.mode().get_shape())
self.assertAllClose(expected_modes, binom.mode().eval())
if __name__ == '__main__':
tf.test.main()

View File

@ -65,7 +65,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
dist.pmf([3., 0, 2]).eval() dist.pmf([3., 0, 2]).eval()
with self.assertRaisesOpError('Condition x >= 0.*'): with self.assertRaisesOpError('Condition x >= 0.*'):
dist.pmf([-1., 4, 2]).eval() dist.pmf([-1., 4, 2]).eval()
with self.assertRaisesOpError('Condition x == y.*'): with self.assertRaisesOpError('counts do not sum to n'):
dist.pmf([3., 3, 0]).eval() dist.pmf([3., 3, 0]).eval()
def testPmf_non_integer_counts(self): def testPmf_non_integer_counts(self):

View File

@ -0,0 +1,226 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
class MultinomialTest(tf.test.TestCase):
def testSimpleShapes(self):
with self.test_session():
p = [.1, .3, .6]
dist = tf.contrib.distributions.Multinomial(n=1., p=p)
self.assertEqual(3, dist.event_shape().eval())
self.assertAllEqual([], dist.batch_shape().eval())
self.assertEqual(tf.TensorShape([3]), dist.get_event_shape())
self.assertEqual(tf.TensorShape([]), dist.get_batch_shape())
def testComplexShapes(self):
with self.test_session():
p = 0.5 * np.ones([3, 2, 2], dtype=np.float32)
n = [[3., 2], [4, 5], [6, 7]]
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
self.assertEqual(2, dist.event_shape().eval())
self.assertAllEqual([3, 2], dist.batch_shape().eval())
self.assertEqual(tf.TensorShape([2]), dist.get_event_shape())
self.assertEqual(tf.TensorShape([3, 2]), dist.get_batch_shape())
def testNProperty(self):
p = [[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]]
n = [[3.], [4]]
with self.test_session():
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
self.assertEqual((2, 1), dist.n.get_shape())
self.assertAllClose(n, dist.n.eval())
def testPProperty(self):
p = [[0.1, 0.2, 0.7]]
with self.test_session():
dist = tf.contrib.distributions.Multinomial(n=3., p=p)
self.assertEqual((1, 3), dist.p.get_shape())
self.assertEqual((1, 3), dist.logits.get_shape())
self.assertAllClose(p, dist.p.eval())
def testLogitsProperty(self):
logits = [[0., 9., -0.5]]
with self.test_session():
multinom = tf.contrib.distributions.Multinomial(n=3., logits=logits)
self.assertEqual((1, 3), multinom.p.get_shape())
self.assertEqual((1, 3), multinom.logits.get_shape())
self.assertAllClose(logits, multinom.logits.eval())
def testPmfNandCountsAgree(self):
p = [[0.1, 0.2, 0.7]]
n = [[5.]]
with self.test_session():
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
dist.pmf([2., 3, 0]).eval()
dist.pmf([3., 0, 2]).eval()
with self.assertRaisesOpError('Condition x >= 0.*'):
dist.pmf([-1., 4, 2]).eval()
with self.assertRaisesOpError('counts do not sum to n'):
dist.pmf([3., 3, 0]).eval()
def testPmf_non_integer_counts(self):
p = [[0.1, 0.2, 0.7]]
n = [[5.]]
with self.test_session():
# No errors with integer n.
multinom = tf.contrib.distributions.Multinomial(n=n, p=p)
multinom.pmf([2., 1, 2]).eval()
multinom.pmf([3., 0, 2]).eval()
# Counts don't sum to n.
with self.assertRaisesOpError('counts do not sum to n'):
multinom.pmf([2., 3, 2]).eval()
# Counts are non-integers.
with self.assertRaisesOpError('Condition x == y.*'):
multinom.pmf([1.0, 2.5, 1.5]).eval()
multinom = tf.contrib.distributions.Multinomial(
n=n, p=p, validate_args=False)
multinom.pmf([1., 2., 2.]).eval()
# Non-integer arguments work.
multinom.pmf([1.0, 2.5, 1.5]).eval()
def testPmfBothZeroBatches(self):
with self.test_session():
# Both zero-batches. No broadcast
p = [0.5, 0.5]
counts = [1., 0]
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
self.assertAllClose(0.5, pmf.eval())
self.assertEqual((), pmf.get_shape())
def testPmfBothZeroBatchesNontrivialN(self):
with self.test_session():
# Both zero-batches. No broadcast
p = [0.1, 0.9]
counts = [3., 2]
dist = tf.contrib.distributions.Multinomial(n=5., p=p)
pmf = dist.pmf(counts)
# 5 choose 3 = 5 choose 2 = 10. 10 * (.9)^2 * (.1)^3 = 81/10000.
self.assertAllClose(81./10000, pmf.eval())
self.assertEqual((), pmf.get_shape())
def testPmfPStretchedInBroadcastWhenSameRank(self):
with self.test_session():
p = [[0.1, 0.9]]
counts = [[1., 0], [0, 1]]
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
self.assertAllClose([0.1, 0.9], pmf.eval())
self.assertEqual((2), pmf.get_shape())
def testPmfPStretchedInBroadcastWhenLowerRank(self):
with self.test_session():
p = [0.1, 0.9]
counts = [[1., 0], [0, 1]]
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
self.assertAllClose([0.1, 0.9], pmf.eval())
self.assertEqual((2), pmf.get_shape())
def testPmfCountsStretchedInBroadcastWhenSameRank(self):
with self.test_session():
p = [[0.1, 0.9], [0.7, 0.3]]
counts = [[1., 0]]
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
self.assertAllClose(pmf.eval(), [0.1, 0.7])
self.assertEqual((2), pmf.get_shape())
def testPmfCountsStretchedInBroadcastWhenLowerRank(self):
with self.test_session():
p = [[0.1, 0.9], [0.7, 0.3]]
counts = [1., 0]
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
self.assertAllClose(pmf.eval(), [0.1, 0.7])
self.assertEqual(pmf.get_shape(), (2))
def testPmfShapeCountsStretched_N(self):
with self.test_session():
# [2, 2, 2]
p = [[[0.1, 0.9], [0.1, 0.9]], [[0.7, 0.3], [0.7, 0.3]]]
# [2, 2]
n = [[3., 3], [3, 3]]
# [2]
counts = [2., 1]
pmf = tf.contrib.distributions.Multinomial(n=n, p=p).pmf(counts)
pmf.eval()
self.assertEqual(pmf.get_shape(), (2, 2))
def testPmfShapeCountsPStretched_N(self):
with self.test_session():
p = [0.1, 0.9]
counts = [3., 2]
n = np.full([4, 3], 5., dtype=np.float32)
pmf = tf.contrib.distributions.Multinomial(n=n, p=p).pmf(counts)
pmf.eval()
self.assertEqual((4, 3), pmf.get_shape())
def testMultinomialMean(self):
with self.test_session():
n = 5.
p = [0.1, 0.2, 0.7]
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
expected_means = 5 * np.array(p, dtype=np.float32)
self.assertEqual((3,), dist.mean().get_shape())
self.assertAllClose(expected_means, dist.mean().eval())
def testMultinomialVariance(self):
with self.test_session():
n = 5.
p = [0.1, 0.2, 0.7]
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
expected_variances = [
[9./20, -1/10, -7/20], [-1/10, 4/5, -7/10], [-7/20, -7/10, 21/20]]
self.assertEqual((3, 3), dist.variance().get_shape())
self.assertAllClose(expected_variances, dist.variance().eval())
def testMultinomialVariance_batch(self):
with self.test_session():
# Shape [2]
n = [5.] * 2
# Shape [4, 1, 2]
p = [[[0.1, 0.9]], [[0.1, 0.9]]] * 2
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
# Shape [2, 2]
inner_var = [[9./20, -9/20], [-9/20, 9/20]]
# Shape [4, 2, 2, 2]
expected_variances = [[inner_var, inner_var]] * 4
self.assertEqual((4, 2, 2, 2), dist.variance().get_shape())
self.assertAllClose(expected_variances, dist.variance().eval())
def testVariance_multidimensional(self):
# Shape [3, 5, 4]
p = np.random.dirichlet([.25, .25, .25, .25], [3, 5]).astype(np.float32)
# Shape [6, 3, 3]
p2 = np.random.dirichlet([.3, .3, .4], [6, 3]).astype(np.float32)
ns = np.random.randint(low=1, high=11, size=[3, 5]).astype(np.float32)
ns2 = np.random.randint(low=1, high=11, size=[6, 1]).astype(np.float32)
with self.test_session():
dist = tf.contrib.distributions.Multinomial(ns, p)
dist2 = tf.contrib.distributions.Multinomial(ns2, p2)
variance = dist.variance()
variance2 = dist2.variance()
self.assertEqual((3, 5, 4, 4), variance.get_shape())
self.assertEqual((6, 3, 3, 3), variance2.get_shape())
if __name__ == '__main__':
tf.test.main()

View File

@ -369,5 +369,87 @@ class MultivariateNormalCholeskyTest(tf.test.TestCase):
self.assertEqual((3, 5), tuple(mvn.batch_shape().eval())) self.assertEqual((3, 5), tuple(mvn.batch_shape().eval()))
class MultivariateNormalFullTest(tf.test.TestCase):
def setUp(self):
self._rng = np.random.RandomState(42)
def _random_mu_and_sigma(self, batch_shape, event_shape):
# This ensures sigma is positive def.
mat_shape = batch_shape + event_shape + event_shape
mat = self._rng.randn(*mat_shape)
sigma = tf.batch_matmul(mat, mat, adj_y=True).eval()
mu_shape = batch_shape + event_shape
mu = self._rng.randn(*mu_shape)
return mu, sigma
def testKLNonBatch(self):
batch_shape = ()
event_shape = (2,)
with self.test_session():
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
mvn_a = distributions.MultivariateNormalFull(mu_a, sigma_a)
mvn_b = distributions.MultivariateNormalFull(mu_b, sigma_b)
kl = distributions.kl(mvn_a, mvn_b)
self.assertEqual(batch_shape, kl.get_shape())
kl_v = kl.eval()
expected_kl = _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b)
self.assertAllClose(expected_kl, kl_v)
def testKLBatch(self):
batch_shape = (2,)
event_shape = (3,)
with self.test_session():
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
mvn_a = distributions.MultivariateNormalFull(mu_a, sigma_a)
mvn_b = distributions.MultivariateNormalFull(mu_b, sigma_b)
kl = distributions.kl(mvn_a, mvn_b)
self.assertEqual(batch_shape, kl.get_shape())
kl_v = kl.eval()
expected_kl_0 = _compute_non_batch_kl(
mu_a[0, :], sigma_a[0, :, :], mu_b[0, :], sigma_b[0, :])
expected_kl_1 = _compute_non_batch_kl(
mu_a[1, :], sigma_a[1, :, :], mu_b[1, :], sigma_b[1, :])
self.assertAllClose(expected_kl_0, kl_v[0])
self.assertAllClose(expected_kl_1, kl_v[1])
def testKLTwoIdenticalDistributionsIsZero(self):
batch_shape = (2,)
event_shape = (3,)
with self.test_session():
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
mvn_a = distributions.MultivariateNormalFull(mu_a, sigma_a)
# Should be zero since KL(p || p) = =.
kl = distributions.kl(mvn_a, mvn_a)
self.assertEqual(batch_shape, kl.get_shape())
kl_v = kl.eval()
self.assertAllClose(np.zeros(*batch_shape), kl_v)
def _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b):
"""Non-batch KL for N(mu_a, sigma_a), N(mu_b, sigma_b)."""
# Check using numpy operations
# This mostly repeats the tensorflow code _kl_mvn_mvn(), but in numpy.
# So it is important to also check that KL(mvn, mvn) = 0.
sigma_b_inv = np.linalg.inv(sigma_b)
t = np.trace(sigma_b_inv.dot(sigma_a))
q = (mu_b - mu_a).dot(sigma_b_inv).dot(mu_b - mu_a)
k = mu_a.shape[0]
l = np.log(np.linalg.det(sigma_b) / np.linalg.det(sigma_a))
return 0.5 * (t + q - k + l)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()

View File

@ -19,15 +19,13 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import distribution from tensorflow.contrib.distributions.python.ops import distribution
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.contrib.distributions.python.ops import kullback_leibler # pylint: disable=line-too-long from tensorflow.contrib.distributions.python.ops import kullback_leibler # pylint: disable=line-too-long
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
@ -38,10 +36,6 @@ class Bernoulli(distribution.Distribution):
The Bernoulli distribution is parameterized by p, the probability of a The Bernoulli distribution is parameterized by p, the probability of a
positive event. positive event.
Note, the following methods of the base class aren't implemented:
* cdf
* log_cdf
""" """
def __init__(self, def __init__(self,
@ -64,10 +58,10 @@ class Bernoulli(distribution.Distribution):
dtype: dtype for samples. dtype: dtype for samples.
validate_args: Whether to assert that `0 <= p <= 1`. If not validate_args, validate_args: Whether to assert that `0 <= p <= 1`. If not validate_args,
`log_pmf` may return nans. `log_pmf` may return nans.
allow_nan_stats: Boolean, default False. If False, raise an exception if allow_nan_stats: Boolean, default `False`. If `False`, raise an
a statistic (e.g. mean/mode/etc...) is undefined for any batch member. exception if a statistic (e.g. mean/mode/etc...) is undefined for any
If True, batch members with valid parameters leading to undefined batch member. If `True`, batch members with valid parameters leading to
statistics will return NaN for this statistic. undefined statistics will return NaN for this statistic.
name: A name for this distribution. name: A name for this distribution.
Raises: Raises:
@ -77,27 +71,8 @@ class Bernoulli(distribution.Distribution):
self._name = name self._name = name
self._dtype = dtype self._dtype = dtype
self._validate_args = validate_args self._validate_args = validate_args
check_op = check_ops.assert_less_equal self._logits, self._p = distribution_util.get_logits_and_prob(
if p is None and logits is None: name=name, logits=logits, p=p, validate_args=validate_args)
raise ValueError("Must pass p or logits.")
elif p is not None and logits is not None:
raise ValueError("Must pass either p or logits, not both.")
elif p is None:
with ops.op_scope([logits], name):
self._logits = array_ops.identity(logits, name="logits")
with ops.name_scope(name):
with ops.name_scope("p"):
self._p = math_ops.sigmoid(self._logits)
elif logits is None:
with ops.name_scope(name):
with ops.name_scope("p"):
p = array_ops.identity(p)
one = constant_op.constant(1., p.dtype)
zero = constant_op.constant(0., p.dtype)
self._p = control_flow_ops.with_dependencies(
[check_op(p, one), check_op(zero, p)] if validate_args else [], p)
with ops.name_scope("logits"):
self._logits = math_ops.log(self._p) - math_ops.log(1. - self._p)
with ops.name_scope(name): with ops.name_scope(name):
with ops.name_scope("q"): with ops.name_scope("q"):
self._q = 1. - self._p self._q = 1. - self._p
@ -184,8 +159,12 @@ class Bernoulli(distribution.Distribution):
event = ops.convert_to_tensor(event, name="event") event = ops.convert_to_tensor(event, name="event")
event = math_ops.cast(event, self.logits.dtype) event = math_ops.cast(event, self.logits.dtype)
logits = self.logits logits = self.logits
if ((event.get_shape().ndims is not None) or # sigmoid_cross_entropy_with_logits doesn't broadcast shape,
(logits.get_shape().ndims is not None) or # so we do this here.
# TODO(b/30637701): Check dynamic shape, and don't broadcast if the
# dynamic shapes are the same.
if (not event.get_shape().is_fully_defined() or
not logits.get_shape().is_fully_defined() or
event.get_shape() != logits.get_shape()): event.get_shape() != logits.get_shape()):
logits = array_ops.ones_like(event) * logits logits = array_ops.ones_like(event) * logits
event = array_ops.ones_like(logits) * event event = array_ops.ones_like(logits) * event
@ -206,8 +185,7 @@ class Bernoulli(distribution.Distribution):
with ops.name_scope(self.name): with ops.name_scope(self.name):
with ops.op_scope([self.p, n], name): with ops.op_scope([self.p, n], name):
n = ops.convert_to_tensor(n, name="n") n = ops.convert_to_tensor(n, name="n")
new_shape = array_ops.concat( new_shape = array_ops.concat(0, ([n], self.batch_shape()))
0, [array_ops.expand_dims(n, 0), self.batch_shape()])
uniform = random_ops.random_uniform( uniform = random_ops.random_uniform(
new_shape, seed=seed, dtype=dtypes.float32) new_shape, seed=seed, dtype=dtypes.float32)
sample = math_ops.less(uniform, self.p) sample = math_ops.less(uniform, self.p)

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""The Beta distribution class.""" """The Beta distribution class."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
@ -95,6 +96,7 @@ class Beta(distribution.Distribution):
x = [.2, .3, .9] x = [.2, .3, .9]
dist.pdf(x) # Shape [2] dist.pdf(x) # Shape [2]
``` ```
""" """
def __init__(self, a, b, validate_args=True, allow_nan_stats=False, def __init__(self, a, b, validate_args=True, allow_nan_stats=False,
@ -102,20 +104,20 @@ class Beta(distribution.Distribution):
"""Initialize a batch of Beta distributions. """Initialize a batch of Beta distributions.
Args: Args:
a: Positive `float` or `double` tensor with shape broadcastable to a: Positive floating point tensor with shape broadcastable to
`[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm` `[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
different Beta distributions. This also defines the different Beta distributions. This also defines the
dtype of the distribution. dtype of the distribution.
b: Positive `float` or `double` tensor with shape broadcastable to b: Positive floating point tensor with shape broadcastable to
`[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm` `[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
different Beta distributions. different Beta distributions.
validate_args: Whether to assert valid values for parameters `a` and `b`, validate_args: Whether to assert valid values for parameters `a` and `b`,
and `x` in `prob` and `log_prob`. If False, correct behavior is not and `x` in `prob` and `log_prob`. If `False`, correct behavior is not
guaranteed. guaranteed.
allow_nan_stats: Boolean, default False. If False, raise an exception if allow_nan_stats: Boolean, default `False`. If `False`, raise an
a statistic (e.g. mean/mode/etc...) is undefined for any batch member. exception if a statistic (e.g. mean/mode/etc...) is undefined for any
If True, batch members with valid parameters leading to undefined batch member. If `True`, batch members with valid parameters leading to
statistics will return NaN for this statistic. undefined statistics will return NaN for this statistic.
name: The name to prefix Ops created by this distribution class. name: The name to prefix Ops created by this distribution class.
Examples: Examples:
@ -127,6 +129,7 @@ class Beta(distribution.Distribution):
# Define a 2-batch. # Define a 2-batch.
dist = Beta([1.0, 2.0], [4.0, 5.0]) dist = Beta([1.0, 2.0], [4.0, 5.0])
``` ```
""" """
with ops.op_scope([a, b], name): with ops.op_scope([a, b], name):
with ops.control_dependencies([ with ops.control_dependencies([
@ -276,8 +279,14 @@ class Beta(distribution.Distribution):
array_ops.ones_like(a_b_sum, dtype=self.dtype))) array_ops.ones_like(a_b_sum, dtype=self.dtype)))
else: else:
return control_flow_ops.with_dependencies([ return control_flow_ops.with_dependencies([
check_ops.assert_less(one, a), check_ops.assert_less(
check_ops.assert_less(one, b)], mode) one, a,
message="mode not defined for components of a <= 1"
),
check_ops.assert_less(
one, b,
message="mode not defined for components of b <= 1"
)], mode)
def entropy(self, name="entropy"): def entropy(self, name="entropy"):
"""Entropy of the distribution in nats.""" """Entropy of the distribution in nats."""
@ -306,7 +315,7 @@ class Beta(distribution.Distribution):
"""`Log(P[counts])`, computed for every batch member. """`Log(P[counts])`, computed for every batch member.
Args: Args:
x: Non-negative `float` or `double`, tensor whose shape can x: Non-negative floating point tensor whose shape can
be broadcast with `self.a` and `self.b`. For fixed leading be broadcast with `self.a` and `self.b`. For fixed leading
dimensions, the last dimension represents counts for the corresponding dimensions, the last dimension represents counts for the corresponding
Beta distribution in `self.a` and `self.b`. `x` is only legal if Beta distribution in `self.a` and `self.b`. `x` is only legal if
@ -334,7 +343,7 @@ class Beta(distribution.Distribution):
"""`P[x]`, computed for every batch member. """`P[x]`, computed for every batch member.
Args: Args:
x: Non-negative `float`, `double` tensor whose shape can x: Non-negative floating point tensor whose shape can
be broadcast with `self.a` and `self.b`. For fixed leading be broadcast with `self.a` and `self.b`. For fixed leading
dimensions, the last dimension represents x for the corresponding Beta dimensions, the last dimension represents x for the corresponding Beta
distribution in `self.a` and `self.b`. `x` is only legal if is distribution in `self.a` and `self.b`. `x` is only legal if is

View File

@ -0,0 +1,340 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Binomial distribution class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=line-too-long
from tensorflow.contrib.distributions.python.ops import distribution
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
# pylint: enable=line-too-long
class Binomial(distribution.Distribution):
"""Binomial distribution.
This distribution is parameterized by a vector `p` of probabilities and `n`,
the total counts.
#### Mathematical details
The Binomial is a distribution over the number of successes in `n` independent
trials, with each trial having the same probability of success `p`.
The probability mass function (pmf):
```pmf(k) = n! / (k! * (n - k)!) * (p)^k * (1 - p)^(n - k)```
#### Examples
Create a single distribution, corresponding to 5 coin flips.
```python
dist = Binomial(n=5., p=.5)
```
Create a single distribution (using logits), corresponding to 5 coin flips.
```python
dist = Binomial(n=5., logits=0.)
```
Creates 3 distributions with the third distribution most likely to have
successes.
```python
p = [.2, .3, .8]
# n will be broadcast to [4., 4., 4.], to match p.
dist = Binomial(n=4., p=p)
```
The distribution functions can be evaluated on counts.
```python
# counts same shape as p.
counts = [1., 2, 3]
dist.prob(counts) # Shape [3]
# p will be broadcast to [[.2, .3, .8], [.2, .3, .8]] to match counts.
counts = [[1., 2, 1], [2, 2, 4]]
dist.prob(counts) # Shape [2, 3]
# p will be broadcast to shape [5, 7, 3] to match counts.
counts = [[...]] # Shape [5, 7, 3]
dist.prob(counts) # Shape [5, 7, 3]
```
"""
def __init__(self,
n,
logits=None,
p=None,
validate_args=True,
allow_nan_stats=False,
name="Binomial"):
"""Initialize a batch of Binomial distributions.
Args:
n: Non-negative floating point tensor with shape broadcastable to
`[N1,..., Nm]` with `m >= 0` and the same dtype as `p` or `logits`.
Defines this as a batch of `N1 x ... x Nm` different Binomial
distributions. Its components should be equal to integer values.
logits: Floating point tensor representing the log-odds of a
positive event with shape broadcastable to `[N1,..., Nm]` `m >= 0`, and
the same dtype as `n`. Each entry represents logits for the probability
of success for independent Binomial distributions.
p: Positive floating point tensor with shape broadcastable to
`[N1,..., Nm]` `m >= 0`, `p in [0, 1]`. Each entry represents the
probability of success for independent Binomial distributions.
validate_args: Whether to assert valid values for parameters `n` and `p`,
and `x` in `prob` and `log_prob`. If `False`, correct behavior is not
guaranteed.
allow_nan_stats: Boolean, default `False`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
batch member. If `True`, batch members with valid parameters leading to
undefined statistics will return NaN for this statistic.
name: The name to prefix Ops created by this distribution class.
Examples:
```python
# Define 1-batch of a binomial distribution.
dist = Binomial(n=2., p=.9)
# Define a 2-batch.
dist = Binomial(n=[4., 5], p=[.1, .3])
```
"""
self._logits, self._p = distribution_util.get_logits_and_prob(
name=name, logits=logits, p=p, validate_args=validate_args)
with ops.op_scope([n], name):
with ops.control_dependencies([
check_ops.assert_non_negative(
n, message="n has negative components."),
distribution_util.assert_integer_form(
n, message="n has non-integer components."
)] if validate_args else []):
self._n = array_ops.identity(n, name="convert_n")
self._name = name
self._validate_args = validate_args
self._allow_nan_stats = allow_nan_stats
self._mean = self._n * self._p
self._get_batch_shape = self._mean.get_shape()
self._get_event_shape = tensor_shape.TensorShape([])
@property
def name(self):
"""Name to prepend to all ops."""
return self._name
@property
def dtype(self):
"""dtype of samples from this distribution."""
return self._p.dtype
@property
def validate_args(self):
"""Boolean describing behavior on invalid input."""
return self._validate_args
@property
def allow_nan_stats(self):
"""Boolean describing behavior when a stat is undefined for batch member."""
return self._allow_nan_stats
def batch_shape(self, name="batch_shape"):
"""Batch dimensions of this instance as a 1-D int32 `Tensor`.
The product of the dimensions of the `batch_shape` is the number of
independent distributions of this kind the instance represents.
Args:
name: name to give to the op
Returns:
`Tensor` `batch_shape`
"""
return array_ops.shape(self._mean)
def get_batch_shape(self):
"""`TensorShape` available at graph construction time.
Same meaning as `batch_shape`. May be only partially defined.
Returns:
batch shape
"""
return self._get_batch_shape
def event_shape(self, name="event_shape"):
"""Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
Args:
name: name to give to the op
Returns:
`Tensor` `event_shape`
"""
with ops.name_scope(self.name):
with ops.op_scope([], name):
return constant_op.constant([], name=name, dtype=dtypes.int32)
def get_event_shape(self):
"""`TensorShape` available at graph construction time.
Same meaning as `event_shape`. May be only partially defined.
Returns:
event shape
"""
return self._get_event_shape
@property
def n(self):
"""Number of trials."""
return self._n
@property
def logits(self):
"""Log-odds."""
return self._logits
@property
def p(self):
"""Probability of success."""
return self._p
def mean(self, name="mean"):
"""Mean of the distribution."""
with ops.name_scope(self.name):
return array_ops.identity(self._mean, name=name)
def variance(self, name="variance"):
"""Variance of the distribution."""
with ops.name_scope(self.name):
with ops.op_scope([self._n, self._p], name):
return self._n * self._p * (1 - self._p)
def std(self, name="std"):
"""Standard deviation of the distribution."""
with ops.name_scope(self.name):
with ops.op_scope([self._n, self._p], name):
return math_ops.sqrt(self.variance())
def mode(self, name="mode"):
"""Mode of the distribution.
Note that when `(n + 1) * p` is an integer, there are actually two modes.
Namely, `(n + 1) * p` and `(n + 1) * p - 1` are both modes. Here we return
only the larger of the two modes.
Args:
name: The name for this op.
Returns:
The mode of the Binomial distribution.
"""
with ops.name_scope(self.name):
with ops.op_scope([self._n, self._p], name):
return math_ops.floor((self._n + 1) * self._p)
def log_prob(self, counts, name="log_prob"):
"""`Log(P[counts])`, computed for every batch member.
For each batch member of counts `k`, `P[counts]` is the probability that
after sampling `n` draws from this Binomial distribution, the number of
successes is `k`. Note that different sequences of draws can result in the
same counts, thus the probability includes a combinatorial coefficient.
Args:
counts: Non-negative tensor with dtype `dtype` and whose shape can be
broadcast with `self.p` and `self.n`. `counts` is only legal if it is
less than or equal to `n` and its components are equal to integer
values.
name: Name to give this Op, defaults to "log_prob".
Returns:
Log probabilities for each record, shape `[N1,...,Nm]`.
"""
n = self._n
p = self._p
with ops.name_scope(self.name):
with ops.op_scope([self._n, self._p, counts], name):
counts = self._check_counts(counts)
prob_prob = counts * math_ops.log(p) + (
n - counts) * math_ops.log(1 - p)
combinations = math_ops.lgamma(n + 1) - math_ops.lgamma(
counts + 1) - math_ops.lgamma(n - counts + 1)
log_prob = prob_prob + combinations
return log_prob
def prob(self, counts, name="prob"):
"""`P[counts]`, computed for every batch member.
For each batch member of counts `k`, `P[counts]` is the probability that
after sampling `n` draws from this Binomial distribution, the number of
successes is `k`. Note that different sequences of draws can result in the
same counts, thus the probability includes a combinatorial coefficient.
Args:
counts: Non-negative tensor with dtype `dtype` and whose shape can be
broadcast with `self.p` and `self.n`. `counts` is only legal if it is
less than or equal to `n` and its components are equal to integer
values.
name: Name to give this Op, defaults to "prob".
Returns:
Probabilities for each record, shape `[N1,...,Nm]`.
"""
return super(Binomial, self).prob(counts, name=name)
@property
def is_continuous(self):
return False
@property
def is_reparameterized(self):
return False
def _check_counts(self, counts):
"""Check counts for proper shape, values, then return tensor version."""
counts = ops.convert_to_tensor(counts, name="counts_before_deps")
if not self.validate_args:
return counts
return control_flow_ops.with_dependencies([
check_ops.assert_non_negative(
counts, message="counts has negative components."),
check_ops.assert_less_equal(
counts, self._n, message="counts are not less than or equal to n."),
distribution_util.assert_integer_form(
counts, message="counts have non-integer components.")], counts)

View File

@ -34,11 +34,6 @@ class Categorical(distribution.Distribution):
The categorical distribution is parameterized by the log-probabilities The categorical distribution is parameterized by the log-probabilities
of a set of classes. of a set of classes.
Note, the following methods of the base class aren't implemented:
* mean
* cdf
* log_cdf
""" """
def __init__( def __init__(
@ -57,10 +52,10 @@ class Categorical(distribution.Distribution):
indexes into the classes. indexes into the classes.
dtype: The type of the event samples (default: int32). dtype: The type of the event samples (default: int32).
validate_args: Unused in this distribution. validate_args: Unused in this distribution.
allow_nan_stats: Boolean, default False. If False, raise an exception if allow_nan_stats: Boolean, default `False`. If `False`, raise an
a statistic (e.g. mean/mode/etc...) is undefined for any batch member. exception if a statistic (e.g. mean/mode/etc...) is undefined for any
If True, batch members with valid parameters leading to undefined batch member. If `True`, batch members with valid parameters leading to
statistics will return NaN for this statistic. undefined statistics will return NaN for this statistic.
name: A name for this distribution (optional). name: A name for this distribution (optional).
""" """
self._allow_nan_stats = allow_nan_stats self._allow_nan_stats = allow_nan_stats
@ -177,8 +172,7 @@ class Categorical(distribution.Distribution):
samples = math_ops.cast(samples, self._dtype) samples = math_ops.cast(samples, self._dtype)
ret = array_ops.reshape( ret = array_ops.reshape(
array_ops.transpose(samples), array_ops.transpose(samples),
array_ops.concat( array_ops.concat(0, ([n], self.batch_shape())))
0, [array_ops.expand_dims(n, 0), self.batch_shape()]))
ret.set_shape(tensor_shape.vector(tensor_util.constant_value(n)) ret.set_shape(tensor_shape.vector(tensor_util.constant_value(n))
.concatenate(self.get_batch_shape())) .concatenate(self.get_batch_shape()))
return ret return ret

View File

@ -42,15 +42,15 @@ class Chi2(gamma.Gamma):
"""Construct Chi2 distributions with parameter `df`. """Construct Chi2 distributions with parameter `df`.
Args: Args:
df: `float` or `double` tensor, the degrees of freedom of the df: Floating point tensor, the degrees of freedom of the
distribution(s). `df` must contain only positive values. distribution(s). `df` must contain only positive values.
validate_args: Whether to assert that `df > 0`, and that `x > 0` in the validate_args: Whether to assert that `df > 0`, and that `x > 0` in the
methods `prob(x)` and `log_prob(x)`. If `validate_args` is False methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False`
and the inputs are invalid, correct behavior is not guaranteed. and the inputs are invalid, correct behavior is not guaranteed.
allow_nan_stats: Boolean, default False. If False, raise an exception if allow_nan_stats: Boolean, default `False`. If `False`, raise an
a statistic (e.g. mean/mode/etc...) is undefined for any batch member. exception if a statistic (e.g. mean/mode/etc...) is undefined for any
If True, batch members with valid parameters leading to undefined batch member. If `True`, batch members with valid parameters leading to
statistics will return NaN for this statistic. undefined statistics will return NaN for this statistic.
name: The name to prepend to all ops created by this distribution. name: The name to prepend to all ops created by this distribution.
""" """
# Even though all stats of chi2 are defined for valid parameters, this is # Even though all stats of chi2 are defined for valid parameters, this is

View File

@ -19,9 +19,8 @@ from __future__ import print_function
# pylint: disable=line-too-long # pylint: disable=line-too-long
import numpy as np
from tensorflow.contrib.distributions.python.ops import distribution from tensorflow.contrib.distributions.python.ops import distribution
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
@ -29,7 +28,6 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops import special_math_ops from tensorflow.python.ops import special_math_ops
@ -37,24 +35,6 @@ from tensorflow.python.ops import special_math_ops
# pylint: enable=line-too-long # pylint: enable=line-too-long
def _assert_close(x, y, data=None, summarize=None, name=None):
if x.dtype.is_integer:
return check_ops.assert_equal(
x, y, data=data, summarize=summarize, name=name)
with ops.op_scope([x, y, data], name, "assert_close"):
x = ops.convert_to_tensor(x, name="x")
y = ops.convert_to_tensor(y, name="y")
tol = np.finfo(x.dtype.as_numpy_dtype).resolution
if data is None:
data = [
"Condition x ~= y did not hold element-wise: x = ", x.name, x, "y = ",
y.name, y
]
condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol))
return logging_ops.Assert(condition, data, summarize=summarize)
class Dirichlet(distribution.Distribution): class Dirichlet(distribution.Distribution):
"""Dirichlet distribution. """Dirichlet distribution.
@ -117,6 +97,7 @@ class Dirichlet(distribution.Distribution):
x = [.2, .3, .5] x = [.2, .3, .5]
dist.prob(x) # Shape [2] dist.prob(x) # Shape [2]
``` ```
""" """
def __init__(self, def __init__(self,
@ -127,16 +108,16 @@ class Dirichlet(distribution.Distribution):
"""Initialize a batch of Dirichlet distributions. """Initialize a batch of Dirichlet distributions.
Args: Args:
alpha: Positive `float` or `double` tensor with shape broadcastable to alpha: Positive floating point tensor with shape broadcastable to
`[N1,..., Nm, k]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm` `[N1,..., Nm, k]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
different `k` class Dirichlet distributions. different `k` class Dirichlet distributions.
validate_args: Whether to assert valid values for parameters `alpha` and validate_args: Whether to assert valid values for parameters `alpha` and
`x` in `prob` and `log_prob`. If False, correct behavior is not `x` in `prob` and `log_prob`. If `False`, correct behavior is not
guaranteed. guaranteed.
allow_nan_stats: Boolean, default False. If False, raise an exception if allow_nan_stats: Boolean, default `False`. If `False`, raise an
a statistic (e.g. mean/mode/etc...) is undefined for any batch member. exception if a statistic (e.g. mean/mode/etc...) is undefined for any
If True, batch members with valid parameters leading to undefined batch member. If `True`, batch members with valid parameters leading to
statistics will return NaN for this statistic. undefined statistics will return NaN for this statistic.
name: The name to prefix Ops created by this distribution class. name: The name to prefix Ops created by this distribution class.
Examples: Examples:
@ -149,6 +130,7 @@ class Dirichlet(distribution.Distribution):
# Define a 2-batch of 3-class distributions. # Define a 2-batch of 3-class distributions.
dist = Dirichlet([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) dist = Dirichlet([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
``` ```
""" """
with ops.op_scope([alpha], name): with ops.op_scope([alpha], name):
alpha = ops.convert_to_tensor(alpha, name="alpha_before_deps") alpha = ops.convert_to_tensor(alpha, name="alpha_before_deps")
@ -302,7 +284,9 @@ class Dirichlet(distribution.Distribution):
array_ops.ones_like(self._alpha, dtype=self.dtype))) array_ops.ones_like(self._alpha, dtype=self.dtype)))
else: else:
return control_flow_ops.with_dependencies([ return control_flow_ops.with_dependencies([
check_ops.assert_less(one, self._alpha) check_ops.assert_less(
one, self._alpha,
message="mode not defined for components of alpha <= 1")
], mode) ], mode)
def entropy(self, name="entropy"): def entropy(self, name="entropy"):
@ -334,7 +318,7 @@ class Dirichlet(distribution.Distribution):
"""`Log(P[counts])`, computed for every batch member. """`Log(P[counts])`, computed for every batch member.
Args: Args:
x: Non-negative `float` or `double`, tensor whose shape can x: Non-negative tensor with dtype `dtype` and whose shape can
be broadcast with `self.alpha`. For fixed leading dimensions, the last be broadcast with `self.alpha`. For fixed leading dimensions, the last
dimension represents counts for the corresponding Dirichlet distribution dimension represents counts for the corresponding Dirichlet distribution
in `self.alpha`. `x` is only legal if it sums up to one. in `self.alpha`. `x` is only legal if it sums up to one.
@ -359,7 +343,7 @@ class Dirichlet(distribution.Distribution):
"""`P[x]`, computed for every batch member. """`P[x]`, computed for every batch member.
Args: Args:
x: Non-negative `float`, `double` tensor whose shape can x: Non-negative tensor with dtype `dtype` and whose shape can
be broadcast with `self.alpha`. For fixed leading dimensions, the last be broadcast with `self.alpha`. For fixed leading dimensions, the last
dimension represents x for the corresponding Dirichlet distribution in dimension represents x for the corresponding Dirichlet distribution in
`self.alpha` and `self.beta`. `x` is only legal if it sums up to one. `self.alpha` and `self.beta`. `x` is only legal if it sums up to one.
@ -407,7 +391,8 @@ class Dirichlet(distribution.Distribution):
x = ops.convert_to_tensor(x, name="x_before_deps") x = ops.convert_to_tensor(x, name="x_before_deps")
candidate_one = math_ops.reduce_sum(x, reduction_indices=[-1]) candidate_one = math_ops.reduce_sum(x, reduction_indices=[-1])
one = constant_op.constant(1., self.dtype) one = constant_op.constant(1., self.dtype)
dependencies = [check_ops.assert_positive(x), check_ops.assert_less(x, one), dependencies = [check_ops.assert_positive(x), check_ops.assert_less(
_assert_close(one, candidate_one) x, one, message="x has components greater than or equal to 1"),
distribution_util.assert_close(one, candidate_one)
] if self.validate_args else [] ] if self.validate_args else []
return control_flow_ops.with_dependencies(dependencies, x) return control_flow_ops.with_dependencies(dependencies, x)

View File

@ -13,13 +13,15 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""The Dirichlet Multinomial distribution class.""" """The Dirichlet Multinomial distribution class."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
# pylint: disable=line-too-long # pylint: disable=line-too-long
from tensorflow.contrib.distributions.python.ops import distribution # pylint: disable=line-too-long from tensorflow.contrib.distributions.python.ops import distribution
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops from tensorflow.python.ops import check_ops
@ -30,34 +32,6 @@ from tensorflow.python.ops import special_math_ops
# pylint: enable=line-too-long # pylint: enable=line-too-long
def _assert_integer_form(x):
"""Check x for integer components (or floats that are equal to integers)."""
x = ops.convert_to_tensor(x, name='x')
casted_x = math_ops.to_int64(x)
return check_ops.assert_equal(x, math_ops.cast(
math_ops.round(casted_x), x.dtype))
def _log_combinations(n, counts, name='log_combinations'):
"""Log number of ways counts could have come in."""
# First a bit about the number of ways counts could have come in:
# E.g. if counts = [1, 2], then this is 3 choose 2.
# In general, this is (sum counts)! / sum(counts!)
# The sum should be along the last dimension of counts. This is the
# "distribution" dimension. Here n a priori represents the sum of counts.
with ops.op_scope([counts], name):
# To compute factorials, use the fact that Gamma(n + 1) = n!
# Compute two terms, each a sum over counts. Compute each for each
# batch member.
# Log Gamma((sum counts) + 1) = Log((sum counts)!)
total_permutations = math_ops.lgamma(n + 1)
# sum(Log Gamma(counts + 1)) = Log sum(counts!)
counts_factorial = math_ops.lgamma(counts + 1)
redundant_permutations = math_ops.reduce_sum(counts_factorial,
reduction_indices=[-1])
return total_permutations - redundant_permutations
class DirichletMultinomial(distribution.Distribution): class DirichletMultinomial(distribution.Distribution):
"""DirichletMultinomial mixture distribution. """DirichletMultinomial mixture distribution.
@ -126,6 +100,7 @@ class DirichletMultinomial(distribution.Distribution):
counts = [2, 1, 0] counts = [2, 1, 0]
dist.pmf(counts) # Shape [2] dist.pmf(counts) # Shape [2]
``` ```
""" """
# TODO(b/27419586) Change docstring for dtype of alpha once int allowed. # TODO(b/27419586) Change docstring for dtype of alpha once int allowed.
@ -134,26 +109,26 @@ class DirichletMultinomial(distribution.Distribution):
alpha, alpha,
validate_args=True, validate_args=True,
allow_nan_stats=False, allow_nan_stats=False,
name='DirichletMultinomial'): name="DirichletMultinomial"):
"""Initialize a batch of DirichletMultinomial distributions. """Initialize a batch of DirichletMultinomial distributions.
Args: Args:
n: Non-negative `float` or `double` tensor, whose dtype is the same as n: Non-negative floating point tensor, whose dtype is the same as
`alpha`. The shape is broadcastable to `[N1,..., Nm]` with `m >= 0`. `alpha`. The shape is broadcastable to `[N1,..., Nm]` with `m >= 0`.
Defines this as a batch of `N1 x ... x Nm` different Dirichlet Defines this as a batch of `N1 x ... x Nm` different Dirichlet
multinomial distributions. Its components should be equal to integral multinomial distributions. Its components should be equal to integer
values. values.
alpha: Positive `float` or `double` tensor, whose dtype is the same as alpha: Positive floating point tensor, whose dtype is the same as
`n` with shape broadcastable to `[N1,..., Nm, k]` `m >= 0`. Defines `n` with shape broadcastable to `[N1,..., Nm, k]` `m >= 0`. Defines
this as a batch of `N1 x ... x Nm` different `k` class Dirichlet this as a batch of `N1 x ... x Nm` different `k` class Dirichlet
multinomial distributions. multinomial distributions.
validate_args: Whether to assert valid values for parameters `alpha` and validate_args: Whether to assert valid values for parameters `alpha` and
`n`, and `x` in `prob` and `log_prob`. If False, correct behavior is `n`, and `x` in `prob` and `log_prob`. If `False`, correct behavior is
not guaranteed. not guaranteed.
allow_nan_stats: Boolean, default False. If False, raise an exception if allow_nan_stats: Boolean, default `False`. If `False`, raise an
a statistic (e.g. mean/mode/etc...) is undefined for any batch member. exception if a statistic (e.g. mean/mode/etc...) is undefined for any
If True, batch members with valid parameters leading to undefined batch member. If `True`, batch members with valid parameters leading to
statistics will return NaN for this statistic. undefined statistics will return NaN for this statistic.
name: The name to prefix Ops created by this distribution class. name: The name to prefix Ops created by this distribution class.
Examples: Examples:
@ -166,6 +141,7 @@ class DirichletMultinomial(distribution.Distribution):
# Define a 2-batch of 3-class distributions. # Define a 2-batch of 3-class distributions.
dist = DirichletMultinomial([3., 4], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) dist = DirichletMultinomial([3., 4], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
``` ```
""" """
self._allow_nan_stats = allow_nan_stats self._allow_nan_stats = allow_nan_stats
self._validate_args = validate_args self._validate_args = validate_args
@ -221,7 +197,7 @@ class DirichletMultinomial(distribution.Distribution):
"""dtype of samples from this distribution.""" """dtype of samples from this distribution."""
return self._alpha.dtype return self._alpha.dtype
def mean(self, name='mean'): def mean(self, name="mean"):
"""Class means for every batch member.""" """Class means for every batch member."""
alpha = self._alpha alpha = self._alpha
alpha_sum = self._alpha_sum alpha_sum = self._alpha_sum
@ -231,7 +207,7 @@ class DirichletMultinomial(distribution.Distribution):
mean_no_n = alpha / array_ops.expand_dims(alpha_sum, -1) mean_no_n = alpha / array_ops.expand_dims(alpha_sum, -1)
return array_ops.expand_dims(n, -1) * mean_no_n return array_ops.expand_dims(n, -1) * mean_no_n
def variance(self, name='mean'): def variance(self, name="mean"):
"""Class variances for every batch member. """Class variances for every batch member.
The variance for each batch member is defined as the following: The variance for each batch member is defined as the following:
@ -273,7 +249,7 @@ class DirichletMultinomial(distribution.Distribution):
variance *= array_ops.expand_dims(shared_factor, -1) variance *= array_ops.expand_dims(shared_factor, -1)
return variance return variance
def batch_shape(self, name='batch_shape'): def batch_shape(self, name="batch_shape"):
"""Batch dimensions of this instance as a 1-D int32 `Tensor`. """Batch dimensions of this instance as a 1-D int32 `Tensor`.
The product of the dimensions of the `batch_shape` is the number of The product of the dimensions of the `batch_shape` is the number of
@ -299,7 +275,7 @@ class DirichletMultinomial(distribution.Distribution):
""" """
return self._get_batch_shape return self._get_batch_shape
def event_shape(self, name='event_shape'): def event_shape(self, name="event_shape"):
"""Shape of a sample from a single distribution as a 1-D int32 `Tensor`. """Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
Args: Args:
@ -322,15 +298,15 @@ class DirichletMultinomial(distribution.Distribution):
""" """
return self._get_event_shape return self._get_event_shape
def cdf(self, x, name='cdf'): def cdf(self, x, name="cdf"):
raise NotImplementedError( raise NotImplementedError(
'DirichletMultinomial does not have a well-defined cdf.') "DirichletMultinomial does not have a well-defined cdf.")
def log_cdf(self, x, name='log_cdf'): def log_cdf(self, x, name="log_cdf"):
raise NotImplementedError( raise NotImplementedError(
'DirichletMultinomial does not have a well-defined cdf.') "DirichletMultinomial does not have a well-defined cdf.")
def log_prob(self, counts, name='log_prob'): def log_prob(self, counts, name="log_prob"):
"""`Log(P[counts])`, computed for every batch member. """`Log(P[counts])`, computed for every batch member.
For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability
@ -340,12 +316,11 @@ class DirichletMultinomial(distribution.Distribution):
probability includes a combinatorial coefficient. probability includes a combinatorial coefficient.
Args: Args:
counts: Non-negative `float` or `double` tensor whose dtype is the same counts: Non-negative tensor with dtype `dtype` and whose shape can be
`self` and whose shape can be broadcast with `self.alpha`. For fixed broadcast with `self.alpha`. For fixed leading dimensions, the last
leading dimensions, the last dimension represents counts for the dimension represents counts for the corresponding Dirichlet Multinomial
corresponding Dirichlet Multinomial distribution in `self.alpha`. distribution in `self.alpha`. `counts` is only legal if it sums up to
`counts` is only legal if it sums up to `n` and its components are `n` and its components are equal to integer values.
equal to integral values.
name: Name to give this Op, defaults to "log_prob". name: Name to give this Op, defaults to "log_prob".
Returns: Returns:
@ -359,20 +334,11 @@ class DirichletMultinomial(distribution.Distribution):
ordered_prob = (special_math_ops.lbeta(alpha + counts) - ordered_prob = (special_math_ops.lbeta(alpha + counts) -
special_math_ops.lbeta(alpha)) special_math_ops.lbeta(alpha))
log_prob = ordered_prob + _log_combinations(n, counts) log_prob = ordered_prob + distribution_util.log_combinations(
# If alpha = counts = [[]], ordered_prob carries the right shape, which n, counts)
# is []. However, since reduce_sum([[]]) = [0], log_combinations = [0],
# which is not correct. Luckily, [] + [0] = [], so the sum is fine, but
# shape must be inferred from ordered_prob. We must also make this
# broadcastable with n, so this is multiplied by n to ensure the shape
# is correctly inferred.
# Note also that tf.constant([]).get_shape() =
# TensorShape([Dimension(0)])
broadcasted_tensor = ordered_prob * n
log_prob.set_shape(broadcasted_tensor.get_shape())
return log_prob return log_prob
def prob(self, counts, name='prob'): def prob(self, counts, name="prob"):
"""`P[counts]`, computed for every batch member. """`P[counts]`, computed for every batch member.
For each batch of counts `[c_1,...,c_k]`, `P[counts]` is the probability For each batch of counts `[c_1,...,c_k]`, `P[counts]` is the probability
@ -382,12 +348,11 @@ class DirichletMultinomial(distribution.Distribution):
probability includes a combinatorial coefficient. probability includes a combinatorial coefficient.
Args: Args:
counts: Non-negative `float` or `double` tensor whose dtype is the same counts: Non-negative tensor with dtype `dtype` and whose shape can be
`self` and whose shape can be broadcast with `self.alpha`. For fixed broadcast with `self.alpha`. For fixed leading dimensions, the last
leading dimensions, the last dimension represents counts for the dimension represents counts for the corresponding Dirichlet Multinomial
corresponding Dirichlet Multinomial distribution in `self.alpha`. distribution in `self.alpha`. `counts` is only legal if it sums up to
`counts` is only legal if it sums up to `n` and its components are `n` and its components are equal to integer values.
equal to integral values.
name: Name to give this Op, defaults to "prob". name: Name to give this Op, defaults to "prob".
Returns: Returns:
@ -397,18 +362,21 @@ class DirichletMultinomial(distribution.Distribution):
def _check_counts(self, counts): def _check_counts(self, counts):
"""Check counts for proper shape, values, then return tensor version.""" """Check counts for proper shape, values, then return tensor version."""
counts = ops.convert_to_tensor(counts, name='counts') counts = ops.convert_to_tensor(counts, name="counts")
if not self.validate_args: if not self.validate_args:
return counts return counts
candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1]) candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1])
return control_flow_ops.with_dependencies([ return control_flow_ops.with_dependencies([
check_ops.assert_non_negative(counts), check_ops.assert_non_negative(counts),
check_ops.assert_equal(self._n, candidate_n), check_ops.assert_equal(
_assert_integer_form(counts)], counts) self._n, candidate_n,
message="counts do not sum to n"
),
distribution_util.assert_integer_form(counts)], counts)
def _check_alpha(self, alpha): def _check_alpha(self, alpha):
alpha = ops.convert_to_tensor(alpha, name='alpha') alpha = ops.convert_to_tensor(alpha, name="alpha")
if not self.validate_args: if not self.validate_args:
return alpha return alpha
return control_flow_ops.with_dependencies( return control_flow_ops.with_dependencies(
@ -416,11 +384,12 @@ class DirichletMultinomial(distribution.Distribution):
check_ops.assert_positive(alpha)], alpha) check_ops.assert_positive(alpha)], alpha)
def _check_n(self, n): def _check_n(self, n):
n = ops.convert_to_tensor(n, name='n') n = ops.convert_to_tensor(n, name="n")
if not self.validate_args: if not self.validate_args:
return n return n
return control_flow_ops.with_dependencies( return control_flow_ops.with_dependencies(
[check_ops.assert_non_negative(n), _assert_integer_form(n)], n) [check_ops.assert_non_negative(n),
distribution_util.assert_integer_form(n)], n)
@property @property
def is_continuous(self): def is_continuous(self):

View File

@ -0,0 +1,177 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for probability distributions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
def assert_close(
x, y, data=None, summarize=None, message=None, name="assert_close"):
"""Assert that that x and y are within machine epsilon of each other.
Args:
x: Numeric `Tensor`
y: Numeric `Tensor`
data: The tensors to print out if the condition is `False`. Defaults to
error message and first few entries of `x` and `y`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional).
Returns:
Op raising `InvalidArgumentError` if |x - y| > machine epsilon.
"""
message = message or ""
x = ops.convert_to_tensor(x, name="x")
y = ops.convert_to_tensor(y, name="y")
if x.dtype.is_integer:
return check_ops.assert_equal(
x, y, data=data, summarize=summarize, message=message, name=name)
with ops.op_scope([x, y, data], name, "assert_close"):
tol = np.finfo(x.dtype.as_numpy_dtype).resolution
if data is None:
data = [
message,
"Condition x ~= y did not hold element-wise: x = ", x.name, x, "y = ",
y.name, y
]
condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol))
return logging_ops.Assert(
condition, data, summarize=summarize)
def assert_integer_form(
x, data=None, summarize=None, message=None, name="assert_integer_form"):
"""Assert that x has integer components (or floats equal to integers).
Args:
x: Numeric `Tensor`
data: The tensors to print out if the condition is `False`. Defaults to
error message and first few entries of `x` and `y`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional).
Returns:
Op raising `InvalidArgumentError` if round(x) != x.
"""
message = message or "x has non-integer components"
x = ops.convert_to_tensor(x, name="x")
casted_x = math_ops.to_int64(x)
return check_ops.assert_equal(
x, math_ops.cast(math_ops.round(casted_x), x.dtype),
data=data, summarize=summarize, message=message, name=name)
def get_logits_and_prob(
logits=None, p=None, multidimensional=False, validate_args=True, name=None):
"""Converts logits to probabilities and vice-versa, and returns both.
Args:
logits: Numeric `Tensor` representing log-odds.
p: Numeric `Tensor` representing probabilities.
multidimensional: Given `p` a [N1, N2, ... k] dimensional tensor,
whether the last dimension represents the probability between k classes.
This will additionally assert that the values in the last dimension
sum to one. If `False`, will instead assert that each value is in
`[0, 1]`.
validate_args: Whether to assert `0 <= p <= 1` if multidimensional is
`False`, otherwise that the last dimension of `p` sums to one.
name: A name for this operation (optional).
Returns:
Tuple with `logits` and `p`. If `p` has an entry that is `0` or `1`, then
the corresponding entry in the returned logits will be `-Inf` and `Inf`
respectively.
Raises:
ValueError: if neither `p` nor `logits` were passed in, or both were.
"""
if p is None and logits is None:
raise ValueError("Must pass p or logits.")
elif p is not None and logits is not None:
raise ValueError("Must pass either p or logits, not both.")
elif p is None:
with ops.op_scope([logits], name):
logits = array_ops.identity(logits, name="logits")
with ops.name_scope(name):
with ops.name_scope("p"):
p = math_ops.sigmoid(logits)
elif logits is None:
with ops.name_scope(name):
with ops.name_scope("p"):
p = array_ops.identity(p)
if validate_args:
one = constant_op.constant(1., p.dtype)
dependencies = [check_ops.assert_non_negative(p)]
if multidimensional:
dependencies += [assert_close(
math_ops.reduce_sum(p, reduction_indices=[-1]),
one, message="p does not sum to 1.")]
else:
dependencies += [check_ops.assert_less_equal(
p, one, message="p has components greater than 1.")]
p = control_flow_ops.with_dependencies(dependencies, p)
with ops.name_scope("logits"):
logits = math_ops.log(p) - math_ops.log(1. - p)
return (logits, p)
def log_combinations(n, counts, name="log_combinations"):
"""Multinomial coefficient.
Given `n` and `counts`, where `counts` has last dimension `k`, we compute
the multinomial coefficient as:
```n! / sum_i n_i!```
where `i` runs over all `k` classes.
Args:
n: Numeric `Tensor` broadcastable with `counts`. This represents `n`
outcomes.
counts: Numeric `Tensor` broadcastable with `n`. This represents counts
in `k` classes, where `k` is the last dimension of the tensor.
name: A name for this operation (optional).
Returns:
`Tensor` representing the multinomial coefficient between `n` and `counts`.
"""
# First a bit about the number of ways counts could have come in:
# E.g. if counts = [1, 2], then this is 3 choose 2.
# In general, this is (sum counts)! / sum(counts!)
# The sum should be along the last dimension of counts. This is the
# "distribution" dimension. Here n a priori represents the sum of counts.
with ops.op_scope([n, counts], name):
total_permutations = math_ops.lgamma(n + 1)
counts_factorial = math_ops.lgamma(counts + 1)
redundant_permutations = math_ops.reduce_sum(counts_factorial,
reduction_indices=[-1])
return total_permutations - redundant_permutations

View File

@ -46,15 +46,15 @@ class Exponential(gamma.Gamma):
"""Construct Exponential distribution with parameter `lam`. """Construct Exponential distribution with parameter `lam`.
Args: Args:
lam: `float` or `double` tensor, the rate of the distribution(s). lam: Floating point tensor, the rate of the distribution(s).
`lam` must contain only positive values. `lam` must contain only positive values.
validate_args: Whether to assert that `lam > 0`, and that `x > 0` in the validate_args: Whether to assert that `lam > 0`, and that `x > 0` in the
methods `prob(x)` and `log_prob(x)`. If `validate_args` is False methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False`
and the inputs are invalid, correct behavior is not guaranteed. and the inputs are invalid, correct behavior is not guaranteed.
allow_nan_stats: Boolean, default False. If False, raise an exception if allow_nan_stats: Boolean, default `False`. If `False`, raise an
a statistic (e.g. mean/mode/etc...) is undefined for any batch member. exception if a statistic (e.g. mean/mode/etc...) is undefined for any
If True, batch members with valid parameters leading to undefined batch member. If `True`, batch members with valid parameters leading to
statistics will return NaN for this statistic. undefined statistics will return NaN for this statistic.
name: The name to prepend to all ops created by this distribution. name: The name to prepend to all ops created by this distribution.
""" """
# Even though all statistics of are defined for valid inputs, this is not # Even though all statistics of are defined for valid inputs, this is not
@ -95,8 +95,7 @@ class Exponential(gamma.Gamma):
broadcast_shape = self._lam.get_shape() broadcast_shape = self._lam.get_shape()
with ops.op_scope([self.lam, n], name, "ExponentialSample"): with ops.op_scope([self.lam, n], name, "ExponentialSample"):
n = ops.convert_to_tensor(n, name="n") n = ops.convert_to_tensor(n, name="n")
shape = array_ops.concat( shape = array_ops.concat(0, ([n], array_ops.shape(self._lam)))
0, [array_ops.pack([n]), array_ops.shape(self._lam)])
# Sample uniformly-at-random from the open-interval (0, 1). # Sample uniformly-at-random from the open-interval (0, 1).
sampled = random_ops.random_uniform( sampled = random_ops.random_uniform(
shape, minval=np.nextafter( shape, minval=np.nextafter(

View File

@ -69,19 +69,19 @@ class Gamma(distribution.Distribution):
broadcasting (e.g. `alpha + beta` is a valid operation). broadcasting (e.g. `alpha + beta` is a valid operation).
Args: Args:
alpha: `float` or `double` tensor, the shape params of the alpha: Floating point tensor, the shape params of the
distribution(s). distribution(s).
alpha must contain only positive values. alpha must contain only positive values.
beta: `float` or `double` tensor, the inverse scale params of the beta: Floating point tensor, the inverse scale params of the
distribution(s). distribution(s).
beta must contain only positive values. beta must contain only positive values.
validate_args: Whether to assert that `a > 0, b > 0`, and that `x > 0` in validate_args: Whether to assert that `a > 0, b > 0`, and that `x > 0` in
the methods `prob(x)` and `log_prob(x)`. If `validate_args` is False the methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False`
and the inputs are invalid, correct behavior is not guaranteed. and the inputs are invalid, correct behavior is not guaranteed.
allow_nan_stats: Boolean, default False. If False, raise an exception if allow_nan_stats: Boolean, default `False`. If `False`, raise an
a statistic (e.g. mean/mode/etc...) is undefined for any batch member. exception if a statistic (e.g. mean/mode/etc...) is undefined for any
If True, batch members with valid parameters leading to undefined batch member. If `True`, batch members with valid parameters leading to
statistics will return NaN for this statistic. undefined statistics will return NaN for this statistic.
name: The name to prepend to all ops created by this distribution. name: The name to prepend to all ops created by this distribution.
Raises: Raises:
@ -213,9 +213,12 @@ class Gamma(distribution.Distribution):
nan = np.nan * self._ones() nan = np.nan * self._ones()
return math_ops.select(alpha_ge_1, mode_if_defined, nan) return math_ops.select(alpha_ge_1, mode_if_defined, nan)
else: else:
one = ops.convert_to_tensor(1.0, dtype=self.dtype) one = constant_op.constant(1.0, dtype=self.dtype)
return control_flow_ops.with_dependencies( return control_flow_ops.with_dependencies(
[check_ops.assert_less(one, alpha)], mode_if_defined) [check_ops.assert_less(
one, alpha,
message="mode not defined for components of alpha <= 1"
)], mode_if_defined)
def variance(self, name="variance"): def variance(self, name="variance"):
"""Variance of each batch member.""" """Variance of each batch member."""

View File

@ -69,18 +69,18 @@ class InverseGamma(distribution.Distribution):
broadcasting (e.g. `alpha + beta` is a valid operation). broadcasting (e.g. `alpha + beta` is a valid operation).
Args: Args:
alpha: `float` or `double` tensor, the shape params of the alpha: Floating point tensor, the shape params of the
distribution(s). distribution(s).
alpha must contain only positive values. alpha must contain only positive values.
beta: `float` or `double` tensor, the scale params of the distribution(s). beta: Floating point tensor, the scale params of the distribution(s).
beta must contain only positive values. beta must contain only positive values.
validate_args: Whether to assert that `a > 0, b > 0`, and that `x > 0` in validate_args: Whether to assert that `a > 0, b > 0`, and that `x > 0` in
the methods `prob(x)` and `log_prob(x)`. If `validate_args` is False the methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False`
and the inputs are invalid, correct behavior is not guaranteed. and the inputs are invalid, correct behavior is not guaranteed.
allow_nan_stats: Boolean, default False. If False, raise an exception if allow_nan_stats: Boolean, default `False`. If `False`, raise an
a statistic (e.g. mean/mode/etc...) is undefined for any batch member. exception if a statistic (e.g. mean/mode/etc...) is undefined for any
If True, batch members with valid parameters leading to undefined batch member. If `True`, batch members with valid parameters leading to
statistics will return NaN for this statistic. undefined statistics will return NaN for this statistic.
name: The name to prepend to all ops created by this distribution. name: The name to prepend to all ops created by this distribution.
Raises: Raises:
@ -206,9 +206,12 @@ class InverseGamma(distribution.Distribution):
nan = np.nan * self._ones() nan = np.nan * self._ones()
return math_ops.select(alpha_gt_1, mean_if_defined, nan) return math_ops.select(alpha_gt_1, mean_if_defined, nan)
else: else:
one = ops.convert_to_tensor(1.0, dtype=self.dtype) one = constant_op.constant(1.0, dtype=self.dtype)
return control_flow_ops.with_dependencies( return control_flow_ops.with_dependencies(
[check_ops.assert_less(one, alpha)], mean_if_defined) [check_ops.assert_less(
one, alpha,
message="mean not defined for components of alpha <= 1")],
mean_if_defined)
def mode(self, name="mode"): def mode(self, name="mode"):
"""Mode of each batch member. """Mode of each batch member.
@ -250,9 +253,12 @@ class InverseGamma(distribution.Distribution):
nan = np.nan * self._ones() nan = np.nan * self._ones()
return math_ops.select(alpha_gt_2, var_if_defined, nan) return math_ops.select(alpha_gt_2, var_if_defined, nan)
else: else:
two = ops.convert_to_tensor(2.0, dtype=self.dtype) two = constant_op.constant(2.0, dtype=self.dtype)
return control_flow_ops.with_dependencies( return control_flow_ops.with_dependencies(
[check_ops.assert_less(two, alpha)], var_if_defined) [check_ops.assert_less(
two, alpha,
message="variance not defined for components of alpha <= 2")],
var_if_defined)
def log_prob(self, x, name="log_prob"): def log_prob(self, x, name="log_prob"):
"""Log prob of observations in `x` under these InverseGamma distribution(s). """Log prob of observations in `x` under these InverseGamma distribution(s).

View File

@ -34,9 +34,9 @@ def kl(dist_a, dist_b, allow_nan=False, name=None):
Args: Args:
dist_a: instance of distributions.Distribution. dist_a: instance of distributions.Distribution.
dist_b: instance of distributions.Distribution. dist_b: instance of distributions.Distribution.
allow_nan: If False (default), a runtime error is raised allow_nan: If `False` (default), a runtime error is raised
if the KL returns NaN values for any batch entry of the given if the KL returns NaN values for any batch entry of the given
distributions. If True, the KL may return a NaN for the given entry. distributions. If `True`, the KL may return a NaN for the given entry.
name: (optional) Name scope to use for created operations. name: (optional) Name scope to use for created operations.
Returns: Returns:

View File

@ -60,17 +60,17 @@ class Laplace(distribution.Distribution):
broadcasting (e.g., `loc / scale` is a valid operation). broadcasting (e.g., `loc / scale` is a valid operation).
Args: Args:
loc: `float` or `double` tensor which characterizes the location (center) loc: Floating point tensor which characterizes the location (center)
of the distribution. of the distribution.
scale: `float` or `double`, positive-valued tensor which characterzes the scale: Positive floating point tensor which characterizes the spread of
spread of the distribution. the distribution.
validate_args: Whether to validate input with asserts. If `validate_args` validate_args: Whether to validate input with asserts. If `validate_args`
is `False`, and the inputs are invalid, correct behavior is not is `False`, and the inputs are invalid, correct behavior is not
guaranteed. guaranteed.
allow_nan_stats: Boolean, default False. If False, raise an exception if allow_nan_stats: Boolean, default `False`. If `False`, raise an
a statistic (e.g. mean/mode/etc...) is undefined for any batch member. exception if a statistic (e.g. mean/mode/etc...) is undefined for any
If True, batch members with valid parameters leading to undefined batch member. If `True`, batch members with valid parameters leading to
statistics will return NaN for this statistic. undefined statistics will return NaN for this statistic.
name: The name to give Ops created by the initializer. name: The name to give Ops created by the initializer.
Raises: Raises:
@ -294,8 +294,7 @@ class Laplace(distribution.Distribution):
with ops.op_scope([self._loc, self._scale, n], name): with ops.op_scope([self._loc, self._scale, n], name):
n = ops.convert_to_tensor(n) n = ops.convert_to_tensor(n)
n_val = tensor_util.constant_value(n) n_val = tensor_util.constant_value(n)
shape = array_ops.concat( shape = array_ops.concat(0, ([n], self.batch_shape()))
0, [array_ops.pack([n]), self.batch_shape()])
# Sample uniformly-at-random from the open-interval (-1, 1). # Sample uniformly-at-random from the open-interval (-1, 1).
uniform_samples = random_ops.random_uniform( uniform_samples = random_ops.random_uniform(
shape=shape, shape=shape,

View File

@ -0,0 +1,343 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Multinomial distribution class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=line-too-long
from tensorflow.contrib.distributions.python.ops import distribution
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
# pylint: enable=line-too-long
class Multinomial(distribution.Distribution):
"""Multinomial distribution.
This distribution is parameterized by a vector `p` of probability
parameters for `k` classes and `n`, the counts per each class..
#### Mathematical details
The Multinomial is a distribution over k-class count data, meaning
for each k-tuple of non-negative integer `counts = [n_1,...,n_k]`, we have a
probability of these draws being made from the distribution. The distribution
has hyperparameters `p = (p_1,...,p_k)`, and probability mass
function (pmf):
```pmf(counts) = n! / (n_1!...n_k!) * (p_1)^n_1*(p_2)^n_2*...(p_k)^n_k```
where above `n = sum_j n_j`, `n!` is `n` factorial.
#### Examples
Create a 3-class distribution, with the 3rd class is most likely to be drawn,
using logits..
```python
logits = [-50., -43, 0]
dist = Multinomial(n=4., logits=logits)
```
Create a 3-class distribution, with the 3rd class is most likely to be drawn.
```python
p = [.2, .3, .5]
dist = Multinomial(n=4., p=p)
```
The distribution functions can be evaluated on counts.
```python
# counts same shape as p.
counts = [1., 0, 3]
dist.prob(counts) # Shape []
# p will be broadcast to [[.2, .3, .5], [.2, .3, .5]] to match counts.
counts = [[1., 2, 1], [2, 2, 0]]
dist.prob(counts) # Shape [2]
# p will be broadcast to shape [5, 7, 3] to match counts.
counts = [[...]] # Shape [5, 7, 3]
dist.prob(counts) # Shape [5, 7]
```
Create a 2-batch of 3-class distributions.
```python
p = [[.1, .2, .7], [.3, .3, .4]] # Shape [2, 3]
dist = Multinomial(n=[4., 5], p=p)
counts = [[2., 1, 1], [3, 1, 1]]
dist.prob(counts) # Shape [2]
```
"""
def __init__(self,
n,
logits=None,
p=None,
validate_args=True,
allow_nan_stats=False,
name="Multinomial"):
"""Initialize a batch of Multinomial distributions.
Args:
n: Non-negative floating point tensor with shape broadcastable to
`[N1,..., Nm]` with `m >= 0`. Defines this as a batch of
`N1 x ... x Nm` different Multinomial distributions. Its components
should be equal to integer values.
logits: Floating point tensor representing the log-odds of a
positive event with shape broadcastable to `[N1,..., Nm, k], m >= 0`,
and the same dtype as `n`. Defines this as a batch of `N1 x ... x Nm`
different `k` class Multinomial distributions.
p: Positive floating point tensor with shape broadcastable to
`[N1,..., Nm, k]` `m >= 0` and same dtype as `n`. Defines this as
a batch of `N1 x ... x Nm` different `k` class Multinomial
distributions. `p`'s components in the last portion of its shape should
sum up to 1.
validate_args: Whether to assert valid values for parameters `n` and `p`,
and `x` in `prob` and `log_prob`. If `False`, correct behavior is not
guaranteed.
allow_nan_stats: Boolean, default `False`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
batch member. If `True`, batch members with valid parameters leading to
undefined statistics will return NaN for this statistic.
name: The name to prefix Ops created by this distribution class.
Examples:
```python
# Define 1-batch of 2-class multinomial distribution,
# also known as a Binomial distribution.
dist = Multinomial(n=2., p=[.1, .9])
# Define a 2-batch of 3-class distributions.
dist = Multinomial(n=[4., 5], p=[[.1, .3, .6], [.4, .05, .55]])
```
"""
self._logits, self._p = distribution_util.get_logits_and_prob(
name=name, logits=logits, p=p, validate_args=validate_args,
multidimensional=True)
with ops.op_scope([n, self._p], name):
with ops.control_dependencies([
check_ops.assert_non_negative(
n, message="n has negative components."),
distribution_util.assert_integer_form(
n, message="n has non-integer components."
)] if validate_args else []):
self._n = array_ops.identity(n, name="convert_n")
self._name = name
self._validate_args = validate_args
self._allow_nan_stats = allow_nan_stats
self._mean = array_ops.expand_dims(n, -1) * self._p
# Only used for inferring shape.
self._broadcast_shape = math_ops.reduce_sum(self._mean,
reduction_indices=[-1],
keep_dims=False)
self._get_batch_shape = self._broadcast_shape.get_shape()
self._get_event_shape = (
self._mean.get_shape().with_rank_at_least(1)[-1:])
@property
def n(self):
"""Number of trials."""
return self._n
@property
def p(self):
"""Event probabilities."""
return self._p
@property
def logits(self):
"""Log-odds."""
return self._logits
@property
def name(self):
"""Name to prepend to all ops."""
return self._name
@property
def dtype(self):
"""dtype of samples from this distribution."""
return self._p.dtype
@property
def validate_args(self):
"""Boolean describing behavior on invalid input."""
return self._validate_args
@property
def allow_nan_stats(self):
"""Boolean describing behavior when a stat is undefined for batch member."""
return self._allow_nan_stats
def batch_shape(self, name="batch_shape"):
"""Batch dimensions of this instance as a 1-D int32 `Tensor`.
The product of the dimensions of the `batch_shape` is the number of
independent distributions of this kind the instance represents.
Args:
name: name to give to the op
Returns:
`Tensor` `batch_shape`
"""
with ops.name_scope(self.name):
with ops.op_scope([self._broadcast_shape], name):
return array_ops.shape(self._broadcast_shape)
def get_batch_shape(self):
"""`TensorShape` available at graph construction time.
Same meaning as `batch_shape`. May be only partially defined.
Returns:
batch shape
"""
return self._get_batch_shape
def event_shape(self, name="event_shape"):
"""Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
Args:
name: name to give to the op
Returns:
`Tensor` `event_shape`
"""
with ops.name_scope(self.name):
with ops.op_scope([self._mean], name):
return array_ops.gather(array_ops.shape(self._mean),
[array_ops.rank(self._mean) - 1])
def get_event_shape(self):
"""`TensorShape` available at graph construction time.
Same meaning as `event_shape`. May be only partially defined.
Returns:
event shape
"""
return self._get_event_shape
def mean(self, name="mean"):
"""Mean of the distribution."""
with ops.name_scope(self.name):
return array_ops.identity(self._mean, name=name)
def variance(self, name="variance"):
"""Variance of the distribution."""
with ops.name_scope(self.name):
with ops.op_scope([self._n, self._p, self._mean], name):
p = array_ops.expand_dims(
self._p * array_ops.expand_dims(
array_ops.ones_like(self._n), -1), -1)
variance = -math_ops.batch_matmul(
array_ops.expand_dims(self._mean, -1), p, adj_y=True)
variance += array_ops.batch_matrix_diag(self._mean)
return variance
def log_prob(self, counts, name="log_prob"):
"""`Log(P[counts])`, computed for every batch member.
For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability
that after sampling `n` draws from this Multinomial distribution, the
number of draws falling in class `j` is `n_j`. Note that different
sequences of draws can result in the same counts, thus the probability
includes a combinatorial coefficient.
Args:
counts: Non-negative tensor with dtype `dtype` and whose shape can
be broadcast with `self.p` and `self.n`. For fixed leading dimensions,
the last dimension represents counts for the corresponding Multinomial
distribution in `self.p`. `counts` is only legal if it sums up to `n`
and its components are equal to integer values.
name: Name to give this Op, defaults to "log_prob".
Returns:
Log probabilities for each record, shape `[N1,...,Nm]`.
"""
n = self._n
p = self._p
with ops.name_scope(self.name):
with ops.op_scope([n, p, counts], name):
counts = self._check_counts(counts)
prob_prob = math_ops.reduce_sum(counts * math_ops.log(self._p),
reduction_indices=[-1])
log_prob = prob_prob + distribution_util.log_combinations(
n, counts)
return log_prob
def prob(self, counts, name="prob"):
"""`P[counts]`, computed for every batch member.
For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability
that after sampling `n` draws from this Multinomial distribution, the
number of draws falling in class `j` is `n_j`. Note that different
sequences of draws can result in the same counts, thus the probability
includes a combinatorial coefficient.
Args:
counts: Non-negative tensor with dtype `dtype` and whose shape can
be broadcast with `self.p` and `self.n`. For fixed leading dimensions,
the last dimension represents counts for the corresponding Multinomial
distribution in `self.p`. `counts` is only legal if it sums up to `n`
and its components are equal to integer values.
name: Name to give this Op, defaults to "prob".
Returns:
Probabilities for each record, shape `[N1,...,Nm]`.
"""
return super(Multinomial, self).prob(counts, name=name)
@property
def is_continuous(self):
return False
@property
def is_reparameterized(self):
return False
def _check_counts(self, counts):
"""Check counts for proper shape, values, then return tensor version."""
counts = ops.convert_to_tensor(counts, name="counts_before_deps")
candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1])
if not self.validate_args:
return counts
return control_flow_ops.with_dependencies([
check_ops.assert_non_negative(
counts, message="counts has negative components."),
check_ops.assert_equal(
self._n, candidate_n, message="counts do not sum to n."),
distribution_util.assert_integer_form(
counts, message="counts have non-integer components.")], counts)

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import math import math
from tensorflow.contrib.distributions.python.ops import distribution from tensorflow.contrib.distributions.python.ops import distribution
from tensorflow.contrib.distributions.python.ops import kullback_leibler
from tensorflow.contrib.distributions.python.ops import operator_pd_cholesky from tensorflow.contrib.distributions.python.ops import operator_pd_cholesky
from tensorflow.contrib.distributions.python.ops import operator_pd_diag from tensorflow.contrib.distributions.python.ops import operator_pd_diag
from tensorflow.contrib.distributions.python.ops import operator_pd_full from tensorflow.contrib.distributions.python.ops import operator_pd_full
@ -104,9 +105,9 @@ class MultivariateNormalOperatorPD(distribution.Distribution):
which determines the covariance. which determines the covariance.
Args: Args:
mu: `float` or `double` tensor with shape `[N1,...,Nb, k]`, `b >= 0`. mu: Floating point tensor with shape `[N1,...,Nb, k]`, `b >= 0`.
cov: `float` or `double` instance of `OperatorPDBase` with same `dtype` cov: Instance of `OperatorPDBase` with same `dtype` as `mu` and shape
as `mu` and shape `[N1,...,Nb, k, k]`. `[N1,...,Nb, k, k]`.
validate_args: Whether to validate input with asserts. If `validate_args` validate_args: Whether to validate input with asserts. If `validate_args`
is `False`, and the inputs are invalid, correct behavior is not is `False`, and the inputs are invalid, correct behavior is not
guaranteed. guaranteed.
@ -149,7 +150,7 @@ class MultivariateNormalOperatorPD(distribution.Distribution):
else: else:
return mu return mu
# Static checks could not be run, so possibly do dyamic checks. # Static checks could not be run, so possibly do dynamic checks.
if not self.validate_args: if not self.validate_args:
return mu return mu
else: else:
@ -465,7 +466,7 @@ class MultivariateNormalDiag(MultivariateNormalOperatorPD):
The mean of `X_i` is `mu[i]`, and the standard deviation is `diag_stdev[i]`. The mean of `X_i` is `mu[i]`, and the standard deviation is `diag_stdev[i]`.
Args: Args:
mu: Rank `N + 1` `float` or `double` tensor with shape `[N1,...,Nb, k]`, mu: Rank `N + 1` floating point tensor with shape `[N1,...,Nb, k]`,
`b >= 0`. `b >= 0`.
diag_stdev: Rank `N + 1` `Tensor` with same `dtype` and shape as `mu`, diag_stdev: Rank `N + 1` `Tensor` with same `dtype` and shape as `mu`,
representing the standard deviations. Must be positive. representing the standard deviations. Must be positive.
@ -580,13 +581,13 @@ class MultivariateNormalDiagPlusVDVT(MultivariateNormalOperatorPD):
``` ```
Args: Args:
mu: Rank `n + 1` `float` or `double` tensor with shape `[N1,...,Nn, k]`, mu: Rank `n + 1` floating point tensor with shape `[N1,...,Nn, k]`,
`n >= 0`. The means. `n >= 0`. The means.
diag_large: Optional rank `n + 1` `float` or `double` tensor, shape diag_large: Optional rank `n + 1` floating point tensor, shape
`[N1,...,Nn, k]` `n >= 0`. Defines the diagonal matrix `M`. `[N1,...,Nn, k]` `n >= 0`. Defines the diagonal matrix `M`.
v: Rank `n + 1` `float` or `double` tensor, shape `[N1,...,Nn, k, r]` v: Rank `n + 1` floating point tensor, shape `[N1,...,Nn, k, r]`
`n >= 0`. Defines the matrix `V`. `n >= 0`. Defines the matrix `V`.
diag_small: Rank `n + 1` `float` or `double` tensor, shape diag_small: Rank `n + 1` floating point tensor, shape
`[N1,...,Nn, k]` `n >= 0`. Defines the diagonal matrix `D`. Default `[N1,...,Nn, k]` `n >= 0`. Defines the diagonal matrix `D`. Default
is `None`, which means `D` will be the identity matrix. is `None`, which means `D` will be the identity matrix.
validate_args: Whether to validate input with asserts. If `validate_args` validate_args: Whether to validate input with asserts. If `validate_args`
@ -669,7 +670,7 @@ class MultivariateNormalCholesky(MultivariateNormalOperatorPD):
factors, such that the covariance of each batch member is `chol chol^T`. factors, such that the covariance of each batch member is `chol chol^T`.
Args: Args:
mu: `(N+1)-D` `float` or `double` tensor with shape `[N1,...,Nb, k]`, mu: `(N+1)-D` floating point tensor with shape `[N1,...,Nb, k]`,
`b >= 0`. `b >= 0`.
chol: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape chol: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape
`[N1,...,Nb, k, k]`. The upper triangular part is ignored (treated as `[N1,...,Nb, k, k]`. The upper triangular part is ignored (treated as
@ -749,7 +750,7 @@ class MultivariateNormalFull(MultivariateNormalOperatorPD):
User must provide means `mu` and `sigma`, the mean and covariance. User must provide means `mu` and `sigma`, the mean and covariance.
Args: Args:
mu: `(N+1)-D` `float` or `double` tensor with shape `[N1,...,Nb, k]`, mu: `(N+1)-D` floating point tensor with shape `[N1,...,Nb, k]`,
`b >= 0`. `b >= 0`.
sigma: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape sigma: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape
`[N1,...,Nb, k, k]`. Each batch member must be positive definite. `[N1,...,Nb, k, k]`. Each batch member must be positive definite.
@ -772,3 +773,72 @@ class MultivariateNormalFull(MultivariateNormalOperatorPD):
allow_nan_stats=allow_nan_stats, allow_nan_stats=allow_nan_stats,
validate_args=validate_args, validate_args=validate_args,
name=name) name=name)
def _kl_mvn_mvn_brute_force(mvn_a, mvn_b, name=None):
"""Batched KL divergence `KL(mvn_a || mvn_b)` for multivariate normals.
With `X`, `Y` both multivariate normals in `R^k` with means `mu_x`, `mu_y` and
covariance `C_x`, `C_y` respectively,
```
KL(X || Y) = 0.5 * ( T + Q + - k + L ),
T := trace(C_b^{-1} C_a),
Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a),
L := Log[Det(C_b)] - Log[Det(C_a)]
```
This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient
methods for solving systems with `C_b` may be available, a dense version of
(the square root of) `C_a` is used, so performance is `O(B s k^2)` where `B`
is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x`
and `y`.
Args:
mvn_a: Instance of subclass of `MultivariateNormalOperatorPD`.
mvn_b: Instance of subclass of `MultivariateNormalOperatorPD`.
name: (optional) name to use for created ops. Default "kl_mvn_mvn".
Returns:
Batchwise `KL(mvn_a || mvn_b)`.
"""
# Access the "private" OperatorPD that each mvn is built from.
cov_a = mvn_a._cov # pylint: disable=protected-access
cov_b = mvn_b._cov # pylint: disable=protected-access
mu_a = mvn_a.mu
mu_b = mvn_b.mu
inputs = [mu_a, mu_b] + cov_a.inputs + cov_b.inputs
with ops.op_scope(inputs, name, "kl_mvn_mvn"):
# If Ca = AA', Cb = BB', then
# tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A']
# = tr[inv(B) A A' inv(B)']
# = tr[(inv(B) A) (inv(B) A)']
# = sum_{ik} (inv(B) A)_{ik}^2
# The second equality follows from the cyclic permutation property.
b_inv_a = cov_b.sqrt_solve(cov_a.sqrt_to_dense())
t = math_ops.reduce_sum(
math_ops.square(b_inv_a),
reduction_indices=[-1, -2])
q = cov_b.inv_quadratic_form_on_vectors(mu_b - mu_a)
k = math_ops.cast(cov_a.vector_space_dimension(), mvn_a.dtype)
one_half_l = cov_b.sqrt_log_det() - cov_a.sqrt_log_det()
return 0.5 * (t + q - k) + one_half_l
# Register KL divergences.
kl_classes = [
MultivariateNormalFull,
MultivariateNormalCholesky,
MultivariateNormalDiag,
MultivariateNormalDiagPlusVDVT,
]
for mvn_aa in kl_classes:
# Register when they are the same here, and do not register when they are the
# same below because that would result in a repeated registration.
kullback_leibler.RegisterKL(mvn_aa, mvn_aa)(_kl_mvn_mvn_brute_force)
for mvn_bb in kl_classes:
if mvn_bb != mvn_aa:
kullback_leibler.RegisterKL(mvn_aa, mvn_bb)(_kl_mvn_mvn_brute_force)

View File

@ -92,15 +92,15 @@ class Normal(distribution.Distribution):
broadcasting (e.g. `mu + sigma` is a valid operation). broadcasting (e.g. `mu + sigma` is a valid operation).
Args: Args:
mu: `float` or `double` tensor, the means of the distribution(s). mu: Floating point tensor, the means of the distribution(s).
sigma: `float` or `double` tensor, the stddevs of the distribution(s). sigma: Floating point tensor, the stddevs of the distribution(s).
sigma must contain only positive values. sigma must contain only positive values.
validate_args: Whether to assert that `sigma > 0`. If `validate_args` is validate_args: Whether to assert that `sigma > 0`. If `validate_args` is
False, correct output is not guaranteed when input is invalid. `False`, correct output is not guaranteed when input is invalid.
allow_nan_stats: Boolean, default False. If False, raise an exception if allow_nan_stats: Boolean, default `False`. If `False`, raise an
a statistic (e.g. mean/mode/etc...) is undefined for any batch member. exception if a statistic (e.g. mean/mode/etc...) is undefined for any
If True, batch members with valid parameters leading to undefined batch member. If `True`, batch members with valid parameters leading to
statistics will return NaN for this statistic. undefined statistics will return NaN for this statistic.
name: The name to give Ops created by the initializer. name: The name to give Ops created by the initializer.
Raises: Raises:
@ -321,8 +321,7 @@ class Normal(distribution.Distribution):
with ops.op_scope([self._mu, self._sigma, n], name): with ops.op_scope([self._mu, self._sigma, n], name):
broadcast_shape = (self._mu + self._sigma).get_shape() broadcast_shape = (self._mu + self._sigma).get_shape()
n = ops.convert_to_tensor(n) n = ops.convert_to_tensor(n)
shape = array_ops.concat( shape = array_ops.concat(0, ([n], array_ops.shape(self.mean())))
0, [array_ops.pack([n]), array_ops.shape(self.mean())])
sampled = random_ops.random_normal( sampled = random_ops.random_normal(
shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed) shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed)

View File

@ -82,6 +82,7 @@ class StudentT(distribution.Distribution):
# returning a length 2 tensor. # returning a length 2 tensor.
dist.pdf(3.0) dist.pdf(3.0)
``` ```
""" """
def __init__(self, def __init__(self,
@ -99,19 +100,19 @@ class StudentT(distribution.Distribution):
broadcasting (e.g. `df + mu + sigma` is a valid operation). broadcasting (e.g. `df + mu + sigma` is a valid operation).
Args: Args:
df: `float` or `double` tensor, the degrees of freedom of the df: Floating point tensor, the degrees of freedom of the
distribution(s). `df` must contain only positive values. distribution(s). `df` must contain only positive values.
mu: `float` or `double` tensor, the means of the distribution(s). mu: Floating point tensor, the means of the distribution(s).
sigma: `float` or `double` tensor, the scaling factor for the sigma: Floating point tensor, the scaling factor for the
distribution(s). `sigma` must contain only positive values. distribution(s). `sigma` must contain only positive values.
Note that `sigma` is not the standard deviation of this distribution. Note that `sigma` is not the standard deviation of this distribution.
validate_args: Whether to assert that `df > 0, sigma > 0`. If validate_args: Whether to assert that `df > 0, sigma > 0`. If
`validate_args` is False and inputs are invalid, correct behavior is not `validate_args` is `False` and inputs are invalid, correct behavior is
guaranteed. not guaranteed.
allow_nan_stats: Boolean, default False. If False, raise an exception if allow_nan_stats: Boolean, default `False`. If `False`, raise an
a statistic (e.g. mean/mode/etc...) is undefined for any batch member. exception if a statistic (e.g. mean/mode/etc...) is undefined for any
If True, batch members with valid parameters leading to undefined batch member. If `True`, batch members with valid parameters leading to
statistics will return NaN for this statistic. undefined statistics will return NaN for this statistic.
name: The name to give Ops created by the initializer. name: The name to give Ops created by the initializer.
Raises: Raises:
@ -185,9 +186,12 @@ class StudentT(distribution.Distribution):
nan = np.nan + self._zeros() nan = np.nan + self._zeros()
return math_ops.select(df_gt_1, result_if_defined, nan) return math_ops.select(df_gt_1, result_if_defined, nan)
else: else:
one = ops.convert_to_tensor(1.0, dtype=self.dtype) one = constant_op.constant(1.0, dtype=self.dtype)
return control_flow_ops.with_dependencies( return control_flow_ops.with_dependencies(
[check_ops.assert_less(one, self._df)], result_if_defined) [check_ops.assert_less(
one, self._df,
message="mean not defined for components of df <= 1"
)], result_if_defined)
def mode(self, name="mode"): def mode(self, name="mode"):
with ops.name_scope(self.name): with ops.name_scope(self.name):
@ -232,9 +236,12 @@ class StudentT(distribution.Distribution):
result_where_defined, result_where_defined,
self._zeros() + np.nan) self._zeros() + np.nan)
else: else:
one = ops.convert_to_tensor(1.0, self.dtype) one = constant_op.constant(1.0, dtype=self.dtype)
return control_flow_ops.with_dependencies( return control_flow_ops.with_dependencies(
[check_ops.assert_less(one, self._df)], result_where_defined) [check_ops.assert_less(
one, self._df,
message="variance not defined for components of df <= 1"
)], result_where_defined)
def std(self, name="std"): def std(self, name="std"):
with ops.name_scope(self.name): with ops.name_scope(self.name):
@ -348,8 +355,7 @@ class StudentT(distribution.Distribution):
# Let X = R*cos(theta), and let Y = R*sin(theta). # Let X = R*cos(theta), and let Y = R*sin(theta).
# Then X ~ t_df and Y ~ t_df. # Then X ~ t_df and Y ~ t_df.
# The variates X and Y are not independent. # The variates X and Y are not independent.
shape = array_ops.concat(0, [array_ops.pack([2, n]), shape = array_ops.concat(0, ([2, n], self.batch_shape()))
self.batch_shape()])
uniform = random_ops.random_uniform(shape=shape, uniform = random_ops.random_uniform(shape=shape,
dtype=self.dtype, dtype=self.dtype,
seed=seed) seed=seed)

View File

@ -57,6 +57,7 @@ class TransformedDistribution(distribution.Distribution):
name="LogitNormalTransformedDistribution" name="LogitNormalTransformedDistribution"
) )
``` ```
""" """
def __init__(self, def __init__(self,

View File

@ -67,14 +67,14 @@ class Uniform(distribution.Distribution):
``` ```
Args: Args:
a: `float` or `double` tensor, the minimum endpoint. a: Floating point tensor, the minimum endpoint.
b: `float` or `double` tensor, the maximum endpoint. Must be > `a`. b: Floating point tensor, the maximum endpoint. Must be > `a`.
validate_args: Whether to assert that `a > b`. If `validate_args` is False validate_args: Whether to assert that `a > b`. If `validate_args` is
and inputs are invalid, correct behavior is not guaranteed. `False` and inputs are invalid, correct behavior is not guaranteed.
allow_nan_stats: Boolean, default False. If False, raise an exception if allow_nan_stats: Boolean, default `False`. If `False`, raise an
a statistic (e.g. mean/mode/etc...) is undefined for any batch member. exception if a statistic (e.g. mean/mode/etc...) is undefined for any
If True, batch members with valid parameters leading to undefined batch member. If `True`, batch members with valid parameters leading to
statistics will return NaN for this statistic. undefined statistics will return NaN for this statistic.
name: The name to prefix Ops created by this distribution class. name: The name to prefix Ops created by this distribution class.
Raises: Raises:
@ -83,8 +83,9 @@ class Uniform(distribution.Distribution):
self._allow_nan_stats = allow_nan_stats self._allow_nan_stats = allow_nan_stats
self._validate_args = validate_args self._validate_args = validate_args
with ops.op_scope([a, b], name): with ops.op_scope([a, b], name):
with ops.control_dependencies([check_ops.assert_less(a, b)] if with ops.control_dependencies([check_ops.assert_less(
validate_args else []): a, b, message="uniform not defined when a > b.")] if validate_args
else []):
a = array_ops.identity(a, name="a") a = array_ops.identity(a, name="a")
b = array_ops.identity(b, name="b") b = array_ops.identity(b, name="b")
@ -228,7 +229,7 @@ class Uniform(distribution.Distribution):
n = ops.convert_to_tensor(n, name="n") n = ops.convert_to_tensor(n, name="n")
n_val = tensor_util.constant_value(n) n_val = tensor_util.constant_value(n)
shape = array_ops.concat(0, [array_ops.pack([n]), self.batch_shape()]) shape = array_ops.concat(0, ([n], self.batch_shape()))
samples = random_ops.random_uniform(shape=shape, samples = random_ops.random_uniform(shape=shape,
dtype=self.dtype, dtype=self.dtype,
seed=seed) seed=seed)

View File

@ -94,6 +94,30 @@ tf_py_test(
], ],
) )
tf_py_test(
name = "gmm_test",
srcs = [
"python/ops/gmm_test.py",
],
additional_deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
tf_py_test(
name = "gmm_ops_test",
srcs = [
"python/ops/gmm_ops_test.py",
],
additional_deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
tf_py_test( tf_py_test(
name = "factorization_ops_test", name = "factorization_ops_test",
srcs = ["python/ops/factorization_ops_test.py"], srcs = ["python/ops/factorization_ops_test.py"],

View File

@ -304,7 +304,7 @@ class WalsModelTest(tf.test.TestCase):
col_factors2 = [x.eval() for x in wals_model.col_factors] col_factors2 = [x.eval() for x in wals_model.col_factors]
for c1, c2 in zip(col_factors1, col_factors2): for c1, c2 in zip(col_factors1, col_factors2):
self.assertAllClose(c1, c2, atol=1e-3) self.assertAllClose(c1, c2, rtol=5e-3, atol=1e-2)
def test_als_transposed(self): def test_als_transposed(self):
with self.test_session(): with self.test_session():
@ -383,7 +383,7 @@ class WalsModelTest(tf.test.TestCase):
regularization=1e-5, regularization=1e-5,
row_weights=None, row_weights=None,
col_weights=None) col_weights=None)
self.simple_train(model, inp, 15) self.simple_train(model, inp, 25)
row_factor = model.row_factors[0].eval() row_factor = model.row_factors[0].eval()
col_factor = model.col_factors[0].eval() col_factor = model.col_factors[0].eval()
self.assertAllClose(data, self.assertAllClose(data,
@ -407,7 +407,7 @@ class WalsModelTest(tf.test.TestCase):
regularization=1e-5, regularization=1e-5,
row_weights=[0] * rows, row_weights=[0] * rows,
col_weights=[0] * cols) col_weights=[0] * cols)
self.simple_train(model, inp, 15) self.simple_train(model, inp, 25)
row_factor = model.row_factors[0].eval() row_factor = model.row_factors[0].eval()
col_factor = model.col_factors[0].eval() col_factor = model.col_factors[0].eval()
self.assertAllClose(data, self.assertAllClose(data,
@ -438,7 +438,7 @@ class WalsModelTest(tf.test.TestCase):
regularization=0.001, regularization=0.001,
row_weights=row_wts, row_weights=row_wts,
col_weights=col_wts) col_weights=col_wts)
self.simple_train(model, inp, 10) self.simple_train(model, inp, 25)
row_factor = model.row_factors[0].eval() row_factor = model.row_factors[0].eval()
col_factor = model.col_factors[0].eval() col_factor = model.col_factors[0].eval()
out = np.dot(row_factor, np.transpose(col_factor)) out = np.dot(row_factor, np.transpose(col_factor))
@ -446,7 +446,7 @@ class WalsModelTest(tf.test.TestCase):
for j in xrange(cols): for j in xrange(cols):
if keep_index([i, j]): if keep_index([i, j]):
self.assertNear(data[i][j], out[i][j], self.assertNear(data[i][j], out[i][j],
err=0.2, msg="%d, %d" % (i, j)) err=0.4, msg="%d, %d" % (i, j))
else: else:
self.assertNear(0, out[i][j], err=0.5, msg="%d, %d" % (i, j)) self.assertNear(0, out[i][j], err=0.5, msg="%d, %d" % (i, j))

View File

@ -0,0 +1,211 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Gaussian mixture model (GMM) clustering.
This goes on top of skflow API.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.contrib.factorization.python.ops import gmm_ops
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators._sklearn import TransformerMixin
from tensorflow.contrib.learn.python.learn.learn_io import data_feeder
from tensorflow.contrib.learn.python.learn.utils import checkpoints
from tensorflow.python.ops.control_flow_ops import with_dependencies
class GMM(estimator.Estimator, TransformerMixin):
"""GMM clustering."""
SCORES = 'scores'
ASSIGNMENTS = 'assignments'
ALL_SCORES = 'all_scores'
def __init__(self,
num_clusters,
model_dir=None,
random_seed=0,
params='wmc',
initial_clusters='random',
covariance_type='full',
batch_size=128,
steps=10,
continue_training=False,
config=None,
verbose=1):
"""Creates a model for running GMM training and inference.
Args:
num_clusters: number of clusters to train.
model_dir: the directory to save the model results and log files.
random_seed: Python integer. Seed for PRNG used to initialize centers.
params: Controls which parameters are updated in the training process.
Can contain any combination of "w" for weights, "m" for means,
and "c" for covars.
initial_clusters: specifies how to initialize the clusters for training.
See gmm_ops.gmm for the possible values.
covariance_type: one of "full", "diag".
batch_size: See TensorFlowEstimator
steps: See TensorFlowEstimator
continue_training: See TensorFlowEstimator
config: See TensorFlowEstimator
verbose: See TensorFlowEstimator
"""
super(GMM, self).__init__(
model_dir=model_dir,
config=config)
self.batch_size = batch_size
self.steps = steps
self.continue_training = continue_training
self.verbose = verbose
self._num_clusters = num_clusters
self._params = params
self._training_initial_clusters = initial_clusters
self._covariance_type = covariance_type
self._training_graph = None
self._random_seed = random_seed
def fit(self, x, y=None, monitors=None, logdir=None, steps=None):
"""Trains a GMM clustering on x.
Note: See TensorFlowEstimator for logic for continuous training and graph
construction across multiple calls to fit.
Args:
x: training input matrix of shape [n_samples, n_features].
y: labels. Should be None.
monitors: List of `Monitor` objects to print training progress and
invoke early stopping.
logdir: the directory to save the log file that can be used for optional
visualization.
steps: number of training steps. If not None, overrides the value passed
in constructor.
Returns:
Returns self.
"""
if logdir is not None:
self._model_dir = logdir
self._data_feeder = data_feeder.setup_train_data_feeder(
x, None, self._num_clusters, self.batch_size)
self._train_model(input_fn=self._data_feeder.input_builder,
feed_fn=self._data_feeder.get_feed_dict_fn(),
steps=steps or self.steps,
monitors=monitors,
init_feed_fn=self._data_feeder.get_feed_dict_fn())
return self
def predict(self, x, batch_size=None):
"""Predict cluster id for each element in x.
Args:
x: 2-D matrix or iterator.
batch_size: size to use for batching up x for querying the model.
Returns:
Array with same number of rows as x, containing cluster ids.
"""
return super(GMM, self).predict(x=x, batch_size=batch_size)[GMM.ASSIGNMENTS]
def score(self, x, batch_size=None):
"""Predict total sum of distances to nearest clusters.
Args:
x: 2-D matrix or iterator.
batch_size: size to use for batching up x for querying the model.
Returns:
Total score.
"""
return np.sum(self.evaluate(x=x, batch_size=batch_size)[GMM.SCORES])
def transform(self, x, batch_size=None):
"""Transforms each element in x to distances to cluster centers.
Args:
x: 2-D matrix or iterator.
batch_size: size to use for batching up x for querying the model.
Returns:
Array with same number of rows as x, and num_clusters columns, containing
distances to the cluster centers.
"""
return super(GMM, self).predict(x=x, batch_size=batch_size)[GMM.ALL_SCORES]
def clusters(self):
"""Returns cluster centers."""
clusters = checkpoints.load_variable(self.model_dir,
gmm_ops.GmmAlgorithm.CLUSTERS_VARIABLE)
return np.squeeze(clusters, 1)
def covariances(self):
"""Returns the covariances."""
return checkpoints.load_variable(
self.model_dir,
gmm_ops.GmmAlgorithm.CLUSTERS_COVS_VARIABLE)
def _get_train_ops(self, features, _):
(_,
_,
losses,
training_op) = gmm_ops.gmm(
features,
self._training_initial_clusters,
self._num_clusters,
self._random_seed,
self._covariance_type,
self._params)
incr_step = tf.assign_add(tf.contrib.framework.get_global_step(), 1)
loss = tf.reduce_sum(losses)
training_op = with_dependencies([training_op, incr_step], loss)
return training_op, loss
def _get_predict_ops(self, features):
(all_scores,
model_predictions,
_,
_) = gmm_ops.gmm(
features,
self._training_initial_clusters,
self._num_clusters,
self._random_seed,
self._covariance_type,
self._params)
return {
GMM.ALL_SCORES: all_scores[0],
GMM.ASSIGNMENTS: model_predictions[0]
}
def _get_eval_ops(self, features, _, unused_metrics):
(_,
_,
losses,
_) = gmm_ops.gmm(
features,
self._training_initial_clusters,
self._num_clusters,
self._random_seed,
self._covariance_type,
self._params)
return {
GMM.SCORES: tf.reduce_sum(losses),
}

View File

@ -0,0 +1,461 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gaussian mixture models Operations."""
# TODO(xavigonzalvo): Factor out covariance matrix operations to make
# code reusable for different types (e.g. diag).
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.python.ops.embedding_ops import embedding_lookup
# Machine epsilon.
MEPS = np.finfo(float).eps
FULL_COVARIANCE = 'full'
DIAG_COVARIANCE = 'diag'
def _covariance(x, diag):
"""Defines the covariance operation of a matrix.
Args:
x: a matrix Tensor. Dimension 0 should contain the number of examples.
diag: if True, it computes the diagonal covariance.
Returns:
A Tensor representing the covariance of x. In the case of
diagonal matrix just the diagonal is returned.
"""
num_points = tf.to_float(tf.shape(x)[0])
x -= tf.reduce_mean(x, 0, keep_dims=True)
if diag:
cov = tf.reduce_sum(
tf.square(x), 0, keep_dims=True) / (num_points - 1)
else:
cov = tf.matmul(x, x, transpose_a=True) / (num_points - 1)
return cov
def _init_clusters_random(data, num_clusters, random_seed):
"""Does random initialization of clusters.
Args:
data: a list of Tensors with a matrix of data, each row is an example.
num_clusters: an integer with the number of clusters.
random_seed: Seed for PRNG used to initialize seeds.
Returns:
A Tensor with num_clusters random rows of data.
"""
assert isinstance(data, list)
num_data = tf.add_n([tf.shape(inp)[0] for inp in data])
with tf.control_dependencies([tf.assert_less_equal(num_clusters, num_data)]):
indices = tf.random_uniform([num_clusters],
minval=0,
maxval=tf.cast(num_data, tf.int64),
seed=random_seed,
dtype=tf.int64)
indices = tf.cast(indices, tf.int32) % num_data
clusters_init = embedding_lookup(data, indices, partition_strategy='div')
return clusters_init
class GmmAlgorithm(object):
"""Tensorflow Gaussian mixture model clustering class."""
CLUSTERS_VARIABLE = 'clusters'
CLUSTERS_COVS_VARIABLE = 'clusters_covs'
def __init__(self, data, num_classes, initial_means=None, params='wmc',
covariance_type=FULL_COVARIANCE, random_seed=0):
"""Constructor.
Args:
data: a list of Tensors with data, each row is a new example.
num_classes: number of clusters.
initial_means: a Tensor with a matrix of means. If None, means are
computed by sampling randomly.
params: Controls which parameters are updated in the training
process. Can contain any combination of "w" for weights, "m" for
means, and "c" for covariances.
covariance_type: one of "full", "diag".
random_seed: Seed for PRNG used to initialize seeds.
Raises:
Exception if covariance type is unknown.
"""
self._params = params
self._random_seed = random_seed
self._covariance_type = covariance_type
if self._covariance_type not in [DIAG_COVARIANCE, FULL_COVARIANCE]:
raise Exception( # pylint: disable=g-doc-exception
'programmer error: Invalid covariance type: %s' %
self._covariance_type)
# Create sharded variables for multiple shards. The following
# lists are indexed by shard.
# Probability per example in a class.
num_shards = len(data)
self._probs = [None] * num_shards
# Prior probability.
self._prior_probs = [None] * num_shards
# Membership weights w_{ik} where "i" is the i-th example and "k"
# is the k-th mixture.
self._w = [None] * num_shards
# Number of examples in a class.
self._points_in_k = [None] * num_shards
first_shard = data[0]
self._dimensions = tf.shape(first_shard)[1]
self._num_classes = num_classes
# Small value to guarantee that covariances are invertible.
self._min_var = tf.diag(tf.ones(tf.pack([self._dimensions]))) * 1e-3
self._create_variables(data, initial_means)
# Operations of partial statistics for the computation of the means.
self._w_mul_x = []
# Operations of partial statistics for the computation of the covariances.
self._w_mul_x2 = []
self._define_graph(data)
def _create_variables(self, data, initial_means=None):
"""Initializes GMM algorithm.
Args:
data: a list of Tensors with data, each row is a new example.
initial_means: a Tensor with a matrix of means.
"""
first_shard = data[0]
# Initialize means: num_classes X 1 X dimensions.
if initial_means is not None:
self._means = tf.Variable(tf.expand_dims(initial_means, 1),
name=self.CLUSTERS_VARIABLE,
validate_shape=False, dtype=tf.float32)
else:
# Sample data randomly
self._means = tf.Variable(tf.expand_dims(
_init_clusters_random(data, self._num_classes, self._random_seed), 1),
name=self.CLUSTERS_VARIABLE,
validate_shape=False)
# Initialize covariances.
if self._covariance_type == FULL_COVARIANCE:
cov = _covariance(first_shard, False) + self._min_var
# A matrix per class, num_classes X dimensions X dimensions
covs = tf.tile(
tf.expand_dims(cov, 0), [self._num_classes, 1, 1])
elif self._covariance_type == DIAG_COVARIANCE:
cov = _covariance(first_shard, True) + self._min_var
# A diagonal per row, num_classes X dimensions.
covs = tf.tile(tf.expand_dims(tf.diag_part(cov), 0),
[self._num_classes, 1])
self._covs = tf.Variable(covs, name='clusters_covs', validate_shape=False)
# Mixture weights, representing the probability that a randomly
# selected unobservable data (in EM terms) was generated by component k.
self._alpha = tf.Variable(tf.tile([1.0 / self._num_classes],
[self._num_classes]))
def training_ops(self):
"""Returns the training operation."""
return self._train_ops
def alphas(self):
return self._alpha
def clusters(self):
"""Returns the clusters with dimensions num_classes X 1 X num_dimensions."""
return self._means
def covariances(self):
"""Returns the covariances matrices."""
return self._covs
def assignments(self):
"""Returns a list of Tensors with the matrix of assignments per shard."""
ret = []
for w in self._w:
ret.append(tf.argmax(w, 1))
return ret
def scores(self):
"""Returns the distances to each class.
Returns:
A tuple with two Tensors. The first contains the distance to
each class. The second contains the distance to the assigned
class.
"""
return (self._all_scores, self._scores)
def _define_graph(self, data):
"""Define graph for a single iteration.
Args:
data: a list of Tensors defining the training data.
"""
for shard_id, shard in enumerate(data):
self._num_examples = tf.shape(shard)[0]
shard = tf.expand_dims(shard, 0)
self._define_log_prob_operation(shard_id, shard)
self._define_prior_log_prob_operation(shard_id)
self._define_expectation_operation(shard_id)
self._define_partial_maximization_operation(shard_id, shard)
self._define_maximization_operation(len(data))
self._define_distance_to_clusters(data)
def _define_full_covariance_probs(self, shard_id, shard):
"""Defines the full covariance probabilties per example in a class.
Updates a matrix with dimension num_examples X num_classes.
Args:
shard_id: id of the current shard.
shard: current data shard, 1 X num_examples X dimensions.
"""
diff = shard - self._means
cholesky = tf.batch_cholesky(self._covs + self._min_var)
log_det_covs = 2.0 * tf.reduce_sum(tf.log(
tf.batch_matrix_diag_part(cholesky)), 1)
x_mu_cov = tf.square(tf.batch_matrix_triangular_solve(
cholesky, tf.transpose(diff, perm=[0, 2, 1]),
lower=True))
diag_m = tf.transpose(tf.reduce_sum(x_mu_cov, 1))
self._probs[shard_id] = -0.5 * (
diag_m + tf.to_float(self._dimensions) * tf.log(2 * np.pi) +
log_det_covs)
def _define_diag_covariance_probs(self, shard_id, shard):
"""Defines the diagonal covariance probabilities per example in a class.
Args:
shard_id: id of the current shard.
shard: current data shard, 1 X num_examples X dimensions.
Returns a matrix num_examples * num_classes.
"""
# num_classes X 1
# TODO(xavigonzalvo): look into alternatives to log for
# reparametrization of variance parameters.
det_expanded = tf.reduce_sum(tf.log(self._covs + 1e-3),
1, keep_dims=True)
diff = shard - self._means
x2 = tf.square(diff)
cov_expanded = tf.expand_dims(1.0 / (self._covs + 1e-3), 2)
# num_classes X num_examples
x2_cov = tf.batch_matmul(x2, cov_expanded)
x2_cov = tf.transpose(tf.squeeze(x2_cov, [2]))
self._probs[shard_id] = -0.5 * (
tf.to_float(self._dimensions) * tf.log(2.0 * np.pi) +
tf.transpose(det_expanded) + x2_cov)
def _define_log_prob_operation(self, shard_id, shard):
"""Probability per example in a class.
Updates a matrix with dimension num_examples X num_classes.
Args:
shard_id: id of the current shard.
shard: current data shard, 1 X num_examples X dimensions.
"""
# TODO(xavigonzalvo): Use the pdf defined in
# third_party/tensorflow/contrib/distributions/python/ops/gaussian.py
if self._covariance_type == FULL_COVARIANCE:
self._define_full_covariance_probs(shard_id, shard)
elif self._covariance_type == DIAG_COVARIANCE:
self._define_diag_covariance_probs(shard_id, shard)
self._probs[shard_id] += tf.log(self._alpha)
def _define_prior_log_prob_operation(self, shard_id):
"""Computes the prior probability of all samples.
Updates a vector where each item is the prior probabibility of an
input example.
Args:
shard_id: id of current shard_id.
"""
self._prior_probs[shard_id] = tf.log(
tf.reduce_sum(tf.exp(self._probs[shard_id]), 1, keep_dims=True))
def _define_expectation_operation(self, shard_id):
# Shape broadcasting.
probs = tf.expand_dims(self._probs[shard_id], 0)
# Membership weights are computed as:
# w_{ik} = \frac{\alpha_k f(\mathbf{y_i}|\mathbf{\theta}_k)}
# {\sum_{m=1}^{K}\alpha_mf(\mathbf{y_i}|\mathbf{\theta}_m)}
# where "i" is the i-th example, "k" is the k-th mixture, theta are
# the model parameters and y_i the observations.
# These are defined for each shard.
self._w[shard_id] = tf.reshape(
tf.exp(probs - self._prior_probs[shard_id]),
tf.pack([self._num_examples, self._num_classes]))
def _define_partial_maximization_operation(self, shard_id, shard):
"""Computes the partial statistics of the means and covariances.
Args:
shard_id: current shard id.
shard: current data shard, 1 X num_examples X dimensions.
"""
# Soft assignment of each data point to each of the two clusters.
self._points_in_k[shard_id] = tf.reduce_sum(self._w[shard_id], 0,
keep_dims=True)
# Partial means.
w_mul_x = tf.expand_dims(
tf.matmul(self._w[shard_id],
tf.squeeze(shard, [0]), transpose_a=True), 1)
self._w_mul_x.append(w_mul_x)
# Partial covariances.
x = tf.concat(0, [shard for _ in range(self._num_classes)])
x_trans = tf.transpose(x, perm=[0, 2, 1])
x_mul_w = tf.concat(0, [
tf.expand_dims(x_trans[k, :, :] * self._w[shard_id][:, k], 0)
for k in range(self._num_classes)])
self._w_mul_x2.append(tf.batch_matmul(x_mul_w, x))
def _define_maximization_operation(self, num_batches):
"""Maximization operations."""
# TODO(xavigonzalvo): some of these operations could be moved to C++.
# Compute the effective number of data points assigned to component k.
with tf.control_dependencies(self._w):
points_in_k = tf.squeeze(tf.add_n(self._points_in_k), squeeze_dims=[0])
# Update alpha.
if 'w' in self._params:
final_points_in_k = points_in_k / num_batches
num_examples = tf.to_float(tf.reduce_sum(final_points_in_k))
self._alpha_op = self._alpha.assign(
final_points_in_k / (num_examples + MEPS))
else:
self._alpha_op = tf.no_op()
self._train_ops = [self._alpha_op]
# Update means.
points_in_k_expanded = tf.reshape(points_in_k,
[self._num_classes, 1, 1])
if 'm' in self._params:
self._means_op = self._means.assign(
tf.div(tf.add_n(self._w_mul_x), points_in_k_expanded + MEPS))
else:
self._means_op = tf.no_op()
# means are (num_classes x 1 x dims)
# Update covariances.
with tf.control_dependencies([self._means_op]):
b = tf.add_n(self._w_mul_x2) / (points_in_k_expanded + MEPS)
new_covs = []
for k in range(self._num_classes):
mean = self._means.ref()[k, :, :]
square_mean = tf.matmul(mean, mean, transpose_a=True)
new_cov = b[k, :, :] - square_mean + self._min_var
if self._covariance_type == FULL_COVARIANCE:
new_covs.append(tf.expand_dims(new_cov, 0))
elif self._covariance_type == DIAG_COVARIANCE:
new_covs.append(tf.expand_dims(tf.diag_part(new_cov), 0))
new_covs = tf.concat(0, new_covs)
if 'c' in self._params:
# Train operations don't need to take care of the means
# because covariances already depend on it.
with tf.control_dependencies([self._means_op, new_covs]):
self._train_ops.append(
tf.assign(self._covs, new_covs, validate_shape=False))
def _define_distance_to_clusters(self, data):
"""Defines the Mahalanobis distance to the assigned Gaussian."""
# TODO(xavigonzalvo): reuse (input - mean) * cov^-1 * (input -
# mean) from log probability function.
self._all_scores = []
for shard in data:
all_scores = []
shard = tf.expand_dims(shard, 0)
for c in xrange(self._num_classes):
if self._covariance_type == FULL_COVARIANCE:
cov = self._covs[c, :, :]
elif self._covariance_type == DIAG_COVARIANCE:
cov = tf.diag(self._covs[c, :])
inverse = tf.matrix_inverse(cov + self._min_var)
inv_cov = tf.tile(
tf.expand_dims(inverse, 0),
tf.pack([self._num_examples, 1, 1]))
diff = tf.transpose(shard - self._means[c, :, :], perm=[1, 0, 2])
m_left = tf.batch_matmul(diff, inv_cov)
all_scores.append(tf.sqrt(tf.batch_matmul(
m_left, tf.transpose(diff, perm=[0, 2, 1])
)))
self._all_scores.append(tf.reshape(
tf.concat(1, all_scores),
tf.pack([self._num_examples, self._num_classes])))
# Distance to the associated class.
self._all_scores = tf.concat(0, self._all_scores)
assignments = tf.concat(0, self.assignments())
rows = tf.to_int64(tf.range(0, self._num_examples))
indices = tf.concat(1, [tf.expand_dims(rows, 1),
tf.expand_dims(assignments, 1)])
self._scores = tf.gather_nd(self._all_scores, indices)
def _define_loglikelihood_operation(self):
"""Defines the total log-likelihood of current iteration."""
self._ll_op = []
for prior_probs in self._prior_probs:
self._ll_op.append(tf.reduce_sum(tf.log(prior_probs)))
tf.scalar_summary('ll', tf.reduce_sum(self._ll_op))
def gmm(inp, initial_clusters, num_clusters, random_seed,
covariance_type=FULL_COVARIANCE, params='wmc'):
"""Creates the graph for Gaussian mixture model (GMM) clustering.
Args:
inp: An input tensor or list of input tensors
initial_clusters: Specifies the clusters used during
initialization. Can be a tensor or numpy array, or a function
that generates the clusters. Can also be "random" to specify
that clusters should be chosen randomly from input data. Note: type
is diverse to be consistent with skflow.
num_clusters: number of clusters.
random_seed: Python integer. Seed for PRNG used to initialize centers.
covariance_type: one of "diag", "full".
params: Controls which parameters are updated in the training
process. Can contain any combination of "w" for weights, "m" for
means, and "c" for covars.
Returns:
Note: tuple of lists returned to be consistent with skflow
A tuple consisting of:
all_scores: A matrix (or list of matrices) of dimensions (num_input,
num_clusters) where the value is the distance of an input vector and a
cluster center.
assignments: A vector (or list of vectors). Each element in the vector
corresponds to an input row in 'inp' and specifies the cluster id
corresponding to the input.
scores: Similar to assignments but specifies the distance to the
assigned cluster instead.
training_op: an op that runs an iteration of training.
"""
initial_means = None
if initial_clusters != 'random' and not isinstance(
initial_clusters, tf.Tensor):
initial_means = tf.constant(initial_clusters, dtype=tf.float32)
# Implementation of GMM.
inp = inp if isinstance(inp, list) else [inp]
gmm_tool = GmmAlgorithm(inp, num_clusters, initial_means, params,
covariance_type, random_seed)
training_ops = gmm_tool.training_ops()
assignments = gmm_tool.assignments()
all_scores, scores = gmm_tool.scores()
return [all_scores], [assignments], [scores], tf.group(*training_ops)

View File

@ -0,0 +1,198 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for gmm_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.factorization.python.ops import gmm_ops
from tensorflow.python.platform import tf_logging as logging
class GmmOpsTest(tf.test.TestCase):
def setUp(self):
self.num_examples = 1000
self.iterations = 40
self.seed = 4
tf.set_random_seed(self.seed)
np.random.seed(self.seed * 2)
self.data, self.true_assignments = self.make_data(self.num_examples)
# Generate more complicated data.
self.centers = [[1, 1], [-1, 0.5], [2, 1]]
self.more_data, self.more_true_assignments = self.make_data_from_centers(
self.num_examples, self.centers)
@staticmethod
def make_data(num_vectors):
"""Generates 2-dimensional data centered on (2,2), (-1,-1).
Args:
num_vectors: number of training examples.
Returns:
A tuple containing the data as a numpy array and the cluster ids.
"""
vectors = []
classes = []
for _ in xrange(num_vectors):
if np.random.random() > 0.5:
vectors.append([np.random.normal(2.0, 0.6),
np.random.normal(2.0, 0.9)])
classes.append(0)
else:
vectors.append([np.random.normal(-1.0, 0.4),
np.random.normal(-1.0, 0.5)])
classes.append(1)
return np.asarray(vectors), classes
@staticmethod
def make_data_from_centers(num_vectors, centers):
"""Generates 2-dimensional data with random centers.
Args:
num_vectors: number of training examples.
centers: a list of random 2-dimensional centers.
Returns:
A tuple containing the data as a numpy array and the cluster ids.
"""
vectors = []
classes = []
for _ in xrange(num_vectors):
current_class = np.random.random_integers(0, len(centers) - 1)
vectors.append([np.random.normal(centers[current_class][0],
np.random.random_sample()),
np.random.normal(centers[current_class][1],
np.random.random_sample())])
classes.append(current_class)
return np.asarray(vectors), len(centers)
def test_covariance(self):
start_time = time.time()
data = self.data.T
np_cov = np.cov(data)
logging.info('Numpy took %f', time.time() - start_time)
start_time = time.time()
with self.test_session() as sess:
op = gmm_ops._covariance(
tf.constant(data.T, dtype=tf.float32),
False)
op_diag = gmm_ops._covariance(
tf.constant(data.T, dtype=tf.float32),
True)
tf.initialize_all_variables().run()
tf_cov = sess.run(op)
np.testing.assert_array_almost_equal(np_cov, tf_cov)
logging.info('Tensorflow took %f', time.time() - start_time)
tf_cov = sess.run(op_diag)
np.testing.assert_array_almost_equal(
np.diag(np_cov), np.ravel(tf_cov), decimal=5)
def test_simple_cluster(self):
"""Tests that the clusters are correct."""
num_classes = 2
graph = tf.Graph()
with graph.as_default() as g:
g.seed = 5
with self.test_session() as sess:
data = tf.constant(self.data, dtype=tf.float32)
_, assignments, _, training_op = gmm_ops.gmm(data, 'random',
num_classes,
random_seed=self.seed)
tf.initialize_all_variables().run()
for _ in xrange(self.iterations):
sess.run(training_op)
assignments = sess.run(assignments)
accuracy = np.mean(
np.asarray(self.true_assignments) == np.squeeze(assignments))
logging.info('Accuracy: %f', accuracy)
self.assertGreater(accuracy, 0.98)
def testParams(self):
"""Tests that the params work as intended."""
num_classes = 2
with self.test_session() as sess:
# Experiment 1. Update weights only.
data = tf.constant(self.data, dtype=tf.float32)
gmm_tool = gmm_ops.GmmAlgorithm([data], num_classes,
[[3.0, 3.0], [0.0, 0.0]], 'w')
training_ops = gmm_tool.training_ops()
tf.initialize_all_variables().run()
for _ in xrange(self.iterations):
sess.run(training_ops)
# Only the probability to each class is updated.
alphas = sess.run(gmm_tool.alphas())
self.assertGreater(alphas[1], 0.6)
means = sess.run(gmm_tool.clusters())
np.testing.assert_almost_equal(
np.expand_dims([[3.0, 3.0], [0.0, 0.0]], 1), means)
covs = sess.run(gmm_tool.covariances())
np.testing.assert_almost_equal(covs[0], covs[1])
# Experiment 2. Update means and covariances.
gmm_tool = gmm_ops.GmmAlgorithm([data], num_classes,
[[3.0, 3.0], [0.0, 0.0]], 'mc')
training_ops = gmm_tool.training_ops()
tf.initialize_all_variables().run()
for _ in xrange(self.iterations):
sess.run(training_ops)
alphas = sess.run(gmm_tool.alphas())
self.assertAlmostEqual(alphas[0], alphas[1])
means = sess.run(gmm_tool.clusters())
np.testing.assert_almost_equal(
np.expand_dims([[2.0, 2.0], [-1.0, -1.0]], 1), means, decimal=1)
covs = sess.run(gmm_tool.covariances())
np.testing.assert_almost_equal(
[[0.371111, -0.0050774], [-0.0050774, 0.8651744]],
covs[0], decimal=4)
np.testing.assert_almost_equal(
[[0.146976, 0.0259463], [0.0259463, 0.2543971]],
covs[1], decimal=4)
# Experiment 3. Update covariances only.
gmm_tool = gmm_ops.GmmAlgorithm([data], num_classes,
[[-1.0, -1.0], [1.0, 1.0]], 'c')
training_ops = gmm_tool.training_ops()
tf.initialize_all_variables().run()
for _ in xrange(self.iterations):
sess.run(training_ops)
alphas = sess.run(gmm_tool.alphas())
self.assertAlmostEqual(alphas[0], alphas[1])
means = sess.run(gmm_tool.clusters())
np.testing.assert_almost_equal(
np.expand_dims([[-1.0, -1.0], [1.0, 1.0]], 1), means)
covs = sess.run(gmm_tool.covariances())
np.testing.assert_almost_equal(
[[0.1299582, 0.0435872], [0.0435872, 0.2558578]],
covs[0], decimal=5)
np.testing.assert_almost_equal(
[[3.195385, 2.6989155], [2.6989155, 3.3881593]],
covs[1], decimal=5)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,172 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ops.gmm."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.factorization.python.ops.gmm import GMM
from tensorflow.contrib.factorization.python.ops.kmeans import KMeansClustering as KMeans
from tensorflow.contrib.learn.python.learn.estimators import run_config
FLAGS = tf.app.flags.FLAGS
class GMMTest(tf.test.TestCase):
def setUp(self):
np.random.seed(3)
tf.set_random_seed(2)
self.num_centers = 2
self.num_dims = 2
self.num_points = 4000
self.batch_size = 100
self.true_centers = self.make_random_centers(self.num_centers,
self.num_dims)
self.points, self.assignments, self.scores = self.make_random_points(
self.true_centers,
self.num_points)
self.true_score = np.add.reduce(self.scores)
# Use initial means from kmeans (just like scikit-learn does).
clusterer = KMeans(num_clusters=self.num_centers)
clusterer.fit(self.points, steps=30)
self.initial_means = clusterer.clusters()
@staticmethod
def make_random_centers(num_centers, num_dims):
return np.round(np.random.rand(num_centers,
num_dims).astype(np.float32) * 500)
@staticmethod
def make_random_points(centers, num_points):
num_centers, num_dims = centers.shape
assignments = np.random.choice(num_centers, num_points)
offsets = np.round(np.random.randn(num_points,
num_dims).astype(np.float32) * 20)
points = centers[assignments] + offsets
means = [np.mean(points[assignments == center], axis=0)
for center in xrange(num_centers)]
covs = [np.cov(points[assignments == center].T)
for center in xrange(num_centers)]
scores = []
for r in xrange(num_points):
scores.append(np.sqrt(np.dot(
np.dot(points[r, :] - means[assignments[r]],
np.linalg.inv(covs[assignments[r]])),
points[r, :] - means[assignments[r]])))
return (points, assignments, scores)
def test_clusters(self):
"""Tests the shape of the clusters."""
gmm = GMM(self.num_centers,
initial_clusters=self.initial_means,
batch_size=self.batch_size,
steps=40,
continue_training=True,
random_seed=4,
config=run_config.RunConfig(tf_random_seed=2))
gmm.fit(x=self.points, steps=0)
clusters = gmm.clusters()
self.assertAllEqual(list(clusters.shape),
[self.num_centers, self.num_dims])
def test_fit(self):
gmm = GMM(self.num_centers,
initial_clusters='random',
batch_size=self.batch_size,
random_seed=4,
config=run_config.RunConfig(tf_random_seed=2))
gmm.fit(x=self.points, steps=1)
score1 = gmm.score(x=self.points)
gmm = GMM(self.num_centers,
initial_clusters='random',
batch_size=self.batch_size,
random_seed=4,
config=run_config.RunConfig(tf_random_seed=2))
gmm.fit(x=self.points, steps=10)
score2 = gmm.score(x=self.points)
self.assertGreater(score1, score2)
self.assertNear(self.true_score, score2, self.true_score * 0.15)
def test_infer(self):
gmm = GMM(self.num_centers,
initial_clusters=self.initial_means,
batch_size=self.batch_size,
steps=40,
continue_training=True,
random_seed=4,
config=run_config.RunConfig(tf_random_seed=2))
gmm.fit(x=self.points, steps=60)
clusters = gmm.clusters()
# Make a small test set
points, true_assignments, true_offsets = (
self.make_random_points(clusters, 40))
assignments = np.ravel(gmm.predict(points))
self.assertAllEqual(true_assignments, assignments)
# Test score
score = gmm.score(points)
self.assertNear(score, np.sum(true_offsets), 4.05)
def _compare_with_sklearn(self, cov_type):
# sklearn version.
iterations = 40
np.random.seed(5)
sklearn_assignments = np.asarray([0, 0, 1, 0, 0, 0, 1, 0, 0, 1])
sklearn_means = np.asarray([[144.83417719, 254.20130341],
[274.38754816, 353.16074346]])
sklearn_covs = np.asarray([[[395.0081194, -4.50389512],
[-4.50389512, 408.27543989]],
[[385.17484203, -31.27834935],
[-31.27834935, 391.74249925]]])
# skflow version.
gmm = GMM(self.num_centers,
initial_clusters=self.initial_means,
covariance_type=cov_type,
batch_size=self.num_points,
steps=iterations,
continue_training=True,
config=run_config.RunConfig(tf_random_seed=2))
gmm.fit(self.points)
skflow_assignments = gmm.predict(self.points[:10, :]).astype(int)
self.assertAllClose(sklearn_assignments,
np.ravel(skflow_assignments))
self.assertAllClose(sklearn_means, gmm.clusters())
if cov_type == 'full':
self.assertAllClose(sklearn_covs, gmm.covariances(), rtol=0.01)
else:
for d in [0, 1]:
self.assertAllClose(np.diag(sklearn_covs[d]),
gmm.covariances()[d, :], rtol=0.01)
def test_compare_full(self):
self._compare_with_sklearn('full')
def test_compare_diag(self):
self._compare_with_sklearn('diag')
if __name__ == '__main__':
tf.test.main()

View File

@ -153,9 +153,11 @@ class KMeansTest(tf.test.TestCase):
def test_fit_with_cosine_distance(self): def test_fit_with_cosine_distance(self):
# Create points on y=x and y=1.5x lines to check the cosine similarity. # Create points on y=x and y=1.5x lines to check the cosine similarity.
# Note that euclidean distance will give different results in this case. # Note that euclidean distance will give different results in this case.
points = np.array([[9, 9], [0.5, 0.5], [10, 15], [0.4, 0.6]]) points = np.array(
[[9, 9], [0.5, 0.5], [10, 15], [0.4, 0.6]], dtype=np.float32)
# true centers are the unit vectors on lines y=x and y=1.5x # true centers are the unit vectors on lines y=x and y=1.5x
true_centers = np.array([[0.70710678, 0.70710678], [0.5547002, 0.83205029]]) true_centers = np.array(
[[0.70710678, 0.70710678], [0.5547002, 0.83205029]], dtype=np.float32)
kmeans = KMeans(2, kmeans = KMeans(2,
initial_clusters=kmeans_ops.RANDOM_INIT, initial_clusters=kmeans_ops.RANDOM_INIT,
distance_metric=kmeans_ops.COSINE_DISTANCE, distance_metric=kmeans_ops.COSINE_DISTANCE,
@ -168,8 +170,9 @@ class KMeansTest(tf.test.TestCase):
np.sort(true_centers, axis=0)) np.sort(true_centers, axis=0))
def test_transform_with_cosine_distance(self): def test_transform_with_cosine_distance(self):
points = np.array([[2.5, 3.5], [2, 8], [3, 1], [3, 18], points = np.array(
[-2.5, -3.5], [-2, -8], [-3, -1], [-3, -18]]) [[2.5, 0.1], [2, 0.2], [3, 0.1], [4, 0.2],
[0.1, 2.5], [0.2, 2], [0.1, 3], [0.2, 4]], dtype=np.float32)
true_centers = [normalize(np.mean(normalize(points)[4:, :], axis=0, true_centers = [normalize(np.mean(normalize(points)[4:, :], axis=0,
keepdims=True))[0], keepdims=True))[0],
@ -180,8 +183,8 @@ class KMeansTest(tf.test.TestCase):
initial_clusters=kmeans_ops.RANDOM_INIT, initial_clusters=kmeans_ops.RANDOM_INIT,
distance_metric=kmeans_ops.COSINE_DISTANCE, distance_metric=kmeans_ops.COSINE_DISTANCE,
use_mini_batch=self.use_mini_batch, use_mini_batch=self.use_mini_batch,
config=self.config(3)) config=self.config(5))
kmeans.fit(x=points, steps=30, batch_size=8) kmeans.fit(x=points, steps=50, batch_size=8)
centers = normalize(kmeans.clusters()) centers = normalize(kmeans.clusters())
self.assertAllClose(np.sort(centers, axis=0), self.assertAllClose(np.sort(centers, axis=0),
@ -193,16 +196,16 @@ class KMeansTest(tf.test.TestCase):
self.assertAllClose(transform, true_transform, atol=1e-3) self.assertAllClose(transform, true_transform, atol=1e-3)
def test_predict_with_cosine_distance(self): def test_predict_with_cosine_distance(self):
points = np.array([[2.5, 3.5], [2, 8], [3, 1], [3, 18], points = np.array(
[-2.5, -3.5], [-2, -8], [-3, -1], [-3, -18]]).astype( [[2.5, 0.1], [2, 0.2], [3, 0.1], [4, 0.2],
np.float32) [0.1, 2.5], [0.2, 2], [0.1, 3], [0.2, 4]], dtype=np.float32)
true_centers = np.array( true_centers = np.array(
[normalize(np.mean(normalize(points)[0:4, :], [normalize(np.mean(normalize(points)[0:4, :],
axis=0, axis=0,
keepdims=True))[0], keepdims=True))[0],
normalize(np.mean(normalize(points)[4:, :], normalize(np.mean(normalize(points)[4:, :],
axis=0, axis=0,
keepdims=True))[0]]) keepdims=True))[0]], dtype=np.float32)
true_assignments = [0] * 4 + [1] * 4 true_assignments = [0] * 4 + [1] * 4
true_score = len(points) - np.tensordot(normalize(points), true_score = len(points) - np.tensordot(normalize(points),
true_centers[true_assignments]) true_centers[true_assignments])
@ -230,14 +233,14 @@ class KMeansTest(tf.test.TestCase):
# the less populated centers. # the less populated centers.
points = np.array([[2.5, 3.5], [2.5, 3.5], [-2, 3], [-2, 3], [-3, -3], points = np.array([[2.5, 3.5], [2.5, 3.5], [-2, 3], [-2, 3], [-3, -3],
[-3.1, -3.2], [-2.8, -3.], [-2.9, -3.1], [-3., -3.1], [-3.1, -3.2], [-2.8, -3.], [-2.9, -3.1], [-3., -3.1],
[-3., -3.1], [-3.2, -3.], [-3., -3.]]).astype(np.float32) [-3., -3.1], [-3.2, -3.], [-3., -3.]], dtype=np.float32)
true_centers = np.array( true_centers = np.array(
[normalize(np.mean(normalize(points)[0:2, :], axis=0, [normalize(np.mean(normalize(points)[0:2, :], axis=0,
keepdims=True))[0], keepdims=True))[0],
normalize(np.mean(normalize(points)[2:4, :], axis=0, normalize(np.mean(normalize(points)[2:4, :], axis=0,
keepdims=True))[0], keepdims=True))[0],
normalize(np.mean(normalize(points)[4:, :], axis=0, normalize(np.mean(normalize(points)[4:, :], axis=0,
keepdims=True))[0]]) keepdims=True))[0]], dtype=np.float32)
true_assignments = [0] * 2 + [1] * 2 + [2] * 8 true_assignments = [0] * 2 + [1] * 2 + [2] * 8
true_score = len(points) - np.tensordot(normalize(points), true_score = len(points) - np.tensordot(normalize(points),
true_centers[true_assignments]) true_centers[true_assignments])
@ -262,7 +265,7 @@ class KMeansTest(tf.test.TestCase):
self.assertAllClose(score, true_score, atol=1e-2) self.assertAllClose(score, true_score, atol=1e-2)
def test_fit_raise_if_num_clusters_larger_than_num_points_random_init(self): def test_fit_raise_if_num_clusters_larger_than_num_points_random_init(self):
points = np.array([[2.0, 3.0], [1.6, 8.2]]) points = np.array([[2.0, 3.0], [1.6, 8.2]], dtype=np.float32)
with self.assertRaisesOpError('less'): with self.assertRaisesOpError('less'):
kmeans = KMeans(num_clusters=3, initial_clusters=kmeans_ops.RANDOM_INIT) kmeans = KMeans(num_clusters=3, initial_clusters=kmeans_ops.RANDOM_INIT)
@ -270,7 +273,7 @@ class KMeansTest(tf.test.TestCase):
def test_fit_raise_if_num_clusters_larger_than_num_points_kmeans_plus_plus( def test_fit_raise_if_num_clusters_larger_than_num_points_kmeans_plus_plus(
self): self):
points = np.array([[2.0, 3.0], [1.6, 8.2]]) points = np.array([[2.0, 3.0], [1.6, 8.2]], dtype=np.float32)
with self.assertRaisesOpError(AssertionError): with self.assertRaisesOpError(AssertionError):
kmeans = KMeans(num_clusters=3, kmeans = KMeans(num_clusters=3,

View File

@ -21,6 +21,7 @@
#include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h" #include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
@ -63,13 +64,11 @@ class FileDeleter {
class DecodeAudioOp : public OpKernel { class DecodeAudioOp : public OpKernel {
public: public:
explicit DecodeAudioOp(OpKernelConstruction* context) explicit DecodeAudioOp(OpKernelConstruction* context) : OpKernel(context) {
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_)); OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_));
file_format_ = str_util::Lowercase(file_format_); file_format_ = str_util::Lowercase(file_format_);
const std::set<string> valid_file_formats( const std::set<string> valid_file_formats(
kValidFileFormats, kValidFileFormats, kValidFileFormats + TF_ARRAYSIZE(kValidFileFormats));
kValidFileFormats + TF_ARRAYSIZE(kValidFileFormats));
OP_REQUIRES(context, valid_file_formats.count(file_format_) == 1, OP_REQUIRES(context, valid_file_formats.count(file_format_) == 1,
errors::InvalidArgument( errors::InvalidArgument(
"file_format arg must be in {", "file_format arg must be in {",
@ -80,8 +79,7 @@ class DecodeAudioOp : public OpKernel {
OP_REQUIRES(context, samples_per_second_ > 0, OP_REQUIRES(context, samples_per_second_ > 0,
errors::InvalidArgument("samples_per_second must be > 0.")); errors::InvalidArgument("samples_per_second must be > 0."));
OP_REQUIRES_OK( OP_REQUIRES_OK(context, context->GetAttr("channel_count", &channel_count_));
context, context->GetAttr("channel_count", &channel_count_));
OP_REQUIRES(context, channel_count_ > 0, OP_REQUIRES(context, channel_count_ > 0,
errors::InvalidArgument("channel_count must be > 0.")); errors::InvalidArgument("channel_count must be > 0."));
} }
@ -117,14 +115,13 @@ class DecodeAudioOp : public OpKernel {
LOG(ERROR) << "Ffmpeg failed with error '" << result.error_message() LOG(ERROR) << "Ffmpeg failed with error '" << result.error_message()
<< "'. Returning empty tensor."; << "'. Returning empty tensor.";
Tensor* output = nullptr; Tensor* output = nullptr;
OP_REQUIRES_OK( OP_REQUIRES_OK(context,
context, context->allocate_output(0, TensorShape({0, 0}), &output)); context->allocate_output(0, TensorShape({0, 0}), &output));
return; return;
} else { } else {
OP_REQUIRES_OK(context, result); OP_REQUIRES_OK(context, result);
} }
OP_REQUIRES( OP_REQUIRES(context, !output_samples.empty(),
context, !output_samples.empty(),
errors::Unknown("No output created by FFmpeg.")); errors::Unknown("No output created by FFmpeg."));
OP_REQUIRES( OP_REQUIRES(
context, output_samples.size() % channel_count_ == 0, context, output_samples.size() % channel_count_ == 0,
@ -133,8 +130,8 @@ class DecodeAudioOp : public OpKernel {
// Copy the output data to the output Tensor. // Copy the output data to the output Tensor.
Tensor* output = nullptr; Tensor* output = nullptr;
const int64 frame_count = output_samples.size() / channel_count_; const int64 frame_count = output_samples.size() / channel_count_;
OP_REQUIRES_OK( OP_REQUIRES_OK(context,
context, context->allocate_output( context->allocate_output(
0, TensorShape({frame_count, channel_count_}), &output)); 0, TensorShape({frame_count, channel_count_}), &output));
auto matrix = output->tensor<float, 2>(); auto matrix = output->tensor<float, 2>();
for (int32 frame = 0; frame < frame_count; ++frame) { for (int32 frame = 0; frame < frame_count; ++frame) {
@ -159,6 +156,15 @@ REGISTER_OP("DecodeAudio")
.Attr("file_format: string") .Attr("file_format: string")
.Attr("samples_per_second: int") .Attr("samples_per_second: int")
.Attr("channel_count: int") .Attr("channel_count: int")
.SetShapeFn([](shape_inference::InferenceContext* c) {
int64 channels;
if (c->GetAttr("channel_count", &channels).ok()) {
c->set_output(0, c->Matrix(c->UnknownDim(), channels));
} else {
c->set_output(0, c->Matrix(c->UnknownDim(), c->UnknownDim()));
}
return Status::OK();
})
.Doc(R"doc( .Doc(R"doc(
Processes the contents of an audio file into a tensor using FFmpeg to decode Processes the contents of an audio file into a tensor using FFmpeg to decode
the file. the file.

View File

@ -16,6 +16,7 @@
#include <limits> #include <limits>
#include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h" #include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
@ -24,8 +25,7 @@ namespace ffmpeg {
class EncodeAudioOp : public OpKernel { class EncodeAudioOp : public OpKernel {
public: public:
explicit EncodeAudioOp(OpKernelConstruction* context) explicit EncodeAudioOp(OpKernelConstruction* context) : OpKernel(context) {
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_)); OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_));
file_format_ = str_util::Lowercase(file_format_); file_format_ = str_util::Lowercase(file_format_);
OP_REQUIRES(context, file_format_ == "wav", OP_REQUIRES(context, file_format_ == "wav",
@ -35,15 +35,15 @@ class EncodeAudioOp : public OpKernel {
context, context->GetAttr("samples_per_second", &samples_per_second_)); context, context->GetAttr("samples_per_second", &samples_per_second_));
OP_REQUIRES(context, samples_per_second_ > 0, OP_REQUIRES(context, samples_per_second_ > 0,
errors::InvalidArgument("samples_per_second must be > 0.")); errors::InvalidArgument("samples_per_second must be > 0."));
OP_REQUIRES_OK( OP_REQUIRES_OK(context,
context, context->GetAttr("bits_per_second", &bits_per_second_)); context->GetAttr("bits_per_second", &bits_per_second_));
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
// Get and verify the input data. // Get and verify the input data.
OP_REQUIRES(context, context->num_inputs() == 1, OP_REQUIRES(
errors::InvalidArgument( context, context->num_inputs() == 1,
"EncodeAudio requires exactly one input.")); errors::InvalidArgument("EncodeAudio requires exactly one input."));
const Tensor& contents = context->input(0); const Tensor& contents = context->input(0);
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(contents.shape()), OP_REQUIRES(context, TensorShapeUtils::IsMatrix(contents.shape()),
errors::InvalidArgument( errors::InvalidArgument(
@ -88,6 +88,7 @@ REGISTER_OP("EncodeAudio")
.Attr("file_format: string") .Attr("file_format: string")
.Attr("samples_per_second: int") .Attr("samples_per_second: int")
.Attr("bits_per_second: int = 192000") .Attr("bits_per_second: int = 192000")
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc( .Doc(R"doc(
Processes a `Tensor` containing sampled audio with the number of channels Processes a `Tensor` containing sampled audio with the number of channels
and length of the audio specified by the dimensions of the `Tensor`. The and length of the audio specified by the dimensions of the `Tensor`. The

View File

@ -91,6 +91,15 @@ py_test(
deps = ["//tensorflow:tensorflow_py"], deps = ["//tensorflow:tensorflow_py"],
) )
py_test(
name = "sampling_ops_threading_test",
size = "small",
srcs = ["python/ops/sampling_ops_threading_test.py"],
srcs_version = "PY2AND3",
tags = ["notsan"],
deps = ["//tensorflow:tensorflow_py"],
)
filegroup( filegroup(
name = "all_files", name = "all_files",
srcs = glob( srcs = glob(

View File

@ -30,6 +30,7 @@
## Deprecation ## Deprecation
@@deprecated @@deprecated
@@deprecated_arg_values
## Arg_Scope ## Arg_Scope
@@arg_scope @@arg_scope

View File

@ -21,4 +21,5 @@ from __future__ import print_function
# pylint: disable=wildcard-import # pylint: disable=wildcard-import
from tensorflow.contrib.framework.python.framework.checkpoint_utils import * from tensorflow.contrib.framework.python.framework.checkpoint_utils import *
from tensorflow.contrib.framework.python.framework.deprecation import deprecated from tensorflow.contrib.framework.python.framework.deprecation import deprecated
from tensorflow.contrib.framework.python.framework.deprecation import deprecated_arg_values
from tensorflow.contrib.framework.python.framework.tensor_util import * from tensorflow.contrib.framework.python.framework.tensor_util import *

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools
import inspect
import re import re
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
@ -34,45 +36,77 @@ def _get_qualified_name(function):
return function.__name__ return function.__name__
def _add_deprecation_to_docstring(doc, date, instructions): def _add_deprecation_to_docstring(
doc, instructions, no_doc_str, suffix_str, notice):
"""Adds a deprecation notice to a docstring.""" """Adds a deprecation notice to a docstring."""
if not doc: if not doc:
lines = ['DEPRECATED FUNCTION'] lines = [no_doc_str]
else: else:
lines = doc.splitlines() lines = doc.splitlines()
lines[0] += ' (deprecated)' lines[0] += ' ' + suffix_str
notice = [ notice = [''] + notice + [instructions]
'',
'THIS FUNCTION IS DEPRECATED. It will be removed after %s.' % date,
'Instructions for updating:',
'%s' % instructions,
]
if len(lines) > 1: if len(lines) > 1:
# Make sure that we keep our distance from the main body # Make sure that we keep our distance from the main body
if lines[1].strip(): if lines[1].strip():
notice += [''] notice.append('')
lines = [lines[0]] + notice + lines[1:] lines[1:1] = notice
else: else:
lines += notice lines += notice
return '\n'.join(lines) return '\n'.join(lines)
def _add_deprecated_function_notice_to_docstring(doc, date, instructions):
"""Adds a deprecation notice to a docstring for deprecated functions."""
return _add_deprecation_to_docstring(
doc, instructions,
'DEPRECATED FUNCTION',
'(deprecated)', [
'THIS FUNCTION IS DEPRECATED. It will be removed after %s.' % date,
'Instructions for updating:'])
def _add_deprecated_arg_notice_to_docstring(doc, date, instructions):
"""Adds a deprecation notice to a docstring for deprecated arguments."""
return _add_deprecation_to_docstring(
doc, instructions,
'DEPRECATED FUNCTION ARGUMENTS',
'(deprecated arguments)', [
'SOME ARGUMENTS ARE DEPRECATED. '
'They will be removed after %s.' % date,
'Instructions for updating:'])
def _validate_deprecation_args(date, instructions):
if not date:
raise ValueError('Tell us what date this will be deprecated!')
if not re.match(r'20\d\d-[01]\d-[0123]\d', date):
raise ValueError('Date must be YYYY-MM-DD.')
if not instructions:
raise ValueError('Don\'t deprecate things without conversion instructions!')
def _validate_callable(func, decorator_name):
if not hasattr(func, '__call__'):
raise ValueError(
'%s is not a function. If this is a property, '
'apply @%s after @property.' % (func, decorator_name))
def deprecated(date, instructions): def deprecated(date, instructions):
"""Decorator for marking functions or methods deprecated. """Decorator for marking functions or methods deprecated.
This decorator adds a deprecation warning to a function's docstring. It has This decorator logs a deprecation warning whenever the decorated function is
the following format: called. It has the following format:
<function> (from <module>) is deprecated and will be removed after <date>. <function> (from <module>) is deprecated and will be removed after <date>.
Instructions for updating: Instructions for updating:
<instructions> <instructions>
whenever the decorated function is called. <function> will include the class <function> will include the class name if it is a method.
name if it is a method.
It also edits the docstring of the function: ' (deprecated)' is appended It also edits the docstring of the function: ' (deprecated)' is appended
to the first line of the docstring and a deprecation notice is prepended to the first line of the docstring and a deprecation notice is prepended
@ -90,28 +124,73 @@ def deprecated(date, instructions):
Raises: Raises:
ValueError: If date is not in ISO 8601 format, or instructions are empty. ValueError: If date is not in ISO 8601 format, or instructions are empty.
""" """
if not date: _validate_deprecation_args(date, instructions)
raise ValueError('Tell us what date this will be deprecated!')
if not re.match(r'20\d\d-[01]\d-[0123]\d', date):
raise ValueError('Date must be YYYY-MM-DD.')
if not instructions:
raise ValueError('Don\'t deprecate things without conversion instructions!')
def deprecated_wrapper(func): def deprecated_wrapper(func):
"""Deprecation wrapper.""" """Deprecation wrapper."""
if not hasattr(func, '__call__'): _validate_callable(func, 'deprecated')
raise ValueError( @functools.wraps(func)
'%s is not a function.'
'If this is a property, apply @deprecated after @property.' % func)
def new_func(*args, **kwargs): def new_func(*args, **kwargs):
logging.warning('%s (from %s) is deprecated and will be removed after %s.' logging.warning(
'\nInstructions for updating:\n%s', '%s (from %s) is deprecated and will be removed after %s.\n'
_get_qualified_name(func), func.__module__, 'Instructions for updating:\n%s',
date, instructions) _get_qualified_name(func), func.__module__, date, instructions)
return func(*args, **kwargs) return func(*args, **kwargs)
new_func.__name__ = func.__name__ new_func.__doc__ = _add_deprecated_function_notice_to_docstring(
new_func.__doc__ = _add_deprecation_to_docstring(func.__doc__, date, func.__doc__, date, instructions)
instructions) return new_func
new_func.__dict__.update(func.__dict__) return deprecated_wrapper
def deprecated_arg_values(date, instructions, **deprecated_kwargs):
"""Decorator for marking specific function argument values as deprecated.
This decorator logs a deprecation warning whenever the decorated function is
called with the deprecated argument values. It has the following format:
Calling <function> (from <module>) with <arg>=<value> is deprecated and
will be removed after <date>. Instructions for updating:
<instructions>
<function> will include the class name if it is a method.
It also edits the docstring of the function: ' (deprecated arguments)' is
appended to the first line of the docstring and a deprecation notice is
prepended to the rest of the docstring.
Args:
date: String. The date the function is scheduled to be removed. Must be
ISO 8601 (YYYY-MM-DD).
instructions: String. Instructions on how to update code using the
deprecated function.
**deprecated_kwargs: The deprecated argument values.
Returns:
Decorated function or method.
Raises:
ValueError: If date is not in ISO 8601 format, or instructions are empty.
"""
_validate_deprecation_args(date, instructions)
if not deprecated_kwargs:
raise ValueError('Specify which argument values are deprecated.')
def deprecated_wrapper(func):
"""Deprecation decorator."""
_validate_callable(func, 'deprecated_arg_values')
@functools.wraps(func)
def new_func(*args, **kwargs):
"""Deprecation wrapper."""
named_args = inspect.getcallargs(func, *args, **kwargs)
for arg_name, arg_value in deprecated_kwargs.items():
if arg_name in named_args and named_args[arg_name] == arg_value:
logging.warning(
'Calling %s (from %s) with %s=%s is deprecated and will be '
'removed after %s.\nInstructions for updating:\n%s',
_get_qualified_name(func), func.__module__,
arg_name, arg_value, date, instructions)
return func(*args, **kwargs)
new_func.__doc__ = _add_deprecated_arg_notice_to_docstring(
func.__doc__, date, instructions)
return new_func return new_func
return deprecated_wrapper return deprecated_wrapper

View File

@ -56,7 +56,7 @@ class DeprecationTest(tf.test.TestCase):
Args: Args:
arg0: Arg 0. arg0: Arg 0.
arg1: Arg 0. arg1: Arg 1.
Returns: Returns:
Sum of args. Sum of args.
@ -73,13 +73,38 @@ class DeprecationTest(tf.test.TestCase):
"\n" "\n"
"\n Args:" "\n Args:"
"\n arg0: Arg 0." "\n arg0: Arg 0."
"\n arg1: Arg 0." "\n arg1: Arg 1."
"\n" "\n"
"\n Returns:" "\n Returns:"
"\n Sum of args." "\n Sum of args."
"\n " % (date, instructions), "\n " % (date, instructions),
_fn.__doc__) _fn.__doc__)
self.assertEqual({}, _fn.__dict__)
# Assert calling new fn issues log warning.
self.assertEqual(3, _fn(1, 2))
self.assertEqual(1, mock_warning.call_count)
(args, _) = mock_warning.call_args
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
self._assert_subset(set([date, instructions]), set(args[1:]))
@tf.test.mock.patch.object(logging, "warning", autospec=True)
def test_static_fn_with_one_line_doc(self, mock_warning):
date = "2016-07-04"
instructions = "This is how you update..."
@deprecation.deprecated(date, instructions)
def _fn(arg0, arg1):
"""fn doc."""
return arg0 + arg1
# Assert function docs are properly updated.
self.assertEqual("_fn", _fn.__name__)
self.assertEqual(
"fn doc. (deprecated)"
"\n"
"\nTHIS FUNCTION IS DEPRECATED. It will be removed after %s."
"\nInstructions for updating:\n%s" % (date, instructions),
_fn.__doc__)
# Assert calling new fn issues log warning. # Assert calling new fn issues log warning.
self.assertEqual(3, _fn(1, 2)) self.assertEqual(3, _fn(1, 2))
@ -106,7 +131,6 @@ class DeprecationTest(tf.test.TestCase):
"\nInstructions for updating:" "\nInstructions for updating:"
"\n%s" % (date, instructions), "\n%s" % (date, instructions),
_fn.__doc__) _fn.__doc__)
self.assertEqual({}, _fn.__dict__)
# Assert calling new fn issues log warning. # Assert calling new fn issues log warning.
self.assertEqual(3, _fn(1, 2)) self.assertEqual(3, _fn(1, 2))
@ -131,7 +155,7 @@ class DeprecationTest(tf.test.TestCase):
Args: Args:
arg0: Arg 0. arg0: Arg 0.
arg1: Arg 0. arg1: Arg 1.
Returns: Returns:
Sum of args. Sum of args.
@ -147,7 +171,7 @@ class DeprecationTest(tf.test.TestCase):
"\n" "\n"
"\n Args:" "\n Args:"
"\n arg0: Arg 0." "\n arg0: Arg 0."
"\n arg1: Arg 0." "\n arg1: Arg 1."
"\n" "\n"
"\n Returns:" "\n Returns:"
"\n Sum of args." "\n Sum of args."
@ -161,6 +185,36 @@ class DeprecationTest(tf.test.TestCase):
self.assertRegexpMatches(args[0], r"deprecated and will be removed after") self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
self._assert_subset(set([date, instructions]), set(args[1:])) self._assert_subset(set([date, instructions]), set(args[1:]))
@tf.test.mock.patch.object(logging, "warning", autospec=True)
def test_instance_fn_with_one_line_doc(self, mock_warning):
date = "2016-07-04"
instructions = "This is how you update..."
class _Object(object):
def __init(self):
pass
@deprecation.deprecated(date, instructions)
def _fn(self, arg0, arg1):
"""fn doc."""
return arg0 + arg1
# Assert function docs are properly updated.
self.assertEqual(
"fn doc. (deprecated)"
"\n"
"\nTHIS FUNCTION IS DEPRECATED. It will be removed after %s."
"\nInstructions for updating:\n%s" % (date, instructions),
getattr(_Object, "_fn").__doc__)
# Assert calling new fn issues log warning.
self.assertEqual(3, _Object()._fn(1, 2))
self.assertEqual(1, mock_warning.call_count)
(args, _) = mock_warning.call_args
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
self._assert_subset(set([date, instructions]), set(args[1:]))
@tf.test.mock.patch.object(logging, "warning", autospec=True) @tf.test.mock.patch.object(logging, "warning", autospec=True)
def test_instance_fn_no_doc(self, mock_warning): def test_instance_fn_no_doc(self, mock_warning):
date = "2016-07-04" date = "2016-07-04"
@ -280,5 +334,155 @@ class DeprecationTest(tf.test.TestCase):
self._assert_subset(set([date, instructions]), set(args[1:])) self._assert_subset(set([date, instructions]), set(args[1:]))
class DeprecatedArgsTest(tf.test.TestCase):
def _assert_subset(self, expected_subset, actual_set):
self.assertTrue(
actual_set.issuperset(expected_subset),
msg="%s is not a superset of %s." % (actual_set, expected_subset))
def test_deprecated_illegal_args(self):
instructions = "This is how you update..."
with self.assertRaisesRegexp(ValueError, "date"):
deprecation.deprecated_arg_values(
None, instructions, deprecated=True)
with self.assertRaisesRegexp(ValueError, "date"):
deprecation.deprecated_arg_values(
"", instructions, deprecated=True)
with self.assertRaisesRegexp(ValueError, "YYYY-MM-DD"):
deprecation.deprecated_arg_values(
"07-04-2016", instructions, deprecated=True)
date = "2016-07-04"
with self.assertRaisesRegexp(ValueError, "instructions"):
deprecation.deprecated_arg_values(
date, None, deprecated=True)
with self.assertRaisesRegexp(ValueError, "instructions"):
deprecation.deprecated_arg_values(
date, "", deprecated=True)
with self.assertRaisesRegexp(ValueError, "argument", deprecated=True):
deprecation.deprecated_arg_values(
date, instructions)
@tf.test.mock.patch.object(logging, "warning", autospec=True)
def test_static_fn_with_doc(self, mock_warning):
date = "2016-07-04"
instructions = "This is how you update..."
@deprecation.deprecated_arg_values(date, instructions, deprecated=True)
def _fn(arg0, arg1, deprecated=True):
"""fn doc.
Args:
arg0: Arg 0.
arg1: Arg 1.
deprecated: Deprecated!
Returns:
Sum of args.
"""
return arg0 + arg1 if deprecated else arg1 + arg0
# Assert function docs are properly updated.
self.assertEqual("_fn", _fn.__name__)
self.assertEqual(
"fn doc. (deprecated arguments)"
"\n"
"\nSOME ARGUMENTS ARE DEPRECATED. They will be removed after %s."
"\nInstructions for updating:\n%s"
"\n"
"\n Args:"
"\n arg0: Arg 0."
"\n arg1: Arg 1."
"\n deprecated: Deprecated!"
"\n"
"\n Returns:"
"\n Sum of args."
"\n " % (date, instructions),
_fn.__doc__)
# Assert calling new fn with non-deprecated value logs nothing.
self.assertEqual(3, _fn(1, 2, deprecated=False))
self.assertEqual(0, mock_warning.call_count)
# Assert calling new fn with deprecated value issues log warning.
self.assertEqual(3, _fn(1, 2, deprecated=True))
self.assertEqual(1, mock_warning.call_count)
(args, _) = mock_warning.call_args
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
self._assert_subset(set([date, instructions]), set(args[1:]))
# Assert calling new fn with default deprecated value issues log warning.
self.assertEqual(3, _fn(1, 2))
self.assertEqual(2, mock_warning.call_count)
@tf.test.mock.patch.object(logging, "warning", autospec=True)
def test_static_fn_with_one_line_doc(self, mock_warning):
date = "2016-07-04"
instructions = "This is how you update..."
@deprecation.deprecated_arg_values(date, instructions, deprecated=True)
def _fn(arg0, arg1, deprecated=True):
"""fn doc."""
return arg0 + arg1 if deprecated else arg1 + arg0
# Assert function docs are properly updated.
self.assertEqual("_fn", _fn.__name__)
self.assertEqual(
"fn doc. (deprecated arguments)"
"\n"
"\nSOME ARGUMENTS ARE DEPRECATED. They will be removed after %s."
"\nInstructions for updating:\n%s" % (date, instructions),
_fn.__doc__)
# Assert calling new fn with non-deprecated value logs nothing.
self.assertEqual(3, _fn(1, 2, deprecated=False))
self.assertEqual(0, mock_warning.call_count)
# Assert calling new fn with deprecated value issues log warning.
self.assertEqual(3, _fn(1, 2, deprecated=True))
self.assertEqual(1, mock_warning.call_count)
(args, _) = mock_warning.call_args
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
self._assert_subset(set([date, instructions]), set(args[1:]))
# Assert calling new fn with default deprecated value issues log warning.
self.assertEqual(3, _fn(1, 2))
self.assertEqual(2, mock_warning.call_count)
@tf.test.mock.patch.object(logging, "warning", autospec=True)
def test_static_fn_no_doc(self, mock_warning):
date = "2016-07-04"
instructions = "This is how you update..."
@deprecation.deprecated_arg_values(date, instructions, deprecated=True)
def _fn(arg0, arg1, deprecated=True):
return arg0 + arg1 if deprecated else arg1 + arg0
# Assert function docs are properly updated.
self.assertEqual("_fn", _fn.__name__)
self.assertEqual(
"DEPRECATED FUNCTION ARGUMENTS"
"\n"
"\nSOME ARGUMENTS ARE DEPRECATED. They will be removed after %s."
"\nInstructions for updating:"
"\n%s" % (date, instructions),
_fn.__doc__)
# Assert calling new fn with non-deprecated value logs nothing.
self.assertEqual(3, _fn(1, 2, deprecated=False))
self.assertEqual(0, mock_warning.call_count)
# Assert calling new fn issues log warning.
self.assertEqual(3, _fn(1, 2, deprecated=True))
self.assertEqual(1, mock_warning.call_count)
(args, _) = mock_warning.call_args
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
self._assert_subset(set([date, instructions]), set(args[1:]))
# Assert calling new fn with default deprecated value issues log warning.
self.assertEqual(3, _fn(1, 2))
self.assertEqual(2, mock_warning.call_count)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()

View File

@ -27,6 +27,7 @@ from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import logging_ops from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.training import input as input_ops from tensorflow.python.training import input as input_ops
from tensorflow.python.training import queue_runner from tensorflow.python.training import queue_runner
@ -34,10 +35,8 @@ __all__ = ['stratified_sample',
'stratified_sample_unknown_dist',] 'stratified_sample_unknown_dist',]
# TODO(joelshor): Use an exponential-moving-average to estimate the initial def stratified_sample(tensors, labels, target_probs, batch_size,
# class distribution and remove the requirement that it be provided. init_probs=None, enqueue_many=False, queue_capacity=16,
def stratified_sample(tensors, labels, init_probs, target_probs, batch_size,
enqueue_many=False, queue_capacity=16,
threads_per_queue=1, name=None): threads_per_queue=1, name=None):
"""Stochastically creates batches based on per-class probabilities. """Stochastically creates batches based on per-class probabilities.
@ -52,11 +51,12 @@ def stratified_sample(tensors, labels, init_probs, target_probs, batch_size,
batch, according to enqueue_many. batch, according to enqueue_many.
labels: Tensor for label of data. Label is a single integer or a batch, labels: Tensor for label of data. Label is a single integer or a batch,
depending on enqueue_many. It is not a one-hot vector. depending on enqueue_many. It is not a one-hot vector.
init_probs: Class proportions in the data. An object whose type has a
registered Tensor conversion function.
target_probs: Target class proportions in batch. An object whose type has a target_probs: Target class proportions in batch. An object whose type has a
registered Tensor conversion function. registered Tensor conversion function.
batch_size: Size of batch to be returned. batch_size: Size of batch to be returned.
init_probs: Class proportions in the data. An object whose type has a
registered Tensor conversion function, or `None` for estimating the
initial distribution.
enqueue_many: Bool. If true, interpret input tensors as having a batch enqueue_many: Bool. If true, interpret input tensors as having a batch
dimension. dimension.
queue_capacity: Capacity of the large queue that holds input examples. queue_capacity: Capacity of the large queue that holds input examples.
@ -81,10 +81,9 @@ def stratified_sample(tensors, labels, init_probs, target_probs, batch_size,
data, label = data_provider.Get(['data', 'label']) data, label = data_provider.Get(['data', 'label'])
# Get stratified batch according to per-class probabilities. # Get stratified batch according to per-class probabilities.
init_probs = [1.0/NUM_CLASSES for _ in range(NUM_CLASSES)]
target_probs = [...distribution you want...] target_probs = [...distribution you want...]
[data_batch], labels = tf.contrib.framework.sampling_ops.stratified_sample( [data_batch], labels = tf.contrib.framework.sampling_ops.stratified_sample(
[data], label, init_probs, target_probs) [data], label, target_probs)
# Run batch through network. # Run batch through network.
... ...
@ -92,22 +91,34 @@ def stratified_sample(tensors, labels, init_probs, target_probs, batch_size,
with ops.op_scope(tensors + [labels], name, 'stratified_sample'): with ops.op_scope(tensors + [labels], name, 'stratified_sample'):
tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors) tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors)
labels = ops.convert_to_tensor(labels) labels = ops.convert_to_tensor(labels)
init_probs = ops.convert_to_tensor(init_probs, dtype=dtypes.float32)
target_probs = ops.convert_to_tensor(target_probs, dtype=dtypes.float32) target_probs = ops.convert_to_tensor(target_probs, dtype=dtypes.float32)
# Reduce the case of a single example to that of a batch of size 1. # Reduce the case of a single example to that of a batch of size 1.
if not enqueue_many: if not enqueue_many:
tensor_list = [array_ops.expand_dims(tensor, 0) for tensor in tensor_list] tensor_list = [array_ops.expand_dims(tensor, 0) for tensor in tensor_list]
labels = array_ops.expand_dims(labels, 0) labels = array_ops.expand_dims(labels, 0)
# If `init_probs` is `None`, set up online estimation of data distribution.
if init_probs is None:
# We use `target_probs` to get the number of classes, so its shape must be
# fully defined at graph construction time.
target_probs.get_shape().assert_is_fully_defined()
init_probs = _estimate_data_distribution(
labels, target_probs.get_shape().num_elements())
else:
init_probs = ops.convert_to_tensor(init_probs, dtype=dtypes.float32)
# Validate that input is consistent. # Validate that input is consistent.
tensor_list, labels, [init_probs, target_probs] = _verify_input( tensor_list, labels, [init_probs, target_probs] = _verify_input(
tensor_list, labels, [init_probs, target_probs]) tensor_list, labels, [init_probs, target_probs])
# Check that all zero initial probabilities also have zero target # Check that all zero initial probabilities also have zero target
# probabilities. # probabilities.
assert_op = logging_ops.Assert(math_ops.reduce_all(math_ops.logical_or( assert_op = logging_ops.Assert(
math_ops.reduce_all(math_ops.logical_or(
math_ops.not_equal(init_probs, 0), math_ops.not_equal(init_probs, 0),
math_ops.equal(target_probs, 0))), [init_probs, target_probs]) math_ops.equal(target_probs, 0))),
['All classes with zero initial probability must also have zero target '
'probability: ', init_probs, target_probs])
init_probs = control_flow_ops.with_dependencies([assert_op], init_probs) init_probs = control_flow_ops.with_dependencies([assert_op], init_probs)
# Calculate acceptance sampling probabilities. # Calculate acceptance sampling probabilities.
@ -212,6 +223,40 @@ def stratified_sample_unknown_dist(tensors, labels, probs, batch_size,
per_class_queues, probs, batch_size) per_class_queues, probs, batch_size)
def _estimate_data_distribution(labels, num_classes):
"""Estimate data distribution as labels are seen."""
# Variable to track running count of classes. Add 1 to avoid division-by-zero,
# and to guarantee that calculation of acceptance probabilities is (mostly)
# correct.
num_examples_per_class_seen = variables.Variable(
initial_value=[1] * num_classes, trainable=False, name='class_count',
dtype=dtypes.int64)
# Update the class-count based on what labels are seen in batch.
num_examples_per_class_seen = num_examples_per_class_seen.assign_add(
math_ops.reduce_sum(array_ops.one_hot(labels, num_classes,
dtype=dtypes.int64), 0))
# Normalize count into a probability.
# NOTE: Without the `+= 0` line below, the test
# `testMultiThreadedEstimateDataDistribution` fails. The reason is that
# before this line, `num_examples_per_class_seen` is a Tensor that shares a
# buffer with an underlying `ref` object. When the `ref` is changed by another
# thread, `num_examples_per_class_seen` changes as well. Since this can happen
# in the middle of the normalization computation, we get probabilities that
# are very far from summing to one. Adding `+= 0` copies the contents of the
# tensor to a new buffer, which will be consistent from the start to the end
# of the normalization computation.
num_examples_per_class_seen += 0
init_prob_estimate = math_ops.truediv(
num_examples_per_class_seen,
math_ops.reduce_sum(num_examples_per_class_seen))
# Must return float32 (not float64) to agree with downstream `_verify_input`
# checks.
return math_ops.cast(init_prob_estimate, dtypes.float32)
def _verify_input(tensor_list, labels, probs_list): def _verify_input(tensor_list, labels, probs_list):
"""Verify that batched inputs are well-formed.""" """Verify that batched inputs are well-formed."""
checked_probs_list = [] checked_probs_list = []

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.platform import tf_logging as logging
class SamplingOpsTest(tf.test.TestCase): class SamplingOpsTest(tf.test.TestCase):
@ -33,15 +34,22 @@ class SamplingOpsTest(tf.test.TestCase):
# Curry the rejection sampler so we can easily run the same tests on both # Curry the rejection sampler so we can easily run the same tests on both
# stratified_sample and stratified_sample_unknown_dist. # stratified_sample and stratified_sample_unknown_dist.
def curried_sampler(val, lbls, probs, batch, enqueue_many=True): def curried_sampler(tensors, labels, probs, batch_size, enqueue_many=True):
return tf.contrib.framework.sampling_ops.stratified_sample( return tf.contrib.framework.sampling_ops.stratified_sample(
val, lbls, initial_p, probs, batch, enqueue_many=enqueue_many) tensors=tensors,
labels=labels,
target_probs=probs,
batch_size=batch_size,
init_probs=initial_p,
enqueue_many=enqueue_many)
samplers = [ samplers = [
tf.contrib.framework.sampling_ops.stratified_sample_unknown_dist, tf.contrib.framework.sampling_ops.stratified_sample_unknown_dist,
curried_sampler, curried_sampler,
] ]
for sampler in samplers: for sampler in samplers:
logging.info('Now testing `%s`', sampler.__class__.__name__)
# Label must have only batch dimension if enqueue_many is True. # Label must have only batch dimension if enqueue_many is True.
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
sampler(val, tf.zeros([]), probs, batch_size, enqueue_many=True) sampler(val, tf.zeros([]), probs, batch_size, enqueue_many=True)
@ -70,20 +78,21 @@ class SamplingOpsTest(tf.test.TestCase):
# Probabilities shape must be fully defined. # Probabilities shape must be fully defined.
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
sampler(val, label, tf.placeholder(tf.float32, shape=[None]), sampler(
batch_size) val, label, tf.placeholder(
tf.float32, shape=[None]), batch_size)
# In the rejection sampling case, make sure that probability lengths are # In the rejection sampling case, make sure that probability lengths are
# the same. # the same.
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
tf.contrib.framework.sampling_ops.stratified_sample( tf.contrib.framework.sampling_ops.stratified_sample(
val, label, [.2] * 5, [.1] * 10, batch_size) val, label, [.1] * 10, batch_size, init_probs=[.2] * 5)
# In the rejection sampling case, make sure that zero initial probability # In the rejection sampling case, make sure that zero initial probability
# classes also have zero target probability. # classes also have zero target probability.
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
tf.contrib.framework.sampling_ops.stratified_sample( tf.contrib.framework.sampling_ops.stratified_sample(
val, label, [0, .5, .5], [.2, .4, .4], batch_size) val, label, [.2, .4, .4], batch_size, init_probs=[0, .5, .5])
# Probabilities must be 1D. # Probabilities must be 1D.
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@ -116,14 +125,16 @@ class SamplingOpsTest(tf.test.TestCase):
# Run session that should fail. # Run session that should fail.
with self.test_session() as sess: with self.test_session() as sess:
with self.assertRaises(tf.errors.InvalidArgumentError): with self.assertRaises(tf.errors.InvalidArgumentError):
sess.run([val_tf, lbl_tf], feed_dict={label_ph: illegal_label, sess.run([val_tf, lbl_tf],
feed_dict={label_ph: illegal_label,
probs_ph: valid_probs}) probs_ph: valid_probs})
for illegal_prob in illegal_probs: for illegal_prob in illegal_probs:
# Run session that should fail. # Run session that should fail.
with self.test_session() as sess: with self.test_session() as sess:
with self.assertRaises(tf.errors.InvalidArgumentError): with self.assertRaises(tf.errors.InvalidArgumentError):
sess.run([prob_tf], feed_dict={label_ph: valid_labels, sess.run([prob_tf],
feed_dict={label_ph: valid_labels,
probs_ph: illegal_prob}) probs_ph: illegal_prob})
def batchingBehaviorHelper(self, sampler): def batchingBehaviorHelper(self, sampler):
@ -152,15 +163,14 @@ class SamplingOpsTest(tf.test.TestCase):
lbl_input_batch = tf.ones([], dtype=tf.int32) lbl_input_batch = tf.ones([], dtype=tf.int32)
probs = np.array([0, 1, 0, 0, 0]) probs = np.array([0, 1, 0, 0, 0])
batches = tf.contrib.framework.sampling_ops.stratified_sample( batches = tf.contrib.framework.sampling_ops.stratified_sample(
val_input_batch, lbl_input_batch, probs, probs, batch_size) val_input_batch, lbl_input_batch, probs, batch_size, init_probs=probs)
batches += tf.contrib.framework.sampling_ops.stratified_sample( batches += tf.contrib.framework.sampling_ops.stratified_sample(
val_input_batch, lbl_input_batch, probs, probs, batch_size) val_input_batch, lbl_input_batch, probs, batch_size, init_probs=probs)
batches += tf.contrib.framework.sampling_ops.stratified_sample_unknown_dist( batches += tf.contrib.framework.sampling_ops.stratified_sample_unknown_dist(
val_input_batch, lbl_input_batch, probs, batch_size) val_input_batch, lbl_input_batch, probs, batch_size)
batches += tf.contrib.framework.sampling_ops.stratified_sample_unknown_dist( batches += tf.contrib.framework.sampling_ops.stratified_sample_unknown_dist(
val_input_batch, lbl_input_batch, probs, batch_size) val_input_batch, lbl_input_batch, probs, batch_size)
summary_op = tf.merge_summary(tf.get_collection( summary_op = tf.merge_summary(tf.get_collection(tf.GraphKeys.SUMMARIES))
tf.GraphKeys.SUMMARIES))
with self.test_session() as sess: with self.test_session() as sess:
coord = tf.train.Coordinator() coord = tf.train.Coordinator()
@ -177,9 +187,15 @@ class SamplingOpsTest(tf.test.TestCase):
def testRejectionBatchingBehavior(self): def testRejectionBatchingBehavior(self):
initial_p = [0, .3, 0, .7, 0] initial_p = [0, .3, 0, .7, 0]
def curried_sampler(val, lbls, probs, batch, enqueue_many=True): def curried_sampler(val, lbls, probs, batch, enqueue_many=True):
return tf.contrib.framework.sampling_ops.stratified_sample( return tf.contrib.framework.sampling_ops.stratified_sample(
val, lbls, initial_p, probs, batch, enqueue_many=enqueue_many) val,
lbls,
probs,
batch,
init_probs=initial_p,
enqueue_many=enqueue_many)
self.batchingBehaviorHelper(curried_sampler) self.batchingBehaviorHelper(curried_sampler)
@ -190,8 +206,7 @@ class SamplingOpsTest(tf.test.TestCase):
lbl2 = 3 lbl2 = 3
# This cond allows the necessary class queues to be populated. # This cond allows the necessary class queues to be populated.
label = tf.cond( label = tf.cond(
tf.greater(.5, tf.random_uniform([])), tf.greater(.5, tf.random_uniform([])), lambda: tf.constant(lbl1),
lambda: tf.constant(lbl1),
lambda: tf.constant(lbl2)) lambda: tf.constant(lbl2))
val = [np.array([1, 4]) * label] val = [np.array([1, 4]) * label]
probs = tf.placeholder(tf.float32, shape=[5]) probs = tf.placeholder(tf.float32, shape=[5])
@ -225,7 +240,7 @@ class SamplingOpsTest(tf.test.TestCase):
def testBatchDimensionNotRequired(self): def testBatchDimensionNotRequired(self):
classes = 5 classes = 5
# Probs must be a tensor, since we pass it directly to _verify_input. # Probs must be a tensor, since we pass it directly to _verify_input.
probs = tf.constant([1.0/classes] * classes) probs = tf.constant([1.0 / classes] * classes)
# Make sure that these vals/labels pairs don't throw any runtime exceptions. # Make sure that these vals/labels pairs don't throw any runtime exceptions.
legal_input_pairs = [ legal_input_pairs = [
@ -243,7 +258,8 @@ class SamplingOpsTest(tf.test.TestCase):
# Run graph to make sure there are no shape-related runtime errors. # Run graph to make sure there are no shape-related runtime errors.
for vals, labels in legal_input_pairs: for vals, labels in legal_input_pairs:
with self.test_session() as sess: with self.test_session() as sess:
sess.run([val_tf, labels_tf], feed_dict={vals_ph: vals, sess.run([val_tf, labels_tf],
feed_dict={vals_ph: vals,
labels_ph: labels}) labels_ph: labels})
def dataListHelper(self, sampler): def dataListHelper(self, sampler):
@ -251,8 +267,8 @@ class SamplingOpsTest(tf.test.TestCase):
val_input_batch = [tf.zeros([2, 3, 4]), tf.ones([2, 4]), tf.ones(2) * 3] val_input_batch = [tf.zeros([2, 3, 4]), tf.ones([2, 4]), tf.ones(2) * 3]
lbl_input_batch = tf.ones([], dtype=tf.int32) lbl_input_batch = tf.ones([], dtype=tf.int32)
probs = np.array([0, 1, 0, 0, 0]) probs = np.array([0, 1, 0, 0, 0])
val_list, lbls = sampler( val_list, lbls = sampler(val_input_batch, lbl_input_batch, probs,
val_input_batch, lbl_input_batch, probs, batch_size) batch_size)
# Check output shapes. # Check output shapes.
self.assertTrue(isinstance(val_list, list)) self.assertTrue(isinstance(val_list, list))
@ -277,9 +293,16 @@ class SamplingOpsTest(tf.test.TestCase):
def testRejectionDataListInput(self): def testRejectionDataListInput(self):
initial_p = [0, 1, 0, 0, 0] initial_p = [0, 1, 0, 0, 0]
def curried_sampler(val, lbls, probs, batch, enqueue_many=False): def curried_sampler(val, lbls, probs, batch, enqueue_many=False):
return tf.contrib.framework.sampling_ops.stratified_sample( return tf.contrib.framework.sampling_ops.stratified_sample(
val, lbls, initial_p, probs, batch, enqueue_many=enqueue_many) val,
lbls,
probs,
batch,
init_probs=initial_p,
enqueue_many=enqueue_many)
self.dataListHelper(curried_sampler) self.dataListHelper(curried_sampler)
def normalBehaviorHelper(self, sampler): def normalBehaviorHelper(self, sampler):
@ -289,8 +312,7 @@ class SamplingOpsTest(tf.test.TestCase):
lbl2 = 3 lbl2 = 3
# This cond allows the necessary class queues to be populated. # This cond allows the necessary class queues to be populated.
label = tf.cond( label = tf.cond(
tf.greater(.5, tf.random_uniform([])), tf.greater(.5, tf.random_uniform([])), lambda: tf.constant(lbl1),
lambda: tf.constant(lbl1),
lambda: tf.constant(lbl2)) lambda: tf.constant(lbl2))
val = [np.array([1, 4]) * label] val = [np.array([1, 4]) * label]
probs = np.array([.8, 0, 0, .2, 0]) probs = np.array([.8, 0, 0, .2, 0])
@ -302,6 +324,9 @@ class SamplingOpsTest(tf.test.TestCase):
data_l = [] data_l = []
label_l = [] label_l = []
with self.test_session() as sess: with self.test_session() as sess:
# Need to initialize variables that keep running total of classes seen.
tf.initialize_all_variables().run()
coord = tf.train.Coordinator() coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord) threads = tf.train.start_queue_runners(coord=coord)
@ -329,7 +354,7 @@ class SamplingOpsTest(tf.test.TestCase):
# is fixed, for a given implementation, this test will pass or fail 100% of # is fixed, for a given implementation, this test will pass or fail 100% of
# the time. This use of assertNear is to cover cases where someone changes # the time. This use of assertNear is to cover cases where someone changes
# an implementation detail, which would cause the random behavior to differ. # an implementation detail, which would cause the random behavior to differ.
self.assertNear(actual_lbl, expected_label, 3*lbl_std_dev_of_mean) self.assertNear(actual_lbl, expected_label, 3 * lbl_std_dev_of_mean)
def testNormalBehavior(self): def testNormalBehavior(self):
self.normalBehaviorHelper( self.normalBehaviorHelper(
@ -337,10 +362,26 @@ class SamplingOpsTest(tf.test.TestCase):
def testRejectionNormalBehavior(self): def testRejectionNormalBehavior(self):
initial_p = [.7, 0, 0, .3, 0] initial_p = [.7, 0, 0, .3, 0]
def curried_sampler(val, lbls, probs, batch, enqueue_many=False): def curried_sampler(val, lbls, probs, batch, enqueue_many=False):
return tf.contrib.framework.sampling_ops.stratified_sample( return tf.contrib.framework.sampling_ops.stratified_sample(
val, lbls, initial_p, probs, batch, enqueue_many=enqueue_many) val,
lbls,
probs,
batch,
init_probs=initial_p,
enqueue_many=enqueue_many)
self.normalBehaviorHelper(curried_sampler) self.normalBehaviorHelper(curried_sampler)
def testRejectionNormalBehaviorWithOnlineInitPEstimate(self):
def curried_sampler(val, lbls, probs, batch, enqueue_many=False):
return tf.contrib.framework.sampling_ops.stratified_sample(
val, lbls, probs, batch, init_probs=None, enqueue_many=enqueue_many)
self.normalBehaviorHelper(curried_sampler)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()

View File

@ -0,0 +1,65 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=unused-import
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
class SamplingOpsThreadingTest(tf.test.TestCase):
def testMultiThreadedEstimateDataDistribution(self):
num_classes = 10
# Set up graph.
tf.set_random_seed(1234)
label = tf.cast(tf.round(tf.random_uniform([1]) * num_classes), tf.int32)
prob_estimate = tf.contrib.framework.sampling_ops._estimate_data_distribution( # pylint: disable=line-too-long
label, num_classes)
# Check that prob_estimate is well-behaved in a multithreaded context.
_, _, [prob_estimate] = tf.contrib.framework.sampling_ops._verify_input(
[], label, [prob_estimate])
# Use queues to run multiple threads over the graph, each of which
# fetches `prob_estimate`.
queue = tf.FIFOQueue(
capacity=25,
dtypes=[prob_estimate.dtype],
shapes=[prob_estimate.get_shape()])
enqueue_op = queue.enqueue([prob_estimate])
tf.train.add_queue_runner(tf.train.QueueRunner(queue, [enqueue_op] * 25))
out_tensor = queue.dequeue()
# Run the multi-threaded session.
with self.test_session() as sess:
# Need to initialize variables that keep running total of classes seen.
tf.initialize_all_variables().run()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for _ in range(25):
sess.run([out_tensor])
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
tf.test.main()

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow { namespace tensorflow {
REGISTER_OP("SparseFeatureCross") REGISTER_OP("SparseFeatureCross")
@ -31,6 +32,12 @@ REGISTER_OP("SparseFeatureCross")
.Attr("dense_types: list({int64, string}) >= 0") .Attr("dense_types: list({int64, string}) >= 0")
.Attr("out_type: {int64, string}") .Attr("out_type: {int64, string}")
.Attr("internal_type: {int64, string}") .Attr("internal_type: {int64, string}")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Matrix(c->UnknownDim(), 2));
c->set_output(1, c->Vector(c->UnknownDim()));
c->set_output(2, c->Vector(2));
return Status::OK();
})
.Doc(R"doc( .Doc(R"doc(
Generates sparse cross form a list of sparse tensors. Generates sparse cross form a list of sparse tensors.

View File

@ -193,35 +193,36 @@ class _SparseColumn(_FeatureColumn,
combiner="sum", combiner="sum",
dtype=dtypes.string): dtype=dtypes.string):
if is_integerized and bucket_size is None: if is_integerized and bucket_size is None:
raise ValueError("bucket_size should be set if is_integerized=True. " raise ValueError("bucket_size must be set if is_integerized is True. "
"column_name: {}".format(column_name)) "column_name: {}".format(column_name))
if is_integerized and not dtype.is_integer: if is_integerized and not dtype.is_integer:
raise ValueError("dtype should be an integer if is_integerized is True. " raise ValueError("dtype must be an integer if is_integerized is True. "
"Column {}.".format(column_name)) "dtype: {}, column_name: {}.".format(dtype, column_name))
if bucket_size is None and lookup_config is None: if bucket_size is None and lookup_config is None:
raise ValueError("one of bucket_size or lookup_config should be " raise ValueError("one of bucket_size or lookup_config must be set. "
"set. column_name: {}".format(column_name)) "column_name: {}".format(column_name))
if bucket_size is not None and lookup_config: if bucket_size is not None and lookup_config:
raise ValueError("one and only one of bucket_size or lookup_config " raise ValueError("one and only one of bucket_size or lookup_config "
"should be set. column_name: {}".format(column_name)) "must be set. column_name: {}".format(column_name))
if bucket_size is not None and bucket_size < 2: if bucket_size is not None and bucket_size < 2:
raise ValueError("bucket_size should be at least 2. " raise ValueError("bucket_size must be at least 2. "
"column_name: {}".format(column_name)) "bucket_size: {}, column_name: {}".format(bucket_size,
column_name))
if ((lookup_config) and if ((lookup_config) and
(not isinstance(lookup_config, _SparseIdLookupConfig))): (not isinstance(lookup_config, _SparseIdLookupConfig))):
raise TypeError( raise TypeError(
"lookup_config should be an instance of _SparseIdLookupConfig. " "lookup_config must be an instance of _SparseIdLookupConfig. "
"Given one is in type {} for column_name {}".format( "Given one is in type {} for column_name {}".format(
type(lookup_config), column_name)) type(lookup_config), column_name))
if (lookup_config and lookup_config.vocabulary_file and if (lookup_config and lookup_config.vocabulary_file and
lookup_config.vocab_size is None): lookup_config.vocab_size is None):
raise ValueError("vocab_size should be defined. " raise ValueError("vocab_size must be defined. "
"column_name: {}".format(column_name)) "column_name: {}".format(column_name))
return super(_SparseColumn, cls).__new__(cls, column_name, is_integerized, return super(_SparseColumn, cls).__new__(cls, column_name, is_integerized,
@ -262,8 +263,8 @@ class _SparseColumn(_FeatureColumn,
input_tensor, input_tensor,
weight_collections=None, weight_collections=None,
trainable=True): trainable=True):
raise ValueError("Column {} is not supported in DNN. " raise ValueError("SparseColumn is not supported in DNN. "
"Please use embedding_column.".format(self)) "Please use embedding_column. column: {}".format(self))
def to_weighted_sum(self, def to_weighted_sum(self,
input_tensor, input_tensor,
@ -279,7 +280,7 @@ class _SparseColumn(_FeatureColumn,
initializer=init_ops.zeros_initializer, initializer=init_ops.zeros_initializer,
combiner=self.combiner, combiner=self.combiner,
trainable=trainable, trainable=trainable,
name=self.name + "_weights") name=self.name)
class _SparseColumnIntegerized(_SparseColumn): class _SparseColumnIntegerized(_SparseColumn):
@ -291,8 +292,8 @@ class _SparseColumnIntegerized(_SparseColumn):
combiner="sum", combiner="sum",
dtype=dtypes.int64): dtype=dtypes.int64):
if not dtype.is_integer: if not dtype.is_integer:
raise ValueError("dtype should be an integer. Given {}".format( raise ValueError("dtype must be an integer. "
column_name)) "dtype: {}, column_name: {}".format(dtype, column_name))
return super(_SparseColumnIntegerized, cls).__new__(cls, return super(_SparseColumnIntegerized, cls).__new__(cls,
column_name, column_name,
@ -507,8 +508,8 @@ class _WeightedSparseColumn(_FeatureColumn, collections.namedtuple(
input_tensor, input_tensor,
weight_collections=None, weight_collections=None,
trainable=True): trainable=True):
raise ValueError("Column {} is not supported in DNN. " raise ValueError("WeightedSparseColumn is not supported in DNN. "
"Please use embedding_column.".format(self)) "Please use embedding_column. column: {}".format(self))
def to_weighted_sum(self, def to_weighted_sum(self,
input_tensor, input_tensor,
@ -524,7 +525,7 @@ class _WeightedSparseColumn(_FeatureColumn, collections.namedtuple(
initializer=init_ops.zeros_initializer, initializer=init_ops.zeros_initializer,
combiner=self.sparse_id_column.combiner, combiner=self.sparse_id_column.combiner,
trainable=trainable, trainable=trainable,
name=self.name + "_weights") name=self.name)
def weighted_sparse_column(sparse_id_column, def weighted_sparse_column(sparse_id_column,
@ -609,7 +610,9 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
ckpt_to_load_from=None, ckpt_to_load_from=None,
tensor_name_in_ckpt=None): tensor_name_in_ckpt=None):
if initializer is not None and not callable(initializer): if initializer is not None and not callable(initializer):
raise ValueError("initializer must be callable if specified.") raise ValueError("initializer must be callable if specified. "
"Embedding of column_name: {}".format(
sparse_id_column.name))
if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
raise ValueError("Must specify both `ckpt_to_load_from` and " raise ValueError("Must specify both `ckpt_to_load_from` and "
@ -674,7 +677,7 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
initializer=self.initializer, initializer=self.initializer,
combiner=self.combiner, combiner=self.combiner,
trainable=trainable, trainable=trainable,
name=self.name + "_weights") name=self.name)
if self.ckpt_to_load_from is not None: if self.ckpt_to_load_from is not None:
weights_to_restore = embedding_weights weights_to_restore = embedding_weights
if len(embedding_weights) == 1: if len(embedding_weights) == 1:
@ -690,8 +693,8 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
num_outputs=1, num_outputs=1,
weight_collections=None, weight_collections=None,
trainable=True): trainable=True):
raise ValueError("Column {} is not supported in linear models. " raise ValueError("EmbeddingColumn is not supported in linear models. "
"Please use sparse_column.".format(self)) "Please use sparse_column. column: {}".format(self))
def embedding_column(sparse_id_column, def embedding_column(sparse_id_column,
@ -744,7 +747,8 @@ class _HashedEmbeddingColumn(collections.namedtuple(
combiner="mean", combiner="mean",
initializer=None): initializer=None):
if initializer is not None and not callable(initializer): if initializer is not None and not callable(initializer):
raise ValueError("initializer must be callable if specified.") raise ValueError("initializer must be callable if specified. "
"column_name: {}".format(column_name))
if initializer is None: if initializer is None:
stddev = 0.1 stddev = 0.1
# TODO(b/25671353): Better initial value? # TODO(b/25671353): Better initial value?
@ -770,7 +774,7 @@ class _HashedEmbeddingColumn(collections.namedtuple(
weight_collections=None, weight_collections=None,
trainable=True): trainable=True):
embeddings = _create_embeddings( embeddings = _create_embeddings(
name=self.name + "_weights", name=self.name,
shape=[self.size], shape=[self.size],
initializer=self.initializer, initializer=self.initializer,
dtype=dtypes.float32, dtype=dtypes.float32,
@ -815,10 +819,14 @@ def hashed_embedding_column(column_name,
""" """
if (dimension < 1) or (size < 1): if (dimension < 1) or (size < 1):
raise ValueError("Dimension and size must be greater than 0.") raise ValueError("Dimension and size must be greater than 0. "
"dimension: {}, size: {}, column_name: {}".format(
dimension, size, column_name))
if combiner not in ("mean", "sqrtn", "sum"): if combiner not in ("mean", "sqrtn", "sum"):
raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'.") raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'. "
"combiner: {}, column_name: {}".format(
combiner, column_name))
return _HashedEmbeddingColumn(column_name, size, dimension, combiner, return _HashedEmbeddingColumn(column_name, size, dimension, combiner,
initializer) initializer)
@ -929,14 +937,18 @@ def real_valued_column(column_name,
""" """
if not isinstance(dimension, int): if not isinstance(dimension, int):
raise TypeError("dimension must be an integer") raise TypeError("dimension must be an integer. "
"dimension: {}, column_name: {}".format(dimension,
column_name))
if dimension < 1: if dimension < 1:
raise ValueError("dimension must be greater than 0") raise ValueError("dimension must be greater than 0. "
"dimension: {}, column_name: {}".format(dimension,
column_name))
if not (dtype.is_integer or dtype.is_floating): if not (dtype.is_integer or dtype.is_floating):
raise ValueError("dtype is not convertible to tf.float32. Given {}".format( raise ValueError("dtype must be convertible to float. "
dtype)) "dtype: {}, column_name: {}".format(dtype, column_name))
if default_value is None: if default_value is None:
return _RealValuedColumn(column_name, dimension, default_value, dtype) return _RealValuedColumn(column_name, dimension, default_value, dtype)
@ -957,9 +969,10 @@ def real_valued_column(column_name,
if isinstance(default_value, list): if isinstance(default_value, list):
if len(default_value) != dimension: if len(default_value) != dimension:
raise ValueError("The length of default_value is not equal to the " raise ValueError(
"value of dimension. default_value is {}.".format( "The length of default_value must be equal to dimension. "
default_value)) "default_value: {}, dimension: {}, column_name: {}".format(
default_value, dimension, column_name))
# Check if the values in the list are all integers or are convertible to # Check if the values in the list are all integers or are convertible to
# floats. # floats.
is_list_all_int = True is_list_all_int = True
@ -980,8 +993,9 @@ def real_valued_column(column_name,
default_value = [float(v) for v in default_value] default_value = [float(v) for v in default_value]
return _RealValuedColumn(column_name, dimension, default_value, dtype) return _RealValuedColumn(column_name, dimension, default_value, dtype)
raise TypeError("default_value is not compatible with dtype. " raise TypeError("default_value must be compatible with dtype. "
"default_value is {}.".format(default_value)) "default_value: {}, dtype: {}, column_name: {}".format(
default_value, dtype, column_name))
class _BucketizedColumn(_FeatureColumn, collections.namedtuple( class _BucketizedColumn(_FeatureColumn, collections.namedtuple(
@ -1008,10 +1022,12 @@ class _BucketizedColumn(_FeatureColumn, collections.namedtuple(
def __new__(cls, source_column, boundaries): def __new__(cls, source_column, boundaries):
if not isinstance(source_column, _RealValuedColumn): if not isinstance(source_column, _RealValuedColumn):
raise TypeError( raise TypeError(
"source_column should be an instance of _RealValuedColumn.") "source_column must be an instance of _RealValuedColumn. "
"source_column: {}".format(source_column))
if not isinstance(boundaries, list) or not boundaries: if not isinstance(boundaries, list) or not boundaries:
raise ValueError("boundaries must be a list and it should not be empty.") raise ValueError("boundaries must be a non-empty list. "
"boundaries: {}".format(boundaries))
# We allow bucket boundaries to be monotonically increasing # We allow bucket boundaries to be monotonically increasing
# (ie a[i+1] >= a[i]). When two bucket boundaries are the same, we # (ie a[i+1] >= a[i]). When two bucket boundaries are the same, we
@ -1023,7 +1039,8 @@ class _BucketizedColumn(_FeatureColumn, collections.namedtuple(
elif boundaries[i] < boundaries[i + 1]: elif boundaries[i] < boundaries[i + 1]:
sanitized_boundaries.append(boundaries[i]) sanitized_boundaries.append(boundaries[i])
else: else:
raise ValueError("boundaries must be a sorted list") raise ValueError("boundaries must be a sorted list. "
"boundaries: {}".format(boundaries))
sanitized_boundaries.append(boundaries[len(boundaries) - 1]) sanitized_boundaries.append(boundaries[len(boundaries) - 1])
return super(_BucketizedColumn, cls).__new__(cls, source_column, return super(_BucketizedColumn, cls).__new__(cls, source_column,
@ -1104,7 +1121,7 @@ class _BucketizedColumn(_FeatureColumn, collections.namedtuple(
initializer=init_ops.zeros_initializer, initializer=init_ops.zeros_initializer,
combiner="sum", combiner="sum",
trainable=trainable, trainable=trainable,
name=self.name + "_weights") name=self.name)
def bucketized_column(source_column, boundaries): def bucketized_column(source_column, boundaries):
@ -1186,18 +1203,21 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple(
ckpt_to_load_from=None, tensor_name_in_ckpt=None): ckpt_to_load_from=None, tensor_name_in_ckpt=None):
for column in columns: for column in columns:
if not _CrossedColumn._is_crossable(column): if not _CrossedColumn._is_crossable(column):
raise TypeError("columns should be a set of " raise TypeError("columns must be a set of _SparseColumn, "
"_SparseColumn, _CrossedColumn, or _BucketizedColumn. " "_CrossedColumn, or _BucketizedColumn instances. "
"Column is {}".format(column)) "column: {}".format(column))
if len(columns) < 2: if len(columns) < 2:
raise ValueError("columns should contain at least 2 elements.") raise ValueError("columns must contain at least 2 elements. "
"columns: {}".format(columns))
if not isinstance(hash_bucket_size, int): if not isinstance(hash_bucket_size, int):
raise TypeError("hash_bucket_size should be an int.") raise TypeError("hash_bucket_size must be an int. "
"hash_bucket_size: {}".format(hash_bucket_size))
if hash_bucket_size < 2: if hash_bucket_size < 2:
raise ValueError("hash_bucket_size should be at least 2.") raise ValueError("hash_bucket_size must be at least 2. "
"hash_bucket_size: {}".format(hash_bucket_size))
if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
raise ValueError("Must specify both `ckpt_to_load_from` and " raise ValueError("Must specify both `ckpt_to_load_from` and "
@ -1275,8 +1295,8 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple(
input_tensor, input_tensor,
weight_collections=None, weight_collections=None,
trainable=True): trainable=True):
raise ValueError("Column {} is not supported in DNN. " raise ValueError("CrossedColumn is not supported in DNN. "
"Please use embedding_column.".format(self)) "Please use embedding_column. column: {}".format(self))
def to_weighted_sum(self, def to_weighted_sum(self,
input_tensor, input_tensor,
@ -1292,7 +1312,7 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple(
initializer=init_ops.zeros_initializer, initializer=init_ops.zeros_initializer,
combiner=self.combiner, combiner=self.combiner,
trainable=trainable, trainable=trainable,
name=self.name + "_weights") name=self.name)
if self.ckpt_to_load_from is not None: if self.ckpt_to_load_from is not None:
weights_to_restore = embedding_weights weights_to_restore = embedding_weights
if len(embedding_weights) == 1: if len(embedding_weights) == 1:
@ -1337,7 +1357,7 @@ def crossed_column(columns, hash_bucket_size, combiner="sum",
class DataFrameColumn(_FeatureColumn, class DataFrameColumn(_FeatureColumn,
collections.namedtuple("DataFrameColumn", collections.namedtuple("DataFrameColumn",
["name", "series"])): ["column_name", "series"])):
"""Represents a feature column produced from a `DataFrame`. """Represents a feature column produced from a `DataFrame`.
Instances of this class are immutable. A `DataFrame` column may be dense or Instances of this class are immutable. A `DataFrame` column may be dense or
@ -1345,13 +1365,17 @@ class DataFrameColumn(_FeatureColumn,
batch_size. batch_size.
Args: Args:
name: a name for this column column_name: a name for this column
series: a `Series` to be wrapped, which has already had its base features series: a `Series` to be wrapped, which has already had its base features
substituted with `PredefinedSeries`. substituted with `PredefinedSeries`.
""" """
def __new__(cls, name, series): def __new__(cls, column_name, series):
return super(DataFrameColumn, cls).__new__(cls, name, series) return super(DataFrameColumn, cls).__new__(cls, column_name, series)
@property
def name(self):
return self.column_name
@property @property
def config(self): def config(self):
@ -1379,6 +1403,16 @@ class DataFrameColumn(_FeatureColumn,
input_tensor, input_tensor,
weight_collections=None, weight_collections=None,
trainable=True): trainable=True):
# DataFrame typically provides Tensors of shape [batch_size],
# but Estimator requires shape [batch_size, 1]
dims = input_tensor.get_shape().ndims
if dims == 0:
raise ValueError(
"Can't build input layer from tensor of shape (): {}".format(
self.column_name))
elif dims == 1:
return array_ops.expand_dims(input_tensor, 1)
else:
return input_tensor return input_tensor
# TODO(soergel): This mirrors RealValuedColumn for now, but should become # TODO(soergel): This mirrors RealValuedColumn for now, but should become
@ -1547,7 +1581,7 @@ def _create_embeddings(name, shape, dtype, initializer, trainable,
with just one variable. with just one variable.
Args: Args:
name: A string specifying the name of the embedding variable. name: A string. The name of the embedding variable will be name + _weights.
shape: shape of the embeddding. Note this is not the shape of partitioned shape: shape of the embeddding. Note this is not the shape of partitioned
variables. variables.
dtype: type of the embedding. Also the shape of each partitioned variable. dtype: type of the embedding. Also the shape of each partitioned variable.
@ -1609,7 +1643,7 @@ def _create_embedding_lookup(input_tensor, weight_tensor, vocab_size, dimension,
A Tensor with shape [batch_size, dimension] and embedding Variable. A Tensor with shape [batch_size, dimension] and embedding Variable.
""" """
embeddings = _create_embeddings(name=name, embeddings = _create_embeddings(name=name + "_weights",
shape=[vocab_size, dimension], shape=[vocab_size, dimension],
dtype=dtypes.float32, dtype=dtypes.float32,
initializer=initializer, initializer=initializer,
@ -1621,4 +1655,4 @@ def _create_embedding_lookup(input_tensor, weight_tensor, vocab_size, dimension,
sparse_weights=weight_tensor, sparse_weights=weight_tensor,
default_id=0, default_id=0,
combiner=combiner, combiner=combiner,
name=name), embeddings name=name + "_weights"), embeddings

View File

@ -60,14 +60,17 @@ class FeatureColumnTest(tf.test.TestCase):
self.assertEqual(b.dimension, 10) self.assertEqual(b.dimension, 10)
self.assertTrue(b.default_value is None) self.assertTrue(b.default_value is None)
# dimension is an integer with self.assertRaisesRegexp(TypeError, "dimension must be an integer"):
with self.assertRaises(TypeError):
tf.contrib.layers.real_valued_column("d3", dimension=1.0) tf.contrib.layers.real_valued_column("d3", dimension=1.0)
# dimension is a positive integer with self.assertRaisesRegexp(ValueError,
with self.assertRaises(ValueError): "dimension must be greater than 0"):
tf.contrib.layers.real_valued_column("d3", dimension=0) tf.contrib.layers.real_valued_column("d3", dimension=0)
with self.assertRaisesRegexp(ValueError,
"dtype must be convertible to float"):
tf.contrib.layers.real_valued_column("d3", dtype=tf.string)
# default_value is an integer. # default_value is an integer.
c1 = tf.contrib.layers.real_valued_column("c1", default_value=2) c1 = tf.contrib.layers.real_valued_column("c1", default_value=2)
self.assertListEqual(list(c1.default_value), [2.]) self.assertListEqual(list(c1.default_value), [2.])
@ -92,15 +95,18 @@ class FeatureColumnTest(tf.test.TestCase):
dimension=4, dimension=4,
default_value=2.) default_value=2.)
self.assertListEqual(list(d2.default_value), [2., 2., 2., 2.]) self.assertListEqual(list(d2.default_value), [2., 2., 2., 2.])
with self.assertRaises(TypeError): with self.assertRaisesRegexp(TypeError,
"default_value must be compatible with dtype"):
tf.contrib.layers.real_valued_column("d3", tf.contrib.layers.real_valued_column("d3",
default_value=2., default_value=2.,
dtype=tf.int32) dtype=tf.int32)
# default_value is neither interger nor float. # default_value is neither integer nor float.
with self.assertRaises(TypeError): with self.assertRaisesRegexp(
TypeError, "default_value must be compatible with dtype"):
tf.contrib.layers.real_valued_column("e1", default_value="string") tf.contrib.layers.real_valued_column("e1", default_value="string")
with self.assertRaises(TypeError): with self.assertRaisesRegexp(
TypeError, "default_value must be compatible with dtype"):
tf.contrib.layers.real_valued_column("e1", tf.contrib.layers.real_valued_column("e1",
dimension=3, dimension=3,
default_value=[1, 3., "string"]) default_value=[1, 3., "string"])
@ -125,11 +131,13 @@ class FeatureColumnTest(tf.test.TestCase):
dimension=3, dimension=3,
default_value=[2., 2, 2]) default_value=[2., 2, 2])
self.assertListEqual(list(g2.default_value), [2., 2., 2.]) self.assertListEqual(list(g2.default_value), [2., 2., 2.])
with self.assertRaises(TypeError): with self.assertRaisesRegexp(
TypeError, "default_value must be compatible with dtype"):
tf.contrib.layers.real_valued_column("g3", tf.contrib.layers.real_valued_column("g3",
default_value=[2.], default_value=[2.],
dtype=tf.int32) dtype=tf.int32)
with self.assertRaises(ValueError): with self.assertRaisesRegexp(
ValueError, "The length of default_value must be equal to dimension"):
tf.contrib.layers.real_valued_column("g4", tf.contrib.layers.real_valued_column("g4",
dimension=3, dimension=3,
default_value=[2.]) default_value=[2.])
@ -140,11 +148,19 @@ class FeatureColumnTest(tf.test.TestCase):
self.assertEqual(a.name, "aaa_BUCKETIZED") self.assertEqual(a.name, "aaa_BUCKETIZED")
def testBucketizedColumnRequiresRealValuedColumn(self): def testBucketizedColumnRequiresRealValuedColumn(self):
with self.assertRaises(TypeError): with self.assertRaisesRegexp(
TypeError, "source_column must be an instance of _RealValuedColumn"):
tf.contrib.layers.bucketized_column("bbb", [0]) tf.contrib.layers.bucketized_column("bbb", [0])
with self.assertRaisesRegexp(
TypeError, "source_column must be an instance of _RealValuedColumn"):
tf.contrib.layers.bucketized_column(
tf.contrib.layers.sparse_column_with_integerized_feature(
column_name="bbb", bucket_size=10),
[0])
def testBucketizedColumnRequiresSortedBuckets(self): def testBucketizedColumnRequiresSortedBuckets(self):
with self.assertRaises(ValueError): with self.assertRaisesRegexp(
ValueError, "boundaries must be a sorted list"):
tf.contrib.layers.bucketized_column( tf.contrib.layers.bucketized_column(
tf.contrib.layers.real_valued_column("ccc"), [5, 0, 4]) tf.contrib.layers.real_valued_column("ccc"), [5, 0, 4])
@ -173,7 +189,10 @@ class FeatureColumnTest(tf.test.TestCase):
def testCrossedColumnNotSupportRealValuedColumn(self): def testCrossedColumnNotSupportRealValuedColumn(self):
b = tf.contrib.layers.sparse_column_with_hash_bucket("bbb", b = tf.contrib.layers.sparse_column_with_hash_bucket("bbb",
hash_bucket_size=100) hash_bucket_size=100)
with self.assertRaises(TypeError): with self.assertRaisesRegexp(
TypeError,
"columns must be a set of _SparseColumn, _CrossedColumn, "
"or _BucketizedColumn instances"):
tf.contrib.layers.crossed_column( tf.contrib.layers.crossed_column(
set([b, tf.contrib.layers.real_valued_column("real")]), set([b, tf.contrib.layers.real_valued_column("real")]),
hash_bucket_size=10000) hash_bucket_size=10000)
@ -194,7 +213,8 @@ class FeatureColumnTest(tf.test.TestCase):
"weights": tf.VarLenFeature(tf.int32)}, "weights": tf.VarLenFeature(tf.int32)},
weighted_ids.config) weighted_ids.config)
with self.assertRaises(ValueError): with self.assertRaisesRegexp(ValueError,
"dtype is not convertible to float"):
weighted_ids = tf.contrib.layers.weighted_sparse_column(ids, "weights", weighted_ids = tf.contrib.layers.weighted_sparse_column(ids, "weights",
dtype=tf.string) dtype=tf.string)
@ -211,7 +231,8 @@ class FeatureColumnTest(tf.test.TestCase):
[1], dtype=tf.int32)}, [1], dtype=tf.int32)},
rvc.config) rvc.config)
with self.assertRaises(ValueError): with self.assertRaisesRegexp(ValueError,
"dtype must be convertible to float"):
tf.contrib.layers.real_valued_column("rvc", dtype=tf.string) tf.contrib.layers.real_valued_column("rvc", dtype=tf.string)
def testSparseColumnDtypes(self): def testSparseColumnDtypes(self):
@ -222,7 +243,8 @@ class FeatureColumnTest(tf.test.TestCase):
"sc", 10, dtype=tf.int32) "sc", 10, dtype=tf.int32)
self.assertDictEqual({"sc": tf.VarLenFeature(dtype=tf.int32)}, sc.config) self.assertDictEqual({"sc": tf.VarLenFeature(dtype=tf.int32)}, sc.config)
with self.assertRaises(ValueError): with self.assertRaisesRegexp(ValueError,
"dtype must be an integer"):
tf.contrib.layers.sparse_column_with_integerized_feature("sc", tf.contrib.layers.sparse_column_with_integerized_feature("sc",
10, 10,
dtype=tf.float32) dtype=tf.float32)

View File

@ -70,7 +70,7 @@ def multi_class_target(n_classes, label_name=None, weight_column_name=None):
will be multiplied by the loss of the example. will be multiplied by the loss of the example.
Returns: Returns:
An instance of _TargetColumn An instance of _MultiClassTargetColumn.
Raises: Raises:
ValueError: if n_classes is < 2 ValueError: if n_classes is < 2

View File

@ -33,6 +33,7 @@ from tensorflow.contrib.learn.python.learn import preprocessing
from tensorflow.contrib.learn.python.learn import utils from tensorflow.contrib.learn.python.learn import utils
from tensorflow.contrib.learn.python.learn.dataframe import * from tensorflow.contrib.learn.python.learn.dataframe import *
from tensorflow.contrib.learn.python.learn.estimators import * from tensorflow.contrib.learn.python.learn.estimators import *
from tensorflow.contrib.learn.python.learn.evaluable import Evaluable
from tensorflow.contrib.learn.python.learn.experiment import Experiment from tensorflow.contrib.learn.python.learn.experiment import Experiment
from tensorflow.contrib.learn.python.learn.graph_actions import evaluate from tensorflow.contrib.learn.python.learn.graph_actions import evaluate
from tensorflow.contrib.learn.python.learn.graph_actions import infer from tensorflow.contrib.learn.python.learn.graph_actions import infer
@ -41,4 +42,5 @@ from tensorflow.contrib.learn.python.learn.graph_actions import run_feeds
from tensorflow.contrib.learn.python.learn.graph_actions import run_n from tensorflow.contrib.learn.python.learn.graph_actions import run_n
from tensorflow.contrib.learn.python.learn.graph_actions import train from tensorflow.contrib.learn.python.learn.graph_actions import train
from tensorflow.contrib.learn.python.learn.learn_io import * from tensorflow.contrib.learn.python.learn.learn_io import *
from tensorflow.contrib.learn.python.learn.trainable import Trainable
# pylint: enable=wildcard-import # pylint: enable=wildcard-import

View File

@ -30,6 +30,7 @@ from tensorflow.contrib.learn.python.learn.dataframe.transform import Transform
# Transforms # Transforms
from tensorflow.contrib.learn.python.learn.dataframe.transforms.boolean_mask import BooleanMask from tensorflow.contrib.learn.python.learn.dataframe.transforms.boolean_mask import BooleanMask
from tensorflow.contrib.learn.python.learn.dataframe.transforms.difference import Difference from tensorflow.contrib.learn.python.learn.dataframe.transforms.difference import Difference
from tensorflow.contrib.learn.python.learn.dataframe.transforms.hashes import HashFast
from tensorflow.contrib.learn.python.learn.dataframe.transforms.in_memory_source import NumpySource from tensorflow.contrib.learn.python.learn.dataframe.transforms.in_memory_source import NumpySource
from tensorflow.contrib.learn.python.learn.dataframe.transforms.in_memory_source import PandasSource from tensorflow.contrib.learn.python.learn.dataframe.transforms.in_memory_source import PandasSource
from tensorflow.contrib.learn.python.learn.dataframe.transforms.reader_source import ReaderSource from tensorflow.contrib.learn.python.learn.dataframe.transforms.reader_source import ReaderSource

View File

@ -117,10 +117,11 @@ class DataFrame(object):
value = [value] value = [value]
self.assign(**dict(zip(key, value))) self.assign(**dict(zip(key, value)))
def build(self): def build(self, **kwargs):
# We do not allow passing a cache here, because that would encourage # We do not allow passing a cache here, because that would encourage
# working around the rule that DataFrames cannot be expected to be # working around the rule that DataFrames cannot be expected to be
# synced with each other (e.g., they shuffle independently). # synced with each other (e.g., they shuffle independently).
cache = {} cache = {}
tensors = {name: c.build(cache) for name, c in self._columns.items()} tensors = {name: c.build(cache, **kwargs)
for name, c in self._columns.items()}
return tensors return tensors

View File

@ -91,7 +91,8 @@ def _build_alternate_universe(
def to_feature_columns_and_input_fn(dataframe, def to_feature_columns_and_input_fn(dataframe,
base_input_keys_with_defaults, base_input_keys_with_defaults,
feature_keys, feature_keys,
target_keys=None): target_keys=None,
**kwargs):
"""Build a list of FeatureColumns and an input_fn for use with Estimator. """Build a list of FeatureColumns and an input_fn for use with Estimator.
Args: Args:
@ -103,6 +104,7 @@ def to_feature_columns_and_input_fn(dataframe,
These may include base features and/or derived features. These may include base features and/or derived features.
target_keys: the names of columns to be used as targets. None is target_keys: the names of columns to be used as targets. None is
acceptable for unsupervised learning. acceptable for unsupervised learning.
**kwargs: Additional keyword arguments, unused here.
Returns: Returns:
A tuple of two elements: A tuple of two elements:
@ -155,10 +157,11 @@ def to_feature_columns_and_input_fn(dataframe,
# Build an input_fn suitable for use with Estimator. # Build an input_fn suitable for use with Estimator.
def input_fn(): def input_fn():
"""An input_fn() for feeding the given set of DataFrameColumns."""
# It's important to build all the tensors together in one DataFrame. # It's important to build all the tensors together in one DataFrame.
# If we did df.select() for both key sets and then build those, the two # If we did df.select() for both key sets and then build those, the two
# resulting DataFrames would be shuffled independently. # resulting DataFrames would be shuffled independently.
tensors = limited_dataframe.build() tensors = limited_dataframe.build(**kwargs)
base_input_features = {key: tensors[key] for key in base_input_keys} base_input_features = {key: tensors[key] for key in base_input_keys}
targets = {key: tensors[key] for key in target_keys} targets = {key: tensors[key] for key in target_keys}

View File

@ -98,7 +98,7 @@ class Series(object):
return transform_cls return transform_cls
return register return register
def build(self, cache): def build(self, cache, **kwargs):
"""Returns a Tensor.""" """Returns a Tensor."""
raise NotImplementedError() raise NotImplementedError()
@ -122,7 +122,7 @@ class PredefinedSeries(Series):
def required_base_features(self): def required_base_features(self):
return {self.name: self.feature_spec} return {self.name: self.feature_spec}
def build(self, cache): def build(self, cache, **kwargs):
try: try:
return cache[self.name] return cache[self.name]
except KeyError: except KeyError:
@ -171,10 +171,11 @@ class TransformedSeries(Series):
result.update(s.required_base_features) result.update(s.required_base_features)
return result return result
def build(self, cache=None): def build(self, cache=None, **kwargs):
if cache is None: if cache is None:
cache = {} cache = {}
all_outputs = self._transform.build_transitive(self._input_series, cache) all_outputs = self._transform.build_transitive(
self._input_series, cache, **kwargs)
return getattr(all_outputs, self._output_name) return getattr(all_outputs, self._output_name)
def __repr__(self): def __repr__(self):

View File

@ -28,6 +28,7 @@ from tensorflow.contrib.learn.python.learn.dataframe import dataframe as df
from tensorflow.contrib.learn.python.learn.dataframe.transforms import batch from tensorflow.contrib.learn.python.learn.dataframe.transforms import batch
from tensorflow.contrib.learn.python.learn.dataframe.transforms import csv_parser from tensorflow.contrib.learn.python.learn.dataframe.transforms import csv_parser
from tensorflow.contrib.learn.python.learn.dataframe.transforms import example_parser from tensorflow.contrib.learn.python.learn.dataframe.transforms import example_parser
from tensorflow.contrib.learn.python.learn.dataframe.transforms import hashes
from tensorflow.contrib.learn.python.learn.dataframe.transforms import in_memory_source from tensorflow.contrib.learn.python.learn.dataframe.transforms import in_memory_source
from tensorflow.contrib.learn.python.learn.dataframe.transforms import reader_source from tensorflow.contrib.learn.python.learn.dataframe.transforms import reader_source
from tensorflow.contrib.learn.python.learn.dataframe.transforms import sparsify from tensorflow.contrib.learn.python.learn.dataframe.transforms import sparsify
@ -83,7 +84,8 @@ class TensorFlowDataFrame(df.DataFrame):
graph=None, graph=None,
session=None, session=None,
start_queues=True, start_queues=True,
initialize_variables=True): initialize_variables=True,
**kwargs):
"""Builds and runs the columns of the `DataFrame` and yields batches. """Builds and runs the columns of the `DataFrame` and yields batches.
This is a generator that yields a dictionary mapping column names to This is a generator that yields a dictionary mapping column names to
@ -97,6 +99,7 @@ class TensorFlowDataFrame(df.DataFrame):
start_queues: if true, queues will be started before running and halted start_queues: if true, queues will be started before running and halted
after producting `n` batches. after producting `n` batches.
initialize_variables: if true, variables will be initialized. initialize_variables: if true, variables will be initialized.
**kwargs: Additional keyword arguments e.g. `num_epochs`.
Yields: Yields:
A dictionary, mapping column names to the values resulting from running A dictionary, mapping column names to the values resulting from running
@ -107,7 +110,7 @@ class TensorFlowDataFrame(df.DataFrame):
with graph.as_default(): with graph.as_default():
if session is None: if session is None:
session = sess.Session() session = sess.Session()
self_built = self.build() self_built = self.build(**kwargs)
keys = list(self_built.keys()) keys = list(self_built.keys())
cols = list(self_built.values()) cols = list(self_built.values())
if initialize_variables: if initialize_variables:
@ -157,6 +160,52 @@ class TensorFlowDataFrame(df.DataFrame):
"Original error: {}").format(type(col), e)) "Original error: {}").format(type(col), e))
return result return result
def split(self, index_series, proportion, batch_size=None):
"""Deterministically split a `DataFrame` into two `DataFrame`s.
Note this split is only as deterministic as the underlying hash function;
see `tf.string_to_hash_bucket_fast`. The hash function is deterministic
for a given binary, but may change occasionally. The only way to achieve
an absolute guarantee that the split `DataFrame`s do not change across runs
is to materialize them.
Note too that the allocation of a row to one partition or the
other is evaluated independently for each row, so the exact number of rows
in each partition is binomially distributed.
Args:
index_series: a `Series` of unique strings, whose hash will determine the
partitioning; or the name in this `DataFrame` of such a `Series`.
(This `Series` must contain strings because TensorFlow provides hash
ops only for strings, and there are no number-to-string converter ops.)
proportion: The proportion of the rows to select for the 'left'
partition; the remaining (1 - proportion) rows form the 'right'
partition.
batch_size: the batch size to use when rebatching the left and right
`DataFrame`s. If None (default), the `DataFrame`s are not rebatched;
thus their batches will have variable sizes, according to which rows
are selected from each batch of the original `DataFrame`.
Returns:
Two `DataFrame`s containing the partitioned rows.
"""
# TODO(soergel): allow seed?
if isinstance(index_series, str):
index_series = self[index_series]
num_buckets = 1000000 # close enough for simple splits
hashed_input, = hashes.HashFast(num_buckets)(index_series)
threshold = int(num_buckets * proportion)
left = hashed_input < threshold
right = ~left
left_rows = self.select_rows(left)
right_rows = self.select_rows(right)
if batch_size:
left_rows = left_rows.batch(batch_size=batch_size, shuffle=False)
right_rows = right_rows.batch(batch_size=batch_size, shuffle=False)
return left_rows, right_rows
def run_once(self): def run_once(self):
"""Creates a new 'Graph` and `Session` and runs a single batch. """Creates a new 'Graph` and `Session` and runs a single batch.
@ -208,7 +257,7 @@ class TensorFlowDataFrame(df.DataFrame):
@classmethod @classmethod
def _from_csv_base(cls, filepatterns, get_default_values, has_header, def _from_csv_base(cls, filepatterns, get_default_values, has_header,
column_names, num_epochs, num_threads, enqueue_size, column_names, num_threads, enqueue_size,
batch_size, queue_capacity, min_after_dequeue, shuffle, batch_size, queue_capacity, min_after_dequeue, shuffle,
seed): seed):
"""Create a `DataFrame` from CSV files. """Create a `DataFrame` from CSV files.
@ -223,9 +272,6 @@ class TensorFlowDataFrame(df.DataFrame):
each column, given the column names. each column, given the column names.
has_header: whether or not the CSV files have headers. has_header: whether or not the CSV files have headers.
column_names: a list of names for the columns in the CSV files. column_names: a list of names for the columns in the CSV files.
num_epochs: the number of times that the reader should loop through all
the file names. If set to `None`, then the reader will continue
indefinitely.
num_threads: the number of readers that will work in parallel. num_threads: the number of readers that will work in parallel.
enqueue_size: block size for each read operation. enqueue_size: block size for each read operation.
batch_size: desired batch size. batch_size: desired batch size.
@ -265,7 +311,6 @@ class TensorFlowDataFrame(df.DataFrame):
reader_kwargs=reader_kwargs, reader_kwargs=reader_kwargs,
enqueue_size=enqueue_size, enqueue_size=enqueue_size,
batch_size=batch_size, batch_size=batch_size,
num_epochs=num_epochs,
queue_capacity=queue_capacity, queue_capacity=queue_capacity,
shuffle=shuffle, shuffle=shuffle,
min_after_dequeue=min_after_dequeue, min_after_dequeue=min_after_dequeue,
@ -287,7 +332,6 @@ class TensorFlowDataFrame(df.DataFrame):
default_values, default_values,
has_header=True, has_header=True,
column_names=None, column_names=None,
num_epochs=None,
num_threads=1, num_threads=1,
enqueue_size=None, enqueue_size=None,
batch_size=32, batch_size=32,
@ -306,9 +350,6 @@ class TensorFlowDataFrame(df.DataFrame):
default_values: a list of default values for each column. default_values: a list of default values for each column.
has_header: whether or not the CSV files have headers. has_header: whether or not the CSV files have headers.
column_names: a list of names for the columns in the CSV files. column_names: a list of names for the columns in the CSV files.
num_epochs: the number of times that the reader should loop through all
the file names. If set to `None`, then the reader will continue
indefinitely.
num_threads: the number of readers that will work in parallel. num_threads: the number of readers that will work in parallel.
enqueue_size: block size for each read operation. enqueue_size: block size for each read operation.
batch_size: desired batch size. batch_size: desired batch size.
@ -332,7 +373,7 @@ class TensorFlowDataFrame(df.DataFrame):
return default_values return default_values
return cls._from_csv_base(filepatterns, get_default_values, has_header, return cls._from_csv_base(filepatterns, get_default_values, has_header,
column_names, num_epochs, num_threads, column_names, num_threads,
enqueue_size, batch_size, queue_capacity, enqueue_size, batch_size, queue_capacity,
min_after_dequeue, shuffle, seed) min_after_dequeue, shuffle, seed)
@ -342,7 +383,6 @@ class TensorFlowDataFrame(df.DataFrame):
feature_spec, feature_spec,
has_header=True, has_header=True,
column_names=None, column_names=None,
num_epochs=None,
num_threads=1, num_threads=1,
enqueue_size=None, enqueue_size=None,
batch_size=32, batch_size=32,
@ -362,9 +402,6 @@ class TensorFlowDataFrame(df.DataFrame):
`VarLenFeature`. `VarLenFeature`.
has_header: whether or not the CSV files have headers. has_header: whether or not the CSV files have headers.
column_names: a list of names for the columns in the CSV files. column_names: a list of names for the columns in the CSV files.
num_epochs: the number of times that the reader should loop through all
the file names. If set to `None`, then the reader will continue
indefinitely.
num_threads: the number of readers that will work in parallel. num_threads: the number of readers that will work in parallel.
enqueue_size: block size for each read operation. enqueue_size: block size for each read operation.
batch_size: desired batch size. batch_size: desired batch size.
@ -387,7 +424,7 @@ class TensorFlowDataFrame(df.DataFrame):
return [_get_default_value(feature_spec[name]) for name in column_names] return [_get_default_value(feature_spec[name]) for name in column_names]
dataframe = cls._from_csv_base(filepatterns, get_default_values, has_header, dataframe = cls._from_csv_base(filepatterns, get_default_values, has_header,
column_names, num_epochs, num_threads, column_names, num_threads,
enqueue_size, batch_size, queue_capacity, enqueue_size, batch_size, queue_capacity,
min_after_dequeue, shuffle, seed) min_after_dequeue, shuffle, seed)
@ -405,7 +442,6 @@ class TensorFlowDataFrame(df.DataFrame):
filepatterns, filepatterns,
features, features,
reader_cls=io_ops.TFRecordReader, reader_cls=io_ops.TFRecordReader,
num_epochs=None,
num_threads=1, num_threads=1,
enqueue_size=None, enqueue_size=None,
batch_size=32, batch_size=32,
@ -421,9 +457,6 @@ class TensorFlowDataFrame(df.DataFrame):
`FixedLenFeature`. `FixedLenFeature`.
reader_cls: a subclass of `tensorflow.ReaderBase` that will be used to reader_cls: a subclass of `tensorflow.ReaderBase` that will be used to
read the `Example`s. read the `Example`s.
num_epochs: the number of times that the reader should loop through all
the file names. If set to `None`, then the reader will continue
indefinitely.
num_threads: the number of readers that will work in parallel. num_threads: the number of readers that will work in parallel.
enqueue_size: block size for each read operation. enqueue_size: block size for each read operation.
batch_size: desired batch size. batch_size: desired batch size.
@ -454,7 +487,6 @@ class TensorFlowDataFrame(df.DataFrame):
filenames, filenames,
enqueue_size=enqueue_size, enqueue_size=enqueue_size,
batch_size=batch_size, batch_size=batch_size,
num_epochs=num_epochs,
queue_capacity=queue_capacity, queue_capacity=queue_capacity,
shuffle=shuffle, shuffle=shuffle,
min_after_dequeue=min_after_dequeue, min_after_dequeue=min_after_dequeue,

View File

@ -223,13 +223,14 @@ class Transform(object):
# pylint: disable=not-callable # pylint: disable=not-callable
return self.return_type(*output_series) return self.return_type(*output_series)
def build_transitive(self, input_series, cache=None): def build_transitive(self, input_series, cache=None, **kwargs):
"""Apply this `Transform` to the provided `Series`, producing 'Tensor's. """Apply this `Transform` to the provided `Series`, producing 'Tensor's.
Args: Args:
input_series: None, a `Series`, or a list of input `Series`, acting as input_series: None, a `Series`, or a list of input `Series`, acting as
positional arguments. positional arguments.
cache: a dict from Series reprs to Tensors. cache: a dict from Series reprs to Tensors.
**kwargs: Additional keyword arguments, unused here.
Returns: Returns:
A namedtuple of the output Tensors. A namedtuple of the output Tensors.
@ -244,7 +245,7 @@ class Transform(object):
if len(input_series) != self.input_valency: if len(input_series) != self.input_valency:
raise ValueError("Expected %s input Series but received %s." % raise ValueError("Expected %s input Series but received %s." %
(self.input_valency, len(input_series))) (self.input_valency, len(input_series)))
input_tensors = [series.build(cache) for series in input_series] input_tensors = [series.build(cache, **kwargs) for series in input_series]
# Note we cache each output individually, not just the entire output # Note we cache each output individually, not just the entire output
# tuple. This allows using the graph as the cache, since it can sensibly # tuple. This allows using the graph as the cache, since it can sensibly
@ -254,7 +255,7 @@ class Transform(object):
output_tensors = [cache.get(output_repr) for output_repr in output_reprs] output_tensors = [cache.get(output_repr) for output_repr in output_reprs]
if None in output_tensors: if None in output_tensors:
result = self._apply_transform(input_tensors) result = self._apply_transform(input_tensors, **kwargs)
for output_name, output_repr in zip(self.output_names, output_reprs): for output_name, output_repr in zip(self.output_names, output_reprs):
cache[output_repr] = getattr(result, output_name) cache[output_repr] = getattr(result, output_name)
else: else:
@ -264,12 +265,13 @@ class Transform(object):
return result return result
@abstractmethod @abstractmethod
def _apply_transform(self, input_tensors): def _apply_transform(self, input_tensors, **kwargs):
"""Applies the transformation to the `transform_input`. """Applies the transformation to the `transform_input`.
Args: Args:
input_tensors: a list of Tensors representing the input to input_tensors: a list of Tensors representing the input to
the Transform. the Transform.
**kwargs: Additional keyword arguments, unused here.
Returns: Returns:
A namedtuple of Tensors representing the transformed output. A namedtuple of Tensors representing the transformed output.

View File

@ -72,7 +72,7 @@ class Batch(AbstractBatchTransform):
def name(self): def name(self):
return "Batch" return "Batch"
def _apply_transform(self, transform_input): def _apply_transform(self, transform_input, **kwargs):
batched = input_ops.batch(transform_input, batched = input_ops.batch(transform_input,
batch_size=self.batch_size, batch_size=self.batch_size,
num_threads=self.num_threads, num_threads=self.num_threads,
@ -121,7 +121,7 @@ class ShuffleBatch(AbstractBatchTransform):
def seed(self): def seed(self):
return self._seed return self._seed
def _apply_transform(self, transform_input): def _apply_transform(self, transform_input, **kwargs):
batched = input_ops.shuffle_batch(transform_input, batched = input_ops.shuffle_batch(transform_input,
batch_size=self.batch_size, batch_size=self.batch_size,
capacity=self.queue_capacity, capacity=self.queue_capacity,

View File

@ -1,4 +1,4 @@
# Copyright 2016 Google Inc. All Rights Reserved. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -53,7 +53,7 @@ class SeriesBinaryTransform(transform.Transform):
def _output_names(self): def _output_names(self):
return "output", return "output",
def _apply_transform(self, input_tensors): def _apply_transform(self, input_tensors, **kwargs):
# TODO(jamieas): consider supporting sparse inputs. # TODO(jamieas): consider supporting sparse inputs.
if isinstance(input_tensors[0], ops.SparseTensor) or isinstance( if isinstance(input_tensors[0], ops.SparseTensor) or isinstance(
input_tensors[1], ops.SparseTensor): input_tensors[1], ops.SparseTensor):
@ -87,7 +87,7 @@ class ScalarBinaryTransform(transform.Transform):
def _output_names(self): def _output_names(self):
return "output", return "output",
def _apply_transform(self, input_tensors): def _apply_transform(self, input_tensors, **kwargs):
input_tensor = input_tensors[0] input_tensor = input_tensors[0]
if isinstance(input_tensor, ops.SparseTensor): if isinstance(input_tensor, ops.SparseTensor):
result = ops.SparseTensor(input_tensor.indices, result = ops.SparseTensor(input_tensor.indices,

View File

@ -77,18 +77,21 @@ class BooleanMask(transform.Transform):
def _output_names(self): def _output_names(self):
return "output", return "output",
def _apply_transform(self, input_tensors): def _apply_transform(self, input_tensors, **kwargs):
"""Applies the transformation to the `transform_input`. """Applies the transformation to the `transform_input`.
Args: Args:
input_tensors: a list of Tensors representing the input to input_tensors: a list of Tensors representing the input to
the Transform. the Transform.
**kwargs: Additional keyword arguments, unused here.
Returns: Returns:
A namedtuple of Tensors representing the transformed output. A namedtuple of Tensors representing the transformed output.
""" """
input_tensor = input_tensors[0] input_tensor = input_tensors[0]
mask = input_tensors[1] mask = input_tensors[1]
if mask.get_shape().ndims > 1:
mask = array_ops.squeeze(mask)
if isinstance(input_tensor, ops.SparseTensor): if isinstance(input_tensor, ops.SparseTensor):
mask_fn = sparse_boolean_mask mask_fn = sparse_boolean_mask

View File

@ -58,7 +58,7 @@ class CSVParser(transform.Transform):
def default_values(self): def default_values(self):
return self._default_values return self._default_values
def _apply_transform(self, input_tensors): def _apply_transform(self, input_tensors, **kwargs):
default_consts = [constant_op.constant(d, shape=[1]) default_consts = [constant_op.constant(d, shape=[1])
for d in self._default_values] for d in self._default_values]
parsed_values = parsing_ops.decode_csv(input_tensors[0], parsed_values = parsing_ops.decode_csv(input_tensors[0],

View File

@ -47,12 +47,13 @@ class Densify(transform.Transform):
def _output_names(self): def _output_names(self):
return "output", return "output",
def _apply_transform(self, input_tensors): def _apply_transform(self, input_tensors, **kwargs):
"""Applies the transformation to the `transform_input`. """Applies the transformation to the `transform_input`.
Args: Args:
input_tensors: a list of Tensors representing the input to input_tensors: a list of Tensors representing the input to
the Transform. the Transform.
**kwargs: Additional keyword arguments, unused here.
Returns: Returns:
A namedtuple of Tensors representing the transformed output. A namedtuple of Tensors representing the transformed output.

View File

@ -1,4 +1,4 @@
# Copyright 2016 Google Inc. All Rights Reserved. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -50,7 +50,7 @@ class Difference(transform.Transform):
def _output_names(self): def _output_names(self):
return "output", return "output",
def _apply_transform(self, input_tensors): def _apply_transform(self, input_tensors, **kwargs):
pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor), pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor),
isinstance(input_tensors[1], ops.SparseTensor)) isinstance(input_tensors[1], ops.SparseTensor))

View File

@ -61,7 +61,7 @@ class ExampleParser(transform.Transform):
def feature_definitions(self): def feature_definitions(self):
return self._ordered_features return self._ordered_features
def _apply_transform(self, input_tensors): def _apply_transform(self, input_tensors, **kwargs):
parsed_values = parsing_ops.parse_example(input_tensors[0], parsed_values = parsing_ops.parse_example(input_tensors[0],
features=self._ordered_features) features=self._ordered_features)
# pylint: disable=not-callable # pylint: disable=not-callable

View File

@ -0,0 +1,68 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Masks one `Series` based on the content of another `Series`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.learn.python.learn.dataframe import transform
from tensorflow.python.ops import string_ops
class HashFast(transform.Transform):
"""Perform a fast hash of a `Series`."""
def __init__(self, num_buckets):
"""Initialize `CSVParser`.
Args:
num_buckets: The number of hash buckets to use.
"""
# TODO(soergel): allow seed?
super(HashFast, self).__init__()
self._num_buckets = num_buckets
@property
def name(self):
return "HashFast"
@property
def input_valency(self):
return 1
@property
def _output_names(self):
return "output",
def _apply_transform(self, input_tensors, **kwargs):
"""Applies the transformation to the `transform_input`.
Args:
input_tensors: a list of Tensors representing the input to
the Transform.
**kwargs: additional keyword arguments, unused here.
Returns:
A namedtuple of Tensors representing the transformed output.
"""
result = string_ops.string_to_hash_bucket_fast(input_tensors[0],
self._num_buckets,
name=None)
# pylint: disable=not-callable
return self.return_type(result)

View File

@ -89,7 +89,7 @@ class BaseInMemorySource(transform.Transform):
def input_valency(self): def input_valency(self):
return 0 return 0
def _apply_transform(self, transform_input): def _apply_transform(self, transform_input, **kwargs):
queue = feeding_functions.enqueue_data(self.data, queue = feeding_functions.enqueue_data(self.data,
self.queue_capacity, self.queue_capacity,
self.shuffle, self.shuffle,

View File

@ -32,7 +32,6 @@ class ReaderSource(transform.Transform):
reader_kwargs=None, reader_kwargs=None,
enqueue_size=None, enqueue_size=None,
batch_size=1, batch_size=1,
num_epochs=None,
queue_capacity=None, queue_capacity=None,
shuffle=False, shuffle=False,
min_after_dequeue=None, min_after_dequeue=None,
@ -49,9 +48,6 @@ class ReaderSource(transform.Transform):
is constructed. is constructed.
enqueue_size: block size for each read operation. enqueue_size: block size for each read operation.
batch_size: The desired batch size of output. Defaults to 1. batch_size: The desired batch size of output. Defaults to 1.
num_epochs: the number of times that the reader should loop through all
the file names. If set to `None`, then the reader will continue
indefinitely.
queue_capacity: Capacity of the queue. Defaults to 10 * `batch_size`. queue_capacity: Capacity of the queue. Defaults to 10 * `batch_size`.
shuffle: Whether records will be shuffled before returning. Defaults to shuffle: Whether records will be shuffled before returning. Defaults to
false. false.
@ -73,7 +69,6 @@ class ReaderSource(transform.Transform):
self._batch_size = batch_size self._batch_size = batch_size
self._queue_capacity = (batch_size * 10 if queue_capacity is None else self._queue_capacity = (batch_size * 10 if queue_capacity is None else
queue_capacity) queue_capacity)
self._num_epochs = num_epochs
self._shuffle = shuffle self._shuffle = shuffle
self._min_after_dequeue = int(self.queue_capacity / 4 if min_after_dequeue self._min_after_dequeue = int(self.queue_capacity / 4 if min_after_dequeue
is None else min_after_dequeue) is None else min_after_dequeue)
@ -100,10 +95,6 @@ class ReaderSource(transform.Transform):
def batch_size(self): def batch_size(self):
return self._batch_size return self._batch_size
@transform.parameter
def num_epochs(self):
return self._num_epochs
@transform.parameter @transform.parameter
def queue_capacity(self): def queue_capacity(self):
return self._queue_capacity return self._queue_capacity
@ -136,9 +127,10 @@ class ReaderSource(transform.Transform):
def _output_names(self): def _output_names(self):
return ("index", "value") return ("index", "value")
def _apply_transform(self, transform_input): def _apply_transform(self, transform_input, **kwargs):
filename_queue = input_ops.string_input_producer(self.work_units, filename_queue = input_ops.string_input_producer(
num_epochs=self.num_epochs, self.work_units,
num_epochs=kwargs.get("num_epochs"),
shuffle=self.shuffle, shuffle=self.shuffle,
seed=self.seed) seed=self.seed)
reader_ops = [] reader_ops = []
@ -174,7 +166,6 @@ def TextFileSource(file_names,
reader_kwargs=None, reader_kwargs=None,
enqueue_size=1, enqueue_size=1,
batch_size=1, batch_size=1,
num_epochs=None,
queue_capacity=None, queue_capacity=None,
shuffle=False, shuffle=False,
min_after_dequeue=None, min_after_dequeue=None,
@ -185,7 +176,6 @@ def TextFileSource(file_names,
reader_kwargs=reader_kwargs, reader_kwargs=reader_kwargs,
enqueue_size=enqueue_size, enqueue_size=enqueue_size,
batch_size=batch_size, batch_size=batch_size,
num_epochs=num_epochs,
queue_capacity=queue_capacity, queue_capacity=queue_capacity,
shuffle=shuffle, shuffle=shuffle,
min_after_dequeue=min_after_dequeue, min_after_dequeue=min_after_dequeue,
@ -197,7 +187,6 @@ def TFRecordSource(file_names,
reader_kwargs=None, reader_kwargs=None,
enqueue_size=1, enqueue_size=1,
batch_size=1, batch_size=1,
num_epochs=None,
queue_capacity=None, queue_capacity=None,
shuffle=False, shuffle=False,
min_after_dequeue=None, min_after_dequeue=None,
@ -208,7 +197,6 @@ def TFRecordSource(file_names,
reader_kwargs=reader_kwargs, reader_kwargs=reader_kwargs,
enqueue_size=enqueue_size, enqueue_size=enqueue_size,
batch_size=batch_size, batch_size=batch_size,
num_epochs=num_epochs,
queue_capacity=queue_capacity, queue_capacity=queue_capacity,
shuffle=shuffle, shuffle=shuffle,
min_after_dequeue=min_after_dequeue, min_after_dequeue=min_after_dequeue,

View File

@ -52,12 +52,13 @@ class Sparsify(transform.Transform):
def _output_names(self): def _output_names(self):
return "output", return "output",
def _apply_transform(self, input_tensors): def _apply_transform(self, input_tensors, **kwargs):
"""Applies the transformation to the `transform_input`. """Applies the transformation to the `transform_input`.
Args: Args:
input_tensors: a list of Tensors representing the input to input_tensors: a list of Tensors representing the input to
the Transform. the Transform.
**kwargs: Additional keyword arguments, unused here.
Returns: Returns:
A namedtuple of Tensors representing the transformed output. A namedtuple of Tensors representing the transformed output.

View File

@ -1,4 +1,4 @@
# Copyright 2016 Google Inc. All Rights Reserved. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -44,7 +44,7 @@ class Sum(transform.Transform):
def _output_names(self): def _output_names(self):
return "output", return "output",
def _apply_transform(self, input_tensors): def _apply_transform(self, input_tensors, **kwargs):
pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor), pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor),
isinstance(input_tensors[1], ops.SparseTensor)) isinstance(input_tensors[1], ops.SparseTensor))

View File

@ -1,4 +1,4 @@
# Copyright 2016 Google Inc. All Rights Reserved. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -43,7 +43,8 @@ UNARY_TRANSFORMS = [("__neg__", math_ops.neg),
("lgamma", math_ops.lgamma), ("lgamma", math_ops.lgamma),
("digamma", math_ops.digamma), ("digamma", math_ops.digamma),
("erf", math_ops.erf), ("erf", math_ops.erf),
("erfc", math_ops.erfc)] ("erfc", math_ops.erfc),
("__invert__", math_ops.logical_not, bool)]
DOC_FORMAT_STRING = ( DOC_FORMAT_STRING = (
"A `Transform` that wraps the `{0}` operation. " "A `Transform` that wraps the `{0}` operation. "
@ -52,7 +53,7 @@ DOC_FORMAT_STRING = (
# pylint: disable=unused-argument # pylint: disable=unused-argument
def register_unary_op(registered_name, operation): def register_unary_op(registered_name, operation, ignore_dtype=None):
"""Creates a `Transform` that wraps a unary tensorflow operation. """Creates a `Transform` that wraps a unary tensorflow operation.
If `registered_name` is specified, the `Transform` is registered as a member If `registered_name` is specified, the `Transform` is registered as a member
@ -62,6 +63,8 @@ def register_unary_op(registered_name, operation):
registered_name: the name of the member function of `Series` corresponding registered_name: the name of the member function of `Series` corresponding
to the returned `Transform`. to the returned `Transform`.
operation: a unary TensorFlow operation. operation: a unary TensorFlow operation.
ignore_dtype: an optional dtype, not used here but needed for symmetry with
test.
""" """
doc = DOC_FORMAT_STRING.format(operation.__name__, operation.__doc__) doc = DOC_FORMAT_STRING.format(operation.__name__, operation.__doc__)
@ -78,7 +81,7 @@ def register_unary_op(registered_name, operation):
def _output_names(self): def _output_names(self):
return "output" return "output"
def _apply_transform(self, input_tensors): def _apply_transform(self, input_tensors, **kwargs):
input_tensor = input_tensors[0] input_tensor = input_tensors[0]
if isinstance(input_tensor, ops.SparseTensor): if isinstance(input_tensor, ops.SparseTensor):
result = ops.SparseTensor(input_tensor.indices, result = ops.SparseTensor(input_tensor.indices,

View File

@ -29,14 +29,10 @@ from tensorflow.contrib.learn.python.learn.estimators import _sklearn
def iris_input_fn(num_epochs=None): def iris_input_fn(num_epochs=None):
iris = tf.contrib.learn.datasets.load_iris() iris = tf.contrib.learn.datasets.load_iris()
features = tf.cast( features = tf.reshape(tf.constant(iris.data), [-1, 4])
tf.reshape(
tf.constant(iris.data), [-1, 4]), tf.float32)
if num_epochs: if num_epochs:
features = tf.train.limit_epochs(features, num_epochs=num_epochs) features = tf.train.limit_epochs(features, num_epochs=num_epochs)
target = tf.cast( target = tf.reshape(tf.constant(iris.target), [-1])
tf.reshape(
tf.constant(iris.target), [-1]), tf.int64)
return features, target return features, target

View File

@ -20,11 +20,13 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import math import math
import re
import six import six
from tensorflow.contrib import layers from tensorflow.contrib import layers
from tensorflow.contrib.layers.python.layers import feature_column_ops from tensorflow.contrib.layers.python.layers import feature_column_ops
from tensorflow.contrib.learn.python.learn.utils import checkpoints
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import clip_ops from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import gradients from tensorflow.python.ops import gradients
@ -47,31 +49,31 @@ class _ComposableModel(object):
def __init__(self, def __init__(self,
num_label_columns, num_label_columns,
optimizer, optimizer,
weight_collection_name,
gradient_clip_norm, gradient_clip_norm,
num_ps_replicas): num_ps_replicas,
scope):
"""Common initialization for all _ComposableModel objects. """Common initialization for all _ComposableModel objects.
Args: Args:
num_label_columns: The number of label/target columns. num_label_columns: The number of label/target columns.
optimizer: An instance of `tf.Optimizer` used to apply gradients to optimizer: An instance of `tf.Optimizer` used to apply gradients to
the model. If `None`, will use a FTRL optimizer. the model. If `None`, will use a FTRL optimizer.
weight_collection_name: A string defining the name to use for the
collection of weights (e.g. 'dnn').
gradient_clip_norm: A float > 0. If provided, gradients are clipped gradient_clip_norm: A float > 0. If provided, gradients are clipped
to their global norm with this clipping ratio. See to their global norm with this clipping ratio. See
tf.clip_by_global_norm for more details. tf.clip_by_global_norm for more details.
num_ps_replicas: The number of parameter server replicas. num_ps_replicas: The number of parameter server replicas.
scope: Scope for variables created in this model.
""" """
self._num_label_columns = num_label_columns self._num_label_columns = num_label_columns
self._optimizer = optimizer self._optimizer = optimizer
self._weight_collection_name = weight_collection_name
self._gradient_clip_norm = gradient_clip_norm self._gradient_clip_norm = gradient_clip_norm
self._num_ps_replicas = num_ps_replicas self._num_ps_replicas = num_ps_replicas
self._scope = scope
self._feature_columns = None self._feature_columns = None
def get_weight_collection_name(self): def get_scope_name(self):
return self._weight_collection_name """Returns the scope name used by this model for variables."""
return self._scope
def build_model(self, features, feature_columns, is_training): def build_model(self, features, feature_columns, is_training):
"""Builds the model that can calculate the logits. """Builds the model that can calculate the logits.
@ -114,7 +116,7 @@ class _ComposableModel(object):
def _get_vars(self): def _get_vars(self):
if self._get_feature_columns(): if self._get_feature_columns():
return ops.get_collection(self._weight_collection_name) return ops.get_collection(self._scope)
return [] return []
def _get_optimizer(self): def _get_optimizer(self):
@ -142,7 +144,8 @@ class LinearComposableModel(_ComposableModel):
num_label_columns, num_label_columns,
optimizer=None, optimizer=None,
gradient_clip_norm=None, gradient_clip_norm=None,
num_ps_replicas=0): num_ps_replicas=0,
scope=None):
"""Initializes LinearComposableModel objects. """Initializes LinearComposableModel objects.
Args: Args:
@ -153,13 +156,49 @@ class LinearComposableModel(_ComposableModel):
to their global norm with this clipping ratio. See to their global norm with this clipping ratio. See
tf.clip_by_global_norm for more details. tf.clip_by_global_norm for more details.
num_ps_replicas: The number of parameter server replicas. num_ps_replicas: The number of parameter server replicas.
scope: Optional scope for variables created in this model. If scope
is not supplied, it will default to 'linear'.
""" """
scope = "linear" if not scope else scope
super(LinearComposableModel, self).__init__( super(LinearComposableModel, self).__init__(
num_label_columns=num_label_columns, num_label_columns=num_label_columns,
optimizer=optimizer, optimizer=optimizer,
weight_collection_name="linear",
gradient_clip_norm=gradient_clip_norm, gradient_clip_norm=gradient_clip_norm,
num_ps_replicas=num_ps_replicas) num_ps_replicas=num_ps_replicas,
scope=scope)
def get_weights(self, model_dir):
"""Returns weights per feature of the linear part.
Args:
model_dir: Directory where model parameters, graph and etc. are saved.
Returns:
The weights created by this model (without the optimizer weights).
"""
all_variables = [name for name, _ in checkpoints.list_variables(model_dir)]
values = {}
optimizer_regex = r".*/" + self._get_optimizer().get_name() + r"(_\d)?$"
for name in all_variables:
if (name.startswith(self._scope + "/") and
name != self._scope + "/bias_weight" and
not re.match(optimizer_regex, name)):
values[name] = checkpoints.load_variable(model_dir, name)
if len(values) == 1:
return values[list(values.keys())[0]]
return values
def get_bias(self, model_dir):
"""Returns bias of the model.
Args:
model_dir: Directory where model parameters, graph and etc. are saved.
Returns:
The bias weights created by this model.
"""
return checkpoints.load_variable(model_dir,
name=(self._scope+"/bias_weight"))
def build_model(self, features, feature_columns, is_training): def build_model(self, features, feature_columns, is_training):
"""See base class.""" """See base class."""
@ -168,12 +207,12 @@ class LinearComposableModel(_ComposableModel):
max_partitions=self._num_ps_replicas, max_partitions=self._num_ps_replicas,
min_slice_size=64 << 20) min_slice_size=64 << 20)
with variable_scope.variable_op_scope( with variable_scope.variable_op_scope(
features.values(), "linear", partitioner=partitioner) as scope: features.values(), self._scope, partitioner=partitioner) as scope:
logits, _, _ = layers.weighted_sum_from_feature_columns( logits, _, _ = layers.weighted_sum_from_feature_columns(
columns_to_tensors=features, columns_to_tensors=features,
feature_columns=self._get_feature_columns(), feature_columns=self._get_feature_columns(),
num_outputs=self._num_label_columns, num_outputs=self._num_label_columns,
weight_collections=[self._weight_collection_name], weight_collections=[self._scope],
scope=scope) scope=scope)
return logits return logits
@ -200,7 +239,8 @@ class DNNComposableModel(_ComposableModel):
activation_fn=nn.relu, activation_fn=nn.relu,
dropout=None, dropout=None,
gradient_clip_norm=None, gradient_clip_norm=None,
num_ps_replicas=0): num_ps_replicas=0,
scope=None):
"""Initializes DNNComposableModel objects. """Initializes DNNComposableModel objects.
Args: Args:
@ -217,17 +257,50 @@ class DNNComposableModel(_ComposableModel):
to their global norm with this clipping ratio. See to their global norm with this clipping ratio. See
tf.clip_by_global_norm for more details. tf.clip_by_global_norm for more details.
num_ps_replicas: The number of parameter server replicas. num_ps_replicas: The number of parameter server replicas.
scope: Optional scope for variables created in this model. If not scope
is supplied, one is generated.
""" """
scope = "dnn" if not scope else scope
super(DNNComposableModel, self).__init__( super(DNNComposableModel, self).__init__(
num_label_columns=num_label_columns, num_label_columns=num_label_columns,
optimizer=optimizer, optimizer=optimizer,
weight_collection_name="DNN",
gradient_clip_norm=gradient_clip_norm, gradient_clip_norm=gradient_clip_norm,
num_ps_replicas=num_ps_replicas) num_ps_replicas=num_ps_replicas,
scope=scope)
self._hidden_units = hidden_units self._hidden_units = hidden_units
self._activation_fn = activation_fn self._activation_fn = activation_fn
self._dropout = dropout self._dropout = dropout
def get_weights(self, model_dir):
"""Returns the weights of the model.
Args:
model_dir: Directory where model parameters, graph and etc. are saved.
Returns:
The weights created by this model.
"""
return [checkpoints.load_variable(
model_dir, name=(self._scope+"/hiddenlayer_%d/weights" % i))
for i, _ in enumerate(self._hidden_units)] + [
checkpoints.load_variable(
model_dir, name=(self._scope+"/logits/weights"))]
def get_bias(self, model_dir):
"""Returns the bias of the model.
Args:
model_dir: Directory where model parameters, graph and etc. are saved.
Returns:
The bias weights created by this model.
"""
return [checkpoints.load_variable(
model_dir, name=(self._scope+"/hiddenlayer_%d/biases" % i))
for i, _ in enumerate(self._hidden_units)] + [
checkpoints.load_variable(
model_dir, name=(self._scope+"/logits/biases"))]
def _add_hidden_layer_summary(self, value, tag): def _add_hidden_layer_summary(self, value, tag):
# TODO(zakaria): Move this code to tf.learn and add test. # TODO(zakaria): Move this code to tf.learn and add test.
logging_ops.scalar_summary("%s:fraction_of_zero_values" % tag, logging_ops.scalar_summary("%s:fraction_of_zero_values" % tag,
@ -244,12 +317,12 @@ class DNNComposableModel(_ComposableModel):
min_slice_size=64 << 20)) min_slice_size=64 << 20))
with variable_scope.variable_op_scope( with variable_scope.variable_op_scope(
features.values(), features.values(),
"input_from_feature_columns", self._scope + "/input_from_feature_columns",
partitioner=input_layer_partitioner) as scope: partitioner=input_layer_partitioner) as scope:
net = layers.input_from_feature_columns( net = layers.input_from_feature_columns(
features, features,
self._get_feature_columns(), self._get_feature_columns(),
weight_collections=[self._weight_collection_name], weight_collections=[self._scope],
scope=scope) scope=scope)
hidden_layer_partitioner = ( hidden_layer_partitioner = (
@ -257,13 +330,13 @@ class DNNComposableModel(_ComposableModel):
max_partitions=self._num_ps_replicas)) max_partitions=self._num_ps_replicas))
for layer_id, num_hidden_units in enumerate(self._hidden_units): for layer_id, num_hidden_units in enumerate(self._hidden_units):
with variable_scope.variable_op_scope( with variable_scope.variable_op_scope(
[net], "hiddenlayer_%d" % layer_id, [net], self._scope + "/hiddenlayer_%d" % layer_id,
partitioner=hidden_layer_partitioner) as scope: partitioner=hidden_layer_partitioner) as scope:
net = layers.fully_connected( net = layers.fully_connected(
net, net,
num_hidden_units, num_hidden_units,
activation_fn=self._activation_fn, activation_fn=self._activation_fn,
variables_collections=[self._weight_collection_name], variables_collections=[self._scope],
scope=scope) scope=scope)
if self._dropout is not None and is_training: if self._dropout is not None and is_training:
net = layers.dropout( net = layers.dropout(
@ -272,15 +345,15 @@ class DNNComposableModel(_ComposableModel):
self._add_hidden_layer_summary(net, scope.name) self._add_hidden_layer_summary(net, scope.name)
with variable_scope.variable_op_scope( with variable_scope.variable_op_scope(
[net], "dnn_logits", [net], self._scope + "/logits",
partitioner=hidden_layer_partitioner) as scope: partitioner=hidden_layer_partitioner) as scope:
logits = layers.fully_connected( logits = layers.fully_connected(
net, net,
self._num_label_columns, self._num_label_columns,
activation_fn=None, activation_fn=None,
variables_collections=[self._weight_collection_name], variables_collections=[self._scope],
scope=scope) scope=scope)
self._add_hidden_layer_summary(logits, "dnn_logits") self._add_hidden_layer_summary(logits, "logits")
return logits return logits
def _get_default_optimizer(self, optimizer_name=None): def _get_default_optimizer(self, optimizer_name=None):

View File

@ -19,6 +19,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tempfile
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib import layers from tensorflow.contrib import layers
@ -42,7 +44,7 @@ class _BaseEstimatorForTest(estimator.BaseEstimator):
def __init__(self, def __init__(self,
target_column, target_column,
feature_columns): feature_columns):
super(_BaseEstimatorForTest, self).__init__() super(_BaseEstimatorForTest, self).__init__(model_dir=tempfile.mkdtemp())
self._target_column = target_column self._target_column = target_column
self._feature_columns = feature_columns self._feature_columns = feature_columns

View File

@ -71,9 +71,9 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
Args: Args:
target_column: A _TargetColumn object. target_column: A _TargetColumn object.
model_dir: Directory to save model parameters, graph and etc. This can also model_dir: Directory to save model parameters, graph and etc. This can
be used to load checkpoints from the directory into a estimator to continue also be used to load checkpoints from the directory into a estimator
training a previously saved model. to continue training a previously saved model.
linear_feature_columns: An iterable containing all the feature columns linear_feature_columns: An iterable containing all the feature columns
used by linear part of the model. All items in the set should be used by linear part of the model. All items in the set should be
instances of classes derived from `FeatureColumn`. instances of classes derived from `FeatureColumn`.
@ -102,8 +102,8 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
ValueError: If both linear_feature_columns and dnn_features_columns are ValueError: If both linear_feature_columns and dnn_features_columns are
empty at the same time. empty at the same time.
""" """
super(_DNNLinearCombinedBaseEstimator, self).__init__(model_dir=model_dir, super(_DNNLinearCombinedBaseEstimator, self).__init__(
config=config) model_dir=model_dir, config=config)
num_ps_replicas = config.num_ps_replicas if config else 0 num_ps_replicas = config.num_ps_replicas if config else 0
@ -124,8 +124,6 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
self._linear_feature_columns = linear_feature_columns self._linear_feature_columns = linear_feature_columns
self._linear_optimizer = linear_optimizer self._linear_optimizer = linear_optimizer
self._linear_weight_collection = (
self._linear_model.get_weight_collection_name())
self._dnn_feature_columns = dnn_feature_columns self._dnn_feature_columns = dnn_feature_columns
self._dnn_hidden_units = dnn_hidden_units self._dnn_hidden_units = dnn_hidden_units
self._centered_bias_weight_collection = "centered_bias" self._centered_bias_weight_collection = "centered_bias"
@ -135,38 +133,24 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
@property @property
def linear_weights_(self): def linear_weights_(self):
"""Returns weights per feature of the linear part.""" """Returns weights per feature of the linear part."""
all_variables = self.get_variable_names() return self._linear_model.get_weights(model_dir=self._model_dir)
# TODO(ispir): Figure out a better way to retrieve variables for features.
# for example using feature info / columns.
values = {}
for name in all_variables:
if (name.startswith("linear/") and name.rfind("/") == 6 and
name != "linear/bias_weight"):
values[name] = self.get_variable_value(name)
if len(values) == 1:
return values[list(values.keys())[0]]
return values
@property @property
def linear_bias_(self): def linear_bias_(self):
"""Returns bias of the linear part.""" """Returns bias of the linear part."""
return (self.get_variable_value("linear/bias_weight") + return (self._linear_model.get_bias(model_dir=self._model_dir) +
self.get_variable_value("centered_bias_weight")) self.get_variable_value("centered_bias_weight"))
@property @property
def dnn_weights_(self): def dnn_weights_(self):
"""Returns weights of deep neural network part.""" """Returns weights of deep neural network part."""
return [self.get_variable_value("hiddenlayer_%d/weights" % i) return self._dnn_model.get_weights(model_dir=self._model_dir)
for i, _ in enumerate(self._dnn_hidden_units)] + [
self.get_variable_value("dnn_logits/weights")]
@property @property
def dnn_bias_(self): def dnn_bias_(self):
"""Returns bias of deep neural network part.""" """Returns bias of deep neural network part."""
return [self.get_variable_value("hiddenlayer_%d/biases" % i) return (self._dnn_model.get_bias(model_dir=self._model_dir) +
for i, _ in enumerate(self._dnn_hidden_units)] + [ [self.get_variable_value("centered_bias_weight")])
self.get_variable_value("dnn_logits/biases"),
self.get_variable_value("centered_bias_weight")]
def _get_feature_dict(self, features): def _get_feature_dict(self, features):
if isinstance(features, dict): if isinstance(features, dict):
@ -347,9 +331,9 @@ class DNNLinearCombinedClassifier(_DNNLinearCombinedBaseEstimator):
"""Constructs a DNNLinearCombinedClassifier instance. """Constructs a DNNLinearCombinedClassifier instance.
Args: Args:
model_dir: Directory to save model parameters, graph and etc. This can also model_dir: Directory to save model parameters, graph and etc. This can
be used to load checkpoints from the directory into a estimator to continue also be used to load checkpoints from the directory into a estimator
training a previously saved model. to continue training a previously saved model.
n_classes: number of target classes. Default is binary classification. n_classes: number of target classes. Default is binary classification.
weight_column_name: A string defining feature column name representing weight_column_name: A string defining feature column name representing
weights. It is used to down weight or boost examples during training. weights. It is used to down weight or boost examples during training.
@ -532,9 +516,9 @@ class DNNLinearCombinedRegressor(_DNNLinearCombinedBaseEstimator):
"""Initializes a DNNLinearCombinedRegressor instance. """Initializes a DNNLinearCombinedRegressor instance.
Args: Args:
model_dir: Directory to save model parameters, graph and etc. This can also model_dir: Directory to save model parameters, graph and etc. This can
be used to load checkpoints from the directory into a estimator to continue also be used to load checkpoints from the directory into a estimator
training a previously saved model. to continue training a previously saved model.
weight_column_name: A string defining feature column name representing weight_column_name: A string defining feature column name representing
weights. It is used to down weight or boost examples during training. It weights. It is used to down weight or boost examples during training. It
will be multiplied by the loss of the example. will be multiplied by the loss of the example.

View File

@ -23,6 +23,7 @@ import tempfile
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.learn.python.learn.estimators import _sklearn from tensorflow.contrib.learn.python.learn.estimators import _sklearn
@ -458,10 +459,39 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
self.assertLess(loss2, 0.01) self.assertLess(loss2, 0.01)
self.assertTrue('centered_bias_weight' in classifier.get_variable_names()) self.assertTrue('centered_bias_weight' in classifier.get_variable_names())
self.assertNotIn('dnn_logits/biases', classifier.get_variable_names()) self.assertNotIn('dnn/logits/biases', classifier.get_variable_names())
self.assertNotIn('dnn_logits/weights', classifier.get_variable_names()) self.assertNotIn('dnn/logits/weights', classifier.get_variable_names())
self.assertEquals(1, len(classifier.linear_bias_)) self.assertEquals(1, len(classifier.linear_bias_))
self.assertEquals(100, len(classifier.linear_weights_)) self.assertEquals(2, len(classifier.linear_weights_))
self.assertEquals(1, len(classifier.linear_weights_['linear/age/weight']))
self.assertEquals(
100, len(classifier.linear_weights_['linear/language_weights']))
def testLinearOnlyOneFeature(self):
"""Tests that linear-only instantiation works for one feature only."""
def input_fn():
return {
'language': tf.SparseTensor(values=['english'],
indices=[[0, 0]],
shape=[1, 1])
}, tf.constant([[1]])
language = tf.contrib.layers.sparse_column_with_hash_bucket('language', 99)
classifier = tf.contrib.learn.DNNLinearCombinedClassifier(
linear_feature_columns=[language])
classifier.fit(input_fn=input_fn, steps=100)
loss1 = classifier.evaluate(input_fn=input_fn, steps=1)['loss']
classifier.fit(input_fn=input_fn, steps=200)
loss2 = classifier.evaluate(input_fn=input_fn, steps=1)['loss']
self.assertLess(loss2, loss1)
self.assertLess(loss2, 0.01)
self.assertTrue('centered_bias_weight' in classifier.get_variable_names())
self.assertNotIn('dnn/logits/biases', classifier.get_variable_names())
self.assertNotIn('dnn/logits/weights', classifier.get_variable_names())
self.assertEquals(1, len(classifier.linear_bias_))
self.assertEquals(99, len(classifier.linear_weights_))
def testDNNOnly(self): def testDNNOnly(self):
"""Tests that DNN-only instantiation works.""" """Tests that DNN-only instantiation works."""

View File

@ -31,7 +31,9 @@ import six
from tensorflow.contrib import framework as contrib_framework from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib import layers from tensorflow.contrib import layers
from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import graph_actions from tensorflow.contrib.learn.python.learn import graph_actions
from tensorflow.contrib.learn.python.learn import trainable
from tensorflow.contrib.learn.python.learn.estimators import _sklearn as sklearn from tensorflow.contrib.learn.python.learn.estimators import _sklearn as sklearn
from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.contrib.learn.python.learn.estimators import run_config
from tensorflow.contrib.learn.python.learn.estimators import tensor_signature from tensorflow.contrib.learn.python.learn.estimators import tensor_signature
@ -138,7 +140,8 @@ def _get_arguments(func):
return _get_arguments(func.func) return _get_arguments(func.func)
class BaseEstimator(sklearn.BaseEstimator): class BaseEstimator(
sklearn.BaseEstimator, evaluable.Evaluable, trainable.Trainable):
"""Abstract BaseEstimator class to train and evaluate TensorFlow models. """Abstract BaseEstimator class to train and evaluate TensorFlow models.
Concrete implementation of this class should provide the following functions: Concrete implementation of this class should provide the following functions:
@ -158,9 +161,9 @@ class BaseEstimator(sklearn.BaseEstimator):
"""Initializes a BaseEstimator instance. """Initializes a BaseEstimator instance.
Args: Args:
model_dir: Directory to save model parameters, graph and etc. This can also model_dir: Directory to save model parameters, graph and etc. This can
be used to load checkpoints from the directory into a estimator to continue also be used to load checkpoints from the directory into a estimator to
training a previously saved model. continue training a previously saved model.
config: A RunConfig instance. config: A RunConfig instance.
""" """
# Model directory. # Model directory.
@ -196,34 +199,8 @@ class BaseEstimator(sklearn.BaseEstimator):
def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None, def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
monitors=None, max_steps=None): monitors=None, max_steps=None):
"""Trains a model given training data `x` predictions and `y` targets. # pylint: disable=g-doc-args,g-doc-return-or-yield
"""See `Trainable`.
Args:
x: Matrix of shape [n_samples, n_features...]. Can be iterator that
returns arrays of features. The training input samples for fitting the
model. If set, `input_fn` must be `None`.
y: Vector or matrix [n_samples] or [n_samples, n_outputs]. Can be
iterator that returns array of targets. The training target values
(class labels in classification, real numbers in regression). If set,
`input_fn` must be `None`.
input_fn: Input function. If set, `x`, `y`, and `batch_size` must be
`None`.
steps: Number of steps for which to train model. If `None`, train forever.
If set, `max_steps` must be `None`.
batch_size: minibatch size to use on the input, defaults to first
dimension of `x`. Must be `None` if `input_fn` is provided.
monitors: List of `BaseMonitor` subclass instances. Used for callbacks
inside the training loop.
max_steps: Number of total steps for which to train model. If `None`,
train forever. If set, `steps` must be `None`.
Two calls to `fit(steps=100)` means 200 training
iterations. On the other hand, two calls to `fit(max_steps=100)` means
that the second call will not do any iteration since first call did
all 100 steps.
Returns:
`self`, for chaining.
Raises: Raises:
ValueError: If `x` or `y` are not `None` while `input_fn` is not `None`. ValueError: If `x` or `y` are not `None` while `input_fn` is not `None`.
@ -284,61 +261,11 @@ class BaseEstimator(sklearn.BaseEstimator):
return self.fit(x=x, y=y, input_fn=input_fn, steps=steps, return self.fit(x=x, y=y, input_fn=input_fn, steps=steps,
batch_size=batch_size, monitors=monitors) batch_size=batch_size, monitors=monitors)
def evaluate(self, def evaluate(
x=None, self, x=None, y=None, input_fn=None, feed_fn=None, batch_size=None,
y=None, steps=None, metrics=None, name=None):
input_fn=None, # pylint: disable=g-doc-args,g-doc-return-or-yield
feed_fn=None, """See `Evaluable`.
batch_size=None,
steps=None,
metrics=None,
name=None):
"""Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
Args:
x: Matrix of shape [n_samples, n_features...]. Can be iterator that
returns arrays of features. The training input samples for fitting the
model. If set, `input_fn` must be `None`.
y: Vector or matrix [n_samples] or [n_samples, n_outputs]. Can be
iterator that returns array of targets. The training target values
(class labels in classification, real numbers in regression). If set,
`input_fn` must be `None`.
input_fn: Input function. If set, `x`, `y`, and `batch_size` must be
`None`.
feed_fn: Function creating a feed dict every time it is called. Called
once per iteration.
batch_size: minibatch size to use on the input, defaults to first
dimension of `x`, if specified. Must be `None` if `input_fn` is
provided.
steps: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
metrics: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../../metrics/python/metrics/ops/streaming_metrics.py.
name: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
Returns:
Returns `dict` with evaluation results.
Raises: Raises:
ValueError: If at least one of `x` or `y` is provided, and at least one of ValueError: If at least one of `x` or `y` is provided, and at least one of
@ -571,7 +498,7 @@ class BaseEstimator(sklearn.BaseEstimator):
log_every_steps=log_every_steps, log_every_steps=log_every_steps,
supervisor_is_chief=(self._config.task == 0), supervisor_is_chief=(self._config.task == 0),
supervisor_master=self._config.master, supervisor_master=self._config.master,
supervisor_save_model_steps=self._config.save_checkpoints_steps, supervisor_save_model_secs=self._config.save_checkpoints_secs,
keep_checkpoint_max=self._config.keep_checkpoint_max, keep_checkpoint_max=self._config.keep_checkpoint_max,
feed_fn=feed_fn, feed_fn=feed_fn,
steps=steps, steps=steps,
@ -770,9 +697,9 @@ class Estimator(BaseEstimator):
is passed to Estimator in `params` parameter. This allows is passed to Estimator in `params` parameter. This allows
to configure Estimators from hyper parameter tunning. to configure Estimators from hyper parameter tunning.
model_dir: Directory to save model parameters, graph and etc. This can also model_dir: Directory to save model parameters, graph and etc. This can
be used to load checkpoints from the directory into a estimator to continue also be used to load checkpoints from the directory into a estimator to
training a previously saved model. continue training a previously saved model.
config: Configuration object. config: Configuration object.
params: `dict` of hyper parameters that will be passed into `model_fn`. params: `dict` of hyper parameters that will be passed into `model_fn`.
Keys are names of parameters, values are basic python types. Keys are names of parameters, values are basic python types.

View File

@ -36,32 +36,26 @@ _IRIS_INPUT_DIM = 4
def boston_input_fn(num_epochs=None): def boston_input_fn(num_epochs=None):
boston = tf.contrib.learn.datasets.load_boston() boston = tf.contrib.learn.datasets.load_boston()
features = tf.cast( features = tf.reshape(tf.constant(boston.data), [-1, _BOSTON_INPUT_DIM])
tf.reshape(tf.constant(boston.data), [-1, _BOSTON_INPUT_DIM]), tf.float32)
if num_epochs: if num_epochs:
features = tf.train.limit_epochs(features, num_epochs=num_epochs) features = tf.train.limit_epochs(features, num_epochs=num_epochs)
target = tf.cast( target = tf.reshape(tf.constant(boston.target), [-1, 1])
tf.reshape(tf.constant(boston.target), [-1, 1]), tf.float32)
return features, target return features, target
def iris_input_fn(): def iris_input_fn():
iris = tf.contrib.learn.datasets.load_iris() iris = tf.contrib.learn.datasets.load_iris()
features = tf.cast( features = tf.reshape(tf.constant(iris.data), [-1, _IRIS_INPUT_DIM])
tf.reshape(tf.constant(iris.data), [-1, _IRIS_INPUT_DIM]), tf.float32) target = tf.reshape(tf.constant(iris.target), [-1])
target = tf.cast(
tf.reshape(tf.constant(iris.target), [-1]), tf.int32)
return features, target return features, target
def boston_eval_fn(): def boston_eval_fn():
boston = tf.contrib.learn.datasets.load_boston() boston = tf.contrib.learn.datasets.load_boston()
n_examples = len(boston.target) n_examples = len(boston.target)
features = tf.cast( features = tf.reshape(
tf.reshape(tf.constant(boston.data), [n_examples, _BOSTON_INPUT_DIM]), tf.constant(boston.data), [n_examples, _BOSTON_INPUT_DIM])
tf.float32) target = tf.reshape(tf.constant(boston.target), [n_examples, 1])
target = tf.cast(
tf.reshape(tf.constant(boston.target), [n_examples, 1]), tf.float32)
return tf.concat(0, [features, features]), tf.concat(0, [target, target]) return tf.concat(0, [features, features]), tf.concat(0, [target, target])
@ -188,7 +182,7 @@ class EstimatorTest(tf.test.TestCase):
with self.assertRaises(tf.contrib.learn.NotFittedError): with self.assertRaises(tf.contrib.learn.NotFittedError):
_ = est.evaluate( _ = est.evaluate(
x=boston.data, x=boston.data,
y=boston.target.astype(np.float32)) y=boston.target.astype(np.float64))
with self.assertRaises(tf.contrib.learn.NotFittedError): with self.assertRaises(tf.contrib.learn.NotFittedError):
est.predict(x=boston.data) est.predict(x=boston.data)
@ -197,10 +191,11 @@ class EstimatorTest(tf.test.TestCase):
output_dir = tempfile.mkdtemp() output_dir = tempfile.mkdtemp()
est = tf.contrib.learn.Estimator(model_fn=linear_model_fn, est = tf.contrib.learn.Estimator(model_fn=linear_model_fn,
model_dir=output_dir) model_dir=output_dir)
est.fit(x=boston.data, y=boston.target.astype(np.float32), steps=50) float64_target = boston.target.astype(np.float64)
est.fit(x=boston.data, y=float64_target, steps=50)
scores = est.evaluate( scores = est.evaluate(
x=boston.data, x=boston.data,
y=boston.target.astype(np.float32), y=float64_target,
metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error}) metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error})
del est del est
# Create another estimator object with the same output dir. # Create another estimator object with the same output dir.
@ -210,19 +205,19 @@ class EstimatorTest(tf.test.TestCase):
# Check we can evaluate and predict. # Check we can evaluate and predict.
scores2 = est2.evaluate( scores2 = est2.evaluate(
x=boston.data, x=boston.data,
y=boston.target.astype(np.float32), y=float64_target,
metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error}) metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error})
self.assertAllClose(scores2['MSE'], self.assertAllClose(scores2['MSE'],
scores['MSE']) scores['MSE'])
predictions = est2.predict(x=boston.data) predictions = est2.predict(x=boston.data)
other_score = _sklearn.mean_squared_error(predictions, boston.target) other_score = _sklearn.mean_squared_error(predictions, float64_target)
self.assertAllClose(other_score, scores['MSE']) self.assertAllClose(other_score, scores['MSE'])
# Check we can keep training. # Check we can keep training.
est2.fit(x=boston.data, y=boston.target.astype(np.float32), steps=100) est2.fit(x=boston.data, y=float64_target, steps=100)
scores3 = est2.evaluate( scores3 = est2.evaluate(
x=boston.data, x=boston.data,
y=boston.target.astype(np.float32), y=float64_target,
metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error}) metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error})
self.assertLess(scores3['MSE'], scores['MSE']) self.assertLess(scores3['MSE'], scores['MSE'])
@ -230,15 +225,16 @@ class EstimatorTest(tf.test.TestCase):
boston = tf.contrib.learn.datasets.load_boston() boston = tf.contrib.learn.datasets.load_boston()
est = tf.contrib.learn.Estimator(model_fn=linear_model_params_fn, est = tf.contrib.learn.Estimator(model_fn=linear_model_params_fn,
params={'learning_rate': 0.01}) params={'learning_rate': 0.01})
est.fit(x=boston.data, y=boston.target.astype(np.float32), steps=100) est.fit(x=boston.data, y=boston.target, steps=100)
def testBostonAll(self): def testBostonAll(self):
boston = tf.contrib.learn.datasets.load_boston() boston = tf.contrib.learn.datasets.load_boston()
est = tf.contrib.learn.Estimator(model_fn=linear_model_fn) est = tf.contrib.learn.Estimator(model_fn=linear_model_fn)
est.fit(x=boston.data, y=boston.target.astype(np.float32), steps=100) float64_target = boston.target.astype(np.float64)
est.fit(x=boston.data, y=float64_target, steps=100)
scores = est.evaluate( scores = est.evaluate(
x=boston.data, x=boston.data,
y=boston.target.astype(np.float32), y=float64_target,
metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error}) metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error})
predictions = est.predict(x=boston.data) predictions = est.predict(x=boston.data)
other_score = _sklearn.mean_squared_error(predictions, boston.target) other_score = _sklearn.mean_squared_error(predictions, boston.target)
@ -277,7 +273,7 @@ class EstimatorTest(tf.test.TestCase):
iris = tf.contrib.learn.datasets.load_iris() iris = tf.contrib.learn.datasets.load_iris()
est = tf.contrib.learn.Estimator(model_fn=logistic_model_no_mode_fn) est = tf.contrib.learn.Estimator(model_fn=logistic_model_no_mode_fn)
x_iter = itertools.islice(iris.data, 100) x_iter = itertools.islice(iris.data, 100)
y_iter = itertools.islice(np.int32(iris.target), 100) y_iter = itertools.islice(iris.target, 100)
est.fit(x_iter, y_iter, steps=100) est.fit(x_iter, y_iter, steps=100)
_ = est.evaluate(input_fn=iris_input_fn, steps=1) _ = est.evaluate(input_fn=iris_input_fn, steps=1)
predictions = est.predict(x=iris.data)['class'] predictions = est.predict(x=iris.data)['class']
@ -374,19 +370,16 @@ class InferRealValuedColumnsTest(tf.test.TestCase):
'': tf.FixedLenFeature(shape=expected_shape, dtype=expected_dtype) '': tf.FixedLenFeature(shape=expected_shape, dtype=expected_dtype)
}, feature_column.config) }, feature_column.config)
# Note: See tf.contrib.learn.io.data_feeder for why int32 converts to float32.
def testInt32Input(self): def testInt32Input(self):
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input( feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(
np.ones(shape=[7, 8], dtype=np.int32)) np.ones(shape=[7, 8], dtype=np.int32))
self._assert_single_feature_column([8], tf.float32, feature_columns) self._assert_single_feature_column([8], tf.int32, feature_columns)
def testInt32InputFn(self): def testInt32InputFn(self):
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input_fn( feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input_fn(
lambda: (tf.ones(shape=[7, 8], dtype=tf.int32), None)) lambda: (tf.ones(shape=[7, 8], dtype=tf.int32), None))
self._assert_single_feature_column([8], tf.int32, feature_columns) self._assert_single_feature_column([8], tf.int32, feature_columns)
# Note: See tf.contrib.learn.io.data_feeder for why int64 doesn't convert to
# float64.
def testInt64Input(self): def testInt64Input(self):
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input( feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(
np.ones(shape=[7, 8], dtype=np.int64)) np.ones(shape=[7, 8], dtype=np.int64))
@ -407,12 +400,10 @@ class InferRealValuedColumnsTest(tf.test.TestCase):
lambda: (tf.ones(shape=[7, 8], dtype=tf.float32), None)) lambda: (tf.ones(shape=[7, 8], dtype=tf.float32), None))
self._assert_single_feature_column([8], tf.float32, feature_columns) self._assert_single_feature_column([8], tf.float32, feature_columns)
# Note: See tf.contrib.learn.io.data_feeder for why float64 converts to
# float32.
def testFloat64Input(self): def testFloat64Input(self):
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input( feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(
np.ones(shape=[7, 8], dtype=np.float64)) np.ones(shape=[7, 8], dtype=np.float64))
self._assert_single_feature_column([8], tf.float32, feature_columns) self._assert_single_feature_column([8], tf.float64, feature_columns)
def testFloat64InputFn(self): def testFloat64InputFn(self):
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input_fn( feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input_fn(
@ -420,9 +411,10 @@ class InferRealValuedColumnsTest(tf.test.TestCase):
self._assert_single_feature_column([8], tf.float64, feature_columns) self._assert_single_feature_column([8], tf.float64, feature_columns)
def testBoolInput(self): def testBoolInput(self):
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input( with self.assertRaisesRegexp(
ValueError, 'on integer or non floating types are not supported'):
tf.contrib.learn.infer_real_valued_columns_from_input(
np.array([[False for _ in xrange(8)] for _ in xrange(7)])) np.array([[False for _ in xrange(8)] for _ in xrange(7)]))
self._assert_single_feature_column([8], tf.float32, feature_columns)
def testBoolInputFn(self): def testBoolInputFn(self):
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
@ -431,18 +423,12 @@ class InferRealValuedColumnsTest(tf.test.TestCase):
tf.contrib.learn.infer_real_valued_columns_from_input_fn( tf.contrib.learn.infer_real_valued_columns_from_input_fn(
lambda: (tf.constant(False, shape=[7, 8], dtype=tf.bool), None)) lambda: (tf.constant(False, shape=[7, 8], dtype=tf.bool), None))
def testInvalidStringInput(self):
# pylint: disable=g-long-lambda
with self.assertRaisesRegexp(
ValueError, 'could not convert string to float'):
tf.contrib.learn.infer_real_valued_columns_from_input(
np.array([['foo%d' % i for i in xrange(8)] for _ in xrange(7)]))
def testStringInput(self): def testStringInput(self):
with self.assertRaisesRegexp(
ValueError, 'on integer or non floating types are not supported'):
# pylint: disable=g-long-lambda # pylint: disable=g-long-lambda
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input( tf.contrib.learn.infer_real_valued_columns_from_input(
np.array([['%d.0' % i for i in xrange(8)] for _ in xrange(7)])) np.array([['%d.0' % i for i in xrange(8)] for _ in xrange(7)]))
self._assert_single_feature_column([8], tf.float32, feature_columns)
def testStringInputFn(self): def testStringInputFn(self):
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
@ -457,13 +443,13 @@ class InferRealValuedColumnsTest(tf.test.TestCase):
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input_fn( feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input_fn(
boston_input_fn) boston_input_fn)
self._assert_single_feature_column( self._assert_single_feature_column(
[_BOSTON_INPUT_DIM], tf.float32, feature_columns) [_BOSTON_INPUT_DIM], tf.float64, feature_columns)
def testIrisInputFn(self): def testIrisInputFn(self):
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input_fn( feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input_fn(
iris_input_fn) iris_input_fn)
self._assert_single_feature_column( self._assert_single_feature_column(
[_IRIS_INPUT_DIM], tf.float32, feature_columns) [_IRIS_INPUT_DIM], tf.float64, feature_columns)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()

View File

@ -122,9 +122,9 @@ class LinearClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
feature_columns: An iterable containing all the feature columns used by feature_columns: An iterable containing all the feature columns used by
the model. All items in the set should be instances of classes derived the model. All items in the set should be instances of classes derived
from `FeatureColumn`. from `FeatureColumn`.
model_dir: Directory to save model parameters, graph and etc. This can also model_dir: Directory to save model parameters, graph and etc. This can
be used to load checkpoints from the directory into a estimator to continue also be used to load checkpoints from the directory into a estimator
training a previously saved model. to continue training a previously saved model.
n_classes: number of target classes. Default is binary classification. n_classes: number of target classes. Default is binary classification.
weight_column_name: A string defining feature column name representing weight_column_name: A string defining feature column name representing
weights. It is used to down weight or boost examples during training. It weights. It is used to down weight or boost examples during training. It
@ -186,8 +186,8 @@ class LinearClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
columns_to_tensors=features, columns_to_tensors=features,
feature_columns=self._linear_feature_columns, feature_columns=self._linear_feature_columns,
num_outputs=self._target_column.num_label_columns, num_outputs=self._target_column.num_label_columns,
weight_collections=[self._linear_weight_collection], weight_collections=[self._linear_model.get_scope_name()],
scope="linear") scope=self._linear_model.get_scope_name())
with ops.control_dependencies([self._centered_bias()]): with ops.control_dependencies([self._centered_bias()]):
loss = self._target_column.loss(logits, targets, features) loss = self._target_column.loss(logits, targets, features)
logging_ops.scalar_summary("loss", loss) logging_ops.scalar_summary("loss", loss)
@ -282,9 +282,9 @@ class LinearRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):
feature_columns: An iterable containing all the feature columns used by feature_columns: An iterable containing all the feature columns used by
the model. All items in the set should be instances of classes derived the model. All items in the set should be instances of classes derived
from `FeatureColumn`. from `FeatureColumn`.
model_dir: Directory to save model parameters, graph, etc. This can also model_dir: Directory to save model parameters, graph, etc. This can
be used to load checkpoints from the directory into a estimator to continue also be used to load checkpoints from the directory into a estimator
training a previously saved model. to continue training a previously saved model.
weight_column_name: A string defining feature column name representing weight_column_name: A string defining feature column name representing
weights. It is used to down weight or boost examples during training. It weights. It is used to down weight or boost examples during training. It
will be multiplied by the loss of the example. will be multiplied by the loss of the example.

View File

@ -56,7 +56,7 @@ class LogisticRegressor(estimator.Estimator):
model_fn: Model function. See superclass Estimator for more details. This model_fn: Model function. See superclass Estimator for more details. This
expects the returned predictions to be probabilities in [0.0, 1.0]. expects the returned predictions to be probabilities in [0.0, 1.0].
thresholds: List of floating point thresholds to use for accuracy, thresholds: List of floating point thresholds to use for accuracy,
precision, and recall metrics. If None, defaults to [0.5]. precision, and recall metrics. If `None`, defaults to `[0.5]`.
model_dir: Directory to save model parameters, graphs, etc. This can also model_dir: Directory to save model parameters, graphs, etc. This can also
be used to load checkpoints from the directory into a estimator to continue be used to load checkpoints from the directory into a estimator to continue
training a previously saved model. training a previously saved model.

View File

@ -17,8 +17,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import time
import numpy as np import numpy as np
import six import six
@ -26,18 +24,36 @@ from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib.learn.python.learn import monitors as mon from tensorflow.contrib.learn.python.learn import monitors as mon
from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import run_config
from tensorflow.contrib.tensor_forest.client import eval_metrics from tensorflow.contrib.tensor_forest.client import eval_metrics
from tensorflow.contrib.tensor_forest.data import data_ops from tensorflow.contrib.tensor_forest.data import data_ops
from tensorflow.contrib.tensor_forest.python import tensor_forest from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops from tensorflow.python.ops import state_ops
def _assert_float32(tensors):
"""Assert all tensors are float32.
Args:
tensors: `Tensor` or `dict` of `Tensor` objects.
Raises:
TypeError: if any tensor is not float32.
"""
if not isinstance(tensors, dict):
tensors = [tensors]
else:
tensors = tensors.values()
for tensor in tensors:
if tensor.dtype.base_dtype != dtypes.float32:
raise TypeError('Expected dtype=float32, %s.' % tensor)
class LossMonitor(mon.EveryN): class LossMonitor(mon.EveryN):
"""Terminates training when training loss stops decreasing.""" """Terminates training when training loss stops decreasing."""
@ -146,6 +162,8 @@ class TensorForestEstimator(estimator.BaseEstimator):
Returns: Returns:
Tuple of train `Operation` and loss `Tensor`. Tuple of train `Operation` and loss `Tensor`.
""" """
_assert_float32(features)
_assert_float32(targets)
features, spec = data_ops.ParseDataTensorOrDict(features) features, spec = data_ops.ParseDataTensorOrDict(features)
labels = data_ops.ParseLabelTensorOrDict(targets) labels = data_ops.ParseLabelTensorOrDict(targets)
@ -168,6 +186,7 @@ class TensorForestEstimator(estimator.BaseEstimator):
return train, self.training_loss return train, self.training_loss
def _get_predict_ops(self, features): def _get_predict_ops(self, features):
_assert_float32(features)
graph_builder = self.graph_builder_class( graph_builder = self.graph_builder_class(
self.params, device_assigner=self.device_assigner, training=False, self.params, device_assigner=self.device_assigner, training=False,
**self.construction_args) **self.construction_args)
@ -175,6 +194,8 @@ class TensorForestEstimator(estimator.BaseEstimator):
return graph_builder.inference_graph(features, data_spec=spec) return graph_builder.inference_graph(features, data_spec=spec)
def _get_eval_ops(self, features, targets, metrics): def _get_eval_ops(self, features, targets, metrics):
_assert_float32(features)
_assert_float32(targets)
features, spec = data_ops.ParseDataTensorOrDict(features) features, spec = data_ops.ParseDataTensorOrDict(features)
labels = data_ops.ParseLabelTensorOrDict(targets) labels = data_ops.ParseLabelTensorOrDict(targets)

View File

@ -19,11 +19,20 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np
import tensorflow as tf import tensorflow as tf
class TensorForestTrainerTests(tf.test.TestCase): class TensorForestTrainerTests(tf.test.TestCase):
def testFloat64(self):
hparams = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
num_trees=3, max_nodes=1000, num_classes=3, num_features=4)
classifier = tf.contrib.learn.TensorForestEstimator(hparams)
iris = tf.contrib.learn.datasets.load_iris()
with self.assertRaisesRegexp(TypeError, 'float32'):
classifier.fit(x=iris.data, y=iris.target, steps=100)
def testClassification(self): def testClassification(self):
"""Tests multi-class classification using matrix data as input.""" """Tests multi-class classification using matrix data as input."""
hparams = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams( hparams = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
@ -31,9 +40,11 @@ class TensorForestTrainerTests(tf.test.TestCase):
classifier = tf.contrib.learn.TensorForestEstimator(hparams) classifier = tf.contrib.learn.TensorForestEstimator(hparams)
iris = tf.contrib.learn.datasets.load_iris() iris = tf.contrib.learn.datasets.load_iris()
data = iris.data.astype(np.float32)
target = iris.target.astype(np.float32)
classifier.fit(x=iris.data, y=iris.target, steps=100) classifier.fit(x=data, y=target, steps=100)
classifier.evaluate(x=iris.data, y=iris.target, steps=10) classifier.evaluate(x=data, y=target, steps=10)
def testRegression(self): def testRegression(self):
"""Tests multi-class classification using matrix data as input.""" """Tests multi-class classification using matrix data as input."""
@ -45,9 +56,11 @@ class TensorForestTrainerTests(tf.test.TestCase):
regressor = tf.contrib.learn.TensorForestEstimator(hparams) regressor = tf.contrib.learn.TensorForestEstimator(hparams)
boston = tf.contrib.learn.datasets.load_boston() boston = tf.contrib.learn.datasets.load_boston()
data = boston.data.astype(np.float32)
target = boston.target.astype(np.float32)
regressor.fit(x=boston.data, y=boston.target, steps=100) regressor.fit(x=data, y=target, steps=100)
regressor.evaluate(x=boston.data, y=boston.target, steps=10) regressor.evaluate(x=data, y=target, steps=10)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -38,8 +38,7 @@ class RunConfig(object):
save_summary_steps=100, save_summary_steps=100,
save_checkpoints_secs=60, save_checkpoints_secs=60,
keep_checkpoint_max=5, keep_checkpoint_max=5,
keep_checkpoint_every_n_hours=10000, keep_checkpoint_every_n_hours=10000):
save_checkpoints_steps=1000):
"""Constructor. """Constructor.
Args: Args:
@ -61,7 +60,6 @@ class RunConfig(object):
keep_checkpoint_every_n_hours: Number of hours between each checkpoint keep_checkpoint_every_n_hours: Number of hours between each checkpoint
to be saved. The default value of 10,000 hours effectively disables to be saved. The default value of 10,000 hours effectively disables
the feature. the feature.
save_checkpoints_steps: Number of steps between each checkpoint saving.
""" """
self.master = master self.master = master
self.task = task self.task = task
@ -77,4 +75,3 @@ class RunConfig(object):
self.save_checkpoints_secs = save_checkpoints_secs self.save_checkpoints_secs = save_checkpoints_secs
self.keep_checkpoint_max = keep_checkpoint_max self.keep_checkpoint_max = keep_checkpoint_max
self.keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours self.keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
self.save_checkpoints_steps = save_checkpoints_steps

View File

@ -0,0 +1,81 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""`Evaluable` interface."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
class Evaluable(object):
"""Interface for objects that are evaluatable by, e.g., `Experiment`.
"""
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def evaluate(
self, x=None, y=None, input_fn=None, feed_fn=None, batch_size=None,
steps=None, metrics=None, name=None):
"""Evaluates given model with provided evaluation data.
Evaluates on the given input data. If `input_fn` is provided, that
input function should raise an end-of-input exception (`OutOfRangeError` or
`StopIteration`) after one epoch of the training data has been provided.
By default, the whole evaluation dataset is used. If `steps` is provided,
only `steps` batches of size `batch_size` are processed.
The return value is a dict containing the metrics specified in `metrics`, as
well as an entry `global_step` which contains the value of the global step
for which this evaluation was performed.
Args:
x: Matrix of shape [n_samples, n_features...]. Can be iterator that
returns arrays of features. The training input samples for fitting the
model. If set, `input_fn` must be `None`.
y: Vector or matrix [n_samples] or [n_samples, n_outputs]. Can be
iterator that returns array of targets. The training target values
(class labels in classification, real numbers in regression). If set,
`input_fn` must be `None`.
input_fn: Input function. If set, `x`, `y`, and `batch_size` must be
`None`.
feed_fn: Function creating a feed dict every time it is called. Called
once per iteration. Must be `None` if `input_fn` is provided.
batch_size: minibatch size to use on the input, defaults to first
dimension of `x`, if specified. Must be `None` if `input_fn` is
provided.
steps: Number of steps for which to evaluate model. If `None`, evaluate
until running tensors generated by `metrics` raises an exception.
metrics: Dict of metric ops to run. If `None`, the default metric
functions are used; if `{}`, no metrics are used. If model has one
output (i.e., returning single predction), keys are `str`, e.g.
`'accuracy'` - just a name of the metric that will show up in
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
the predictions to run this metric on.
Metric ops should support streaming, e.g., returning
update_op and value tensors. See more details in
../../../metrics/python/metrics/ops/streaming_metrics.py.
name: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
Returns:
Returns `dict` with evaluation results.
"""
raise NotImplementedError

View File

@ -21,7 +21,9 @@ from __future__ import print_function
import time import time
from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import monitors from tensorflow.contrib.learn.python.learn import monitors
from tensorflow.contrib.learn.python.learn import trainable
from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
from tensorflow.python.platform import flags from tensorflow.python.platform import flags
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
@ -47,7 +49,7 @@ class Experiment(object):
"""Constructor for `Experiment`. """Constructor for `Experiment`.
Args: Args:
estimator: `Estimator` object. estimator: Object implementing `Trainable` and `Evaluable`.
train_input_fn: function, returns features and targets for training. train_input_fn: function, returns features and targets for training.
eval_input_fn: function, returns features and targets for evaluation. If eval_input_fn: function, returns features and targets for evaluation. If
`eval_steps` is `None`, this should be configured only to produce for a `eval_steps` is `None`, this should be configured only to produce for a
@ -67,7 +69,14 @@ class Experiment(object):
continuous_eval_throttle_secs: Do not re-evaluate unless the last continuous_eval_throttle_secs: Do not re-evaluate unless the last
evaluation was started at least this many seconds ago for evaluation was started at least this many seconds ago for
continuous_eval(). continuous_eval().
Raises:
ValueError: if `estimator` does not implement `Evaluable` and `Trainable`.
""" """
if not isinstance(estimator, evaluable.Evaluable):
raise ValueError("`estimator` must implement `Evaluable`.")
if not isinstance(estimator, trainable.Trainable):
raise ValueError("`estimator` must implement `Trainable`.")
super(Experiment, self).__init__() super(Experiment, self).__init__()
self._estimator = estimator self._estimator = estimator
self._train_input_fn = train_input_fn self._train_input_fn = train_input_fn

View File

@ -130,7 +130,7 @@ def _supervised_train(graph,
log_every_steps=10, log_every_steps=10,
supervisor_is_chief=True, supervisor_is_chief=True,
supervisor_master='', supervisor_master='',
supervisor_save_model_steps=1000, supervisor_save_model_secs=600,
keep_checkpoint_max=5, keep_checkpoint_max=5,
supervisor_save_summaries_steps=100, supervisor_save_summaries_steps=100,
feed_fn=None, feed_fn=None,
@ -171,8 +171,8 @@ def _supervised_train(graph,
supervisor_is_chief: Whether the current process is the chief supervisor in supervisor_is_chief: Whether the current process is the chief supervisor in
charge of restoring the model and running standard services. charge of restoring the model and running standard services.
supervisor_master: The master string to use when preparing the session. supervisor_master: The master string to use when preparing the session.
supervisor_save_model_steps: Save a checkpoint every supervisor_save_model_secs: Save model every
`supervisor_save_model_steps` steps when training. `supervisor_save_model_secs` seconds when training.
keep_checkpoint_max: The maximum number of recent checkpoint files to keep_checkpoint_max: The maximum number of recent checkpoint files to
keep. As new files are created, older files are deleted. If None or 0, keep. As new files are created, older files are deleted. If None or 0,
all checkpoint files are kept. This is simply passed as the max_to_keep all checkpoint files are kept. This is simply passed as the max_to_keep
@ -251,15 +251,18 @@ def _supervised_train(graph,
init_fn=init_fn, init_fn=init_fn,
keep_checkpoint_max=keep_checkpoint_max) keep_checkpoint_max=keep_checkpoint_max)
if supervisor_is_chief: if supervisor_is_chief:
if scaffold.summary_op is not None:
monitors.append(monitors_lib.SummarySaver(
scaffold.summary_op,
save_steps=supervisor_save_summaries_steps,
summary_writer=summary_writer))
if supervisor_save_model_steps > 0:
monitors.append( monitors.append(
monitors_lib.CheckpointSaver(supervisor_save_model_steps, monitors_lib.SummarySaver(
scaffold.saver, output_dir)) summary_op=None,
save_steps=supervisor_save_summaries_steps,
summary_writer=summary_writer,
scaffold=scaffold))
if supervisor_save_model_secs > 0:
monitors.append(
monitors_lib.CheckpointSaver(
output_dir,
save_secs=supervisor_save_model_secs,
scaffold=scaffold))
if steps is not None or max_steps is not None: if steps is not None or max_steps is not None:
monitors.append(monitors_lib.StopAtStep(steps, max_steps)) monitors.append(monitors_lib.StopAtStep(steps, max_steps))

View File

@ -30,6 +30,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
# pylint: disable=g-multiple-import,g-bad-import-order # pylint: disable=g-multiple-import,g-bad-import-order
from .pandas_io import HAS_PANDAS, extract_pandas_data, extract_pandas_matrix, extract_pandas_labels from .pandas_io import HAS_PANDAS, extract_pandas_data, extract_pandas_matrix, extract_pandas_labels
@ -206,6 +207,13 @@ def _access(data, iloc):
return data[iloc] return data[iloc]
def _check_dtype(dtype):
if dtypes.as_dtype(dtype) == dtypes.float64:
logging.warn(
'float64 is not supported by many models, consider casting to float32.')
return dtype
class DataFeeder(object): class DataFeeder(object):
"""Data feeder is an example class to sample data for TF trainer.""" """Data feeder is an example class to sample data for TF trainer."""
@ -215,60 +223,82 @@ class DataFeeder(object):
"""Initializes a DataFeeder instance. """Initializes a DataFeeder instance.
Args: Args:
x: feature Nd numpy matrix of shape [n_samples, n_features, ...]. x: Feature Nd numpy matrix of shape `[n_samples, n_features, ...]`.
y: target vector, either floats for regression or class id for y: Target vector, either floats for regression or class id for
classification. If matrix, will consider as a sequence classification. If matrix, will consider as a sequence
of targets. Can be None for unsupervised setting. of targets. Can be `None` for unsupervised setting.
n_classes: number of classes, 0 and 1 are considered regression, None will n_classes: Number of classes, 0 and 1 are considered regression, `None`
pass through the input labels without one-hot conversion. will pass through the input labels without one-hot conversion.
batch_size: mini batch size to accumulate. batch_size: Mini-batch size to accumulate.
random_state: numpy RandomState object to reproduce sampling. shuffle: Whether to shuffle `x`.
random_state: Numpy `RandomState` object to reproduce sampling.
epochs: Number of times to iterate over input data before raising
`StopIteration` exception.
Attributes: Attributes:
x: input features. x: Input features.
y: input target. y: Input target.
n_classes: number of classes (if None, pass through indices without n_classes: Number of classes (if `None`, pass through indices without
one-hot conversion). one-hot conversion).
batch_size: mini batch size to accumulate. batch_size: Mini-batch size to accumulate.
input_shape: shape of the input. input_shape: Shape of the input.
output_shape: shape of the output. output_shape: Shape of the output.
input_dtype: dtype of input. input_dtype: DType of input.
output_dtype: dtype of output. output_dtype: DType of output.
""" """
x_dtype = np.int64 if x.dtype == np.int64 else np.float32 self._x = check_array(x, dtype=x.dtype)
# self.n_classes is None means we're passing in raw target indices.
y_dtype = ( y_dtype = (
np.int64 if n_classes is not None and n_classes > 1 else np.float32) np.int64 if n_classes is not None and n_classes > 1 else np.float32)
self.x = check_array(x, dtype=x_dtype)
# self.n_classes is None means we're passing in raw target indices
if n_classes is not None: if n_classes is not None:
self.y = (None if y is None else check_array(y, dtype=y_dtype)) self._y = (None if y is None else check_array(y, dtype=y_dtype))
elif isinstance(y, list):
self._y = np.array(y)
else: else:
self.y = y self._y = y
if isinstance(self.y, list):
self.y = np.array(y)
self.n_classes = n_classes self.n_classes = n_classes
self.max_epochs = epochs self.max_epochs = epochs
self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape( self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
self.x.shape, None if self.y is None else self.y.shape, n_classes, self._x.shape, None if self._y is None else self._y.shape, n_classes,
batch_size) batch_size)
# Input dtype matches dtype of x. # Input dtype matches dtype of x.
self.input_dtype = x_dtype self._input_dtype = _check_dtype(self._x.dtype)
# self.n_classes is None means we're passing in raw target indices # self.n_classes is None means we're passing in raw target indices
if n_classes is not None or y is None: if n_classes is not None or self._y is None:
self.output_dtype = np.float32 self._output_dtype = np.float32
else: else:
self.output_dtype = self.y.dtype self._output_dtype = _check_dtype(self._y.dtype)
self.shuffle = shuffle self._shuffle = shuffle
self.random_state = np.random.RandomState( self.random_state = np.random.RandomState(
42) if random_state is None else random_state 42) if random_state is None else random_state
if self.shuffle: if self._shuffle:
self.indices = self.random_state.permutation(self.x.shape[0]) self.indices = self.random_state.permutation(self._x.shape[0])
else: else:
self.indices = np.array(range(self.x.shape[0])) self.indices = np.array(range(self._x.shape[0]))
self.offset = 0 self.offset = 0
self.epoch = 0 self.epoch = 0
self._epoch_placeholder = None self._epoch_placeholder = None
@property
def x(self):
return self._x
@property
def y(self):
return self._y
@property
def shuffle(self):
return self._shuffle
@property
def input_dtype(self):
return self._input_dtype
@property
def output_dtype(self):
return self._output_dtype
@property @property
def batch_size(self): def batch_size(self):
return self._batch_size return self._batch_size
@ -291,7 +321,7 @@ class DataFeeder(object):
""" """
input_shape = [None] + self.input_shape[1:] input_shape = [None] + self.input_shape[1:]
self._input_placeholder = array_ops.placeholder( self._input_placeholder = array_ops.placeholder(
dtypes.as_dtype(self.input_dtype), dtypes.as_dtype(self._input_dtype),
input_shape, input_shape,
name='input') name='input')
if self.output_shape is None: if self.output_shape is None:
@ -299,7 +329,7 @@ class DataFeeder(object):
else: else:
output_shape = [None] + self.output_shape[1:] output_shape = [None] + self.output_shape[1:]
self._output_placeholder = array_ops.placeholder( self._output_placeholder = array_ops.placeholder(
dtypes.as_dtype(self.output_dtype), dtypes.as_dtype(self._output_dtype),
output_shape, output_shape,
name='output') name='output')
return self._input_placeholder, self._output_placeholder return self._input_placeholder, self._output_placeholder
@ -345,20 +375,20 @@ class DataFeeder(object):
feed_dict[self._epoch_placeholder.name] = [self.epoch] feed_dict[self._epoch_placeholder.name] = [self.epoch]
# Take next batch of indices. # Take next batch of indices.
end = min(self.x.shape[0], self.offset + self._batch_size) end = min(self._x.shape[0], self.offset + self._batch_size)
batch_indices = self.indices[self.offset:end] batch_indices = self.indices[self.offset:end]
# Assign input features from random indices. # Assign input features from random indices.
inp = ( inp = (
np.array(_access(self.x, batch_indices)).reshape( np.array(_access(self._x, batch_indices)).reshape(
(batch_indices.shape[0], 1)) (batch_indices.shape[0], 1))
if len(self.x.shape) == 1 else _access(self.x, batch_indices)) if len(self._x.shape) == 1 else _access(self._x, batch_indices))
feed_dict[self._input_placeholder.name] = inp feed_dict[self._input_placeholder.name] = inp
# move offset and reset it if necessary # move offset and reset it if necessary
self.offset += self._batch_size self.offset += self._batch_size
if self.offset >= self.x.shape[0]: if self.offset >= self._x.shape[0]:
self.indices = self.random_state.permutation(self.x.shape[0]) self.indices = self.random_state.permutation(self._x.shape[0])
self.offset = 0 self.offset = 0
self.epoch += 1 self.epoch += 1
@ -368,21 +398,21 @@ class DataFeeder(object):
# assign labels from random indices # assign labels from random indices
self.output_shape[0] = batch_indices.shape[0] self.output_shape[0] = batch_indices.shape[0]
out = np.zeros(self.output_shape, dtype=self.output_dtype) out = np.zeros(self.output_shape, dtype=self._output_dtype)
for i in xrange(out.shape[0]): for i in xrange(out.shape[0]):
sample = batch_indices[i] sample = batch_indices[i]
# self.n_classes is None means we're passing in raw target indices # self.n_classes is None means we're passing in raw target indices
if self.n_classes is None: if self.n_classes is None:
out[i] = _access(self.y, sample) out[i] = _access(self._y, sample)
else: else:
if self.n_classes > 1: if self.n_classes > 1:
if len(self.output_shape) == 2: if len(self.output_shape) == 2:
out.itemset((i, int(_access(self.y, sample))), 1.0) out.itemset((i, int(_access(self._y, sample))), 1.0)
else: else:
for idx, value in enumerate(_access(self.y, sample)): for idx, value in enumerate(_access(self._y, sample)):
out.itemset(tuple([i, idx, value]), 1.0) out.itemset(tuple([i, idx, value]), 1.0)
else: else:
out[i] = _access(self.y, sample) out[i] = _access(self._y, sample)
feed_dict[self._output_placeholder.name] = out feed_dict[self._output_placeholder.name] = out
return feed_dict return feed_dict
@ -420,32 +450,28 @@ class StreamingDataFeeder(DataFeeder):
""" """
# pylint: disable=invalid-name,super-init-not-called # pylint: disable=invalid-name,super-init-not-called
x_first_el = six.next(x) x_first_el = six.next(x)
self.x = itertools.chain([x_first_el], x) self._x = itertools.chain([x_first_el], x)
if y is not None: if y is not None:
y_first_el = six.next(y) y_first_el = six.next(y)
self.y = itertools.chain([y_first_el], y) self._y = itertools.chain([y_first_el], y)
else: else:
y_first_el = None y_first_el = None
self.y = None self._y = None
self.n_classes = n_classes self.n_classes = n_classes
self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape( self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
[1] + list(x_first_el.shape), [1] + list(x_first_el.shape),
[1] + list(y_first_el.shape) if y is not None else None, [1] + list(y_first_el.shape) if y is not None else None,
n_classes, n_classes,
batch_size) batch_size)
self.input_dtype = x_first_el.dtype self._input_dtype = _check_dtype(x_first_el.dtype)
# Convert float64 to float32, as all the parameters in the model are
# floats32 and there is a lot of benefits in using it in NNs.
if self.input_dtype == np.float64:
self.input_dtype = np.float32
# Output types are floats, due to both softmaxes and regression req. # Output types are floats, due to both softmaxes and regression req.
if n_classes is not None and n_classes > 0: if n_classes is not None and n_classes > 0:
self.output_dtype = np.float32 self._output_dtype = np.float32
elif y is not None: elif y is not None:
if isinstance(y_first_el, list) or isinstance(y_first_el, np.ndarray): if isinstance(y_first_el, list) or isinstance(y_first_el, np.ndarray):
self.output_dtype = np.dtype(type(y_first_el[0])) self._output_dtype = _check_dtype(np.dtype(type(y_first_el[0])))
else: else:
self.output_dtype = np.dtype(type(y_first_el)) self._output_dtype = _check_dtype(np.dtype(type(y_first_el)))
def get_feed_params(self): def get_feed_params(self):
"""Function returns a dict with data feed params while training. """Function returns a dict with data feed params while training.
@ -472,22 +498,22 @@ class StreamingDataFeeder(DataFeeder):
""" """
if self.stopped: if self.stopped:
raise StopIteration raise StopIteration
inp = np.zeros(self.input_shape, dtype=self.input_dtype) inp = np.zeros(self.input_shape, dtype=self._input_dtype)
if self.y is not None: if self._y is not None:
out = np.zeros(self.output_shape, dtype=self.output_dtype) out = np.zeros(self.output_shape, dtype=self._output_dtype)
for i in xrange(self._batch_size): for i in xrange(self._batch_size):
# Add handling when queue ends. # Add handling when queue ends.
try: try:
inp[i, :] = six.next(self.x) inp[i, :] = six.next(self._x)
except StopIteration: except StopIteration:
self.stopped = True self.stopped = True
inp = inp[:i, :] inp = inp[:i, :]
if self.y is not None: if self._y is not None:
out = out[:i] out = out[:i]
break break
if self.y is not None: if self._y is not None:
y = six.next(self.y) y = six.next(self._y)
if self.n_classes is not None and self.n_classes > 1: if self.n_classes is not None and self.n_classes > 1:
if len(self.output_shape) == 2: if len(self.output_shape) == 2:
out.itemset((i, y), 1.0) out.itemset((i, y), 1.0)
@ -496,7 +522,7 @@ class StreamingDataFeeder(DataFeeder):
out.itemset(tuple([i, idx, value]), 1.0) out.itemset(tuple([i, idx, value]), 1.0)
else: else:
out[i] = y out[i] = y
if self.y is None: if self._y is None:
return {self._input_placeholder.name: inp} return {self._input_placeholder.name: inp}
return {self._input_placeholder.name: inp, return {self._input_placeholder.name: inp,
self._output_placeholder.name: out} self._output_placeholder.name: out}
@ -511,6 +537,7 @@ class DaskDataFeeder(object):
into them. DaskDataFeeder will remove requirement to have full dataset in the into them. DaskDataFeeder will remove requirement to have full dataset in the
memory and still do random seeks for sampling of batches. memory and still do random seeks for sampling of batches.
""" """
def __init__(self, x, y, n_classes, batch_size, shuffle=True, def __init__(self, x, y, n_classes, batch_size, shuffle=True,
random_state=None, epochs=None): random_state=None, epochs=None):
"""Initializes a DaskDataFeeder instance. """Initializes a DaskDataFeeder instance.
@ -521,8 +548,10 @@ class DaskDataFeeder(object):
regression values. regression values.
n_classes: indicator of how many classes the target has. n_classes: indicator of how many classes the target has.
batch_size: Mini batch size to accumulate. batch_size: Mini batch size to accumulate.
shuffle: Whether to shuffle the inputs.
random_state: random state for RNG. Note that it will mutate so use a random_state: random state for RNG. Note that it will mutate so use a
int value for this if you want consistent sized batches. int value for this if you want consistent sized batches.
epochs: Number of epochs to run.
Attributes: Attributes:
x: input features. x: input features.
@ -537,35 +566,33 @@ class DaskDataFeeder(object):
# pylint: disable=invalid-name,super-init-not-called # pylint: disable=invalid-name,super-init-not-called
import dask.dataframe as dd # pylint: disable=g-import-not-at-top import dask.dataframe as dd # pylint: disable=g-import-not-at-top
# TODO(terrytangyuan): check x and y dtypes in dask_io like pandas # TODO(terrytangyuan): check x and y dtypes in dask_io like pandas
self.x = x self._x = x
self.y = y self._y = y
# save column names # save column names
self.x_columns = list(x.columns) self._x_columns = list(x.columns)
if isinstance(y.columns[0], str): if isinstance(y.columns[0], str):
self.y_columns = list(y.columns) self._y_columns = list(y.columns)
else: else:
# deal with cases where two DFs have overlapped default numeric colnames # deal with cases where two DFs have overlapped default numeric colnames
self.y_columns = len(self.x_columns) + 1 self._y_columns = len(self._x_columns) + 1
self.y = self.y.rename(columns={y.columns[0]: self.y_columns}) self._y = self._y.rename(columns={y.columns[0]: self._y_columns})
# TODO(terrytangyuan): deal with unsupervised cases # TODO(terrytangyuan): deal with unsupervised cases
# combine into a data frame # combine into a data frame
self.df = dd.multi.concat([self.x, self.y], axis=1) self.df = dd.multi.concat([self._x, self._y], axis=1)
self.n_classes = n_classes self.n_classes = n_classes
x_count = x.count().compute()[0] x_count = x.count().compute()[0]
x_shape = (x_count, len(self.x.columns)) x_shape = (x_count, len(self._x.columns))
y_shape = (x_count, len(self.y.columns)) y_shape = (x_count, len(self._y.columns))
# TODO(terrytangyuan): Add support for shuffle and epochs. # TODO(terrytangyuan): Add support for shuffle and epochs.
self.shuffle = shuffle self._shuffle = shuffle
self.epochs = epochs self.epochs = epochs
self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape( self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
x_shape, y_shape, n_classes, batch_size) x_shape, y_shape, n_classes, batch_size)
self.sample_fraction = self._batch_size / float(x_count) self.sample_fraction = self._batch_size / float(x_count)
# TODO(ptucker,ipolosukhin): Remove this? self._input_dtype = _check_dtype(self._x.dtypes[0])
# TODO(ipolosukhin): remove or restore. self._output_dtype = _check_dtype(self._y.dtypes[self._y_columns])
# self.x.dtypes[0], self.y.dtypes[self.y_columns]
self.input_dtype, self.output_dtype = np.float32, np.float32
if random_state is None: if random_state is None:
self.random_state = 66 self.random_state = 66
else: else:
@ -597,17 +624,17 @@ class DaskDataFeeder(object):
sample = self.df.random_split( sample = self.df.random_split(
[self.sample_fraction, 1 - self.sample_fraction], [self.sample_fraction, 1 - self.sample_fraction],
random_state=self.random_state) random_state=self.random_state)
inp = extract_pandas_matrix(sample[0][self.x_columns].compute()).tolist() inp = extract_pandas_matrix(sample[0][self._x_columns].compute()).tolist()
out = extract_pandas_matrix(sample[0][self.y_columns].compute()) out = extract_pandas_matrix(sample[0][self._y_columns].compute())
# convert to correct dtype # convert to correct dtype
inp = np.array(inp, dtype=self.input_dtype) inp = np.array(inp, dtype=self._input_dtype)
# one-hot encode out for each class for cross entropy loss # one-hot encode out for each class for cross entropy loss
if HAS_PANDAS: if HAS_PANDAS:
import pandas as pd # pylint: disable=g-import-not-at-top import pandas as pd # pylint: disable=g-import-not-at-top
if not isinstance(out, pd.Series): if not isinstance(out, pd.Series):
out = out.flatten() out = out.flatten()
out_max = self.y.max().compute().values[0] out_max = self._y.max().compute().values[0]
encoded_out = np.zeros((out.size, out_max + 1), dtype=self.output_dtype) encoded_out = np.zeros((out.size, out_max + 1), dtype=self._output_dtype)
encoded_out[np.arange(out.size), out] = 1 encoded_out[np.arange(out.size), out] = 1
return {input_placeholder.name: inp, return {input_placeholder.name: inp,
output_placeholder.name: encoded_out} output_placeholder.name: encoded_out}

View File

@ -20,12 +20,17 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import io_ops from tensorflow.python.ops import io_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import input as input_ops from tensorflow.python.training import input as input_ops
from tensorflow.python.training import queue_runner
# Default name for key in the feature dict. # Default name for key in the feature dict.
KEY_FEATURE_NAME = '__key__' KEY_FEATURE_NAME = '__key__'
@ -219,11 +224,18 @@ def read_keyed_batch_examples(
return queued_examples_with_keys return queued_examples_with_keys
def read_keyed_batch_features( def read_keyed_batch_features(file_pattern,
file_pattern, batch_size, features, reader, batch_size,
randomize_input=True, num_epochs=None, features,
queue_capacity=10000, reader_num_threads=1, reader,
parser_num_threads=1, name=None): randomize_input=True,
num_epochs=None,
queue_capacity=10000,
reader_num_threads=1,
feature_queue_capacity=100,
num_queue_runners=2,
parser_num_threads=None,
name=None):
"""Adds operations to read, queue, batch and parse `Example` protos. """Adds operations to read, queue, batch and parse `Example` protos.
Given file pattern (or list of files), will setup a queue for file names, Given file pattern (or list of files), will setup a queue for file names,
@ -251,7 +263,12 @@ def read_keyed_batch_features(
tf.initialize_local_variables() as shown in the tests. tf.initialize_local_variables() as shown in the tests.
queue_capacity: Capacity for input queue. queue_capacity: Capacity for input queue.
reader_num_threads: The number of threads to read examples. reader_num_threads: The number of threads to read examples.
parser_num_threads: The number of threads to parse examples. feature_queue_capacity: Capacity of the parsed features queue.
num_queue_runners: Number of queue runners to start for the feature queue,
Adding multiple queue runners for the parsed example queue helps maintain
a full queue when the subsequent computations overall are cheaper than
parsing.
parser_num_threads: (Deprecated) The number of threads to parse examples.
name: Name of resulting op. name: Name of resulting op.
Returns: Returns:
@ -261,6 +278,11 @@ def read_keyed_batch_features(
Raises: Raises:
ValueError: for invalid inputs. ValueError: for invalid inputs.
""" """
if parser_num_threads:
# TODO(sibyl-Aix6ihai): Remove on Sept 3 2016.
logging.warning('parser_num_threads is deprecated, it will be removed on'
'Sept 3 2016')
with ops.op_scope([file_pattern], name, 'read_batch_features') as scope: with ops.op_scope([file_pattern], name, 'read_batch_features') as scope:
keys, examples = read_keyed_batch_examples( keys, examples = read_keyed_batch_examples(
file_pattern, batch_size, reader, randomize_input=randomize_input, file_pattern, batch_size, reader, randomize_input=randomize_input,
@ -268,24 +290,66 @@ def read_keyed_batch_features(
num_threads=reader_num_threads, read_batch_size=batch_size, num_threads=reader_num_threads, read_batch_size=batch_size,
name=scope) name=scope)
if parser_num_threads == 1: # Parse the example.
# Avoid queue overhead for single thread feature_map = parsing_ops.parse_example(examples, features)
return keys, parsing_ops.parse_example(examples, features)
# Parse features into tensors in many threads and put on the queue. # Lets also add preprocessed tensors into the queue types for each item of
features_list = [] # the queue.
for _ in range(parser_num_threads): tensors_to_enqueue = []
feature_dict = parsing_ops.parse_example(examples, features) # Each entry contains the key, and a boolean which indicates whether the
feature_dict[KEY_FEATURE_NAME] = keys # tensor was a sparse tensor.
features_list.append(feature_dict) tensors_mapping = []
queued_features = input_ops.batch_join( # TODO(sibyl-Aix6ihai): Most of the functionality here is about pushing sparse
features_list, # tensors into a queue. This could be taken care in somewhere else so others
batch_size=batch_size, # can reuse it. Also, QueueBase maybe extended to handle sparse tensors
capacity=queue_capacity, # directly.
enqueue_many=True, for key, tensor in feature_map.iteritems():
name='parse_example_batch_join') if isinstance(tensor, ops.SparseTensor):
queued_keys = queued_features.pop(KEY_FEATURE_NAME) tensors_mapping.append((key, True))
return queued_keys, queued_features tensors_to_enqueue.extend([tensor.indices, tensor.values, tensor.shape])
else:
tensors_mapping.append((key, False))
tensors_to_enqueue.append(tensor)
tensors_to_enqueue.append(keys)
queue_dtypes = [x.dtype for x in tensors_to_enqueue]
input_queue = data_flow_ops.FIFOQueue(feature_queue_capacity, queue_dtypes)
# Add a summary op to debug if our feature queue is full or not.
logging_ops.scalar_summary('queue/parsed_features/%s/fraction_of_%d_full' %
(input_queue.name, feature_queue_capacity),
math_ops.cast(input_queue.size(), dtypes.float32)
* (1. / feature_queue_capacity))
# Add multiple queue runners so that the queue is always full. Adding more
# than two queue-runners may hog the cpu on the worker to fill up the queue.
for _ in range(num_queue_runners):
queue_runner.add_queue_runner(
queue_runner.QueueRunner(input_queue, [input_queue.enqueue(
tensors_to_enqueue)]))
dequeued_tensors = input_queue.dequeue()
# Reset shapes on dequeued tensors.
for i in range(len(tensors_to_enqueue)):
dequeued_tensors[i].set_shape(tensors_to_enqueue[i].get_shape())
# Recreate feature mapping according to the original dictionary.
dequeued_feature_map = {}
index = 0
for key, is_sparse_tensor in tensors_mapping:
if is_sparse_tensor:
# Three tensors are (indices, values, shape).
dequeued_feature_map[key] = ops.SparseTensor(
dequeued_tensors[index], dequeued_tensors[index + 1],
dequeued_tensors[index + 2])
index += 3
else:
dequeued_feature_map[key] = dequeued_tensors[index]
index += 1
dequeued_keys = dequeued_tensors[-1]
return dequeued_keys, dequeued_feature_map
def read_batch_features(file_pattern, batch_size, features, reader, def read_batch_features(file_pattern, batch_size, features, reader,

View File

@ -124,18 +124,18 @@ class GraphIOTest(tf.test.TestCase):
_VALID_FILE_PATTERN, batch_size, features, randomize_input=False, _VALID_FILE_PATTERN, batch_size, features, randomize_input=False,
queue_capacity=queue_capacity, reader_num_threads=2, queue_capacity=queue_capacity, reader_num_threads=2,
parser_num_threads=2, name=name) parser_num_threads=2, name=name)
self.assertEqual("%s/parse_example_batch_join:1" % name, self.assertEqual("%s/fifo_queue_1_Dequeue:0" % name,
features["feature"].name) features["feature"].name)
file_name_queue_name = "%s/file_name_queue" % name file_name_queue_name = "%s/file_name_queue" % name
file_names_name = "%s/input" % file_name_queue_name file_names_name = "%s/input" % file_name_queue_name
example_queue_name = "%s/fifo_queue" % name example_queue_name = "%s/fifo_queue" % name
parse_example_queue_name = "%s/parse_example_batch_join" % name parse_example_queue_name = "%s/fifo_queue" % name
op_nodes = test_util.assert_ops_in_graph({ op_nodes = test_util.assert_ops_in_graph({
file_names_name: "Const", file_names_name: "Const",
file_name_queue_name: "FIFOQueue", file_name_queue_name: "FIFOQueue",
"%s/read/TFRecordReader" % name: "TFRecordReader", "%s/read/TFRecordReader" % name: "TFRecordReader",
example_queue_name: "FIFOQueue", example_queue_name: "FIFOQueue",
parse_example_queue_name: "QueueDequeueMany", parse_example_queue_name: "FIFOQueue",
name: "QueueDequeueMany" name: "QueueDequeueMany"
}, g) }, g)
self.assertAllEqual(_FILE_NAMES, sess.run(["%s:0" % file_names_name])[0]) self.assertAllEqual(_FILE_NAMES, sess.run(["%s:0" % file_names_name])[0])

View File

@ -19,10 +19,10 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib import rnn as contrib_rnn
from tensorflow.contrib.learn.python.learn.ops import autoencoder_ops from tensorflow.contrib.learn.python.learn.ops import autoencoder_ops
from tensorflow.contrib.learn.python.learn.ops import dnn_ops from tensorflow.contrib.learn.python.learn.ops import dnn_ops
from tensorflow.contrib.learn.python.learn.ops import losses_ops from tensorflow.contrib.learn.python.learn.ops import losses_ops
from tensorflow.contrib import rnn as contrib_rnn
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops as array_ops_ from tensorflow.python.ops import array_ops as array_ops_
@ -81,6 +81,7 @@ def linear_regression(x, y, init_mean=None, init_stddev=1.0):
with vs.variable_scope('linear_regression'): with vs.variable_scope('linear_regression'):
logging_ops.histogram_summary('linear_regression.x', x) logging_ops.histogram_summary('linear_regression.x', x)
logging_ops.histogram_summary('linear_regression.y', y) logging_ops.histogram_summary('linear_regression.y', y)
dtype = x.dtype.base_dtype
y_shape = y.get_shape() y_shape = y.get_shape()
if len(y_shape) == 1: if len(y_shape) == 1:
output_shape = 1 output_shape = 1
@ -88,15 +89,18 @@ def linear_regression(x, y, init_mean=None, init_stddev=1.0):
output_shape = y_shape[1] output_shape = y_shape[1]
# Set up the requested initialization. # Set up the requested initialization.
if init_mean is None: if init_mean is None:
weights = vs.get_variable('weights', [x.get_shape()[1], output_shape]) weights = vs.get_variable(
bias = vs.get_variable('bias', [output_shape]) 'weights', [x.get_shape()[1], output_shape], dtype=dtype)
bias = vs.get_variable('bias', [output_shape], dtype=dtype)
else: else:
weights = vs.get_variable('weights', [x.get_shape()[1], output_shape], weights = vs.get_variable('weights', [x.get_shape()[1], output_shape],
initializer=init_ops.random_normal_initializer( initializer=init_ops.random_normal_initializer(
init_mean, init_stddev)) init_mean, init_stddev, dtype=dtype),
dtype=dtype)
bias = vs.get_variable('bias', [output_shape], bias = vs.get_variable('bias', [output_shape],
initializer=init_ops.random_normal_initializer( initializer=init_ops.random_normal_initializer(
init_mean, init_stddev)) init_mean, init_stddev, dtype=dtype),
dtype=dtype)
logging_ops.histogram_summary('linear_regression.weights', weights) logging_ops.histogram_summary('linear_regression.weights', weights)
logging_ops.histogram_summary('linear_regression.bias', bias) logging_ops.histogram_summary('linear_regression.bias', bias)
return losses_ops.mean_squared_error_regressor(x, y, weights, bias) return losses_ops.mean_squared_error_regressor(x, y, weights, bias)
@ -135,19 +139,22 @@ def logistic_regression(x,
with vs.variable_scope('logistic_regression'): with vs.variable_scope('logistic_regression'):
logging_ops.histogram_summary('%s.x' % vs.get_variable_scope().name, x) logging_ops.histogram_summary('%s.x' % vs.get_variable_scope().name, x)
logging_ops.histogram_summary('%s.y' % vs.get_variable_scope().name, y) logging_ops.histogram_summary('%s.y' % vs.get_variable_scope().name, y)
dtype = x.dtype.base_dtype
# Set up the requested initialization. # Set up the requested initialization.
if init_mean is None: if init_mean is None:
weights = vs.get_variable('weights', weights = vs.get_variable(
[x.get_shape()[1], y.get_shape()[-1]]) 'weights', [x.get_shape()[1], y.get_shape()[-1]], dtype=dtype)
bias = vs.get_variable('bias', [y.get_shape()[-1]]) bias = vs.get_variable('bias', [y.get_shape()[-1]], dtype=dtype)
else: else:
weights = vs.get_variable('weights', weights = vs.get_variable('weights',
[x.get_shape()[1], y.get_shape()[-1]], [x.get_shape()[1], y.get_shape()[-1]],
initializer=init_ops.random_normal_initializer( initializer=init_ops.random_normal_initializer(
init_mean, init_stddev)) init_mean, init_stddev, dtype=dtype),
dtype=dtype)
bias = vs.get_variable('bias', [y.get_shape()[-1]], bias = vs.get_variable('bias', [y.get_shape()[-1]],
initializer=init_ops.random_normal_initializer( initializer=init_ops.random_normal_initializer(
init_mean, init_stddev)) init_mean, init_stddev, dtype=dtype),
dtype=dtype)
logging_ops.histogram_summary('%s.weights' % vs.get_variable_scope().name, logging_ops.histogram_summary('%s.weights' % vs.get_variable_scope().name,
weights) weights)
logging_ops.histogram_summary('%s.bias' % vs.get_variable_scope().name, logging_ops.histogram_summary('%s.bias' % vs.get_variable_scope().name,

View File

@ -535,8 +535,12 @@ class LoggingTrainable(EveryN):
class SummarySaver(EveryN): class SummarySaver(EveryN):
"""Saves summaries every N steps.""" """Saves summaries every N steps."""
def __init__(self, summary_op, save_steps=100, output_dir=None, def __init__(self,
summary_writer=None): summary_op,
save_steps=100,
output_dir=None,
summary_writer=None,
scaffold=None):
"""Initializes a `SummarySaver` monitor. """Initializes a `SummarySaver` monitor.
Args: Args:
@ -548,6 +552,7 @@ class SummarySaver(EveryN):
if no `summary_writer` is supplied. if no `summary_writer` is supplied.
summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed, summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed,
one will be created accordingly. one will be created accordingly.
scaffold: `Scaffold` to get summary_op if it's not provided.
""" """
# TODO(ipolosukhin): Implement every N seconds. # TODO(ipolosukhin): Implement every N seconds.
super(SummarySaver, self).__init__(every_n_steps=save_steps) super(SummarySaver, self).__init__(every_n_steps=save_steps)
@ -555,6 +560,7 @@ class SummarySaver(EveryN):
self._summary_writer = summary_writer self._summary_writer = summary_writer
if summary_writer is None and output_dir: if summary_writer is None and output_dir:
self._summary_writer = summary_io.SummaryWriter(output_dir) self._summary_writer = summary_io.SummaryWriter(output_dir)
self._scaffold = scaffold
# TODO(mdan): Throw an error if output_dir and summary_writer are None. # TODO(mdan): Throw an error if output_dir and summary_writer are None.
def set_estimator(self, estimator): def set_estimator(self, estimator):
@ -565,14 +571,17 @@ class SummarySaver(EveryN):
def every_n_step_begin(self, step): def every_n_step_begin(self, step):
super(SummarySaver, self).every_n_step_begin(step) super(SummarySaver, self).every_n_step_begin(step)
if self._summary_op is None and self._scaffold is not None:
self._summary_op = self._scaffold.summary_op
if self._summary_op is not None: if self._summary_op is not None:
return [self._summary_op] return [self._summary_op]
return [] return []
def every_n_step_end(self, step, outputs): def every_n_step_end(self, step, outputs):
super(SummarySaver, self).every_n_step_end(step, outputs) super(SummarySaver, self).every_n_step_end(step, outputs)
if self._summary_op is not None:
summary_strs = _extract_output(outputs, self._summary_op) summary_strs = _extract_output(outputs, self._summary_op)
if self._summary_writer and self._summary_op is not None: if self._summary_writer:
self._summary_writer.add_summary(summary_strs, step) self._summary_writer.add_summary(summary_strs, step)
return False return False
@ -923,36 +932,88 @@ class ExportMonitor(EveryN):
default_batch_size=self._default_batch_size) default_batch_size=self._default_batch_size)
class CheckpointSaver(EveryN): class CheckpointSaver(BaseMonitor):
"""Saves checkpoints every N steps.""" """Saves checkpoints every N steps."""
def __init__(self, every_n_steps, saver, checkpoint_dir, def __init__(self,
checkpoint_dir,
save_secs=None,
save_steps=None,
saver=None,
checkpoint_basename="model.ckpt", checkpoint_basename="model.ckpt",
first_n_steps=-1): scaffold=None):
"""Initialize CheckpointSaver monitor. """Initialize CheckpointSaver monitor.
Args: Args:
every_n_steps: `int`, save every N steps.
saver: `Saver` object, used for saving.
checkpoint_dir: `str`, base directory for the checkpoint files. checkpoint_dir: `str`, base directory for the checkpoint files.
save_secs: `int`, save every N secs.
save_steps: `int`, save every N steps.
saver: `Saver` object, used for saving.
checkpoint_basename: `str`, base name for the checkpoint files. checkpoint_basename: `str`, base name for the checkpoint files.
first_n_steps: `int`, if positive, save every step during the scaffold: `Scaffold`, use to get saver object.
first `first_n_steps` steps.
Raises:
ValueError: If both `save_steps` and `save_secs` are not `None`.
ValueError: If both `save_steps` and `save_secs` are `None`.
""" """
logging.info("Create CheckpointSaver") logging.info("Create CheckpointSaver")
super(CheckpointSaver, self).__init__(every_n_steps=every_n_steps, super(CheckpointSaver, self).__init__()
first_n_steps=first_n_steps)
self._saver = saver self._saver = saver
self._summary_writer = SummaryWriterCache.get(checkpoint_dir) self._summary_writer = SummaryWriterCache.get(checkpoint_dir)
self._save_path = os.path.join(checkpoint_dir, checkpoint_basename) self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
self._scaffold = scaffold
self._save_secs = save_secs
self._save_steps = save_steps
self._last_saved_time = None
self._last_begin_step = None
self._last_saved_step = None
def every_n_post_step(self, step, session): if save_steps is None and save_secs is None:
raise ValueError("Either save_steps or save_secs should be provided")
if (save_steps is not None) and (save_secs is not None):
raise ValueError("Can not provide both save_steps and save_secs.")
def begin(self, max_steps=None):
super(CheckpointSaver, self).begin(max_steps)
self._last_saved_time = None
self._last_begin_step = None
self._last_saved_step = None
def step_begin(self, step):
super(CheckpointSaver, self).step_begin(step)
self._last_begin_step = step
def post_step(self, step, session):
super(CheckpointSaver, self).post_step(step, session)
if self._last_saved_time is None:
self._save(step, session)
if self._save_steps is not None:
if step >= self._last_saved_step + self._save_steps:
self._save(step, session)
if self._save_secs is not None:
if time.time() >= self._last_saved_time + self._save_secs:
self._save(step, session)
def end(self, session=None):
super(CheckpointSaver, self).end(session)
self._save(self._last_begin_step, session)
def _save(self, step, session):
"""Saves the latest checkpoint."""
if step == self._last_saved_step:
return
logging.info("Saving checkpoints for %d into %s.", step, self._save_path) logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
self._last_saved_time = time.time()
self._last_saved_step = step
if self._saver is None:
self._scaffold.saver.save(session, self._save_path, global_step=step)
else:
self._saver.save(session, self._save_path, global_step=step) self._saver.save(session, self._save_path, global_step=step)
if self._summary_writer:
self._summary_writer.add_session_log( self._summary_writer.add_session_log(
SessionLog(status=SessionLog.CHECKPOINT, SessionLog(
checkpoint_path=self._save_path), status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
step) step)

View File

@ -119,47 +119,85 @@ class Scaffold(object):
keep_checkpoint_max: Optional parameter to use to construct a saver if keep_checkpoint_max: Optional parameter to use to construct a saver if
none is already there in the graph. none is already there in the graph.
""" """
if global_step_tensor is None:
global_step_tensor = contrib_variables.get_or_create_global_step()
self.global_step_tensor = global_step_tensor
if init_op is None:
init_op = Scaffold._get_or_default('init_op', ops.GraphKeys.INIT_OP,
variables.initialize_all_variables)
self.init_op = init_op
self.init_feed_dict = init_feed_dict
# NOTE(touts): modifying the init function to be passed the scaffold is a # NOTE(touts): modifying the init function to be passed the scaffold is a
# hack to make it easy to find the saver. Is there a better way? # hack to make it easy to find the saver. Is there a better way?
if init_fn: if init_fn:
self.init_fn = lambda sess: init_fn(self, sess) self._init_fn = lambda sess: init_fn(self, sess)
else: else:
self.init_fn = None self._init_fn = None
if ready_op is None:
ready_op = Scaffold._get_or_default( self._global_step_tensor = global_step_tensor
self._init_op = init_op
self._ready_op = ready_op
self._local_init_op = local_init_op
self._summary_op = summary_op
self._saver = saver
self._keep_checkpoint_max = keep_checkpoint_max
self._init_feed_dict = init_feed_dict
def finalize(self):
"""Creates operations if needed and finalizes the graph."""
if self._global_step_tensor is None:
self._global_step_tensor = contrib_variables.get_or_create_global_step()
if self._init_op is None:
self._init_op = Scaffold._get_or_default(
'init_op', ops.GraphKeys.INIT_OP, variables.initialize_all_variables)
if self._ready_op is None:
self._ready_op = Scaffold._get_or_default(
'ready_op', ops.GraphKeys.READY_OP, 'ready_op', ops.GraphKeys.READY_OP,
variables.report_uninitialized_variables) variables.report_uninitialized_variables)
self.ready_op = ready_op if self._local_init_op is None:
if local_init_op is None: self._local_init_op = Scaffold._get_or_default(
local_init_op = Scaffold._get_or_default('local_init_op', 'local_init_op', ops.GraphKeys.LOCAL_INIT_OP,
ops.GraphKeys.LOCAL_INIT_OP,
Scaffold._default_local_init_op) Scaffold._default_local_init_op)
self.local_init_op = local_init_op if self._summary_op is None:
if summary_op is None: self._summary_op = Scaffold._get_or_default(
summary_op = Scaffold._get_or_default('summary_op', 'summary_op', ops.GraphKeys.SUMMARY_OP,
ops.GraphKeys.SUMMARY_OP,
logging_ops.merge_all_summaries) logging_ops.merge_all_summaries)
self.summary_op = summary_op
# pylint: disable=g-long-lambda # pylint: disable=g-long-lambda
if saver is None: if self._saver is None:
saver = Scaffold._get_or_default( self._saver = Scaffold._get_or_default(
'saver', 'saver',
ops.GraphKeys.SAVERS, ops.GraphKeys.SAVERS,
lambda: training_saver.Saver(sharded=True, lambda: training_saver.Saver(sharded=True,
max_to_keep=keep_checkpoint_max)) max_to_keep=self._keep_checkpoint_max))
# pylint: enable=g-long-lambda # pylint: enable=g-long-lambda
self.saver = saver
ops.get_default_graph().finalize() ops.get_default_graph().finalize()
@property
def global_step_tensor(self):
return self._global_step_tensor
@property
def init_fn(self):
return self._init_fn
@property
def init_op(self):
return self._init_op
@property
def ready_op(self):
return self._ready_op
@property
def local_init_op(self):
return self._local_init_op
@property
def summary_op(self):
return self._summary_op
@property
def saver(self):
return self._saver
@property
def init_feed_dict(self):
return self._init_feed_dict
@staticmethod @staticmethod
def _get_or_default(arg_name, collection_key, default_constructor): def _get_or_default(arg_name, collection_key, default_constructor):
"""Get from cache or create a default operation.""" """Get from cache or create a default operation."""
@ -213,9 +251,10 @@ class SupervisedSession(object):
self._config = config self._config = config
self._monitors = monitors or [] self._monitors = monitors or []
self._scaffold = scaffold or Scaffold() self._scaffold = scaffold or Scaffold()
# Finalize and write the graph. for monitor in self._monitors:
self._graph.finalize() monitor.begin(max_steps=None)
# Create the session. # Create the session.
self._scaffold.finalize()
self._session_manager = sm.SessionManager( self._session_manager = sm.SessionManager(
local_init_op=self._scaffold.local_init_op, local_init_op=self._scaffold.local_init_op,
ready_op=self._scaffold.ready_op, ready_op=self._scaffold.ready_op,
@ -223,8 +262,6 @@ class SupervisedSession(object):
self._sess = recoverable_session.RecoverableSession(self._create_session) self._sess = recoverable_session.RecoverableSession(self._create_session)
# Call the begin() method of monitors. # Call the begin() method of monitors.
self._init_step = self._tf_sess.run(self._scaffold.global_step_tensor) self._init_step = self._tf_sess.run(self._scaffold.global_step_tensor)
for monitor in self._monitors:
monitor.begin(max_steps=None)
# Write the graph out, note: this uses self._init_step. # Write the graph out, note: this uses self._init_step.
self.write_graph() self.write_graph()

View File

@ -76,9 +76,8 @@ class CoordinatedSessionTest(tf.test.TestCase):
self.assertFalse(coord_sess.should_stop()) self.assertFalse(coord_sess.should_stop())
self.assertEqual(0, coord_sess.run(c)) self.assertEqual(0, coord_sess.run(c))
self.assertEqual(1, coord_sess.run(v, feed_dict={c: 1})) self.assertEqual(1, coord_sess.run(v, feed_dict={c: 1}))
with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, with self.assertRaisesRegexp(TypeError, 'None has invalid type'):
'both fed and fetched'): coord_sess.run([None], feed_dict={c: 2})
coord_sess.run(c, feed_dict={c: 2})
self.assertTrue(coord.should_stop()) self.assertTrue(coord.should_stop())
self.assertTrue(coord_sess.should_stop()) self.assertTrue(coord_sess.should_stop())
@ -101,9 +100,8 @@ class CoordinatedSessionTest(tf.test.TestCase):
self.assertEqual(1, coord_sess.run(v, feed_dict={c: 1})) self.assertEqual(1, coord_sess.run(v, feed_dict={c: 1}))
for t in threads: for t in threads:
self.assertTrue(t.is_alive()) self.assertTrue(t.is_alive())
with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, with self.assertRaisesRegexp(TypeError, 'None has invalid type'):
'both fed and fetched'): coord_sess.run([None], feed_dict={c: 2})
coord_sess.run(c, feed_dict={c: 2})
for t in threads: for t in threads:
self.assertFalse(t.is_alive()) self.assertFalse(t.is_alive())
self.assertTrue(coord.should_stop()) self.assertTrue(coord.should_stop())

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np import numpy as np
import six import six
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
# pylint: disable=wildcard-import # pylint: disable=wildcard-import
@ -31,6 +32,68 @@ class DataFeederTest(tf.test.TestCase):
# pylint: disable=undefined-variable # pylint: disable=undefined-variable
"""Tests for `DataFeeder`.""" """Tests for `DataFeeder`."""
def _assert_raises(self, input_data):
with self.assertRaisesRegexp(TypeError, 'annot convert'):
data_feeder.DataFeeder(input_data, None, n_classes=0, batch_size=1)
def test_input_uint32(self):
self._assert_raises(np.matrix([[1, 2], [3, 4]], dtype=np.uint32))
def test_input_uint64(self):
self._assert_raises(np.matrix([[1, 2], [3, 4]], dtype=np.uint64))
def _assert_dtype(self, expected_np_dtype, expected_tf_dtype, input_data):
feeder = data_feeder.DataFeeder(input_data, None, n_classes=0, batch_size=1)
self.assertEqual(expected_np_dtype, feeder.input_dtype)
with tf.Graph().as_default() as g, self.test_session(g):
inp, _ = feeder.input_builder()
self.assertEqual(expected_tf_dtype, inp.dtype)
def test_input_int8(self):
self._assert_dtype(
np.int8, tf.int8, np.matrix([[1, 2], [3, 4]], dtype=np.int8))
def test_input_int16(self):
self._assert_dtype(
np.int16, tf.int16, np.matrix([[1, 2], [3, 4]], dtype=np.int16))
def test_input_int32(self):
self._assert_dtype(
np.int32, tf.int32, np.matrix([[1, 2], [3, 4]], dtype=np.int32))
def test_input_int64(self):
self._assert_dtype(
np.int64, tf.int64, np.matrix([[1, 2], [3, 4]], dtype=np.int64))
def test_input_uint8(self):
self._assert_dtype(
np.uint8, tf.uint8, np.matrix([[1, 2], [3, 4]], dtype=np.uint8))
def test_input_uint16(self):
self._assert_dtype(
np.uint16, tf.uint16, np.matrix([[1, 2], [3, 4]], dtype=np.uint16))
def test_input_float16(self):
self._assert_dtype(
np.float16, tf.float16, np.matrix([[1, 2], [3, 4]], dtype=np.float16))
def test_input_float32(self):
self._assert_dtype(
np.float32, tf.float32, np.matrix([[1, 2], [3, 4]], dtype=np.float32))
def test_input_float64(self):
self._assert_dtype(
np.float64, tf.float64, np.matrix([[1, 2], [3, 4]], dtype=np.float64))
def test_input_bool(self):
self._assert_dtype(
np.bool, tf.bool,
np.array([[False for _ in xrange(2)] for _ in xrange(2)]))
def test_input_string(self):
input_data = np.array([['str%d' % i for i in xrange(2)] for _ in xrange(2)])
self._assert_dtype(input_data.dtype, tf.string, input_data)
def test_unsupervised(self): def test_unsupervised(self):
data = np.matrix([[1, 2], [2, 3], [3, 4]]) data = np.matrix([[1, 2], [2, 3], [3, 4]])
feeder = data_feeder.DataFeeder(data, None, n_classes=0, batch_size=1) feeder = data_feeder.DataFeeder(data, None, n_classes=0, batch_size=1)

View File

@ -1,4 +1,4 @@
# Copyright 2016 Google Inc. All Rights Reserved. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

Some files were not shown because too many files have changed in this diff Show More