Merge pull request #3656 from zheng-xq/branch_129393964
Branch 129393964
This commit is contained in:
commit
c99b9bef4c
@ -37,7 +37,10 @@ config_setting(
|
|||||||
|
|
||||||
package_group(
|
package_group(
|
||||||
name = "internal",
|
name = "internal",
|
||||||
packages = ["//tensorflow/..."],
|
packages = [
|
||||||
|
"//learning/vis/...",
|
||||||
|
"//tensorflow/...",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
sh_binary(
|
sh_binary(
|
||||||
|
@ -482,7 +482,6 @@ static void TF_Run_Helper(
|
|||||||
result = session->PRun(handle, input_pairs, output_tensor_names, &outputs);
|
result = session->PRun(handle, input_pairs, output_tensor_names, &outputs);
|
||||||
}
|
}
|
||||||
if (!result.ok()) {
|
if (!result.ok()) {
|
||||||
LOG(ERROR) << result.error_message();
|
|
||||||
status->status = result;
|
status->status = result;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -73,6 +73,46 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "grad_op_registry",
|
||||||
|
srcs = ["framework/grad_op_registry.cc"],
|
||||||
|
hdrs = ["framework/grad_op_registry.h"],
|
||||||
|
deps = [
|
||||||
|
":ops",
|
||||||
|
":scope",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "math_grad",
|
||||||
|
srcs = ["gradients/math_grad.cc"],
|
||||||
|
deps = [
|
||||||
|
":cc_ops",
|
||||||
|
":grad_op_registry",
|
||||||
|
":ops",
|
||||||
|
":scope",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "gradients/math_grad_test",
|
||||||
|
deps = [
|
||||||
|
":cc_ops",
|
||||||
|
":grad_op_registry",
|
||||||
|
":math_grad",
|
||||||
|
"//tensorflow/core:all_kernels",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:tensorflow",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_gen_op_wrappers_cc(
|
tf_gen_op_wrappers_cc(
|
||||||
name = "cc_ops",
|
name = "cc_ops",
|
||||||
op_lib_names = [
|
op_lib_names = [
|
||||||
|
42
tensorflow/cc/framework/grad_op_registry.cc
Normal file
42
tensorflow/cc/framework/grad_op_registry.cc
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
// static
|
||||||
|
GradOpRegistry* GradOpRegistry::Global() {
|
||||||
|
static GradOpRegistry* grad_op_registry = new GradOpRegistry;
|
||||||
|
return grad_op_registry;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GradOpRegistry::Register(const string& op, GradFunc func) {
|
||||||
|
CHECK(registry_.insert({op, func}).second) << "Existing gradient for " << op;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GradOpRegistry::Lookup(const string& op, GradFunc* func) {
|
||||||
|
auto iter = registry_.find(op);
|
||||||
|
if (iter == registry_.end()) {
|
||||||
|
return errors::NotFound("No gradient defined for op: ", op);
|
||||||
|
}
|
||||||
|
*func = iter->second;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace ops
|
||||||
|
} // namespace tensorflow
|
75
tensorflow/cc/framework/grad_op_registry.h
Normal file
75
tensorflow/cc/framework/grad_op_registry.h
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "tensorflow/cc/framework/ops.h"
|
||||||
|
#include "tensorflow/cc/framework/scope.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
// GradFunc is the signature for all gradient functions in GradOpRegistry.
|
||||||
|
// Implementations should add operations to compute the gradient outputs of 'op'
|
||||||
|
// (returned in 'grad_outputs') using 'scope' and 'grad_inputs'.
|
||||||
|
typedef Status (*GradFunc)(const Scope& scope, const Operation& op,
|
||||||
|
const std::vector<Output>& grad_inputs,
|
||||||
|
std::vector<Output>* grad_outputs);
|
||||||
|
|
||||||
|
// GradOpRegistry maintains a static registry of gradient functions.
|
||||||
|
// Gradient functions are indexed in the registry by the forward op name (i.e.
|
||||||
|
// "MatMul" -> MatMulGrad func).
|
||||||
|
class GradOpRegistry {
|
||||||
|
public:
|
||||||
|
// Registers 'func' as the the gradient function for 'op'.
|
||||||
|
// Returns true if registration was succesful, check fails otherwise.
|
||||||
|
bool Register(const string& op, GradFunc func);
|
||||||
|
|
||||||
|
// Sets 'func' to the gradient function for 'op' and returns Status OK if
|
||||||
|
// the gradient function for 'op' exists in the registry.
|
||||||
|
// Note that 'func' can be null for ops that have registered no-gradient with
|
||||||
|
// the registry.
|
||||||
|
// Returns error status otherwise.
|
||||||
|
Status Lookup(const string& op, GradFunc* func);
|
||||||
|
|
||||||
|
// Returns a pointer to the global gradient function registry.
|
||||||
|
static GradOpRegistry* Global();
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unordered_map<string, GradFunc> registry_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ops
|
||||||
|
|
||||||
|
// Macros used to define gradient functions for ops.
|
||||||
|
#define REGISTER_GRADIENT_OP(name, fn) \
|
||||||
|
REGISTER_GRADIENT_OP_UNIQ_HELPER(__COUNTER__, name, fn)
|
||||||
|
|
||||||
|
#define REGISTER_NO_GRADIENT_OP(name) \
|
||||||
|
REGISTER_GRADIENT_OP_UNIQ_HELPER(__COUNTER__, name, nullptr)
|
||||||
|
|
||||||
|
#define REGISTER_GRADIENT_OP_UNIQ_HELPER(ctr, name, fn) \
|
||||||
|
REGISTER_GRADIENT_OP_UNIQ(ctr, name, fn)
|
||||||
|
|
||||||
|
#define REGISTER_GRADIENT_OP_UNIQ(ctr, name, fn) \
|
||||||
|
static bool unused_ret_val_##ctr = \
|
||||||
|
::tensorflow::ops::GradOpRegistry::Global()->Register(name, fn)
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
|
@ -18,6 +18,44 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
|
Operation::Operation(Node* n) : inputs_(GetInputs(n)), node_(n) {}
|
||||||
|
|
||||||
|
Output Operation::input(int i) const {
|
||||||
|
CHECK_NOTNULL(node_);
|
||||||
|
CHECK_GE(i, 0);
|
||||||
|
CHECK_LT(i, node_->num_inputs());
|
||||||
|
// Handle the case where the input was unknown at the time this
|
||||||
|
// Operation was constructed.
|
||||||
|
if (inputs_[i].first == nullptr && inputs_[i].second == -1) {
|
||||||
|
for (const Edge* e : node_->in_edges()) {
|
||||||
|
if (e->IsControlEdge()) continue;
|
||||||
|
if (e->dst_input() == i) {
|
||||||
|
return Output(e->src(), e->src_output());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Output(inputs_[i].first, inputs_[i].second);
|
||||||
|
}
|
||||||
|
|
||||||
|
Output Operation::output(int i) const {
|
||||||
|
CHECK_NOTNULL(node_);
|
||||||
|
CHECK_GE(i, 0);
|
||||||
|
CHECK_LT(i, node_->num_outputs());
|
||||||
|
return Output(node_, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation::Inputs Operation::GetInputs(Node* node) {
|
||||||
|
Operation::Inputs inputs;
|
||||||
|
if (node != nullptr) {
|
||||||
|
inputs.resize(node->num_inputs(), {nullptr, -1});
|
||||||
|
for (const Edge* e : node->in_edges()) {
|
||||||
|
if (e->IsControlEdge()) continue;
|
||||||
|
inputs[e->dst_input()] = std::make_pair(e->src(), e->src_output());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return inputs;
|
||||||
|
}
|
||||||
|
|
||||||
Input::Initializer::Initializer(
|
Input::Initializer::Initializer(
|
||||||
const std::initializer_list<Input::Initializer>& v) {
|
const std::initializer_list<Input::Initializer>& v) {
|
||||||
if (v.size() < 1) {
|
if (v.size() < 1) {
|
||||||
|
@ -27,17 +27,29 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
|
class Output;
|
||||||
|
|
||||||
// Represents a node in the computation graph.
|
// Represents a node in the computation graph.
|
||||||
class Operation {
|
class Operation {
|
||||||
public:
|
public:
|
||||||
Operation() : node_(nullptr) {}
|
Operation() : node_(nullptr) {}
|
||||||
explicit Operation(Node* n) : node_(n) {}
|
explicit Operation(Node* n);
|
||||||
|
|
||||||
|
int num_inputs() const { return node_->num_inputs(); }
|
||||||
|
DataType input_type(int o) const { return node_->input_type(o); }
|
||||||
|
Output input(int i) const;
|
||||||
|
|
||||||
int num_outputs() const { return node_->num_outputs(); }
|
int num_outputs() const { return node_->num_outputs(); }
|
||||||
DataType output_type(int o) const { return node_->output_type(o); }
|
DataType output_type(int o) const { return node_->output_type(o); }
|
||||||
|
Output output(int i) const;
|
||||||
|
|
||||||
Node* node() const { return node_; }
|
Node* node() const { return node_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
typedef std::vector<std::pair<Node*, int64>> Inputs;
|
||||||
|
static Inputs GetInputs(Node* node);
|
||||||
|
|
||||||
|
Inputs inputs_;
|
||||||
Node* node_;
|
Node* node_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -81,7 +93,7 @@ class Input {
|
|||||||
tensor = t;
|
tensor = t;
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit Initializer(const Tensor& t) : tensor(t) {}
|
Initializer(const Tensor& t) : tensor(t) {} // NOLINT(runtime/explicit)
|
||||||
|
|
||||||
// Construct from a scalar value and an explicit shape
|
// Construct from a scalar value and an explicit shape
|
||||||
template <typename T, typename = typename std::enable_if<
|
template <typename T, typename = typename std::enable_if<
|
||||||
|
91
tensorflow/cc/gradients/math_grad.cc
Normal file
91
tensorflow/cc/gradients/math_grad.cc
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
|
||||||
|
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace ops {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// TODO(andydavis) Move this to a more appropriate file.
|
||||||
|
REGISTER_NO_GRADIENT_OP("Const");
|
||||||
|
|
||||||
|
// MatMulGrad helper function used to compute two MatMul operations
|
||||||
|
// based on input matrix transposition combinations.
|
||||||
|
Status MatMulGradHelper(const Scope& scope, const Output& x0, const bool adj_x0,
|
||||||
|
const Output& x1, const bool adj_x1, const Output& y0,
|
||||||
|
const bool adj_y0, const Output& y1, const bool adj_y1,
|
||||||
|
std::vector<Output>* grad_outputs) {
|
||||||
|
auto dx =
|
||||||
|
MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1));
|
||||||
|
grad_outputs->push_back(dx);
|
||||||
|
auto dy =
|
||||||
|
MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1));
|
||||||
|
grad_outputs->push_back(dy);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// MatMulGrad common used to read and check node attr state, and determine
|
||||||
|
// proper MatMul products for gradients based on input matrix transposition
|
||||||
|
// combinations.
|
||||||
|
// TODO(andydavis) Re-use this function for BatchMatMulGrad.
|
||||||
|
Status MatMulGradCommon(const Scope& scope, const Operation& op,
|
||||||
|
const std::vector<Output>& grad_inputs,
|
||||||
|
const string& attr_adj_x, const string& attr_adj_y,
|
||||||
|
std::vector<Output>* grad_outputs) {
|
||||||
|
DataType dtype;
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), "T", &dtype));
|
||||||
|
if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
|
||||||
|
return errors::Unimplemented(
|
||||||
|
"MatMul gradient for complex data type is not supported yet.");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ta;
|
||||||
|
bool tb;
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_x, &ta));
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_y, &tb));
|
||||||
|
|
||||||
|
if (!ta && !tb) {
|
||||||
|
return MatMulGradHelper(scope, grad_inputs[0], false, op.input(1), true,
|
||||||
|
op.input(0), true, grad_inputs[0], false,
|
||||||
|
grad_outputs);
|
||||||
|
} else if (!ta && tb) {
|
||||||
|
return MatMulGradHelper(scope, grad_inputs[0], false, op.input(1), false,
|
||||||
|
grad_inputs[0], true, op.input(0), false,
|
||||||
|
grad_outputs);
|
||||||
|
} else if (ta && !tb) {
|
||||||
|
return MatMulGradHelper(scope, op.input(1), false, grad_inputs[0], true,
|
||||||
|
op.input(0), false, grad_inputs[0], false,
|
||||||
|
grad_outputs);
|
||||||
|
}
|
||||||
|
return MatMulGradHelper(scope, op.input(1), true, grad_inputs[0], true,
|
||||||
|
grad_inputs[0], true, op.input(0), true,
|
||||||
|
grad_outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MatMulGrad(const Scope& scope, const Operation& op,
|
||||||
|
const std::vector<Output>& grad_inputs,
|
||||||
|
std::vector<Output>* grad_outputs) {
|
||||||
|
return MatMulGradCommon(scope, op, grad_inputs, "transpose_a", "transpose_b",
|
||||||
|
grad_outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRADIENT_OP("MatMul", MatMulGrad);
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace tensorflow
|
183
tensorflow/cc/gradients/math_grad_test.cc
Normal file
183
tensorflow/cc/gradients/math_grad_test.cc
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/graph/default_device.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/lib/random/random.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
using namespace ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// TODO(andydavis) Test gradient function against numeric gradients output.
|
||||||
|
// TODO(andydavis) As more gradients are added move common test functions
|
||||||
|
// to a testutil library.
|
||||||
|
class MathGradTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
MathGradTest() : root_(Scope::NewRootScope()) {}
|
||||||
|
|
||||||
|
void ComputeMatMulGrad(const Output& x, const bool t_x, const Output& y,
|
||||||
|
const bool t_y, const Output& dz,
|
||||||
|
std::vector<Tensor>* out) {
|
||||||
|
// Compute forward MatMul: z = MatMul(x, y).
|
||||||
|
auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
|
||||||
|
TF_EXPECT_OK(root_.status());
|
||||||
|
CHECK_NOTNULL(z.node());
|
||||||
|
std::vector<Output> grad_outputs;
|
||||||
|
// Call MatMulGrad which populates 'grad_outputs'.
|
||||||
|
CallGradFunction(Operation(z.node()), {dz}, &grad_outputs);
|
||||||
|
EXPECT_EQ(2, grad_outputs.size());
|
||||||
|
// Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'.
|
||||||
|
GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CallGradFunction(const Operation& op,
|
||||||
|
const std::vector<Output>& grad_inputs,
|
||||||
|
std::vector<Output>* grad_outputs) {
|
||||||
|
GradFunc grad_fn;
|
||||||
|
TF_EXPECT_OK(GradOpRegistry::Global()->Lookup(op.node()->name(), &grad_fn));
|
||||||
|
TF_EXPECT_OK(grad_fn(root_, op, grad_inputs, grad_outputs));
|
||||||
|
TF_EXPECT_OK(root_.status());
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor ComputeMatMul(const Output& x, const bool t_x, const Output& y,
|
||||||
|
const bool t_y) {
|
||||||
|
auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
|
||||||
|
TF_EXPECT_OK(root_.status());
|
||||||
|
Tensor out;
|
||||||
|
GetTensor(root_, z, &out);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RandMatMulGradData(const bool tx, const bool ty,
|
||||||
|
std::vector<Tensor>* data) {
|
||||||
|
// z = MatMul(x, y)
|
||||||
|
const int m = Rand();
|
||||||
|
const int k = Rand();
|
||||||
|
const int n = Rand();
|
||||||
|
// x.shape = [m, k]
|
||||||
|
const TensorShape x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
|
||||||
|
data->emplace_back(DT_FLOAT, x_shape);
|
||||||
|
RandTensor(&data->back());
|
||||||
|
// y.shape = [k, n]
|
||||||
|
const TensorShape y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
|
||||||
|
data->emplace_back(DT_FLOAT, y_shape);
|
||||||
|
RandTensor(&data->back());
|
||||||
|
// z.shape = [m, n]
|
||||||
|
data->emplace_back(DT_FLOAT, TensorShape({m, n}));
|
||||||
|
RandTensor(&data->back());
|
||||||
|
}
|
||||||
|
|
||||||
|
void RandTensor(Tensor* t) {
|
||||||
|
test::FillFn<float>(
|
||||||
|
t, [this](const int i) { return static_cast<float>(Rand()); });
|
||||||
|
}
|
||||||
|
|
||||||
|
int Rand() { return 1 + (random::New64() % 10); }
|
||||||
|
|
||||||
|
// TODO(andydavis) Move 'GetTensors/GetTensor' to some testutil class.
|
||||||
|
// Note: they should be moved to a general/non-grad specific testutil class.
|
||||||
|
void GetTensors(const Scope& scope, OutputList tensors,
|
||||||
|
std::vector<Tensor>* out) {
|
||||||
|
SessionOptions options;
|
||||||
|
std::unique_ptr<Session> session(NewSession(options));
|
||||||
|
GraphDef def;
|
||||||
|
scope.graph()->ToGraphDef(&def);
|
||||||
|
|
||||||
|
graph::SetDefaultDevice("/cpu:0", &def);
|
||||||
|
|
||||||
|
TF_CHECK_OK(session->Create(def));
|
||||||
|
std::vector<string> names;
|
||||||
|
for (const auto& t : tensors) {
|
||||||
|
names.push_back(strings::StrCat(t.node()->name(), ":", t.index()));
|
||||||
|
}
|
||||||
|
TF_CHECK_OK(session->Run({}, names, {}, out));
|
||||||
|
TF_CHECK_OK(session->Close());
|
||||||
|
}
|
||||||
|
|
||||||
|
void GetTensor(const Scope& scope, Output tensor, Tensor* out) {
|
||||||
|
std::vector<Tensor> outputs;
|
||||||
|
GetTensors(scope, {tensor}, &outputs);
|
||||||
|
*out = outputs[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
Scope root_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(MathGradTest, MatMulGrad_NoTranspose) {
|
||||||
|
std::vector<Tensor> data;
|
||||||
|
RandMatMulGradData(false, false, &data);
|
||||||
|
auto x = Const(root_, data[0]);
|
||||||
|
auto y = Const(root_, data[1]);
|
||||||
|
auto dz = Const(root_, data[2]);
|
||||||
|
|
||||||
|
std::vector<Tensor> grad_outputs;
|
||||||
|
ComputeMatMulGrad(x, false, y, false, dz, &grad_outputs);
|
||||||
|
|
||||||
|
test::ExpectClose(grad_outputs[0], ComputeMatMul(dz, false, y, true));
|
||||||
|
test::ExpectClose(grad_outputs[1], ComputeMatMul(x, true, dz, false));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MathGradTest, MatMulGrad_TransposeX) {
|
||||||
|
std::vector<Tensor> data;
|
||||||
|
RandMatMulGradData(true, false, &data);
|
||||||
|
auto x = Const(root_, data[0]);
|
||||||
|
auto y = Const(root_, data[1]);
|
||||||
|
auto dz = Const(root_, data[2]);
|
||||||
|
|
||||||
|
std::vector<Tensor> grad_outputs;
|
||||||
|
ComputeMatMulGrad(x, true, y, false, dz, &grad_outputs);
|
||||||
|
|
||||||
|
test::ExpectClose(grad_outputs[0], ComputeMatMul(y, false, dz, true));
|
||||||
|
test::ExpectClose(grad_outputs[1], ComputeMatMul(x, false, dz, false));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MathGradTest, MatMulGrad_TransposeY) {
|
||||||
|
std::vector<Tensor> data;
|
||||||
|
RandMatMulGradData(false, true, &data);
|
||||||
|
auto x = Const(root_, data[0]);
|
||||||
|
auto y = Const(root_, data[1]);
|
||||||
|
auto dz = Const(root_, data[2]);
|
||||||
|
|
||||||
|
std::vector<Tensor> grad_outputs;
|
||||||
|
ComputeMatMulGrad(x, false, y, true, dz, &grad_outputs);
|
||||||
|
|
||||||
|
test::ExpectClose(grad_outputs[0], ComputeMatMul(dz, false, y, false));
|
||||||
|
test::ExpectClose(grad_outputs[1], ComputeMatMul(dz, true, x, false));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
|
||||||
|
std::vector<Tensor> data;
|
||||||
|
RandMatMulGradData(true, true, &data);
|
||||||
|
auto x = Const(root_, data[0]);
|
||||||
|
auto y = Const(root_, data[1]);
|
||||||
|
auto dz = Const(root_, data[2]);
|
||||||
|
|
||||||
|
std::vector<Tensor> grad_outputs;
|
||||||
|
ComputeMatMulGrad(x, true, y, true, dz, &grad_outputs);
|
||||||
|
|
||||||
|
test::ExpectClose(grad_outputs[0], ComputeMatMul(y, true, dz, true));
|
||||||
|
test::ExpectClose(grad_outputs[1], ComputeMatMul(dz, true, x, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -99,7 +99,16 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/beta_test.py"],
|
srcs = ["python/kernel_tests/beta_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:platform_test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cuda_py_tests(
|
||||||
|
name = "binomial_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["python/kernel_tests/binomial_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
":distributions_py",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
tags = ["notsan"],
|
tags = ["notsan"],
|
||||||
@ -179,9 +188,8 @@ cuda_py_tests(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_tests(
|
cuda_py_tests(
|
||||||
name = "kullback_leibler_test",
|
name = "laplace_test",
|
||||||
size = "small",
|
srcs = ["python/kernel_tests/laplace_test.py"],
|
||||||
srcs = ["python/kernel_tests/kullback_leibler_test.py"],
|
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
@ -190,13 +198,14 @@ cuda_py_tests(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_tests(
|
cuda_py_tests(
|
||||||
name = "laplace_test",
|
name = "multinomial_test",
|
||||||
srcs = ["python/kernel_tests/laplace_test.py"],
|
srcs = ["python/kernel_tests/multinomial_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
|
tags = ["notsan"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_tests(
|
cuda_py_tests(
|
||||||
@ -239,6 +248,15 @@ cuda_py_tests(
|
|||||||
srcs = ["python/kernel_tests/uniform_test.py"],
|
srcs = ["python/kernel_tests/uniform_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distributions_py",
|
":distributions_py",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cuda_py_tests(
|
||||||
|
name = "kullback_leibler_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["python/kernel_tests/kullback_leibler_test.py"],
|
||||||
|
additional_deps = [
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -25,6 +25,7 @@ initialized with parameters that define the distributions.
|
|||||||
|
|
||||||
### Univariate (scalar) distributions
|
### Univariate (scalar) distributions
|
||||||
|
|
||||||
|
@@Binomial
|
||||||
@@Bernoulli
|
@@Bernoulli
|
||||||
@@Beta
|
@@Beta
|
||||||
@@Categorical
|
@@Categorical
|
||||||
@ -50,6 +51,7 @@ initialized with parameters that define the distributions.
|
|||||||
|
|
||||||
@@Dirichlet
|
@@Dirichlet
|
||||||
@@DirichletMultinomial
|
@@DirichletMultinomial
|
||||||
|
@@Multinomial
|
||||||
|
|
||||||
### Transformed distributions
|
### Transformed distributions
|
||||||
|
|
||||||
@ -79,6 +81,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops.bernoulli import *
|
from tensorflow.contrib.distributions.python.ops.bernoulli import *
|
||||||
from tensorflow.contrib.distributions.python.ops.beta import *
|
from tensorflow.contrib.distributions.python.ops.beta import *
|
||||||
|
from tensorflow.contrib.distributions.python.ops.binomial import *
|
||||||
from tensorflow.contrib.distributions.python.ops.categorical import *
|
from tensorflow.contrib.distributions.python.ops.categorical import *
|
||||||
from tensorflow.contrib.distributions.python.ops.chi2 import *
|
from tensorflow.contrib.distributions.python.ops.chi2 import *
|
||||||
from tensorflow.contrib.distributions.python.ops.dirichlet import *
|
from tensorflow.contrib.distributions.python.ops.dirichlet import *
|
||||||
@ -89,6 +92,7 @@ from tensorflow.contrib.distributions.python.ops.gamma import *
|
|||||||
from tensorflow.contrib.distributions.python.ops.inverse_gamma import *
|
from tensorflow.contrib.distributions.python.ops.inverse_gamma import *
|
||||||
from tensorflow.contrib.distributions.python.ops.kullback_leibler import *
|
from tensorflow.contrib.distributions.python.ops.kullback_leibler import *
|
||||||
from tensorflow.contrib.distributions.python.ops.laplace import *
|
from tensorflow.contrib.distributions.python.ops.laplace import *
|
||||||
|
from tensorflow.contrib.distributions.python.ops.multinomial import *
|
||||||
from tensorflow.contrib.distributions.python.ops.mvn import *
|
from tensorflow.contrib.distributions.python.ops.mvn import *
|
||||||
from tensorflow.contrib.distributions.python.ops.normal import *
|
from tensorflow.contrib.distributions.python.ops.normal import *
|
||||||
from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import *
|
from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import *
|
||||||
|
@ -57,10 +57,17 @@ class BernoulliTest(tf.test.TestCase):
|
|||||||
self.assertAllClose(scipy.special.logit(p), dist.logits.eval())
|
self.assertAllClose(scipy.special.logit(p), dist.logits.eval())
|
||||||
|
|
||||||
def testInvalidP(self):
|
def testInvalidP(self):
|
||||||
invalid_ps = [1.01, -0.01, 2., -3.]
|
invalid_ps = [1.01, 2.]
|
||||||
for p in invalid_ps:
|
for p in invalid_ps:
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
with self.assertRaisesOpError("x <= y"):
|
with self.assertRaisesOpError("p has components greater than 1"):
|
||||||
|
dist = tf.contrib.distributions.Bernoulli(p=p)
|
||||||
|
dist.p.eval()
|
||||||
|
|
||||||
|
invalid_ps = [-0.01, -3.]
|
||||||
|
for p in invalid_ps:
|
||||||
|
with self.test_session():
|
||||||
|
with self.assertRaisesOpError("Condition x >= 0"):
|
||||||
dist = tf.contrib.distributions.Bernoulli(p=p)
|
dist = tf.contrib.distributions.Bernoulli(p=p)
|
||||||
dist.p.eval()
|
dist.p.eval()
|
||||||
|
|
||||||
|
@ -0,0 +1,173 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from scipy import stats
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class BinomialTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def testSimpleShapes(self):
|
||||||
|
with self.test_session():
|
||||||
|
p = np.float32(np.random.beta(1, 1))
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=1., p=p)
|
||||||
|
self.assertAllEqual([], binom.event_shape().eval())
|
||||||
|
self.assertAllEqual([], binom.batch_shape().eval())
|
||||||
|
self.assertEqual(tf.TensorShape([]), binom.get_event_shape())
|
||||||
|
self.assertEqual(tf.TensorShape([]), binom.get_batch_shape())
|
||||||
|
|
||||||
|
def testComplexShapes(self):
|
||||||
|
with self.test_session():
|
||||||
|
p = np.random.beta(1, 1, size=(3, 2)).astype(np.float32)
|
||||||
|
n = [[3., 2], [4, 5], [6, 7]]
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=n, p=p)
|
||||||
|
self.assertAllEqual([], binom.event_shape().eval())
|
||||||
|
self.assertAllEqual([3, 2], binom.batch_shape().eval())
|
||||||
|
self.assertEqual(tf.TensorShape([]), binom.get_event_shape())
|
||||||
|
self.assertEqual(tf.TensorShape([3, 2]), binom.get_batch_shape())
|
||||||
|
|
||||||
|
def testNProperty(self):
|
||||||
|
p = [[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]]
|
||||||
|
n = [[3.], [4]]
|
||||||
|
with self.test_session():
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=n, p=p)
|
||||||
|
self.assertEqual((2, 1), binom.n.get_shape())
|
||||||
|
self.assertAllClose(n, binom.n.eval())
|
||||||
|
|
||||||
|
def testPProperty(self):
|
||||||
|
p = [[0.1, 0.2, 0.7]]
|
||||||
|
with self.test_session():
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=3., p=p)
|
||||||
|
self.assertEqual((1, 3), binom.p.get_shape())
|
||||||
|
self.assertEqual((1, 3), binom.logits.get_shape())
|
||||||
|
self.assertAllClose(p, binom.p.eval())
|
||||||
|
|
||||||
|
def testLogitsProperty(self):
|
||||||
|
logits = [[0., 9., -0.5]]
|
||||||
|
with self.test_session():
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=3., logits=logits)
|
||||||
|
self.assertEqual((1, 3), binom.p.get_shape())
|
||||||
|
self.assertEqual((1, 3), binom.logits.get_shape())
|
||||||
|
self.assertAllClose(logits, binom.logits.eval())
|
||||||
|
|
||||||
|
def testPmfNandCountsAgree(self):
|
||||||
|
p = [[0.1, 0.2, 0.7]]
|
||||||
|
n = [[5.]]
|
||||||
|
with self.test_session():
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=n, p=p)
|
||||||
|
binom.pmf([2., 3, 2]).eval()
|
||||||
|
binom.pmf([3., 1, 2]).eval()
|
||||||
|
with self.assertRaisesOpError('Condition x >= 0.*'):
|
||||||
|
binom.pmf([-1., 4, 2]).eval()
|
||||||
|
with self.assertRaisesOpError('Condition x <= y.*'):
|
||||||
|
binom.pmf([7., 3, 0]).eval()
|
||||||
|
|
||||||
|
def testPmf_non_integer_counts(self):
|
||||||
|
p = [[0.1, 0.2, 0.7]]
|
||||||
|
n = [[5.]]
|
||||||
|
with self.test_session():
|
||||||
|
# No errors with integer n.
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=n, p=p)
|
||||||
|
binom.pmf([2., 3, 2]).eval()
|
||||||
|
binom.pmf([3., 1, 2]).eval()
|
||||||
|
# Both equality and integer checking fail.
|
||||||
|
with self.assertRaisesOpError('Condition x == y.*'):
|
||||||
|
binom.pmf([1.0, 2.5, 1.5]).eval()
|
||||||
|
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=n, p=p, validate_args=False)
|
||||||
|
binom.pmf([1., 2., 3.]).eval()
|
||||||
|
# Non-integer arguments work.
|
||||||
|
binom.pmf([1.0, 2.5, 1.5]).eval()
|
||||||
|
|
||||||
|
def testPmfBothZeroBatches(self):
|
||||||
|
with self.test_session():
|
||||||
|
# Both zero-batches. No broadcast
|
||||||
|
p = 0.5
|
||||||
|
counts = 1.
|
||||||
|
pmf = tf.contrib.distributions.Binomial(n=1., p=p).pmf(counts)
|
||||||
|
self.assertAllClose(0.5, pmf.eval())
|
||||||
|
self.assertEqual((), pmf.get_shape())
|
||||||
|
|
||||||
|
def testPmfBothZeroBatchesNontrivialN(self):
|
||||||
|
with self.test_session():
|
||||||
|
# Both zero-batches. No broadcast
|
||||||
|
p = 0.1
|
||||||
|
counts = 3.
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=5., p=p)
|
||||||
|
pmf = binom.pmf(counts)
|
||||||
|
self.assertAllClose(stats.binom.pmf(counts, n=5., p=p), pmf.eval())
|
||||||
|
self.assertEqual((), pmf.get_shape())
|
||||||
|
|
||||||
|
def testPmfPStretchedInBroadcastWhenSameRank(self):
|
||||||
|
with self.test_session():
|
||||||
|
p = [[0.1, 0.9]]
|
||||||
|
counts = [[1., 2.]]
|
||||||
|
pmf = tf.contrib.distributions.Binomial(n=3., p=p).pmf(counts)
|
||||||
|
self.assertAllClose(stats.binom.pmf(counts, n=3., p=p), pmf.eval())
|
||||||
|
self.assertEqual((1, 2), pmf.get_shape())
|
||||||
|
|
||||||
|
def testPmfPStretchedInBroadcastWhenLowerRank(self):
|
||||||
|
with self.test_session():
|
||||||
|
p = [0.1, 0.4]
|
||||||
|
counts = [[1.], [0.]]
|
||||||
|
pmf = tf.contrib.distributions.Binomial(n=1., p=p).pmf(counts)
|
||||||
|
self.assertAllClose([[0.1, 0.4], [0.9, 0.6]], pmf.eval())
|
||||||
|
self.assertEqual((2, 2), pmf.get_shape())
|
||||||
|
|
||||||
|
def testBinomialMean(self):
|
||||||
|
with self.test_session():
|
||||||
|
n = 5.
|
||||||
|
p = [0.1, 0.2, 0.7]
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=n, p=p)
|
||||||
|
expected_means = stats.binom.mean(n, p)
|
||||||
|
self.assertEqual((3,), binom.mean().get_shape())
|
||||||
|
self.assertAllClose(expected_means, binom.mean().eval())
|
||||||
|
|
||||||
|
def testBinomialVariance(self):
|
||||||
|
with self.test_session():
|
||||||
|
n = 5.
|
||||||
|
p = [0.1, 0.2, 0.7]
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=n, p=p)
|
||||||
|
expected_variances = stats.binom.var(n, p)
|
||||||
|
self.assertEqual((3,), binom.variance().get_shape())
|
||||||
|
self.assertAllClose(expected_variances, binom.variance().eval())
|
||||||
|
|
||||||
|
def testBinomialMode(self):
|
||||||
|
with self.test_session():
|
||||||
|
n = 5.
|
||||||
|
p = [0.1, 0.2, 0.7]
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=n, p=p)
|
||||||
|
expected_modes = [0., 1, 4]
|
||||||
|
self.assertEqual((3,), binom.mode().get_shape())
|
||||||
|
self.assertAllClose(expected_modes, binom.mode().eval())
|
||||||
|
|
||||||
|
def testBinomialMultipleMode(self):
|
||||||
|
with self.test_session():
|
||||||
|
n = 9.
|
||||||
|
p = [0.1, 0.2, 0.7]
|
||||||
|
binom = tf.contrib.distributions.Binomial(n=n, p=p)
|
||||||
|
# For the case where (n + 1) * p is an integer, the modes are:
|
||||||
|
# (n + 1) * p and (n + 1) * p - 1. In this case, we get back
|
||||||
|
# the larger of the two modes.
|
||||||
|
expected_modes = [1., 2, 7]
|
||||||
|
self.assertEqual((3,), binom.mode().get_shape())
|
||||||
|
self.assertAllClose(expected_modes, binom.mode().eval())
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
@ -65,7 +65,7 @@ class DirichletMultinomialTest(tf.test.TestCase):
|
|||||||
dist.pmf([3., 0, 2]).eval()
|
dist.pmf([3., 0, 2]).eval()
|
||||||
with self.assertRaisesOpError('Condition x >= 0.*'):
|
with self.assertRaisesOpError('Condition x >= 0.*'):
|
||||||
dist.pmf([-1., 4, 2]).eval()
|
dist.pmf([-1., 4, 2]).eval()
|
||||||
with self.assertRaisesOpError('Condition x == y.*'):
|
with self.assertRaisesOpError('counts do not sum to n'):
|
||||||
dist.pmf([3., 3, 0]).eval()
|
dist.pmf([3., 3, 0]).eval()
|
||||||
|
|
||||||
def testPmf_non_integer_counts(self):
|
def testPmf_non_integer_counts(self):
|
||||||
|
@ -0,0 +1,226 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class MultinomialTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def testSimpleShapes(self):
|
||||||
|
with self.test_session():
|
||||||
|
p = [.1, .3, .6]
|
||||||
|
dist = tf.contrib.distributions.Multinomial(n=1., p=p)
|
||||||
|
self.assertEqual(3, dist.event_shape().eval())
|
||||||
|
self.assertAllEqual([], dist.batch_shape().eval())
|
||||||
|
self.assertEqual(tf.TensorShape([3]), dist.get_event_shape())
|
||||||
|
self.assertEqual(tf.TensorShape([]), dist.get_batch_shape())
|
||||||
|
|
||||||
|
def testComplexShapes(self):
|
||||||
|
with self.test_session():
|
||||||
|
p = 0.5 * np.ones([3, 2, 2], dtype=np.float32)
|
||||||
|
n = [[3., 2], [4, 5], [6, 7]]
|
||||||
|
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
|
||||||
|
self.assertEqual(2, dist.event_shape().eval())
|
||||||
|
self.assertAllEqual([3, 2], dist.batch_shape().eval())
|
||||||
|
self.assertEqual(tf.TensorShape([2]), dist.get_event_shape())
|
||||||
|
self.assertEqual(tf.TensorShape([3, 2]), dist.get_batch_shape())
|
||||||
|
|
||||||
|
def testNProperty(self):
|
||||||
|
p = [[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]]
|
||||||
|
n = [[3.], [4]]
|
||||||
|
with self.test_session():
|
||||||
|
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
|
||||||
|
self.assertEqual((2, 1), dist.n.get_shape())
|
||||||
|
self.assertAllClose(n, dist.n.eval())
|
||||||
|
|
||||||
|
def testPProperty(self):
|
||||||
|
p = [[0.1, 0.2, 0.7]]
|
||||||
|
with self.test_session():
|
||||||
|
dist = tf.contrib.distributions.Multinomial(n=3., p=p)
|
||||||
|
self.assertEqual((1, 3), dist.p.get_shape())
|
||||||
|
self.assertEqual((1, 3), dist.logits.get_shape())
|
||||||
|
self.assertAllClose(p, dist.p.eval())
|
||||||
|
|
||||||
|
def testLogitsProperty(self):
|
||||||
|
logits = [[0., 9., -0.5]]
|
||||||
|
with self.test_session():
|
||||||
|
multinom = tf.contrib.distributions.Multinomial(n=3., logits=logits)
|
||||||
|
self.assertEqual((1, 3), multinom.p.get_shape())
|
||||||
|
self.assertEqual((1, 3), multinom.logits.get_shape())
|
||||||
|
self.assertAllClose(logits, multinom.logits.eval())
|
||||||
|
|
||||||
|
def testPmfNandCountsAgree(self):
|
||||||
|
p = [[0.1, 0.2, 0.7]]
|
||||||
|
n = [[5.]]
|
||||||
|
with self.test_session():
|
||||||
|
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
|
||||||
|
dist.pmf([2., 3, 0]).eval()
|
||||||
|
dist.pmf([3., 0, 2]).eval()
|
||||||
|
with self.assertRaisesOpError('Condition x >= 0.*'):
|
||||||
|
dist.pmf([-1., 4, 2]).eval()
|
||||||
|
with self.assertRaisesOpError('counts do not sum to n'):
|
||||||
|
dist.pmf([3., 3, 0]).eval()
|
||||||
|
|
||||||
|
def testPmf_non_integer_counts(self):
|
||||||
|
p = [[0.1, 0.2, 0.7]]
|
||||||
|
n = [[5.]]
|
||||||
|
with self.test_session():
|
||||||
|
# No errors with integer n.
|
||||||
|
multinom = tf.contrib.distributions.Multinomial(n=n, p=p)
|
||||||
|
multinom.pmf([2., 1, 2]).eval()
|
||||||
|
multinom.pmf([3., 0, 2]).eval()
|
||||||
|
# Counts don't sum to n.
|
||||||
|
with self.assertRaisesOpError('counts do not sum to n'):
|
||||||
|
multinom.pmf([2., 3, 2]).eval()
|
||||||
|
# Counts are non-integers.
|
||||||
|
with self.assertRaisesOpError('Condition x == y.*'):
|
||||||
|
multinom.pmf([1.0, 2.5, 1.5]).eval()
|
||||||
|
|
||||||
|
multinom = tf.contrib.distributions.Multinomial(
|
||||||
|
n=n, p=p, validate_args=False)
|
||||||
|
multinom.pmf([1., 2., 2.]).eval()
|
||||||
|
# Non-integer arguments work.
|
||||||
|
multinom.pmf([1.0, 2.5, 1.5]).eval()
|
||||||
|
|
||||||
|
def testPmfBothZeroBatches(self):
|
||||||
|
with self.test_session():
|
||||||
|
# Both zero-batches. No broadcast
|
||||||
|
p = [0.5, 0.5]
|
||||||
|
counts = [1., 0]
|
||||||
|
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
|
||||||
|
self.assertAllClose(0.5, pmf.eval())
|
||||||
|
self.assertEqual((), pmf.get_shape())
|
||||||
|
|
||||||
|
def testPmfBothZeroBatchesNontrivialN(self):
|
||||||
|
with self.test_session():
|
||||||
|
# Both zero-batches. No broadcast
|
||||||
|
p = [0.1, 0.9]
|
||||||
|
counts = [3., 2]
|
||||||
|
dist = tf.contrib.distributions.Multinomial(n=5., p=p)
|
||||||
|
pmf = dist.pmf(counts)
|
||||||
|
# 5 choose 3 = 5 choose 2 = 10. 10 * (.9)^2 * (.1)^3 = 81/10000.
|
||||||
|
self.assertAllClose(81./10000, pmf.eval())
|
||||||
|
self.assertEqual((), pmf.get_shape())
|
||||||
|
|
||||||
|
def testPmfPStretchedInBroadcastWhenSameRank(self):
|
||||||
|
with self.test_session():
|
||||||
|
p = [[0.1, 0.9]]
|
||||||
|
counts = [[1., 0], [0, 1]]
|
||||||
|
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
|
||||||
|
self.assertAllClose([0.1, 0.9], pmf.eval())
|
||||||
|
self.assertEqual((2), pmf.get_shape())
|
||||||
|
|
||||||
|
def testPmfPStretchedInBroadcastWhenLowerRank(self):
|
||||||
|
with self.test_session():
|
||||||
|
p = [0.1, 0.9]
|
||||||
|
counts = [[1., 0], [0, 1]]
|
||||||
|
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
|
||||||
|
self.assertAllClose([0.1, 0.9], pmf.eval())
|
||||||
|
self.assertEqual((2), pmf.get_shape())
|
||||||
|
|
||||||
|
def testPmfCountsStretchedInBroadcastWhenSameRank(self):
|
||||||
|
with self.test_session():
|
||||||
|
p = [[0.1, 0.9], [0.7, 0.3]]
|
||||||
|
counts = [[1., 0]]
|
||||||
|
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
|
||||||
|
self.assertAllClose(pmf.eval(), [0.1, 0.7])
|
||||||
|
self.assertEqual((2), pmf.get_shape())
|
||||||
|
|
||||||
|
def testPmfCountsStretchedInBroadcastWhenLowerRank(self):
|
||||||
|
with self.test_session():
|
||||||
|
p = [[0.1, 0.9], [0.7, 0.3]]
|
||||||
|
counts = [1., 0]
|
||||||
|
pmf = tf.contrib.distributions.Multinomial(n=1., p=p).pmf(counts)
|
||||||
|
self.assertAllClose(pmf.eval(), [0.1, 0.7])
|
||||||
|
self.assertEqual(pmf.get_shape(), (2))
|
||||||
|
|
||||||
|
def testPmfShapeCountsStretched_N(self):
|
||||||
|
with self.test_session():
|
||||||
|
# [2, 2, 2]
|
||||||
|
p = [[[0.1, 0.9], [0.1, 0.9]], [[0.7, 0.3], [0.7, 0.3]]]
|
||||||
|
# [2, 2]
|
||||||
|
n = [[3., 3], [3, 3]]
|
||||||
|
# [2]
|
||||||
|
counts = [2., 1]
|
||||||
|
pmf = tf.contrib.distributions.Multinomial(n=n, p=p).pmf(counts)
|
||||||
|
pmf.eval()
|
||||||
|
self.assertEqual(pmf.get_shape(), (2, 2))
|
||||||
|
|
||||||
|
def testPmfShapeCountsPStretched_N(self):
|
||||||
|
with self.test_session():
|
||||||
|
p = [0.1, 0.9]
|
||||||
|
counts = [3., 2]
|
||||||
|
n = np.full([4, 3], 5., dtype=np.float32)
|
||||||
|
pmf = tf.contrib.distributions.Multinomial(n=n, p=p).pmf(counts)
|
||||||
|
pmf.eval()
|
||||||
|
self.assertEqual((4, 3), pmf.get_shape())
|
||||||
|
|
||||||
|
def testMultinomialMean(self):
|
||||||
|
with self.test_session():
|
||||||
|
n = 5.
|
||||||
|
p = [0.1, 0.2, 0.7]
|
||||||
|
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
|
||||||
|
expected_means = 5 * np.array(p, dtype=np.float32)
|
||||||
|
self.assertEqual((3,), dist.mean().get_shape())
|
||||||
|
self.assertAllClose(expected_means, dist.mean().eval())
|
||||||
|
|
||||||
|
def testMultinomialVariance(self):
|
||||||
|
with self.test_session():
|
||||||
|
n = 5.
|
||||||
|
p = [0.1, 0.2, 0.7]
|
||||||
|
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
|
||||||
|
expected_variances = [
|
||||||
|
[9./20, -1/10, -7/20], [-1/10, 4/5, -7/10], [-7/20, -7/10, 21/20]]
|
||||||
|
self.assertEqual((3, 3), dist.variance().get_shape())
|
||||||
|
self.assertAllClose(expected_variances, dist.variance().eval())
|
||||||
|
|
||||||
|
def testMultinomialVariance_batch(self):
|
||||||
|
with self.test_session():
|
||||||
|
# Shape [2]
|
||||||
|
n = [5.] * 2
|
||||||
|
# Shape [4, 1, 2]
|
||||||
|
p = [[[0.1, 0.9]], [[0.1, 0.9]]] * 2
|
||||||
|
dist = tf.contrib.distributions.Multinomial(n=n, p=p)
|
||||||
|
# Shape [2, 2]
|
||||||
|
inner_var = [[9./20, -9/20], [-9/20, 9/20]]
|
||||||
|
# Shape [4, 2, 2, 2]
|
||||||
|
expected_variances = [[inner_var, inner_var]] * 4
|
||||||
|
self.assertEqual((4, 2, 2, 2), dist.variance().get_shape())
|
||||||
|
self.assertAllClose(expected_variances, dist.variance().eval())
|
||||||
|
|
||||||
|
def testVariance_multidimensional(self):
|
||||||
|
# Shape [3, 5, 4]
|
||||||
|
p = np.random.dirichlet([.25, .25, .25, .25], [3, 5]).astype(np.float32)
|
||||||
|
# Shape [6, 3, 3]
|
||||||
|
p2 = np.random.dirichlet([.3, .3, .4], [6, 3]).astype(np.float32)
|
||||||
|
|
||||||
|
ns = np.random.randint(low=1, high=11, size=[3, 5]).astype(np.float32)
|
||||||
|
ns2 = np.random.randint(low=1, high=11, size=[6, 1]).astype(np.float32)
|
||||||
|
|
||||||
|
with self.test_session():
|
||||||
|
dist = tf.contrib.distributions.Multinomial(ns, p)
|
||||||
|
dist2 = tf.contrib.distributions.Multinomial(ns2, p2)
|
||||||
|
|
||||||
|
variance = dist.variance()
|
||||||
|
variance2 = dist2.variance()
|
||||||
|
self.assertEqual((3, 5, 4, 4), variance.get_shape())
|
||||||
|
self.assertEqual((6, 3, 3, 3), variance2.get_shape())
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
@ -369,5 +369,87 @@ class MultivariateNormalCholeskyTest(tf.test.TestCase):
|
|||||||
self.assertEqual((3, 5), tuple(mvn.batch_shape().eval()))
|
self.assertEqual((3, 5), tuple(mvn.batch_shape().eval()))
|
||||||
|
|
||||||
|
|
||||||
|
class MultivariateNormalFullTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self._rng = np.random.RandomState(42)
|
||||||
|
|
||||||
|
def _random_mu_and_sigma(self, batch_shape, event_shape):
|
||||||
|
# This ensures sigma is positive def.
|
||||||
|
mat_shape = batch_shape + event_shape + event_shape
|
||||||
|
mat = self._rng.randn(*mat_shape)
|
||||||
|
sigma = tf.batch_matmul(mat, mat, adj_y=True).eval()
|
||||||
|
|
||||||
|
mu_shape = batch_shape + event_shape
|
||||||
|
mu = self._rng.randn(*mu_shape)
|
||||||
|
|
||||||
|
return mu, sigma
|
||||||
|
|
||||||
|
def testKLNonBatch(self):
|
||||||
|
batch_shape = ()
|
||||||
|
event_shape = (2,)
|
||||||
|
with self.test_session():
|
||||||
|
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
|
||||||
|
mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
|
||||||
|
mvn_a = distributions.MultivariateNormalFull(mu_a, sigma_a)
|
||||||
|
mvn_b = distributions.MultivariateNormalFull(mu_b, sigma_b)
|
||||||
|
|
||||||
|
kl = distributions.kl(mvn_a, mvn_b)
|
||||||
|
self.assertEqual(batch_shape, kl.get_shape())
|
||||||
|
|
||||||
|
kl_v = kl.eval()
|
||||||
|
expected_kl = _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b)
|
||||||
|
self.assertAllClose(expected_kl, kl_v)
|
||||||
|
|
||||||
|
def testKLBatch(self):
|
||||||
|
batch_shape = (2,)
|
||||||
|
event_shape = (3,)
|
||||||
|
with self.test_session():
|
||||||
|
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
|
||||||
|
mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
|
||||||
|
mvn_a = distributions.MultivariateNormalFull(mu_a, sigma_a)
|
||||||
|
mvn_b = distributions.MultivariateNormalFull(mu_b, sigma_b)
|
||||||
|
|
||||||
|
kl = distributions.kl(mvn_a, mvn_b)
|
||||||
|
self.assertEqual(batch_shape, kl.get_shape())
|
||||||
|
|
||||||
|
kl_v = kl.eval()
|
||||||
|
expected_kl_0 = _compute_non_batch_kl(
|
||||||
|
mu_a[0, :], sigma_a[0, :, :], mu_b[0, :], sigma_b[0, :])
|
||||||
|
expected_kl_1 = _compute_non_batch_kl(
|
||||||
|
mu_a[1, :], sigma_a[1, :, :], mu_b[1, :], sigma_b[1, :])
|
||||||
|
self.assertAllClose(expected_kl_0, kl_v[0])
|
||||||
|
self.assertAllClose(expected_kl_1, kl_v[1])
|
||||||
|
|
||||||
|
def testKLTwoIdenticalDistributionsIsZero(self):
|
||||||
|
batch_shape = (2,)
|
||||||
|
event_shape = (3,)
|
||||||
|
with self.test_session():
|
||||||
|
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
|
||||||
|
mvn_a = distributions.MultivariateNormalFull(mu_a, sigma_a)
|
||||||
|
|
||||||
|
# Should be zero since KL(p || p) = =.
|
||||||
|
kl = distributions.kl(mvn_a, mvn_a)
|
||||||
|
self.assertEqual(batch_shape, kl.get_shape())
|
||||||
|
|
||||||
|
kl_v = kl.eval()
|
||||||
|
self.assertAllClose(np.zeros(*batch_shape), kl_v)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b):
|
||||||
|
"""Non-batch KL for N(mu_a, sigma_a), N(mu_b, sigma_b)."""
|
||||||
|
# Check using numpy operations
|
||||||
|
# This mostly repeats the tensorflow code _kl_mvn_mvn(), but in numpy.
|
||||||
|
# So it is important to also check that KL(mvn, mvn) = 0.
|
||||||
|
sigma_b_inv = np.linalg.inv(sigma_b)
|
||||||
|
|
||||||
|
t = np.trace(sigma_b_inv.dot(sigma_a))
|
||||||
|
q = (mu_b - mu_a).dot(sigma_b_inv).dot(mu_b - mu_a)
|
||||||
|
k = mu_a.shape[0]
|
||||||
|
l = np.log(np.linalg.det(sigma_b) / np.linalg.det(sigma_a))
|
||||||
|
|
||||||
|
return 0.5 * (t + q - k + l)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
@ -19,15 +19,13 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops import distribution
|
from tensorflow.contrib.distributions.python.ops import distribution
|
||||||
|
from tensorflow.contrib.distributions.python.ops import distribution_util
|
||||||
from tensorflow.contrib.distributions.python.ops import kullback_leibler # pylint: disable=line-too-long
|
from tensorflow.contrib.distributions.python.ops import kullback_leibler # pylint: disable=line-too-long
|
||||||
from tensorflow.python.framework import constant_op
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
|
||||||
from tensorflow.python.ops import control_flow_ops
|
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn
|
from tensorflow.python.ops import nn
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
@ -38,10 +36,6 @@ class Bernoulli(distribution.Distribution):
|
|||||||
|
|
||||||
The Bernoulli distribution is parameterized by p, the probability of a
|
The Bernoulli distribution is parameterized by p, the probability of a
|
||||||
positive event.
|
positive event.
|
||||||
|
|
||||||
Note, the following methods of the base class aren't implemented:
|
|
||||||
* cdf
|
|
||||||
* log_cdf
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -64,10 +58,10 @@ class Bernoulli(distribution.Distribution):
|
|||||||
dtype: dtype for samples.
|
dtype: dtype for samples.
|
||||||
validate_args: Whether to assert that `0 <= p <= 1`. If not validate_args,
|
validate_args: Whether to assert that `0 <= p <= 1`. If not validate_args,
|
||||||
`log_pmf` may return nans.
|
`log_pmf` may return nans.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: A name for this distribution.
|
name: A name for this distribution.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -77,27 +71,8 @@ class Bernoulli(distribution.Distribution):
|
|||||||
self._name = name
|
self._name = name
|
||||||
self._dtype = dtype
|
self._dtype = dtype
|
||||||
self._validate_args = validate_args
|
self._validate_args = validate_args
|
||||||
check_op = check_ops.assert_less_equal
|
self._logits, self._p = distribution_util.get_logits_and_prob(
|
||||||
if p is None and logits is None:
|
name=name, logits=logits, p=p, validate_args=validate_args)
|
||||||
raise ValueError("Must pass p or logits.")
|
|
||||||
elif p is not None and logits is not None:
|
|
||||||
raise ValueError("Must pass either p or logits, not both.")
|
|
||||||
elif p is None:
|
|
||||||
with ops.op_scope([logits], name):
|
|
||||||
self._logits = array_ops.identity(logits, name="logits")
|
|
||||||
with ops.name_scope(name):
|
|
||||||
with ops.name_scope("p"):
|
|
||||||
self._p = math_ops.sigmoid(self._logits)
|
|
||||||
elif logits is None:
|
|
||||||
with ops.name_scope(name):
|
|
||||||
with ops.name_scope("p"):
|
|
||||||
p = array_ops.identity(p)
|
|
||||||
one = constant_op.constant(1., p.dtype)
|
|
||||||
zero = constant_op.constant(0., p.dtype)
|
|
||||||
self._p = control_flow_ops.with_dependencies(
|
|
||||||
[check_op(p, one), check_op(zero, p)] if validate_args else [], p)
|
|
||||||
with ops.name_scope("logits"):
|
|
||||||
self._logits = math_ops.log(self._p) - math_ops.log(1. - self._p)
|
|
||||||
with ops.name_scope(name):
|
with ops.name_scope(name):
|
||||||
with ops.name_scope("q"):
|
with ops.name_scope("q"):
|
||||||
self._q = 1. - self._p
|
self._q = 1. - self._p
|
||||||
@ -184,8 +159,12 @@ class Bernoulli(distribution.Distribution):
|
|||||||
event = ops.convert_to_tensor(event, name="event")
|
event = ops.convert_to_tensor(event, name="event")
|
||||||
event = math_ops.cast(event, self.logits.dtype)
|
event = math_ops.cast(event, self.logits.dtype)
|
||||||
logits = self.logits
|
logits = self.logits
|
||||||
if ((event.get_shape().ndims is not None) or
|
# sigmoid_cross_entropy_with_logits doesn't broadcast shape,
|
||||||
(logits.get_shape().ndims is not None) or
|
# so we do this here.
|
||||||
|
# TODO(b/30637701): Check dynamic shape, and don't broadcast if the
|
||||||
|
# dynamic shapes are the same.
|
||||||
|
if (not event.get_shape().is_fully_defined() or
|
||||||
|
not logits.get_shape().is_fully_defined() or
|
||||||
event.get_shape() != logits.get_shape()):
|
event.get_shape() != logits.get_shape()):
|
||||||
logits = array_ops.ones_like(event) * logits
|
logits = array_ops.ones_like(event) * logits
|
||||||
event = array_ops.ones_like(logits) * event
|
event = array_ops.ones_like(logits) * event
|
||||||
@ -206,8 +185,7 @@ class Bernoulli(distribution.Distribution):
|
|||||||
with ops.name_scope(self.name):
|
with ops.name_scope(self.name):
|
||||||
with ops.op_scope([self.p, n], name):
|
with ops.op_scope([self.p, n], name):
|
||||||
n = ops.convert_to_tensor(n, name="n")
|
n = ops.convert_to_tensor(n, name="n")
|
||||||
new_shape = array_ops.concat(
|
new_shape = array_ops.concat(0, ([n], self.batch_shape()))
|
||||||
0, [array_ops.expand_dims(n, 0), self.batch_shape()])
|
|
||||||
uniform = random_ops.random_uniform(
|
uniform = random_ops.random_uniform(
|
||||||
new_shape, seed=seed, dtype=dtypes.float32)
|
new_shape, seed=seed, dtype=dtypes.float32)
|
||||||
sample = math_ops.less(uniform, self.p)
|
sample = math_ops.less(uniform, self.p)
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""The Beta distribution class."""
|
"""The Beta distribution class."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
@ -95,6 +96,7 @@ class Beta(distribution.Distribution):
|
|||||||
x = [.2, .3, .9]
|
x = [.2, .3, .9]
|
||||||
dist.pdf(x) # Shape [2]
|
dist.pdf(x) # Shape [2]
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, a, b, validate_args=True, allow_nan_stats=False,
|
def __init__(self, a, b, validate_args=True, allow_nan_stats=False,
|
||||||
@ -102,20 +104,20 @@ class Beta(distribution.Distribution):
|
|||||||
"""Initialize a batch of Beta distributions.
|
"""Initialize a batch of Beta distributions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a: Positive `float` or `double` tensor with shape broadcastable to
|
a: Positive floating point tensor with shape broadcastable to
|
||||||
`[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
|
`[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
|
||||||
different Beta distributions. This also defines the
|
different Beta distributions. This also defines the
|
||||||
dtype of the distribution.
|
dtype of the distribution.
|
||||||
b: Positive `float` or `double` tensor with shape broadcastable to
|
b: Positive floating point tensor with shape broadcastable to
|
||||||
`[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
|
`[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
|
||||||
different Beta distributions.
|
different Beta distributions.
|
||||||
validate_args: Whether to assert valid values for parameters `a` and `b`,
|
validate_args: Whether to assert valid values for parameters `a` and `b`,
|
||||||
and `x` in `prob` and `log_prob`. If False, correct behavior is not
|
and `x` in `prob` and `log_prob`. If `False`, correct behavior is not
|
||||||
guaranteed.
|
guaranteed.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to prefix Ops created by this distribution class.
|
name: The name to prefix Ops created by this distribution class.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
@ -127,6 +129,7 @@ class Beta(distribution.Distribution):
|
|||||||
# Define a 2-batch.
|
# Define a 2-batch.
|
||||||
dist = Beta([1.0, 2.0], [4.0, 5.0])
|
dist = Beta([1.0, 2.0], [4.0, 5.0])
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
with ops.op_scope([a, b], name):
|
with ops.op_scope([a, b], name):
|
||||||
with ops.control_dependencies([
|
with ops.control_dependencies([
|
||||||
@ -276,8 +279,14 @@ class Beta(distribution.Distribution):
|
|||||||
array_ops.ones_like(a_b_sum, dtype=self.dtype)))
|
array_ops.ones_like(a_b_sum, dtype=self.dtype)))
|
||||||
else:
|
else:
|
||||||
return control_flow_ops.with_dependencies([
|
return control_flow_ops.with_dependencies([
|
||||||
check_ops.assert_less(one, a),
|
check_ops.assert_less(
|
||||||
check_ops.assert_less(one, b)], mode)
|
one, a,
|
||||||
|
message="mode not defined for components of a <= 1"
|
||||||
|
),
|
||||||
|
check_ops.assert_less(
|
||||||
|
one, b,
|
||||||
|
message="mode not defined for components of b <= 1"
|
||||||
|
)], mode)
|
||||||
|
|
||||||
def entropy(self, name="entropy"):
|
def entropy(self, name="entropy"):
|
||||||
"""Entropy of the distribution in nats."""
|
"""Entropy of the distribution in nats."""
|
||||||
@ -306,7 +315,7 @@ class Beta(distribution.Distribution):
|
|||||||
"""`Log(P[counts])`, computed for every batch member.
|
"""`Log(P[counts])`, computed for every batch member.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Non-negative `float` or `double`, tensor whose shape can
|
x: Non-negative floating point tensor whose shape can
|
||||||
be broadcast with `self.a` and `self.b`. For fixed leading
|
be broadcast with `self.a` and `self.b`. For fixed leading
|
||||||
dimensions, the last dimension represents counts for the corresponding
|
dimensions, the last dimension represents counts for the corresponding
|
||||||
Beta distribution in `self.a` and `self.b`. `x` is only legal if
|
Beta distribution in `self.a` and `self.b`. `x` is only legal if
|
||||||
@ -334,7 +343,7 @@ class Beta(distribution.Distribution):
|
|||||||
"""`P[x]`, computed for every batch member.
|
"""`P[x]`, computed for every batch member.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Non-negative `float`, `double` tensor whose shape can
|
x: Non-negative floating point tensor whose shape can
|
||||||
be broadcast with `self.a` and `self.b`. For fixed leading
|
be broadcast with `self.a` and `self.b`. For fixed leading
|
||||||
dimensions, the last dimension represents x for the corresponding Beta
|
dimensions, the last dimension represents x for the corresponding Beta
|
||||||
distribution in `self.a` and `self.b`. `x` is only legal if is
|
distribution in `self.a` and `self.b`. `x` is only legal if is
|
||||||
|
340
tensorflow/contrib/distributions/python/ops/binomial.py
Normal file
340
tensorflow/contrib/distributions/python/ops/binomial.py
Normal file
@ -0,0 +1,340 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""The Binomial distribution class."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# pylint: disable=line-too-long
|
||||||
|
|
||||||
|
from tensorflow.contrib.distributions.python.ops import distribution
|
||||||
|
from tensorflow.contrib.distributions.python.ops import distribution_util
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import check_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
|
||||||
|
# pylint: enable=line-too-long
|
||||||
|
|
||||||
|
|
||||||
|
class Binomial(distribution.Distribution):
|
||||||
|
"""Binomial distribution.
|
||||||
|
|
||||||
|
This distribution is parameterized by a vector `p` of probabilities and `n`,
|
||||||
|
the total counts.
|
||||||
|
|
||||||
|
#### Mathematical details
|
||||||
|
|
||||||
|
The Binomial is a distribution over the number of successes in `n` independent
|
||||||
|
trials, with each trial having the same probability of success `p`.
|
||||||
|
The probability mass function (pmf):
|
||||||
|
|
||||||
|
```pmf(k) = n! / (k! * (n - k)!) * (p)^k * (1 - p)^(n - k)```
|
||||||
|
|
||||||
|
#### Examples
|
||||||
|
|
||||||
|
Create a single distribution, corresponding to 5 coin flips.
|
||||||
|
|
||||||
|
```python
|
||||||
|
dist = Binomial(n=5., p=.5)
|
||||||
|
```
|
||||||
|
|
||||||
|
Create a single distribution (using logits), corresponding to 5 coin flips.
|
||||||
|
|
||||||
|
```python
|
||||||
|
dist = Binomial(n=5., logits=0.)
|
||||||
|
```
|
||||||
|
|
||||||
|
Creates 3 distributions with the third distribution most likely to have
|
||||||
|
successes.
|
||||||
|
|
||||||
|
```python
|
||||||
|
p = [.2, .3, .8]
|
||||||
|
# n will be broadcast to [4., 4., 4.], to match p.
|
||||||
|
dist = Binomial(n=4., p=p)
|
||||||
|
```
|
||||||
|
|
||||||
|
The distribution functions can be evaluated on counts.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# counts same shape as p.
|
||||||
|
counts = [1., 2, 3]
|
||||||
|
dist.prob(counts) # Shape [3]
|
||||||
|
|
||||||
|
# p will be broadcast to [[.2, .3, .8], [.2, .3, .8]] to match counts.
|
||||||
|
counts = [[1., 2, 1], [2, 2, 4]]
|
||||||
|
dist.prob(counts) # Shape [2, 3]
|
||||||
|
|
||||||
|
# p will be broadcast to shape [5, 7, 3] to match counts.
|
||||||
|
counts = [[...]] # Shape [5, 7, 3]
|
||||||
|
dist.prob(counts) # Shape [5, 7, 3]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
n,
|
||||||
|
logits=None,
|
||||||
|
p=None,
|
||||||
|
validate_args=True,
|
||||||
|
allow_nan_stats=False,
|
||||||
|
name="Binomial"):
|
||||||
|
"""Initialize a batch of Binomial distributions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n: Non-negative floating point tensor with shape broadcastable to
|
||||||
|
`[N1,..., Nm]` with `m >= 0` and the same dtype as `p` or `logits`.
|
||||||
|
Defines this as a batch of `N1 x ... x Nm` different Binomial
|
||||||
|
distributions. Its components should be equal to integer values.
|
||||||
|
logits: Floating point tensor representing the log-odds of a
|
||||||
|
positive event with shape broadcastable to `[N1,..., Nm]` `m >= 0`, and
|
||||||
|
the same dtype as `n`. Each entry represents logits for the probability
|
||||||
|
of success for independent Binomial distributions.
|
||||||
|
p: Positive floating point tensor with shape broadcastable to
|
||||||
|
`[N1,..., Nm]` `m >= 0`, `p in [0, 1]`. Each entry represents the
|
||||||
|
probability of success for independent Binomial distributions.
|
||||||
|
validate_args: Whether to assert valid values for parameters `n` and `p`,
|
||||||
|
and `x` in `prob` and `log_prob`. If `False`, correct behavior is not
|
||||||
|
guaranteed.
|
||||||
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
|
undefined statistics will return NaN for this statistic.
|
||||||
|
name: The name to prefix Ops created by this distribution class.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Define 1-batch of a binomial distribution.
|
||||||
|
dist = Binomial(n=2., p=.9)
|
||||||
|
|
||||||
|
# Define a 2-batch.
|
||||||
|
dist = Binomial(n=[4., 5], p=[.1, .3])
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._logits, self._p = distribution_util.get_logits_and_prob(
|
||||||
|
name=name, logits=logits, p=p, validate_args=validate_args)
|
||||||
|
|
||||||
|
with ops.op_scope([n], name):
|
||||||
|
with ops.control_dependencies([
|
||||||
|
check_ops.assert_non_negative(
|
||||||
|
n, message="n has negative components."),
|
||||||
|
distribution_util.assert_integer_form(
|
||||||
|
n, message="n has non-integer components."
|
||||||
|
)] if validate_args else []):
|
||||||
|
self._n = array_ops.identity(n, name="convert_n")
|
||||||
|
|
||||||
|
self._name = name
|
||||||
|
self._validate_args = validate_args
|
||||||
|
self._allow_nan_stats = allow_nan_stats
|
||||||
|
|
||||||
|
self._mean = self._n * self._p
|
||||||
|
self._get_batch_shape = self._mean.get_shape()
|
||||||
|
self._get_event_shape = tensor_shape.TensorShape([])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
"""Name to prepend to all ops."""
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
"""dtype of samples from this distribution."""
|
||||||
|
return self._p.dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def validate_args(self):
|
||||||
|
"""Boolean describing behavior on invalid input."""
|
||||||
|
return self._validate_args
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allow_nan_stats(self):
|
||||||
|
"""Boolean describing behavior when a stat is undefined for batch member."""
|
||||||
|
return self._allow_nan_stats
|
||||||
|
|
||||||
|
def batch_shape(self, name="batch_shape"):
|
||||||
|
"""Batch dimensions of this instance as a 1-D int32 `Tensor`.
|
||||||
|
|
||||||
|
The product of the dimensions of the `batch_shape` is the number of
|
||||||
|
independent distributions of this kind the instance represents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: name to give to the op
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tensor` `batch_shape`
|
||||||
|
"""
|
||||||
|
return array_ops.shape(self._mean)
|
||||||
|
|
||||||
|
def get_batch_shape(self):
|
||||||
|
"""`TensorShape` available at graph construction time.
|
||||||
|
|
||||||
|
Same meaning as `batch_shape`. May be only partially defined.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
batch shape
|
||||||
|
"""
|
||||||
|
return self._get_batch_shape
|
||||||
|
|
||||||
|
def event_shape(self, name="event_shape"):
|
||||||
|
"""Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: name to give to the op
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tensor` `event_shape`
|
||||||
|
"""
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
with ops.op_scope([], name):
|
||||||
|
return constant_op.constant([], name=name, dtype=dtypes.int32)
|
||||||
|
|
||||||
|
def get_event_shape(self):
|
||||||
|
"""`TensorShape` available at graph construction time.
|
||||||
|
|
||||||
|
Same meaning as `event_shape`. May be only partially defined.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
event shape
|
||||||
|
"""
|
||||||
|
return self._get_event_shape
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n(self):
|
||||||
|
"""Number of trials."""
|
||||||
|
return self._n
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logits(self):
|
||||||
|
"""Log-odds."""
|
||||||
|
return self._logits
|
||||||
|
|
||||||
|
@property
|
||||||
|
def p(self):
|
||||||
|
"""Probability of success."""
|
||||||
|
return self._p
|
||||||
|
|
||||||
|
def mean(self, name="mean"):
|
||||||
|
"""Mean of the distribution."""
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
return array_ops.identity(self._mean, name=name)
|
||||||
|
|
||||||
|
def variance(self, name="variance"):
|
||||||
|
"""Variance of the distribution."""
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
with ops.op_scope([self._n, self._p], name):
|
||||||
|
return self._n * self._p * (1 - self._p)
|
||||||
|
|
||||||
|
def std(self, name="std"):
|
||||||
|
"""Standard deviation of the distribution."""
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
with ops.op_scope([self._n, self._p], name):
|
||||||
|
return math_ops.sqrt(self.variance())
|
||||||
|
|
||||||
|
def mode(self, name="mode"):
|
||||||
|
"""Mode of the distribution.
|
||||||
|
|
||||||
|
Note that when `(n + 1) * p` is an integer, there are actually two modes.
|
||||||
|
Namely, `(n + 1) * p` and `(n + 1) * p - 1` are both modes. Here we return
|
||||||
|
only the larger of the two modes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name for this op.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The mode of the Binomial distribution.
|
||||||
|
"""
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
with ops.op_scope([self._n, self._p], name):
|
||||||
|
return math_ops.floor((self._n + 1) * self._p)
|
||||||
|
|
||||||
|
def log_prob(self, counts, name="log_prob"):
|
||||||
|
"""`Log(P[counts])`, computed for every batch member.
|
||||||
|
|
||||||
|
For each batch member of counts `k`, `P[counts]` is the probability that
|
||||||
|
after sampling `n` draws from this Binomial distribution, the number of
|
||||||
|
successes is `k`. Note that different sequences of draws can result in the
|
||||||
|
same counts, thus the probability includes a combinatorial coefficient.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
counts: Non-negative tensor with dtype `dtype` and whose shape can be
|
||||||
|
broadcast with `self.p` and `self.n`. `counts` is only legal if it is
|
||||||
|
less than or equal to `n` and its components are equal to integer
|
||||||
|
values.
|
||||||
|
name: Name to give this Op, defaults to "log_prob".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Log probabilities for each record, shape `[N1,...,Nm]`.
|
||||||
|
"""
|
||||||
|
n = self._n
|
||||||
|
p = self._p
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
with ops.op_scope([self._n, self._p, counts], name):
|
||||||
|
counts = self._check_counts(counts)
|
||||||
|
|
||||||
|
prob_prob = counts * math_ops.log(p) + (
|
||||||
|
n - counts) * math_ops.log(1 - p)
|
||||||
|
|
||||||
|
combinations = math_ops.lgamma(n + 1) - math_ops.lgamma(
|
||||||
|
counts + 1) - math_ops.lgamma(n - counts + 1)
|
||||||
|
log_prob = prob_prob + combinations
|
||||||
|
return log_prob
|
||||||
|
|
||||||
|
def prob(self, counts, name="prob"):
|
||||||
|
"""`P[counts]`, computed for every batch member.
|
||||||
|
|
||||||
|
|
||||||
|
For each batch member of counts `k`, `P[counts]` is the probability that
|
||||||
|
after sampling `n` draws from this Binomial distribution, the number of
|
||||||
|
successes is `k`. Note that different sequences of draws can result in the
|
||||||
|
same counts, thus the probability includes a combinatorial coefficient.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
counts: Non-negative tensor with dtype `dtype` and whose shape can be
|
||||||
|
broadcast with `self.p` and `self.n`. `counts` is only legal if it is
|
||||||
|
less than or equal to `n` and its components are equal to integer
|
||||||
|
values.
|
||||||
|
name: Name to give this Op, defaults to "prob".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Probabilities for each record, shape `[N1,...,Nm]`.
|
||||||
|
"""
|
||||||
|
return super(Binomial, self).prob(counts, name=name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_continuous(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_reparameterized(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _check_counts(self, counts):
|
||||||
|
"""Check counts for proper shape, values, then return tensor version."""
|
||||||
|
counts = ops.convert_to_tensor(counts, name="counts_before_deps")
|
||||||
|
if not self.validate_args:
|
||||||
|
return counts
|
||||||
|
return control_flow_ops.with_dependencies([
|
||||||
|
check_ops.assert_non_negative(
|
||||||
|
counts, message="counts has negative components."),
|
||||||
|
check_ops.assert_less_equal(
|
||||||
|
counts, self._n, message="counts are not less than or equal to n."),
|
||||||
|
distribution_util.assert_integer_form(
|
||||||
|
counts, message="counts have non-integer components.")], counts)
|
@ -34,11 +34,6 @@ class Categorical(distribution.Distribution):
|
|||||||
|
|
||||||
The categorical distribution is parameterized by the log-probabilities
|
The categorical distribution is parameterized by the log-probabilities
|
||||||
of a set of classes.
|
of a set of classes.
|
||||||
|
|
||||||
Note, the following methods of the base class aren't implemented:
|
|
||||||
* mean
|
|
||||||
* cdf
|
|
||||||
* log_cdf
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -57,10 +52,10 @@ class Categorical(distribution.Distribution):
|
|||||||
indexes into the classes.
|
indexes into the classes.
|
||||||
dtype: The type of the event samples (default: int32).
|
dtype: The type of the event samples (default: int32).
|
||||||
validate_args: Unused in this distribution.
|
validate_args: Unused in this distribution.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: A name for this distribution (optional).
|
name: A name for this distribution (optional).
|
||||||
"""
|
"""
|
||||||
self._allow_nan_stats = allow_nan_stats
|
self._allow_nan_stats = allow_nan_stats
|
||||||
@ -177,8 +172,7 @@ class Categorical(distribution.Distribution):
|
|||||||
samples = math_ops.cast(samples, self._dtype)
|
samples = math_ops.cast(samples, self._dtype)
|
||||||
ret = array_ops.reshape(
|
ret = array_ops.reshape(
|
||||||
array_ops.transpose(samples),
|
array_ops.transpose(samples),
|
||||||
array_ops.concat(
|
array_ops.concat(0, ([n], self.batch_shape())))
|
||||||
0, [array_ops.expand_dims(n, 0), self.batch_shape()]))
|
|
||||||
ret.set_shape(tensor_shape.vector(tensor_util.constant_value(n))
|
ret.set_shape(tensor_shape.vector(tensor_util.constant_value(n))
|
||||||
.concatenate(self.get_batch_shape()))
|
.concatenate(self.get_batch_shape()))
|
||||||
return ret
|
return ret
|
||||||
|
@ -42,15 +42,15 @@ class Chi2(gamma.Gamma):
|
|||||||
"""Construct Chi2 distributions with parameter `df`.
|
"""Construct Chi2 distributions with parameter `df`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
df: `float` or `double` tensor, the degrees of freedom of the
|
df: Floating point tensor, the degrees of freedom of the
|
||||||
distribution(s). `df` must contain only positive values.
|
distribution(s). `df` must contain only positive values.
|
||||||
validate_args: Whether to assert that `df > 0`, and that `x > 0` in the
|
validate_args: Whether to assert that `df > 0`, and that `x > 0` in the
|
||||||
methods `prob(x)` and `log_prob(x)`. If `validate_args` is False
|
methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False`
|
||||||
and the inputs are invalid, correct behavior is not guaranteed.
|
and the inputs are invalid, correct behavior is not guaranteed.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to prepend to all ops created by this distribution.
|
name: The name to prepend to all ops created by this distribution.
|
||||||
"""
|
"""
|
||||||
# Even though all stats of chi2 are defined for valid parameters, this is
|
# Even though all stats of chi2 are defined for valid parameters, this is
|
||||||
|
@ -19,9 +19,8 @@ from __future__ import print_function
|
|||||||
|
|
||||||
# pylint: disable=line-too-long
|
# pylint: disable=line-too-long
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops import distribution
|
from tensorflow.contrib.distributions.python.ops import distribution
|
||||||
|
from tensorflow.contrib.distributions.python.ops import distribution_util
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
@ -29,7 +28,6 @@ from tensorflow.python.framework import tensor_util
|
|||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import logging_ops
|
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import special_math_ops
|
from tensorflow.python.ops import special_math_ops
|
||||||
@ -37,24 +35,6 @@ from tensorflow.python.ops import special_math_ops
|
|||||||
# pylint: enable=line-too-long
|
# pylint: enable=line-too-long
|
||||||
|
|
||||||
|
|
||||||
def _assert_close(x, y, data=None, summarize=None, name=None):
|
|
||||||
if x.dtype.is_integer:
|
|
||||||
return check_ops.assert_equal(
|
|
||||||
x, y, data=data, summarize=summarize, name=name)
|
|
||||||
|
|
||||||
with ops.op_scope([x, y, data], name, "assert_close"):
|
|
||||||
x = ops.convert_to_tensor(x, name="x")
|
|
||||||
y = ops.convert_to_tensor(y, name="y")
|
|
||||||
tol = np.finfo(x.dtype.as_numpy_dtype).resolution
|
|
||||||
if data is None:
|
|
||||||
data = [
|
|
||||||
"Condition x ~= y did not hold element-wise: x = ", x.name, x, "y = ",
|
|
||||||
y.name, y
|
|
||||||
]
|
|
||||||
condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol))
|
|
||||||
return logging_ops.Assert(condition, data, summarize=summarize)
|
|
||||||
|
|
||||||
|
|
||||||
class Dirichlet(distribution.Distribution):
|
class Dirichlet(distribution.Distribution):
|
||||||
"""Dirichlet distribution.
|
"""Dirichlet distribution.
|
||||||
|
|
||||||
@ -117,6 +97,7 @@ class Dirichlet(distribution.Distribution):
|
|||||||
x = [.2, .3, .5]
|
x = [.2, .3, .5]
|
||||||
dist.prob(x) # Shape [2]
|
dist.prob(x) # Shape [2]
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -127,16 +108,16 @@ class Dirichlet(distribution.Distribution):
|
|||||||
"""Initialize a batch of Dirichlet distributions.
|
"""Initialize a batch of Dirichlet distributions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
alpha: Positive `float` or `double` tensor with shape broadcastable to
|
alpha: Positive floating point tensor with shape broadcastable to
|
||||||
`[N1,..., Nm, k]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
|
`[N1,..., Nm, k]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm`
|
||||||
different `k` class Dirichlet distributions.
|
different `k` class Dirichlet distributions.
|
||||||
validate_args: Whether to assert valid values for parameters `alpha` and
|
validate_args: Whether to assert valid values for parameters `alpha` and
|
||||||
`x` in `prob` and `log_prob`. If False, correct behavior is not
|
`x` in `prob` and `log_prob`. If `False`, correct behavior is not
|
||||||
guaranteed.
|
guaranteed.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to prefix Ops created by this distribution class.
|
name: The name to prefix Ops created by this distribution class.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
@ -149,6 +130,7 @@ class Dirichlet(distribution.Distribution):
|
|||||||
# Define a 2-batch of 3-class distributions.
|
# Define a 2-batch of 3-class distributions.
|
||||||
dist = Dirichlet([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
dist = Dirichlet([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
with ops.op_scope([alpha], name):
|
with ops.op_scope([alpha], name):
|
||||||
alpha = ops.convert_to_tensor(alpha, name="alpha_before_deps")
|
alpha = ops.convert_to_tensor(alpha, name="alpha_before_deps")
|
||||||
@ -302,7 +284,9 @@ class Dirichlet(distribution.Distribution):
|
|||||||
array_ops.ones_like(self._alpha, dtype=self.dtype)))
|
array_ops.ones_like(self._alpha, dtype=self.dtype)))
|
||||||
else:
|
else:
|
||||||
return control_flow_ops.with_dependencies([
|
return control_flow_ops.with_dependencies([
|
||||||
check_ops.assert_less(one, self._alpha)
|
check_ops.assert_less(
|
||||||
|
one, self._alpha,
|
||||||
|
message="mode not defined for components of alpha <= 1")
|
||||||
], mode)
|
], mode)
|
||||||
|
|
||||||
def entropy(self, name="entropy"):
|
def entropy(self, name="entropy"):
|
||||||
@ -334,7 +318,7 @@ class Dirichlet(distribution.Distribution):
|
|||||||
"""`Log(P[counts])`, computed for every batch member.
|
"""`Log(P[counts])`, computed for every batch member.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Non-negative `float` or `double`, tensor whose shape can
|
x: Non-negative tensor with dtype `dtype` and whose shape can
|
||||||
be broadcast with `self.alpha`. For fixed leading dimensions, the last
|
be broadcast with `self.alpha`. For fixed leading dimensions, the last
|
||||||
dimension represents counts for the corresponding Dirichlet distribution
|
dimension represents counts for the corresponding Dirichlet distribution
|
||||||
in `self.alpha`. `x` is only legal if it sums up to one.
|
in `self.alpha`. `x` is only legal if it sums up to one.
|
||||||
@ -359,7 +343,7 @@ class Dirichlet(distribution.Distribution):
|
|||||||
"""`P[x]`, computed for every batch member.
|
"""`P[x]`, computed for every batch member.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Non-negative `float`, `double` tensor whose shape can
|
x: Non-negative tensor with dtype `dtype` and whose shape can
|
||||||
be broadcast with `self.alpha`. For fixed leading dimensions, the last
|
be broadcast with `self.alpha`. For fixed leading dimensions, the last
|
||||||
dimension represents x for the corresponding Dirichlet distribution in
|
dimension represents x for the corresponding Dirichlet distribution in
|
||||||
`self.alpha` and `self.beta`. `x` is only legal if it sums up to one.
|
`self.alpha` and `self.beta`. `x` is only legal if it sums up to one.
|
||||||
@ -407,7 +391,8 @@ class Dirichlet(distribution.Distribution):
|
|||||||
x = ops.convert_to_tensor(x, name="x_before_deps")
|
x = ops.convert_to_tensor(x, name="x_before_deps")
|
||||||
candidate_one = math_ops.reduce_sum(x, reduction_indices=[-1])
|
candidate_one = math_ops.reduce_sum(x, reduction_indices=[-1])
|
||||||
one = constant_op.constant(1., self.dtype)
|
one = constant_op.constant(1., self.dtype)
|
||||||
dependencies = [check_ops.assert_positive(x), check_ops.assert_less(x, one),
|
dependencies = [check_ops.assert_positive(x), check_ops.assert_less(
|
||||||
_assert_close(one, candidate_one)
|
x, one, message="x has components greater than or equal to 1"),
|
||||||
|
distribution_util.assert_close(one, candidate_one)
|
||||||
] if self.validate_args else []
|
] if self.validate_args else []
|
||||||
return control_flow_ops.with_dependencies(dependencies, x)
|
return control_flow_ops.with_dependencies(dependencies, x)
|
||||||
|
@ -13,13 +13,15 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""The Dirichlet Multinomial distribution class."""
|
"""The Dirichlet Multinomial distribution class."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
# pylint: disable=line-too-long
|
# pylint: disable=line-too-long
|
||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops import distribution # pylint: disable=line-too-long
|
from tensorflow.contrib.distributions.python.ops import distribution
|
||||||
|
from tensorflow.contrib.distributions.python.ops import distribution_util
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
@ -30,34 +32,6 @@ from tensorflow.python.ops import special_math_ops
|
|||||||
# pylint: enable=line-too-long
|
# pylint: enable=line-too-long
|
||||||
|
|
||||||
|
|
||||||
def _assert_integer_form(x):
|
|
||||||
"""Check x for integer components (or floats that are equal to integers)."""
|
|
||||||
x = ops.convert_to_tensor(x, name='x')
|
|
||||||
casted_x = math_ops.to_int64(x)
|
|
||||||
return check_ops.assert_equal(x, math_ops.cast(
|
|
||||||
math_ops.round(casted_x), x.dtype))
|
|
||||||
|
|
||||||
|
|
||||||
def _log_combinations(n, counts, name='log_combinations'):
|
|
||||||
"""Log number of ways counts could have come in."""
|
|
||||||
# First a bit about the number of ways counts could have come in:
|
|
||||||
# E.g. if counts = [1, 2], then this is 3 choose 2.
|
|
||||||
# In general, this is (sum counts)! / sum(counts!)
|
|
||||||
# The sum should be along the last dimension of counts. This is the
|
|
||||||
# "distribution" dimension. Here n a priori represents the sum of counts.
|
|
||||||
with ops.op_scope([counts], name):
|
|
||||||
# To compute factorials, use the fact that Gamma(n + 1) = n!
|
|
||||||
# Compute two terms, each a sum over counts. Compute each for each
|
|
||||||
# batch member.
|
|
||||||
# Log Gamma((sum counts) + 1) = Log((sum counts)!)
|
|
||||||
total_permutations = math_ops.lgamma(n + 1)
|
|
||||||
# sum(Log Gamma(counts + 1)) = Log sum(counts!)
|
|
||||||
counts_factorial = math_ops.lgamma(counts + 1)
|
|
||||||
redundant_permutations = math_ops.reduce_sum(counts_factorial,
|
|
||||||
reduction_indices=[-1])
|
|
||||||
return total_permutations - redundant_permutations
|
|
||||||
|
|
||||||
|
|
||||||
class DirichletMultinomial(distribution.Distribution):
|
class DirichletMultinomial(distribution.Distribution):
|
||||||
"""DirichletMultinomial mixture distribution.
|
"""DirichletMultinomial mixture distribution.
|
||||||
|
|
||||||
@ -126,6 +100,7 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
counts = [2, 1, 0]
|
counts = [2, 1, 0]
|
||||||
dist.pmf(counts) # Shape [2]
|
dist.pmf(counts) # Shape [2]
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO(b/27419586) Change docstring for dtype of alpha once int allowed.
|
# TODO(b/27419586) Change docstring for dtype of alpha once int allowed.
|
||||||
@ -134,26 +109,26 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
alpha,
|
alpha,
|
||||||
validate_args=True,
|
validate_args=True,
|
||||||
allow_nan_stats=False,
|
allow_nan_stats=False,
|
||||||
name='DirichletMultinomial'):
|
name="DirichletMultinomial"):
|
||||||
"""Initialize a batch of DirichletMultinomial distributions.
|
"""Initialize a batch of DirichletMultinomial distributions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n: Non-negative `float` or `double` tensor, whose dtype is the same as
|
n: Non-negative floating point tensor, whose dtype is the same as
|
||||||
`alpha`. The shape is broadcastable to `[N1,..., Nm]` with `m >= 0`.
|
`alpha`. The shape is broadcastable to `[N1,..., Nm]` with `m >= 0`.
|
||||||
Defines this as a batch of `N1 x ... x Nm` different Dirichlet
|
Defines this as a batch of `N1 x ... x Nm` different Dirichlet
|
||||||
multinomial distributions. Its components should be equal to integral
|
multinomial distributions. Its components should be equal to integer
|
||||||
values.
|
values.
|
||||||
alpha: Positive `float` or `double` tensor, whose dtype is the same as
|
alpha: Positive floating point tensor, whose dtype is the same as
|
||||||
`n` with shape broadcastable to `[N1,..., Nm, k]` `m >= 0`. Defines
|
`n` with shape broadcastable to `[N1,..., Nm, k]` `m >= 0`. Defines
|
||||||
this as a batch of `N1 x ... x Nm` different `k` class Dirichlet
|
this as a batch of `N1 x ... x Nm` different `k` class Dirichlet
|
||||||
multinomial distributions.
|
multinomial distributions.
|
||||||
validate_args: Whether to assert valid values for parameters `alpha` and
|
validate_args: Whether to assert valid values for parameters `alpha` and
|
||||||
`n`, and `x` in `prob` and `log_prob`. If False, correct behavior is
|
`n`, and `x` in `prob` and `log_prob`. If `False`, correct behavior is
|
||||||
not guaranteed.
|
not guaranteed.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to prefix Ops created by this distribution class.
|
name: The name to prefix Ops created by this distribution class.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
@ -166,6 +141,7 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
# Define a 2-batch of 3-class distributions.
|
# Define a 2-batch of 3-class distributions.
|
||||||
dist = DirichletMultinomial([3., 4], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
dist = DirichletMultinomial([3., 4], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self._allow_nan_stats = allow_nan_stats
|
self._allow_nan_stats = allow_nan_stats
|
||||||
self._validate_args = validate_args
|
self._validate_args = validate_args
|
||||||
@ -221,7 +197,7 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
"""dtype of samples from this distribution."""
|
"""dtype of samples from this distribution."""
|
||||||
return self._alpha.dtype
|
return self._alpha.dtype
|
||||||
|
|
||||||
def mean(self, name='mean'):
|
def mean(self, name="mean"):
|
||||||
"""Class means for every batch member."""
|
"""Class means for every batch member."""
|
||||||
alpha = self._alpha
|
alpha = self._alpha
|
||||||
alpha_sum = self._alpha_sum
|
alpha_sum = self._alpha_sum
|
||||||
@ -231,7 +207,7 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
mean_no_n = alpha / array_ops.expand_dims(alpha_sum, -1)
|
mean_no_n = alpha / array_ops.expand_dims(alpha_sum, -1)
|
||||||
return array_ops.expand_dims(n, -1) * mean_no_n
|
return array_ops.expand_dims(n, -1) * mean_no_n
|
||||||
|
|
||||||
def variance(self, name='mean'):
|
def variance(self, name="mean"):
|
||||||
"""Class variances for every batch member.
|
"""Class variances for every batch member.
|
||||||
|
|
||||||
The variance for each batch member is defined as the following:
|
The variance for each batch member is defined as the following:
|
||||||
@ -273,7 +249,7 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
variance *= array_ops.expand_dims(shared_factor, -1)
|
variance *= array_ops.expand_dims(shared_factor, -1)
|
||||||
return variance
|
return variance
|
||||||
|
|
||||||
def batch_shape(self, name='batch_shape'):
|
def batch_shape(self, name="batch_shape"):
|
||||||
"""Batch dimensions of this instance as a 1-D int32 `Tensor`.
|
"""Batch dimensions of this instance as a 1-D int32 `Tensor`.
|
||||||
|
|
||||||
The product of the dimensions of the `batch_shape` is the number of
|
The product of the dimensions of the `batch_shape` is the number of
|
||||||
@ -299,7 +275,7 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
"""
|
"""
|
||||||
return self._get_batch_shape
|
return self._get_batch_shape
|
||||||
|
|
||||||
def event_shape(self, name='event_shape'):
|
def event_shape(self, name="event_shape"):
|
||||||
"""Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
|
"""Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -322,15 +298,15 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
"""
|
"""
|
||||||
return self._get_event_shape
|
return self._get_event_shape
|
||||||
|
|
||||||
def cdf(self, x, name='cdf'):
|
def cdf(self, x, name="cdf"):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'DirichletMultinomial does not have a well-defined cdf.')
|
"DirichletMultinomial does not have a well-defined cdf.")
|
||||||
|
|
||||||
def log_cdf(self, x, name='log_cdf'):
|
def log_cdf(self, x, name="log_cdf"):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'DirichletMultinomial does not have a well-defined cdf.')
|
"DirichletMultinomial does not have a well-defined cdf.")
|
||||||
|
|
||||||
def log_prob(self, counts, name='log_prob'):
|
def log_prob(self, counts, name="log_prob"):
|
||||||
"""`Log(P[counts])`, computed for every batch member.
|
"""`Log(P[counts])`, computed for every batch member.
|
||||||
|
|
||||||
For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability
|
For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability
|
||||||
@ -340,12 +316,11 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
probability includes a combinatorial coefficient.
|
probability includes a combinatorial coefficient.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
counts: Non-negative `float` or `double` tensor whose dtype is the same
|
counts: Non-negative tensor with dtype `dtype` and whose shape can be
|
||||||
`self` and whose shape can be broadcast with `self.alpha`. For fixed
|
broadcast with `self.alpha`. For fixed leading dimensions, the last
|
||||||
leading dimensions, the last dimension represents counts for the
|
dimension represents counts for the corresponding Dirichlet Multinomial
|
||||||
corresponding Dirichlet Multinomial distribution in `self.alpha`.
|
distribution in `self.alpha`. `counts` is only legal if it sums up to
|
||||||
`counts` is only legal if it sums up to `n` and its components are
|
`n` and its components are equal to integer values.
|
||||||
equal to integral values.
|
|
||||||
name: Name to give this Op, defaults to "log_prob".
|
name: Name to give this Op, defaults to "log_prob".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -359,20 +334,11 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
|
|
||||||
ordered_prob = (special_math_ops.lbeta(alpha + counts) -
|
ordered_prob = (special_math_ops.lbeta(alpha + counts) -
|
||||||
special_math_ops.lbeta(alpha))
|
special_math_ops.lbeta(alpha))
|
||||||
log_prob = ordered_prob + _log_combinations(n, counts)
|
log_prob = ordered_prob + distribution_util.log_combinations(
|
||||||
# If alpha = counts = [[]], ordered_prob carries the right shape, which
|
n, counts)
|
||||||
# is []. However, since reduce_sum([[]]) = [0], log_combinations = [0],
|
|
||||||
# which is not correct. Luckily, [] + [0] = [], so the sum is fine, but
|
|
||||||
# shape must be inferred from ordered_prob. We must also make this
|
|
||||||
# broadcastable with n, so this is multiplied by n to ensure the shape
|
|
||||||
# is correctly inferred.
|
|
||||||
# Note also that tf.constant([]).get_shape() =
|
|
||||||
# TensorShape([Dimension(0)])
|
|
||||||
broadcasted_tensor = ordered_prob * n
|
|
||||||
log_prob.set_shape(broadcasted_tensor.get_shape())
|
|
||||||
return log_prob
|
return log_prob
|
||||||
|
|
||||||
def prob(self, counts, name='prob'):
|
def prob(self, counts, name="prob"):
|
||||||
"""`P[counts]`, computed for every batch member.
|
"""`P[counts]`, computed for every batch member.
|
||||||
|
|
||||||
For each batch of counts `[c_1,...,c_k]`, `P[counts]` is the probability
|
For each batch of counts `[c_1,...,c_k]`, `P[counts]` is the probability
|
||||||
@ -382,12 +348,11 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
probability includes a combinatorial coefficient.
|
probability includes a combinatorial coefficient.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
counts: Non-negative `float` or `double` tensor whose dtype is the same
|
counts: Non-negative tensor with dtype `dtype` and whose shape can be
|
||||||
`self` and whose shape can be broadcast with `self.alpha`. For fixed
|
broadcast with `self.alpha`. For fixed leading dimensions, the last
|
||||||
leading dimensions, the last dimension represents counts for the
|
dimension represents counts for the corresponding Dirichlet Multinomial
|
||||||
corresponding Dirichlet Multinomial distribution in `self.alpha`.
|
distribution in `self.alpha`. `counts` is only legal if it sums up to
|
||||||
`counts` is only legal if it sums up to `n` and its components are
|
`n` and its components are equal to integer values.
|
||||||
equal to integral values.
|
|
||||||
name: Name to give this Op, defaults to "prob".
|
name: Name to give this Op, defaults to "prob".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -397,18 +362,21 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
|
|
||||||
def _check_counts(self, counts):
|
def _check_counts(self, counts):
|
||||||
"""Check counts for proper shape, values, then return tensor version."""
|
"""Check counts for proper shape, values, then return tensor version."""
|
||||||
counts = ops.convert_to_tensor(counts, name='counts')
|
counts = ops.convert_to_tensor(counts, name="counts")
|
||||||
if not self.validate_args:
|
if not self.validate_args:
|
||||||
return counts
|
return counts
|
||||||
candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1])
|
candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1])
|
||||||
|
|
||||||
return control_flow_ops.with_dependencies([
|
return control_flow_ops.with_dependencies([
|
||||||
check_ops.assert_non_negative(counts),
|
check_ops.assert_non_negative(counts),
|
||||||
check_ops.assert_equal(self._n, candidate_n),
|
check_ops.assert_equal(
|
||||||
_assert_integer_form(counts)], counts)
|
self._n, candidate_n,
|
||||||
|
message="counts do not sum to n"
|
||||||
|
),
|
||||||
|
distribution_util.assert_integer_form(counts)], counts)
|
||||||
|
|
||||||
def _check_alpha(self, alpha):
|
def _check_alpha(self, alpha):
|
||||||
alpha = ops.convert_to_tensor(alpha, name='alpha')
|
alpha = ops.convert_to_tensor(alpha, name="alpha")
|
||||||
if not self.validate_args:
|
if not self.validate_args:
|
||||||
return alpha
|
return alpha
|
||||||
return control_flow_ops.with_dependencies(
|
return control_flow_ops.with_dependencies(
|
||||||
@ -416,11 +384,12 @@ class DirichletMultinomial(distribution.Distribution):
|
|||||||
check_ops.assert_positive(alpha)], alpha)
|
check_ops.assert_positive(alpha)], alpha)
|
||||||
|
|
||||||
def _check_n(self, n):
|
def _check_n(self, n):
|
||||||
n = ops.convert_to_tensor(n, name='n')
|
n = ops.convert_to_tensor(n, name="n")
|
||||||
if not self.validate_args:
|
if not self.validate_args:
|
||||||
return n
|
return n
|
||||||
return control_flow_ops.with_dependencies(
|
return control_flow_ops.with_dependencies(
|
||||||
[check_ops.assert_non_negative(n), _assert_integer_form(n)], n)
|
[check_ops.assert_non_negative(n),
|
||||||
|
distribution_util.assert_integer_form(n)], n)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_continuous(self):
|
def is_continuous(self):
|
||||||
|
177
tensorflow/contrib/distributions/python/ops/distribution_util.py
Normal file
177
tensorflow/contrib/distributions/python/ops/distribution_util.py
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Utilities for probability distributions."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import check_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import logging_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
|
||||||
|
|
||||||
|
def assert_close(
|
||||||
|
x, y, data=None, summarize=None, message=None, name="assert_close"):
|
||||||
|
"""Assert that that x and y are within machine epsilon of each other.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Numeric `Tensor`
|
||||||
|
y: Numeric `Tensor`
|
||||||
|
data: The tensors to print out if the condition is `False`. Defaults to
|
||||||
|
error message and first few entries of `x` and `y`.
|
||||||
|
summarize: Print this many entries of each tensor.
|
||||||
|
message: A string to prefix to the default message.
|
||||||
|
name: A name for this operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Op raising `InvalidArgumentError` if |x - y| > machine epsilon.
|
||||||
|
"""
|
||||||
|
message = message or ""
|
||||||
|
x = ops.convert_to_tensor(x, name="x")
|
||||||
|
y = ops.convert_to_tensor(y, name="y")
|
||||||
|
|
||||||
|
if x.dtype.is_integer:
|
||||||
|
return check_ops.assert_equal(
|
||||||
|
x, y, data=data, summarize=summarize, message=message, name=name)
|
||||||
|
|
||||||
|
with ops.op_scope([x, y, data], name, "assert_close"):
|
||||||
|
tol = np.finfo(x.dtype.as_numpy_dtype).resolution
|
||||||
|
if data is None:
|
||||||
|
data = [
|
||||||
|
message,
|
||||||
|
"Condition x ~= y did not hold element-wise: x = ", x.name, x, "y = ",
|
||||||
|
y.name, y
|
||||||
|
]
|
||||||
|
condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol))
|
||||||
|
return logging_ops.Assert(
|
||||||
|
condition, data, summarize=summarize)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_integer_form(
|
||||||
|
x, data=None, summarize=None, message=None, name="assert_integer_form"):
|
||||||
|
"""Assert that x has integer components (or floats equal to integers).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Numeric `Tensor`
|
||||||
|
data: The tensors to print out if the condition is `False`. Defaults to
|
||||||
|
error message and first few entries of `x` and `y`.
|
||||||
|
summarize: Print this many entries of each tensor.
|
||||||
|
message: A string to prefix to the default message.
|
||||||
|
name: A name for this operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Op raising `InvalidArgumentError` if round(x) != x.
|
||||||
|
"""
|
||||||
|
|
||||||
|
message = message or "x has non-integer components"
|
||||||
|
x = ops.convert_to_tensor(x, name="x")
|
||||||
|
casted_x = math_ops.to_int64(x)
|
||||||
|
return check_ops.assert_equal(
|
||||||
|
x, math_ops.cast(math_ops.round(casted_x), x.dtype),
|
||||||
|
data=data, summarize=summarize, message=message, name=name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_logits_and_prob(
|
||||||
|
logits=None, p=None, multidimensional=False, validate_args=True, name=None):
|
||||||
|
"""Converts logits to probabilities and vice-versa, and returns both.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: Numeric `Tensor` representing log-odds.
|
||||||
|
p: Numeric `Tensor` representing probabilities.
|
||||||
|
multidimensional: Given `p` a [N1, N2, ... k] dimensional tensor,
|
||||||
|
whether the last dimension represents the probability between k classes.
|
||||||
|
This will additionally assert that the values in the last dimension
|
||||||
|
sum to one. If `False`, will instead assert that each value is in
|
||||||
|
`[0, 1]`.
|
||||||
|
validate_args: Whether to assert `0 <= p <= 1` if multidimensional is
|
||||||
|
`False`, otherwise that the last dimension of `p` sums to one.
|
||||||
|
name: A name for this operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple with `logits` and `p`. If `p` has an entry that is `0` or `1`, then
|
||||||
|
the corresponding entry in the returned logits will be `-Inf` and `Inf`
|
||||||
|
respectively.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if neither `p` nor `logits` were passed in, or both were.
|
||||||
|
"""
|
||||||
|
if p is None and logits is None:
|
||||||
|
raise ValueError("Must pass p or logits.")
|
||||||
|
elif p is not None and logits is not None:
|
||||||
|
raise ValueError("Must pass either p or logits, not both.")
|
||||||
|
elif p is None:
|
||||||
|
with ops.op_scope([logits], name):
|
||||||
|
logits = array_ops.identity(logits, name="logits")
|
||||||
|
with ops.name_scope(name):
|
||||||
|
with ops.name_scope("p"):
|
||||||
|
p = math_ops.sigmoid(logits)
|
||||||
|
elif logits is None:
|
||||||
|
with ops.name_scope(name):
|
||||||
|
with ops.name_scope("p"):
|
||||||
|
p = array_ops.identity(p)
|
||||||
|
if validate_args:
|
||||||
|
one = constant_op.constant(1., p.dtype)
|
||||||
|
dependencies = [check_ops.assert_non_negative(p)]
|
||||||
|
if multidimensional:
|
||||||
|
dependencies += [assert_close(
|
||||||
|
math_ops.reduce_sum(p, reduction_indices=[-1]),
|
||||||
|
one, message="p does not sum to 1.")]
|
||||||
|
else:
|
||||||
|
dependencies += [check_ops.assert_less_equal(
|
||||||
|
p, one, message="p has components greater than 1.")]
|
||||||
|
p = control_flow_ops.with_dependencies(dependencies, p)
|
||||||
|
with ops.name_scope("logits"):
|
||||||
|
logits = math_ops.log(p) - math_ops.log(1. - p)
|
||||||
|
return (logits, p)
|
||||||
|
|
||||||
|
|
||||||
|
def log_combinations(n, counts, name="log_combinations"):
|
||||||
|
"""Multinomial coefficient.
|
||||||
|
|
||||||
|
Given `n` and `counts`, where `counts` has last dimension `k`, we compute
|
||||||
|
the multinomial coefficient as:
|
||||||
|
|
||||||
|
```n! / sum_i n_i!```
|
||||||
|
|
||||||
|
where `i` runs over all `k` classes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n: Numeric `Tensor` broadcastable with `counts`. This represents `n`
|
||||||
|
outcomes.
|
||||||
|
counts: Numeric `Tensor` broadcastable with `n`. This represents counts
|
||||||
|
in `k` classes, where `k` is the last dimension of the tensor.
|
||||||
|
name: A name for this operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tensor` representing the multinomial coefficient between `n` and `counts`.
|
||||||
|
"""
|
||||||
|
# First a bit about the number of ways counts could have come in:
|
||||||
|
# E.g. if counts = [1, 2], then this is 3 choose 2.
|
||||||
|
# In general, this is (sum counts)! / sum(counts!)
|
||||||
|
# The sum should be along the last dimension of counts. This is the
|
||||||
|
# "distribution" dimension. Here n a priori represents the sum of counts.
|
||||||
|
with ops.op_scope([n, counts], name):
|
||||||
|
total_permutations = math_ops.lgamma(n + 1)
|
||||||
|
counts_factorial = math_ops.lgamma(counts + 1)
|
||||||
|
redundant_permutations = math_ops.reduce_sum(counts_factorial,
|
||||||
|
reduction_indices=[-1])
|
||||||
|
return total_permutations - redundant_permutations
|
@ -46,15 +46,15 @@ class Exponential(gamma.Gamma):
|
|||||||
"""Construct Exponential distribution with parameter `lam`.
|
"""Construct Exponential distribution with parameter `lam`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
lam: `float` or `double` tensor, the rate of the distribution(s).
|
lam: Floating point tensor, the rate of the distribution(s).
|
||||||
`lam` must contain only positive values.
|
`lam` must contain only positive values.
|
||||||
validate_args: Whether to assert that `lam > 0`, and that `x > 0` in the
|
validate_args: Whether to assert that `lam > 0`, and that `x > 0` in the
|
||||||
methods `prob(x)` and `log_prob(x)`. If `validate_args` is False
|
methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False`
|
||||||
and the inputs are invalid, correct behavior is not guaranteed.
|
and the inputs are invalid, correct behavior is not guaranteed.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to prepend to all ops created by this distribution.
|
name: The name to prepend to all ops created by this distribution.
|
||||||
"""
|
"""
|
||||||
# Even though all statistics of are defined for valid inputs, this is not
|
# Even though all statistics of are defined for valid inputs, this is not
|
||||||
@ -95,8 +95,7 @@ class Exponential(gamma.Gamma):
|
|||||||
broadcast_shape = self._lam.get_shape()
|
broadcast_shape = self._lam.get_shape()
|
||||||
with ops.op_scope([self.lam, n], name, "ExponentialSample"):
|
with ops.op_scope([self.lam, n], name, "ExponentialSample"):
|
||||||
n = ops.convert_to_tensor(n, name="n")
|
n = ops.convert_to_tensor(n, name="n")
|
||||||
shape = array_ops.concat(
|
shape = array_ops.concat(0, ([n], array_ops.shape(self._lam)))
|
||||||
0, [array_ops.pack([n]), array_ops.shape(self._lam)])
|
|
||||||
# Sample uniformly-at-random from the open-interval (0, 1).
|
# Sample uniformly-at-random from the open-interval (0, 1).
|
||||||
sampled = random_ops.random_uniform(
|
sampled = random_ops.random_uniform(
|
||||||
shape, minval=np.nextafter(
|
shape, minval=np.nextafter(
|
||||||
|
@ -69,19 +69,19 @@ class Gamma(distribution.Distribution):
|
|||||||
broadcasting (e.g. `alpha + beta` is a valid operation).
|
broadcasting (e.g. `alpha + beta` is a valid operation).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
alpha: `float` or `double` tensor, the shape params of the
|
alpha: Floating point tensor, the shape params of the
|
||||||
distribution(s).
|
distribution(s).
|
||||||
alpha must contain only positive values.
|
alpha must contain only positive values.
|
||||||
beta: `float` or `double` tensor, the inverse scale params of the
|
beta: Floating point tensor, the inverse scale params of the
|
||||||
distribution(s).
|
distribution(s).
|
||||||
beta must contain only positive values.
|
beta must contain only positive values.
|
||||||
validate_args: Whether to assert that `a > 0, b > 0`, and that `x > 0` in
|
validate_args: Whether to assert that `a > 0, b > 0`, and that `x > 0` in
|
||||||
the methods `prob(x)` and `log_prob(x)`. If `validate_args` is False
|
the methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False`
|
||||||
and the inputs are invalid, correct behavior is not guaranteed.
|
and the inputs are invalid, correct behavior is not guaranteed.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to prepend to all ops created by this distribution.
|
name: The name to prepend to all ops created by this distribution.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -213,9 +213,12 @@ class Gamma(distribution.Distribution):
|
|||||||
nan = np.nan * self._ones()
|
nan = np.nan * self._ones()
|
||||||
return math_ops.select(alpha_ge_1, mode_if_defined, nan)
|
return math_ops.select(alpha_ge_1, mode_if_defined, nan)
|
||||||
else:
|
else:
|
||||||
one = ops.convert_to_tensor(1.0, dtype=self.dtype)
|
one = constant_op.constant(1.0, dtype=self.dtype)
|
||||||
return control_flow_ops.with_dependencies(
|
return control_flow_ops.with_dependencies(
|
||||||
[check_ops.assert_less(one, alpha)], mode_if_defined)
|
[check_ops.assert_less(
|
||||||
|
one, alpha,
|
||||||
|
message="mode not defined for components of alpha <= 1"
|
||||||
|
)], mode_if_defined)
|
||||||
|
|
||||||
def variance(self, name="variance"):
|
def variance(self, name="variance"):
|
||||||
"""Variance of each batch member."""
|
"""Variance of each batch member."""
|
||||||
|
@ -69,18 +69,18 @@ class InverseGamma(distribution.Distribution):
|
|||||||
broadcasting (e.g. `alpha + beta` is a valid operation).
|
broadcasting (e.g. `alpha + beta` is a valid operation).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
alpha: `float` or `double` tensor, the shape params of the
|
alpha: Floating point tensor, the shape params of the
|
||||||
distribution(s).
|
distribution(s).
|
||||||
alpha must contain only positive values.
|
alpha must contain only positive values.
|
||||||
beta: `float` or `double` tensor, the scale params of the distribution(s).
|
beta: Floating point tensor, the scale params of the distribution(s).
|
||||||
beta must contain only positive values.
|
beta must contain only positive values.
|
||||||
validate_args: Whether to assert that `a > 0, b > 0`, and that `x > 0` in
|
validate_args: Whether to assert that `a > 0, b > 0`, and that `x > 0` in
|
||||||
the methods `prob(x)` and `log_prob(x)`. If `validate_args` is False
|
the methods `prob(x)` and `log_prob(x)`. If `validate_args` is `False`
|
||||||
and the inputs are invalid, correct behavior is not guaranteed.
|
and the inputs are invalid, correct behavior is not guaranteed.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to prepend to all ops created by this distribution.
|
name: The name to prepend to all ops created by this distribution.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -206,9 +206,12 @@ class InverseGamma(distribution.Distribution):
|
|||||||
nan = np.nan * self._ones()
|
nan = np.nan * self._ones()
|
||||||
return math_ops.select(alpha_gt_1, mean_if_defined, nan)
|
return math_ops.select(alpha_gt_1, mean_if_defined, nan)
|
||||||
else:
|
else:
|
||||||
one = ops.convert_to_tensor(1.0, dtype=self.dtype)
|
one = constant_op.constant(1.0, dtype=self.dtype)
|
||||||
return control_flow_ops.with_dependencies(
|
return control_flow_ops.with_dependencies(
|
||||||
[check_ops.assert_less(one, alpha)], mean_if_defined)
|
[check_ops.assert_less(
|
||||||
|
one, alpha,
|
||||||
|
message="mean not defined for components of alpha <= 1")],
|
||||||
|
mean_if_defined)
|
||||||
|
|
||||||
def mode(self, name="mode"):
|
def mode(self, name="mode"):
|
||||||
"""Mode of each batch member.
|
"""Mode of each batch member.
|
||||||
@ -250,9 +253,12 @@ class InverseGamma(distribution.Distribution):
|
|||||||
nan = np.nan * self._ones()
|
nan = np.nan * self._ones()
|
||||||
return math_ops.select(alpha_gt_2, var_if_defined, nan)
|
return math_ops.select(alpha_gt_2, var_if_defined, nan)
|
||||||
else:
|
else:
|
||||||
two = ops.convert_to_tensor(2.0, dtype=self.dtype)
|
two = constant_op.constant(2.0, dtype=self.dtype)
|
||||||
return control_flow_ops.with_dependencies(
|
return control_flow_ops.with_dependencies(
|
||||||
[check_ops.assert_less(two, alpha)], var_if_defined)
|
[check_ops.assert_less(
|
||||||
|
two, alpha,
|
||||||
|
message="variance not defined for components of alpha <= 2")],
|
||||||
|
var_if_defined)
|
||||||
|
|
||||||
def log_prob(self, x, name="log_prob"):
|
def log_prob(self, x, name="log_prob"):
|
||||||
"""Log prob of observations in `x` under these InverseGamma distribution(s).
|
"""Log prob of observations in `x` under these InverseGamma distribution(s).
|
||||||
|
@ -34,9 +34,9 @@ def kl(dist_a, dist_b, allow_nan=False, name=None):
|
|||||||
Args:
|
Args:
|
||||||
dist_a: instance of distributions.Distribution.
|
dist_a: instance of distributions.Distribution.
|
||||||
dist_b: instance of distributions.Distribution.
|
dist_b: instance of distributions.Distribution.
|
||||||
allow_nan: If False (default), a runtime error is raised
|
allow_nan: If `False` (default), a runtime error is raised
|
||||||
if the KL returns NaN values for any batch entry of the given
|
if the KL returns NaN values for any batch entry of the given
|
||||||
distributions. If True, the KL may return a NaN for the given entry.
|
distributions. If `True`, the KL may return a NaN for the given entry.
|
||||||
name: (optional) Name scope to use for created operations.
|
name: (optional) Name scope to use for created operations.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -60,17 +60,17 @@ class Laplace(distribution.Distribution):
|
|||||||
broadcasting (e.g., `loc / scale` is a valid operation).
|
broadcasting (e.g., `loc / scale` is a valid operation).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
loc: `float` or `double` tensor which characterizes the location (center)
|
loc: Floating point tensor which characterizes the location (center)
|
||||||
of the distribution.
|
of the distribution.
|
||||||
scale: `float` or `double`, positive-valued tensor which characterzes the
|
scale: Positive floating point tensor which characterizes the spread of
|
||||||
spread of the distribution.
|
the distribution.
|
||||||
validate_args: Whether to validate input with asserts. If `validate_args`
|
validate_args: Whether to validate input with asserts. If `validate_args`
|
||||||
is `False`, and the inputs are invalid, correct behavior is not
|
is `False`, and the inputs are invalid, correct behavior is not
|
||||||
guaranteed.
|
guaranteed.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to give Ops created by the initializer.
|
name: The name to give Ops created by the initializer.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -294,8 +294,7 @@ class Laplace(distribution.Distribution):
|
|||||||
with ops.op_scope([self._loc, self._scale, n], name):
|
with ops.op_scope([self._loc, self._scale, n], name):
|
||||||
n = ops.convert_to_tensor(n)
|
n = ops.convert_to_tensor(n)
|
||||||
n_val = tensor_util.constant_value(n)
|
n_val = tensor_util.constant_value(n)
|
||||||
shape = array_ops.concat(
|
shape = array_ops.concat(0, ([n], self.batch_shape()))
|
||||||
0, [array_ops.pack([n]), self.batch_shape()])
|
|
||||||
# Sample uniformly-at-random from the open-interval (-1, 1).
|
# Sample uniformly-at-random from the open-interval (-1, 1).
|
||||||
uniform_samples = random_ops.random_uniform(
|
uniform_samples = random_ops.random_uniform(
|
||||||
shape=shape,
|
shape=shape,
|
||||||
|
343
tensorflow/contrib/distributions/python/ops/multinomial.py
Normal file
343
tensorflow/contrib/distributions/python/ops/multinomial.py
Normal file
@ -0,0 +1,343 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""The Multinomial distribution class."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# pylint: disable=line-too-long
|
||||||
|
|
||||||
|
from tensorflow.contrib.distributions.python.ops import distribution
|
||||||
|
from tensorflow.contrib.distributions.python.ops import distribution_util
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import check_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
|
||||||
|
# pylint: enable=line-too-long
|
||||||
|
|
||||||
|
|
||||||
|
class Multinomial(distribution.Distribution):
|
||||||
|
"""Multinomial distribution.
|
||||||
|
|
||||||
|
This distribution is parameterized by a vector `p` of probability
|
||||||
|
parameters for `k` classes and `n`, the counts per each class..
|
||||||
|
|
||||||
|
#### Mathematical details
|
||||||
|
|
||||||
|
The Multinomial is a distribution over k-class count data, meaning
|
||||||
|
for each k-tuple of non-negative integer `counts = [n_1,...,n_k]`, we have a
|
||||||
|
probability of these draws being made from the distribution. The distribution
|
||||||
|
has hyperparameters `p = (p_1,...,p_k)`, and probability mass
|
||||||
|
function (pmf):
|
||||||
|
|
||||||
|
```pmf(counts) = n! / (n_1!...n_k!) * (p_1)^n_1*(p_2)^n_2*...(p_k)^n_k```
|
||||||
|
|
||||||
|
where above `n = sum_j n_j`, `n!` is `n` factorial.
|
||||||
|
|
||||||
|
#### Examples
|
||||||
|
|
||||||
|
Create a 3-class distribution, with the 3rd class is most likely to be drawn,
|
||||||
|
using logits..
|
||||||
|
|
||||||
|
```python
|
||||||
|
logits = [-50., -43, 0]
|
||||||
|
dist = Multinomial(n=4., logits=logits)
|
||||||
|
```
|
||||||
|
|
||||||
|
Create a 3-class distribution, with the 3rd class is most likely to be drawn.
|
||||||
|
|
||||||
|
```python
|
||||||
|
p = [.2, .3, .5]
|
||||||
|
dist = Multinomial(n=4., p=p)
|
||||||
|
```
|
||||||
|
|
||||||
|
The distribution functions can be evaluated on counts.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# counts same shape as p.
|
||||||
|
counts = [1., 0, 3]
|
||||||
|
dist.prob(counts) # Shape []
|
||||||
|
|
||||||
|
# p will be broadcast to [[.2, .3, .5], [.2, .3, .5]] to match counts.
|
||||||
|
counts = [[1., 2, 1], [2, 2, 0]]
|
||||||
|
dist.prob(counts) # Shape [2]
|
||||||
|
|
||||||
|
# p will be broadcast to shape [5, 7, 3] to match counts.
|
||||||
|
counts = [[...]] # Shape [5, 7, 3]
|
||||||
|
dist.prob(counts) # Shape [5, 7]
|
||||||
|
```
|
||||||
|
|
||||||
|
Create a 2-batch of 3-class distributions.
|
||||||
|
|
||||||
|
```python
|
||||||
|
p = [[.1, .2, .7], [.3, .3, .4]] # Shape [2, 3]
|
||||||
|
dist = Multinomial(n=[4., 5], p=p)
|
||||||
|
|
||||||
|
counts = [[2., 1, 1], [3, 1, 1]]
|
||||||
|
dist.prob(counts) # Shape [2]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
n,
|
||||||
|
logits=None,
|
||||||
|
p=None,
|
||||||
|
validate_args=True,
|
||||||
|
allow_nan_stats=False,
|
||||||
|
name="Multinomial"):
|
||||||
|
"""Initialize a batch of Multinomial distributions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n: Non-negative floating point tensor with shape broadcastable to
|
||||||
|
`[N1,..., Nm]` with `m >= 0`. Defines this as a batch of
|
||||||
|
`N1 x ... x Nm` different Multinomial distributions. Its components
|
||||||
|
should be equal to integer values.
|
||||||
|
logits: Floating point tensor representing the log-odds of a
|
||||||
|
positive event with shape broadcastable to `[N1,..., Nm, k], m >= 0`,
|
||||||
|
and the same dtype as `n`. Defines this as a batch of `N1 x ... x Nm`
|
||||||
|
different `k` class Multinomial distributions.
|
||||||
|
p: Positive floating point tensor with shape broadcastable to
|
||||||
|
`[N1,..., Nm, k]` `m >= 0` and same dtype as `n`. Defines this as
|
||||||
|
a batch of `N1 x ... x Nm` different `k` class Multinomial
|
||||||
|
distributions. `p`'s components in the last portion of its shape should
|
||||||
|
sum up to 1.
|
||||||
|
validate_args: Whether to assert valid values for parameters `n` and `p`,
|
||||||
|
and `x` in `prob` and `log_prob`. If `False`, correct behavior is not
|
||||||
|
guaranteed.
|
||||||
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
|
undefined statistics will return NaN for this statistic.
|
||||||
|
name: The name to prefix Ops created by this distribution class.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Define 1-batch of 2-class multinomial distribution,
|
||||||
|
# also known as a Binomial distribution.
|
||||||
|
dist = Multinomial(n=2., p=[.1, .9])
|
||||||
|
|
||||||
|
# Define a 2-batch of 3-class distributions.
|
||||||
|
dist = Multinomial(n=[4., 5], p=[[.1, .3, .6], [.4, .05, .55]])
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._logits, self._p = distribution_util.get_logits_and_prob(
|
||||||
|
name=name, logits=logits, p=p, validate_args=validate_args,
|
||||||
|
multidimensional=True)
|
||||||
|
with ops.op_scope([n, self._p], name):
|
||||||
|
with ops.control_dependencies([
|
||||||
|
check_ops.assert_non_negative(
|
||||||
|
n, message="n has negative components."),
|
||||||
|
distribution_util.assert_integer_form(
|
||||||
|
n, message="n has non-integer components."
|
||||||
|
)] if validate_args else []):
|
||||||
|
self._n = array_ops.identity(n, name="convert_n")
|
||||||
|
self._name = name
|
||||||
|
|
||||||
|
self._validate_args = validate_args
|
||||||
|
self._allow_nan_stats = allow_nan_stats
|
||||||
|
|
||||||
|
self._mean = array_ops.expand_dims(n, -1) * self._p
|
||||||
|
# Only used for inferring shape.
|
||||||
|
self._broadcast_shape = math_ops.reduce_sum(self._mean,
|
||||||
|
reduction_indices=[-1],
|
||||||
|
keep_dims=False)
|
||||||
|
|
||||||
|
self._get_batch_shape = self._broadcast_shape.get_shape()
|
||||||
|
self._get_event_shape = (
|
||||||
|
self._mean.get_shape().with_rank_at_least(1)[-1:])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n(self):
|
||||||
|
"""Number of trials."""
|
||||||
|
return self._n
|
||||||
|
|
||||||
|
@property
|
||||||
|
def p(self):
|
||||||
|
"""Event probabilities."""
|
||||||
|
return self._p
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logits(self):
|
||||||
|
"""Log-odds."""
|
||||||
|
return self._logits
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
"""Name to prepend to all ops."""
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
"""dtype of samples from this distribution."""
|
||||||
|
return self._p.dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def validate_args(self):
|
||||||
|
"""Boolean describing behavior on invalid input."""
|
||||||
|
return self._validate_args
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allow_nan_stats(self):
|
||||||
|
"""Boolean describing behavior when a stat is undefined for batch member."""
|
||||||
|
return self._allow_nan_stats
|
||||||
|
|
||||||
|
def batch_shape(self, name="batch_shape"):
|
||||||
|
"""Batch dimensions of this instance as a 1-D int32 `Tensor`.
|
||||||
|
|
||||||
|
The product of the dimensions of the `batch_shape` is the number of
|
||||||
|
independent distributions of this kind the instance represents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: name to give to the op
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tensor` `batch_shape`
|
||||||
|
"""
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
with ops.op_scope([self._broadcast_shape], name):
|
||||||
|
return array_ops.shape(self._broadcast_shape)
|
||||||
|
|
||||||
|
def get_batch_shape(self):
|
||||||
|
"""`TensorShape` available at graph construction time.
|
||||||
|
|
||||||
|
Same meaning as `batch_shape`. May be only partially defined.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
batch shape
|
||||||
|
"""
|
||||||
|
return self._get_batch_shape
|
||||||
|
|
||||||
|
def event_shape(self, name="event_shape"):
|
||||||
|
"""Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: name to give to the op
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tensor` `event_shape`
|
||||||
|
"""
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
with ops.op_scope([self._mean], name):
|
||||||
|
return array_ops.gather(array_ops.shape(self._mean),
|
||||||
|
[array_ops.rank(self._mean) - 1])
|
||||||
|
|
||||||
|
def get_event_shape(self):
|
||||||
|
"""`TensorShape` available at graph construction time.
|
||||||
|
|
||||||
|
Same meaning as `event_shape`. May be only partially defined.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
event shape
|
||||||
|
"""
|
||||||
|
return self._get_event_shape
|
||||||
|
|
||||||
|
def mean(self, name="mean"):
|
||||||
|
"""Mean of the distribution."""
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
return array_ops.identity(self._mean, name=name)
|
||||||
|
|
||||||
|
def variance(self, name="variance"):
|
||||||
|
"""Variance of the distribution."""
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
with ops.op_scope([self._n, self._p, self._mean], name):
|
||||||
|
p = array_ops.expand_dims(
|
||||||
|
self._p * array_ops.expand_dims(
|
||||||
|
array_ops.ones_like(self._n), -1), -1)
|
||||||
|
variance = -math_ops.batch_matmul(
|
||||||
|
array_ops.expand_dims(self._mean, -1), p, adj_y=True)
|
||||||
|
variance += array_ops.batch_matrix_diag(self._mean)
|
||||||
|
return variance
|
||||||
|
|
||||||
|
def log_prob(self, counts, name="log_prob"):
|
||||||
|
"""`Log(P[counts])`, computed for every batch member.
|
||||||
|
|
||||||
|
For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability
|
||||||
|
that after sampling `n` draws from this Multinomial distribution, the
|
||||||
|
number of draws falling in class `j` is `n_j`. Note that different
|
||||||
|
sequences of draws can result in the same counts, thus the probability
|
||||||
|
includes a combinatorial coefficient.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
counts: Non-negative tensor with dtype `dtype` and whose shape can
|
||||||
|
be broadcast with `self.p` and `self.n`. For fixed leading dimensions,
|
||||||
|
the last dimension represents counts for the corresponding Multinomial
|
||||||
|
distribution in `self.p`. `counts` is only legal if it sums up to `n`
|
||||||
|
and its components are equal to integer values.
|
||||||
|
name: Name to give this Op, defaults to "log_prob".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Log probabilities for each record, shape `[N1,...,Nm]`.
|
||||||
|
"""
|
||||||
|
n = self._n
|
||||||
|
p = self._p
|
||||||
|
with ops.name_scope(self.name):
|
||||||
|
with ops.op_scope([n, p, counts], name):
|
||||||
|
counts = self._check_counts(counts)
|
||||||
|
|
||||||
|
prob_prob = math_ops.reduce_sum(counts * math_ops.log(self._p),
|
||||||
|
reduction_indices=[-1])
|
||||||
|
log_prob = prob_prob + distribution_util.log_combinations(
|
||||||
|
n, counts)
|
||||||
|
return log_prob
|
||||||
|
|
||||||
|
def prob(self, counts, name="prob"):
|
||||||
|
"""`P[counts]`, computed for every batch member.
|
||||||
|
|
||||||
|
For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability
|
||||||
|
that after sampling `n` draws from this Multinomial distribution, the
|
||||||
|
number of draws falling in class `j` is `n_j`. Note that different
|
||||||
|
sequences of draws can result in the same counts, thus the probability
|
||||||
|
includes a combinatorial coefficient.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
counts: Non-negative tensor with dtype `dtype` and whose shape can
|
||||||
|
be broadcast with `self.p` and `self.n`. For fixed leading dimensions,
|
||||||
|
the last dimension represents counts for the corresponding Multinomial
|
||||||
|
distribution in `self.p`. `counts` is only legal if it sums up to `n`
|
||||||
|
and its components are equal to integer values.
|
||||||
|
name: Name to give this Op, defaults to "prob".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Probabilities for each record, shape `[N1,...,Nm]`.
|
||||||
|
"""
|
||||||
|
return super(Multinomial, self).prob(counts, name=name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_continuous(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_reparameterized(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _check_counts(self, counts):
|
||||||
|
"""Check counts for proper shape, values, then return tensor version."""
|
||||||
|
counts = ops.convert_to_tensor(counts, name="counts_before_deps")
|
||||||
|
candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1])
|
||||||
|
if not self.validate_args:
|
||||||
|
return counts
|
||||||
|
|
||||||
|
return control_flow_ops.with_dependencies([
|
||||||
|
check_ops.assert_non_negative(
|
||||||
|
counts, message="counts has negative components."),
|
||||||
|
check_ops.assert_equal(
|
||||||
|
self._n, candidate_n, message="counts do not sum to n."),
|
||||||
|
distribution_util.assert_integer_form(
|
||||||
|
counts, message="counts have non-integer components.")], counts)
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops import distribution
|
from tensorflow.contrib.distributions.python.ops import distribution
|
||||||
|
from tensorflow.contrib.distributions.python.ops import kullback_leibler
|
||||||
from tensorflow.contrib.distributions.python.ops import operator_pd_cholesky
|
from tensorflow.contrib.distributions.python.ops import operator_pd_cholesky
|
||||||
from tensorflow.contrib.distributions.python.ops import operator_pd_diag
|
from tensorflow.contrib.distributions.python.ops import operator_pd_diag
|
||||||
from tensorflow.contrib.distributions.python.ops import operator_pd_full
|
from tensorflow.contrib.distributions.python.ops import operator_pd_full
|
||||||
@ -104,9 +105,9 @@ class MultivariateNormalOperatorPD(distribution.Distribution):
|
|||||||
which determines the covariance.
|
which determines the covariance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mu: `float` or `double` tensor with shape `[N1,...,Nb, k]`, `b >= 0`.
|
mu: Floating point tensor with shape `[N1,...,Nb, k]`, `b >= 0`.
|
||||||
cov: `float` or `double` instance of `OperatorPDBase` with same `dtype`
|
cov: Instance of `OperatorPDBase` with same `dtype` as `mu` and shape
|
||||||
as `mu` and shape `[N1,...,Nb, k, k]`.
|
`[N1,...,Nb, k, k]`.
|
||||||
validate_args: Whether to validate input with asserts. If `validate_args`
|
validate_args: Whether to validate input with asserts. If `validate_args`
|
||||||
is `False`, and the inputs are invalid, correct behavior is not
|
is `False`, and the inputs are invalid, correct behavior is not
|
||||||
guaranteed.
|
guaranteed.
|
||||||
@ -149,7 +150,7 @@ class MultivariateNormalOperatorPD(distribution.Distribution):
|
|||||||
else:
|
else:
|
||||||
return mu
|
return mu
|
||||||
|
|
||||||
# Static checks could not be run, so possibly do dyamic checks.
|
# Static checks could not be run, so possibly do dynamic checks.
|
||||||
if not self.validate_args:
|
if not self.validate_args:
|
||||||
return mu
|
return mu
|
||||||
else:
|
else:
|
||||||
@ -465,7 +466,7 @@ class MultivariateNormalDiag(MultivariateNormalOperatorPD):
|
|||||||
The mean of `X_i` is `mu[i]`, and the standard deviation is `diag_stdev[i]`.
|
The mean of `X_i` is `mu[i]`, and the standard deviation is `diag_stdev[i]`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mu: Rank `N + 1` `float` or `double` tensor with shape `[N1,...,Nb, k]`,
|
mu: Rank `N + 1` floating point tensor with shape `[N1,...,Nb, k]`,
|
||||||
`b >= 0`.
|
`b >= 0`.
|
||||||
diag_stdev: Rank `N + 1` `Tensor` with same `dtype` and shape as `mu`,
|
diag_stdev: Rank `N + 1` `Tensor` with same `dtype` and shape as `mu`,
|
||||||
representing the standard deviations. Must be positive.
|
representing the standard deviations. Must be positive.
|
||||||
@ -580,13 +581,13 @@ class MultivariateNormalDiagPlusVDVT(MultivariateNormalOperatorPD):
|
|||||||
```
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mu: Rank `n + 1` `float` or `double` tensor with shape `[N1,...,Nn, k]`,
|
mu: Rank `n + 1` floating point tensor with shape `[N1,...,Nn, k]`,
|
||||||
`n >= 0`. The means.
|
`n >= 0`. The means.
|
||||||
diag_large: Optional rank `n + 1` `float` or `double` tensor, shape
|
diag_large: Optional rank `n + 1` floating point tensor, shape
|
||||||
`[N1,...,Nn, k]` `n >= 0`. Defines the diagonal matrix `M`.
|
`[N1,...,Nn, k]` `n >= 0`. Defines the diagonal matrix `M`.
|
||||||
v: Rank `n + 1` `float` or `double` tensor, shape `[N1,...,Nn, k, r]`
|
v: Rank `n + 1` floating point tensor, shape `[N1,...,Nn, k, r]`
|
||||||
`n >= 0`. Defines the matrix `V`.
|
`n >= 0`. Defines the matrix `V`.
|
||||||
diag_small: Rank `n + 1` `float` or `double` tensor, shape
|
diag_small: Rank `n + 1` floating point tensor, shape
|
||||||
`[N1,...,Nn, k]` `n >= 0`. Defines the diagonal matrix `D`. Default
|
`[N1,...,Nn, k]` `n >= 0`. Defines the diagonal matrix `D`. Default
|
||||||
is `None`, which means `D` will be the identity matrix.
|
is `None`, which means `D` will be the identity matrix.
|
||||||
validate_args: Whether to validate input with asserts. If `validate_args`
|
validate_args: Whether to validate input with asserts. If `validate_args`
|
||||||
@ -669,7 +670,7 @@ class MultivariateNormalCholesky(MultivariateNormalOperatorPD):
|
|||||||
factors, such that the covariance of each batch member is `chol chol^T`.
|
factors, such that the covariance of each batch member is `chol chol^T`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mu: `(N+1)-D` `float` or `double` tensor with shape `[N1,...,Nb, k]`,
|
mu: `(N+1)-D` floating point tensor with shape `[N1,...,Nb, k]`,
|
||||||
`b >= 0`.
|
`b >= 0`.
|
||||||
chol: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape
|
chol: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape
|
||||||
`[N1,...,Nb, k, k]`. The upper triangular part is ignored (treated as
|
`[N1,...,Nb, k, k]`. The upper triangular part is ignored (treated as
|
||||||
@ -749,7 +750,7 @@ class MultivariateNormalFull(MultivariateNormalOperatorPD):
|
|||||||
User must provide means `mu` and `sigma`, the mean and covariance.
|
User must provide means `mu` and `sigma`, the mean and covariance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mu: `(N+1)-D` `float` or `double` tensor with shape `[N1,...,Nb, k]`,
|
mu: `(N+1)-D` floating point tensor with shape `[N1,...,Nb, k]`,
|
||||||
`b >= 0`.
|
`b >= 0`.
|
||||||
sigma: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape
|
sigma: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape
|
||||||
`[N1,...,Nb, k, k]`. Each batch member must be positive definite.
|
`[N1,...,Nb, k, k]`. Each batch member must be positive definite.
|
||||||
@ -772,3 +773,72 @@ class MultivariateNormalFull(MultivariateNormalOperatorPD):
|
|||||||
allow_nan_stats=allow_nan_stats,
|
allow_nan_stats=allow_nan_stats,
|
||||||
validate_args=validate_args,
|
validate_args=validate_args,
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
|
|
||||||
|
def _kl_mvn_mvn_brute_force(mvn_a, mvn_b, name=None):
|
||||||
|
"""Batched KL divergence `KL(mvn_a || mvn_b)` for multivariate normals.
|
||||||
|
|
||||||
|
With `X`, `Y` both multivariate normals in `R^k` with means `mu_x`, `mu_y` and
|
||||||
|
covariance `C_x`, `C_y` respectively,
|
||||||
|
|
||||||
|
```
|
||||||
|
KL(X || Y) = 0.5 * ( T + Q + - k + L ),
|
||||||
|
T := trace(C_b^{-1} C_a),
|
||||||
|
Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a),
|
||||||
|
L := Log[Det(C_b)] - Log[Det(C_a)]
|
||||||
|
```
|
||||||
|
|
||||||
|
This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient
|
||||||
|
methods for solving systems with `C_b` may be available, a dense version of
|
||||||
|
(the square root of) `C_a` is used, so performance is `O(B s k^2)` where `B`
|
||||||
|
is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x`
|
||||||
|
and `y`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mvn_a: Instance of subclass of `MultivariateNormalOperatorPD`.
|
||||||
|
mvn_b: Instance of subclass of `MultivariateNormalOperatorPD`.
|
||||||
|
name: (optional) name to use for created ops. Default "kl_mvn_mvn".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Batchwise `KL(mvn_a || mvn_b)`.
|
||||||
|
"""
|
||||||
|
# Access the "private" OperatorPD that each mvn is built from.
|
||||||
|
cov_a = mvn_a._cov # pylint: disable=protected-access
|
||||||
|
cov_b = mvn_b._cov # pylint: disable=protected-access
|
||||||
|
mu_a = mvn_a.mu
|
||||||
|
mu_b = mvn_b.mu
|
||||||
|
inputs = [mu_a, mu_b] + cov_a.inputs + cov_b.inputs
|
||||||
|
|
||||||
|
with ops.op_scope(inputs, name, "kl_mvn_mvn"):
|
||||||
|
# If Ca = AA', Cb = BB', then
|
||||||
|
# tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A']
|
||||||
|
# = tr[inv(B) A A' inv(B)']
|
||||||
|
# = tr[(inv(B) A) (inv(B) A)']
|
||||||
|
# = sum_{ik} (inv(B) A)_{ik}^2
|
||||||
|
# The second equality follows from the cyclic permutation property.
|
||||||
|
b_inv_a = cov_b.sqrt_solve(cov_a.sqrt_to_dense())
|
||||||
|
t = math_ops.reduce_sum(
|
||||||
|
math_ops.square(b_inv_a),
|
||||||
|
reduction_indices=[-1, -2])
|
||||||
|
q = cov_b.inv_quadratic_form_on_vectors(mu_b - mu_a)
|
||||||
|
k = math_ops.cast(cov_a.vector_space_dimension(), mvn_a.dtype)
|
||||||
|
one_half_l = cov_b.sqrt_log_det() - cov_a.sqrt_log_det()
|
||||||
|
return 0.5 * (t + q - k) + one_half_l
|
||||||
|
|
||||||
|
|
||||||
|
# Register KL divergences.
|
||||||
|
kl_classes = [
|
||||||
|
MultivariateNormalFull,
|
||||||
|
MultivariateNormalCholesky,
|
||||||
|
MultivariateNormalDiag,
|
||||||
|
MultivariateNormalDiagPlusVDVT,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
for mvn_aa in kl_classes:
|
||||||
|
# Register when they are the same here, and do not register when they are the
|
||||||
|
# same below because that would result in a repeated registration.
|
||||||
|
kullback_leibler.RegisterKL(mvn_aa, mvn_aa)(_kl_mvn_mvn_brute_force)
|
||||||
|
for mvn_bb in kl_classes:
|
||||||
|
if mvn_bb != mvn_aa:
|
||||||
|
kullback_leibler.RegisterKL(mvn_aa, mvn_bb)(_kl_mvn_mvn_brute_force)
|
||||||
|
@ -92,15 +92,15 @@ class Normal(distribution.Distribution):
|
|||||||
broadcasting (e.g. `mu + sigma` is a valid operation).
|
broadcasting (e.g. `mu + sigma` is a valid operation).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mu: `float` or `double` tensor, the means of the distribution(s).
|
mu: Floating point tensor, the means of the distribution(s).
|
||||||
sigma: `float` or `double` tensor, the stddevs of the distribution(s).
|
sigma: Floating point tensor, the stddevs of the distribution(s).
|
||||||
sigma must contain only positive values.
|
sigma must contain only positive values.
|
||||||
validate_args: Whether to assert that `sigma > 0`. If `validate_args` is
|
validate_args: Whether to assert that `sigma > 0`. If `validate_args` is
|
||||||
False, correct output is not guaranteed when input is invalid.
|
`False`, correct output is not guaranteed when input is invalid.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to give Ops created by the initializer.
|
name: The name to give Ops created by the initializer.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -321,8 +321,7 @@ class Normal(distribution.Distribution):
|
|||||||
with ops.op_scope([self._mu, self._sigma, n], name):
|
with ops.op_scope([self._mu, self._sigma, n], name):
|
||||||
broadcast_shape = (self._mu + self._sigma).get_shape()
|
broadcast_shape = (self._mu + self._sigma).get_shape()
|
||||||
n = ops.convert_to_tensor(n)
|
n = ops.convert_to_tensor(n)
|
||||||
shape = array_ops.concat(
|
shape = array_ops.concat(0, ([n], array_ops.shape(self.mean())))
|
||||||
0, [array_ops.pack([n]), array_ops.shape(self.mean())])
|
|
||||||
sampled = random_ops.random_normal(
|
sampled = random_ops.random_normal(
|
||||||
shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed)
|
shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed)
|
||||||
|
|
||||||
|
@ -82,6 +82,7 @@ class StudentT(distribution.Distribution):
|
|||||||
# returning a length 2 tensor.
|
# returning a length 2 tensor.
|
||||||
dist.pdf(3.0)
|
dist.pdf(3.0)
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -99,19 +100,19 @@ class StudentT(distribution.Distribution):
|
|||||||
broadcasting (e.g. `df + mu + sigma` is a valid operation).
|
broadcasting (e.g. `df + mu + sigma` is a valid operation).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
df: `float` or `double` tensor, the degrees of freedom of the
|
df: Floating point tensor, the degrees of freedom of the
|
||||||
distribution(s). `df` must contain only positive values.
|
distribution(s). `df` must contain only positive values.
|
||||||
mu: `float` or `double` tensor, the means of the distribution(s).
|
mu: Floating point tensor, the means of the distribution(s).
|
||||||
sigma: `float` or `double` tensor, the scaling factor for the
|
sigma: Floating point tensor, the scaling factor for the
|
||||||
distribution(s). `sigma` must contain only positive values.
|
distribution(s). `sigma` must contain only positive values.
|
||||||
Note that `sigma` is not the standard deviation of this distribution.
|
Note that `sigma` is not the standard deviation of this distribution.
|
||||||
validate_args: Whether to assert that `df > 0, sigma > 0`. If
|
validate_args: Whether to assert that `df > 0, sigma > 0`. If
|
||||||
`validate_args` is False and inputs are invalid, correct behavior is not
|
`validate_args` is `False` and inputs are invalid, correct behavior is
|
||||||
guaranteed.
|
not guaranteed.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to give Ops created by the initializer.
|
name: The name to give Ops created by the initializer.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -185,9 +186,12 @@ class StudentT(distribution.Distribution):
|
|||||||
nan = np.nan + self._zeros()
|
nan = np.nan + self._zeros()
|
||||||
return math_ops.select(df_gt_1, result_if_defined, nan)
|
return math_ops.select(df_gt_1, result_if_defined, nan)
|
||||||
else:
|
else:
|
||||||
one = ops.convert_to_tensor(1.0, dtype=self.dtype)
|
one = constant_op.constant(1.0, dtype=self.dtype)
|
||||||
return control_flow_ops.with_dependencies(
|
return control_flow_ops.with_dependencies(
|
||||||
[check_ops.assert_less(one, self._df)], result_if_defined)
|
[check_ops.assert_less(
|
||||||
|
one, self._df,
|
||||||
|
message="mean not defined for components of df <= 1"
|
||||||
|
)], result_if_defined)
|
||||||
|
|
||||||
def mode(self, name="mode"):
|
def mode(self, name="mode"):
|
||||||
with ops.name_scope(self.name):
|
with ops.name_scope(self.name):
|
||||||
@ -232,9 +236,12 @@ class StudentT(distribution.Distribution):
|
|||||||
result_where_defined,
|
result_where_defined,
|
||||||
self._zeros() + np.nan)
|
self._zeros() + np.nan)
|
||||||
else:
|
else:
|
||||||
one = ops.convert_to_tensor(1.0, self.dtype)
|
one = constant_op.constant(1.0, dtype=self.dtype)
|
||||||
return control_flow_ops.with_dependencies(
|
return control_flow_ops.with_dependencies(
|
||||||
[check_ops.assert_less(one, self._df)], result_where_defined)
|
[check_ops.assert_less(
|
||||||
|
one, self._df,
|
||||||
|
message="variance not defined for components of df <= 1"
|
||||||
|
)], result_where_defined)
|
||||||
|
|
||||||
def std(self, name="std"):
|
def std(self, name="std"):
|
||||||
with ops.name_scope(self.name):
|
with ops.name_scope(self.name):
|
||||||
@ -348,8 +355,7 @@ class StudentT(distribution.Distribution):
|
|||||||
# Let X = R*cos(theta), and let Y = R*sin(theta).
|
# Let X = R*cos(theta), and let Y = R*sin(theta).
|
||||||
# Then X ~ t_df and Y ~ t_df.
|
# Then X ~ t_df and Y ~ t_df.
|
||||||
# The variates X and Y are not independent.
|
# The variates X and Y are not independent.
|
||||||
shape = array_ops.concat(0, [array_ops.pack([2, n]),
|
shape = array_ops.concat(0, ([2, n], self.batch_shape()))
|
||||||
self.batch_shape()])
|
|
||||||
uniform = random_ops.random_uniform(shape=shape,
|
uniform = random_ops.random_uniform(shape=shape,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
seed=seed)
|
seed=seed)
|
||||||
|
@ -57,6 +57,7 @@ class TransformedDistribution(distribution.Distribution):
|
|||||||
name="LogitNormalTransformedDistribution"
|
name="LogitNormalTransformedDistribution"
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -67,14 +67,14 @@ class Uniform(distribution.Distribution):
|
|||||||
```
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a: `float` or `double` tensor, the minimum endpoint.
|
a: Floating point tensor, the minimum endpoint.
|
||||||
b: `float` or `double` tensor, the maximum endpoint. Must be > `a`.
|
b: Floating point tensor, the maximum endpoint. Must be > `a`.
|
||||||
validate_args: Whether to assert that `a > b`. If `validate_args` is False
|
validate_args: Whether to assert that `a > b`. If `validate_args` is
|
||||||
and inputs are invalid, correct behavior is not guaranteed.
|
`False` and inputs are invalid, correct behavior is not guaranteed.
|
||||||
allow_nan_stats: Boolean, default False. If False, raise an exception if
|
allow_nan_stats: Boolean, default `False`. If `False`, raise an
|
||||||
a statistic (e.g. mean/mode/etc...) is undefined for any batch member.
|
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
|
||||||
If True, batch members with valid parameters leading to undefined
|
batch member. If `True`, batch members with valid parameters leading to
|
||||||
statistics will return NaN for this statistic.
|
undefined statistics will return NaN for this statistic.
|
||||||
name: The name to prefix Ops created by this distribution class.
|
name: The name to prefix Ops created by this distribution class.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -83,8 +83,9 @@ class Uniform(distribution.Distribution):
|
|||||||
self._allow_nan_stats = allow_nan_stats
|
self._allow_nan_stats = allow_nan_stats
|
||||||
self._validate_args = validate_args
|
self._validate_args = validate_args
|
||||||
with ops.op_scope([a, b], name):
|
with ops.op_scope([a, b], name):
|
||||||
with ops.control_dependencies([check_ops.assert_less(a, b)] if
|
with ops.control_dependencies([check_ops.assert_less(
|
||||||
validate_args else []):
|
a, b, message="uniform not defined when a > b.")] if validate_args
|
||||||
|
else []):
|
||||||
a = array_ops.identity(a, name="a")
|
a = array_ops.identity(a, name="a")
|
||||||
b = array_ops.identity(b, name="b")
|
b = array_ops.identity(b, name="b")
|
||||||
|
|
||||||
@ -228,7 +229,7 @@ class Uniform(distribution.Distribution):
|
|||||||
n = ops.convert_to_tensor(n, name="n")
|
n = ops.convert_to_tensor(n, name="n")
|
||||||
n_val = tensor_util.constant_value(n)
|
n_val = tensor_util.constant_value(n)
|
||||||
|
|
||||||
shape = array_ops.concat(0, [array_ops.pack([n]), self.batch_shape()])
|
shape = array_ops.concat(0, ([n], self.batch_shape()))
|
||||||
samples = random_ops.random_uniform(shape=shape,
|
samples = random_ops.random_uniform(shape=shape,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
seed=seed)
|
seed=seed)
|
||||||
|
@ -94,6 +94,30 @@ tf_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "gmm_test",
|
||||||
|
srcs = [
|
||||||
|
"python/ops/gmm_test.py",
|
||||||
|
],
|
||||||
|
additional_deps = [
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "gmm_ops_test",
|
||||||
|
srcs = [
|
||||||
|
"python/ops/gmm_ops_test.py",
|
||||||
|
],
|
||||||
|
additional_deps = [
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "factorization_ops_test",
|
name = "factorization_ops_test",
|
||||||
srcs = ["python/ops/factorization_ops_test.py"],
|
srcs = ["python/ops/factorization_ops_test.py"],
|
||||||
|
@ -304,7 +304,7 @@ class WalsModelTest(tf.test.TestCase):
|
|||||||
col_factors2 = [x.eval() for x in wals_model.col_factors]
|
col_factors2 = [x.eval() for x in wals_model.col_factors]
|
||||||
|
|
||||||
for c1, c2 in zip(col_factors1, col_factors2):
|
for c1, c2 in zip(col_factors1, col_factors2):
|
||||||
self.assertAllClose(c1, c2, atol=1e-3)
|
self.assertAllClose(c1, c2, rtol=5e-3, atol=1e-2)
|
||||||
|
|
||||||
def test_als_transposed(self):
|
def test_als_transposed(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
@ -383,7 +383,7 @@ class WalsModelTest(tf.test.TestCase):
|
|||||||
regularization=1e-5,
|
regularization=1e-5,
|
||||||
row_weights=None,
|
row_weights=None,
|
||||||
col_weights=None)
|
col_weights=None)
|
||||||
self.simple_train(model, inp, 15)
|
self.simple_train(model, inp, 25)
|
||||||
row_factor = model.row_factors[0].eval()
|
row_factor = model.row_factors[0].eval()
|
||||||
col_factor = model.col_factors[0].eval()
|
col_factor = model.col_factors[0].eval()
|
||||||
self.assertAllClose(data,
|
self.assertAllClose(data,
|
||||||
@ -407,7 +407,7 @@ class WalsModelTest(tf.test.TestCase):
|
|||||||
regularization=1e-5,
|
regularization=1e-5,
|
||||||
row_weights=[0] * rows,
|
row_weights=[0] * rows,
|
||||||
col_weights=[0] * cols)
|
col_weights=[0] * cols)
|
||||||
self.simple_train(model, inp, 15)
|
self.simple_train(model, inp, 25)
|
||||||
row_factor = model.row_factors[0].eval()
|
row_factor = model.row_factors[0].eval()
|
||||||
col_factor = model.col_factors[0].eval()
|
col_factor = model.col_factors[0].eval()
|
||||||
self.assertAllClose(data,
|
self.assertAllClose(data,
|
||||||
@ -438,7 +438,7 @@ class WalsModelTest(tf.test.TestCase):
|
|||||||
regularization=0.001,
|
regularization=0.001,
|
||||||
row_weights=row_wts,
|
row_weights=row_wts,
|
||||||
col_weights=col_wts)
|
col_weights=col_wts)
|
||||||
self.simple_train(model, inp, 10)
|
self.simple_train(model, inp, 25)
|
||||||
row_factor = model.row_factors[0].eval()
|
row_factor = model.row_factors[0].eval()
|
||||||
col_factor = model.col_factors[0].eval()
|
col_factor = model.col_factors[0].eval()
|
||||||
out = np.dot(row_factor, np.transpose(col_factor))
|
out = np.dot(row_factor, np.transpose(col_factor))
|
||||||
@ -446,7 +446,7 @@ class WalsModelTest(tf.test.TestCase):
|
|||||||
for j in xrange(cols):
|
for j in xrange(cols):
|
||||||
if keep_index([i, j]):
|
if keep_index([i, j]):
|
||||||
self.assertNear(data[i][j], out[i][j],
|
self.assertNear(data[i][j], out[i][j],
|
||||||
err=0.2, msg="%d, %d" % (i, j))
|
err=0.4, msg="%d, %d" % (i, j))
|
||||||
else:
|
else:
|
||||||
self.assertNear(0, out[i][j], err=0.5, msg="%d, %d" % (i, j))
|
self.assertNear(0, out[i][j], err=0.5, msg="%d, %d" % (i, j))
|
||||||
|
|
||||||
|
211
tensorflow/contrib/factorization/python/ops/gmm.py
Normal file
211
tensorflow/contrib/factorization/python/ops/gmm.py
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Implementation of Gaussian mixture model (GMM) clustering.
|
||||||
|
|
||||||
|
This goes on top of skflow API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.contrib.factorization.python.ops import gmm_ops
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators._sklearn import TransformerMixin
|
||||||
|
from tensorflow.contrib.learn.python.learn.learn_io import data_feeder
|
||||||
|
from tensorflow.contrib.learn.python.learn.utils import checkpoints
|
||||||
|
from tensorflow.python.ops.control_flow_ops import with_dependencies
|
||||||
|
|
||||||
|
|
||||||
|
class GMM(estimator.Estimator, TransformerMixin):
|
||||||
|
"""GMM clustering."""
|
||||||
|
SCORES = 'scores'
|
||||||
|
ASSIGNMENTS = 'assignments'
|
||||||
|
ALL_SCORES = 'all_scores'
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_clusters,
|
||||||
|
model_dir=None,
|
||||||
|
random_seed=0,
|
||||||
|
params='wmc',
|
||||||
|
initial_clusters='random',
|
||||||
|
covariance_type='full',
|
||||||
|
batch_size=128,
|
||||||
|
steps=10,
|
||||||
|
continue_training=False,
|
||||||
|
config=None,
|
||||||
|
verbose=1):
|
||||||
|
"""Creates a model for running GMM training and inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_clusters: number of clusters to train.
|
||||||
|
model_dir: the directory to save the model results and log files.
|
||||||
|
random_seed: Python integer. Seed for PRNG used to initialize centers.
|
||||||
|
params: Controls which parameters are updated in the training process.
|
||||||
|
Can contain any combination of "w" for weights, "m" for means,
|
||||||
|
and "c" for covars.
|
||||||
|
initial_clusters: specifies how to initialize the clusters for training.
|
||||||
|
See gmm_ops.gmm for the possible values.
|
||||||
|
covariance_type: one of "full", "diag".
|
||||||
|
batch_size: See TensorFlowEstimator
|
||||||
|
steps: See TensorFlowEstimator
|
||||||
|
continue_training: See TensorFlowEstimator
|
||||||
|
config: See TensorFlowEstimator
|
||||||
|
verbose: See TensorFlowEstimator
|
||||||
|
"""
|
||||||
|
super(GMM, self).__init__(
|
||||||
|
model_dir=model_dir,
|
||||||
|
config=config)
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.steps = steps
|
||||||
|
self.continue_training = continue_training
|
||||||
|
self.verbose = verbose
|
||||||
|
self._num_clusters = num_clusters
|
||||||
|
self._params = params
|
||||||
|
self._training_initial_clusters = initial_clusters
|
||||||
|
self._covariance_type = covariance_type
|
||||||
|
self._training_graph = None
|
||||||
|
self._random_seed = random_seed
|
||||||
|
|
||||||
|
def fit(self, x, y=None, monitors=None, logdir=None, steps=None):
|
||||||
|
"""Trains a GMM clustering on x.
|
||||||
|
|
||||||
|
Note: See TensorFlowEstimator for logic for continuous training and graph
|
||||||
|
construction across multiple calls to fit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: training input matrix of shape [n_samples, n_features].
|
||||||
|
y: labels. Should be None.
|
||||||
|
monitors: List of `Monitor` objects to print training progress and
|
||||||
|
invoke early stopping.
|
||||||
|
logdir: the directory to save the log file that can be used for optional
|
||||||
|
visualization.
|
||||||
|
steps: number of training steps. If not None, overrides the value passed
|
||||||
|
in constructor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Returns self.
|
||||||
|
"""
|
||||||
|
if logdir is not None:
|
||||||
|
self._model_dir = logdir
|
||||||
|
self._data_feeder = data_feeder.setup_train_data_feeder(
|
||||||
|
x, None, self._num_clusters, self.batch_size)
|
||||||
|
self._train_model(input_fn=self._data_feeder.input_builder,
|
||||||
|
feed_fn=self._data_feeder.get_feed_dict_fn(),
|
||||||
|
steps=steps or self.steps,
|
||||||
|
monitors=monitors,
|
||||||
|
init_feed_fn=self._data_feeder.get_feed_dict_fn())
|
||||||
|
return self
|
||||||
|
|
||||||
|
def predict(self, x, batch_size=None):
|
||||||
|
"""Predict cluster id for each element in x.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: 2-D matrix or iterator.
|
||||||
|
batch_size: size to use for batching up x for querying the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Array with same number of rows as x, containing cluster ids.
|
||||||
|
"""
|
||||||
|
return super(GMM, self).predict(x=x, batch_size=batch_size)[GMM.ASSIGNMENTS]
|
||||||
|
|
||||||
|
def score(self, x, batch_size=None):
|
||||||
|
"""Predict total sum of distances to nearest clusters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: 2-D matrix or iterator.
|
||||||
|
batch_size: size to use for batching up x for querying the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total score.
|
||||||
|
"""
|
||||||
|
return np.sum(self.evaluate(x=x, batch_size=batch_size)[GMM.SCORES])
|
||||||
|
|
||||||
|
def transform(self, x, batch_size=None):
|
||||||
|
"""Transforms each element in x to distances to cluster centers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: 2-D matrix or iterator.
|
||||||
|
batch_size: size to use for batching up x for querying the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Array with same number of rows as x, and num_clusters columns, containing
|
||||||
|
distances to the cluster centers.
|
||||||
|
"""
|
||||||
|
return super(GMM, self).predict(x=x, batch_size=batch_size)[GMM.ALL_SCORES]
|
||||||
|
|
||||||
|
def clusters(self):
|
||||||
|
"""Returns cluster centers."""
|
||||||
|
clusters = checkpoints.load_variable(self.model_dir,
|
||||||
|
gmm_ops.GmmAlgorithm.CLUSTERS_VARIABLE)
|
||||||
|
return np.squeeze(clusters, 1)
|
||||||
|
|
||||||
|
def covariances(self):
|
||||||
|
"""Returns the covariances."""
|
||||||
|
return checkpoints.load_variable(
|
||||||
|
self.model_dir,
|
||||||
|
gmm_ops.GmmAlgorithm.CLUSTERS_COVS_VARIABLE)
|
||||||
|
|
||||||
|
def _get_train_ops(self, features, _):
|
||||||
|
(_,
|
||||||
|
_,
|
||||||
|
losses,
|
||||||
|
training_op) = gmm_ops.gmm(
|
||||||
|
features,
|
||||||
|
self._training_initial_clusters,
|
||||||
|
self._num_clusters,
|
||||||
|
self._random_seed,
|
||||||
|
self._covariance_type,
|
||||||
|
self._params)
|
||||||
|
incr_step = tf.assign_add(tf.contrib.framework.get_global_step(), 1)
|
||||||
|
loss = tf.reduce_sum(losses)
|
||||||
|
training_op = with_dependencies([training_op, incr_step], loss)
|
||||||
|
return training_op, loss
|
||||||
|
|
||||||
|
def _get_predict_ops(self, features):
|
||||||
|
(all_scores,
|
||||||
|
model_predictions,
|
||||||
|
_,
|
||||||
|
_) = gmm_ops.gmm(
|
||||||
|
features,
|
||||||
|
self._training_initial_clusters,
|
||||||
|
self._num_clusters,
|
||||||
|
self._random_seed,
|
||||||
|
self._covariance_type,
|
||||||
|
self._params)
|
||||||
|
return {
|
||||||
|
GMM.ALL_SCORES: all_scores[0],
|
||||||
|
GMM.ASSIGNMENTS: model_predictions[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_eval_ops(self, features, _, unused_metrics):
|
||||||
|
(_,
|
||||||
|
_,
|
||||||
|
losses,
|
||||||
|
_) = gmm_ops.gmm(
|
||||||
|
features,
|
||||||
|
self._training_initial_clusters,
|
||||||
|
self._num_clusters,
|
||||||
|
self._random_seed,
|
||||||
|
self._covariance_type,
|
||||||
|
self._params)
|
||||||
|
return {
|
||||||
|
GMM.SCORES: tf.reduce_sum(losses),
|
||||||
|
}
|
461
tensorflow/contrib/factorization/python/ops/gmm_ops.py
Normal file
461
tensorflow/contrib/factorization/python/ops/gmm_ops.py
Normal file
@ -0,0 +1,461 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Gaussian mixture models Operations."""
|
||||||
|
# TODO(xavigonzalvo): Factor out covariance matrix operations to make
|
||||||
|
# code reusable for different types (e.g. diag).
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.python.ops.embedding_ops import embedding_lookup
|
||||||
|
|
||||||
|
# Machine epsilon.
|
||||||
|
MEPS = np.finfo(float).eps
|
||||||
|
FULL_COVARIANCE = 'full'
|
||||||
|
DIAG_COVARIANCE = 'diag'
|
||||||
|
|
||||||
|
|
||||||
|
def _covariance(x, diag):
|
||||||
|
"""Defines the covariance operation of a matrix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: a matrix Tensor. Dimension 0 should contain the number of examples.
|
||||||
|
diag: if True, it computes the diagonal covariance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Tensor representing the covariance of x. In the case of
|
||||||
|
diagonal matrix just the diagonal is returned.
|
||||||
|
"""
|
||||||
|
num_points = tf.to_float(tf.shape(x)[0])
|
||||||
|
x -= tf.reduce_mean(x, 0, keep_dims=True)
|
||||||
|
if diag:
|
||||||
|
cov = tf.reduce_sum(
|
||||||
|
tf.square(x), 0, keep_dims=True) / (num_points - 1)
|
||||||
|
else:
|
||||||
|
cov = tf.matmul(x, x, transpose_a=True) / (num_points - 1)
|
||||||
|
return cov
|
||||||
|
|
||||||
|
|
||||||
|
def _init_clusters_random(data, num_clusters, random_seed):
|
||||||
|
"""Does random initialization of clusters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: a list of Tensors with a matrix of data, each row is an example.
|
||||||
|
num_clusters: an integer with the number of clusters.
|
||||||
|
random_seed: Seed for PRNG used to initialize seeds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Tensor with num_clusters random rows of data.
|
||||||
|
"""
|
||||||
|
assert isinstance(data, list)
|
||||||
|
num_data = tf.add_n([tf.shape(inp)[0] for inp in data])
|
||||||
|
with tf.control_dependencies([tf.assert_less_equal(num_clusters, num_data)]):
|
||||||
|
indices = tf.random_uniform([num_clusters],
|
||||||
|
minval=0,
|
||||||
|
maxval=tf.cast(num_data, tf.int64),
|
||||||
|
seed=random_seed,
|
||||||
|
dtype=tf.int64)
|
||||||
|
indices = tf.cast(indices, tf.int32) % num_data
|
||||||
|
clusters_init = embedding_lookup(data, indices, partition_strategy='div')
|
||||||
|
return clusters_init
|
||||||
|
|
||||||
|
|
||||||
|
class GmmAlgorithm(object):
|
||||||
|
"""Tensorflow Gaussian mixture model clustering class."""
|
||||||
|
CLUSTERS_VARIABLE = 'clusters'
|
||||||
|
CLUSTERS_COVS_VARIABLE = 'clusters_covs'
|
||||||
|
|
||||||
|
def __init__(self, data, num_classes, initial_means=None, params='wmc',
|
||||||
|
covariance_type=FULL_COVARIANCE, random_seed=0):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: a list of Tensors with data, each row is a new example.
|
||||||
|
num_classes: number of clusters.
|
||||||
|
initial_means: a Tensor with a matrix of means. If None, means are
|
||||||
|
computed by sampling randomly.
|
||||||
|
params: Controls which parameters are updated in the training
|
||||||
|
process. Can contain any combination of "w" for weights, "m" for
|
||||||
|
means, and "c" for covariances.
|
||||||
|
covariance_type: one of "full", "diag".
|
||||||
|
random_seed: Seed for PRNG used to initialize seeds.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception if covariance type is unknown.
|
||||||
|
"""
|
||||||
|
self._params = params
|
||||||
|
self._random_seed = random_seed
|
||||||
|
self._covariance_type = covariance_type
|
||||||
|
if self._covariance_type not in [DIAG_COVARIANCE, FULL_COVARIANCE]:
|
||||||
|
raise Exception( # pylint: disable=g-doc-exception
|
||||||
|
'programmer error: Invalid covariance type: %s' %
|
||||||
|
self._covariance_type)
|
||||||
|
# Create sharded variables for multiple shards. The following
|
||||||
|
# lists are indexed by shard.
|
||||||
|
# Probability per example in a class.
|
||||||
|
num_shards = len(data)
|
||||||
|
self._probs = [None] * num_shards
|
||||||
|
# Prior probability.
|
||||||
|
self._prior_probs = [None] * num_shards
|
||||||
|
# Membership weights w_{ik} where "i" is the i-th example and "k"
|
||||||
|
# is the k-th mixture.
|
||||||
|
self._w = [None] * num_shards
|
||||||
|
# Number of examples in a class.
|
||||||
|
self._points_in_k = [None] * num_shards
|
||||||
|
first_shard = data[0]
|
||||||
|
self._dimensions = tf.shape(first_shard)[1]
|
||||||
|
self._num_classes = num_classes
|
||||||
|
# Small value to guarantee that covariances are invertible.
|
||||||
|
self._min_var = tf.diag(tf.ones(tf.pack([self._dimensions]))) * 1e-3
|
||||||
|
self._create_variables(data, initial_means)
|
||||||
|
# Operations of partial statistics for the computation of the means.
|
||||||
|
self._w_mul_x = []
|
||||||
|
# Operations of partial statistics for the computation of the covariances.
|
||||||
|
self._w_mul_x2 = []
|
||||||
|
self._define_graph(data)
|
||||||
|
|
||||||
|
def _create_variables(self, data, initial_means=None):
|
||||||
|
"""Initializes GMM algorithm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: a list of Tensors with data, each row is a new example.
|
||||||
|
initial_means: a Tensor with a matrix of means.
|
||||||
|
"""
|
||||||
|
first_shard = data[0]
|
||||||
|
# Initialize means: num_classes X 1 X dimensions.
|
||||||
|
if initial_means is not None:
|
||||||
|
self._means = tf.Variable(tf.expand_dims(initial_means, 1),
|
||||||
|
name=self.CLUSTERS_VARIABLE,
|
||||||
|
validate_shape=False, dtype=tf.float32)
|
||||||
|
else:
|
||||||
|
# Sample data randomly
|
||||||
|
self._means = tf.Variable(tf.expand_dims(
|
||||||
|
_init_clusters_random(data, self._num_classes, self._random_seed), 1),
|
||||||
|
name=self.CLUSTERS_VARIABLE,
|
||||||
|
validate_shape=False)
|
||||||
|
|
||||||
|
# Initialize covariances.
|
||||||
|
if self._covariance_type == FULL_COVARIANCE:
|
||||||
|
cov = _covariance(first_shard, False) + self._min_var
|
||||||
|
# A matrix per class, num_classes X dimensions X dimensions
|
||||||
|
covs = tf.tile(
|
||||||
|
tf.expand_dims(cov, 0), [self._num_classes, 1, 1])
|
||||||
|
elif self._covariance_type == DIAG_COVARIANCE:
|
||||||
|
cov = _covariance(first_shard, True) + self._min_var
|
||||||
|
# A diagonal per row, num_classes X dimensions.
|
||||||
|
covs = tf.tile(tf.expand_dims(tf.diag_part(cov), 0),
|
||||||
|
[self._num_classes, 1])
|
||||||
|
self._covs = tf.Variable(covs, name='clusters_covs', validate_shape=False)
|
||||||
|
# Mixture weights, representing the probability that a randomly
|
||||||
|
# selected unobservable data (in EM terms) was generated by component k.
|
||||||
|
self._alpha = tf.Variable(tf.tile([1.0 / self._num_classes],
|
||||||
|
[self._num_classes]))
|
||||||
|
|
||||||
|
def training_ops(self):
|
||||||
|
"""Returns the training operation."""
|
||||||
|
return self._train_ops
|
||||||
|
|
||||||
|
def alphas(self):
|
||||||
|
return self._alpha
|
||||||
|
|
||||||
|
def clusters(self):
|
||||||
|
"""Returns the clusters with dimensions num_classes X 1 X num_dimensions."""
|
||||||
|
return self._means
|
||||||
|
|
||||||
|
def covariances(self):
|
||||||
|
"""Returns the covariances matrices."""
|
||||||
|
return self._covs
|
||||||
|
|
||||||
|
def assignments(self):
|
||||||
|
"""Returns a list of Tensors with the matrix of assignments per shard."""
|
||||||
|
ret = []
|
||||||
|
for w in self._w:
|
||||||
|
ret.append(tf.argmax(w, 1))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def scores(self):
|
||||||
|
"""Returns the distances to each class.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple with two Tensors. The first contains the distance to
|
||||||
|
each class. The second contains the distance to the assigned
|
||||||
|
class.
|
||||||
|
"""
|
||||||
|
return (self._all_scores, self._scores)
|
||||||
|
|
||||||
|
def _define_graph(self, data):
|
||||||
|
"""Define graph for a single iteration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: a list of Tensors defining the training data.
|
||||||
|
"""
|
||||||
|
for shard_id, shard in enumerate(data):
|
||||||
|
self._num_examples = tf.shape(shard)[0]
|
||||||
|
shard = tf.expand_dims(shard, 0)
|
||||||
|
self._define_log_prob_operation(shard_id, shard)
|
||||||
|
self._define_prior_log_prob_operation(shard_id)
|
||||||
|
self._define_expectation_operation(shard_id)
|
||||||
|
self._define_partial_maximization_operation(shard_id, shard)
|
||||||
|
self._define_maximization_operation(len(data))
|
||||||
|
self._define_distance_to_clusters(data)
|
||||||
|
|
||||||
|
def _define_full_covariance_probs(self, shard_id, shard):
|
||||||
|
"""Defines the full covariance probabilties per example in a class.
|
||||||
|
|
||||||
|
Updates a matrix with dimension num_examples X num_classes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shard_id: id of the current shard.
|
||||||
|
shard: current data shard, 1 X num_examples X dimensions.
|
||||||
|
"""
|
||||||
|
diff = shard - self._means
|
||||||
|
cholesky = tf.batch_cholesky(self._covs + self._min_var)
|
||||||
|
log_det_covs = 2.0 * tf.reduce_sum(tf.log(
|
||||||
|
tf.batch_matrix_diag_part(cholesky)), 1)
|
||||||
|
x_mu_cov = tf.square(tf.batch_matrix_triangular_solve(
|
||||||
|
cholesky, tf.transpose(diff, perm=[0, 2, 1]),
|
||||||
|
lower=True))
|
||||||
|
diag_m = tf.transpose(tf.reduce_sum(x_mu_cov, 1))
|
||||||
|
self._probs[shard_id] = -0.5 * (
|
||||||
|
diag_m + tf.to_float(self._dimensions) * tf.log(2 * np.pi) +
|
||||||
|
log_det_covs)
|
||||||
|
|
||||||
|
def _define_diag_covariance_probs(self, shard_id, shard):
|
||||||
|
"""Defines the diagonal covariance probabilities per example in a class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shard_id: id of the current shard.
|
||||||
|
shard: current data shard, 1 X num_examples X dimensions.
|
||||||
|
|
||||||
|
Returns a matrix num_examples * num_classes.
|
||||||
|
"""
|
||||||
|
# num_classes X 1
|
||||||
|
# TODO(xavigonzalvo): look into alternatives to log for
|
||||||
|
# reparametrization of variance parameters.
|
||||||
|
det_expanded = tf.reduce_sum(tf.log(self._covs + 1e-3),
|
||||||
|
1, keep_dims=True)
|
||||||
|
diff = shard - self._means
|
||||||
|
x2 = tf.square(diff)
|
||||||
|
cov_expanded = tf.expand_dims(1.0 / (self._covs + 1e-3), 2)
|
||||||
|
# num_classes X num_examples
|
||||||
|
x2_cov = tf.batch_matmul(x2, cov_expanded)
|
||||||
|
x2_cov = tf.transpose(tf.squeeze(x2_cov, [2]))
|
||||||
|
self._probs[shard_id] = -0.5 * (
|
||||||
|
tf.to_float(self._dimensions) * tf.log(2.0 * np.pi) +
|
||||||
|
tf.transpose(det_expanded) + x2_cov)
|
||||||
|
|
||||||
|
def _define_log_prob_operation(self, shard_id, shard):
|
||||||
|
"""Probability per example in a class.
|
||||||
|
|
||||||
|
Updates a matrix with dimension num_examples X num_classes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shard_id: id of the current shard.
|
||||||
|
shard: current data shard, 1 X num_examples X dimensions.
|
||||||
|
"""
|
||||||
|
# TODO(xavigonzalvo): Use the pdf defined in
|
||||||
|
# third_party/tensorflow/contrib/distributions/python/ops/gaussian.py
|
||||||
|
if self._covariance_type == FULL_COVARIANCE:
|
||||||
|
self._define_full_covariance_probs(shard_id, shard)
|
||||||
|
elif self._covariance_type == DIAG_COVARIANCE:
|
||||||
|
self._define_diag_covariance_probs(shard_id, shard)
|
||||||
|
self._probs[shard_id] += tf.log(self._alpha)
|
||||||
|
|
||||||
|
def _define_prior_log_prob_operation(self, shard_id):
|
||||||
|
"""Computes the prior probability of all samples.
|
||||||
|
|
||||||
|
Updates a vector where each item is the prior probabibility of an
|
||||||
|
input example.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shard_id: id of current shard_id.
|
||||||
|
"""
|
||||||
|
self._prior_probs[shard_id] = tf.log(
|
||||||
|
tf.reduce_sum(tf.exp(self._probs[shard_id]), 1, keep_dims=True))
|
||||||
|
|
||||||
|
def _define_expectation_operation(self, shard_id):
|
||||||
|
# Shape broadcasting.
|
||||||
|
probs = tf.expand_dims(self._probs[shard_id], 0)
|
||||||
|
# Membership weights are computed as:
|
||||||
|
# w_{ik} = \frac{\alpha_k f(\mathbf{y_i}|\mathbf{\theta}_k)}
|
||||||
|
# {\sum_{m=1}^{K}\alpha_mf(\mathbf{y_i}|\mathbf{\theta}_m)}
|
||||||
|
# where "i" is the i-th example, "k" is the k-th mixture, theta are
|
||||||
|
# the model parameters and y_i the observations.
|
||||||
|
# These are defined for each shard.
|
||||||
|
self._w[shard_id] = tf.reshape(
|
||||||
|
tf.exp(probs - self._prior_probs[shard_id]),
|
||||||
|
tf.pack([self._num_examples, self._num_classes]))
|
||||||
|
|
||||||
|
def _define_partial_maximization_operation(self, shard_id, shard):
|
||||||
|
"""Computes the partial statistics of the means and covariances.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shard_id: current shard id.
|
||||||
|
shard: current data shard, 1 X num_examples X dimensions.
|
||||||
|
"""
|
||||||
|
# Soft assignment of each data point to each of the two clusters.
|
||||||
|
self._points_in_k[shard_id] = tf.reduce_sum(self._w[shard_id], 0,
|
||||||
|
keep_dims=True)
|
||||||
|
# Partial means.
|
||||||
|
w_mul_x = tf.expand_dims(
|
||||||
|
tf.matmul(self._w[shard_id],
|
||||||
|
tf.squeeze(shard, [0]), transpose_a=True), 1)
|
||||||
|
self._w_mul_x.append(w_mul_x)
|
||||||
|
# Partial covariances.
|
||||||
|
x = tf.concat(0, [shard for _ in range(self._num_classes)])
|
||||||
|
x_trans = tf.transpose(x, perm=[0, 2, 1])
|
||||||
|
x_mul_w = tf.concat(0, [
|
||||||
|
tf.expand_dims(x_trans[k, :, :] * self._w[shard_id][:, k], 0)
|
||||||
|
for k in range(self._num_classes)])
|
||||||
|
self._w_mul_x2.append(tf.batch_matmul(x_mul_w, x))
|
||||||
|
|
||||||
|
def _define_maximization_operation(self, num_batches):
|
||||||
|
"""Maximization operations."""
|
||||||
|
# TODO(xavigonzalvo): some of these operations could be moved to C++.
|
||||||
|
# Compute the effective number of data points assigned to component k.
|
||||||
|
with tf.control_dependencies(self._w):
|
||||||
|
points_in_k = tf.squeeze(tf.add_n(self._points_in_k), squeeze_dims=[0])
|
||||||
|
# Update alpha.
|
||||||
|
if 'w' in self._params:
|
||||||
|
final_points_in_k = points_in_k / num_batches
|
||||||
|
num_examples = tf.to_float(tf.reduce_sum(final_points_in_k))
|
||||||
|
self._alpha_op = self._alpha.assign(
|
||||||
|
final_points_in_k / (num_examples + MEPS))
|
||||||
|
else:
|
||||||
|
self._alpha_op = tf.no_op()
|
||||||
|
self._train_ops = [self._alpha_op]
|
||||||
|
|
||||||
|
# Update means.
|
||||||
|
points_in_k_expanded = tf.reshape(points_in_k,
|
||||||
|
[self._num_classes, 1, 1])
|
||||||
|
if 'm' in self._params:
|
||||||
|
self._means_op = self._means.assign(
|
||||||
|
tf.div(tf.add_n(self._w_mul_x), points_in_k_expanded + MEPS))
|
||||||
|
else:
|
||||||
|
self._means_op = tf.no_op()
|
||||||
|
# means are (num_classes x 1 x dims)
|
||||||
|
|
||||||
|
# Update covariances.
|
||||||
|
with tf.control_dependencies([self._means_op]):
|
||||||
|
b = tf.add_n(self._w_mul_x2) / (points_in_k_expanded + MEPS)
|
||||||
|
new_covs = []
|
||||||
|
for k in range(self._num_classes):
|
||||||
|
mean = self._means.ref()[k, :, :]
|
||||||
|
square_mean = tf.matmul(mean, mean, transpose_a=True)
|
||||||
|
new_cov = b[k, :, :] - square_mean + self._min_var
|
||||||
|
if self._covariance_type == FULL_COVARIANCE:
|
||||||
|
new_covs.append(tf.expand_dims(new_cov, 0))
|
||||||
|
elif self._covariance_type == DIAG_COVARIANCE:
|
||||||
|
new_covs.append(tf.expand_dims(tf.diag_part(new_cov), 0))
|
||||||
|
new_covs = tf.concat(0, new_covs)
|
||||||
|
if 'c' in self._params:
|
||||||
|
# Train operations don't need to take care of the means
|
||||||
|
# because covariances already depend on it.
|
||||||
|
with tf.control_dependencies([self._means_op, new_covs]):
|
||||||
|
self._train_ops.append(
|
||||||
|
tf.assign(self._covs, new_covs, validate_shape=False))
|
||||||
|
|
||||||
|
def _define_distance_to_clusters(self, data):
|
||||||
|
"""Defines the Mahalanobis distance to the assigned Gaussian."""
|
||||||
|
# TODO(xavigonzalvo): reuse (input - mean) * cov^-1 * (input -
|
||||||
|
# mean) from log probability function.
|
||||||
|
self._all_scores = []
|
||||||
|
for shard in data:
|
||||||
|
all_scores = []
|
||||||
|
shard = tf.expand_dims(shard, 0)
|
||||||
|
for c in xrange(self._num_classes):
|
||||||
|
if self._covariance_type == FULL_COVARIANCE:
|
||||||
|
cov = self._covs[c, :, :]
|
||||||
|
elif self._covariance_type == DIAG_COVARIANCE:
|
||||||
|
cov = tf.diag(self._covs[c, :])
|
||||||
|
inverse = tf.matrix_inverse(cov + self._min_var)
|
||||||
|
inv_cov = tf.tile(
|
||||||
|
tf.expand_dims(inverse, 0),
|
||||||
|
tf.pack([self._num_examples, 1, 1]))
|
||||||
|
diff = tf.transpose(shard - self._means[c, :, :], perm=[1, 0, 2])
|
||||||
|
m_left = tf.batch_matmul(diff, inv_cov)
|
||||||
|
all_scores.append(tf.sqrt(tf.batch_matmul(
|
||||||
|
m_left, tf.transpose(diff, perm=[0, 2, 1])
|
||||||
|
)))
|
||||||
|
self._all_scores.append(tf.reshape(
|
||||||
|
tf.concat(1, all_scores),
|
||||||
|
tf.pack([self._num_examples, self._num_classes])))
|
||||||
|
|
||||||
|
# Distance to the associated class.
|
||||||
|
self._all_scores = tf.concat(0, self._all_scores)
|
||||||
|
assignments = tf.concat(0, self.assignments())
|
||||||
|
rows = tf.to_int64(tf.range(0, self._num_examples))
|
||||||
|
indices = tf.concat(1, [tf.expand_dims(rows, 1),
|
||||||
|
tf.expand_dims(assignments, 1)])
|
||||||
|
self._scores = tf.gather_nd(self._all_scores, indices)
|
||||||
|
|
||||||
|
def _define_loglikelihood_operation(self):
|
||||||
|
"""Defines the total log-likelihood of current iteration."""
|
||||||
|
self._ll_op = []
|
||||||
|
for prior_probs in self._prior_probs:
|
||||||
|
self._ll_op.append(tf.reduce_sum(tf.log(prior_probs)))
|
||||||
|
tf.scalar_summary('ll', tf.reduce_sum(self._ll_op))
|
||||||
|
|
||||||
|
|
||||||
|
def gmm(inp, initial_clusters, num_clusters, random_seed,
|
||||||
|
covariance_type=FULL_COVARIANCE, params='wmc'):
|
||||||
|
"""Creates the graph for Gaussian mixture model (GMM) clustering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inp: An input tensor or list of input tensors
|
||||||
|
initial_clusters: Specifies the clusters used during
|
||||||
|
initialization. Can be a tensor or numpy array, or a function
|
||||||
|
that generates the clusters. Can also be "random" to specify
|
||||||
|
that clusters should be chosen randomly from input data. Note: type
|
||||||
|
is diverse to be consistent with skflow.
|
||||||
|
num_clusters: number of clusters.
|
||||||
|
random_seed: Python integer. Seed for PRNG used to initialize centers.
|
||||||
|
covariance_type: one of "diag", "full".
|
||||||
|
params: Controls which parameters are updated in the training
|
||||||
|
process. Can contain any combination of "w" for weights, "m" for
|
||||||
|
means, and "c" for covars.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Note: tuple of lists returned to be consistent with skflow
|
||||||
|
A tuple consisting of:
|
||||||
|
all_scores: A matrix (or list of matrices) of dimensions (num_input,
|
||||||
|
num_clusters) where the value is the distance of an input vector and a
|
||||||
|
cluster center.
|
||||||
|
assignments: A vector (or list of vectors). Each element in the vector
|
||||||
|
corresponds to an input row in 'inp' and specifies the cluster id
|
||||||
|
corresponding to the input.
|
||||||
|
scores: Similar to assignments but specifies the distance to the
|
||||||
|
assigned cluster instead.
|
||||||
|
training_op: an op that runs an iteration of training.
|
||||||
|
"""
|
||||||
|
initial_means = None
|
||||||
|
if initial_clusters != 'random' and not isinstance(
|
||||||
|
initial_clusters, tf.Tensor):
|
||||||
|
initial_means = tf.constant(initial_clusters, dtype=tf.float32)
|
||||||
|
|
||||||
|
# Implementation of GMM.
|
||||||
|
inp = inp if isinstance(inp, list) else [inp]
|
||||||
|
gmm_tool = GmmAlgorithm(inp, num_clusters, initial_means, params,
|
||||||
|
covariance_type, random_seed)
|
||||||
|
training_ops = gmm_tool.training_ops()
|
||||||
|
assignments = gmm_tool.assignments()
|
||||||
|
all_scores, scores = gmm_tool.scores()
|
||||||
|
return [all_scores], [assignments], [scores], tf.group(*training_ops)
|
198
tensorflow/contrib/factorization/python/ops/gmm_ops_test.py
Normal file
198
tensorflow/contrib/factorization/python/ops/gmm_ops_test.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Tests for gmm_ops."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.contrib.factorization.python.ops import gmm_ops
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
|
||||||
|
|
||||||
|
class GmmOpsTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.num_examples = 1000
|
||||||
|
self.iterations = 40
|
||||||
|
self.seed = 4
|
||||||
|
tf.set_random_seed(self.seed)
|
||||||
|
np.random.seed(self.seed * 2)
|
||||||
|
self.data, self.true_assignments = self.make_data(self.num_examples)
|
||||||
|
# Generate more complicated data.
|
||||||
|
self.centers = [[1, 1], [-1, 0.5], [2, 1]]
|
||||||
|
self.more_data, self.more_true_assignments = self.make_data_from_centers(
|
||||||
|
self.num_examples, self.centers)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_data(num_vectors):
|
||||||
|
"""Generates 2-dimensional data centered on (2,2), (-1,-1).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_vectors: number of training examples.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing the data as a numpy array and the cluster ids.
|
||||||
|
"""
|
||||||
|
vectors = []
|
||||||
|
classes = []
|
||||||
|
for _ in xrange(num_vectors):
|
||||||
|
if np.random.random() > 0.5:
|
||||||
|
vectors.append([np.random.normal(2.0, 0.6),
|
||||||
|
np.random.normal(2.0, 0.9)])
|
||||||
|
classes.append(0)
|
||||||
|
else:
|
||||||
|
vectors.append([np.random.normal(-1.0, 0.4),
|
||||||
|
np.random.normal(-1.0, 0.5)])
|
||||||
|
classes.append(1)
|
||||||
|
return np.asarray(vectors), classes
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_data_from_centers(num_vectors, centers):
|
||||||
|
"""Generates 2-dimensional data with random centers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_vectors: number of training examples.
|
||||||
|
centers: a list of random 2-dimensional centers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing the data as a numpy array and the cluster ids.
|
||||||
|
"""
|
||||||
|
vectors = []
|
||||||
|
classes = []
|
||||||
|
for _ in xrange(num_vectors):
|
||||||
|
current_class = np.random.random_integers(0, len(centers) - 1)
|
||||||
|
vectors.append([np.random.normal(centers[current_class][0],
|
||||||
|
np.random.random_sample()),
|
||||||
|
np.random.normal(centers[current_class][1],
|
||||||
|
np.random.random_sample())])
|
||||||
|
classes.append(current_class)
|
||||||
|
return np.asarray(vectors), len(centers)
|
||||||
|
|
||||||
|
def test_covariance(self):
|
||||||
|
start_time = time.time()
|
||||||
|
data = self.data.T
|
||||||
|
np_cov = np.cov(data)
|
||||||
|
logging.info('Numpy took %f', time.time() - start_time)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
with self.test_session() as sess:
|
||||||
|
op = gmm_ops._covariance(
|
||||||
|
tf.constant(data.T, dtype=tf.float32),
|
||||||
|
False)
|
||||||
|
op_diag = gmm_ops._covariance(
|
||||||
|
tf.constant(data.T, dtype=tf.float32),
|
||||||
|
True)
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
tf_cov = sess.run(op)
|
||||||
|
np.testing.assert_array_almost_equal(np_cov, tf_cov)
|
||||||
|
logging.info('Tensorflow took %f', time.time() - start_time)
|
||||||
|
tf_cov = sess.run(op_diag)
|
||||||
|
np.testing.assert_array_almost_equal(
|
||||||
|
np.diag(np_cov), np.ravel(tf_cov), decimal=5)
|
||||||
|
|
||||||
|
def test_simple_cluster(self):
|
||||||
|
"""Tests that the clusters are correct."""
|
||||||
|
num_classes = 2
|
||||||
|
graph = tf.Graph()
|
||||||
|
with graph.as_default() as g:
|
||||||
|
g.seed = 5
|
||||||
|
with self.test_session() as sess:
|
||||||
|
data = tf.constant(self.data, dtype=tf.float32)
|
||||||
|
_, assignments, _, training_op = gmm_ops.gmm(data, 'random',
|
||||||
|
num_classes,
|
||||||
|
random_seed=self.seed)
|
||||||
|
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
for _ in xrange(self.iterations):
|
||||||
|
sess.run(training_op)
|
||||||
|
assignments = sess.run(assignments)
|
||||||
|
accuracy = np.mean(
|
||||||
|
np.asarray(self.true_assignments) == np.squeeze(assignments))
|
||||||
|
logging.info('Accuracy: %f', accuracy)
|
||||||
|
self.assertGreater(accuracy, 0.98)
|
||||||
|
|
||||||
|
def testParams(self):
|
||||||
|
"""Tests that the params work as intended."""
|
||||||
|
num_classes = 2
|
||||||
|
with self.test_session() as sess:
|
||||||
|
# Experiment 1. Update weights only.
|
||||||
|
data = tf.constant(self.data, dtype=tf.float32)
|
||||||
|
gmm_tool = gmm_ops.GmmAlgorithm([data], num_classes,
|
||||||
|
[[3.0, 3.0], [0.0, 0.0]], 'w')
|
||||||
|
training_ops = gmm_tool.training_ops()
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
for _ in xrange(self.iterations):
|
||||||
|
sess.run(training_ops)
|
||||||
|
|
||||||
|
# Only the probability to each class is updated.
|
||||||
|
alphas = sess.run(gmm_tool.alphas())
|
||||||
|
self.assertGreater(alphas[1], 0.6)
|
||||||
|
means = sess.run(gmm_tool.clusters())
|
||||||
|
np.testing.assert_almost_equal(
|
||||||
|
np.expand_dims([[3.0, 3.0], [0.0, 0.0]], 1), means)
|
||||||
|
covs = sess.run(gmm_tool.covariances())
|
||||||
|
np.testing.assert_almost_equal(covs[0], covs[1])
|
||||||
|
|
||||||
|
# Experiment 2. Update means and covariances.
|
||||||
|
gmm_tool = gmm_ops.GmmAlgorithm([data], num_classes,
|
||||||
|
[[3.0, 3.0], [0.0, 0.0]], 'mc')
|
||||||
|
training_ops = gmm_tool.training_ops()
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
for _ in xrange(self.iterations):
|
||||||
|
sess.run(training_ops)
|
||||||
|
alphas = sess.run(gmm_tool.alphas())
|
||||||
|
self.assertAlmostEqual(alphas[0], alphas[1])
|
||||||
|
means = sess.run(gmm_tool.clusters())
|
||||||
|
np.testing.assert_almost_equal(
|
||||||
|
np.expand_dims([[2.0, 2.0], [-1.0, -1.0]], 1), means, decimal=1)
|
||||||
|
covs = sess.run(gmm_tool.covariances())
|
||||||
|
np.testing.assert_almost_equal(
|
||||||
|
[[0.371111, -0.0050774], [-0.0050774, 0.8651744]],
|
||||||
|
covs[0], decimal=4)
|
||||||
|
np.testing.assert_almost_equal(
|
||||||
|
[[0.146976, 0.0259463], [0.0259463, 0.2543971]],
|
||||||
|
covs[1], decimal=4)
|
||||||
|
|
||||||
|
# Experiment 3. Update covariances only.
|
||||||
|
gmm_tool = gmm_ops.GmmAlgorithm([data], num_classes,
|
||||||
|
[[-1.0, -1.0], [1.0, 1.0]], 'c')
|
||||||
|
training_ops = gmm_tool.training_ops()
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
for _ in xrange(self.iterations):
|
||||||
|
sess.run(training_ops)
|
||||||
|
alphas = sess.run(gmm_tool.alphas())
|
||||||
|
self.assertAlmostEqual(alphas[0], alphas[1])
|
||||||
|
means = sess.run(gmm_tool.clusters())
|
||||||
|
np.testing.assert_almost_equal(
|
||||||
|
np.expand_dims([[-1.0, -1.0], [1.0, 1.0]], 1), means)
|
||||||
|
covs = sess.run(gmm_tool.covariances())
|
||||||
|
np.testing.assert_almost_equal(
|
||||||
|
[[0.1299582, 0.0435872], [0.0435872, 0.2558578]],
|
||||||
|
covs[0], decimal=5)
|
||||||
|
np.testing.assert_almost_equal(
|
||||||
|
[[3.195385, 2.6989155], [2.6989155, 3.3881593]],
|
||||||
|
covs[1], decimal=5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
172
tensorflow/contrib/factorization/python/ops/gmm_test.py
Normal file
172
tensorflow/contrib/factorization/python/ops/gmm_test.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Tests for ops.gmm."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.contrib.factorization.python.ops.gmm import GMM
|
||||||
|
from tensorflow.contrib.factorization.python.ops.kmeans import KMeansClustering as KMeans
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import run_config
|
||||||
|
|
||||||
|
FLAGS = tf.app.flags.FLAGS
|
||||||
|
|
||||||
|
|
||||||
|
class GMMTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
np.random.seed(3)
|
||||||
|
tf.set_random_seed(2)
|
||||||
|
self.num_centers = 2
|
||||||
|
self.num_dims = 2
|
||||||
|
self.num_points = 4000
|
||||||
|
self.batch_size = 100
|
||||||
|
self.true_centers = self.make_random_centers(self.num_centers,
|
||||||
|
self.num_dims)
|
||||||
|
self.points, self.assignments, self.scores = self.make_random_points(
|
||||||
|
self.true_centers,
|
||||||
|
self.num_points)
|
||||||
|
self.true_score = np.add.reduce(self.scores)
|
||||||
|
|
||||||
|
# Use initial means from kmeans (just like scikit-learn does).
|
||||||
|
clusterer = KMeans(num_clusters=self.num_centers)
|
||||||
|
clusterer.fit(self.points, steps=30)
|
||||||
|
self.initial_means = clusterer.clusters()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_random_centers(num_centers, num_dims):
|
||||||
|
return np.round(np.random.rand(num_centers,
|
||||||
|
num_dims).astype(np.float32) * 500)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_random_points(centers, num_points):
|
||||||
|
num_centers, num_dims = centers.shape
|
||||||
|
assignments = np.random.choice(num_centers, num_points)
|
||||||
|
offsets = np.round(np.random.randn(num_points,
|
||||||
|
num_dims).astype(np.float32) * 20)
|
||||||
|
points = centers[assignments] + offsets
|
||||||
|
means = [np.mean(points[assignments == center], axis=0)
|
||||||
|
for center in xrange(num_centers)]
|
||||||
|
covs = [np.cov(points[assignments == center].T)
|
||||||
|
for center in xrange(num_centers)]
|
||||||
|
scores = []
|
||||||
|
for r in xrange(num_points):
|
||||||
|
scores.append(np.sqrt(np.dot(
|
||||||
|
np.dot(points[r, :] - means[assignments[r]],
|
||||||
|
np.linalg.inv(covs[assignments[r]])),
|
||||||
|
points[r, :] - means[assignments[r]])))
|
||||||
|
return (points, assignments, scores)
|
||||||
|
|
||||||
|
def test_clusters(self):
|
||||||
|
"""Tests the shape of the clusters."""
|
||||||
|
gmm = GMM(self.num_centers,
|
||||||
|
initial_clusters=self.initial_means,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
steps=40,
|
||||||
|
continue_training=True,
|
||||||
|
random_seed=4,
|
||||||
|
config=run_config.RunConfig(tf_random_seed=2))
|
||||||
|
gmm.fit(x=self.points, steps=0)
|
||||||
|
clusters = gmm.clusters()
|
||||||
|
self.assertAllEqual(list(clusters.shape),
|
||||||
|
[self.num_centers, self.num_dims])
|
||||||
|
|
||||||
|
def test_fit(self):
|
||||||
|
gmm = GMM(self.num_centers,
|
||||||
|
initial_clusters='random',
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
random_seed=4,
|
||||||
|
config=run_config.RunConfig(tf_random_seed=2))
|
||||||
|
gmm.fit(x=self.points, steps=1)
|
||||||
|
score1 = gmm.score(x=self.points)
|
||||||
|
gmm = GMM(self.num_centers,
|
||||||
|
initial_clusters='random',
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
random_seed=4,
|
||||||
|
config=run_config.RunConfig(tf_random_seed=2))
|
||||||
|
gmm.fit(x=self.points, steps=10)
|
||||||
|
score2 = gmm.score(x=self.points)
|
||||||
|
self.assertGreater(score1, score2)
|
||||||
|
self.assertNear(self.true_score, score2, self.true_score * 0.15)
|
||||||
|
|
||||||
|
def test_infer(self):
|
||||||
|
gmm = GMM(self.num_centers,
|
||||||
|
initial_clusters=self.initial_means,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
steps=40,
|
||||||
|
continue_training=True,
|
||||||
|
random_seed=4,
|
||||||
|
config=run_config.RunConfig(tf_random_seed=2))
|
||||||
|
gmm.fit(x=self.points, steps=60)
|
||||||
|
clusters = gmm.clusters()
|
||||||
|
|
||||||
|
# Make a small test set
|
||||||
|
points, true_assignments, true_offsets = (
|
||||||
|
self.make_random_points(clusters, 40))
|
||||||
|
|
||||||
|
assignments = np.ravel(gmm.predict(points))
|
||||||
|
self.assertAllEqual(true_assignments, assignments)
|
||||||
|
|
||||||
|
# Test score
|
||||||
|
score = gmm.score(points)
|
||||||
|
self.assertNear(score, np.sum(true_offsets), 4.05)
|
||||||
|
|
||||||
|
def _compare_with_sklearn(self, cov_type):
|
||||||
|
# sklearn version.
|
||||||
|
iterations = 40
|
||||||
|
np.random.seed(5)
|
||||||
|
sklearn_assignments = np.asarray([0, 0, 1, 0, 0, 0, 1, 0, 0, 1])
|
||||||
|
sklearn_means = np.asarray([[144.83417719, 254.20130341],
|
||||||
|
[274.38754816, 353.16074346]])
|
||||||
|
sklearn_covs = np.asarray([[[395.0081194, -4.50389512],
|
||||||
|
[-4.50389512, 408.27543989]],
|
||||||
|
[[385.17484203, -31.27834935],
|
||||||
|
[-31.27834935, 391.74249925]]])
|
||||||
|
|
||||||
|
# skflow version.
|
||||||
|
gmm = GMM(self.num_centers,
|
||||||
|
initial_clusters=self.initial_means,
|
||||||
|
covariance_type=cov_type,
|
||||||
|
batch_size=self.num_points,
|
||||||
|
steps=iterations,
|
||||||
|
continue_training=True,
|
||||||
|
config=run_config.RunConfig(tf_random_seed=2))
|
||||||
|
gmm.fit(self.points)
|
||||||
|
skflow_assignments = gmm.predict(self.points[:10, :]).astype(int)
|
||||||
|
self.assertAllClose(sklearn_assignments,
|
||||||
|
np.ravel(skflow_assignments))
|
||||||
|
self.assertAllClose(sklearn_means, gmm.clusters())
|
||||||
|
if cov_type == 'full':
|
||||||
|
self.assertAllClose(sklearn_covs, gmm.covariances(), rtol=0.01)
|
||||||
|
else:
|
||||||
|
for d in [0, 1]:
|
||||||
|
self.assertAllClose(np.diag(sklearn_covs[d]),
|
||||||
|
gmm.covariances()[d, :], rtol=0.01)
|
||||||
|
|
||||||
|
def test_compare_full(self):
|
||||||
|
self._compare_with_sklearn('full')
|
||||||
|
|
||||||
|
def test_compare_diag(self):
|
||||||
|
self._compare_with_sklearn('diag')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
@ -153,9 +153,11 @@ class KMeansTest(tf.test.TestCase):
|
|||||||
def test_fit_with_cosine_distance(self):
|
def test_fit_with_cosine_distance(self):
|
||||||
# Create points on y=x and y=1.5x lines to check the cosine similarity.
|
# Create points on y=x and y=1.5x lines to check the cosine similarity.
|
||||||
# Note that euclidean distance will give different results in this case.
|
# Note that euclidean distance will give different results in this case.
|
||||||
points = np.array([[9, 9], [0.5, 0.5], [10, 15], [0.4, 0.6]])
|
points = np.array(
|
||||||
|
[[9, 9], [0.5, 0.5], [10, 15], [0.4, 0.6]], dtype=np.float32)
|
||||||
# true centers are the unit vectors on lines y=x and y=1.5x
|
# true centers are the unit vectors on lines y=x and y=1.5x
|
||||||
true_centers = np.array([[0.70710678, 0.70710678], [0.5547002, 0.83205029]])
|
true_centers = np.array(
|
||||||
|
[[0.70710678, 0.70710678], [0.5547002, 0.83205029]], dtype=np.float32)
|
||||||
kmeans = KMeans(2,
|
kmeans = KMeans(2,
|
||||||
initial_clusters=kmeans_ops.RANDOM_INIT,
|
initial_clusters=kmeans_ops.RANDOM_INIT,
|
||||||
distance_metric=kmeans_ops.COSINE_DISTANCE,
|
distance_metric=kmeans_ops.COSINE_DISTANCE,
|
||||||
@ -168,8 +170,9 @@ class KMeansTest(tf.test.TestCase):
|
|||||||
np.sort(true_centers, axis=0))
|
np.sort(true_centers, axis=0))
|
||||||
|
|
||||||
def test_transform_with_cosine_distance(self):
|
def test_transform_with_cosine_distance(self):
|
||||||
points = np.array([[2.5, 3.5], [2, 8], [3, 1], [3, 18],
|
points = np.array(
|
||||||
[-2.5, -3.5], [-2, -8], [-3, -1], [-3, -18]])
|
[[2.5, 0.1], [2, 0.2], [3, 0.1], [4, 0.2],
|
||||||
|
[0.1, 2.5], [0.2, 2], [0.1, 3], [0.2, 4]], dtype=np.float32)
|
||||||
|
|
||||||
true_centers = [normalize(np.mean(normalize(points)[4:, :], axis=0,
|
true_centers = [normalize(np.mean(normalize(points)[4:, :], axis=0,
|
||||||
keepdims=True))[0],
|
keepdims=True))[0],
|
||||||
@ -180,8 +183,8 @@ class KMeansTest(tf.test.TestCase):
|
|||||||
initial_clusters=kmeans_ops.RANDOM_INIT,
|
initial_clusters=kmeans_ops.RANDOM_INIT,
|
||||||
distance_metric=kmeans_ops.COSINE_DISTANCE,
|
distance_metric=kmeans_ops.COSINE_DISTANCE,
|
||||||
use_mini_batch=self.use_mini_batch,
|
use_mini_batch=self.use_mini_batch,
|
||||||
config=self.config(3))
|
config=self.config(5))
|
||||||
kmeans.fit(x=points, steps=30, batch_size=8)
|
kmeans.fit(x=points, steps=50, batch_size=8)
|
||||||
|
|
||||||
centers = normalize(kmeans.clusters())
|
centers = normalize(kmeans.clusters())
|
||||||
self.assertAllClose(np.sort(centers, axis=0),
|
self.assertAllClose(np.sort(centers, axis=0),
|
||||||
@ -193,16 +196,16 @@ class KMeansTest(tf.test.TestCase):
|
|||||||
self.assertAllClose(transform, true_transform, atol=1e-3)
|
self.assertAllClose(transform, true_transform, atol=1e-3)
|
||||||
|
|
||||||
def test_predict_with_cosine_distance(self):
|
def test_predict_with_cosine_distance(self):
|
||||||
points = np.array([[2.5, 3.5], [2, 8], [3, 1], [3, 18],
|
points = np.array(
|
||||||
[-2.5, -3.5], [-2, -8], [-3, -1], [-3, -18]]).astype(
|
[[2.5, 0.1], [2, 0.2], [3, 0.1], [4, 0.2],
|
||||||
np.float32)
|
[0.1, 2.5], [0.2, 2], [0.1, 3], [0.2, 4]], dtype=np.float32)
|
||||||
true_centers = np.array(
|
true_centers = np.array(
|
||||||
[normalize(np.mean(normalize(points)[0:4, :],
|
[normalize(np.mean(normalize(points)[0:4, :],
|
||||||
axis=0,
|
axis=0,
|
||||||
keepdims=True))[0],
|
keepdims=True))[0],
|
||||||
normalize(np.mean(normalize(points)[4:, :],
|
normalize(np.mean(normalize(points)[4:, :],
|
||||||
axis=0,
|
axis=0,
|
||||||
keepdims=True))[0]])
|
keepdims=True))[0]], dtype=np.float32)
|
||||||
true_assignments = [0] * 4 + [1] * 4
|
true_assignments = [0] * 4 + [1] * 4
|
||||||
true_score = len(points) - np.tensordot(normalize(points),
|
true_score = len(points) - np.tensordot(normalize(points),
|
||||||
true_centers[true_assignments])
|
true_centers[true_assignments])
|
||||||
@ -230,14 +233,14 @@ class KMeansTest(tf.test.TestCase):
|
|||||||
# the less populated centers.
|
# the less populated centers.
|
||||||
points = np.array([[2.5, 3.5], [2.5, 3.5], [-2, 3], [-2, 3], [-3, -3],
|
points = np.array([[2.5, 3.5], [2.5, 3.5], [-2, 3], [-2, 3], [-3, -3],
|
||||||
[-3.1, -3.2], [-2.8, -3.], [-2.9, -3.1], [-3., -3.1],
|
[-3.1, -3.2], [-2.8, -3.], [-2.9, -3.1], [-3., -3.1],
|
||||||
[-3., -3.1], [-3.2, -3.], [-3., -3.]]).astype(np.float32)
|
[-3., -3.1], [-3.2, -3.], [-3., -3.]], dtype=np.float32)
|
||||||
true_centers = np.array(
|
true_centers = np.array(
|
||||||
[normalize(np.mean(normalize(points)[0:2, :], axis=0,
|
[normalize(np.mean(normalize(points)[0:2, :], axis=0,
|
||||||
keepdims=True))[0],
|
keepdims=True))[0],
|
||||||
normalize(np.mean(normalize(points)[2:4, :], axis=0,
|
normalize(np.mean(normalize(points)[2:4, :], axis=0,
|
||||||
keepdims=True))[0],
|
keepdims=True))[0],
|
||||||
normalize(np.mean(normalize(points)[4:, :], axis=0,
|
normalize(np.mean(normalize(points)[4:, :], axis=0,
|
||||||
keepdims=True))[0]])
|
keepdims=True))[0]], dtype=np.float32)
|
||||||
true_assignments = [0] * 2 + [1] * 2 + [2] * 8
|
true_assignments = [0] * 2 + [1] * 2 + [2] * 8
|
||||||
true_score = len(points) - np.tensordot(normalize(points),
|
true_score = len(points) - np.tensordot(normalize(points),
|
||||||
true_centers[true_assignments])
|
true_centers[true_assignments])
|
||||||
@ -262,7 +265,7 @@ class KMeansTest(tf.test.TestCase):
|
|||||||
self.assertAllClose(score, true_score, atol=1e-2)
|
self.assertAllClose(score, true_score, atol=1e-2)
|
||||||
|
|
||||||
def test_fit_raise_if_num_clusters_larger_than_num_points_random_init(self):
|
def test_fit_raise_if_num_clusters_larger_than_num_points_random_init(self):
|
||||||
points = np.array([[2.0, 3.0], [1.6, 8.2]])
|
points = np.array([[2.0, 3.0], [1.6, 8.2]], dtype=np.float32)
|
||||||
|
|
||||||
with self.assertRaisesOpError('less'):
|
with self.assertRaisesOpError('less'):
|
||||||
kmeans = KMeans(num_clusters=3, initial_clusters=kmeans_ops.RANDOM_INIT)
|
kmeans = KMeans(num_clusters=3, initial_clusters=kmeans_ops.RANDOM_INIT)
|
||||||
@ -270,7 +273,7 @@ class KMeansTest(tf.test.TestCase):
|
|||||||
|
|
||||||
def test_fit_raise_if_num_clusters_larger_than_num_points_kmeans_plus_plus(
|
def test_fit_raise_if_num_clusters_larger_than_num_points_kmeans_plus_plus(
|
||||||
self):
|
self):
|
||||||
points = np.array([[2.0, 3.0], [1.6, 8.2]])
|
points = np.array([[2.0, 3.0], [1.6, 8.2]], dtype=np.float32)
|
||||||
|
|
||||||
with self.assertRaisesOpError(AssertionError):
|
with self.assertRaisesOpError(AssertionError):
|
||||||
kmeans = KMeans(num_clusters=3,
|
kmeans = KMeans(num_clusters=3,
|
||||||
|
@ -21,6 +21,7 @@
|
|||||||
#include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h"
|
#include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
@ -63,13 +64,11 @@ class FileDeleter {
|
|||||||
|
|
||||||
class DecodeAudioOp : public OpKernel {
|
class DecodeAudioOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit DecodeAudioOp(OpKernelConstruction* context)
|
explicit DecodeAudioOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||||
: OpKernel(context) {
|
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_));
|
OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_));
|
||||||
file_format_ = str_util::Lowercase(file_format_);
|
file_format_ = str_util::Lowercase(file_format_);
|
||||||
const std::set<string> valid_file_formats(
|
const std::set<string> valid_file_formats(
|
||||||
kValidFileFormats,
|
kValidFileFormats, kValidFileFormats + TF_ARRAYSIZE(kValidFileFormats));
|
||||||
kValidFileFormats + TF_ARRAYSIZE(kValidFileFormats));
|
|
||||||
OP_REQUIRES(context, valid_file_formats.count(file_format_) == 1,
|
OP_REQUIRES(context, valid_file_formats.count(file_format_) == 1,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"file_format arg must be in {",
|
"file_format arg must be in {",
|
||||||
@ -80,8 +79,7 @@ class DecodeAudioOp : public OpKernel {
|
|||||||
OP_REQUIRES(context, samples_per_second_ > 0,
|
OP_REQUIRES(context, samples_per_second_ > 0,
|
||||||
errors::InvalidArgument("samples_per_second must be > 0."));
|
errors::InvalidArgument("samples_per_second must be > 0."));
|
||||||
|
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(context, context->GetAttr("channel_count", &channel_count_));
|
||||||
context, context->GetAttr("channel_count", &channel_count_));
|
|
||||||
OP_REQUIRES(context, channel_count_ > 0,
|
OP_REQUIRES(context, channel_count_ > 0,
|
||||||
errors::InvalidArgument("channel_count must be > 0."));
|
errors::InvalidArgument("channel_count must be > 0."));
|
||||||
}
|
}
|
||||||
@ -117,14 +115,13 @@ class DecodeAudioOp : public OpKernel {
|
|||||||
LOG(ERROR) << "Ffmpeg failed with error '" << result.error_message()
|
LOG(ERROR) << "Ffmpeg failed with error '" << result.error_message()
|
||||||
<< "'. Returning empty tensor.";
|
<< "'. Returning empty tensor.";
|
||||||
Tensor* output = nullptr;
|
Tensor* output = nullptr;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(context,
|
||||||
context, context->allocate_output(0, TensorShape({0, 0}), &output));
|
context->allocate_output(0, TensorShape({0, 0}), &output));
|
||||||
return;
|
return;
|
||||||
} else {
|
} else {
|
||||||
OP_REQUIRES_OK(context, result);
|
OP_REQUIRES_OK(context, result);
|
||||||
}
|
}
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(context, !output_samples.empty(),
|
||||||
context, !output_samples.empty(),
|
|
||||||
errors::Unknown("No output created by FFmpeg."));
|
errors::Unknown("No output created by FFmpeg."));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, output_samples.size() % channel_count_ == 0,
|
context, output_samples.size() % channel_count_ == 0,
|
||||||
@ -133,8 +130,8 @@ class DecodeAudioOp : public OpKernel {
|
|||||||
// Copy the output data to the output Tensor.
|
// Copy the output data to the output Tensor.
|
||||||
Tensor* output = nullptr;
|
Tensor* output = nullptr;
|
||||||
const int64 frame_count = output_samples.size() / channel_count_;
|
const int64 frame_count = output_samples.size() / channel_count_;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(context,
|
||||||
context, context->allocate_output(
|
context->allocate_output(
|
||||||
0, TensorShape({frame_count, channel_count_}), &output));
|
0, TensorShape({frame_count, channel_count_}), &output));
|
||||||
auto matrix = output->tensor<float, 2>();
|
auto matrix = output->tensor<float, 2>();
|
||||||
for (int32 frame = 0; frame < frame_count; ++frame) {
|
for (int32 frame = 0; frame < frame_count; ++frame) {
|
||||||
@ -159,6 +156,15 @@ REGISTER_OP("DecodeAudio")
|
|||||||
.Attr("file_format: string")
|
.Attr("file_format: string")
|
||||||
.Attr("samples_per_second: int")
|
.Attr("samples_per_second: int")
|
||||||
.Attr("channel_count: int")
|
.Attr("channel_count: int")
|
||||||
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
|
int64 channels;
|
||||||
|
if (c->GetAttr("channel_count", &channels).ok()) {
|
||||||
|
c->set_output(0, c->Matrix(c->UnknownDim(), channels));
|
||||||
|
} else {
|
||||||
|
c->set_output(0, c->Matrix(c->UnknownDim(), c->UnknownDim()));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
})
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Processes the contents of an audio file into a tensor using FFmpeg to decode
|
Processes the contents of an audio file into a tensor using FFmpeg to decode
|
||||||
the file.
|
the file.
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
#include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h"
|
#include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h"
|
||||||
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
|
||||||
@ -24,8 +25,7 @@ namespace ffmpeg {
|
|||||||
|
|
||||||
class EncodeAudioOp : public OpKernel {
|
class EncodeAudioOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit EncodeAudioOp(OpKernelConstruction* context)
|
explicit EncodeAudioOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||||
: OpKernel(context) {
|
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_));
|
OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_));
|
||||||
file_format_ = str_util::Lowercase(file_format_);
|
file_format_ = str_util::Lowercase(file_format_);
|
||||||
OP_REQUIRES(context, file_format_ == "wav",
|
OP_REQUIRES(context, file_format_ == "wav",
|
||||||
@ -35,15 +35,15 @@ class EncodeAudioOp : public OpKernel {
|
|||||||
context, context->GetAttr("samples_per_second", &samples_per_second_));
|
context, context->GetAttr("samples_per_second", &samples_per_second_));
|
||||||
OP_REQUIRES(context, samples_per_second_ > 0,
|
OP_REQUIRES(context, samples_per_second_ > 0,
|
||||||
errors::InvalidArgument("samples_per_second must be > 0."));
|
errors::InvalidArgument("samples_per_second must be > 0."));
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(context,
|
||||||
context, context->GetAttr("bits_per_second", &bits_per_second_));
|
context->GetAttr("bits_per_second", &bits_per_second_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
// Get and verify the input data.
|
// Get and verify the input data.
|
||||||
OP_REQUIRES(context, context->num_inputs() == 1,
|
OP_REQUIRES(
|
||||||
errors::InvalidArgument(
|
context, context->num_inputs() == 1,
|
||||||
"EncodeAudio requires exactly one input."));
|
errors::InvalidArgument("EncodeAudio requires exactly one input."));
|
||||||
const Tensor& contents = context->input(0);
|
const Tensor& contents = context->input(0);
|
||||||
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(contents.shape()),
|
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(contents.shape()),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
@ -88,6 +88,7 @@ REGISTER_OP("EncodeAudio")
|
|||||||
.Attr("file_format: string")
|
.Attr("file_format: string")
|
||||||
.Attr("samples_per_second: int")
|
.Attr("samples_per_second: int")
|
||||||
.Attr("bits_per_second: int = 192000")
|
.Attr("bits_per_second: int = 192000")
|
||||||
|
.SetShapeFn(shape_inference::ScalarShape)
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Processes a `Tensor` containing sampled audio with the number of channels
|
Processes a `Tensor` containing sampled audio with the number of channels
|
||||||
and length of the audio specified by the dimensions of the `Tensor`. The
|
and length of the audio specified by the dimensions of the `Tensor`. The
|
||||||
|
@ -91,6 +91,15 @@ py_test(
|
|||||||
deps = ["//tensorflow:tensorflow_py"],
|
deps = ["//tensorflow:tensorflow_py"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "sampling_ops_threading_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["python/ops/sampling_ops_threading_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
tags = ["notsan"],
|
||||||
|
deps = ["//tensorflow:tensorflow_py"],
|
||||||
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "all_files",
|
name = "all_files",
|
||||||
srcs = glob(
|
srcs = glob(
|
||||||
|
@ -30,6 +30,7 @@
|
|||||||
|
|
||||||
## Deprecation
|
## Deprecation
|
||||||
@@deprecated
|
@@deprecated
|
||||||
|
@@deprecated_arg_values
|
||||||
|
|
||||||
## Arg_Scope
|
## Arg_Scope
|
||||||
@@arg_scope
|
@@arg_scope
|
||||||
|
@ -21,4 +21,5 @@ from __future__ import print_function
|
|||||||
# pylint: disable=wildcard-import
|
# pylint: disable=wildcard-import
|
||||||
from tensorflow.contrib.framework.python.framework.checkpoint_utils import *
|
from tensorflow.contrib.framework.python.framework.checkpoint_utils import *
|
||||||
from tensorflow.contrib.framework.python.framework.deprecation import deprecated
|
from tensorflow.contrib.framework.python.framework.deprecation import deprecated
|
||||||
|
from tensorflow.contrib.framework.python.framework.deprecation import deprecated_arg_values
|
||||||
from tensorflow.contrib.framework.python.framework.tensor_util import *
|
from tensorflow.contrib.framework.python.framework.tensor_util import *
|
||||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import inspect
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
@ -34,45 +36,77 @@ def _get_qualified_name(function):
|
|||||||
return function.__name__
|
return function.__name__
|
||||||
|
|
||||||
|
|
||||||
def _add_deprecation_to_docstring(doc, date, instructions):
|
def _add_deprecation_to_docstring(
|
||||||
|
doc, instructions, no_doc_str, suffix_str, notice):
|
||||||
"""Adds a deprecation notice to a docstring."""
|
"""Adds a deprecation notice to a docstring."""
|
||||||
if not doc:
|
if not doc:
|
||||||
lines = ['DEPRECATED FUNCTION']
|
lines = [no_doc_str]
|
||||||
else:
|
else:
|
||||||
lines = doc.splitlines()
|
lines = doc.splitlines()
|
||||||
lines[0] += ' (deprecated)'
|
lines[0] += ' ' + suffix_str
|
||||||
|
|
||||||
notice = [
|
notice = [''] + notice + [instructions]
|
||||||
'',
|
|
||||||
'THIS FUNCTION IS DEPRECATED. It will be removed after %s.' % date,
|
|
||||||
'Instructions for updating:',
|
|
||||||
'%s' % instructions,
|
|
||||||
]
|
|
||||||
|
|
||||||
if len(lines) > 1:
|
if len(lines) > 1:
|
||||||
# Make sure that we keep our distance from the main body
|
# Make sure that we keep our distance from the main body
|
||||||
if lines[1].strip():
|
if lines[1].strip():
|
||||||
notice += ['']
|
notice.append('')
|
||||||
|
|
||||||
lines = [lines[0]] + notice + lines[1:]
|
lines[1:1] = notice
|
||||||
else:
|
else:
|
||||||
lines += notice
|
lines += notice
|
||||||
|
|
||||||
return '\n'.join(lines)
|
return '\n'.join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_deprecated_function_notice_to_docstring(doc, date, instructions):
|
||||||
|
"""Adds a deprecation notice to a docstring for deprecated functions."""
|
||||||
|
return _add_deprecation_to_docstring(
|
||||||
|
doc, instructions,
|
||||||
|
'DEPRECATED FUNCTION',
|
||||||
|
'(deprecated)', [
|
||||||
|
'THIS FUNCTION IS DEPRECATED. It will be removed after %s.' % date,
|
||||||
|
'Instructions for updating:'])
|
||||||
|
|
||||||
|
|
||||||
|
def _add_deprecated_arg_notice_to_docstring(doc, date, instructions):
|
||||||
|
"""Adds a deprecation notice to a docstring for deprecated arguments."""
|
||||||
|
return _add_deprecation_to_docstring(
|
||||||
|
doc, instructions,
|
||||||
|
'DEPRECATED FUNCTION ARGUMENTS',
|
||||||
|
'(deprecated arguments)', [
|
||||||
|
'SOME ARGUMENTS ARE DEPRECATED. '
|
||||||
|
'They will be removed after %s.' % date,
|
||||||
|
'Instructions for updating:'])
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_deprecation_args(date, instructions):
|
||||||
|
if not date:
|
||||||
|
raise ValueError('Tell us what date this will be deprecated!')
|
||||||
|
if not re.match(r'20\d\d-[01]\d-[0123]\d', date):
|
||||||
|
raise ValueError('Date must be YYYY-MM-DD.')
|
||||||
|
if not instructions:
|
||||||
|
raise ValueError('Don\'t deprecate things without conversion instructions!')
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_callable(func, decorator_name):
|
||||||
|
if not hasattr(func, '__call__'):
|
||||||
|
raise ValueError(
|
||||||
|
'%s is not a function. If this is a property, '
|
||||||
|
'apply @%s after @property.' % (func, decorator_name))
|
||||||
|
|
||||||
|
|
||||||
def deprecated(date, instructions):
|
def deprecated(date, instructions):
|
||||||
"""Decorator for marking functions or methods deprecated.
|
"""Decorator for marking functions or methods deprecated.
|
||||||
|
|
||||||
This decorator adds a deprecation warning to a function's docstring. It has
|
This decorator logs a deprecation warning whenever the decorated function is
|
||||||
the following format:
|
called. It has the following format:
|
||||||
|
|
||||||
<function> (from <module>) is deprecated and will be removed after <date>.
|
<function> (from <module>) is deprecated and will be removed after <date>.
|
||||||
Instructions for updating:
|
Instructions for updating:
|
||||||
<instructions>
|
<instructions>
|
||||||
|
|
||||||
whenever the decorated function is called. <function> will include the class
|
<function> will include the class name if it is a method.
|
||||||
name if it is a method.
|
|
||||||
|
|
||||||
It also edits the docstring of the function: ' (deprecated)' is appended
|
It also edits the docstring of the function: ' (deprecated)' is appended
|
||||||
to the first line of the docstring and a deprecation notice is prepended
|
to the first line of the docstring and a deprecation notice is prepended
|
||||||
@ -90,28 +124,73 @@ def deprecated(date, instructions):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If date is not in ISO 8601 format, or instructions are empty.
|
ValueError: If date is not in ISO 8601 format, or instructions are empty.
|
||||||
"""
|
"""
|
||||||
if not date:
|
_validate_deprecation_args(date, instructions)
|
||||||
raise ValueError('Tell us what date this will be deprecated!')
|
|
||||||
if not re.match(r'20\d\d-[01]\d-[0123]\d', date):
|
|
||||||
raise ValueError('Date must be YYYY-MM-DD.')
|
|
||||||
if not instructions:
|
|
||||||
raise ValueError('Don\'t deprecate things without conversion instructions!')
|
|
||||||
|
|
||||||
def deprecated_wrapper(func):
|
def deprecated_wrapper(func):
|
||||||
"""Deprecation wrapper."""
|
"""Deprecation wrapper."""
|
||||||
if not hasattr(func, '__call__'):
|
_validate_callable(func, 'deprecated')
|
||||||
raise ValueError(
|
@functools.wraps(func)
|
||||||
'%s is not a function.'
|
|
||||||
'If this is a property, apply @deprecated after @property.' % func)
|
|
||||||
def new_func(*args, **kwargs):
|
def new_func(*args, **kwargs):
|
||||||
logging.warning('%s (from %s) is deprecated and will be removed after %s.'
|
logging.warning(
|
||||||
'\nInstructions for updating:\n%s',
|
'%s (from %s) is deprecated and will be removed after %s.\n'
|
||||||
_get_qualified_name(func), func.__module__,
|
'Instructions for updating:\n%s',
|
||||||
date, instructions)
|
_get_qualified_name(func), func.__module__, date, instructions)
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
new_func.__name__ = func.__name__
|
new_func.__doc__ = _add_deprecated_function_notice_to_docstring(
|
||||||
new_func.__doc__ = _add_deprecation_to_docstring(func.__doc__, date,
|
func.__doc__, date, instructions)
|
||||||
instructions)
|
return new_func
|
||||||
new_func.__dict__.update(func.__dict__)
|
return deprecated_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def deprecated_arg_values(date, instructions, **deprecated_kwargs):
|
||||||
|
"""Decorator for marking specific function argument values as deprecated.
|
||||||
|
|
||||||
|
This decorator logs a deprecation warning whenever the decorated function is
|
||||||
|
called with the deprecated argument values. It has the following format:
|
||||||
|
|
||||||
|
Calling <function> (from <module>) with <arg>=<value> is deprecated and
|
||||||
|
will be removed after <date>. Instructions for updating:
|
||||||
|
<instructions>
|
||||||
|
|
||||||
|
<function> will include the class name if it is a method.
|
||||||
|
|
||||||
|
It also edits the docstring of the function: ' (deprecated arguments)' is
|
||||||
|
appended to the first line of the docstring and a deprecation notice is
|
||||||
|
prepended to the rest of the docstring.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
date: String. The date the function is scheduled to be removed. Must be
|
||||||
|
ISO 8601 (YYYY-MM-DD).
|
||||||
|
instructions: String. Instructions on how to update code using the
|
||||||
|
deprecated function.
|
||||||
|
**deprecated_kwargs: The deprecated argument values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decorated function or method.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If date is not in ISO 8601 format, or instructions are empty.
|
||||||
|
"""
|
||||||
|
_validate_deprecation_args(date, instructions)
|
||||||
|
if not deprecated_kwargs:
|
||||||
|
raise ValueError('Specify which argument values are deprecated.')
|
||||||
|
|
||||||
|
def deprecated_wrapper(func):
|
||||||
|
"""Deprecation decorator."""
|
||||||
|
_validate_callable(func, 'deprecated_arg_values')
|
||||||
|
@functools.wraps(func)
|
||||||
|
def new_func(*args, **kwargs):
|
||||||
|
"""Deprecation wrapper."""
|
||||||
|
named_args = inspect.getcallargs(func, *args, **kwargs)
|
||||||
|
for arg_name, arg_value in deprecated_kwargs.items():
|
||||||
|
if arg_name in named_args and named_args[arg_name] == arg_value:
|
||||||
|
logging.warning(
|
||||||
|
'Calling %s (from %s) with %s=%s is deprecated and will be '
|
||||||
|
'removed after %s.\nInstructions for updating:\n%s',
|
||||||
|
_get_qualified_name(func), func.__module__,
|
||||||
|
arg_name, arg_value, date, instructions)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
new_func.__doc__ = _add_deprecated_arg_notice_to_docstring(
|
||||||
|
func.__doc__, date, instructions)
|
||||||
return new_func
|
return new_func
|
||||||
return deprecated_wrapper
|
return deprecated_wrapper
|
||||||
|
@ -56,7 +56,7 @@ class DeprecationTest(tf.test.TestCase):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
arg0: Arg 0.
|
arg0: Arg 0.
|
||||||
arg1: Arg 0.
|
arg1: Arg 1.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Sum of args.
|
Sum of args.
|
||||||
@ -73,13 +73,38 @@ class DeprecationTest(tf.test.TestCase):
|
|||||||
"\n"
|
"\n"
|
||||||
"\n Args:"
|
"\n Args:"
|
||||||
"\n arg0: Arg 0."
|
"\n arg0: Arg 0."
|
||||||
"\n arg1: Arg 0."
|
"\n arg1: Arg 1."
|
||||||
"\n"
|
"\n"
|
||||||
"\n Returns:"
|
"\n Returns:"
|
||||||
"\n Sum of args."
|
"\n Sum of args."
|
||||||
"\n " % (date, instructions),
|
"\n " % (date, instructions),
|
||||||
_fn.__doc__)
|
_fn.__doc__)
|
||||||
self.assertEqual({}, _fn.__dict__)
|
|
||||||
|
# Assert calling new fn issues log warning.
|
||||||
|
self.assertEqual(3, _fn(1, 2))
|
||||||
|
self.assertEqual(1, mock_warning.call_count)
|
||||||
|
(args, _) = mock_warning.call_args
|
||||||
|
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
|
||||||
|
self._assert_subset(set([date, instructions]), set(args[1:]))
|
||||||
|
|
||||||
|
@tf.test.mock.patch.object(logging, "warning", autospec=True)
|
||||||
|
def test_static_fn_with_one_line_doc(self, mock_warning):
|
||||||
|
date = "2016-07-04"
|
||||||
|
instructions = "This is how you update..."
|
||||||
|
|
||||||
|
@deprecation.deprecated(date, instructions)
|
||||||
|
def _fn(arg0, arg1):
|
||||||
|
"""fn doc."""
|
||||||
|
return arg0 + arg1
|
||||||
|
|
||||||
|
# Assert function docs are properly updated.
|
||||||
|
self.assertEqual("_fn", _fn.__name__)
|
||||||
|
self.assertEqual(
|
||||||
|
"fn doc. (deprecated)"
|
||||||
|
"\n"
|
||||||
|
"\nTHIS FUNCTION IS DEPRECATED. It will be removed after %s."
|
||||||
|
"\nInstructions for updating:\n%s" % (date, instructions),
|
||||||
|
_fn.__doc__)
|
||||||
|
|
||||||
# Assert calling new fn issues log warning.
|
# Assert calling new fn issues log warning.
|
||||||
self.assertEqual(3, _fn(1, 2))
|
self.assertEqual(3, _fn(1, 2))
|
||||||
@ -106,7 +131,6 @@ class DeprecationTest(tf.test.TestCase):
|
|||||||
"\nInstructions for updating:"
|
"\nInstructions for updating:"
|
||||||
"\n%s" % (date, instructions),
|
"\n%s" % (date, instructions),
|
||||||
_fn.__doc__)
|
_fn.__doc__)
|
||||||
self.assertEqual({}, _fn.__dict__)
|
|
||||||
|
|
||||||
# Assert calling new fn issues log warning.
|
# Assert calling new fn issues log warning.
|
||||||
self.assertEqual(3, _fn(1, 2))
|
self.assertEqual(3, _fn(1, 2))
|
||||||
@ -131,7 +155,7 @@ class DeprecationTest(tf.test.TestCase):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
arg0: Arg 0.
|
arg0: Arg 0.
|
||||||
arg1: Arg 0.
|
arg1: Arg 1.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Sum of args.
|
Sum of args.
|
||||||
@ -147,7 +171,7 @@ class DeprecationTest(tf.test.TestCase):
|
|||||||
"\n"
|
"\n"
|
||||||
"\n Args:"
|
"\n Args:"
|
||||||
"\n arg0: Arg 0."
|
"\n arg0: Arg 0."
|
||||||
"\n arg1: Arg 0."
|
"\n arg1: Arg 1."
|
||||||
"\n"
|
"\n"
|
||||||
"\n Returns:"
|
"\n Returns:"
|
||||||
"\n Sum of args."
|
"\n Sum of args."
|
||||||
@ -161,6 +185,36 @@ class DeprecationTest(tf.test.TestCase):
|
|||||||
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
|
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
|
||||||
self._assert_subset(set([date, instructions]), set(args[1:]))
|
self._assert_subset(set([date, instructions]), set(args[1:]))
|
||||||
|
|
||||||
|
@tf.test.mock.patch.object(logging, "warning", autospec=True)
|
||||||
|
def test_instance_fn_with_one_line_doc(self, mock_warning):
|
||||||
|
date = "2016-07-04"
|
||||||
|
instructions = "This is how you update..."
|
||||||
|
|
||||||
|
class _Object(object):
|
||||||
|
|
||||||
|
def __init(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@deprecation.deprecated(date, instructions)
|
||||||
|
def _fn(self, arg0, arg1):
|
||||||
|
"""fn doc."""
|
||||||
|
return arg0 + arg1
|
||||||
|
|
||||||
|
# Assert function docs are properly updated.
|
||||||
|
self.assertEqual(
|
||||||
|
"fn doc. (deprecated)"
|
||||||
|
"\n"
|
||||||
|
"\nTHIS FUNCTION IS DEPRECATED. It will be removed after %s."
|
||||||
|
"\nInstructions for updating:\n%s" % (date, instructions),
|
||||||
|
getattr(_Object, "_fn").__doc__)
|
||||||
|
|
||||||
|
# Assert calling new fn issues log warning.
|
||||||
|
self.assertEqual(3, _Object()._fn(1, 2))
|
||||||
|
self.assertEqual(1, mock_warning.call_count)
|
||||||
|
(args, _) = mock_warning.call_args
|
||||||
|
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
|
||||||
|
self._assert_subset(set([date, instructions]), set(args[1:]))
|
||||||
|
|
||||||
@tf.test.mock.patch.object(logging, "warning", autospec=True)
|
@tf.test.mock.patch.object(logging, "warning", autospec=True)
|
||||||
def test_instance_fn_no_doc(self, mock_warning):
|
def test_instance_fn_no_doc(self, mock_warning):
|
||||||
date = "2016-07-04"
|
date = "2016-07-04"
|
||||||
@ -280,5 +334,155 @@ class DeprecationTest(tf.test.TestCase):
|
|||||||
self._assert_subset(set([date, instructions]), set(args[1:]))
|
self._assert_subset(set([date, instructions]), set(args[1:]))
|
||||||
|
|
||||||
|
|
||||||
|
class DeprecatedArgsTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def _assert_subset(self, expected_subset, actual_set):
|
||||||
|
self.assertTrue(
|
||||||
|
actual_set.issuperset(expected_subset),
|
||||||
|
msg="%s is not a superset of %s." % (actual_set, expected_subset))
|
||||||
|
|
||||||
|
def test_deprecated_illegal_args(self):
|
||||||
|
instructions = "This is how you update..."
|
||||||
|
with self.assertRaisesRegexp(ValueError, "date"):
|
||||||
|
deprecation.deprecated_arg_values(
|
||||||
|
None, instructions, deprecated=True)
|
||||||
|
with self.assertRaisesRegexp(ValueError, "date"):
|
||||||
|
deprecation.deprecated_arg_values(
|
||||||
|
"", instructions, deprecated=True)
|
||||||
|
with self.assertRaisesRegexp(ValueError, "YYYY-MM-DD"):
|
||||||
|
deprecation.deprecated_arg_values(
|
||||||
|
"07-04-2016", instructions, deprecated=True)
|
||||||
|
date = "2016-07-04"
|
||||||
|
with self.assertRaisesRegexp(ValueError, "instructions"):
|
||||||
|
deprecation.deprecated_arg_values(
|
||||||
|
date, None, deprecated=True)
|
||||||
|
with self.assertRaisesRegexp(ValueError, "instructions"):
|
||||||
|
deprecation.deprecated_arg_values(
|
||||||
|
date, "", deprecated=True)
|
||||||
|
with self.assertRaisesRegexp(ValueError, "argument", deprecated=True):
|
||||||
|
deprecation.deprecated_arg_values(
|
||||||
|
date, instructions)
|
||||||
|
|
||||||
|
@tf.test.mock.patch.object(logging, "warning", autospec=True)
|
||||||
|
def test_static_fn_with_doc(self, mock_warning):
|
||||||
|
date = "2016-07-04"
|
||||||
|
instructions = "This is how you update..."
|
||||||
|
|
||||||
|
@deprecation.deprecated_arg_values(date, instructions, deprecated=True)
|
||||||
|
def _fn(arg0, arg1, deprecated=True):
|
||||||
|
"""fn doc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arg0: Arg 0.
|
||||||
|
arg1: Arg 1.
|
||||||
|
deprecated: Deprecated!
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sum of args.
|
||||||
|
"""
|
||||||
|
return arg0 + arg1 if deprecated else arg1 + arg0
|
||||||
|
|
||||||
|
# Assert function docs are properly updated.
|
||||||
|
self.assertEqual("_fn", _fn.__name__)
|
||||||
|
self.assertEqual(
|
||||||
|
"fn doc. (deprecated arguments)"
|
||||||
|
"\n"
|
||||||
|
"\nSOME ARGUMENTS ARE DEPRECATED. They will be removed after %s."
|
||||||
|
"\nInstructions for updating:\n%s"
|
||||||
|
"\n"
|
||||||
|
"\n Args:"
|
||||||
|
"\n arg0: Arg 0."
|
||||||
|
"\n arg1: Arg 1."
|
||||||
|
"\n deprecated: Deprecated!"
|
||||||
|
"\n"
|
||||||
|
"\n Returns:"
|
||||||
|
"\n Sum of args."
|
||||||
|
"\n " % (date, instructions),
|
||||||
|
_fn.__doc__)
|
||||||
|
|
||||||
|
# Assert calling new fn with non-deprecated value logs nothing.
|
||||||
|
self.assertEqual(3, _fn(1, 2, deprecated=False))
|
||||||
|
self.assertEqual(0, mock_warning.call_count)
|
||||||
|
|
||||||
|
# Assert calling new fn with deprecated value issues log warning.
|
||||||
|
self.assertEqual(3, _fn(1, 2, deprecated=True))
|
||||||
|
self.assertEqual(1, mock_warning.call_count)
|
||||||
|
(args, _) = mock_warning.call_args
|
||||||
|
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
|
||||||
|
self._assert_subset(set([date, instructions]), set(args[1:]))
|
||||||
|
|
||||||
|
# Assert calling new fn with default deprecated value issues log warning.
|
||||||
|
self.assertEqual(3, _fn(1, 2))
|
||||||
|
self.assertEqual(2, mock_warning.call_count)
|
||||||
|
|
||||||
|
@tf.test.mock.patch.object(logging, "warning", autospec=True)
|
||||||
|
def test_static_fn_with_one_line_doc(self, mock_warning):
|
||||||
|
date = "2016-07-04"
|
||||||
|
instructions = "This is how you update..."
|
||||||
|
|
||||||
|
@deprecation.deprecated_arg_values(date, instructions, deprecated=True)
|
||||||
|
def _fn(arg0, arg1, deprecated=True):
|
||||||
|
"""fn doc."""
|
||||||
|
return arg0 + arg1 if deprecated else arg1 + arg0
|
||||||
|
|
||||||
|
# Assert function docs are properly updated.
|
||||||
|
self.assertEqual("_fn", _fn.__name__)
|
||||||
|
self.assertEqual(
|
||||||
|
"fn doc. (deprecated arguments)"
|
||||||
|
"\n"
|
||||||
|
"\nSOME ARGUMENTS ARE DEPRECATED. They will be removed after %s."
|
||||||
|
"\nInstructions for updating:\n%s" % (date, instructions),
|
||||||
|
_fn.__doc__)
|
||||||
|
|
||||||
|
# Assert calling new fn with non-deprecated value logs nothing.
|
||||||
|
self.assertEqual(3, _fn(1, 2, deprecated=False))
|
||||||
|
self.assertEqual(0, mock_warning.call_count)
|
||||||
|
|
||||||
|
# Assert calling new fn with deprecated value issues log warning.
|
||||||
|
self.assertEqual(3, _fn(1, 2, deprecated=True))
|
||||||
|
self.assertEqual(1, mock_warning.call_count)
|
||||||
|
(args, _) = mock_warning.call_args
|
||||||
|
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
|
||||||
|
self._assert_subset(set([date, instructions]), set(args[1:]))
|
||||||
|
|
||||||
|
# Assert calling new fn with default deprecated value issues log warning.
|
||||||
|
self.assertEqual(3, _fn(1, 2))
|
||||||
|
self.assertEqual(2, mock_warning.call_count)
|
||||||
|
|
||||||
|
@tf.test.mock.patch.object(logging, "warning", autospec=True)
|
||||||
|
def test_static_fn_no_doc(self, mock_warning):
|
||||||
|
date = "2016-07-04"
|
||||||
|
instructions = "This is how you update..."
|
||||||
|
|
||||||
|
@deprecation.deprecated_arg_values(date, instructions, deprecated=True)
|
||||||
|
def _fn(arg0, arg1, deprecated=True):
|
||||||
|
return arg0 + arg1 if deprecated else arg1 + arg0
|
||||||
|
|
||||||
|
# Assert function docs are properly updated.
|
||||||
|
self.assertEqual("_fn", _fn.__name__)
|
||||||
|
self.assertEqual(
|
||||||
|
"DEPRECATED FUNCTION ARGUMENTS"
|
||||||
|
"\n"
|
||||||
|
"\nSOME ARGUMENTS ARE DEPRECATED. They will be removed after %s."
|
||||||
|
"\nInstructions for updating:"
|
||||||
|
"\n%s" % (date, instructions),
|
||||||
|
_fn.__doc__)
|
||||||
|
|
||||||
|
# Assert calling new fn with non-deprecated value logs nothing.
|
||||||
|
self.assertEqual(3, _fn(1, 2, deprecated=False))
|
||||||
|
self.assertEqual(0, mock_warning.call_count)
|
||||||
|
|
||||||
|
# Assert calling new fn issues log warning.
|
||||||
|
self.assertEqual(3, _fn(1, 2, deprecated=True))
|
||||||
|
self.assertEqual(1, mock_warning.call_count)
|
||||||
|
(args, _) = mock_warning.call_args
|
||||||
|
self.assertRegexpMatches(args[0], r"deprecated and will be removed after")
|
||||||
|
self._assert_subset(set([date, instructions]), set(args[1:]))
|
||||||
|
|
||||||
|
# Assert calling new fn with default deprecated value issues log warning.
|
||||||
|
self.assertEqual(3, _fn(1, 2))
|
||||||
|
self.assertEqual(2, mock_warning.call_count)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
@ -27,6 +27,7 @@ from tensorflow.python.ops import data_flow_ops
|
|||||||
from tensorflow.python.ops import logging_ops
|
from tensorflow.python.ops import logging_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.training import input as input_ops
|
from tensorflow.python.training import input as input_ops
|
||||||
from tensorflow.python.training import queue_runner
|
from tensorflow.python.training import queue_runner
|
||||||
|
|
||||||
@ -34,10 +35,8 @@ __all__ = ['stratified_sample',
|
|||||||
'stratified_sample_unknown_dist',]
|
'stratified_sample_unknown_dist',]
|
||||||
|
|
||||||
|
|
||||||
# TODO(joelshor): Use an exponential-moving-average to estimate the initial
|
def stratified_sample(tensors, labels, target_probs, batch_size,
|
||||||
# class distribution and remove the requirement that it be provided.
|
init_probs=None, enqueue_many=False, queue_capacity=16,
|
||||||
def stratified_sample(tensors, labels, init_probs, target_probs, batch_size,
|
|
||||||
enqueue_many=False, queue_capacity=16,
|
|
||||||
threads_per_queue=1, name=None):
|
threads_per_queue=1, name=None):
|
||||||
"""Stochastically creates batches based on per-class probabilities.
|
"""Stochastically creates batches based on per-class probabilities.
|
||||||
|
|
||||||
@ -52,11 +51,12 @@ def stratified_sample(tensors, labels, init_probs, target_probs, batch_size,
|
|||||||
batch, according to enqueue_many.
|
batch, according to enqueue_many.
|
||||||
labels: Tensor for label of data. Label is a single integer or a batch,
|
labels: Tensor for label of data. Label is a single integer or a batch,
|
||||||
depending on enqueue_many. It is not a one-hot vector.
|
depending on enqueue_many. It is not a one-hot vector.
|
||||||
init_probs: Class proportions in the data. An object whose type has a
|
|
||||||
registered Tensor conversion function.
|
|
||||||
target_probs: Target class proportions in batch. An object whose type has a
|
target_probs: Target class proportions in batch. An object whose type has a
|
||||||
registered Tensor conversion function.
|
registered Tensor conversion function.
|
||||||
batch_size: Size of batch to be returned.
|
batch_size: Size of batch to be returned.
|
||||||
|
init_probs: Class proportions in the data. An object whose type has a
|
||||||
|
registered Tensor conversion function, or `None` for estimating the
|
||||||
|
initial distribution.
|
||||||
enqueue_many: Bool. If true, interpret input tensors as having a batch
|
enqueue_many: Bool. If true, interpret input tensors as having a batch
|
||||||
dimension.
|
dimension.
|
||||||
queue_capacity: Capacity of the large queue that holds input examples.
|
queue_capacity: Capacity of the large queue that holds input examples.
|
||||||
@ -81,10 +81,9 @@ def stratified_sample(tensors, labels, init_probs, target_probs, batch_size,
|
|||||||
data, label = data_provider.Get(['data', 'label'])
|
data, label = data_provider.Get(['data', 'label'])
|
||||||
|
|
||||||
# Get stratified batch according to per-class probabilities.
|
# Get stratified batch according to per-class probabilities.
|
||||||
init_probs = [1.0/NUM_CLASSES for _ in range(NUM_CLASSES)]
|
|
||||||
target_probs = [...distribution you want...]
|
target_probs = [...distribution you want...]
|
||||||
[data_batch], labels = tf.contrib.framework.sampling_ops.stratified_sample(
|
[data_batch], labels = tf.contrib.framework.sampling_ops.stratified_sample(
|
||||||
[data], label, init_probs, target_probs)
|
[data], label, target_probs)
|
||||||
|
|
||||||
# Run batch through network.
|
# Run batch through network.
|
||||||
...
|
...
|
||||||
@ -92,22 +91,34 @@ def stratified_sample(tensors, labels, init_probs, target_probs, batch_size,
|
|||||||
with ops.op_scope(tensors + [labels], name, 'stratified_sample'):
|
with ops.op_scope(tensors + [labels], name, 'stratified_sample'):
|
||||||
tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors)
|
tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors)
|
||||||
labels = ops.convert_to_tensor(labels)
|
labels = ops.convert_to_tensor(labels)
|
||||||
init_probs = ops.convert_to_tensor(init_probs, dtype=dtypes.float32)
|
|
||||||
target_probs = ops.convert_to_tensor(target_probs, dtype=dtypes.float32)
|
target_probs = ops.convert_to_tensor(target_probs, dtype=dtypes.float32)
|
||||||
# Reduce the case of a single example to that of a batch of size 1.
|
# Reduce the case of a single example to that of a batch of size 1.
|
||||||
if not enqueue_many:
|
if not enqueue_many:
|
||||||
tensor_list = [array_ops.expand_dims(tensor, 0) for tensor in tensor_list]
|
tensor_list = [array_ops.expand_dims(tensor, 0) for tensor in tensor_list]
|
||||||
labels = array_ops.expand_dims(labels, 0)
|
labels = array_ops.expand_dims(labels, 0)
|
||||||
|
|
||||||
|
# If `init_probs` is `None`, set up online estimation of data distribution.
|
||||||
|
if init_probs is None:
|
||||||
|
# We use `target_probs` to get the number of classes, so its shape must be
|
||||||
|
# fully defined at graph construction time.
|
||||||
|
target_probs.get_shape().assert_is_fully_defined()
|
||||||
|
init_probs = _estimate_data_distribution(
|
||||||
|
labels, target_probs.get_shape().num_elements())
|
||||||
|
else:
|
||||||
|
init_probs = ops.convert_to_tensor(init_probs, dtype=dtypes.float32)
|
||||||
|
|
||||||
# Validate that input is consistent.
|
# Validate that input is consistent.
|
||||||
tensor_list, labels, [init_probs, target_probs] = _verify_input(
|
tensor_list, labels, [init_probs, target_probs] = _verify_input(
|
||||||
tensor_list, labels, [init_probs, target_probs])
|
tensor_list, labels, [init_probs, target_probs])
|
||||||
|
|
||||||
# Check that all zero initial probabilities also have zero target
|
# Check that all zero initial probabilities also have zero target
|
||||||
# probabilities.
|
# probabilities.
|
||||||
assert_op = logging_ops.Assert(math_ops.reduce_all(math_ops.logical_or(
|
assert_op = logging_ops.Assert(
|
||||||
|
math_ops.reduce_all(math_ops.logical_or(
|
||||||
math_ops.not_equal(init_probs, 0),
|
math_ops.not_equal(init_probs, 0),
|
||||||
math_ops.equal(target_probs, 0))), [init_probs, target_probs])
|
math_ops.equal(target_probs, 0))),
|
||||||
|
['All classes with zero initial probability must also have zero target '
|
||||||
|
'probability: ', init_probs, target_probs])
|
||||||
init_probs = control_flow_ops.with_dependencies([assert_op], init_probs)
|
init_probs = control_flow_ops.with_dependencies([assert_op], init_probs)
|
||||||
|
|
||||||
# Calculate acceptance sampling probabilities.
|
# Calculate acceptance sampling probabilities.
|
||||||
@ -212,6 +223,40 @@ def stratified_sample_unknown_dist(tensors, labels, probs, batch_size,
|
|||||||
per_class_queues, probs, batch_size)
|
per_class_queues, probs, batch_size)
|
||||||
|
|
||||||
|
|
||||||
|
def _estimate_data_distribution(labels, num_classes):
|
||||||
|
"""Estimate data distribution as labels are seen."""
|
||||||
|
# Variable to track running count of classes. Add 1 to avoid division-by-zero,
|
||||||
|
# and to guarantee that calculation of acceptance probabilities is (mostly)
|
||||||
|
# correct.
|
||||||
|
num_examples_per_class_seen = variables.Variable(
|
||||||
|
initial_value=[1] * num_classes, trainable=False, name='class_count',
|
||||||
|
dtype=dtypes.int64)
|
||||||
|
|
||||||
|
# Update the class-count based on what labels are seen in batch.
|
||||||
|
num_examples_per_class_seen = num_examples_per_class_seen.assign_add(
|
||||||
|
math_ops.reduce_sum(array_ops.one_hot(labels, num_classes,
|
||||||
|
dtype=dtypes.int64), 0))
|
||||||
|
|
||||||
|
# Normalize count into a probability.
|
||||||
|
# NOTE: Without the `+= 0` line below, the test
|
||||||
|
# `testMultiThreadedEstimateDataDistribution` fails. The reason is that
|
||||||
|
# before this line, `num_examples_per_class_seen` is a Tensor that shares a
|
||||||
|
# buffer with an underlying `ref` object. When the `ref` is changed by another
|
||||||
|
# thread, `num_examples_per_class_seen` changes as well. Since this can happen
|
||||||
|
# in the middle of the normalization computation, we get probabilities that
|
||||||
|
# are very far from summing to one. Adding `+= 0` copies the contents of the
|
||||||
|
# tensor to a new buffer, which will be consistent from the start to the end
|
||||||
|
# of the normalization computation.
|
||||||
|
num_examples_per_class_seen += 0
|
||||||
|
init_prob_estimate = math_ops.truediv(
|
||||||
|
num_examples_per_class_seen,
|
||||||
|
math_ops.reduce_sum(num_examples_per_class_seen))
|
||||||
|
|
||||||
|
# Must return float32 (not float64) to agree with downstream `_verify_input`
|
||||||
|
# checks.
|
||||||
|
return math_ops.cast(init_prob_estimate, dtypes.float32)
|
||||||
|
|
||||||
|
|
||||||
def _verify_input(tensor_list, labels, probs_list):
|
def _verify_input(tensor_list, labels, probs_list):
|
||||||
"""Verify that batched inputs are well-formed."""
|
"""Verify that batched inputs are well-formed."""
|
||||||
checked_probs_list = []
|
checked_probs_list = []
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
|
||||||
|
|
||||||
class SamplingOpsTest(tf.test.TestCase):
|
class SamplingOpsTest(tf.test.TestCase):
|
||||||
@ -33,15 +34,22 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
|
|
||||||
# Curry the rejection sampler so we can easily run the same tests on both
|
# Curry the rejection sampler so we can easily run the same tests on both
|
||||||
# stratified_sample and stratified_sample_unknown_dist.
|
# stratified_sample and stratified_sample_unknown_dist.
|
||||||
def curried_sampler(val, lbls, probs, batch, enqueue_many=True):
|
def curried_sampler(tensors, labels, probs, batch_size, enqueue_many=True):
|
||||||
return tf.contrib.framework.sampling_ops.stratified_sample(
|
return tf.contrib.framework.sampling_ops.stratified_sample(
|
||||||
val, lbls, initial_p, probs, batch, enqueue_many=enqueue_many)
|
tensors=tensors,
|
||||||
|
labels=labels,
|
||||||
|
target_probs=probs,
|
||||||
|
batch_size=batch_size,
|
||||||
|
init_probs=initial_p,
|
||||||
|
enqueue_many=enqueue_many)
|
||||||
|
|
||||||
samplers = [
|
samplers = [
|
||||||
tf.contrib.framework.sampling_ops.stratified_sample_unknown_dist,
|
tf.contrib.framework.sampling_ops.stratified_sample_unknown_dist,
|
||||||
curried_sampler,
|
curried_sampler,
|
||||||
]
|
]
|
||||||
|
|
||||||
for sampler in samplers:
|
for sampler in samplers:
|
||||||
|
logging.info('Now testing `%s`', sampler.__class__.__name__)
|
||||||
# Label must have only batch dimension if enqueue_many is True.
|
# Label must have only batch dimension if enqueue_many is True.
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
sampler(val, tf.zeros([]), probs, batch_size, enqueue_many=True)
|
sampler(val, tf.zeros([]), probs, batch_size, enqueue_many=True)
|
||||||
@ -70,20 +78,21 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
|
|
||||||
# Probabilities shape must be fully defined.
|
# Probabilities shape must be fully defined.
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
sampler(val, label, tf.placeholder(tf.float32, shape=[None]),
|
sampler(
|
||||||
batch_size)
|
val, label, tf.placeholder(
|
||||||
|
tf.float32, shape=[None]), batch_size)
|
||||||
|
|
||||||
# In the rejection sampling case, make sure that probability lengths are
|
# In the rejection sampling case, make sure that probability lengths are
|
||||||
# the same.
|
# the same.
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
tf.contrib.framework.sampling_ops.stratified_sample(
|
tf.contrib.framework.sampling_ops.stratified_sample(
|
||||||
val, label, [.2] * 5, [.1] * 10, batch_size)
|
val, label, [.1] * 10, batch_size, init_probs=[.2] * 5)
|
||||||
|
|
||||||
# In the rejection sampling case, make sure that zero initial probability
|
# In the rejection sampling case, make sure that zero initial probability
|
||||||
# classes also have zero target probability.
|
# classes also have zero target probability.
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
tf.contrib.framework.sampling_ops.stratified_sample(
|
tf.contrib.framework.sampling_ops.stratified_sample(
|
||||||
val, label, [0, .5, .5], [.2, .4, .4], batch_size)
|
val, label, [.2, .4, .4], batch_size, init_probs=[0, .5, .5])
|
||||||
|
|
||||||
# Probabilities must be 1D.
|
# Probabilities must be 1D.
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
@ -116,14 +125,16 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
# Run session that should fail.
|
# Run session that should fail.
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with self.assertRaises(tf.errors.InvalidArgumentError):
|
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||||
sess.run([val_tf, lbl_tf], feed_dict={label_ph: illegal_label,
|
sess.run([val_tf, lbl_tf],
|
||||||
|
feed_dict={label_ph: illegal_label,
|
||||||
probs_ph: valid_probs})
|
probs_ph: valid_probs})
|
||||||
|
|
||||||
for illegal_prob in illegal_probs:
|
for illegal_prob in illegal_probs:
|
||||||
# Run session that should fail.
|
# Run session that should fail.
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with self.assertRaises(tf.errors.InvalidArgumentError):
|
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||||
sess.run([prob_tf], feed_dict={label_ph: valid_labels,
|
sess.run([prob_tf],
|
||||||
|
feed_dict={label_ph: valid_labels,
|
||||||
probs_ph: illegal_prob})
|
probs_ph: illegal_prob})
|
||||||
|
|
||||||
def batchingBehaviorHelper(self, sampler):
|
def batchingBehaviorHelper(self, sampler):
|
||||||
@ -152,15 +163,14 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
lbl_input_batch = tf.ones([], dtype=tf.int32)
|
lbl_input_batch = tf.ones([], dtype=tf.int32)
|
||||||
probs = np.array([0, 1, 0, 0, 0])
|
probs = np.array([0, 1, 0, 0, 0])
|
||||||
batches = tf.contrib.framework.sampling_ops.stratified_sample(
|
batches = tf.contrib.framework.sampling_ops.stratified_sample(
|
||||||
val_input_batch, lbl_input_batch, probs, probs, batch_size)
|
val_input_batch, lbl_input_batch, probs, batch_size, init_probs=probs)
|
||||||
batches += tf.contrib.framework.sampling_ops.stratified_sample(
|
batches += tf.contrib.framework.sampling_ops.stratified_sample(
|
||||||
val_input_batch, lbl_input_batch, probs, probs, batch_size)
|
val_input_batch, lbl_input_batch, probs, batch_size, init_probs=probs)
|
||||||
batches += tf.contrib.framework.sampling_ops.stratified_sample_unknown_dist(
|
batches += tf.contrib.framework.sampling_ops.stratified_sample_unknown_dist(
|
||||||
val_input_batch, lbl_input_batch, probs, batch_size)
|
val_input_batch, lbl_input_batch, probs, batch_size)
|
||||||
batches += tf.contrib.framework.sampling_ops.stratified_sample_unknown_dist(
|
batches += tf.contrib.framework.sampling_ops.stratified_sample_unknown_dist(
|
||||||
val_input_batch, lbl_input_batch, probs, batch_size)
|
val_input_batch, lbl_input_batch, probs, batch_size)
|
||||||
summary_op = tf.merge_summary(tf.get_collection(
|
summary_op = tf.merge_summary(tf.get_collection(tf.GraphKeys.SUMMARIES))
|
||||||
tf.GraphKeys.SUMMARIES))
|
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
coord = tf.train.Coordinator()
|
coord = tf.train.Coordinator()
|
||||||
@ -177,9 +187,15 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
|
|
||||||
def testRejectionBatchingBehavior(self):
|
def testRejectionBatchingBehavior(self):
|
||||||
initial_p = [0, .3, 0, .7, 0]
|
initial_p = [0, .3, 0, .7, 0]
|
||||||
|
|
||||||
def curried_sampler(val, lbls, probs, batch, enqueue_many=True):
|
def curried_sampler(val, lbls, probs, batch, enqueue_many=True):
|
||||||
return tf.contrib.framework.sampling_ops.stratified_sample(
|
return tf.contrib.framework.sampling_ops.stratified_sample(
|
||||||
val, lbls, initial_p, probs, batch, enqueue_many=enqueue_many)
|
val,
|
||||||
|
lbls,
|
||||||
|
probs,
|
||||||
|
batch,
|
||||||
|
init_probs=initial_p,
|
||||||
|
enqueue_many=enqueue_many)
|
||||||
|
|
||||||
self.batchingBehaviorHelper(curried_sampler)
|
self.batchingBehaviorHelper(curried_sampler)
|
||||||
|
|
||||||
@ -190,8 +206,7 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
lbl2 = 3
|
lbl2 = 3
|
||||||
# This cond allows the necessary class queues to be populated.
|
# This cond allows the necessary class queues to be populated.
|
||||||
label = tf.cond(
|
label = tf.cond(
|
||||||
tf.greater(.5, tf.random_uniform([])),
|
tf.greater(.5, tf.random_uniform([])), lambda: tf.constant(lbl1),
|
||||||
lambda: tf.constant(lbl1),
|
|
||||||
lambda: tf.constant(lbl2))
|
lambda: tf.constant(lbl2))
|
||||||
val = [np.array([1, 4]) * label]
|
val = [np.array([1, 4]) * label]
|
||||||
probs = tf.placeholder(tf.float32, shape=[5])
|
probs = tf.placeholder(tf.float32, shape=[5])
|
||||||
@ -243,7 +258,8 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
# Run graph to make sure there are no shape-related runtime errors.
|
# Run graph to make sure there are no shape-related runtime errors.
|
||||||
for vals, labels in legal_input_pairs:
|
for vals, labels in legal_input_pairs:
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
sess.run([val_tf, labels_tf], feed_dict={vals_ph: vals,
|
sess.run([val_tf, labels_tf],
|
||||||
|
feed_dict={vals_ph: vals,
|
||||||
labels_ph: labels})
|
labels_ph: labels})
|
||||||
|
|
||||||
def dataListHelper(self, sampler):
|
def dataListHelper(self, sampler):
|
||||||
@ -251,8 +267,8 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
val_input_batch = [tf.zeros([2, 3, 4]), tf.ones([2, 4]), tf.ones(2) * 3]
|
val_input_batch = [tf.zeros([2, 3, 4]), tf.ones([2, 4]), tf.ones(2) * 3]
|
||||||
lbl_input_batch = tf.ones([], dtype=tf.int32)
|
lbl_input_batch = tf.ones([], dtype=tf.int32)
|
||||||
probs = np.array([0, 1, 0, 0, 0])
|
probs = np.array([0, 1, 0, 0, 0])
|
||||||
val_list, lbls = sampler(
|
val_list, lbls = sampler(val_input_batch, lbl_input_batch, probs,
|
||||||
val_input_batch, lbl_input_batch, probs, batch_size)
|
batch_size)
|
||||||
|
|
||||||
# Check output shapes.
|
# Check output shapes.
|
||||||
self.assertTrue(isinstance(val_list, list))
|
self.assertTrue(isinstance(val_list, list))
|
||||||
@ -277,9 +293,16 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
|
|
||||||
def testRejectionDataListInput(self):
|
def testRejectionDataListInput(self):
|
||||||
initial_p = [0, 1, 0, 0, 0]
|
initial_p = [0, 1, 0, 0, 0]
|
||||||
|
|
||||||
def curried_sampler(val, lbls, probs, batch, enqueue_many=False):
|
def curried_sampler(val, lbls, probs, batch, enqueue_many=False):
|
||||||
return tf.contrib.framework.sampling_ops.stratified_sample(
|
return tf.contrib.framework.sampling_ops.stratified_sample(
|
||||||
val, lbls, initial_p, probs, batch, enqueue_many=enqueue_many)
|
val,
|
||||||
|
lbls,
|
||||||
|
probs,
|
||||||
|
batch,
|
||||||
|
init_probs=initial_p,
|
||||||
|
enqueue_many=enqueue_many)
|
||||||
|
|
||||||
self.dataListHelper(curried_sampler)
|
self.dataListHelper(curried_sampler)
|
||||||
|
|
||||||
def normalBehaviorHelper(self, sampler):
|
def normalBehaviorHelper(self, sampler):
|
||||||
@ -289,8 +312,7 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
lbl2 = 3
|
lbl2 = 3
|
||||||
# This cond allows the necessary class queues to be populated.
|
# This cond allows the necessary class queues to be populated.
|
||||||
label = tf.cond(
|
label = tf.cond(
|
||||||
tf.greater(.5, tf.random_uniform([])),
|
tf.greater(.5, tf.random_uniform([])), lambda: tf.constant(lbl1),
|
||||||
lambda: tf.constant(lbl1),
|
|
||||||
lambda: tf.constant(lbl2))
|
lambda: tf.constant(lbl2))
|
||||||
val = [np.array([1, 4]) * label]
|
val = [np.array([1, 4]) * label]
|
||||||
probs = np.array([.8, 0, 0, .2, 0])
|
probs = np.array([.8, 0, 0, .2, 0])
|
||||||
@ -302,6 +324,9 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
data_l = []
|
data_l = []
|
||||||
label_l = []
|
label_l = []
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
|
# Need to initialize variables that keep running total of classes seen.
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
|
||||||
coord = tf.train.Coordinator()
|
coord = tf.train.Coordinator()
|
||||||
threads = tf.train.start_queue_runners(coord=coord)
|
threads = tf.train.start_queue_runners(coord=coord)
|
||||||
|
|
||||||
@ -337,10 +362,26 @@ class SamplingOpsTest(tf.test.TestCase):
|
|||||||
|
|
||||||
def testRejectionNormalBehavior(self):
|
def testRejectionNormalBehavior(self):
|
||||||
initial_p = [.7, 0, 0, .3, 0]
|
initial_p = [.7, 0, 0, .3, 0]
|
||||||
|
|
||||||
def curried_sampler(val, lbls, probs, batch, enqueue_many=False):
|
def curried_sampler(val, lbls, probs, batch, enqueue_many=False):
|
||||||
return tf.contrib.framework.sampling_ops.stratified_sample(
|
return tf.contrib.framework.sampling_ops.stratified_sample(
|
||||||
val, lbls, initial_p, probs, batch, enqueue_many=enqueue_many)
|
val,
|
||||||
|
lbls,
|
||||||
|
probs,
|
||||||
|
batch,
|
||||||
|
init_probs=initial_p,
|
||||||
|
enqueue_many=enqueue_many)
|
||||||
|
|
||||||
self.normalBehaviorHelper(curried_sampler)
|
self.normalBehaviorHelper(curried_sampler)
|
||||||
|
|
||||||
|
def testRejectionNormalBehaviorWithOnlineInitPEstimate(self):
|
||||||
|
|
||||||
|
def curried_sampler(val, lbls, probs, batch, enqueue_many=False):
|
||||||
|
return tf.contrib.framework.sampling_ops.stratified_sample(
|
||||||
|
val, lbls, probs, batch, init_probs=None, enqueue_many=enqueue_many)
|
||||||
|
|
||||||
|
self.normalBehaviorHelper(curried_sampler)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
@ -0,0 +1,65 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
# pylint: disable=unused-import
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class SamplingOpsThreadingTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def testMultiThreadedEstimateDataDistribution(self):
|
||||||
|
num_classes = 10
|
||||||
|
|
||||||
|
# Set up graph.
|
||||||
|
tf.set_random_seed(1234)
|
||||||
|
label = tf.cast(tf.round(tf.random_uniform([1]) * num_classes), tf.int32)
|
||||||
|
|
||||||
|
prob_estimate = tf.contrib.framework.sampling_ops._estimate_data_distribution( # pylint: disable=line-too-long
|
||||||
|
label, num_classes)
|
||||||
|
# Check that prob_estimate is well-behaved in a multithreaded context.
|
||||||
|
_, _, [prob_estimate] = tf.contrib.framework.sampling_ops._verify_input(
|
||||||
|
[], label, [prob_estimate])
|
||||||
|
|
||||||
|
# Use queues to run multiple threads over the graph, each of which
|
||||||
|
# fetches `prob_estimate`.
|
||||||
|
queue = tf.FIFOQueue(
|
||||||
|
capacity=25,
|
||||||
|
dtypes=[prob_estimate.dtype],
|
||||||
|
shapes=[prob_estimate.get_shape()])
|
||||||
|
enqueue_op = queue.enqueue([prob_estimate])
|
||||||
|
tf.train.add_queue_runner(tf.train.QueueRunner(queue, [enqueue_op] * 25))
|
||||||
|
out_tensor = queue.dequeue()
|
||||||
|
|
||||||
|
# Run the multi-threaded session.
|
||||||
|
with self.test_session() as sess:
|
||||||
|
# Need to initialize variables that keep running total of classes seen.
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
|
||||||
|
coord = tf.train.Coordinator()
|
||||||
|
threads = tf.train.start_queue_runners(coord=coord)
|
||||||
|
|
||||||
|
for _ in range(25):
|
||||||
|
sess.run([out_tensor])
|
||||||
|
|
||||||
|
coord.request_stop()
|
||||||
|
coord.join(threads)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
REGISTER_OP("SparseFeatureCross")
|
REGISTER_OP("SparseFeatureCross")
|
||||||
@ -31,6 +32,12 @@ REGISTER_OP("SparseFeatureCross")
|
|||||||
.Attr("dense_types: list({int64, string}) >= 0")
|
.Attr("dense_types: list({int64, string}) >= 0")
|
||||||
.Attr("out_type: {int64, string}")
|
.Attr("out_type: {int64, string}")
|
||||||
.Attr("internal_type: {int64, string}")
|
.Attr("internal_type: {int64, string}")
|
||||||
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
|
c->set_output(0, c->Matrix(c->UnknownDim(), 2));
|
||||||
|
c->set_output(1, c->Vector(c->UnknownDim()));
|
||||||
|
c->set_output(2, c->Vector(2));
|
||||||
|
return Status::OK();
|
||||||
|
})
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Generates sparse cross form a list of sparse tensors.
|
Generates sparse cross form a list of sparse tensors.
|
||||||
|
|
||||||
|
@ -193,35 +193,36 @@ class _SparseColumn(_FeatureColumn,
|
|||||||
combiner="sum",
|
combiner="sum",
|
||||||
dtype=dtypes.string):
|
dtype=dtypes.string):
|
||||||
if is_integerized and bucket_size is None:
|
if is_integerized and bucket_size is None:
|
||||||
raise ValueError("bucket_size should be set if is_integerized=True. "
|
raise ValueError("bucket_size must be set if is_integerized is True. "
|
||||||
"column_name: {}".format(column_name))
|
"column_name: {}".format(column_name))
|
||||||
|
|
||||||
if is_integerized and not dtype.is_integer:
|
if is_integerized and not dtype.is_integer:
|
||||||
raise ValueError("dtype should be an integer if is_integerized is True. "
|
raise ValueError("dtype must be an integer if is_integerized is True. "
|
||||||
"Column {}.".format(column_name))
|
"dtype: {}, column_name: {}.".format(dtype, column_name))
|
||||||
|
|
||||||
if bucket_size is None and lookup_config is None:
|
if bucket_size is None and lookup_config is None:
|
||||||
raise ValueError("one of bucket_size or lookup_config should be "
|
raise ValueError("one of bucket_size or lookup_config must be set. "
|
||||||
"set. column_name: {}".format(column_name))
|
"column_name: {}".format(column_name))
|
||||||
|
|
||||||
if bucket_size is not None and lookup_config:
|
if bucket_size is not None and lookup_config:
|
||||||
raise ValueError("one and only one of bucket_size or lookup_config "
|
raise ValueError("one and only one of bucket_size or lookup_config "
|
||||||
"should be set. column_name: {}".format(column_name))
|
"must be set. column_name: {}".format(column_name))
|
||||||
|
|
||||||
if bucket_size is not None and bucket_size < 2:
|
if bucket_size is not None and bucket_size < 2:
|
||||||
raise ValueError("bucket_size should be at least 2. "
|
raise ValueError("bucket_size must be at least 2. "
|
||||||
"column_name: {}".format(column_name))
|
"bucket_size: {}, column_name: {}".format(bucket_size,
|
||||||
|
column_name))
|
||||||
|
|
||||||
if ((lookup_config) and
|
if ((lookup_config) and
|
||||||
(not isinstance(lookup_config, _SparseIdLookupConfig))):
|
(not isinstance(lookup_config, _SparseIdLookupConfig))):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"lookup_config should be an instance of _SparseIdLookupConfig. "
|
"lookup_config must be an instance of _SparseIdLookupConfig. "
|
||||||
"Given one is in type {} for column_name {}".format(
|
"Given one is in type {} for column_name {}".format(
|
||||||
type(lookup_config), column_name))
|
type(lookup_config), column_name))
|
||||||
|
|
||||||
if (lookup_config and lookup_config.vocabulary_file and
|
if (lookup_config and lookup_config.vocabulary_file and
|
||||||
lookup_config.vocab_size is None):
|
lookup_config.vocab_size is None):
|
||||||
raise ValueError("vocab_size should be defined. "
|
raise ValueError("vocab_size must be defined. "
|
||||||
"column_name: {}".format(column_name))
|
"column_name: {}".format(column_name))
|
||||||
|
|
||||||
return super(_SparseColumn, cls).__new__(cls, column_name, is_integerized,
|
return super(_SparseColumn, cls).__new__(cls, column_name, is_integerized,
|
||||||
@ -262,8 +263,8 @@ class _SparseColumn(_FeatureColumn,
|
|||||||
input_tensor,
|
input_tensor,
|
||||||
weight_collections=None,
|
weight_collections=None,
|
||||||
trainable=True):
|
trainable=True):
|
||||||
raise ValueError("Column {} is not supported in DNN. "
|
raise ValueError("SparseColumn is not supported in DNN. "
|
||||||
"Please use embedding_column.".format(self))
|
"Please use embedding_column. column: {}".format(self))
|
||||||
|
|
||||||
def to_weighted_sum(self,
|
def to_weighted_sum(self,
|
||||||
input_tensor,
|
input_tensor,
|
||||||
@ -279,7 +280,7 @@ class _SparseColumn(_FeatureColumn,
|
|||||||
initializer=init_ops.zeros_initializer,
|
initializer=init_ops.zeros_initializer,
|
||||||
combiner=self.combiner,
|
combiner=self.combiner,
|
||||||
trainable=trainable,
|
trainable=trainable,
|
||||||
name=self.name + "_weights")
|
name=self.name)
|
||||||
|
|
||||||
|
|
||||||
class _SparseColumnIntegerized(_SparseColumn):
|
class _SparseColumnIntegerized(_SparseColumn):
|
||||||
@ -291,8 +292,8 @@ class _SparseColumnIntegerized(_SparseColumn):
|
|||||||
combiner="sum",
|
combiner="sum",
|
||||||
dtype=dtypes.int64):
|
dtype=dtypes.int64):
|
||||||
if not dtype.is_integer:
|
if not dtype.is_integer:
|
||||||
raise ValueError("dtype should be an integer. Given {}".format(
|
raise ValueError("dtype must be an integer. "
|
||||||
column_name))
|
"dtype: {}, column_name: {}".format(dtype, column_name))
|
||||||
|
|
||||||
return super(_SparseColumnIntegerized, cls).__new__(cls,
|
return super(_SparseColumnIntegerized, cls).__new__(cls,
|
||||||
column_name,
|
column_name,
|
||||||
@ -507,8 +508,8 @@ class _WeightedSparseColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
input_tensor,
|
input_tensor,
|
||||||
weight_collections=None,
|
weight_collections=None,
|
||||||
trainable=True):
|
trainable=True):
|
||||||
raise ValueError("Column {} is not supported in DNN. "
|
raise ValueError("WeightedSparseColumn is not supported in DNN. "
|
||||||
"Please use embedding_column.".format(self))
|
"Please use embedding_column. column: {}".format(self))
|
||||||
|
|
||||||
def to_weighted_sum(self,
|
def to_weighted_sum(self,
|
||||||
input_tensor,
|
input_tensor,
|
||||||
@ -524,7 +525,7 @@ class _WeightedSparseColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
initializer=init_ops.zeros_initializer,
|
initializer=init_ops.zeros_initializer,
|
||||||
combiner=self.sparse_id_column.combiner,
|
combiner=self.sparse_id_column.combiner,
|
||||||
trainable=trainable,
|
trainable=trainable,
|
||||||
name=self.name + "_weights")
|
name=self.name)
|
||||||
|
|
||||||
|
|
||||||
def weighted_sparse_column(sparse_id_column,
|
def weighted_sparse_column(sparse_id_column,
|
||||||
@ -609,7 +610,9 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
ckpt_to_load_from=None,
|
ckpt_to_load_from=None,
|
||||||
tensor_name_in_ckpt=None):
|
tensor_name_in_ckpt=None):
|
||||||
if initializer is not None and not callable(initializer):
|
if initializer is not None and not callable(initializer):
|
||||||
raise ValueError("initializer must be callable if specified.")
|
raise ValueError("initializer must be callable if specified. "
|
||||||
|
"Embedding of column_name: {}".format(
|
||||||
|
sparse_id_column.name))
|
||||||
|
|
||||||
if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
|
if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
|
||||||
raise ValueError("Must specify both `ckpt_to_load_from` and "
|
raise ValueError("Must specify both `ckpt_to_load_from` and "
|
||||||
@ -674,7 +677,7 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
initializer=self.initializer,
|
initializer=self.initializer,
|
||||||
combiner=self.combiner,
|
combiner=self.combiner,
|
||||||
trainable=trainable,
|
trainable=trainable,
|
||||||
name=self.name + "_weights")
|
name=self.name)
|
||||||
if self.ckpt_to_load_from is not None:
|
if self.ckpt_to_load_from is not None:
|
||||||
weights_to_restore = embedding_weights
|
weights_to_restore = embedding_weights
|
||||||
if len(embedding_weights) == 1:
|
if len(embedding_weights) == 1:
|
||||||
@ -690,8 +693,8 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
num_outputs=1,
|
num_outputs=1,
|
||||||
weight_collections=None,
|
weight_collections=None,
|
||||||
trainable=True):
|
trainable=True):
|
||||||
raise ValueError("Column {} is not supported in linear models. "
|
raise ValueError("EmbeddingColumn is not supported in linear models. "
|
||||||
"Please use sparse_column.".format(self))
|
"Please use sparse_column. column: {}".format(self))
|
||||||
|
|
||||||
|
|
||||||
def embedding_column(sparse_id_column,
|
def embedding_column(sparse_id_column,
|
||||||
@ -744,7 +747,8 @@ class _HashedEmbeddingColumn(collections.namedtuple(
|
|||||||
combiner="mean",
|
combiner="mean",
|
||||||
initializer=None):
|
initializer=None):
|
||||||
if initializer is not None and not callable(initializer):
|
if initializer is not None and not callable(initializer):
|
||||||
raise ValueError("initializer must be callable if specified.")
|
raise ValueError("initializer must be callable if specified. "
|
||||||
|
"column_name: {}".format(column_name))
|
||||||
if initializer is None:
|
if initializer is None:
|
||||||
stddev = 0.1
|
stddev = 0.1
|
||||||
# TODO(b/25671353): Better initial value?
|
# TODO(b/25671353): Better initial value?
|
||||||
@ -770,7 +774,7 @@ class _HashedEmbeddingColumn(collections.namedtuple(
|
|||||||
weight_collections=None,
|
weight_collections=None,
|
||||||
trainable=True):
|
trainable=True):
|
||||||
embeddings = _create_embeddings(
|
embeddings = _create_embeddings(
|
||||||
name=self.name + "_weights",
|
name=self.name,
|
||||||
shape=[self.size],
|
shape=[self.size],
|
||||||
initializer=self.initializer,
|
initializer=self.initializer,
|
||||||
dtype=dtypes.float32,
|
dtype=dtypes.float32,
|
||||||
@ -815,10 +819,14 @@ def hashed_embedding_column(column_name,
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
if (dimension < 1) or (size < 1):
|
if (dimension < 1) or (size < 1):
|
||||||
raise ValueError("Dimension and size must be greater than 0.")
|
raise ValueError("Dimension and size must be greater than 0. "
|
||||||
|
"dimension: {}, size: {}, column_name: {}".format(
|
||||||
|
dimension, size, column_name))
|
||||||
|
|
||||||
if combiner not in ("mean", "sqrtn", "sum"):
|
if combiner not in ("mean", "sqrtn", "sum"):
|
||||||
raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'.")
|
raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'. "
|
||||||
|
"combiner: {}, column_name: {}".format(
|
||||||
|
combiner, column_name))
|
||||||
|
|
||||||
return _HashedEmbeddingColumn(column_name, size, dimension, combiner,
|
return _HashedEmbeddingColumn(column_name, size, dimension, combiner,
|
||||||
initializer)
|
initializer)
|
||||||
@ -929,14 +937,18 @@ def real_valued_column(column_name,
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if not isinstance(dimension, int):
|
if not isinstance(dimension, int):
|
||||||
raise TypeError("dimension must be an integer")
|
raise TypeError("dimension must be an integer. "
|
||||||
|
"dimension: {}, column_name: {}".format(dimension,
|
||||||
|
column_name))
|
||||||
|
|
||||||
if dimension < 1:
|
if dimension < 1:
|
||||||
raise ValueError("dimension must be greater than 0")
|
raise ValueError("dimension must be greater than 0. "
|
||||||
|
"dimension: {}, column_name: {}".format(dimension,
|
||||||
|
column_name))
|
||||||
|
|
||||||
if not (dtype.is_integer or dtype.is_floating):
|
if not (dtype.is_integer or dtype.is_floating):
|
||||||
raise ValueError("dtype is not convertible to tf.float32. Given {}".format(
|
raise ValueError("dtype must be convertible to float. "
|
||||||
dtype))
|
"dtype: {}, column_name: {}".format(dtype, column_name))
|
||||||
|
|
||||||
if default_value is None:
|
if default_value is None:
|
||||||
return _RealValuedColumn(column_name, dimension, default_value, dtype)
|
return _RealValuedColumn(column_name, dimension, default_value, dtype)
|
||||||
@ -957,9 +969,10 @@ def real_valued_column(column_name,
|
|||||||
|
|
||||||
if isinstance(default_value, list):
|
if isinstance(default_value, list):
|
||||||
if len(default_value) != dimension:
|
if len(default_value) != dimension:
|
||||||
raise ValueError("The length of default_value is not equal to the "
|
raise ValueError(
|
||||||
"value of dimension. default_value is {}.".format(
|
"The length of default_value must be equal to dimension. "
|
||||||
default_value))
|
"default_value: {}, dimension: {}, column_name: {}".format(
|
||||||
|
default_value, dimension, column_name))
|
||||||
# Check if the values in the list are all integers or are convertible to
|
# Check if the values in the list are all integers or are convertible to
|
||||||
# floats.
|
# floats.
|
||||||
is_list_all_int = True
|
is_list_all_int = True
|
||||||
@ -980,8 +993,9 @@ def real_valued_column(column_name,
|
|||||||
default_value = [float(v) for v in default_value]
|
default_value = [float(v) for v in default_value]
|
||||||
return _RealValuedColumn(column_name, dimension, default_value, dtype)
|
return _RealValuedColumn(column_name, dimension, default_value, dtype)
|
||||||
|
|
||||||
raise TypeError("default_value is not compatible with dtype. "
|
raise TypeError("default_value must be compatible with dtype. "
|
||||||
"default_value is {}.".format(default_value))
|
"default_value: {}, dtype: {}, column_name: {}".format(
|
||||||
|
default_value, dtype, column_name))
|
||||||
|
|
||||||
|
|
||||||
class _BucketizedColumn(_FeatureColumn, collections.namedtuple(
|
class _BucketizedColumn(_FeatureColumn, collections.namedtuple(
|
||||||
@ -1008,10 +1022,12 @@ class _BucketizedColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
def __new__(cls, source_column, boundaries):
|
def __new__(cls, source_column, boundaries):
|
||||||
if not isinstance(source_column, _RealValuedColumn):
|
if not isinstance(source_column, _RealValuedColumn):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"source_column should be an instance of _RealValuedColumn.")
|
"source_column must be an instance of _RealValuedColumn. "
|
||||||
|
"source_column: {}".format(source_column))
|
||||||
|
|
||||||
if not isinstance(boundaries, list) or not boundaries:
|
if not isinstance(boundaries, list) or not boundaries:
|
||||||
raise ValueError("boundaries must be a list and it should not be empty.")
|
raise ValueError("boundaries must be a non-empty list. "
|
||||||
|
"boundaries: {}".format(boundaries))
|
||||||
|
|
||||||
# We allow bucket boundaries to be monotonically increasing
|
# We allow bucket boundaries to be monotonically increasing
|
||||||
# (ie a[i+1] >= a[i]). When two bucket boundaries are the same, we
|
# (ie a[i+1] >= a[i]). When two bucket boundaries are the same, we
|
||||||
@ -1023,7 +1039,8 @@ class _BucketizedColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
elif boundaries[i] < boundaries[i + 1]:
|
elif boundaries[i] < boundaries[i + 1]:
|
||||||
sanitized_boundaries.append(boundaries[i])
|
sanitized_boundaries.append(boundaries[i])
|
||||||
else:
|
else:
|
||||||
raise ValueError("boundaries must be a sorted list")
|
raise ValueError("boundaries must be a sorted list. "
|
||||||
|
"boundaries: {}".format(boundaries))
|
||||||
sanitized_boundaries.append(boundaries[len(boundaries) - 1])
|
sanitized_boundaries.append(boundaries[len(boundaries) - 1])
|
||||||
|
|
||||||
return super(_BucketizedColumn, cls).__new__(cls, source_column,
|
return super(_BucketizedColumn, cls).__new__(cls, source_column,
|
||||||
@ -1104,7 +1121,7 @@ class _BucketizedColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
initializer=init_ops.zeros_initializer,
|
initializer=init_ops.zeros_initializer,
|
||||||
combiner="sum",
|
combiner="sum",
|
||||||
trainable=trainable,
|
trainable=trainable,
|
||||||
name=self.name + "_weights")
|
name=self.name)
|
||||||
|
|
||||||
|
|
||||||
def bucketized_column(source_column, boundaries):
|
def bucketized_column(source_column, boundaries):
|
||||||
@ -1186,18 +1203,21 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
ckpt_to_load_from=None, tensor_name_in_ckpt=None):
|
ckpt_to_load_from=None, tensor_name_in_ckpt=None):
|
||||||
for column in columns:
|
for column in columns:
|
||||||
if not _CrossedColumn._is_crossable(column):
|
if not _CrossedColumn._is_crossable(column):
|
||||||
raise TypeError("columns should be a set of "
|
raise TypeError("columns must be a set of _SparseColumn, "
|
||||||
"_SparseColumn, _CrossedColumn, or _BucketizedColumn. "
|
"_CrossedColumn, or _BucketizedColumn instances. "
|
||||||
"Column is {}".format(column))
|
"column: {}".format(column))
|
||||||
|
|
||||||
if len(columns) < 2:
|
if len(columns) < 2:
|
||||||
raise ValueError("columns should contain at least 2 elements.")
|
raise ValueError("columns must contain at least 2 elements. "
|
||||||
|
"columns: {}".format(columns))
|
||||||
|
|
||||||
if not isinstance(hash_bucket_size, int):
|
if not isinstance(hash_bucket_size, int):
|
||||||
raise TypeError("hash_bucket_size should be an int.")
|
raise TypeError("hash_bucket_size must be an int. "
|
||||||
|
"hash_bucket_size: {}".format(hash_bucket_size))
|
||||||
|
|
||||||
if hash_bucket_size < 2:
|
if hash_bucket_size < 2:
|
||||||
raise ValueError("hash_bucket_size should be at least 2.")
|
raise ValueError("hash_bucket_size must be at least 2. "
|
||||||
|
"hash_bucket_size: {}".format(hash_bucket_size))
|
||||||
|
|
||||||
if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
|
if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
|
||||||
raise ValueError("Must specify both `ckpt_to_load_from` and "
|
raise ValueError("Must specify both `ckpt_to_load_from` and "
|
||||||
@ -1275,8 +1295,8 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
input_tensor,
|
input_tensor,
|
||||||
weight_collections=None,
|
weight_collections=None,
|
||||||
trainable=True):
|
trainable=True):
|
||||||
raise ValueError("Column {} is not supported in DNN. "
|
raise ValueError("CrossedColumn is not supported in DNN. "
|
||||||
"Please use embedding_column.".format(self))
|
"Please use embedding_column. column: {}".format(self))
|
||||||
|
|
||||||
def to_weighted_sum(self,
|
def to_weighted_sum(self,
|
||||||
input_tensor,
|
input_tensor,
|
||||||
@ -1292,7 +1312,7 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
initializer=init_ops.zeros_initializer,
|
initializer=init_ops.zeros_initializer,
|
||||||
combiner=self.combiner,
|
combiner=self.combiner,
|
||||||
trainable=trainable,
|
trainable=trainable,
|
||||||
name=self.name + "_weights")
|
name=self.name)
|
||||||
if self.ckpt_to_load_from is not None:
|
if self.ckpt_to_load_from is not None:
|
||||||
weights_to_restore = embedding_weights
|
weights_to_restore = embedding_weights
|
||||||
if len(embedding_weights) == 1:
|
if len(embedding_weights) == 1:
|
||||||
@ -1337,7 +1357,7 @@ def crossed_column(columns, hash_bucket_size, combiner="sum",
|
|||||||
|
|
||||||
class DataFrameColumn(_FeatureColumn,
|
class DataFrameColumn(_FeatureColumn,
|
||||||
collections.namedtuple("DataFrameColumn",
|
collections.namedtuple("DataFrameColumn",
|
||||||
["name", "series"])):
|
["column_name", "series"])):
|
||||||
"""Represents a feature column produced from a `DataFrame`.
|
"""Represents a feature column produced from a `DataFrame`.
|
||||||
|
|
||||||
Instances of this class are immutable. A `DataFrame` column may be dense or
|
Instances of this class are immutable. A `DataFrame` column may be dense or
|
||||||
@ -1345,13 +1365,17 @@ class DataFrameColumn(_FeatureColumn,
|
|||||||
batch_size.
|
batch_size.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: a name for this column
|
column_name: a name for this column
|
||||||
series: a `Series` to be wrapped, which has already had its base features
|
series: a `Series` to be wrapped, which has already had its base features
|
||||||
substituted with `PredefinedSeries`.
|
substituted with `PredefinedSeries`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, name, series):
|
def __new__(cls, column_name, series):
|
||||||
return super(DataFrameColumn, cls).__new__(cls, name, series)
|
return super(DataFrameColumn, cls).__new__(cls, column_name, series)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self.column_name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config(self):
|
def config(self):
|
||||||
@ -1379,6 +1403,16 @@ class DataFrameColumn(_FeatureColumn,
|
|||||||
input_tensor,
|
input_tensor,
|
||||||
weight_collections=None,
|
weight_collections=None,
|
||||||
trainable=True):
|
trainable=True):
|
||||||
|
# DataFrame typically provides Tensors of shape [batch_size],
|
||||||
|
# but Estimator requires shape [batch_size, 1]
|
||||||
|
dims = input_tensor.get_shape().ndims
|
||||||
|
if dims == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Can't build input layer from tensor of shape (): {}".format(
|
||||||
|
self.column_name))
|
||||||
|
elif dims == 1:
|
||||||
|
return array_ops.expand_dims(input_tensor, 1)
|
||||||
|
else:
|
||||||
return input_tensor
|
return input_tensor
|
||||||
|
|
||||||
# TODO(soergel): This mirrors RealValuedColumn for now, but should become
|
# TODO(soergel): This mirrors RealValuedColumn for now, but should become
|
||||||
@ -1547,7 +1581,7 @@ def _create_embeddings(name, shape, dtype, initializer, trainable,
|
|||||||
with just one variable.
|
with just one variable.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: A string specifying the name of the embedding variable.
|
name: A string. The name of the embedding variable will be name + _weights.
|
||||||
shape: shape of the embeddding. Note this is not the shape of partitioned
|
shape: shape of the embeddding. Note this is not the shape of partitioned
|
||||||
variables.
|
variables.
|
||||||
dtype: type of the embedding. Also the shape of each partitioned variable.
|
dtype: type of the embedding. Also the shape of each partitioned variable.
|
||||||
@ -1609,7 +1643,7 @@ def _create_embedding_lookup(input_tensor, weight_tensor, vocab_size, dimension,
|
|||||||
A Tensor with shape [batch_size, dimension] and embedding Variable.
|
A Tensor with shape [batch_size, dimension] and embedding Variable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
embeddings = _create_embeddings(name=name,
|
embeddings = _create_embeddings(name=name + "_weights",
|
||||||
shape=[vocab_size, dimension],
|
shape=[vocab_size, dimension],
|
||||||
dtype=dtypes.float32,
|
dtype=dtypes.float32,
|
||||||
initializer=initializer,
|
initializer=initializer,
|
||||||
@ -1621,4 +1655,4 @@ def _create_embedding_lookup(input_tensor, weight_tensor, vocab_size, dimension,
|
|||||||
sparse_weights=weight_tensor,
|
sparse_weights=weight_tensor,
|
||||||
default_id=0,
|
default_id=0,
|
||||||
combiner=combiner,
|
combiner=combiner,
|
||||||
name=name), embeddings
|
name=name + "_weights"), embeddings
|
||||||
|
@ -60,14 +60,17 @@ class FeatureColumnTest(tf.test.TestCase):
|
|||||||
self.assertEqual(b.dimension, 10)
|
self.assertEqual(b.dimension, 10)
|
||||||
self.assertTrue(b.default_value is None)
|
self.assertTrue(b.default_value is None)
|
||||||
|
|
||||||
# dimension is an integer
|
with self.assertRaisesRegexp(TypeError, "dimension must be an integer"):
|
||||||
with self.assertRaises(TypeError):
|
|
||||||
tf.contrib.layers.real_valued_column("d3", dimension=1.0)
|
tf.contrib.layers.real_valued_column("d3", dimension=1.0)
|
||||||
|
|
||||||
# dimension is a positive integer
|
with self.assertRaisesRegexp(ValueError,
|
||||||
with self.assertRaises(ValueError):
|
"dimension must be greater than 0"):
|
||||||
tf.contrib.layers.real_valued_column("d3", dimension=0)
|
tf.contrib.layers.real_valued_column("d3", dimension=0)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
"dtype must be convertible to float"):
|
||||||
|
tf.contrib.layers.real_valued_column("d3", dtype=tf.string)
|
||||||
|
|
||||||
# default_value is an integer.
|
# default_value is an integer.
|
||||||
c1 = tf.contrib.layers.real_valued_column("c1", default_value=2)
|
c1 = tf.contrib.layers.real_valued_column("c1", default_value=2)
|
||||||
self.assertListEqual(list(c1.default_value), [2.])
|
self.assertListEqual(list(c1.default_value), [2.])
|
||||||
@ -92,15 +95,18 @@ class FeatureColumnTest(tf.test.TestCase):
|
|||||||
dimension=4,
|
dimension=4,
|
||||||
default_value=2.)
|
default_value=2.)
|
||||||
self.assertListEqual(list(d2.default_value), [2., 2., 2., 2.])
|
self.assertListEqual(list(d2.default_value), [2., 2., 2., 2.])
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaisesRegexp(TypeError,
|
||||||
|
"default_value must be compatible with dtype"):
|
||||||
tf.contrib.layers.real_valued_column("d3",
|
tf.contrib.layers.real_valued_column("d3",
|
||||||
default_value=2.,
|
default_value=2.,
|
||||||
dtype=tf.int32)
|
dtype=tf.int32)
|
||||||
|
|
||||||
# default_value is neither interger nor float.
|
# default_value is neither integer nor float.
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError, "default_value must be compatible with dtype"):
|
||||||
tf.contrib.layers.real_valued_column("e1", default_value="string")
|
tf.contrib.layers.real_valued_column("e1", default_value="string")
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError, "default_value must be compatible with dtype"):
|
||||||
tf.contrib.layers.real_valued_column("e1",
|
tf.contrib.layers.real_valued_column("e1",
|
||||||
dimension=3,
|
dimension=3,
|
||||||
default_value=[1, 3., "string"])
|
default_value=[1, 3., "string"])
|
||||||
@ -125,11 +131,13 @@ class FeatureColumnTest(tf.test.TestCase):
|
|||||||
dimension=3,
|
dimension=3,
|
||||||
default_value=[2., 2, 2])
|
default_value=[2., 2, 2])
|
||||||
self.assertListEqual(list(g2.default_value), [2., 2., 2.])
|
self.assertListEqual(list(g2.default_value), [2., 2., 2.])
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError, "default_value must be compatible with dtype"):
|
||||||
tf.contrib.layers.real_valued_column("g3",
|
tf.contrib.layers.real_valued_column("g3",
|
||||||
default_value=[2.],
|
default_value=[2.],
|
||||||
dtype=tf.int32)
|
dtype=tf.int32)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, "The length of default_value must be equal to dimension"):
|
||||||
tf.contrib.layers.real_valued_column("g4",
|
tf.contrib.layers.real_valued_column("g4",
|
||||||
dimension=3,
|
dimension=3,
|
||||||
default_value=[2.])
|
default_value=[2.])
|
||||||
@ -140,11 +148,19 @@ class FeatureColumnTest(tf.test.TestCase):
|
|||||||
self.assertEqual(a.name, "aaa_BUCKETIZED")
|
self.assertEqual(a.name, "aaa_BUCKETIZED")
|
||||||
|
|
||||||
def testBucketizedColumnRequiresRealValuedColumn(self):
|
def testBucketizedColumnRequiresRealValuedColumn(self):
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError, "source_column must be an instance of _RealValuedColumn"):
|
||||||
tf.contrib.layers.bucketized_column("bbb", [0])
|
tf.contrib.layers.bucketized_column("bbb", [0])
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError, "source_column must be an instance of _RealValuedColumn"):
|
||||||
|
tf.contrib.layers.bucketized_column(
|
||||||
|
tf.contrib.layers.sparse_column_with_integerized_feature(
|
||||||
|
column_name="bbb", bucket_size=10),
|
||||||
|
[0])
|
||||||
|
|
||||||
def testBucketizedColumnRequiresSortedBuckets(self):
|
def testBucketizedColumnRequiresSortedBuckets(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, "boundaries must be a sorted list"):
|
||||||
tf.contrib.layers.bucketized_column(
|
tf.contrib.layers.bucketized_column(
|
||||||
tf.contrib.layers.real_valued_column("ccc"), [5, 0, 4])
|
tf.contrib.layers.real_valued_column("ccc"), [5, 0, 4])
|
||||||
|
|
||||||
@ -173,7 +189,10 @@ class FeatureColumnTest(tf.test.TestCase):
|
|||||||
def testCrossedColumnNotSupportRealValuedColumn(self):
|
def testCrossedColumnNotSupportRealValuedColumn(self):
|
||||||
b = tf.contrib.layers.sparse_column_with_hash_bucket("bbb",
|
b = tf.contrib.layers.sparse_column_with_hash_bucket("bbb",
|
||||||
hash_bucket_size=100)
|
hash_bucket_size=100)
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError,
|
||||||
|
"columns must be a set of _SparseColumn, _CrossedColumn, "
|
||||||
|
"or _BucketizedColumn instances"):
|
||||||
tf.contrib.layers.crossed_column(
|
tf.contrib.layers.crossed_column(
|
||||||
set([b, tf.contrib.layers.real_valued_column("real")]),
|
set([b, tf.contrib.layers.real_valued_column("real")]),
|
||||||
hash_bucket_size=10000)
|
hash_bucket_size=10000)
|
||||||
@ -194,7 +213,8 @@ class FeatureColumnTest(tf.test.TestCase):
|
|||||||
"weights": tf.VarLenFeature(tf.int32)},
|
"weights": tf.VarLenFeature(tf.int32)},
|
||||||
weighted_ids.config)
|
weighted_ids.config)
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
"dtype is not convertible to float"):
|
||||||
weighted_ids = tf.contrib.layers.weighted_sparse_column(ids, "weights",
|
weighted_ids = tf.contrib.layers.weighted_sparse_column(ids, "weights",
|
||||||
dtype=tf.string)
|
dtype=tf.string)
|
||||||
|
|
||||||
@ -211,7 +231,8 @@ class FeatureColumnTest(tf.test.TestCase):
|
|||||||
[1], dtype=tf.int32)},
|
[1], dtype=tf.int32)},
|
||||||
rvc.config)
|
rvc.config)
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
"dtype must be convertible to float"):
|
||||||
tf.contrib.layers.real_valued_column("rvc", dtype=tf.string)
|
tf.contrib.layers.real_valued_column("rvc", dtype=tf.string)
|
||||||
|
|
||||||
def testSparseColumnDtypes(self):
|
def testSparseColumnDtypes(self):
|
||||||
@ -222,7 +243,8 @@ class FeatureColumnTest(tf.test.TestCase):
|
|||||||
"sc", 10, dtype=tf.int32)
|
"sc", 10, dtype=tf.int32)
|
||||||
self.assertDictEqual({"sc": tf.VarLenFeature(dtype=tf.int32)}, sc.config)
|
self.assertDictEqual({"sc": tf.VarLenFeature(dtype=tf.int32)}, sc.config)
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
"dtype must be an integer"):
|
||||||
tf.contrib.layers.sparse_column_with_integerized_feature("sc",
|
tf.contrib.layers.sparse_column_with_integerized_feature("sc",
|
||||||
10,
|
10,
|
||||||
dtype=tf.float32)
|
dtype=tf.float32)
|
||||||
|
@ -70,7 +70,7 @@ def multi_class_target(n_classes, label_name=None, weight_column_name=None):
|
|||||||
will be multiplied by the loss of the example.
|
will be multiplied by the loss of the example.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An instance of _TargetColumn
|
An instance of _MultiClassTargetColumn.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if n_classes is < 2
|
ValueError: if n_classes is < 2
|
||||||
|
@ -33,6 +33,7 @@ from tensorflow.contrib.learn.python.learn import preprocessing
|
|||||||
from tensorflow.contrib.learn.python.learn import utils
|
from tensorflow.contrib.learn.python.learn import utils
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe import *
|
from tensorflow.contrib.learn.python.learn.dataframe import *
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import *
|
from tensorflow.contrib.learn.python.learn.estimators import *
|
||||||
|
from tensorflow.contrib.learn.python.learn.evaluable import Evaluable
|
||||||
from tensorflow.contrib.learn.python.learn.experiment import Experiment
|
from tensorflow.contrib.learn.python.learn.experiment import Experiment
|
||||||
from tensorflow.contrib.learn.python.learn.graph_actions import evaluate
|
from tensorflow.contrib.learn.python.learn.graph_actions import evaluate
|
||||||
from tensorflow.contrib.learn.python.learn.graph_actions import infer
|
from tensorflow.contrib.learn.python.learn.graph_actions import infer
|
||||||
@ -41,4 +42,5 @@ from tensorflow.contrib.learn.python.learn.graph_actions import run_feeds
|
|||||||
from tensorflow.contrib.learn.python.learn.graph_actions import run_n
|
from tensorflow.contrib.learn.python.learn.graph_actions import run_n
|
||||||
from tensorflow.contrib.learn.python.learn.graph_actions import train
|
from tensorflow.contrib.learn.python.learn.graph_actions import train
|
||||||
from tensorflow.contrib.learn.python.learn.learn_io import *
|
from tensorflow.contrib.learn.python.learn.learn_io import *
|
||||||
|
from tensorflow.contrib.learn.python.learn.trainable import Trainable
|
||||||
# pylint: enable=wildcard-import
|
# pylint: enable=wildcard-import
|
||||||
|
@ -30,6 +30,7 @@ from tensorflow.contrib.learn.python.learn.dataframe.transform import Transform
|
|||||||
# Transforms
|
# Transforms
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms.boolean_mask import BooleanMask
|
from tensorflow.contrib.learn.python.learn.dataframe.transforms.boolean_mask import BooleanMask
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms.difference import Difference
|
from tensorflow.contrib.learn.python.learn.dataframe.transforms.difference import Difference
|
||||||
|
from tensorflow.contrib.learn.python.learn.dataframe.transforms.hashes import HashFast
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms.in_memory_source import NumpySource
|
from tensorflow.contrib.learn.python.learn.dataframe.transforms.in_memory_source import NumpySource
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms.in_memory_source import PandasSource
|
from tensorflow.contrib.learn.python.learn.dataframe.transforms.in_memory_source import PandasSource
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms.reader_source import ReaderSource
|
from tensorflow.contrib.learn.python.learn.dataframe.transforms.reader_source import ReaderSource
|
||||||
|
@ -117,10 +117,11 @@ class DataFrame(object):
|
|||||||
value = [value]
|
value = [value]
|
||||||
self.assign(**dict(zip(key, value)))
|
self.assign(**dict(zip(key, value)))
|
||||||
|
|
||||||
def build(self):
|
def build(self, **kwargs):
|
||||||
# We do not allow passing a cache here, because that would encourage
|
# We do not allow passing a cache here, because that would encourage
|
||||||
# working around the rule that DataFrames cannot be expected to be
|
# working around the rule that DataFrames cannot be expected to be
|
||||||
# synced with each other (e.g., they shuffle independently).
|
# synced with each other (e.g., they shuffle independently).
|
||||||
cache = {}
|
cache = {}
|
||||||
tensors = {name: c.build(cache) for name, c in self._columns.items()}
|
tensors = {name: c.build(cache, **kwargs)
|
||||||
|
for name, c in self._columns.items()}
|
||||||
return tensors
|
return tensors
|
||||||
|
@ -91,7 +91,8 @@ def _build_alternate_universe(
|
|||||||
def to_feature_columns_and_input_fn(dataframe,
|
def to_feature_columns_and_input_fn(dataframe,
|
||||||
base_input_keys_with_defaults,
|
base_input_keys_with_defaults,
|
||||||
feature_keys,
|
feature_keys,
|
||||||
target_keys=None):
|
target_keys=None,
|
||||||
|
**kwargs):
|
||||||
"""Build a list of FeatureColumns and an input_fn for use with Estimator.
|
"""Build a list of FeatureColumns and an input_fn for use with Estimator.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -103,6 +104,7 @@ def to_feature_columns_and_input_fn(dataframe,
|
|||||||
These may include base features and/or derived features.
|
These may include base features and/or derived features.
|
||||||
target_keys: the names of columns to be used as targets. None is
|
target_keys: the names of columns to be used as targets. None is
|
||||||
acceptable for unsupervised learning.
|
acceptable for unsupervised learning.
|
||||||
|
**kwargs: Additional keyword arguments, unused here.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of two elements:
|
A tuple of two elements:
|
||||||
@ -155,10 +157,11 @@ def to_feature_columns_and_input_fn(dataframe,
|
|||||||
|
|
||||||
# Build an input_fn suitable for use with Estimator.
|
# Build an input_fn suitable for use with Estimator.
|
||||||
def input_fn():
|
def input_fn():
|
||||||
|
"""An input_fn() for feeding the given set of DataFrameColumns."""
|
||||||
# It's important to build all the tensors together in one DataFrame.
|
# It's important to build all the tensors together in one DataFrame.
|
||||||
# If we did df.select() for both key sets and then build those, the two
|
# If we did df.select() for both key sets and then build those, the two
|
||||||
# resulting DataFrames would be shuffled independently.
|
# resulting DataFrames would be shuffled independently.
|
||||||
tensors = limited_dataframe.build()
|
tensors = limited_dataframe.build(**kwargs)
|
||||||
|
|
||||||
base_input_features = {key: tensors[key] for key in base_input_keys}
|
base_input_features = {key: tensors[key] for key in base_input_keys}
|
||||||
targets = {key: tensors[key] for key in target_keys}
|
targets = {key: tensors[key] for key in target_keys}
|
||||||
|
@ -98,7 +98,7 @@ class Series(object):
|
|||||||
return transform_cls
|
return transform_cls
|
||||||
return register
|
return register
|
||||||
|
|
||||||
def build(self, cache):
|
def build(self, cache, **kwargs):
|
||||||
"""Returns a Tensor."""
|
"""Returns a Tensor."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@ -122,7 +122,7 @@ class PredefinedSeries(Series):
|
|||||||
def required_base_features(self):
|
def required_base_features(self):
|
||||||
return {self.name: self.feature_spec}
|
return {self.name: self.feature_spec}
|
||||||
|
|
||||||
def build(self, cache):
|
def build(self, cache, **kwargs):
|
||||||
try:
|
try:
|
||||||
return cache[self.name]
|
return cache[self.name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -171,10 +171,11 @@ class TransformedSeries(Series):
|
|||||||
result.update(s.required_base_features)
|
result.update(s.required_base_features)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def build(self, cache=None):
|
def build(self, cache=None, **kwargs):
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = {}
|
cache = {}
|
||||||
all_outputs = self._transform.build_transitive(self._input_series, cache)
|
all_outputs = self._transform.build_transitive(
|
||||||
|
self._input_series, cache, **kwargs)
|
||||||
return getattr(all_outputs, self._output_name)
|
return getattr(all_outputs, self._output_name)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
@ -28,6 +28,7 @@ from tensorflow.contrib.learn.python.learn.dataframe import dataframe as df
|
|||||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms import batch
|
from tensorflow.contrib.learn.python.learn.dataframe.transforms import batch
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms import csv_parser
|
from tensorflow.contrib.learn.python.learn.dataframe.transforms import csv_parser
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms import example_parser
|
from tensorflow.contrib.learn.python.learn.dataframe.transforms import example_parser
|
||||||
|
from tensorflow.contrib.learn.python.learn.dataframe.transforms import hashes
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms import in_memory_source
|
from tensorflow.contrib.learn.python.learn.dataframe.transforms import in_memory_source
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms import reader_source
|
from tensorflow.contrib.learn.python.learn.dataframe.transforms import reader_source
|
||||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms import sparsify
|
from tensorflow.contrib.learn.python.learn.dataframe.transforms import sparsify
|
||||||
@ -83,7 +84,8 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
graph=None,
|
graph=None,
|
||||||
session=None,
|
session=None,
|
||||||
start_queues=True,
|
start_queues=True,
|
||||||
initialize_variables=True):
|
initialize_variables=True,
|
||||||
|
**kwargs):
|
||||||
"""Builds and runs the columns of the `DataFrame` and yields batches.
|
"""Builds and runs the columns of the `DataFrame` and yields batches.
|
||||||
|
|
||||||
This is a generator that yields a dictionary mapping column names to
|
This is a generator that yields a dictionary mapping column names to
|
||||||
@ -97,6 +99,7 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
start_queues: if true, queues will be started before running and halted
|
start_queues: if true, queues will be started before running and halted
|
||||||
after producting `n` batches.
|
after producting `n` batches.
|
||||||
initialize_variables: if true, variables will be initialized.
|
initialize_variables: if true, variables will be initialized.
|
||||||
|
**kwargs: Additional keyword arguments e.g. `num_epochs`.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
A dictionary, mapping column names to the values resulting from running
|
A dictionary, mapping column names to the values resulting from running
|
||||||
@ -107,7 +110,7 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
with graph.as_default():
|
with graph.as_default():
|
||||||
if session is None:
|
if session is None:
|
||||||
session = sess.Session()
|
session = sess.Session()
|
||||||
self_built = self.build()
|
self_built = self.build(**kwargs)
|
||||||
keys = list(self_built.keys())
|
keys = list(self_built.keys())
|
||||||
cols = list(self_built.values())
|
cols = list(self_built.values())
|
||||||
if initialize_variables:
|
if initialize_variables:
|
||||||
@ -157,6 +160,52 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
"Original error: {}").format(type(col), e))
|
"Original error: {}").format(type(col), e))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def split(self, index_series, proportion, batch_size=None):
|
||||||
|
"""Deterministically split a `DataFrame` into two `DataFrame`s.
|
||||||
|
|
||||||
|
Note this split is only as deterministic as the underlying hash function;
|
||||||
|
see `tf.string_to_hash_bucket_fast`. The hash function is deterministic
|
||||||
|
for a given binary, but may change occasionally. The only way to achieve
|
||||||
|
an absolute guarantee that the split `DataFrame`s do not change across runs
|
||||||
|
is to materialize them.
|
||||||
|
|
||||||
|
Note too that the allocation of a row to one partition or the
|
||||||
|
other is evaluated independently for each row, so the exact number of rows
|
||||||
|
in each partition is binomially distributed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_series: a `Series` of unique strings, whose hash will determine the
|
||||||
|
partitioning; or the name in this `DataFrame` of such a `Series`.
|
||||||
|
(This `Series` must contain strings because TensorFlow provides hash
|
||||||
|
ops only for strings, and there are no number-to-string converter ops.)
|
||||||
|
proportion: The proportion of the rows to select for the 'left'
|
||||||
|
partition; the remaining (1 - proportion) rows form the 'right'
|
||||||
|
partition.
|
||||||
|
batch_size: the batch size to use when rebatching the left and right
|
||||||
|
`DataFrame`s. If None (default), the `DataFrame`s are not rebatched;
|
||||||
|
thus their batches will have variable sizes, according to which rows
|
||||||
|
are selected from each batch of the original `DataFrame`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Two `DataFrame`s containing the partitioned rows.
|
||||||
|
"""
|
||||||
|
# TODO(soergel): allow seed?
|
||||||
|
if isinstance(index_series, str):
|
||||||
|
index_series = self[index_series]
|
||||||
|
num_buckets = 1000000 # close enough for simple splits
|
||||||
|
hashed_input, = hashes.HashFast(num_buckets)(index_series)
|
||||||
|
threshold = int(num_buckets * proportion)
|
||||||
|
left = hashed_input < threshold
|
||||||
|
right = ~left
|
||||||
|
left_rows = self.select_rows(left)
|
||||||
|
right_rows = self.select_rows(right)
|
||||||
|
|
||||||
|
if batch_size:
|
||||||
|
left_rows = left_rows.batch(batch_size=batch_size, shuffle=False)
|
||||||
|
right_rows = right_rows.batch(batch_size=batch_size, shuffle=False)
|
||||||
|
|
||||||
|
return left_rows, right_rows
|
||||||
|
|
||||||
def run_once(self):
|
def run_once(self):
|
||||||
"""Creates a new 'Graph` and `Session` and runs a single batch.
|
"""Creates a new 'Graph` and `Session` and runs a single batch.
|
||||||
|
|
||||||
@ -208,7 +257,7 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_csv_base(cls, filepatterns, get_default_values, has_header,
|
def _from_csv_base(cls, filepatterns, get_default_values, has_header,
|
||||||
column_names, num_epochs, num_threads, enqueue_size,
|
column_names, num_threads, enqueue_size,
|
||||||
batch_size, queue_capacity, min_after_dequeue, shuffle,
|
batch_size, queue_capacity, min_after_dequeue, shuffle,
|
||||||
seed):
|
seed):
|
||||||
"""Create a `DataFrame` from CSV files.
|
"""Create a `DataFrame` from CSV files.
|
||||||
@ -223,9 +272,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
each column, given the column names.
|
each column, given the column names.
|
||||||
has_header: whether or not the CSV files have headers.
|
has_header: whether or not the CSV files have headers.
|
||||||
column_names: a list of names for the columns in the CSV files.
|
column_names: a list of names for the columns in the CSV files.
|
||||||
num_epochs: the number of times that the reader should loop through all
|
|
||||||
the file names. If set to `None`, then the reader will continue
|
|
||||||
indefinitely.
|
|
||||||
num_threads: the number of readers that will work in parallel.
|
num_threads: the number of readers that will work in parallel.
|
||||||
enqueue_size: block size for each read operation.
|
enqueue_size: block size for each read operation.
|
||||||
batch_size: desired batch size.
|
batch_size: desired batch size.
|
||||||
@ -265,7 +311,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
reader_kwargs=reader_kwargs,
|
reader_kwargs=reader_kwargs,
|
||||||
enqueue_size=enqueue_size,
|
enqueue_size=enqueue_size,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_epochs=num_epochs,
|
|
||||||
queue_capacity=queue_capacity,
|
queue_capacity=queue_capacity,
|
||||||
shuffle=shuffle,
|
shuffle=shuffle,
|
||||||
min_after_dequeue=min_after_dequeue,
|
min_after_dequeue=min_after_dequeue,
|
||||||
@ -287,7 +332,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
default_values,
|
default_values,
|
||||||
has_header=True,
|
has_header=True,
|
||||||
column_names=None,
|
column_names=None,
|
||||||
num_epochs=None,
|
|
||||||
num_threads=1,
|
num_threads=1,
|
||||||
enqueue_size=None,
|
enqueue_size=None,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
@ -306,9 +350,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
default_values: a list of default values for each column.
|
default_values: a list of default values for each column.
|
||||||
has_header: whether or not the CSV files have headers.
|
has_header: whether or not the CSV files have headers.
|
||||||
column_names: a list of names for the columns in the CSV files.
|
column_names: a list of names for the columns in the CSV files.
|
||||||
num_epochs: the number of times that the reader should loop through all
|
|
||||||
the file names. If set to `None`, then the reader will continue
|
|
||||||
indefinitely.
|
|
||||||
num_threads: the number of readers that will work in parallel.
|
num_threads: the number of readers that will work in parallel.
|
||||||
enqueue_size: block size for each read operation.
|
enqueue_size: block size for each read operation.
|
||||||
batch_size: desired batch size.
|
batch_size: desired batch size.
|
||||||
@ -332,7 +373,7 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
return default_values
|
return default_values
|
||||||
|
|
||||||
return cls._from_csv_base(filepatterns, get_default_values, has_header,
|
return cls._from_csv_base(filepatterns, get_default_values, has_header,
|
||||||
column_names, num_epochs, num_threads,
|
column_names, num_threads,
|
||||||
enqueue_size, batch_size, queue_capacity,
|
enqueue_size, batch_size, queue_capacity,
|
||||||
min_after_dequeue, shuffle, seed)
|
min_after_dequeue, shuffle, seed)
|
||||||
|
|
||||||
@ -342,7 +383,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
feature_spec,
|
feature_spec,
|
||||||
has_header=True,
|
has_header=True,
|
||||||
column_names=None,
|
column_names=None,
|
||||||
num_epochs=None,
|
|
||||||
num_threads=1,
|
num_threads=1,
|
||||||
enqueue_size=None,
|
enqueue_size=None,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
@ -362,9 +402,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
`VarLenFeature`.
|
`VarLenFeature`.
|
||||||
has_header: whether or not the CSV files have headers.
|
has_header: whether or not the CSV files have headers.
|
||||||
column_names: a list of names for the columns in the CSV files.
|
column_names: a list of names for the columns in the CSV files.
|
||||||
num_epochs: the number of times that the reader should loop through all
|
|
||||||
the file names. If set to `None`, then the reader will continue
|
|
||||||
indefinitely.
|
|
||||||
num_threads: the number of readers that will work in parallel.
|
num_threads: the number of readers that will work in parallel.
|
||||||
enqueue_size: block size for each read operation.
|
enqueue_size: block size for each read operation.
|
||||||
batch_size: desired batch size.
|
batch_size: desired batch size.
|
||||||
@ -387,7 +424,7 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
return [_get_default_value(feature_spec[name]) for name in column_names]
|
return [_get_default_value(feature_spec[name]) for name in column_names]
|
||||||
|
|
||||||
dataframe = cls._from_csv_base(filepatterns, get_default_values, has_header,
|
dataframe = cls._from_csv_base(filepatterns, get_default_values, has_header,
|
||||||
column_names, num_epochs, num_threads,
|
column_names, num_threads,
|
||||||
enqueue_size, batch_size, queue_capacity,
|
enqueue_size, batch_size, queue_capacity,
|
||||||
min_after_dequeue, shuffle, seed)
|
min_after_dequeue, shuffle, seed)
|
||||||
|
|
||||||
@ -405,7 +442,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
filepatterns,
|
filepatterns,
|
||||||
features,
|
features,
|
||||||
reader_cls=io_ops.TFRecordReader,
|
reader_cls=io_ops.TFRecordReader,
|
||||||
num_epochs=None,
|
|
||||||
num_threads=1,
|
num_threads=1,
|
||||||
enqueue_size=None,
|
enqueue_size=None,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
@ -421,9 +457,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
`FixedLenFeature`.
|
`FixedLenFeature`.
|
||||||
reader_cls: a subclass of `tensorflow.ReaderBase` that will be used to
|
reader_cls: a subclass of `tensorflow.ReaderBase` that will be used to
|
||||||
read the `Example`s.
|
read the `Example`s.
|
||||||
num_epochs: the number of times that the reader should loop through all
|
|
||||||
the file names. If set to `None`, then the reader will continue
|
|
||||||
indefinitely.
|
|
||||||
num_threads: the number of readers that will work in parallel.
|
num_threads: the number of readers that will work in parallel.
|
||||||
enqueue_size: block size for each read operation.
|
enqueue_size: block size for each read operation.
|
||||||
batch_size: desired batch size.
|
batch_size: desired batch size.
|
||||||
@ -454,7 +487,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
filenames,
|
filenames,
|
||||||
enqueue_size=enqueue_size,
|
enqueue_size=enqueue_size,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_epochs=num_epochs,
|
|
||||||
queue_capacity=queue_capacity,
|
queue_capacity=queue_capacity,
|
||||||
shuffle=shuffle,
|
shuffle=shuffle,
|
||||||
min_after_dequeue=min_after_dequeue,
|
min_after_dequeue=min_after_dequeue,
|
||||||
|
@ -223,13 +223,14 @@ class Transform(object):
|
|||||||
# pylint: disable=not-callable
|
# pylint: disable=not-callable
|
||||||
return self.return_type(*output_series)
|
return self.return_type(*output_series)
|
||||||
|
|
||||||
def build_transitive(self, input_series, cache=None):
|
def build_transitive(self, input_series, cache=None, **kwargs):
|
||||||
"""Apply this `Transform` to the provided `Series`, producing 'Tensor's.
|
"""Apply this `Transform` to the provided `Series`, producing 'Tensor's.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_series: None, a `Series`, or a list of input `Series`, acting as
|
input_series: None, a `Series`, or a list of input `Series`, acting as
|
||||||
positional arguments.
|
positional arguments.
|
||||||
cache: a dict from Series reprs to Tensors.
|
cache: a dict from Series reprs to Tensors.
|
||||||
|
**kwargs: Additional keyword arguments, unused here.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A namedtuple of the output Tensors.
|
A namedtuple of the output Tensors.
|
||||||
@ -244,7 +245,7 @@ class Transform(object):
|
|||||||
if len(input_series) != self.input_valency:
|
if len(input_series) != self.input_valency:
|
||||||
raise ValueError("Expected %s input Series but received %s." %
|
raise ValueError("Expected %s input Series but received %s." %
|
||||||
(self.input_valency, len(input_series)))
|
(self.input_valency, len(input_series)))
|
||||||
input_tensors = [series.build(cache) for series in input_series]
|
input_tensors = [series.build(cache, **kwargs) for series in input_series]
|
||||||
|
|
||||||
# Note we cache each output individually, not just the entire output
|
# Note we cache each output individually, not just the entire output
|
||||||
# tuple. This allows using the graph as the cache, since it can sensibly
|
# tuple. This allows using the graph as the cache, since it can sensibly
|
||||||
@ -254,7 +255,7 @@ class Transform(object):
|
|||||||
output_tensors = [cache.get(output_repr) for output_repr in output_reprs]
|
output_tensors = [cache.get(output_repr) for output_repr in output_reprs]
|
||||||
|
|
||||||
if None in output_tensors:
|
if None in output_tensors:
|
||||||
result = self._apply_transform(input_tensors)
|
result = self._apply_transform(input_tensors, **kwargs)
|
||||||
for output_name, output_repr in zip(self.output_names, output_reprs):
|
for output_name, output_repr in zip(self.output_names, output_reprs):
|
||||||
cache[output_repr] = getattr(result, output_name)
|
cache[output_repr] = getattr(result, output_name)
|
||||||
else:
|
else:
|
||||||
@ -264,12 +265,13 @@ class Transform(object):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _apply_transform(self, input_tensors):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
"""Applies the transformation to the `transform_input`.
|
"""Applies the transformation to the `transform_input`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_tensors: a list of Tensors representing the input to
|
input_tensors: a list of Tensors representing the input to
|
||||||
the Transform.
|
the Transform.
|
||||||
|
**kwargs: Additional keyword arguments, unused here.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A namedtuple of Tensors representing the transformed output.
|
A namedtuple of Tensors representing the transformed output.
|
||||||
|
@ -72,7 +72,7 @@ class Batch(AbstractBatchTransform):
|
|||||||
def name(self):
|
def name(self):
|
||||||
return "Batch"
|
return "Batch"
|
||||||
|
|
||||||
def _apply_transform(self, transform_input):
|
def _apply_transform(self, transform_input, **kwargs):
|
||||||
batched = input_ops.batch(transform_input,
|
batched = input_ops.batch(transform_input,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
num_threads=self.num_threads,
|
num_threads=self.num_threads,
|
||||||
@ -121,7 +121,7 @@ class ShuffleBatch(AbstractBatchTransform):
|
|||||||
def seed(self):
|
def seed(self):
|
||||||
return self._seed
|
return self._seed
|
||||||
|
|
||||||
def _apply_transform(self, transform_input):
|
def _apply_transform(self, transform_input, **kwargs):
|
||||||
batched = input_ops.shuffle_batch(transform_input,
|
batched = input_ops.shuffle_batch(transform_input,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
capacity=self.queue_capacity,
|
capacity=self.queue_capacity,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -53,7 +53,7 @@ class SeriesBinaryTransform(transform.Transform):
|
|||||||
def _output_names(self):
|
def _output_names(self):
|
||||||
return "output",
|
return "output",
|
||||||
|
|
||||||
def _apply_transform(self, input_tensors):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
# TODO(jamieas): consider supporting sparse inputs.
|
# TODO(jamieas): consider supporting sparse inputs.
|
||||||
if isinstance(input_tensors[0], ops.SparseTensor) or isinstance(
|
if isinstance(input_tensors[0], ops.SparseTensor) or isinstance(
|
||||||
input_tensors[1], ops.SparseTensor):
|
input_tensors[1], ops.SparseTensor):
|
||||||
@ -87,7 +87,7 @@ class ScalarBinaryTransform(transform.Transform):
|
|||||||
def _output_names(self):
|
def _output_names(self):
|
||||||
return "output",
|
return "output",
|
||||||
|
|
||||||
def _apply_transform(self, input_tensors):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
input_tensor = input_tensors[0]
|
input_tensor = input_tensors[0]
|
||||||
if isinstance(input_tensor, ops.SparseTensor):
|
if isinstance(input_tensor, ops.SparseTensor):
|
||||||
result = ops.SparseTensor(input_tensor.indices,
|
result = ops.SparseTensor(input_tensor.indices,
|
||||||
|
@ -77,18 +77,21 @@ class BooleanMask(transform.Transform):
|
|||||||
def _output_names(self):
|
def _output_names(self):
|
||||||
return "output",
|
return "output",
|
||||||
|
|
||||||
def _apply_transform(self, input_tensors):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
"""Applies the transformation to the `transform_input`.
|
"""Applies the transformation to the `transform_input`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_tensors: a list of Tensors representing the input to
|
input_tensors: a list of Tensors representing the input to
|
||||||
the Transform.
|
the Transform.
|
||||||
|
**kwargs: Additional keyword arguments, unused here.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A namedtuple of Tensors representing the transformed output.
|
A namedtuple of Tensors representing the transformed output.
|
||||||
"""
|
"""
|
||||||
input_tensor = input_tensors[0]
|
input_tensor = input_tensors[0]
|
||||||
mask = input_tensors[1]
|
mask = input_tensors[1]
|
||||||
|
if mask.get_shape().ndims > 1:
|
||||||
|
mask = array_ops.squeeze(mask)
|
||||||
|
|
||||||
if isinstance(input_tensor, ops.SparseTensor):
|
if isinstance(input_tensor, ops.SparseTensor):
|
||||||
mask_fn = sparse_boolean_mask
|
mask_fn = sparse_boolean_mask
|
||||||
|
@ -58,7 +58,7 @@ class CSVParser(transform.Transform):
|
|||||||
def default_values(self):
|
def default_values(self):
|
||||||
return self._default_values
|
return self._default_values
|
||||||
|
|
||||||
def _apply_transform(self, input_tensors):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
default_consts = [constant_op.constant(d, shape=[1])
|
default_consts = [constant_op.constant(d, shape=[1])
|
||||||
for d in self._default_values]
|
for d in self._default_values]
|
||||||
parsed_values = parsing_ops.decode_csv(input_tensors[0],
|
parsed_values = parsing_ops.decode_csv(input_tensors[0],
|
||||||
|
@ -47,12 +47,13 @@ class Densify(transform.Transform):
|
|||||||
def _output_names(self):
|
def _output_names(self):
|
||||||
return "output",
|
return "output",
|
||||||
|
|
||||||
def _apply_transform(self, input_tensors):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
"""Applies the transformation to the `transform_input`.
|
"""Applies the transformation to the `transform_input`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_tensors: a list of Tensors representing the input to
|
input_tensors: a list of Tensors representing the input to
|
||||||
the Transform.
|
the Transform.
|
||||||
|
**kwargs: Additional keyword arguments, unused here.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A namedtuple of Tensors representing the transformed output.
|
A namedtuple of Tensors representing the transformed output.
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -50,7 +50,7 @@ class Difference(transform.Transform):
|
|||||||
def _output_names(self):
|
def _output_names(self):
|
||||||
return "output",
|
return "output",
|
||||||
|
|
||||||
def _apply_transform(self, input_tensors):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor),
|
pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor),
|
||||||
isinstance(input_tensors[1], ops.SparseTensor))
|
isinstance(input_tensors[1], ops.SparseTensor))
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ class ExampleParser(transform.Transform):
|
|||||||
def feature_definitions(self):
|
def feature_definitions(self):
|
||||||
return self._ordered_features
|
return self._ordered_features
|
||||||
|
|
||||||
def _apply_transform(self, input_tensors):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
parsed_values = parsing_ops.parse_example(input_tensors[0],
|
parsed_values = parsing_ops.parse_example(input_tensors[0],
|
||||||
features=self._ordered_features)
|
features=self._ordered_features)
|
||||||
# pylint: disable=not-callable
|
# pylint: disable=not-callable
|
||||||
|
@ -0,0 +1,68 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Masks one `Series` based on the content of another `Series`."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib.learn.python.learn.dataframe import transform
|
||||||
|
from tensorflow.python.ops import string_ops
|
||||||
|
|
||||||
|
|
||||||
|
class HashFast(transform.Transform):
|
||||||
|
"""Perform a fast hash of a `Series`."""
|
||||||
|
|
||||||
|
def __init__(self, num_buckets):
|
||||||
|
"""Initialize `CSVParser`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_buckets: The number of hash buckets to use.
|
||||||
|
"""
|
||||||
|
# TODO(soergel): allow seed?
|
||||||
|
super(HashFast, self).__init__()
|
||||||
|
self._num_buckets = num_buckets
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return "HashFast"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_valency(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _output_names(self):
|
||||||
|
return "output",
|
||||||
|
|
||||||
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
|
"""Applies the transformation to the `transform_input`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_tensors: a list of Tensors representing the input to
|
||||||
|
the Transform.
|
||||||
|
**kwargs: additional keyword arguments, unused here.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A namedtuple of Tensors representing the transformed output.
|
||||||
|
"""
|
||||||
|
result = string_ops.string_to_hash_bucket_fast(input_tensors[0],
|
||||||
|
self._num_buckets,
|
||||||
|
name=None)
|
||||||
|
# pylint: disable=not-callable
|
||||||
|
return self.return_type(result)
|
||||||
|
|
||||||
|
|
@ -89,7 +89,7 @@ class BaseInMemorySource(transform.Transform):
|
|||||||
def input_valency(self):
|
def input_valency(self):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def _apply_transform(self, transform_input):
|
def _apply_transform(self, transform_input, **kwargs):
|
||||||
queue = feeding_functions.enqueue_data(self.data,
|
queue = feeding_functions.enqueue_data(self.data,
|
||||||
self.queue_capacity,
|
self.queue_capacity,
|
||||||
self.shuffle,
|
self.shuffle,
|
||||||
|
@ -32,7 +32,6 @@ class ReaderSource(transform.Transform):
|
|||||||
reader_kwargs=None,
|
reader_kwargs=None,
|
||||||
enqueue_size=None,
|
enqueue_size=None,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
num_epochs=None,
|
|
||||||
queue_capacity=None,
|
queue_capacity=None,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
min_after_dequeue=None,
|
min_after_dequeue=None,
|
||||||
@ -49,9 +48,6 @@ class ReaderSource(transform.Transform):
|
|||||||
is constructed.
|
is constructed.
|
||||||
enqueue_size: block size for each read operation.
|
enqueue_size: block size for each read operation.
|
||||||
batch_size: The desired batch size of output. Defaults to 1.
|
batch_size: The desired batch size of output. Defaults to 1.
|
||||||
num_epochs: the number of times that the reader should loop through all
|
|
||||||
the file names. If set to `None`, then the reader will continue
|
|
||||||
indefinitely.
|
|
||||||
queue_capacity: Capacity of the queue. Defaults to 10 * `batch_size`.
|
queue_capacity: Capacity of the queue. Defaults to 10 * `batch_size`.
|
||||||
shuffle: Whether records will be shuffled before returning. Defaults to
|
shuffle: Whether records will be shuffled before returning. Defaults to
|
||||||
false.
|
false.
|
||||||
@ -73,7 +69,6 @@ class ReaderSource(transform.Transform):
|
|||||||
self._batch_size = batch_size
|
self._batch_size = batch_size
|
||||||
self._queue_capacity = (batch_size * 10 if queue_capacity is None else
|
self._queue_capacity = (batch_size * 10 if queue_capacity is None else
|
||||||
queue_capacity)
|
queue_capacity)
|
||||||
self._num_epochs = num_epochs
|
|
||||||
self._shuffle = shuffle
|
self._shuffle = shuffle
|
||||||
self._min_after_dequeue = int(self.queue_capacity / 4 if min_after_dequeue
|
self._min_after_dequeue = int(self.queue_capacity / 4 if min_after_dequeue
|
||||||
is None else min_after_dequeue)
|
is None else min_after_dequeue)
|
||||||
@ -100,10 +95,6 @@ class ReaderSource(transform.Transform):
|
|||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
return self._batch_size
|
return self._batch_size
|
||||||
|
|
||||||
@transform.parameter
|
|
||||||
def num_epochs(self):
|
|
||||||
return self._num_epochs
|
|
||||||
|
|
||||||
@transform.parameter
|
@transform.parameter
|
||||||
def queue_capacity(self):
|
def queue_capacity(self):
|
||||||
return self._queue_capacity
|
return self._queue_capacity
|
||||||
@ -136,9 +127,10 @@ class ReaderSource(transform.Transform):
|
|||||||
def _output_names(self):
|
def _output_names(self):
|
||||||
return ("index", "value")
|
return ("index", "value")
|
||||||
|
|
||||||
def _apply_transform(self, transform_input):
|
def _apply_transform(self, transform_input, **kwargs):
|
||||||
filename_queue = input_ops.string_input_producer(self.work_units,
|
filename_queue = input_ops.string_input_producer(
|
||||||
num_epochs=self.num_epochs,
|
self.work_units,
|
||||||
|
num_epochs=kwargs.get("num_epochs"),
|
||||||
shuffle=self.shuffle,
|
shuffle=self.shuffle,
|
||||||
seed=self.seed)
|
seed=self.seed)
|
||||||
reader_ops = []
|
reader_ops = []
|
||||||
@ -174,7 +166,6 @@ def TextFileSource(file_names,
|
|||||||
reader_kwargs=None,
|
reader_kwargs=None,
|
||||||
enqueue_size=1,
|
enqueue_size=1,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
num_epochs=None,
|
|
||||||
queue_capacity=None,
|
queue_capacity=None,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
min_after_dequeue=None,
|
min_after_dequeue=None,
|
||||||
@ -185,7 +176,6 @@ def TextFileSource(file_names,
|
|||||||
reader_kwargs=reader_kwargs,
|
reader_kwargs=reader_kwargs,
|
||||||
enqueue_size=enqueue_size,
|
enqueue_size=enqueue_size,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_epochs=num_epochs,
|
|
||||||
queue_capacity=queue_capacity,
|
queue_capacity=queue_capacity,
|
||||||
shuffle=shuffle,
|
shuffle=shuffle,
|
||||||
min_after_dequeue=min_after_dequeue,
|
min_after_dequeue=min_after_dequeue,
|
||||||
@ -197,7 +187,6 @@ def TFRecordSource(file_names,
|
|||||||
reader_kwargs=None,
|
reader_kwargs=None,
|
||||||
enqueue_size=1,
|
enqueue_size=1,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
num_epochs=None,
|
|
||||||
queue_capacity=None,
|
queue_capacity=None,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
min_after_dequeue=None,
|
min_after_dequeue=None,
|
||||||
@ -208,7 +197,6 @@ def TFRecordSource(file_names,
|
|||||||
reader_kwargs=reader_kwargs,
|
reader_kwargs=reader_kwargs,
|
||||||
enqueue_size=enqueue_size,
|
enqueue_size=enqueue_size,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_epochs=num_epochs,
|
|
||||||
queue_capacity=queue_capacity,
|
queue_capacity=queue_capacity,
|
||||||
shuffle=shuffle,
|
shuffle=shuffle,
|
||||||
min_after_dequeue=min_after_dequeue,
|
min_after_dequeue=min_after_dequeue,
|
||||||
|
@ -52,12 +52,13 @@ class Sparsify(transform.Transform):
|
|||||||
def _output_names(self):
|
def _output_names(self):
|
||||||
return "output",
|
return "output",
|
||||||
|
|
||||||
def _apply_transform(self, input_tensors):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
"""Applies the transformation to the `transform_input`.
|
"""Applies the transformation to the `transform_input`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_tensors: a list of Tensors representing the input to
|
input_tensors: a list of Tensors representing the input to
|
||||||
the Transform.
|
the Transform.
|
||||||
|
**kwargs: Additional keyword arguments, unused here.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A namedtuple of Tensors representing the transformed output.
|
A namedtuple of Tensors representing the transformed output.
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -44,7 +44,7 @@ class Sum(transform.Transform):
|
|||||||
def _output_names(self):
|
def _output_names(self):
|
||||||
return "output",
|
return "output",
|
||||||
|
|
||||||
def _apply_transform(self, input_tensors):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor),
|
pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor),
|
||||||
isinstance(input_tensors[1], ops.SparseTensor))
|
isinstance(input_tensors[1], ops.SparseTensor))
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -43,7 +43,8 @@ UNARY_TRANSFORMS = [("__neg__", math_ops.neg),
|
|||||||
("lgamma", math_ops.lgamma),
|
("lgamma", math_ops.lgamma),
|
||||||
("digamma", math_ops.digamma),
|
("digamma", math_ops.digamma),
|
||||||
("erf", math_ops.erf),
|
("erf", math_ops.erf),
|
||||||
("erfc", math_ops.erfc)]
|
("erfc", math_ops.erfc),
|
||||||
|
("__invert__", math_ops.logical_not, bool)]
|
||||||
|
|
||||||
DOC_FORMAT_STRING = (
|
DOC_FORMAT_STRING = (
|
||||||
"A `Transform` that wraps the `{0}` operation. "
|
"A `Transform` that wraps the `{0}` operation. "
|
||||||
@ -52,7 +53,7 @@ DOC_FORMAT_STRING = (
|
|||||||
|
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def register_unary_op(registered_name, operation):
|
def register_unary_op(registered_name, operation, ignore_dtype=None):
|
||||||
"""Creates a `Transform` that wraps a unary tensorflow operation.
|
"""Creates a `Transform` that wraps a unary tensorflow operation.
|
||||||
|
|
||||||
If `registered_name` is specified, the `Transform` is registered as a member
|
If `registered_name` is specified, the `Transform` is registered as a member
|
||||||
@ -62,6 +63,8 @@ def register_unary_op(registered_name, operation):
|
|||||||
registered_name: the name of the member function of `Series` corresponding
|
registered_name: the name of the member function of `Series` corresponding
|
||||||
to the returned `Transform`.
|
to the returned `Transform`.
|
||||||
operation: a unary TensorFlow operation.
|
operation: a unary TensorFlow operation.
|
||||||
|
ignore_dtype: an optional dtype, not used here but needed for symmetry with
|
||||||
|
test.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
doc = DOC_FORMAT_STRING.format(operation.__name__, operation.__doc__)
|
doc = DOC_FORMAT_STRING.format(operation.__name__, operation.__doc__)
|
||||||
@ -78,7 +81,7 @@ def register_unary_op(registered_name, operation):
|
|||||||
def _output_names(self):
|
def _output_names(self):
|
||||||
return "output"
|
return "output"
|
||||||
|
|
||||||
def _apply_transform(self, input_tensors):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
input_tensor = input_tensors[0]
|
input_tensor = input_tensors[0]
|
||||||
if isinstance(input_tensor, ops.SparseTensor):
|
if isinstance(input_tensor, ops.SparseTensor):
|
||||||
result = ops.SparseTensor(input_tensor.indices,
|
result = ops.SparseTensor(input_tensor.indices,
|
||||||
|
@ -29,14 +29,10 @@ from tensorflow.contrib.learn.python.learn.estimators import _sklearn
|
|||||||
|
|
||||||
def iris_input_fn(num_epochs=None):
|
def iris_input_fn(num_epochs=None):
|
||||||
iris = tf.contrib.learn.datasets.load_iris()
|
iris = tf.contrib.learn.datasets.load_iris()
|
||||||
features = tf.cast(
|
features = tf.reshape(tf.constant(iris.data), [-1, 4])
|
||||||
tf.reshape(
|
|
||||||
tf.constant(iris.data), [-1, 4]), tf.float32)
|
|
||||||
if num_epochs:
|
if num_epochs:
|
||||||
features = tf.train.limit_epochs(features, num_epochs=num_epochs)
|
features = tf.train.limit_epochs(features, num_epochs=num_epochs)
|
||||||
target = tf.cast(
|
target = tf.reshape(tf.constant(iris.target), [-1])
|
||||||
tf.reshape(
|
|
||||||
tf.constant(iris.target), [-1]), tf.int64)
|
|
||||||
return features, target
|
return features, target
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,11 +20,13 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import re
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.contrib import layers
|
from tensorflow.contrib import layers
|
||||||
from tensorflow.contrib.layers.python.layers import feature_column_ops
|
from tensorflow.contrib.layers.python.layers import feature_column_ops
|
||||||
|
from tensorflow.contrib.learn.python.learn.utils import checkpoints
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import clip_ops
|
from tensorflow.python.ops import clip_ops
|
||||||
from tensorflow.python.ops import gradients
|
from tensorflow.python.ops import gradients
|
||||||
@ -47,31 +49,31 @@ class _ComposableModel(object):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_label_columns,
|
num_label_columns,
|
||||||
optimizer,
|
optimizer,
|
||||||
weight_collection_name,
|
|
||||||
gradient_clip_norm,
|
gradient_clip_norm,
|
||||||
num_ps_replicas):
|
num_ps_replicas,
|
||||||
|
scope):
|
||||||
"""Common initialization for all _ComposableModel objects.
|
"""Common initialization for all _ComposableModel objects.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_label_columns: The number of label/target columns.
|
num_label_columns: The number of label/target columns.
|
||||||
optimizer: An instance of `tf.Optimizer` used to apply gradients to
|
optimizer: An instance of `tf.Optimizer` used to apply gradients to
|
||||||
the model. If `None`, will use a FTRL optimizer.
|
the model. If `None`, will use a FTRL optimizer.
|
||||||
weight_collection_name: A string defining the name to use for the
|
|
||||||
collection of weights (e.g. 'dnn').
|
|
||||||
gradient_clip_norm: A float > 0. If provided, gradients are clipped
|
gradient_clip_norm: A float > 0. If provided, gradients are clipped
|
||||||
to their global norm with this clipping ratio. See
|
to their global norm with this clipping ratio. See
|
||||||
tf.clip_by_global_norm for more details.
|
tf.clip_by_global_norm for more details.
|
||||||
num_ps_replicas: The number of parameter server replicas.
|
num_ps_replicas: The number of parameter server replicas.
|
||||||
|
scope: Scope for variables created in this model.
|
||||||
"""
|
"""
|
||||||
self._num_label_columns = num_label_columns
|
self._num_label_columns = num_label_columns
|
||||||
self._optimizer = optimizer
|
self._optimizer = optimizer
|
||||||
self._weight_collection_name = weight_collection_name
|
|
||||||
self._gradient_clip_norm = gradient_clip_norm
|
self._gradient_clip_norm = gradient_clip_norm
|
||||||
self._num_ps_replicas = num_ps_replicas
|
self._num_ps_replicas = num_ps_replicas
|
||||||
|
self._scope = scope
|
||||||
self._feature_columns = None
|
self._feature_columns = None
|
||||||
|
|
||||||
def get_weight_collection_name(self):
|
def get_scope_name(self):
|
||||||
return self._weight_collection_name
|
"""Returns the scope name used by this model for variables."""
|
||||||
|
return self._scope
|
||||||
|
|
||||||
def build_model(self, features, feature_columns, is_training):
|
def build_model(self, features, feature_columns, is_training):
|
||||||
"""Builds the model that can calculate the logits.
|
"""Builds the model that can calculate the logits.
|
||||||
@ -114,7 +116,7 @@ class _ComposableModel(object):
|
|||||||
|
|
||||||
def _get_vars(self):
|
def _get_vars(self):
|
||||||
if self._get_feature_columns():
|
if self._get_feature_columns():
|
||||||
return ops.get_collection(self._weight_collection_name)
|
return ops.get_collection(self._scope)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _get_optimizer(self):
|
def _get_optimizer(self):
|
||||||
@ -142,7 +144,8 @@ class LinearComposableModel(_ComposableModel):
|
|||||||
num_label_columns,
|
num_label_columns,
|
||||||
optimizer=None,
|
optimizer=None,
|
||||||
gradient_clip_norm=None,
|
gradient_clip_norm=None,
|
||||||
num_ps_replicas=0):
|
num_ps_replicas=0,
|
||||||
|
scope=None):
|
||||||
"""Initializes LinearComposableModel objects.
|
"""Initializes LinearComposableModel objects.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -153,13 +156,49 @@ class LinearComposableModel(_ComposableModel):
|
|||||||
to their global norm with this clipping ratio. See
|
to their global norm with this clipping ratio. See
|
||||||
tf.clip_by_global_norm for more details.
|
tf.clip_by_global_norm for more details.
|
||||||
num_ps_replicas: The number of parameter server replicas.
|
num_ps_replicas: The number of parameter server replicas.
|
||||||
|
scope: Optional scope for variables created in this model. If scope
|
||||||
|
is not supplied, it will default to 'linear'.
|
||||||
"""
|
"""
|
||||||
|
scope = "linear" if not scope else scope
|
||||||
super(LinearComposableModel, self).__init__(
|
super(LinearComposableModel, self).__init__(
|
||||||
num_label_columns=num_label_columns,
|
num_label_columns=num_label_columns,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
weight_collection_name="linear",
|
|
||||||
gradient_clip_norm=gradient_clip_norm,
|
gradient_clip_norm=gradient_clip_norm,
|
||||||
num_ps_replicas=num_ps_replicas)
|
num_ps_replicas=num_ps_replicas,
|
||||||
|
scope=scope)
|
||||||
|
|
||||||
|
def get_weights(self, model_dir):
|
||||||
|
"""Returns weights per feature of the linear part.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_dir: Directory where model parameters, graph and etc. are saved.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The weights created by this model (without the optimizer weights).
|
||||||
|
"""
|
||||||
|
all_variables = [name for name, _ in checkpoints.list_variables(model_dir)]
|
||||||
|
values = {}
|
||||||
|
optimizer_regex = r".*/" + self._get_optimizer().get_name() + r"(_\d)?$"
|
||||||
|
for name in all_variables:
|
||||||
|
if (name.startswith(self._scope + "/") and
|
||||||
|
name != self._scope + "/bias_weight" and
|
||||||
|
not re.match(optimizer_regex, name)):
|
||||||
|
values[name] = checkpoints.load_variable(model_dir, name)
|
||||||
|
if len(values) == 1:
|
||||||
|
return values[list(values.keys())[0]]
|
||||||
|
return values
|
||||||
|
|
||||||
|
def get_bias(self, model_dir):
|
||||||
|
"""Returns bias of the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_dir: Directory where model parameters, graph and etc. are saved.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The bias weights created by this model.
|
||||||
|
"""
|
||||||
|
return checkpoints.load_variable(model_dir,
|
||||||
|
name=(self._scope+"/bias_weight"))
|
||||||
|
|
||||||
def build_model(self, features, feature_columns, is_training):
|
def build_model(self, features, feature_columns, is_training):
|
||||||
"""See base class."""
|
"""See base class."""
|
||||||
@ -168,12 +207,12 @@ class LinearComposableModel(_ComposableModel):
|
|||||||
max_partitions=self._num_ps_replicas,
|
max_partitions=self._num_ps_replicas,
|
||||||
min_slice_size=64 << 20)
|
min_slice_size=64 << 20)
|
||||||
with variable_scope.variable_op_scope(
|
with variable_scope.variable_op_scope(
|
||||||
features.values(), "linear", partitioner=partitioner) as scope:
|
features.values(), self._scope, partitioner=partitioner) as scope:
|
||||||
logits, _, _ = layers.weighted_sum_from_feature_columns(
|
logits, _, _ = layers.weighted_sum_from_feature_columns(
|
||||||
columns_to_tensors=features,
|
columns_to_tensors=features,
|
||||||
feature_columns=self._get_feature_columns(),
|
feature_columns=self._get_feature_columns(),
|
||||||
num_outputs=self._num_label_columns,
|
num_outputs=self._num_label_columns,
|
||||||
weight_collections=[self._weight_collection_name],
|
weight_collections=[self._scope],
|
||||||
scope=scope)
|
scope=scope)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
@ -200,7 +239,8 @@ class DNNComposableModel(_ComposableModel):
|
|||||||
activation_fn=nn.relu,
|
activation_fn=nn.relu,
|
||||||
dropout=None,
|
dropout=None,
|
||||||
gradient_clip_norm=None,
|
gradient_clip_norm=None,
|
||||||
num_ps_replicas=0):
|
num_ps_replicas=0,
|
||||||
|
scope=None):
|
||||||
"""Initializes DNNComposableModel objects.
|
"""Initializes DNNComposableModel objects.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -217,17 +257,50 @@ class DNNComposableModel(_ComposableModel):
|
|||||||
to their global norm with this clipping ratio. See
|
to their global norm with this clipping ratio. See
|
||||||
tf.clip_by_global_norm for more details.
|
tf.clip_by_global_norm for more details.
|
||||||
num_ps_replicas: The number of parameter server replicas.
|
num_ps_replicas: The number of parameter server replicas.
|
||||||
|
scope: Optional scope for variables created in this model. If not scope
|
||||||
|
is supplied, one is generated.
|
||||||
"""
|
"""
|
||||||
|
scope = "dnn" if not scope else scope
|
||||||
super(DNNComposableModel, self).__init__(
|
super(DNNComposableModel, self).__init__(
|
||||||
num_label_columns=num_label_columns,
|
num_label_columns=num_label_columns,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
weight_collection_name="DNN",
|
|
||||||
gradient_clip_norm=gradient_clip_norm,
|
gradient_clip_norm=gradient_clip_norm,
|
||||||
num_ps_replicas=num_ps_replicas)
|
num_ps_replicas=num_ps_replicas,
|
||||||
|
scope=scope)
|
||||||
self._hidden_units = hidden_units
|
self._hidden_units = hidden_units
|
||||||
self._activation_fn = activation_fn
|
self._activation_fn = activation_fn
|
||||||
self._dropout = dropout
|
self._dropout = dropout
|
||||||
|
|
||||||
|
def get_weights(self, model_dir):
|
||||||
|
"""Returns the weights of the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_dir: Directory where model parameters, graph and etc. are saved.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The weights created by this model.
|
||||||
|
"""
|
||||||
|
return [checkpoints.load_variable(
|
||||||
|
model_dir, name=(self._scope+"/hiddenlayer_%d/weights" % i))
|
||||||
|
for i, _ in enumerate(self._hidden_units)] + [
|
||||||
|
checkpoints.load_variable(
|
||||||
|
model_dir, name=(self._scope+"/logits/weights"))]
|
||||||
|
|
||||||
|
def get_bias(self, model_dir):
|
||||||
|
"""Returns the bias of the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_dir: Directory where model parameters, graph and etc. are saved.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The bias weights created by this model.
|
||||||
|
"""
|
||||||
|
return [checkpoints.load_variable(
|
||||||
|
model_dir, name=(self._scope+"/hiddenlayer_%d/biases" % i))
|
||||||
|
for i, _ in enumerate(self._hidden_units)] + [
|
||||||
|
checkpoints.load_variable(
|
||||||
|
model_dir, name=(self._scope+"/logits/biases"))]
|
||||||
|
|
||||||
def _add_hidden_layer_summary(self, value, tag):
|
def _add_hidden_layer_summary(self, value, tag):
|
||||||
# TODO(zakaria): Move this code to tf.learn and add test.
|
# TODO(zakaria): Move this code to tf.learn and add test.
|
||||||
logging_ops.scalar_summary("%s:fraction_of_zero_values" % tag,
|
logging_ops.scalar_summary("%s:fraction_of_zero_values" % tag,
|
||||||
@ -244,12 +317,12 @@ class DNNComposableModel(_ComposableModel):
|
|||||||
min_slice_size=64 << 20))
|
min_slice_size=64 << 20))
|
||||||
with variable_scope.variable_op_scope(
|
with variable_scope.variable_op_scope(
|
||||||
features.values(),
|
features.values(),
|
||||||
"input_from_feature_columns",
|
self._scope + "/input_from_feature_columns",
|
||||||
partitioner=input_layer_partitioner) as scope:
|
partitioner=input_layer_partitioner) as scope:
|
||||||
net = layers.input_from_feature_columns(
|
net = layers.input_from_feature_columns(
|
||||||
features,
|
features,
|
||||||
self._get_feature_columns(),
|
self._get_feature_columns(),
|
||||||
weight_collections=[self._weight_collection_name],
|
weight_collections=[self._scope],
|
||||||
scope=scope)
|
scope=scope)
|
||||||
|
|
||||||
hidden_layer_partitioner = (
|
hidden_layer_partitioner = (
|
||||||
@ -257,13 +330,13 @@ class DNNComposableModel(_ComposableModel):
|
|||||||
max_partitions=self._num_ps_replicas))
|
max_partitions=self._num_ps_replicas))
|
||||||
for layer_id, num_hidden_units in enumerate(self._hidden_units):
|
for layer_id, num_hidden_units in enumerate(self._hidden_units):
|
||||||
with variable_scope.variable_op_scope(
|
with variable_scope.variable_op_scope(
|
||||||
[net], "hiddenlayer_%d" % layer_id,
|
[net], self._scope + "/hiddenlayer_%d" % layer_id,
|
||||||
partitioner=hidden_layer_partitioner) as scope:
|
partitioner=hidden_layer_partitioner) as scope:
|
||||||
net = layers.fully_connected(
|
net = layers.fully_connected(
|
||||||
net,
|
net,
|
||||||
num_hidden_units,
|
num_hidden_units,
|
||||||
activation_fn=self._activation_fn,
|
activation_fn=self._activation_fn,
|
||||||
variables_collections=[self._weight_collection_name],
|
variables_collections=[self._scope],
|
||||||
scope=scope)
|
scope=scope)
|
||||||
if self._dropout is not None and is_training:
|
if self._dropout is not None and is_training:
|
||||||
net = layers.dropout(
|
net = layers.dropout(
|
||||||
@ -272,15 +345,15 @@ class DNNComposableModel(_ComposableModel):
|
|||||||
self._add_hidden_layer_summary(net, scope.name)
|
self._add_hidden_layer_summary(net, scope.name)
|
||||||
|
|
||||||
with variable_scope.variable_op_scope(
|
with variable_scope.variable_op_scope(
|
||||||
[net], "dnn_logits",
|
[net], self._scope + "/logits",
|
||||||
partitioner=hidden_layer_partitioner) as scope:
|
partitioner=hidden_layer_partitioner) as scope:
|
||||||
logits = layers.fully_connected(
|
logits = layers.fully_connected(
|
||||||
net,
|
net,
|
||||||
self._num_label_columns,
|
self._num_label_columns,
|
||||||
activation_fn=None,
|
activation_fn=None,
|
||||||
variables_collections=[self._weight_collection_name],
|
variables_collections=[self._scope],
|
||||||
scope=scope)
|
scope=scope)
|
||||||
self._add_hidden_layer_summary(logits, "dnn_logits")
|
self._add_hidden_layer_summary(logits, "logits")
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def _get_default_optimizer(self, optimizer_name=None):
|
def _get_default_optimizer(self, optimizer_name=None):
|
||||||
|
@ -19,6 +19,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.contrib import layers
|
from tensorflow.contrib import layers
|
||||||
@ -42,7 +44,7 @@ class _BaseEstimatorForTest(estimator.BaseEstimator):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
target_column,
|
target_column,
|
||||||
feature_columns):
|
feature_columns):
|
||||||
super(_BaseEstimatorForTest, self).__init__()
|
super(_BaseEstimatorForTest, self).__init__(model_dir=tempfile.mkdtemp())
|
||||||
self._target_column = target_column
|
self._target_column = target_column
|
||||||
self._feature_columns = feature_columns
|
self._feature_columns = feature_columns
|
||||||
|
|
||||||
|
@ -71,9 +71,9 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_column: A _TargetColumn object.
|
target_column: A _TargetColumn object.
|
||||||
model_dir: Directory to save model parameters, graph and etc. This can also
|
model_dir: Directory to save model parameters, graph and etc. This can
|
||||||
be used to load checkpoints from the directory into a estimator to continue
|
also be used to load checkpoints from the directory into a estimator
|
||||||
training a previously saved model.
|
to continue training a previously saved model.
|
||||||
linear_feature_columns: An iterable containing all the feature columns
|
linear_feature_columns: An iterable containing all the feature columns
|
||||||
used by linear part of the model. All items in the set should be
|
used by linear part of the model. All items in the set should be
|
||||||
instances of classes derived from `FeatureColumn`.
|
instances of classes derived from `FeatureColumn`.
|
||||||
@ -102,8 +102,8 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
|
|||||||
ValueError: If both linear_feature_columns and dnn_features_columns are
|
ValueError: If both linear_feature_columns and dnn_features_columns are
|
||||||
empty at the same time.
|
empty at the same time.
|
||||||
"""
|
"""
|
||||||
super(_DNNLinearCombinedBaseEstimator, self).__init__(model_dir=model_dir,
|
super(_DNNLinearCombinedBaseEstimator, self).__init__(
|
||||||
config=config)
|
model_dir=model_dir, config=config)
|
||||||
|
|
||||||
num_ps_replicas = config.num_ps_replicas if config else 0
|
num_ps_replicas = config.num_ps_replicas if config else 0
|
||||||
|
|
||||||
@ -124,8 +124,6 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
|
|||||||
|
|
||||||
self._linear_feature_columns = linear_feature_columns
|
self._linear_feature_columns = linear_feature_columns
|
||||||
self._linear_optimizer = linear_optimizer
|
self._linear_optimizer = linear_optimizer
|
||||||
self._linear_weight_collection = (
|
|
||||||
self._linear_model.get_weight_collection_name())
|
|
||||||
self._dnn_feature_columns = dnn_feature_columns
|
self._dnn_feature_columns = dnn_feature_columns
|
||||||
self._dnn_hidden_units = dnn_hidden_units
|
self._dnn_hidden_units = dnn_hidden_units
|
||||||
self._centered_bias_weight_collection = "centered_bias"
|
self._centered_bias_weight_collection = "centered_bias"
|
||||||
@ -135,38 +133,24 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
|
|||||||
@property
|
@property
|
||||||
def linear_weights_(self):
|
def linear_weights_(self):
|
||||||
"""Returns weights per feature of the linear part."""
|
"""Returns weights per feature of the linear part."""
|
||||||
all_variables = self.get_variable_names()
|
return self._linear_model.get_weights(model_dir=self._model_dir)
|
||||||
# TODO(ispir): Figure out a better way to retrieve variables for features.
|
|
||||||
# for example using feature info / columns.
|
|
||||||
values = {}
|
|
||||||
for name in all_variables:
|
|
||||||
if (name.startswith("linear/") and name.rfind("/") == 6 and
|
|
||||||
name != "linear/bias_weight"):
|
|
||||||
values[name] = self.get_variable_value(name)
|
|
||||||
if len(values) == 1:
|
|
||||||
return values[list(values.keys())[0]]
|
|
||||||
return values
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def linear_bias_(self):
|
def linear_bias_(self):
|
||||||
"""Returns bias of the linear part."""
|
"""Returns bias of the linear part."""
|
||||||
return (self.get_variable_value("linear/bias_weight") +
|
return (self._linear_model.get_bias(model_dir=self._model_dir) +
|
||||||
self.get_variable_value("centered_bias_weight"))
|
self.get_variable_value("centered_bias_weight"))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dnn_weights_(self):
|
def dnn_weights_(self):
|
||||||
"""Returns weights of deep neural network part."""
|
"""Returns weights of deep neural network part."""
|
||||||
return [self.get_variable_value("hiddenlayer_%d/weights" % i)
|
return self._dnn_model.get_weights(model_dir=self._model_dir)
|
||||||
for i, _ in enumerate(self._dnn_hidden_units)] + [
|
|
||||||
self.get_variable_value("dnn_logits/weights")]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dnn_bias_(self):
|
def dnn_bias_(self):
|
||||||
"""Returns bias of deep neural network part."""
|
"""Returns bias of deep neural network part."""
|
||||||
return [self.get_variable_value("hiddenlayer_%d/biases" % i)
|
return (self._dnn_model.get_bias(model_dir=self._model_dir) +
|
||||||
for i, _ in enumerate(self._dnn_hidden_units)] + [
|
[self.get_variable_value("centered_bias_weight")])
|
||||||
self.get_variable_value("dnn_logits/biases"),
|
|
||||||
self.get_variable_value("centered_bias_weight")]
|
|
||||||
|
|
||||||
def _get_feature_dict(self, features):
|
def _get_feature_dict(self, features):
|
||||||
if isinstance(features, dict):
|
if isinstance(features, dict):
|
||||||
@ -347,9 +331,9 @@ class DNNLinearCombinedClassifier(_DNNLinearCombinedBaseEstimator):
|
|||||||
"""Constructs a DNNLinearCombinedClassifier instance.
|
"""Constructs a DNNLinearCombinedClassifier instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_dir: Directory to save model parameters, graph and etc. This can also
|
model_dir: Directory to save model parameters, graph and etc. This can
|
||||||
be used to load checkpoints from the directory into a estimator to continue
|
also be used to load checkpoints from the directory into a estimator
|
||||||
training a previously saved model.
|
to continue training a previously saved model.
|
||||||
n_classes: number of target classes. Default is binary classification.
|
n_classes: number of target classes. Default is binary classification.
|
||||||
weight_column_name: A string defining feature column name representing
|
weight_column_name: A string defining feature column name representing
|
||||||
weights. It is used to down weight or boost examples during training.
|
weights. It is used to down weight or boost examples during training.
|
||||||
@ -532,9 +516,9 @@ class DNNLinearCombinedRegressor(_DNNLinearCombinedBaseEstimator):
|
|||||||
"""Initializes a DNNLinearCombinedRegressor instance.
|
"""Initializes a DNNLinearCombinedRegressor instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_dir: Directory to save model parameters, graph and etc. This can also
|
model_dir: Directory to save model parameters, graph and etc. This can
|
||||||
be used to load checkpoints from the directory into a estimator to continue
|
also be used to load checkpoints from the directory into a estimator
|
||||||
training a previously saved model.
|
to continue training a previously saved model.
|
||||||
weight_column_name: A string defining feature column name representing
|
weight_column_name: A string defining feature column name representing
|
||||||
weights. It is used to down weight or boost examples during training. It
|
weights. It is used to down weight or boost examples during training. It
|
||||||
will be multiplied by the loss of the example.
|
will be multiplied by the loss of the example.
|
||||||
|
@ -23,6 +23,7 @@ import tempfile
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
|
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
|
||||||
|
|
||||||
|
|
||||||
@ -458,10 +459,39 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
|
|||||||
self.assertLess(loss2, 0.01)
|
self.assertLess(loss2, 0.01)
|
||||||
self.assertTrue('centered_bias_weight' in classifier.get_variable_names())
|
self.assertTrue('centered_bias_weight' in classifier.get_variable_names())
|
||||||
|
|
||||||
self.assertNotIn('dnn_logits/biases', classifier.get_variable_names())
|
self.assertNotIn('dnn/logits/biases', classifier.get_variable_names())
|
||||||
self.assertNotIn('dnn_logits/weights', classifier.get_variable_names())
|
self.assertNotIn('dnn/logits/weights', classifier.get_variable_names())
|
||||||
self.assertEquals(1, len(classifier.linear_bias_))
|
self.assertEquals(1, len(classifier.linear_bias_))
|
||||||
self.assertEquals(100, len(classifier.linear_weights_))
|
self.assertEquals(2, len(classifier.linear_weights_))
|
||||||
|
self.assertEquals(1, len(classifier.linear_weights_['linear/age/weight']))
|
||||||
|
self.assertEquals(
|
||||||
|
100, len(classifier.linear_weights_['linear/language_weights']))
|
||||||
|
|
||||||
|
def testLinearOnlyOneFeature(self):
|
||||||
|
"""Tests that linear-only instantiation works for one feature only."""
|
||||||
|
def input_fn():
|
||||||
|
return {
|
||||||
|
'language': tf.SparseTensor(values=['english'],
|
||||||
|
indices=[[0, 0]],
|
||||||
|
shape=[1, 1])
|
||||||
|
}, tf.constant([[1]])
|
||||||
|
|
||||||
|
language = tf.contrib.layers.sparse_column_with_hash_bucket('language', 99)
|
||||||
|
|
||||||
|
classifier = tf.contrib.learn.DNNLinearCombinedClassifier(
|
||||||
|
linear_feature_columns=[language])
|
||||||
|
classifier.fit(input_fn=input_fn, steps=100)
|
||||||
|
loss1 = classifier.evaluate(input_fn=input_fn, steps=1)['loss']
|
||||||
|
classifier.fit(input_fn=input_fn, steps=200)
|
||||||
|
loss2 = classifier.evaluate(input_fn=input_fn, steps=1)['loss']
|
||||||
|
self.assertLess(loss2, loss1)
|
||||||
|
self.assertLess(loss2, 0.01)
|
||||||
|
self.assertTrue('centered_bias_weight' in classifier.get_variable_names())
|
||||||
|
|
||||||
|
self.assertNotIn('dnn/logits/biases', classifier.get_variable_names())
|
||||||
|
self.assertNotIn('dnn/logits/weights', classifier.get_variable_names())
|
||||||
|
self.assertEquals(1, len(classifier.linear_bias_))
|
||||||
|
self.assertEquals(99, len(classifier.linear_weights_))
|
||||||
|
|
||||||
def testDNNOnly(self):
|
def testDNNOnly(self):
|
||||||
"""Tests that DNN-only instantiation works."""
|
"""Tests that DNN-only instantiation works."""
|
||||||
|
@ -31,7 +31,9 @@ import six
|
|||||||
|
|
||||||
from tensorflow.contrib import framework as contrib_framework
|
from tensorflow.contrib import framework as contrib_framework
|
||||||
from tensorflow.contrib import layers
|
from tensorflow.contrib import layers
|
||||||
|
from tensorflow.contrib.learn.python.learn import evaluable
|
||||||
from tensorflow.contrib.learn.python.learn import graph_actions
|
from tensorflow.contrib.learn.python.learn import graph_actions
|
||||||
|
from tensorflow.contrib.learn.python.learn import trainable
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import _sklearn as sklearn
|
from tensorflow.contrib.learn.python.learn.estimators import _sklearn as sklearn
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import run_config
|
from tensorflow.contrib.learn.python.learn.estimators import run_config
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import tensor_signature
|
from tensorflow.contrib.learn.python.learn.estimators import tensor_signature
|
||||||
@ -138,7 +140,8 @@ def _get_arguments(func):
|
|||||||
return _get_arguments(func.func)
|
return _get_arguments(func.func)
|
||||||
|
|
||||||
|
|
||||||
class BaseEstimator(sklearn.BaseEstimator):
|
class BaseEstimator(
|
||||||
|
sklearn.BaseEstimator, evaluable.Evaluable, trainable.Trainable):
|
||||||
"""Abstract BaseEstimator class to train and evaluate TensorFlow models.
|
"""Abstract BaseEstimator class to train and evaluate TensorFlow models.
|
||||||
|
|
||||||
Concrete implementation of this class should provide the following functions:
|
Concrete implementation of this class should provide the following functions:
|
||||||
@ -158,9 +161,9 @@ class BaseEstimator(sklearn.BaseEstimator):
|
|||||||
"""Initializes a BaseEstimator instance.
|
"""Initializes a BaseEstimator instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_dir: Directory to save model parameters, graph and etc. This can also
|
model_dir: Directory to save model parameters, graph and etc. This can
|
||||||
be used to load checkpoints from the directory into a estimator to continue
|
also be used to load checkpoints from the directory into a estimator to
|
||||||
training a previously saved model.
|
continue training a previously saved model.
|
||||||
config: A RunConfig instance.
|
config: A RunConfig instance.
|
||||||
"""
|
"""
|
||||||
# Model directory.
|
# Model directory.
|
||||||
@ -196,34 +199,8 @@ class BaseEstimator(sklearn.BaseEstimator):
|
|||||||
|
|
||||||
def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
|
def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
|
||||||
monitors=None, max_steps=None):
|
monitors=None, max_steps=None):
|
||||||
"""Trains a model given training data `x` predictions and `y` targets.
|
# pylint: disable=g-doc-args,g-doc-return-or-yield
|
||||||
|
"""See `Trainable`.
|
||||||
Args:
|
|
||||||
x: Matrix of shape [n_samples, n_features...]. Can be iterator that
|
|
||||||
returns arrays of features. The training input samples for fitting the
|
|
||||||
model. If set, `input_fn` must be `None`.
|
|
||||||
y: Vector or matrix [n_samples] or [n_samples, n_outputs]. Can be
|
|
||||||
iterator that returns array of targets. The training target values
|
|
||||||
(class labels in classification, real numbers in regression). If set,
|
|
||||||
`input_fn` must be `None`.
|
|
||||||
input_fn: Input function. If set, `x`, `y`, and `batch_size` must be
|
|
||||||
`None`.
|
|
||||||
steps: Number of steps for which to train model. If `None`, train forever.
|
|
||||||
If set, `max_steps` must be `None`.
|
|
||||||
batch_size: minibatch size to use on the input, defaults to first
|
|
||||||
dimension of `x`. Must be `None` if `input_fn` is provided.
|
|
||||||
monitors: List of `BaseMonitor` subclass instances. Used for callbacks
|
|
||||||
inside the training loop.
|
|
||||||
max_steps: Number of total steps for which to train model. If `None`,
|
|
||||||
train forever. If set, `steps` must be `None`.
|
|
||||||
|
|
||||||
Two calls to `fit(steps=100)` means 200 training
|
|
||||||
iterations. On the other hand, two calls to `fit(max_steps=100)` means
|
|
||||||
that the second call will not do any iteration since first call did
|
|
||||||
all 100 steps.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`self`, for chaining.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If `x` or `y` are not `None` while `input_fn` is not `None`.
|
ValueError: If `x` or `y` are not `None` while `input_fn` is not `None`.
|
||||||
@ -284,61 +261,11 @@ class BaseEstimator(sklearn.BaseEstimator):
|
|||||||
return self.fit(x=x, y=y, input_fn=input_fn, steps=steps,
|
return self.fit(x=x, y=y, input_fn=input_fn, steps=steps,
|
||||||
batch_size=batch_size, monitors=monitors)
|
batch_size=batch_size, monitors=monitors)
|
||||||
|
|
||||||
def evaluate(self,
|
def evaluate(
|
||||||
x=None,
|
self, x=None, y=None, input_fn=None, feed_fn=None, batch_size=None,
|
||||||
y=None,
|
steps=None, metrics=None, name=None):
|
||||||
input_fn=None,
|
# pylint: disable=g-doc-args,g-doc-return-or-yield
|
||||||
feed_fn=None,
|
"""See `Evaluable`.
|
||||||
batch_size=None,
|
|
||||||
steps=None,
|
|
||||||
metrics=None,
|
|
||||||
name=None):
|
|
||||||
"""Evaluates given model with provided evaluation data.
|
|
||||||
|
|
||||||
Evaluates on the given input data. If `input_fn` is provided, that
|
|
||||||
input function should raise an end-of-input exception (`OutOfRangeError` or
|
|
||||||
`StopIteration`) after one epoch of the training data has been provided.
|
|
||||||
|
|
||||||
By default, the whole evaluation dataset is used. If `steps` is provided,
|
|
||||||
only `steps` batches of size `batch_size` are processed.
|
|
||||||
|
|
||||||
The return value is a dict containing the metrics specified in `metrics`, as
|
|
||||||
well as an entry `global_step` which contains the value of the global step
|
|
||||||
for which this evaluation was performed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: Matrix of shape [n_samples, n_features...]. Can be iterator that
|
|
||||||
returns arrays of features. The training input samples for fitting the
|
|
||||||
model. If set, `input_fn` must be `None`.
|
|
||||||
y: Vector or matrix [n_samples] or [n_samples, n_outputs]. Can be
|
|
||||||
iterator that returns array of targets. The training target values
|
|
||||||
(class labels in classification, real numbers in regression). If set,
|
|
||||||
`input_fn` must be `None`.
|
|
||||||
input_fn: Input function. If set, `x`, `y`, and `batch_size` must be
|
|
||||||
`None`.
|
|
||||||
feed_fn: Function creating a feed dict every time it is called. Called
|
|
||||||
once per iteration.
|
|
||||||
batch_size: minibatch size to use on the input, defaults to first
|
|
||||||
dimension of `x`, if specified. Must be `None` if `input_fn` is
|
|
||||||
provided.
|
|
||||||
steps: Number of steps for which to evaluate model. If `None`, evaluate
|
|
||||||
until running tensors generated by `metrics` raises an exception.
|
|
||||||
metrics: Dict of metric ops to run. If `None`, the default metric
|
|
||||||
functions are used; if `{}`, no metrics are used. If model has one
|
|
||||||
output (i.e., returning single predction), keys are `str`, e.g.
|
|
||||||
`'accuracy'` - just a name of the metric that will show up in
|
|
||||||
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
|
|
||||||
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
|
|
||||||
the predictions to run this metric on.
|
|
||||||
|
|
||||||
Metric ops should support streaming, e.g., returning
|
|
||||||
update_op and value tensors. See more details in
|
|
||||||
../../../../metrics/python/metrics/ops/streaming_metrics.py.
|
|
||||||
name: Name of the evaluation if user needs to run multiple evaluations on
|
|
||||||
different data sets, such as on training data vs test data.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Returns `dict` with evaluation results.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If at least one of `x` or `y` is provided, and at least one of
|
ValueError: If at least one of `x` or `y` is provided, and at least one of
|
||||||
@ -571,7 +498,7 @@ class BaseEstimator(sklearn.BaseEstimator):
|
|||||||
log_every_steps=log_every_steps,
|
log_every_steps=log_every_steps,
|
||||||
supervisor_is_chief=(self._config.task == 0),
|
supervisor_is_chief=(self._config.task == 0),
|
||||||
supervisor_master=self._config.master,
|
supervisor_master=self._config.master,
|
||||||
supervisor_save_model_steps=self._config.save_checkpoints_steps,
|
supervisor_save_model_secs=self._config.save_checkpoints_secs,
|
||||||
keep_checkpoint_max=self._config.keep_checkpoint_max,
|
keep_checkpoint_max=self._config.keep_checkpoint_max,
|
||||||
feed_fn=feed_fn,
|
feed_fn=feed_fn,
|
||||||
steps=steps,
|
steps=steps,
|
||||||
@ -770,9 +697,9 @@ class Estimator(BaseEstimator):
|
|||||||
is passed to Estimator in `params` parameter. This allows
|
is passed to Estimator in `params` parameter. This allows
|
||||||
to configure Estimators from hyper parameter tunning.
|
to configure Estimators from hyper parameter tunning.
|
||||||
|
|
||||||
model_dir: Directory to save model parameters, graph and etc. This can also
|
model_dir: Directory to save model parameters, graph and etc. This can
|
||||||
be used to load checkpoints from the directory into a estimator to continue
|
also be used to load checkpoints from the directory into a estimator to
|
||||||
training a previously saved model.
|
continue training a previously saved model.
|
||||||
config: Configuration object.
|
config: Configuration object.
|
||||||
params: `dict` of hyper parameters that will be passed into `model_fn`.
|
params: `dict` of hyper parameters that will be passed into `model_fn`.
|
||||||
Keys are names of parameters, values are basic python types.
|
Keys are names of parameters, values are basic python types.
|
||||||
|
@ -36,32 +36,26 @@ _IRIS_INPUT_DIM = 4
|
|||||||
|
|
||||||
def boston_input_fn(num_epochs=None):
|
def boston_input_fn(num_epochs=None):
|
||||||
boston = tf.contrib.learn.datasets.load_boston()
|
boston = tf.contrib.learn.datasets.load_boston()
|
||||||
features = tf.cast(
|
features = tf.reshape(tf.constant(boston.data), [-1, _BOSTON_INPUT_DIM])
|
||||||
tf.reshape(tf.constant(boston.data), [-1, _BOSTON_INPUT_DIM]), tf.float32)
|
|
||||||
if num_epochs:
|
if num_epochs:
|
||||||
features = tf.train.limit_epochs(features, num_epochs=num_epochs)
|
features = tf.train.limit_epochs(features, num_epochs=num_epochs)
|
||||||
target = tf.cast(
|
target = tf.reshape(tf.constant(boston.target), [-1, 1])
|
||||||
tf.reshape(tf.constant(boston.target), [-1, 1]), tf.float32)
|
|
||||||
return features, target
|
return features, target
|
||||||
|
|
||||||
|
|
||||||
def iris_input_fn():
|
def iris_input_fn():
|
||||||
iris = tf.contrib.learn.datasets.load_iris()
|
iris = tf.contrib.learn.datasets.load_iris()
|
||||||
features = tf.cast(
|
features = tf.reshape(tf.constant(iris.data), [-1, _IRIS_INPUT_DIM])
|
||||||
tf.reshape(tf.constant(iris.data), [-1, _IRIS_INPUT_DIM]), tf.float32)
|
target = tf.reshape(tf.constant(iris.target), [-1])
|
||||||
target = tf.cast(
|
|
||||||
tf.reshape(tf.constant(iris.target), [-1]), tf.int32)
|
|
||||||
return features, target
|
return features, target
|
||||||
|
|
||||||
|
|
||||||
def boston_eval_fn():
|
def boston_eval_fn():
|
||||||
boston = tf.contrib.learn.datasets.load_boston()
|
boston = tf.contrib.learn.datasets.load_boston()
|
||||||
n_examples = len(boston.target)
|
n_examples = len(boston.target)
|
||||||
features = tf.cast(
|
features = tf.reshape(
|
||||||
tf.reshape(tf.constant(boston.data), [n_examples, _BOSTON_INPUT_DIM]),
|
tf.constant(boston.data), [n_examples, _BOSTON_INPUT_DIM])
|
||||||
tf.float32)
|
target = tf.reshape(tf.constant(boston.target), [n_examples, 1])
|
||||||
target = tf.cast(
|
|
||||||
tf.reshape(tf.constant(boston.target), [n_examples, 1]), tf.float32)
|
|
||||||
return tf.concat(0, [features, features]), tf.concat(0, [target, target])
|
return tf.concat(0, [features, features]), tf.concat(0, [target, target])
|
||||||
|
|
||||||
|
|
||||||
@ -188,7 +182,7 @@ class EstimatorTest(tf.test.TestCase):
|
|||||||
with self.assertRaises(tf.contrib.learn.NotFittedError):
|
with self.assertRaises(tf.contrib.learn.NotFittedError):
|
||||||
_ = est.evaluate(
|
_ = est.evaluate(
|
||||||
x=boston.data,
|
x=boston.data,
|
||||||
y=boston.target.astype(np.float32))
|
y=boston.target.astype(np.float64))
|
||||||
with self.assertRaises(tf.contrib.learn.NotFittedError):
|
with self.assertRaises(tf.contrib.learn.NotFittedError):
|
||||||
est.predict(x=boston.data)
|
est.predict(x=boston.data)
|
||||||
|
|
||||||
@ -197,10 +191,11 @@ class EstimatorTest(tf.test.TestCase):
|
|||||||
output_dir = tempfile.mkdtemp()
|
output_dir = tempfile.mkdtemp()
|
||||||
est = tf.contrib.learn.Estimator(model_fn=linear_model_fn,
|
est = tf.contrib.learn.Estimator(model_fn=linear_model_fn,
|
||||||
model_dir=output_dir)
|
model_dir=output_dir)
|
||||||
est.fit(x=boston.data, y=boston.target.astype(np.float32), steps=50)
|
float64_target = boston.target.astype(np.float64)
|
||||||
|
est.fit(x=boston.data, y=float64_target, steps=50)
|
||||||
scores = est.evaluate(
|
scores = est.evaluate(
|
||||||
x=boston.data,
|
x=boston.data,
|
||||||
y=boston.target.astype(np.float32),
|
y=float64_target,
|
||||||
metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error})
|
metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error})
|
||||||
del est
|
del est
|
||||||
# Create another estimator object with the same output dir.
|
# Create another estimator object with the same output dir.
|
||||||
@ -210,19 +205,19 @@ class EstimatorTest(tf.test.TestCase):
|
|||||||
# Check we can evaluate and predict.
|
# Check we can evaluate and predict.
|
||||||
scores2 = est2.evaluate(
|
scores2 = est2.evaluate(
|
||||||
x=boston.data,
|
x=boston.data,
|
||||||
y=boston.target.astype(np.float32),
|
y=float64_target,
|
||||||
metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error})
|
metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error})
|
||||||
self.assertAllClose(scores2['MSE'],
|
self.assertAllClose(scores2['MSE'],
|
||||||
scores['MSE'])
|
scores['MSE'])
|
||||||
predictions = est2.predict(x=boston.data)
|
predictions = est2.predict(x=boston.data)
|
||||||
other_score = _sklearn.mean_squared_error(predictions, boston.target)
|
other_score = _sklearn.mean_squared_error(predictions, float64_target)
|
||||||
self.assertAllClose(other_score, scores['MSE'])
|
self.assertAllClose(other_score, scores['MSE'])
|
||||||
|
|
||||||
# Check we can keep training.
|
# Check we can keep training.
|
||||||
est2.fit(x=boston.data, y=boston.target.astype(np.float32), steps=100)
|
est2.fit(x=boston.data, y=float64_target, steps=100)
|
||||||
scores3 = est2.evaluate(
|
scores3 = est2.evaluate(
|
||||||
x=boston.data,
|
x=boston.data,
|
||||||
y=boston.target.astype(np.float32),
|
y=float64_target,
|
||||||
metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error})
|
metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error})
|
||||||
self.assertLess(scores3['MSE'], scores['MSE'])
|
self.assertLess(scores3['MSE'], scores['MSE'])
|
||||||
|
|
||||||
@ -230,15 +225,16 @@ class EstimatorTest(tf.test.TestCase):
|
|||||||
boston = tf.contrib.learn.datasets.load_boston()
|
boston = tf.contrib.learn.datasets.load_boston()
|
||||||
est = tf.contrib.learn.Estimator(model_fn=linear_model_params_fn,
|
est = tf.contrib.learn.Estimator(model_fn=linear_model_params_fn,
|
||||||
params={'learning_rate': 0.01})
|
params={'learning_rate': 0.01})
|
||||||
est.fit(x=boston.data, y=boston.target.astype(np.float32), steps=100)
|
est.fit(x=boston.data, y=boston.target, steps=100)
|
||||||
|
|
||||||
def testBostonAll(self):
|
def testBostonAll(self):
|
||||||
boston = tf.contrib.learn.datasets.load_boston()
|
boston = tf.contrib.learn.datasets.load_boston()
|
||||||
est = tf.contrib.learn.Estimator(model_fn=linear_model_fn)
|
est = tf.contrib.learn.Estimator(model_fn=linear_model_fn)
|
||||||
est.fit(x=boston.data, y=boston.target.astype(np.float32), steps=100)
|
float64_target = boston.target.astype(np.float64)
|
||||||
|
est.fit(x=boston.data, y=float64_target, steps=100)
|
||||||
scores = est.evaluate(
|
scores = est.evaluate(
|
||||||
x=boston.data,
|
x=boston.data,
|
||||||
y=boston.target.astype(np.float32),
|
y=float64_target,
|
||||||
metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error})
|
metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error})
|
||||||
predictions = est.predict(x=boston.data)
|
predictions = est.predict(x=boston.data)
|
||||||
other_score = _sklearn.mean_squared_error(predictions, boston.target)
|
other_score = _sklearn.mean_squared_error(predictions, boston.target)
|
||||||
@ -277,7 +273,7 @@ class EstimatorTest(tf.test.TestCase):
|
|||||||
iris = tf.contrib.learn.datasets.load_iris()
|
iris = tf.contrib.learn.datasets.load_iris()
|
||||||
est = tf.contrib.learn.Estimator(model_fn=logistic_model_no_mode_fn)
|
est = tf.contrib.learn.Estimator(model_fn=logistic_model_no_mode_fn)
|
||||||
x_iter = itertools.islice(iris.data, 100)
|
x_iter = itertools.islice(iris.data, 100)
|
||||||
y_iter = itertools.islice(np.int32(iris.target), 100)
|
y_iter = itertools.islice(iris.target, 100)
|
||||||
est.fit(x_iter, y_iter, steps=100)
|
est.fit(x_iter, y_iter, steps=100)
|
||||||
_ = est.evaluate(input_fn=iris_input_fn, steps=1)
|
_ = est.evaluate(input_fn=iris_input_fn, steps=1)
|
||||||
predictions = est.predict(x=iris.data)['class']
|
predictions = est.predict(x=iris.data)['class']
|
||||||
@ -374,19 +370,16 @@ class InferRealValuedColumnsTest(tf.test.TestCase):
|
|||||||
'': tf.FixedLenFeature(shape=expected_shape, dtype=expected_dtype)
|
'': tf.FixedLenFeature(shape=expected_shape, dtype=expected_dtype)
|
||||||
}, feature_column.config)
|
}, feature_column.config)
|
||||||
|
|
||||||
# Note: See tf.contrib.learn.io.data_feeder for why int32 converts to float32.
|
|
||||||
def testInt32Input(self):
|
def testInt32Input(self):
|
||||||
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(
|
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(
|
||||||
np.ones(shape=[7, 8], dtype=np.int32))
|
np.ones(shape=[7, 8], dtype=np.int32))
|
||||||
self._assert_single_feature_column([8], tf.float32, feature_columns)
|
self._assert_single_feature_column([8], tf.int32, feature_columns)
|
||||||
|
|
||||||
def testInt32InputFn(self):
|
def testInt32InputFn(self):
|
||||||
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input_fn(
|
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input_fn(
|
||||||
lambda: (tf.ones(shape=[7, 8], dtype=tf.int32), None))
|
lambda: (tf.ones(shape=[7, 8], dtype=tf.int32), None))
|
||||||
self._assert_single_feature_column([8], tf.int32, feature_columns)
|
self._assert_single_feature_column([8], tf.int32, feature_columns)
|
||||||
|
|
||||||
# Note: See tf.contrib.learn.io.data_feeder for why int64 doesn't convert to
|
|
||||||
# float64.
|
|
||||||
def testInt64Input(self):
|
def testInt64Input(self):
|
||||||
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(
|
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(
|
||||||
np.ones(shape=[7, 8], dtype=np.int64))
|
np.ones(shape=[7, 8], dtype=np.int64))
|
||||||
@ -407,12 +400,10 @@ class InferRealValuedColumnsTest(tf.test.TestCase):
|
|||||||
lambda: (tf.ones(shape=[7, 8], dtype=tf.float32), None))
|
lambda: (tf.ones(shape=[7, 8], dtype=tf.float32), None))
|
||||||
self._assert_single_feature_column([8], tf.float32, feature_columns)
|
self._assert_single_feature_column([8], tf.float32, feature_columns)
|
||||||
|
|
||||||
# Note: See tf.contrib.learn.io.data_feeder for why float64 converts to
|
|
||||||
# float32.
|
|
||||||
def testFloat64Input(self):
|
def testFloat64Input(self):
|
||||||
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(
|
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(
|
||||||
np.ones(shape=[7, 8], dtype=np.float64))
|
np.ones(shape=[7, 8], dtype=np.float64))
|
||||||
self._assert_single_feature_column([8], tf.float32, feature_columns)
|
self._assert_single_feature_column([8], tf.float64, feature_columns)
|
||||||
|
|
||||||
def testFloat64InputFn(self):
|
def testFloat64InputFn(self):
|
||||||
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input_fn(
|
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input_fn(
|
||||||
@ -420,9 +411,10 @@ class InferRealValuedColumnsTest(tf.test.TestCase):
|
|||||||
self._assert_single_feature_column([8], tf.float64, feature_columns)
|
self._assert_single_feature_column([8], tf.float64, feature_columns)
|
||||||
|
|
||||||
def testBoolInput(self):
|
def testBoolInput(self):
|
||||||
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, 'on integer or non floating types are not supported'):
|
||||||
|
tf.contrib.learn.infer_real_valued_columns_from_input(
|
||||||
np.array([[False for _ in xrange(8)] for _ in xrange(7)]))
|
np.array([[False for _ in xrange(8)] for _ in xrange(7)]))
|
||||||
self._assert_single_feature_column([8], tf.float32, feature_columns)
|
|
||||||
|
|
||||||
def testBoolInputFn(self):
|
def testBoolInputFn(self):
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
@ -431,18 +423,12 @@ class InferRealValuedColumnsTest(tf.test.TestCase):
|
|||||||
tf.contrib.learn.infer_real_valued_columns_from_input_fn(
|
tf.contrib.learn.infer_real_valued_columns_from_input_fn(
|
||||||
lambda: (tf.constant(False, shape=[7, 8], dtype=tf.bool), None))
|
lambda: (tf.constant(False, shape=[7, 8], dtype=tf.bool), None))
|
||||||
|
|
||||||
def testInvalidStringInput(self):
|
|
||||||
# pylint: disable=g-long-lambda
|
|
||||||
with self.assertRaisesRegexp(
|
|
||||||
ValueError, 'could not convert string to float'):
|
|
||||||
tf.contrib.learn.infer_real_valued_columns_from_input(
|
|
||||||
np.array([['foo%d' % i for i in xrange(8)] for _ in xrange(7)]))
|
|
||||||
|
|
||||||
def testStringInput(self):
|
def testStringInput(self):
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, 'on integer or non floating types are not supported'):
|
||||||
# pylint: disable=g-long-lambda
|
# pylint: disable=g-long-lambda
|
||||||
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(
|
tf.contrib.learn.infer_real_valued_columns_from_input(
|
||||||
np.array([['%d.0' % i for i in xrange(8)] for _ in xrange(7)]))
|
np.array([['%d.0' % i for i in xrange(8)] for _ in xrange(7)]))
|
||||||
self._assert_single_feature_column([8], tf.float32, feature_columns)
|
|
||||||
|
|
||||||
def testStringInputFn(self):
|
def testStringInputFn(self):
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
@ -457,13 +443,13 @@ class InferRealValuedColumnsTest(tf.test.TestCase):
|
|||||||
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input_fn(
|
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input_fn(
|
||||||
boston_input_fn)
|
boston_input_fn)
|
||||||
self._assert_single_feature_column(
|
self._assert_single_feature_column(
|
||||||
[_BOSTON_INPUT_DIM], tf.float32, feature_columns)
|
[_BOSTON_INPUT_DIM], tf.float64, feature_columns)
|
||||||
|
|
||||||
def testIrisInputFn(self):
|
def testIrisInputFn(self):
|
||||||
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input_fn(
|
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input_fn(
|
||||||
iris_input_fn)
|
iris_input_fn)
|
||||||
self._assert_single_feature_column(
|
self._assert_single_feature_column(
|
||||||
[_IRIS_INPUT_DIM], tf.float32, feature_columns)
|
[_IRIS_INPUT_DIM], tf.float64, feature_columns)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
@ -122,9 +122,9 @@ class LinearClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
|
|||||||
feature_columns: An iterable containing all the feature columns used by
|
feature_columns: An iterable containing all the feature columns used by
|
||||||
the model. All items in the set should be instances of classes derived
|
the model. All items in the set should be instances of classes derived
|
||||||
from `FeatureColumn`.
|
from `FeatureColumn`.
|
||||||
model_dir: Directory to save model parameters, graph and etc. This can also
|
model_dir: Directory to save model parameters, graph and etc. This can
|
||||||
be used to load checkpoints from the directory into a estimator to continue
|
also be used to load checkpoints from the directory into a estimator
|
||||||
training a previously saved model.
|
to continue training a previously saved model.
|
||||||
n_classes: number of target classes. Default is binary classification.
|
n_classes: number of target classes. Default is binary classification.
|
||||||
weight_column_name: A string defining feature column name representing
|
weight_column_name: A string defining feature column name representing
|
||||||
weights. It is used to down weight or boost examples during training. It
|
weights. It is used to down weight or boost examples during training. It
|
||||||
@ -186,8 +186,8 @@ class LinearClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
|
|||||||
columns_to_tensors=features,
|
columns_to_tensors=features,
|
||||||
feature_columns=self._linear_feature_columns,
|
feature_columns=self._linear_feature_columns,
|
||||||
num_outputs=self._target_column.num_label_columns,
|
num_outputs=self._target_column.num_label_columns,
|
||||||
weight_collections=[self._linear_weight_collection],
|
weight_collections=[self._linear_model.get_scope_name()],
|
||||||
scope="linear")
|
scope=self._linear_model.get_scope_name())
|
||||||
with ops.control_dependencies([self._centered_bias()]):
|
with ops.control_dependencies([self._centered_bias()]):
|
||||||
loss = self._target_column.loss(logits, targets, features)
|
loss = self._target_column.loss(logits, targets, features)
|
||||||
logging_ops.scalar_summary("loss", loss)
|
logging_ops.scalar_summary("loss", loss)
|
||||||
@ -282,9 +282,9 @@ class LinearRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):
|
|||||||
feature_columns: An iterable containing all the feature columns used by
|
feature_columns: An iterable containing all the feature columns used by
|
||||||
the model. All items in the set should be instances of classes derived
|
the model. All items in the set should be instances of classes derived
|
||||||
from `FeatureColumn`.
|
from `FeatureColumn`.
|
||||||
model_dir: Directory to save model parameters, graph, etc. This can also
|
model_dir: Directory to save model parameters, graph, etc. This can
|
||||||
be used to load checkpoints from the directory into a estimator to continue
|
also be used to load checkpoints from the directory into a estimator
|
||||||
training a previously saved model.
|
to continue training a previously saved model.
|
||||||
weight_column_name: A string defining feature column name representing
|
weight_column_name: A string defining feature column name representing
|
||||||
weights. It is used to down weight or boost examples during training. It
|
weights. It is used to down weight or boost examples during training. It
|
||||||
will be multiplied by the loss of the example.
|
will be multiplied by the loss of the example.
|
||||||
|
@ -56,7 +56,7 @@ class LogisticRegressor(estimator.Estimator):
|
|||||||
model_fn: Model function. See superclass Estimator for more details. This
|
model_fn: Model function. See superclass Estimator for more details. This
|
||||||
expects the returned predictions to be probabilities in [0.0, 1.0].
|
expects the returned predictions to be probabilities in [0.0, 1.0].
|
||||||
thresholds: List of floating point thresholds to use for accuracy,
|
thresholds: List of floating point thresholds to use for accuracy,
|
||||||
precision, and recall metrics. If None, defaults to [0.5].
|
precision, and recall metrics. If `None`, defaults to `[0.5]`.
|
||||||
model_dir: Directory to save model parameters, graphs, etc. This can also
|
model_dir: Directory to save model parameters, graphs, etc. This can also
|
||||||
be used to load checkpoints from the directory into a estimator to continue
|
be used to load checkpoints from the directory into a estimator to continue
|
||||||
training a previously saved model.
|
training a previously saved model.
|
||||||
|
@ -17,8 +17,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
|
|
||||||
@ -26,18 +24,36 @@ from tensorflow.contrib import framework as contrib_framework
|
|||||||
from tensorflow.contrib.learn.python.learn import monitors as mon
|
from tensorflow.contrib.learn.python.learn import monitors as mon
|
||||||
|
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import run_config
|
|
||||||
|
|
||||||
from tensorflow.contrib.tensor_forest.client import eval_metrics
|
from tensorflow.contrib.tensor_forest.client import eval_metrics
|
||||||
from tensorflow.contrib.tensor_forest.data import data_ops
|
from tensorflow.contrib.tensor_forest.data import data_ops
|
||||||
from tensorflow.contrib.tensor_forest.python import tensor_forest
|
from tensorflow.contrib.tensor_forest.python import tensor_forest
|
||||||
|
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_float32(tensors):
|
||||||
|
"""Assert all tensors are float32.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensors: `Tensor` or `dict` of `Tensor` objects.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if any tensor is not float32.
|
||||||
|
"""
|
||||||
|
if not isinstance(tensors, dict):
|
||||||
|
tensors = [tensors]
|
||||||
|
else:
|
||||||
|
tensors = tensors.values()
|
||||||
|
for tensor in tensors:
|
||||||
|
if tensor.dtype.base_dtype != dtypes.float32:
|
||||||
|
raise TypeError('Expected dtype=float32, %s.' % tensor)
|
||||||
|
|
||||||
|
|
||||||
class LossMonitor(mon.EveryN):
|
class LossMonitor(mon.EveryN):
|
||||||
"""Terminates training when training loss stops decreasing."""
|
"""Terminates training when training loss stops decreasing."""
|
||||||
|
|
||||||
@ -146,6 +162,8 @@ class TensorForestEstimator(estimator.BaseEstimator):
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of train `Operation` and loss `Tensor`.
|
Tuple of train `Operation` and loss `Tensor`.
|
||||||
"""
|
"""
|
||||||
|
_assert_float32(features)
|
||||||
|
_assert_float32(targets)
|
||||||
features, spec = data_ops.ParseDataTensorOrDict(features)
|
features, spec = data_ops.ParseDataTensorOrDict(features)
|
||||||
labels = data_ops.ParseLabelTensorOrDict(targets)
|
labels = data_ops.ParseLabelTensorOrDict(targets)
|
||||||
|
|
||||||
@ -168,6 +186,7 @@ class TensorForestEstimator(estimator.BaseEstimator):
|
|||||||
return train, self.training_loss
|
return train, self.training_loss
|
||||||
|
|
||||||
def _get_predict_ops(self, features):
|
def _get_predict_ops(self, features):
|
||||||
|
_assert_float32(features)
|
||||||
graph_builder = self.graph_builder_class(
|
graph_builder = self.graph_builder_class(
|
||||||
self.params, device_assigner=self.device_assigner, training=False,
|
self.params, device_assigner=self.device_assigner, training=False,
|
||||||
**self.construction_args)
|
**self.construction_args)
|
||||||
@ -175,6 +194,8 @@ class TensorForestEstimator(estimator.BaseEstimator):
|
|||||||
return graph_builder.inference_graph(features, data_spec=spec)
|
return graph_builder.inference_graph(features, data_spec=spec)
|
||||||
|
|
||||||
def _get_eval_ops(self, features, targets, metrics):
|
def _get_eval_ops(self, features, targets, metrics):
|
||||||
|
_assert_float32(features)
|
||||||
|
_assert_float32(targets)
|
||||||
features, spec = data_ops.ParseDataTensorOrDict(features)
|
features, spec = data_ops.ParseDataTensorOrDict(features)
|
||||||
labels = data_ops.ParseLabelTensorOrDict(targets)
|
labels = data_ops.ParseLabelTensorOrDict(targets)
|
||||||
|
|
||||||
|
@ -19,11 +19,20 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
class TensorForestTrainerTests(tf.test.TestCase):
|
class TensorForestTrainerTests(tf.test.TestCase):
|
||||||
|
|
||||||
|
def testFloat64(self):
|
||||||
|
hparams = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
|
||||||
|
num_trees=3, max_nodes=1000, num_classes=3, num_features=4)
|
||||||
|
classifier = tf.contrib.learn.TensorForestEstimator(hparams)
|
||||||
|
iris = tf.contrib.learn.datasets.load_iris()
|
||||||
|
with self.assertRaisesRegexp(TypeError, 'float32'):
|
||||||
|
classifier.fit(x=iris.data, y=iris.target, steps=100)
|
||||||
|
|
||||||
def testClassification(self):
|
def testClassification(self):
|
||||||
"""Tests multi-class classification using matrix data as input."""
|
"""Tests multi-class classification using matrix data as input."""
|
||||||
hparams = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
|
hparams = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
|
||||||
@ -31,9 +40,11 @@ class TensorForestTrainerTests(tf.test.TestCase):
|
|||||||
classifier = tf.contrib.learn.TensorForestEstimator(hparams)
|
classifier = tf.contrib.learn.TensorForestEstimator(hparams)
|
||||||
|
|
||||||
iris = tf.contrib.learn.datasets.load_iris()
|
iris = tf.contrib.learn.datasets.load_iris()
|
||||||
|
data = iris.data.astype(np.float32)
|
||||||
|
target = iris.target.astype(np.float32)
|
||||||
|
|
||||||
classifier.fit(x=iris.data, y=iris.target, steps=100)
|
classifier.fit(x=data, y=target, steps=100)
|
||||||
classifier.evaluate(x=iris.data, y=iris.target, steps=10)
|
classifier.evaluate(x=data, y=target, steps=10)
|
||||||
|
|
||||||
def testRegression(self):
|
def testRegression(self):
|
||||||
"""Tests multi-class classification using matrix data as input."""
|
"""Tests multi-class classification using matrix data as input."""
|
||||||
@ -45,9 +56,11 @@ class TensorForestTrainerTests(tf.test.TestCase):
|
|||||||
regressor = tf.contrib.learn.TensorForestEstimator(hparams)
|
regressor = tf.contrib.learn.TensorForestEstimator(hparams)
|
||||||
|
|
||||||
boston = tf.contrib.learn.datasets.load_boston()
|
boston = tf.contrib.learn.datasets.load_boston()
|
||||||
|
data = boston.data.astype(np.float32)
|
||||||
|
target = boston.target.astype(np.float32)
|
||||||
|
|
||||||
regressor.fit(x=boston.data, y=boston.target, steps=100)
|
regressor.fit(x=data, y=target, steps=100)
|
||||||
regressor.evaluate(x=boston.data, y=boston.target, steps=10)
|
regressor.evaluate(x=data, y=target, steps=10)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -38,8 +38,7 @@ class RunConfig(object):
|
|||||||
save_summary_steps=100,
|
save_summary_steps=100,
|
||||||
save_checkpoints_secs=60,
|
save_checkpoints_secs=60,
|
||||||
keep_checkpoint_max=5,
|
keep_checkpoint_max=5,
|
||||||
keep_checkpoint_every_n_hours=10000,
|
keep_checkpoint_every_n_hours=10000):
|
||||||
save_checkpoints_steps=1000):
|
|
||||||
"""Constructor.
|
"""Constructor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -61,7 +60,6 @@ class RunConfig(object):
|
|||||||
keep_checkpoint_every_n_hours: Number of hours between each checkpoint
|
keep_checkpoint_every_n_hours: Number of hours between each checkpoint
|
||||||
to be saved. The default value of 10,000 hours effectively disables
|
to be saved. The default value of 10,000 hours effectively disables
|
||||||
the feature.
|
the feature.
|
||||||
save_checkpoints_steps: Number of steps between each checkpoint saving.
|
|
||||||
"""
|
"""
|
||||||
self.master = master
|
self.master = master
|
||||||
self.task = task
|
self.task = task
|
||||||
@ -77,4 +75,3 @@ class RunConfig(object):
|
|||||||
self.save_checkpoints_secs = save_checkpoints_secs
|
self.save_checkpoints_secs = save_checkpoints_secs
|
||||||
self.keep_checkpoint_max = keep_checkpoint_max
|
self.keep_checkpoint_max = keep_checkpoint_max
|
||||||
self.keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
|
self.keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
|
||||||
self.save_checkpoints_steps = save_checkpoints_steps
|
|
||||||
|
81
tensorflow/contrib/learn/python/learn/evaluable.py
Normal file
81
tensorflow/contrib/learn/python/learn/evaluable.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""`Evaluable` interface."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import abc
|
||||||
|
|
||||||
|
|
||||||
|
class Evaluable(object):
|
||||||
|
"""Interface for objects that are evaluatable by, e.g., `Experiment`.
|
||||||
|
"""
|
||||||
|
__metaclass__ = abc.ABCMeta
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def evaluate(
|
||||||
|
self, x=None, y=None, input_fn=None, feed_fn=None, batch_size=None,
|
||||||
|
steps=None, metrics=None, name=None):
|
||||||
|
"""Evaluates given model with provided evaluation data.
|
||||||
|
|
||||||
|
Evaluates on the given input data. If `input_fn` is provided, that
|
||||||
|
input function should raise an end-of-input exception (`OutOfRangeError` or
|
||||||
|
`StopIteration`) after one epoch of the training data has been provided.
|
||||||
|
|
||||||
|
By default, the whole evaluation dataset is used. If `steps` is provided,
|
||||||
|
only `steps` batches of size `batch_size` are processed.
|
||||||
|
|
||||||
|
The return value is a dict containing the metrics specified in `metrics`, as
|
||||||
|
well as an entry `global_step` which contains the value of the global step
|
||||||
|
for which this evaluation was performed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Matrix of shape [n_samples, n_features...]. Can be iterator that
|
||||||
|
returns arrays of features. The training input samples for fitting the
|
||||||
|
model. If set, `input_fn` must be `None`.
|
||||||
|
y: Vector or matrix [n_samples] or [n_samples, n_outputs]. Can be
|
||||||
|
iterator that returns array of targets. The training target values
|
||||||
|
(class labels in classification, real numbers in regression). If set,
|
||||||
|
`input_fn` must be `None`.
|
||||||
|
input_fn: Input function. If set, `x`, `y`, and `batch_size` must be
|
||||||
|
`None`.
|
||||||
|
feed_fn: Function creating a feed dict every time it is called. Called
|
||||||
|
once per iteration. Must be `None` if `input_fn` is provided.
|
||||||
|
batch_size: minibatch size to use on the input, defaults to first
|
||||||
|
dimension of `x`, if specified. Must be `None` if `input_fn` is
|
||||||
|
provided.
|
||||||
|
steps: Number of steps for which to evaluate model. If `None`, evaluate
|
||||||
|
until running tensors generated by `metrics` raises an exception.
|
||||||
|
metrics: Dict of metric ops to run. If `None`, the default metric
|
||||||
|
functions are used; if `{}`, no metrics are used. If model has one
|
||||||
|
output (i.e., returning single predction), keys are `str`, e.g.
|
||||||
|
`'accuracy'` - just a name of the metric that will show up in
|
||||||
|
the logs / summaries. Otherwise, keys are tuple of two `str`, e.g.
|
||||||
|
`('accuracy', 'classes')`- name of the metric and name of `Tensor` in
|
||||||
|
the predictions to run this metric on.
|
||||||
|
|
||||||
|
Metric ops should support streaming, e.g., returning
|
||||||
|
update_op and value tensors. See more details in
|
||||||
|
../../../metrics/python/metrics/ops/streaming_metrics.py.
|
||||||
|
name: Name of the evaluation if user needs to run multiple evaluations on
|
||||||
|
different data sets, such as on training data vs test data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Returns `dict` with evaluation results.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
@ -21,7 +21,9 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from tensorflow.contrib.learn.python.learn import evaluable
|
||||||
from tensorflow.contrib.learn.python.learn import monitors
|
from tensorflow.contrib.learn.python.learn import monitors
|
||||||
|
from tensorflow.contrib.learn.python.learn import trainable
|
||||||
from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
|
from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
|
||||||
from tensorflow.python.platform import flags
|
from tensorflow.python.platform import flags
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
@ -47,7 +49,7 @@ class Experiment(object):
|
|||||||
"""Constructor for `Experiment`.
|
"""Constructor for `Experiment`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
estimator: `Estimator` object.
|
estimator: Object implementing `Trainable` and `Evaluable`.
|
||||||
train_input_fn: function, returns features and targets for training.
|
train_input_fn: function, returns features and targets for training.
|
||||||
eval_input_fn: function, returns features and targets for evaluation. If
|
eval_input_fn: function, returns features and targets for evaluation. If
|
||||||
`eval_steps` is `None`, this should be configured only to produce for a
|
`eval_steps` is `None`, this should be configured only to produce for a
|
||||||
@ -67,7 +69,14 @@ class Experiment(object):
|
|||||||
continuous_eval_throttle_secs: Do not re-evaluate unless the last
|
continuous_eval_throttle_secs: Do not re-evaluate unless the last
|
||||||
evaluation was started at least this many seconds ago for
|
evaluation was started at least this many seconds ago for
|
||||||
continuous_eval().
|
continuous_eval().
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `estimator` does not implement `Evaluable` and `Trainable`.
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(estimator, evaluable.Evaluable):
|
||||||
|
raise ValueError("`estimator` must implement `Evaluable`.")
|
||||||
|
if not isinstance(estimator, trainable.Trainable):
|
||||||
|
raise ValueError("`estimator` must implement `Trainable`.")
|
||||||
super(Experiment, self).__init__()
|
super(Experiment, self).__init__()
|
||||||
self._estimator = estimator
|
self._estimator = estimator
|
||||||
self._train_input_fn = train_input_fn
|
self._train_input_fn = train_input_fn
|
||||||
|
@ -130,7 +130,7 @@ def _supervised_train(graph,
|
|||||||
log_every_steps=10,
|
log_every_steps=10,
|
||||||
supervisor_is_chief=True,
|
supervisor_is_chief=True,
|
||||||
supervisor_master='',
|
supervisor_master='',
|
||||||
supervisor_save_model_steps=1000,
|
supervisor_save_model_secs=600,
|
||||||
keep_checkpoint_max=5,
|
keep_checkpoint_max=5,
|
||||||
supervisor_save_summaries_steps=100,
|
supervisor_save_summaries_steps=100,
|
||||||
feed_fn=None,
|
feed_fn=None,
|
||||||
@ -171,8 +171,8 @@ def _supervised_train(graph,
|
|||||||
supervisor_is_chief: Whether the current process is the chief supervisor in
|
supervisor_is_chief: Whether the current process is the chief supervisor in
|
||||||
charge of restoring the model and running standard services.
|
charge of restoring the model and running standard services.
|
||||||
supervisor_master: The master string to use when preparing the session.
|
supervisor_master: The master string to use when preparing the session.
|
||||||
supervisor_save_model_steps: Save a checkpoint every
|
supervisor_save_model_secs: Save model every
|
||||||
`supervisor_save_model_steps` steps when training.
|
`supervisor_save_model_secs` seconds when training.
|
||||||
keep_checkpoint_max: The maximum number of recent checkpoint files to
|
keep_checkpoint_max: The maximum number of recent checkpoint files to
|
||||||
keep. As new files are created, older files are deleted. If None or 0,
|
keep. As new files are created, older files are deleted. If None or 0,
|
||||||
all checkpoint files are kept. This is simply passed as the max_to_keep
|
all checkpoint files are kept. This is simply passed as the max_to_keep
|
||||||
@ -251,15 +251,18 @@ def _supervised_train(graph,
|
|||||||
init_fn=init_fn,
|
init_fn=init_fn,
|
||||||
keep_checkpoint_max=keep_checkpoint_max)
|
keep_checkpoint_max=keep_checkpoint_max)
|
||||||
if supervisor_is_chief:
|
if supervisor_is_chief:
|
||||||
if scaffold.summary_op is not None:
|
|
||||||
monitors.append(monitors_lib.SummarySaver(
|
|
||||||
scaffold.summary_op,
|
|
||||||
save_steps=supervisor_save_summaries_steps,
|
|
||||||
summary_writer=summary_writer))
|
|
||||||
if supervisor_save_model_steps > 0:
|
|
||||||
monitors.append(
|
monitors.append(
|
||||||
monitors_lib.CheckpointSaver(supervisor_save_model_steps,
|
monitors_lib.SummarySaver(
|
||||||
scaffold.saver, output_dir))
|
summary_op=None,
|
||||||
|
save_steps=supervisor_save_summaries_steps,
|
||||||
|
summary_writer=summary_writer,
|
||||||
|
scaffold=scaffold))
|
||||||
|
if supervisor_save_model_secs > 0:
|
||||||
|
monitors.append(
|
||||||
|
monitors_lib.CheckpointSaver(
|
||||||
|
output_dir,
|
||||||
|
save_secs=supervisor_save_model_secs,
|
||||||
|
scaffold=scaffold))
|
||||||
|
|
||||||
if steps is not None or max_steps is not None:
|
if steps is not None or max_steps is not None:
|
||||||
monitors.append(monitors_lib.StopAtStep(steps, max_steps))
|
monitors.append(monitors_lib.StopAtStep(steps, max_steps))
|
||||||
|
@ -30,6 +30,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
|||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
|
||||||
# pylint: disable=g-multiple-import,g-bad-import-order
|
# pylint: disable=g-multiple-import,g-bad-import-order
|
||||||
from .pandas_io import HAS_PANDAS, extract_pandas_data, extract_pandas_matrix, extract_pandas_labels
|
from .pandas_io import HAS_PANDAS, extract_pandas_data, extract_pandas_matrix, extract_pandas_labels
|
||||||
@ -206,6 +207,13 @@ def _access(data, iloc):
|
|||||||
return data[iloc]
|
return data[iloc]
|
||||||
|
|
||||||
|
|
||||||
|
def _check_dtype(dtype):
|
||||||
|
if dtypes.as_dtype(dtype) == dtypes.float64:
|
||||||
|
logging.warn(
|
||||||
|
'float64 is not supported by many models, consider casting to float32.')
|
||||||
|
return dtype
|
||||||
|
|
||||||
|
|
||||||
class DataFeeder(object):
|
class DataFeeder(object):
|
||||||
"""Data feeder is an example class to sample data for TF trainer."""
|
"""Data feeder is an example class to sample data for TF trainer."""
|
||||||
|
|
||||||
@ -215,60 +223,82 @@ class DataFeeder(object):
|
|||||||
"""Initializes a DataFeeder instance.
|
"""Initializes a DataFeeder instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: feature Nd numpy matrix of shape [n_samples, n_features, ...].
|
x: Feature Nd numpy matrix of shape `[n_samples, n_features, ...]`.
|
||||||
y: target vector, either floats for regression or class id for
|
y: Target vector, either floats for regression or class id for
|
||||||
classification. If matrix, will consider as a sequence
|
classification. If matrix, will consider as a sequence
|
||||||
of targets. Can be None for unsupervised setting.
|
of targets. Can be `None` for unsupervised setting.
|
||||||
n_classes: number of classes, 0 and 1 are considered regression, None will
|
n_classes: Number of classes, 0 and 1 are considered regression, `None`
|
||||||
pass through the input labels without one-hot conversion.
|
will pass through the input labels without one-hot conversion.
|
||||||
batch_size: mini batch size to accumulate.
|
batch_size: Mini-batch size to accumulate.
|
||||||
random_state: numpy RandomState object to reproduce sampling.
|
shuffle: Whether to shuffle `x`.
|
||||||
|
random_state: Numpy `RandomState` object to reproduce sampling.
|
||||||
|
epochs: Number of times to iterate over input data before raising
|
||||||
|
`StopIteration` exception.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
x: input features.
|
x: Input features.
|
||||||
y: input target.
|
y: Input target.
|
||||||
n_classes: number of classes (if None, pass through indices without
|
n_classes: Number of classes (if `None`, pass through indices without
|
||||||
one-hot conversion).
|
one-hot conversion).
|
||||||
batch_size: mini batch size to accumulate.
|
batch_size: Mini-batch size to accumulate.
|
||||||
input_shape: shape of the input.
|
input_shape: Shape of the input.
|
||||||
output_shape: shape of the output.
|
output_shape: Shape of the output.
|
||||||
input_dtype: dtype of input.
|
input_dtype: DType of input.
|
||||||
output_dtype: dtype of output.
|
output_dtype: DType of output.
|
||||||
"""
|
"""
|
||||||
x_dtype = np.int64 if x.dtype == np.int64 else np.float32
|
self._x = check_array(x, dtype=x.dtype)
|
||||||
|
# self.n_classes is None means we're passing in raw target indices.
|
||||||
y_dtype = (
|
y_dtype = (
|
||||||
np.int64 if n_classes is not None and n_classes > 1 else np.float32)
|
np.int64 if n_classes is not None and n_classes > 1 else np.float32)
|
||||||
self.x = check_array(x, dtype=x_dtype)
|
|
||||||
# self.n_classes is None means we're passing in raw target indices
|
|
||||||
if n_classes is not None:
|
if n_classes is not None:
|
||||||
self.y = (None if y is None else check_array(y, dtype=y_dtype))
|
self._y = (None if y is None else check_array(y, dtype=y_dtype))
|
||||||
|
elif isinstance(y, list):
|
||||||
|
self._y = np.array(y)
|
||||||
else:
|
else:
|
||||||
self.y = y
|
self._y = y
|
||||||
if isinstance(self.y, list):
|
|
||||||
self.y = np.array(y)
|
|
||||||
self.n_classes = n_classes
|
self.n_classes = n_classes
|
||||||
self.max_epochs = epochs
|
self.max_epochs = epochs
|
||||||
self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
|
self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
|
||||||
self.x.shape, None if self.y is None else self.y.shape, n_classes,
|
self._x.shape, None if self._y is None else self._y.shape, n_classes,
|
||||||
batch_size)
|
batch_size)
|
||||||
# Input dtype matches dtype of x.
|
# Input dtype matches dtype of x.
|
||||||
self.input_dtype = x_dtype
|
self._input_dtype = _check_dtype(self._x.dtype)
|
||||||
# self.n_classes is None means we're passing in raw target indices
|
# self.n_classes is None means we're passing in raw target indices
|
||||||
if n_classes is not None or y is None:
|
if n_classes is not None or self._y is None:
|
||||||
self.output_dtype = np.float32
|
self._output_dtype = np.float32
|
||||||
else:
|
else:
|
||||||
self.output_dtype = self.y.dtype
|
self._output_dtype = _check_dtype(self._y.dtype)
|
||||||
self.shuffle = shuffle
|
self._shuffle = shuffle
|
||||||
self.random_state = np.random.RandomState(
|
self.random_state = np.random.RandomState(
|
||||||
42) if random_state is None else random_state
|
42) if random_state is None else random_state
|
||||||
if self.shuffle:
|
if self._shuffle:
|
||||||
self.indices = self.random_state.permutation(self.x.shape[0])
|
self.indices = self.random_state.permutation(self._x.shape[0])
|
||||||
else:
|
else:
|
||||||
self.indices = np.array(range(self.x.shape[0]))
|
self.indices = np.array(range(self._x.shape[0]))
|
||||||
self.offset = 0
|
self.offset = 0
|
||||||
self.epoch = 0
|
self.epoch = 0
|
||||||
self._epoch_placeholder = None
|
self._epoch_placeholder = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def x(self):
|
||||||
|
return self._x
|
||||||
|
|
||||||
|
@property
|
||||||
|
def y(self):
|
||||||
|
return self._y
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shuffle(self):
|
||||||
|
return self._shuffle
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_dtype(self):
|
||||||
|
return self._input_dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_dtype(self):
|
||||||
|
return self._output_dtype
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
return self._batch_size
|
return self._batch_size
|
||||||
@ -291,7 +321,7 @@ class DataFeeder(object):
|
|||||||
"""
|
"""
|
||||||
input_shape = [None] + self.input_shape[1:]
|
input_shape = [None] + self.input_shape[1:]
|
||||||
self._input_placeholder = array_ops.placeholder(
|
self._input_placeholder = array_ops.placeholder(
|
||||||
dtypes.as_dtype(self.input_dtype),
|
dtypes.as_dtype(self._input_dtype),
|
||||||
input_shape,
|
input_shape,
|
||||||
name='input')
|
name='input')
|
||||||
if self.output_shape is None:
|
if self.output_shape is None:
|
||||||
@ -299,7 +329,7 @@ class DataFeeder(object):
|
|||||||
else:
|
else:
|
||||||
output_shape = [None] + self.output_shape[1:]
|
output_shape = [None] + self.output_shape[1:]
|
||||||
self._output_placeholder = array_ops.placeholder(
|
self._output_placeholder = array_ops.placeholder(
|
||||||
dtypes.as_dtype(self.output_dtype),
|
dtypes.as_dtype(self._output_dtype),
|
||||||
output_shape,
|
output_shape,
|
||||||
name='output')
|
name='output')
|
||||||
return self._input_placeholder, self._output_placeholder
|
return self._input_placeholder, self._output_placeholder
|
||||||
@ -345,20 +375,20 @@ class DataFeeder(object):
|
|||||||
feed_dict[self._epoch_placeholder.name] = [self.epoch]
|
feed_dict[self._epoch_placeholder.name] = [self.epoch]
|
||||||
|
|
||||||
# Take next batch of indices.
|
# Take next batch of indices.
|
||||||
end = min(self.x.shape[0], self.offset + self._batch_size)
|
end = min(self._x.shape[0], self.offset + self._batch_size)
|
||||||
batch_indices = self.indices[self.offset:end]
|
batch_indices = self.indices[self.offset:end]
|
||||||
|
|
||||||
# Assign input features from random indices.
|
# Assign input features from random indices.
|
||||||
inp = (
|
inp = (
|
||||||
np.array(_access(self.x, batch_indices)).reshape(
|
np.array(_access(self._x, batch_indices)).reshape(
|
||||||
(batch_indices.shape[0], 1))
|
(batch_indices.shape[0], 1))
|
||||||
if len(self.x.shape) == 1 else _access(self.x, batch_indices))
|
if len(self._x.shape) == 1 else _access(self._x, batch_indices))
|
||||||
feed_dict[self._input_placeholder.name] = inp
|
feed_dict[self._input_placeholder.name] = inp
|
||||||
|
|
||||||
# move offset and reset it if necessary
|
# move offset and reset it if necessary
|
||||||
self.offset += self._batch_size
|
self.offset += self._batch_size
|
||||||
if self.offset >= self.x.shape[0]:
|
if self.offset >= self._x.shape[0]:
|
||||||
self.indices = self.random_state.permutation(self.x.shape[0])
|
self.indices = self.random_state.permutation(self._x.shape[0])
|
||||||
self.offset = 0
|
self.offset = 0
|
||||||
self.epoch += 1
|
self.epoch += 1
|
||||||
|
|
||||||
@ -368,21 +398,21 @@ class DataFeeder(object):
|
|||||||
|
|
||||||
# assign labels from random indices
|
# assign labels from random indices
|
||||||
self.output_shape[0] = batch_indices.shape[0]
|
self.output_shape[0] = batch_indices.shape[0]
|
||||||
out = np.zeros(self.output_shape, dtype=self.output_dtype)
|
out = np.zeros(self.output_shape, dtype=self._output_dtype)
|
||||||
for i in xrange(out.shape[0]):
|
for i in xrange(out.shape[0]):
|
||||||
sample = batch_indices[i]
|
sample = batch_indices[i]
|
||||||
# self.n_classes is None means we're passing in raw target indices
|
# self.n_classes is None means we're passing in raw target indices
|
||||||
if self.n_classes is None:
|
if self.n_classes is None:
|
||||||
out[i] = _access(self.y, sample)
|
out[i] = _access(self._y, sample)
|
||||||
else:
|
else:
|
||||||
if self.n_classes > 1:
|
if self.n_classes > 1:
|
||||||
if len(self.output_shape) == 2:
|
if len(self.output_shape) == 2:
|
||||||
out.itemset((i, int(_access(self.y, sample))), 1.0)
|
out.itemset((i, int(_access(self._y, sample))), 1.0)
|
||||||
else:
|
else:
|
||||||
for idx, value in enumerate(_access(self.y, sample)):
|
for idx, value in enumerate(_access(self._y, sample)):
|
||||||
out.itemset(tuple([i, idx, value]), 1.0)
|
out.itemset(tuple([i, idx, value]), 1.0)
|
||||||
else:
|
else:
|
||||||
out[i] = _access(self.y, sample)
|
out[i] = _access(self._y, sample)
|
||||||
feed_dict[self._output_placeholder.name] = out
|
feed_dict[self._output_placeholder.name] = out
|
||||||
|
|
||||||
return feed_dict
|
return feed_dict
|
||||||
@ -420,32 +450,28 @@ class StreamingDataFeeder(DataFeeder):
|
|||||||
"""
|
"""
|
||||||
# pylint: disable=invalid-name,super-init-not-called
|
# pylint: disable=invalid-name,super-init-not-called
|
||||||
x_first_el = six.next(x)
|
x_first_el = six.next(x)
|
||||||
self.x = itertools.chain([x_first_el], x)
|
self._x = itertools.chain([x_first_el], x)
|
||||||
if y is not None:
|
if y is not None:
|
||||||
y_first_el = six.next(y)
|
y_first_el = six.next(y)
|
||||||
self.y = itertools.chain([y_first_el], y)
|
self._y = itertools.chain([y_first_el], y)
|
||||||
else:
|
else:
|
||||||
y_first_el = None
|
y_first_el = None
|
||||||
self.y = None
|
self._y = None
|
||||||
self.n_classes = n_classes
|
self.n_classes = n_classes
|
||||||
self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
|
self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
|
||||||
[1] + list(x_first_el.shape),
|
[1] + list(x_first_el.shape),
|
||||||
[1] + list(y_first_el.shape) if y is not None else None,
|
[1] + list(y_first_el.shape) if y is not None else None,
|
||||||
n_classes,
|
n_classes,
|
||||||
batch_size)
|
batch_size)
|
||||||
self.input_dtype = x_first_el.dtype
|
self._input_dtype = _check_dtype(x_first_el.dtype)
|
||||||
# Convert float64 to float32, as all the parameters in the model are
|
|
||||||
# floats32 and there is a lot of benefits in using it in NNs.
|
|
||||||
if self.input_dtype == np.float64:
|
|
||||||
self.input_dtype = np.float32
|
|
||||||
# Output types are floats, due to both softmaxes and regression req.
|
# Output types are floats, due to both softmaxes and regression req.
|
||||||
if n_classes is not None and n_classes > 0:
|
if n_classes is not None and n_classes > 0:
|
||||||
self.output_dtype = np.float32
|
self._output_dtype = np.float32
|
||||||
elif y is not None:
|
elif y is not None:
|
||||||
if isinstance(y_first_el, list) or isinstance(y_first_el, np.ndarray):
|
if isinstance(y_first_el, list) or isinstance(y_first_el, np.ndarray):
|
||||||
self.output_dtype = np.dtype(type(y_first_el[0]))
|
self._output_dtype = _check_dtype(np.dtype(type(y_first_el[0])))
|
||||||
else:
|
else:
|
||||||
self.output_dtype = np.dtype(type(y_first_el))
|
self._output_dtype = _check_dtype(np.dtype(type(y_first_el)))
|
||||||
|
|
||||||
def get_feed_params(self):
|
def get_feed_params(self):
|
||||||
"""Function returns a dict with data feed params while training.
|
"""Function returns a dict with data feed params while training.
|
||||||
@ -472,22 +498,22 @@ class StreamingDataFeeder(DataFeeder):
|
|||||||
"""
|
"""
|
||||||
if self.stopped:
|
if self.stopped:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
inp = np.zeros(self.input_shape, dtype=self.input_dtype)
|
inp = np.zeros(self.input_shape, dtype=self._input_dtype)
|
||||||
if self.y is not None:
|
if self._y is not None:
|
||||||
out = np.zeros(self.output_shape, dtype=self.output_dtype)
|
out = np.zeros(self.output_shape, dtype=self._output_dtype)
|
||||||
for i in xrange(self._batch_size):
|
for i in xrange(self._batch_size):
|
||||||
# Add handling when queue ends.
|
# Add handling when queue ends.
|
||||||
try:
|
try:
|
||||||
inp[i, :] = six.next(self.x)
|
inp[i, :] = six.next(self._x)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
self.stopped = True
|
self.stopped = True
|
||||||
inp = inp[:i, :]
|
inp = inp[:i, :]
|
||||||
if self.y is not None:
|
if self._y is not None:
|
||||||
out = out[:i]
|
out = out[:i]
|
||||||
break
|
break
|
||||||
|
|
||||||
if self.y is not None:
|
if self._y is not None:
|
||||||
y = six.next(self.y)
|
y = six.next(self._y)
|
||||||
if self.n_classes is not None and self.n_classes > 1:
|
if self.n_classes is not None and self.n_classes > 1:
|
||||||
if len(self.output_shape) == 2:
|
if len(self.output_shape) == 2:
|
||||||
out.itemset((i, y), 1.0)
|
out.itemset((i, y), 1.0)
|
||||||
@ -496,7 +522,7 @@ class StreamingDataFeeder(DataFeeder):
|
|||||||
out.itemset(tuple([i, idx, value]), 1.0)
|
out.itemset(tuple([i, idx, value]), 1.0)
|
||||||
else:
|
else:
|
||||||
out[i] = y
|
out[i] = y
|
||||||
if self.y is None:
|
if self._y is None:
|
||||||
return {self._input_placeholder.name: inp}
|
return {self._input_placeholder.name: inp}
|
||||||
return {self._input_placeholder.name: inp,
|
return {self._input_placeholder.name: inp,
|
||||||
self._output_placeholder.name: out}
|
self._output_placeholder.name: out}
|
||||||
@ -511,6 +537,7 @@ class DaskDataFeeder(object):
|
|||||||
into them. DaskDataFeeder will remove requirement to have full dataset in the
|
into them. DaskDataFeeder will remove requirement to have full dataset in the
|
||||||
memory and still do random seeks for sampling of batches.
|
memory and still do random seeks for sampling of batches.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, x, y, n_classes, batch_size, shuffle=True,
|
def __init__(self, x, y, n_classes, batch_size, shuffle=True,
|
||||||
random_state=None, epochs=None):
|
random_state=None, epochs=None):
|
||||||
"""Initializes a DaskDataFeeder instance.
|
"""Initializes a DaskDataFeeder instance.
|
||||||
@ -521,8 +548,10 @@ class DaskDataFeeder(object):
|
|||||||
regression values.
|
regression values.
|
||||||
n_classes: indicator of how many classes the target has.
|
n_classes: indicator of how many classes the target has.
|
||||||
batch_size: Mini batch size to accumulate.
|
batch_size: Mini batch size to accumulate.
|
||||||
|
shuffle: Whether to shuffle the inputs.
|
||||||
random_state: random state for RNG. Note that it will mutate so use a
|
random_state: random state for RNG. Note that it will mutate so use a
|
||||||
int value for this if you want consistent sized batches.
|
int value for this if you want consistent sized batches.
|
||||||
|
epochs: Number of epochs to run.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
x: input features.
|
x: input features.
|
||||||
@ -537,35 +566,33 @@ class DaskDataFeeder(object):
|
|||||||
# pylint: disable=invalid-name,super-init-not-called
|
# pylint: disable=invalid-name,super-init-not-called
|
||||||
import dask.dataframe as dd # pylint: disable=g-import-not-at-top
|
import dask.dataframe as dd # pylint: disable=g-import-not-at-top
|
||||||
# TODO(terrytangyuan): check x and y dtypes in dask_io like pandas
|
# TODO(terrytangyuan): check x and y dtypes in dask_io like pandas
|
||||||
self.x = x
|
self._x = x
|
||||||
self.y = y
|
self._y = y
|
||||||
# save column names
|
# save column names
|
||||||
self.x_columns = list(x.columns)
|
self._x_columns = list(x.columns)
|
||||||
if isinstance(y.columns[0], str):
|
if isinstance(y.columns[0], str):
|
||||||
self.y_columns = list(y.columns)
|
self._y_columns = list(y.columns)
|
||||||
else:
|
else:
|
||||||
# deal with cases where two DFs have overlapped default numeric colnames
|
# deal with cases where two DFs have overlapped default numeric colnames
|
||||||
self.y_columns = len(self.x_columns) + 1
|
self._y_columns = len(self._x_columns) + 1
|
||||||
self.y = self.y.rename(columns={y.columns[0]: self.y_columns})
|
self._y = self._y.rename(columns={y.columns[0]: self._y_columns})
|
||||||
|
|
||||||
# TODO(terrytangyuan): deal with unsupervised cases
|
# TODO(terrytangyuan): deal with unsupervised cases
|
||||||
# combine into a data frame
|
# combine into a data frame
|
||||||
self.df = dd.multi.concat([self.x, self.y], axis=1)
|
self.df = dd.multi.concat([self._x, self._y], axis=1)
|
||||||
self.n_classes = n_classes
|
self.n_classes = n_classes
|
||||||
|
|
||||||
x_count = x.count().compute()[0]
|
x_count = x.count().compute()[0]
|
||||||
x_shape = (x_count, len(self.x.columns))
|
x_shape = (x_count, len(self._x.columns))
|
||||||
y_shape = (x_count, len(self.y.columns))
|
y_shape = (x_count, len(self._y.columns))
|
||||||
# TODO(terrytangyuan): Add support for shuffle and epochs.
|
# TODO(terrytangyuan): Add support for shuffle and epochs.
|
||||||
self.shuffle = shuffle
|
self._shuffle = shuffle
|
||||||
self.epochs = epochs
|
self.epochs = epochs
|
||||||
self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
|
self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
|
||||||
x_shape, y_shape, n_classes, batch_size)
|
x_shape, y_shape, n_classes, batch_size)
|
||||||
self.sample_fraction = self._batch_size / float(x_count)
|
self.sample_fraction = self._batch_size / float(x_count)
|
||||||
# TODO(ptucker,ipolosukhin): Remove this?
|
self._input_dtype = _check_dtype(self._x.dtypes[0])
|
||||||
# TODO(ipolosukhin): remove or restore.
|
self._output_dtype = _check_dtype(self._y.dtypes[self._y_columns])
|
||||||
# self.x.dtypes[0], self.y.dtypes[self.y_columns]
|
|
||||||
self.input_dtype, self.output_dtype = np.float32, np.float32
|
|
||||||
if random_state is None:
|
if random_state is None:
|
||||||
self.random_state = 66
|
self.random_state = 66
|
||||||
else:
|
else:
|
||||||
@ -597,17 +624,17 @@ class DaskDataFeeder(object):
|
|||||||
sample = self.df.random_split(
|
sample = self.df.random_split(
|
||||||
[self.sample_fraction, 1 - self.sample_fraction],
|
[self.sample_fraction, 1 - self.sample_fraction],
|
||||||
random_state=self.random_state)
|
random_state=self.random_state)
|
||||||
inp = extract_pandas_matrix(sample[0][self.x_columns].compute()).tolist()
|
inp = extract_pandas_matrix(sample[0][self._x_columns].compute()).tolist()
|
||||||
out = extract_pandas_matrix(sample[0][self.y_columns].compute())
|
out = extract_pandas_matrix(sample[0][self._y_columns].compute())
|
||||||
# convert to correct dtype
|
# convert to correct dtype
|
||||||
inp = np.array(inp, dtype=self.input_dtype)
|
inp = np.array(inp, dtype=self._input_dtype)
|
||||||
# one-hot encode out for each class for cross entropy loss
|
# one-hot encode out for each class for cross entropy loss
|
||||||
if HAS_PANDAS:
|
if HAS_PANDAS:
|
||||||
import pandas as pd # pylint: disable=g-import-not-at-top
|
import pandas as pd # pylint: disable=g-import-not-at-top
|
||||||
if not isinstance(out, pd.Series):
|
if not isinstance(out, pd.Series):
|
||||||
out = out.flatten()
|
out = out.flatten()
|
||||||
out_max = self.y.max().compute().values[0]
|
out_max = self._y.max().compute().values[0]
|
||||||
encoded_out = np.zeros((out.size, out_max + 1), dtype=self.output_dtype)
|
encoded_out = np.zeros((out.size, out_max + 1), dtype=self._output_dtype)
|
||||||
encoded_out[np.arange(out.size), out] = 1
|
encoded_out[np.arange(out.size), out] = 1
|
||||||
return {input_placeholder.name: inp,
|
return {input_placeholder.name: inp,
|
||||||
output_placeholder.name: encoded_out}
|
output_placeholder.name: encoded_out}
|
||||||
|
@ -20,12 +20,17 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import data_flow_ops
|
||||||
from tensorflow.python.ops import io_ops
|
from tensorflow.python.ops import io_ops
|
||||||
|
from tensorflow.python.ops import logging_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import parsing_ops
|
from tensorflow.python.ops import parsing_ops
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import input as input_ops
|
from tensorflow.python.training import input as input_ops
|
||||||
|
from tensorflow.python.training import queue_runner
|
||||||
|
|
||||||
# Default name for key in the feature dict.
|
# Default name for key in the feature dict.
|
||||||
KEY_FEATURE_NAME = '__key__'
|
KEY_FEATURE_NAME = '__key__'
|
||||||
@ -219,11 +224,18 @@ def read_keyed_batch_examples(
|
|||||||
return queued_examples_with_keys
|
return queued_examples_with_keys
|
||||||
|
|
||||||
|
|
||||||
def read_keyed_batch_features(
|
def read_keyed_batch_features(file_pattern,
|
||||||
file_pattern, batch_size, features, reader,
|
batch_size,
|
||||||
randomize_input=True, num_epochs=None,
|
features,
|
||||||
queue_capacity=10000, reader_num_threads=1,
|
reader,
|
||||||
parser_num_threads=1, name=None):
|
randomize_input=True,
|
||||||
|
num_epochs=None,
|
||||||
|
queue_capacity=10000,
|
||||||
|
reader_num_threads=1,
|
||||||
|
feature_queue_capacity=100,
|
||||||
|
num_queue_runners=2,
|
||||||
|
parser_num_threads=None,
|
||||||
|
name=None):
|
||||||
"""Adds operations to read, queue, batch and parse `Example` protos.
|
"""Adds operations to read, queue, batch and parse `Example` protos.
|
||||||
|
|
||||||
Given file pattern (or list of files), will setup a queue for file names,
|
Given file pattern (or list of files), will setup a queue for file names,
|
||||||
@ -251,7 +263,12 @@ def read_keyed_batch_features(
|
|||||||
tf.initialize_local_variables() as shown in the tests.
|
tf.initialize_local_variables() as shown in the tests.
|
||||||
queue_capacity: Capacity for input queue.
|
queue_capacity: Capacity for input queue.
|
||||||
reader_num_threads: The number of threads to read examples.
|
reader_num_threads: The number of threads to read examples.
|
||||||
parser_num_threads: The number of threads to parse examples.
|
feature_queue_capacity: Capacity of the parsed features queue.
|
||||||
|
num_queue_runners: Number of queue runners to start for the feature queue,
|
||||||
|
Adding multiple queue runners for the parsed example queue helps maintain
|
||||||
|
a full queue when the subsequent computations overall are cheaper than
|
||||||
|
parsing.
|
||||||
|
parser_num_threads: (Deprecated) The number of threads to parse examples.
|
||||||
name: Name of resulting op.
|
name: Name of resulting op.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -261,6 +278,11 @@ def read_keyed_batch_features(
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: for invalid inputs.
|
ValueError: for invalid inputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if parser_num_threads:
|
||||||
|
# TODO(sibyl-Aix6ihai): Remove on Sept 3 2016.
|
||||||
|
logging.warning('parser_num_threads is deprecated, it will be removed on'
|
||||||
|
'Sept 3 2016')
|
||||||
with ops.op_scope([file_pattern], name, 'read_batch_features') as scope:
|
with ops.op_scope([file_pattern], name, 'read_batch_features') as scope:
|
||||||
keys, examples = read_keyed_batch_examples(
|
keys, examples = read_keyed_batch_examples(
|
||||||
file_pattern, batch_size, reader, randomize_input=randomize_input,
|
file_pattern, batch_size, reader, randomize_input=randomize_input,
|
||||||
@ -268,24 +290,66 @@ def read_keyed_batch_features(
|
|||||||
num_threads=reader_num_threads, read_batch_size=batch_size,
|
num_threads=reader_num_threads, read_batch_size=batch_size,
|
||||||
name=scope)
|
name=scope)
|
||||||
|
|
||||||
if parser_num_threads == 1:
|
# Parse the example.
|
||||||
# Avoid queue overhead for single thread
|
feature_map = parsing_ops.parse_example(examples, features)
|
||||||
return keys, parsing_ops.parse_example(examples, features)
|
|
||||||
|
|
||||||
# Parse features into tensors in many threads and put on the queue.
|
# Lets also add preprocessed tensors into the queue types for each item of
|
||||||
features_list = []
|
# the queue.
|
||||||
for _ in range(parser_num_threads):
|
tensors_to_enqueue = []
|
||||||
feature_dict = parsing_ops.parse_example(examples, features)
|
# Each entry contains the key, and a boolean which indicates whether the
|
||||||
feature_dict[KEY_FEATURE_NAME] = keys
|
# tensor was a sparse tensor.
|
||||||
features_list.append(feature_dict)
|
tensors_mapping = []
|
||||||
queued_features = input_ops.batch_join(
|
# TODO(sibyl-Aix6ihai): Most of the functionality here is about pushing sparse
|
||||||
features_list,
|
# tensors into a queue. This could be taken care in somewhere else so others
|
||||||
batch_size=batch_size,
|
# can reuse it. Also, QueueBase maybe extended to handle sparse tensors
|
||||||
capacity=queue_capacity,
|
# directly.
|
||||||
enqueue_many=True,
|
for key, tensor in feature_map.iteritems():
|
||||||
name='parse_example_batch_join')
|
if isinstance(tensor, ops.SparseTensor):
|
||||||
queued_keys = queued_features.pop(KEY_FEATURE_NAME)
|
tensors_mapping.append((key, True))
|
||||||
return queued_keys, queued_features
|
tensors_to_enqueue.extend([tensor.indices, tensor.values, tensor.shape])
|
||||||
|
else:
|
||||||
|
tensors_mapping.append((key, False))
|
||||||
|
tensors_to_enqueue.append(tensor)
|
||||||
|
tensors_to_enqueue.append(keys)
|
||||||
|
|
||||||
|
queue_dtypes = [x.dtype for x in tensors_to_enqueue]
|
||||||
|
input_queue = data_flow_ops.FIFOQueue(feature_queue_capacity, queue_dtypes)
|
||||||
|
|
||||||
|
# Add a summary op to debug if our feature queue is full or not.
|
||||||
|
logging_ops.scalar_summary('queue/parsed_features/%s/fraction_of_%d_full' %
|
||||||
|
(input_queue.name, feature_queue_capacity),
|
||||||
|
math_ops.cast(input_queue.size(), dtypes.float32)
|
||||||
|
* (1. / feature_queue_capacity))
|
||||||
|
|
||||||
|
# Add multiple queue runners so that the queue is always full. Adding more
|
||||||
|
# than two queue-runners may hog the cpu on the worker to fill up the queue.
|
||||||
|
for _ in range(num_queue_runners):
|
||||||
|
queue_runner.add_queue_runner(
|
||||||
|
queue_runner.QueueRunner(input_queue, [input_queue.enqueue(
|
||||||
|
tensors_to_enqueue)]))
|
||||||
|
|
||||||
|
dequeued_tensors = input_queue.dequeue()
|
||||||
|
|
||||||
|
# Reset shapes on dequeued tensors.
|
||||||
|
for i in range(len(tensors_to_enqueue)):
|
||||||
|
dequeued_tensors[i].set_shape(tensors_to_enqueue[i].get_shape())
|
||||||
|
|
||||||
|
# Recreate feature mapping according to the original dictionary.
|
||||||
|
dequeued_feature_map = {}
|
||||||
|
index = 0
|
||||||
|
for key, is_sparse_tensor in tensors_mapping:
|
||||||
|
if is_sparse_tensor:
|
||||||
|
# Three tensors are (indices, values, shape).
|
||||||
|
dequeued_feature_map[key] = ops.SparseTensor(
|
||||||
|
dequeued_tensors[index], dequeued_tensors[index + 1],
|
||||||
|
dequeued_tensors[index + 2])
|
||||||
|
index += 3
|
||||||
|
else:
|
||||||
|
dequeued_feature_map[key] = dequeued_tensors[index]
|
||||||
|
index += 1
|
||||||
|
dequeued_keys = dequeued_tensors[-1]
|
||||||
|
|
||||||
|
return dequeued_keys, dequeued_feature_map
|
||||||
|
|
||||||
|
|
||||||
def read_batch_features(file_pattern, batch_size, features, reader,
|
def read_batch_features(file_pattern, batch_size, features, reader,
|
||||||
|
@ -124,18 +124,18 @@ class GraphIOTest(tf.test.TestCase):
|
|||||||
_VALID_FILE_PATTERN, batch_size, features, randomize_input=False,
|
_VALID_FILE_PATTERN, batch_size, features, randomize_input=False,
|
||||||
queue_capacity=queue_capacity, reader_num_threads=2,
|
queue_capacity=queue_capacity, reader_num_threads=2,
|
||||||
parser_num_threads=2, name=name)
|
parser_num_threads=2, name=name)
|
||||||
self.assertEqual("%s/parse_example_batch_join:1" % name,
|
self.assertEqual("%s/fifo_queue_1_Dequeue:0" % name,
|
||||||
features["feature"].name)
|
features["feature"].name)
|
||||||
file_name_queue_name = "%s/file_name_queue" % name
|
file_name_queue_name = "%s/file_name_queue" % name
|
||||||
file_names_name = "%s/input" % file_name_queue_name
|
file_names_name = "%s/input" % file_name_queue_name
|
||||||
example_queue_name = "%s/fifo_queue" % name
|
example_queue_name = "%s/fifo_queue" % name
|
||||||
parse_example_queue_name = "%s/parse_example_batch_join" % name
|
parse_example_queue_name = "%s/fifo_queue" % name
|
||||||
op_nodes = test_util.assert_ops_in_graph({
|
op_nodes = test_util.assert_ops_in_graph({
|
||||||
file_names_name: "Const",
|
file_names_name: "Const",
|
||||||
file_name_queue_name: "FIFOQueue",
|
file_name_queue_name: "FIFOQueue",
|
||||||
"%s/read/TFRecordReader" % name: "TFRecordReader",
|
"%s/read/TFRecordReader" % name: "TFRecordReader",
|
||||||
example_queue_name: "FIFOQueue",
|
example_queue_name: "FIFOQueue",
|
||||||
parse_example_queue_name: "QueueDequeueMany",
|
parse_example_queue_name: "FIFOQueue",
|
||||||
name: "QueueDequeueMany"
|
name: "QueueDequeueMany"
|
||||||
}, g)
|
}, g)
|
||||||
self.assertAllEqual(_FILE_NAMES, sess.run(["%s:0" % file_names_name])[0])
|
self.assertAllEqual(_FILE_NAMES, sess.run(["%s:0" % file_names_name])[0])
|
||||||
|
@ -19,10 +19,10 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib import rnn as contrib_rnn
|
||||||
from tensorflow.contrib.learn.python.learn.ops import autoencoder_ops
|
from tensorflow.contrib.learn.python.learn.ops import autoencoder_ops
|
||||||
from tensorflow.contrib.learn.python.learn.ops import dnn_ops
|
from tensorflow.contrib.learn.python.learn.ops import dnn_ops
|
||||||
from tensorflow.contrib.learn.python.learn.ops import losses_ops
|
from tensorflow.contrib.learn.python.learn.ops import losses_ops
|
||||||
from tensorflow.contrib import rnn as contrib_rnn
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops as array_ops_
|
from tensorflow.python.ops import array_ops as array_ops_
|
||||||
@ -81,6 +81,7 @@ def linear_regression(x, y, init_mean=None, init_stddev=1.0):
|
|||||||
with vs.variable_scope('linear_regression'):
|
with vs.variable_scope('linear_regression'):
|
||||||
logging_ops.histogram_summary('linear_regression.x', x)
|
logging_ops.histogram_summary('linear_regression.x', x)
|
||||||
logging_ops.histogram_summary('linear_regression.y', y)
|
logging_ops.histogram_summary('linear_regression.y', y)
|
||||||
|
dtype = x.dtype.base_dtype
|
||||||
y_shape = y.get_shape()
|
y_shape = y.get_shape()
|
||||||
if len(y_shape) == 1:
|
if len(y_shape) == 1:
|
||||||
output_shape = 1
|
output_shape = 1
|
||||||
@ -88,15 +89,18 @@ def linear_regression(x, y, init_mean=None, init_stddev=1.0):
|
|||||||
output_shape = y_shape[1]
|
output_shape = y_shape[1]
|
||||||
# Set up the requested initialization.
|
# Set up the requested initialization.
|
||||||
if init_mean is None:
|
if init_mean is None:
|
||||||
weights = vs.get_variable('weights', [x.get_shape()[1], output_shape])
|
weights = vs.get_variable(
|
||||||
bias = vs.get_variable('bias', [output_shape])
|
'weights', [x.get_shape()[1], output_shape], dtype=dtype)
|
||||||
|
bias = vs.get_variable('bias', [output_shape], dtype=dtype)
|
||||||
else:
|
else:
|
||||||
weights = vs.get_variable('weights', [x.get_shape()[1], output_shape],
|
weights = vs.get_variable('weights', [x.get_shape()[1], output_shape],
|
||||||
initializer=init_ops.random_normal_initializer(
|
initializer=init_ops.random_normal_initializer(
|
||||||
init_mean, init_stddev))
|
init_mean, init_stddev, dtype=dtype),
|
||||||
|
dtype=dtype)
|
||||||
bias = vs.get_variable('bias', [output_shape],
|
bias = vs.get_variable('bias', [output_shape],
|
||||||
initializer=init_ops.random_normal_initializer(
|
initializer=init_ops.random_normal_initializer(
|
||||||
init_mean, init_stddev))
|
init_mean, init_stddev, dtype=dtype),
|
||||||
|
dtype=dtype)
|
||||||
logging_ops.histogram_summary('linear_regression.weights', weights)
|
logging_ops.histogram_summary('linear_regression.weights', weights)
|
||||||
logging_ops.histogram_summary('linear_regression.bias', bias)
|
logging_ops.histogram_summary('linear_regression.bias', bias)
|
||||||
return losses_ops.mean_squared_error_regressor(x, y, weights, bias)
|
return losses_ops.mean_squared_error_regressor(x, y, weights, bias)
|
||||||
@ -135,19 +139,22 @@ def logistic_regression(x,
|
|||||||
with vs.variable_scope('logistic_regression'):
|
with vs.variable_scope('logistic_regression'):
|
||||||
logging_ops.histogram_summary('%s.x' % vs.get_variable_scope().name, x)
|
logging_ops.histogram_summary('%s.x' % vs.get_variable_scope().name, x)
|
||||||
logging_ops.histogram_summary('%s.y' % vs.get_variable_scope().name, y)
|
logging_ops.histogram_summary('%s.y' % vs.get_variable_scope().name, y)
|
||||||
|
dtype = x.dtype.base_dtype
|
||||||
# Set up the requested initialization.
|
# Set up the requested initialization.
|
||||||
if init_mean is None:
|
if init_mean is None:
|
||||||
weights = vs.get_variable('weights',
|
weights = vs.get_variable(
|
||||||
[x.get_shape()[1], y.get_shape()[-1]])
|
'weights', [x.get_shape()[1], y.get_shape()[-1]], dtype=dtype)
|
||||||
bias = vs.get_variable('bias', [y.get_shape()[-1]])
|
bias = vs.get_variable('bias', [y.get_shape()[-1]], dtype=dtype)
|
||||||
else:
|
else:
|
||||||
weights = vs.get_variable('weights',
|
weights = vs.get_variable('weights',
|
||||||
[x.get_shape()[1], y.get_shape()[-1]],
|
[x.get_shape()[1], y.get_shape()[-1]],
|
||||||
initializer=init_ops.random_normal_initializer(
|
initializer=init_ops.random_normal_initializer(
|
||||||
init_mean, init_stddev))
|
init_mean, init_stddev, dtype=dtype),
|
||||||
|
dtype=dtype)
|
||||||
bias = vs.get_variable('bias', [y.get_shape()[-1]],
|
bias = vs.get_variable('bias', [y.get_shape()[-1]],
|
||||||
initializer=init_ops.random_normal_initializer(
|
initializer=init_ops.random_normal_initializer(
|
||||||
init_mean, init_stddev))
|
init_mean, init_stddev, dtype=dtype),
|
||||||
|
dtype=dtype)
|
||||||
logging_ops.histogram_summary('%s.weights' % vs.get_variable_scope().name,
|
logging_ops.histogram_summary('%s.weights' % vs.get_variable_scope().name,
|
||||||
weights)
|
weights)
|
||||||
logging_ops.histogram_summary('%s.bias' % vs.get_variable_scope().name,
|
logging_ops.histogram_summary('%s.bias' % vs.get_variable_scope().name,
|
||||||
|
@ -535,8 +535,12 @@ class LoggingTrainable(EveryN):
|
|||||||
class SummarySaver(EveryN):
|
class SummarySaver(EveryN):
|
||||||
"""Saves summaries every N steps."""
|
"""Saves summaries every N steps."""
|
||||||
|
|
||||||
def __init__(self, summary_op, save_steps=100, output_dir=None,
|
def __init__(self,
|
||||||
summary_writer=None):
|
summary_op,
|
||||||
|
save_steps=100,
|
||||||
|
output_dir=None,
|
||||||
|
summary_writer=None,
|
||||||
|
scaffold=None):
|
||||||
"""Initializes a `SummarySaver` monitor.
|
"""Initializes a `SummarySaver` monitor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -548,6 +552,7 @@ class SummarySaver(EveryN):
|
|||||||
if no `summary_writer` is supplied.
|
if no `summary_writer` is supplied.
|
||||||
summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed,
|
summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed,
|
||||||
one will be created accordingly.
|
one will be created accordingly.
|
||||||
|
scaffold: `Scaffold` to get summary_op if it's not provided.
|
||||||
"""
|
"""
|
||||||
# TODO(ipolosukhin): Implement every N seconds.
|
# TODO(ipolosukhin): Implement every N seconds.
|
||||||
super(SummarySaver, self).__init__(every_n_steps=save_steps)
|
super(SummarySaver, self).__init__(every_n_steps=save_steps)
|
||||||
@ -555,6 +560,7 @@ class SummarySaver(EveryN):
|
|||||||
self._summary_writer = summary_writer
|
self._summary_writer = summary_writer
|
||||||
if summary_writer is None and output_dir:
|
if summary_writer is None and output_dir:
|
||||||
self._summary_writer = summary_io.SummaryWriter(output_dir)
|
self._summary_writer = summary_io.SummaryWriter(output_dir)
|
||||||
|
self._scaffold = scaffold
|
||||||
# TODO(mdan): Throw an error if output_dir and summary_writer are None.
|
# TODO(mdan): Throw an error if output_dir and summary_writer are None.
|
||||||
|
|
||||||
def set_estimator(self, estimator):
|
def set_estimator(self, estimator):
|
||||||
@ -565,14 +571,17 @@ class SummarySaver(EveryN):
|
|||||||
|
|
||||||
def every_n_step_begin(self, step):
|
def every_n_step_begin(self, step):
|
||||||
super(SummarySaver, self).every_n_step_begin(step)
|
super(SummarySaver, self).every_n_step_begin(step)
|
||||||
|
if self._summary_op is None and self._scaffold is not None:
|
||||||
|
self._summary_op = self._scaffold.summary_op
|
||||||
if self._summary_op is not None:
|
if self._summary_op is not None:
|
||||||
return [self._summary_op]
|
return [self._summary_op]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def every_n_step_end(self, step, outputs):
|
def every_n_step_end(self, step, outputs):
|
||||||
super(SummarySaver, self).every_n_step_end(step, outputs)
|
super(SummarySaver, self).every_n_step_end(step, outputs)
|
||||||
|
if self._summary_op is not None:
|
||||||
summary_strs = _extract_output(outputs, self._summary_op)
|
summary_strs = _extract_output(outputs, self._summary_op)
|
||||||
if self._summary_writer and self._summary_op is not None:
|
if self._summary_writer:
|
||||||
self._summary_writer.add_summary(summary_strs, step)
|
self._summary_writer.add_summary(summary_strs, step)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -923,36 +932,88 @@ class ExportMonitor(EveryN):
|
|||||||
default_batch_size=self._default_batch_size)
|
default_batch_size=self._default_batch_size)
|
||||||
|
|
||||||
|
|
||||||
class CheckpointSaver(EveryN):
|
class CheckpointSaver(BaseMonitor):
|
||||||
"""Saves checkpoints every N steps."""
|
"""Saves checkpoints every N steps."""
|
||||||
|
|
||||||
def __init__(self, every_n_steps, saver, checkpoint_dir,
|
def __init__(self,
|
||||||
|
checkpoint_dir,
|
||||||
|
save_secs=None,
|
||||||
|
save_steps=None,
|
||||||
|
saver=None,
|
||||||
checkpoint_basename="model.ckpt",
|
checkpoint_basename="model.ckpt",
|
||||||
first_n_steps=-1):
|
scaffold=None):
|
||||||
"""Initialize CheckpointSaver monitor.
|
"""Initialize CheckpointSaver monitor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
every_n_steps: `int`, save every N steps.
|
|
||||||
saver: `Saver` object, used for saving.
|
|
||||||
checkpoint_dir: `str`, base directory for the checkpoint files.
|
checkpoint_dir: `str`, base directory for the checkpoint files.
|
||||||
|
save_secs: `int`, save every N secs.
|
||||||
|
save_steps: `int`, save every N steps.
|
||||||
|
saver: `Saver` object, used for saving.
|
||||||
checkpoint_basename: `str`, base name for the checkpoint files.
|
checkpoint_basename: `str`, base name for the checkpoint files.
|
||||||
first_n_steps: `int`, if positive, save every step during the
|
scaffold: `Scaffold`, use to get saver object.
|
||||||
first `first_n_steps` steps.
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If both `save_steps` and `save_secs` are not `None`.
|
||||||
|
ValueError: If both `save_steps` and `save_secs` are `None`.
|
||||||
"""
|
"""
|
||||||
logging.info("Create CheckpointSaver")
|
logging.info("Create CheckpointSaver")
|
||||||
super(CheckpointSaver, self).__init__(every_n_steps=every_n_steps,
|
super(CheckpointSaver, self).__init__()
|
||||||
first_n_steps=first_n_steps)
|
|
||||||
self._saver = saver
|
self._saver = saver
|
||||||
self._summary_writer = SummaryWriterCache.get(checkpoint_dir)
|
self._summary_writer = SummaryWriterCache.get(checkpoint_dir)
|
||||||
self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
|
self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
|
||||||
|
self._scaffold = scaffold
|
||||||
|
self._save_secs = save_secs
|
||||||
|
self._save_steps = save_steps
|
||||||
|
self._last_saved_time = None
|
||||||
|
self._last_begin_step = None
|
||||||
|
self._last_saved_step = None
|
||||||
|
|
||||||
def every_n_post_step(self, step, session):
|
if save_steps is None and save_secs is None:
|
||||||
|
raise ValueError("Either save_steps or save_secs should be provided")
|
||||||
|
if (save_steps is not None) and (save_secs is not None):
|
||||||
|
raise ValueError("Can not provide both save_steps and save_secs.")
|
||||||
|
|
||||||
|
def begin(self, max_steps=None):
|
||||||
|
super(CheckpointSaver, self).begin(max_steps)
|
||||||
|
self._last_saved_time = None
|
||||||
|
self._last_begin_step = None
|
||||||
|
self._last_saved_step = None
|
||||||
|
|
||||||
|
def step_begin(self, step):
|
||||||
|
super(CheckpointSaver, self).step_begin(step)
|
||||||
|
self._last_begin_step = step
|
||||||
|
|
||||||
|
def post_step(self, step, session):
|
||||||
|
super(CheckpointSaver, self).post_step(step, session)
|
||||||
|
if self._last_saved_time is None:
|
||||||
|
self._save(step, session)
|
||||||
|
|
||||||
|
if self._save_steps is not None:
|
||||||
|
if step >= self._last_saved_step + self._save_steps:
|
||||||
|
self._save(step, session)
|
||||||
|
|
||||||
|
if self._save_secs is not None:
|
||||||
|
if time.time() >= self._last_saved_time + self._save_secs:
|
||||||
|
self._save(step, session)
|
||||||
|
|
||||||
|
def end(self, session=None):
|
||||||
|
super(CheckpointSaver, self).end(session)
|
||||||
|
self._save(self._last_begin_step, session)
|
||||||
|
|
||||||
|
def _save(self, step, session):
|
||||||
|
"""Saves the latest checkpoint."""
|
||||||
|
if step == self._last_saved_step:
|
||||||
|
return
|
||||||
logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
|
logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
|
||||||
|
self._last_saved_time = time.time()
|
||||||
|
self._last_saved_step = step
|
||||||
|
if self._saver is None:
|
||||||
|
self._scaffold.saver.save(session, self._save_path, global_step=step)
|
||||||
|
else:
|
||||||
self._saver.save(session, self._save_path, global_step=step)
|
self._saver.save(session, self._save_path, global_step=step)
|
||||||
if self._summary_writer:
|
|
||||||
self._summary_writer.add_session_log(
|
self._summary_writer.add_session_log(
|
||||||
SessionLog(status=SessionLog.CHECKPOINT,
|
SessionLog(
|
||||||
checkpoint_path=self._save_path),
|
status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
|
||||||
step)
|
step)
|
||||||
|
|
||||||
|
|
||||||
|
@ -119,47 +119,85 @@ class Scaffold(object):
|
|||||||
keep_checkpoint_max: Optional parameter to use to construct a saver if
|
keep_checkpoint_max: Optional parameter to use to construct a saver if
|
||||||
none is already there in the graph.
|
none is already there in the graph.
|
||||||
"""
|
"""
|
||||||
if global_step_tensor is None:
|
|
||||||
global_step_tensor = contrib_variables.get_or_create_global_step()
|
|
||||||
self.global_step_tensor = global_step_tensor
|
|
||||||
if init_op is None:
|
|
||||||
init_op = Scaffold._get_or_default('init_op', ops.GraphKeys.INIT_OP,
|
|
||||||
variables.initialize_all_variables)
|
|
||||||
self.init_op = init_op
|
|
||||||
self.init_feed_dict = init_feed_dict
|
|
||||||
# NOTE(touts): modifying the init function to be passed the scaffold is a
|
# NOTE(touts): modifying the init function to be passed the scaffold is a
|
||||||
# hack to make it easy to find the saver. Is there a better way?
|
# hack to make it easy to find the saver. Is there a better way?
|
||||||
if init_fn:
|
if init_fn:
|
||||||
self.init_fn = lambda sess: init_fn(self, sess)
|
self._init_fn = lambda sess: init_fn(self, sess)
|
||||||
else:
|
else:
|
||||||
self.init_fn = None
|
self._init_fn = None
|
||||||
if ready_op is None:
|
|
||||||
ready_op = Scaffold._get_or_default(
|
self._global_step_tensor = global_step_tensor
|
||||||
|
self._init_op = init_op
|
||||||
|
self._ready_op = ready_op
|
||||||
|
self._local_init_op = local_init_op
|
||||||
|
self._summary_op = summary_op
|
||||||
|
self._saver = saver
|
||||||
|
self._keep_checkpoint_max = keep_checkpoint_max
|
||||||
|
self._init_feed_dict = init_feed_dict
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
"""Creates operations if needed and finalizes the graph."""
|
||||||
|
if self._global_step_tensor is None:
|
||||||
|
self._global_step_tensor = contrib_variables.get_or_create_global_step()
|
||||||
|
if self._init_op is None:
|
||||||
|
self._init_op = Scaffold._get_or_default(
|
||||||
|
'init_op', ops.GraphKeys.INIT_OP, variables.initialize_all_variables)
|
||||||
|
if self._ready_op is None:
|
||||||
|
self._ready_op = Scaffold._get_or_default(
|
||||||
'ready_op', ops.GraphKeys.READY_OP,
|
'ready_op', ops.GraphKeys.READY_OP,
|
||||||
variables.report_uninitialized_variables)
|
variables.report_uninitialized_variables)
|
||||||
self.ready_op = ready_op
|
if self._local_init_op is None:
|
||||||
if local_init_op is None:
|
self._local_init_op = Scaffold._get_or_default(
|
||||||
local_init_op = Scaffold._get_or_default('local_init_op',
|
'local_init_op', ops.GraphKeys.LOCAL_INIT_OP,
|
||||||
ops.GraphKeys.LOCAL_INIT_OP,
|
|
||||||
Scaffold._default_local_init_op)
|
Scaffold._default_local_init_op)
|
||||||
self.local_init_op = local_init_op
|
if self._summary_op is None:
|
||||||
if summary_op is None:
|
self._summary_op = Scaffold._get_or_default(
|
||||||
summary_op = Scaffold._get_or_default('summary_op',
|
'summary_op', ops.GraphKeys.SUMMARY_OP,
|
||||||
ops.GraphKeys.SUMMARY_OP,
|
|
||||||
logging_ops.merge_all_summaries)
|
logging_ops.merge_all_summaries)
|
||||||
self.summary_op = summary_op
|
|
||||||
# pylint: disable=g-long-lambda
|
# pylint: disable=g-long-lambda
|
||||||
if saver is None:
|
if self._saver is None:
|
||||||
saver = Scaffold._get_or_default(
|
self._saver = Scaffold._get_or_default(
|
||||||
'saver',
|
'saver',
|
||||||
ops.GraphKeys.SAVERS,
|
ops.GraphKeys.SAVERS,
|
||||||
lambda: training_saver.Saver(sharded=True,
|
lambda: training_saver.Saver(sharded=True,
|
||||||
max_to_keep=keep_checkpoint_max))
|
max_to_keep=self._keep_checkpoint_max))
|
||||||
# pylint: enable=g-long-lambda
|
# pylint: enable=g-long-lambda
|
||||||
self.saver = saver
|
|
||||||
|
|
||||||
ops.get_default_graph().finalize()
|
ops.get_default_graph().finalize()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def global_step_tensor(self):
|
||||||
|
return self._global_step_tensor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def init_fn(self):
|
||||||
|
return self._init_fn
|
||||||
|
|
||||||
|
@property
|
||||||
|
def init_op(self):
|
||||||
|
return self._init_op
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ready_op(self):
|
||||||
|
return self._ready_op
|
||||||
|
|
||||||
|
@property
|
||||||
|
def local_init_op(self):
|
||||||
|
return self._local_init_op
|
||||||
|
|
||||||
|
@property
|
||||||
|
def summary_op(self):
|
||||||
|
return self._summary_op
|
||||||
|
|
||||||
|
@property
|
||||||
|
def saver(self):
|
||||||
|
return self._saver
|
||||||
|
|
||||||
|
@property
|
||||||
|
def init_feed_dict(self):
|
||||||
|
return self._init_feed_dict
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_or_default(arg_name, collection_key, default_constructor):
|
def _get_or_default(arg_name, collection_key, default_constructor):
|
||||||
"""Get from cache or create a default operation."""
|
"""Get from cache or create a default operation."""
|
||||||
@ -213,9 +251,10 @@ class SupervisedSession(object):
|
|||||||
self._config = config
|
self._config = config
|
||||||
self._monitors = monitors or []
|
self._monitors = monitors or []
|
||||||
self._scaffold = scaffold or Scaffold()
|
self._scaffold = scaffold or Scaffold()
|
||||||
# Finalize and write the graph.
|
for monitor in self._monitors:
|
||||||
self._graph.finalize()
|
monitor.begin(max_steps=None)
|
||||||
# Create the session.
|
# Create the session.
|
||||||
|
self._scaffold.finalize()
|
||||||
self._session_manager = sm.SessionManager(
|
self._session_manager = sm.SessionManager(
|
||||||
local_init_op=self._scaffold.local_init_op,
|
local_init_op=self._scaffold.local_init_op,
|
||||||
ready_op=self._scaffold.ready_op,
|
ready_op=self._scaffold.ready_op,
|
||||||
@ -223,8 +262,6 @@ class SupervisedSession(object):
|
|||||||
self._sess = recoverable_session.RecoverableSession(self._create_session)
|
self._sess = recoverable_session.RecoverableSession(self._create_session)
|
||||||
# Call the begin() method of monitors.
|
# Call the begin() method of monitors.
|
||||||
self._init_step = self._tf_sess.run(self._scaffold.global_step_tensor)
|
self._init_step = self._tf_sess.run(self._scaffold.global_step_tensor)
|
||||||
for monitor in self._monitors:
|
|
||||||
monitor.begin(max_steps=None)
|
|
||||||
# Write the graph out, note: this uses self._init_step.
|
# Write the graph out, note: this uses self._init_step.
|
||||||
self.write_graph()
|
self.write_graph()
|
||||||
|
|
||||||
|
@ -76,9 +76,8 @@ class CoordinatedSessionTest(tf.test.TestCase):
|
|||||||
self.assertFalse(coord_sess.should_stop())
|
self.assertFalse(coord_sess.should_stop())
|
||||||
self.assertEqual(0, coord_sess.run(c))
|
self.assertEqual(0, coord_sess.run(c))
|
||||||
self.assertEqual(1, coord_sess.run(v, feed_dict={c: 1}))
|
self.assertEqual(1, coord_sess.run(v, feed_dict={c: 1}))
|
||||||
with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
|
with self.assertRaisesRegexp(TypeError, 'None has invalid type'):
|
||||||
'both fed and fetched'):
|
coord_sess.run([None], feed_dict={c: 2})
|
||||||
coord_sess.run(c, feed_dict={c: 2})
|
|
||||||
self.assertTrue(coord.should_stop())
|
self.assertTrue(coord.should_stop())
|
||||||
self.assertTrue(coord_sess.should_stop())
|
self.assertTrue(coord_sess.should_stop())
|
||||||
|
|
||||||
@ -101,9 +100,8 @@ class CoordinatedSessionTest(tf.test.TestCase):
|
|||||||
self.assertEqual(1, coord_sess.run(v, feed_dict={c: 1}))
|
self.assertEqual(1, coord_sess.run(v, feed_dict={c: 1}))
|
||||||
for t in threads:
|
for t in threads:
|
||||||
self.assertTrue(t.is_alive())
|
self.assertTrue(t.is_alive())
|
||||||
with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
|
with self.assertRaisesRegexp(TypeError, 'None has invalid type'):
|
||||||
'both fed and fetched'):
|
coord_sess.run([None], feed_dict={c: 2})
|
||||||
coord_sess.run(c, feed_dict={c: 2})
|
|
||||||
for t in threads:
|
for t in threads:
|
||||||
self.assertFalse(t.is_alive())
|
self.assertFalse(t.is_alive())
|
||||||
self.assertTrue(coord.should_stop())
|
self.assertTrue(coord.should_stop())
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
# pylint: disable=wildcard-import
|
# pylint: disable=wildcard-import
|
||||||
@ -31,6 +32,68 @@ class DataFeederTest(tf.test.TestCase):
|
|||||||
# pylint: disable=undefined-variable
|
# pylint: disable=undefined-variable
|
||||||
"""Tests for `DataFeeder`."""
|
"""Tests for `DataFeeder`."""
|
||||||
|
|
||||||
|
def _assert_raises(self, input_data):
|
||||||
|
with self.assertRaisesRegexp(TypeError, 'annot convert'):
|
||||||
|
data_feeder.DataFeeder(input_data, None, n_classes=0, batch_size=1)
|
||||||
|
|
||||||
|
def test_input_uint32(self):
|
||||||
|
self._assert_raises(np.matrix([[1, 2], [3, 4]], dtype=np.uint32))
|
||||||
|
|
||||||
|
def test_input_uint64(self):
|
||||||
|
self._assert_raises(np.matrix([[1, 2], [3, 4]], dtype=np.uint64))
|
||||||
|
|
||||||
|
def _assert_dtype(self, expected_np_dtype, expected_tf_dtype, input_data):
|
||||||
|
feeder = data_feeder.DataFeeder(input_data, None, n_classes=0, batch_size=1)
|
||||||
|
self.assertEqual(expected_np_dtype, feeder.input_dtype)
|
||||||
|
with tf.Graph().as_default() as g, self.test_session(g):
|
||||||
|
inp, _ = feeder.input_builder()
|
||||||
|
self.assertEqual(expected_tf_dtype, inp.dtype)
|
||||||
|
|
||||||
|
def test_input_int8(self):
|
||||||
|
self._assert_dtype(
|
||||||
|
np.int8, tf.int8, np.matrix([[1, 2], [3, 4]], dtype=np.int8))
|
||||||
|
|
||||||
|
def test_input_int16(self):
|
||||||
|
self._assert_dtype(
|
||||||
|
np.int16, tf.int16, np.matrix([[1, 2], [3, 4]], dtype=np.int16))
|
||||||
|
|
||||||
|
def test_input_int32(self):
|
||||||
|
self._assert_dtype(
|
||||||
|
np.int32, tf.int32, np.matrix([[1, 2], [3, 4]], dtype=np.int32))
|
||||||
|
|
||||||
|
def test_input_int64(self):
|
||||||
|
self._assert_dtype(
|
||||||
|
np.int64, tf.int64, np.matrix([[1, 2], [3, 4]], dtype=np.int64))
|
||||||
|
|
||||||
|
def test_input_uint8(self):
|
||||||
|
self._assert_dtype(
|
||||||
|
np.uint8, tf.uint8, np.matrix([[1, 2], [3, 4]], dtype=np.uint8))
|
||||||
|
|
||||||
|
def test_input_uint16(self):
|
||||||
|
self._assert_dtype(
|
||||||
|
np.uint16, tf.uint16, np.matrix([[1, 2], [3, 4]], dtype=np.uint16))
|
||||||
|
|
||||||
|
def test_input_float16(self):
|
||||||
|
self._assert_dtype(
|
||||||
|
np.float16, tf.float16, np.matrix([[1, 2], [3, 4]], dtype=np.float16))
|
||||||
|
|
||||||
|
def test_input_float32(self):
|
||||||
|
self._assert_dtype(
|
||||||
|
np.float32, tf.float32, np.matrix([[1, 2], [3, 4]], dtype=np.float32))
|
||||||
|
|
||||||
|
def test_input_float64(self):
|
||||||
|
self._assert_dtype(
|
||||||
|
np.float64, tf.float64, np.matrix([[1, 2], [3, 4]], dtype=np.float64))
|
||||||
|
|
||||||
|
def test_input_bool(self):
|
||||||
|
self._assert_dtype(
|
||||||
|
np.bool, tf.bool,
|
||||||
|
np.array([[False for _ in xrange(2)] for _ in xrange(2)]))
|
||||||
|
|
||||||
|
def test_input_string(self):
|
||||||
|
input_data = np.array([['str%d' % i for i in xrange(2)] for _ in xrange(2)])
|
||||||
|
self._assert_dtype(input_data.dtype, tf.string, input_data)
|
||||||
|
|
||||||
def test_unsupervised(self):
|
def test_unsupervised(self):
|
||||||
data = np.matrix([[1, 2], [2, 3], [3, 4]])
|
data = np.matrix([[1, 2], [2, 3], [3, 4]])
|
||||||
feeder = data_feeder.DataFeeder(data, None, n_classes=0, batch_size=1)
|
feeder = data_feeder.DataFeeder(data, None, n_classes=0, batch_size=1)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user