Merge pull request #8958 from rohan100jain/branch_152141388

Branch 152141388
This commit is contained in:
Yifei Feng 2017-04-04 11:07:14 -07:00 committed by GitHub
commit efe5376f3d
211 changed files with 5461 additions and 4624 deletions

View File

@ -202,6 +202,7 @@ filegroup(
"//tensorflow/contrib/boosted_trees:all_files",
"//tensorflow/contrib/boosted_trees/lib:all_files",
"//tensorflow/contrib/boosted_trees/proto:all_files",
"//tensorflow/contrib/boosted_trees/resources:all_files",
"//tensorflow/contrib/cloud:all_files",
"//tensorflow/contrib/cloud/kernels:all_files",
"//tensorflow/contrib/compiler:all_files",
@ -256,6 +257,7 @@ filegroup(
"//tensorflow/contrib/tfprof/python/tools/tfprof:all_files",
"//tensorflow/contrib/training:all_files",
"//tensorflow/contrib/util:all_files",
"//tensorflow/contrib/xla_tf_graph:all_files",
"//tensorflow/core:all_files",
"//tensorflow/core/debug:all_files",
"//tensorflow/core/distributed_runtime:all_files",

View File

@ -51,6 +51,7 @@ genrule(
"test_graph_tfgather.pb",
"test_graph_tfmatmul.pb",
"test_graph_tfmatmulandadd.pb",
"test_graph_tffunction.pb",
],
cmd = "$(location :make_test_graphs) --out_dir $(@D)",
tags = ["manual"],
@ -114,6 +115,15 @@ tf_library(
tags = ["manual"],
)
tf_library(
name = "test_graph_tffunction",
testonly = 1,
config = "test_graph_tffunction.config.pbtxt",
cpp_class = "FunctionComp",
graph = "test_graph_tffunction.pb",
tags = ["manual"],
)
cc_test(
name = "tfcompile_test",
srcs = ["tfcompile_test.cc"],
@ -122,6 +132,7 @@ cc_test(
":test_graph_tfadd",
":test_graph_tfadd_with_ckpt",
":test_graph_tfadd_with_ckpt_saver",
":test_graph_tffunction",
":test_graph_tfgather",
":test_graph_tfmatmul",
":test_graph_tfmatmulandadd",

View File

@ -25,6 +25,7 @@ from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@ -95,6 +96,17 @@ def tfmatmulandadd(_):
math_ops.add(x, y, name='x_y_sum')
def tffunction(_):
@function.Defun(dtypes.int32, dtypes.int32)
def test_func(a, b):
return a + b
x = constant_op.constant([1], name='x_const')
y = constant_op.constant([2], name='y_const')
test_func(x, y, name='func_call') # pylint: disable=unexpected-keyword-arg
def write_graph(build_graph, out_dir):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
@ -112,6 +124,7 @@ def main(_):
write_graph(tfgather, FLAGS.out_dir)
write_graph(tfmatmul, FLAGS.out_dir)
write_graph(tfmatmulandadd, FLAGS.out_dir)
write_graph(tffunction, FLAGS.out_dir)
if __name__ == '__main__':
@ -121,7 +134,6 @@ if __name__ == '__main__':
'--out_dir',
type=str,
default='',
help='Output directory for graphs, checkpoints and savers.'
)
help='Output directory for graphs, checkpoints and savers.')
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -0,0 +1,16 @@
# Text form of tensorflow.tfcompile.Config proto.
feed {
id { node_name: "x_const" }
shape {
dim { size: 1 }
}
}
feed {
id { node_name: "y_const" }
shape {
dim { size: 1 }
}
}
fetch {
id { node_name: "func_call" }
}

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h"
#include "tensorflow/compiler/aot/tests/test_graph_tffunction.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
@ -376,6 +377,21 @@ TEST(TFCompileTest, MatMulAndAdd1) {
}
}
TEST(TFCompileTest, Function) {
// The function is equivalent to an addition
FunctionComp add_fn;
EXPECT_EQ(add_fn.arg0_data(), add_fn.args()[0]);
EXPECT_EQ(add_fn.arg1_data(), add_fn.args()[1]);
add_fn.arg0() = 1;
add_fn.arg1() = 2;
EXPECT_TRUE(add_fn.Run());
EXPECT_EQ(add_fn.error_msg(), "");
EXPECT_EQ(add_fn.result0(), 3);
EXPECT_EQ(add_fn.result0_data()[0], 3);
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
}
} // namespace
} // namespace tfcompile
} // namespace tensorflow

View File

@ -50,7 +50,7 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
}
// Make sure we don't recurse infinitely on recursive functions.
const int kMaxRecursionDepth = 5;
const int kMaxRecursionDepth = 10;
bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type,
int depth, FunctionLibraryRuntime* lib_runtime);

View File

@ -2339,6 +2339,14 @@ TEST_F(OpTest, ZerosLike) {
});
}
TEST_F(OpTest, OnesLike) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("OnesLike").Input(RandomTensor(type)).Attr("T", type));
});
}
} // anonymous namespace
} // namespace tensorflow

View File

@ -257,6 +257,11 @@ class UnaryOpsTest(XLATestCase):
np.array([[4, 3], [2, 1]], dtype=dtype),
expected=np.array([[0, 0], [0, 0]], dtype=dtype))
self._assertOpOutputMatchesExpected(
array_ops.ones_like,
np.array([[4, 3], [2, 1]], dtype=dtype),
expected=np.array([[1, 1], [1, 1]], dtype=dtype))
def testLogicalOps(self):
self._assertOpOutputMatchesExpected(
math_ops.logical_not,

View File

@ -241,5 +241,19 @@ class ZerosLikeOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("ZerosLike"), ZerosLikeOp);
class OnesLikeOp : public XlaOpKernel {
public:
explicit OnesLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0);
auto one = XlaHelpers::One(ctx->builder(), input_type(0));
ctx->SetOutput(0, ctx->builder()->Broadcast(one, input_shape.dim_sizes()));
}
};
REGISTER_XLA_OP(Name("OnesLike"), OnesLikeOp);
} // namespace
} // namespace tensorflow

View File

@ -59,9 +59,15 @@ Status CheckSignature(const DataTypeVector& types,
XlaCompiler::XlaCompiler(XlaCompiler::Options options)
: options_(std::move(options)),
initialization_status_(Status::OK()),
next_step_id_(1),
device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
device_mgr_({device_}) {}
device_mgr_({device_}) {
if (options_.populate_resource_manager) {
initialization_status_ =
(*options_.populate_resource_manager)(device_->resource_manager());
}
}
XlaCompiler::~XlaCompiler() = default;
@ -379,6 +385,9 @@ Status XlaCompiler::CompileGraph(string const& name,
CompilationResult* result) {
VLOG(1) << "Executing graph symbolically to populate ComputationBuilder.";
// Report the error here if initialization failed.
TF_RETURN_IF_ERROR(initialization_status_);
xla::ComputationBuilder builder(client(), name);
XlaContext* context =
new XlaContext(this, &builder, options_.allow_cpu_custom_calls,

View File

@ -214,6 +214,12 @@ class XlaCompiler {
// This is useful to prune stateful operators that should not be executed
// from a function body.
bool prune_unreachable_nodes = false;
// If not nullptr, populate_resource_manager is called with the
// compilation device's resource manager when the compilation
// device is created, and can be used to create metadata objects
// that can be accessed by XLA op kernels.
std::function<Status(ResourceMgr*)>* populate_resource_manager = nullptr;
};
explicit XlaCompiler(Options options);
@ -247,6 +253,7 @@ class XlaCompiler {
Status BuildExecutable(const CompilationResult& result,
std::unique_ptr<xla::LocalExecutable>* executable);
const Options& options() const { return options_; }
xla::Client* client() const { return options_.client; }
XlaCompilationDevice* device() const { return device_; }
const DeviceMgr* device_mgr() const { return &device_mgr_; }
@ -260,6 +267,9 @@ class XlaCompiler {
private:
Options options_;
// Status set to non-OK in the constructor if initialization fails.
Status initialization_status_;
// Returns the next step sequence number.
int64 NextStepId();

View File

@ -17,12 +17,14 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
@ -33,6 +35,65 @@ limitations under the License.
namespace tensorflow {
namespace {
// Helper class to test the ability to pass resources through to XLA
// compiled kernels.
class DummyResourceForTest : public ResourceBase {
public:
string DebugString() override { return "dummy"; }
void Increment() { ++value_; }
int Get() { return value_; }
private:
int value_ = 0;
};
class DummyReadResourceOp : public XlaOpKernel {
public:
explicit DummyReadResourceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
ResourceMgr* rm = ctx->op_kernel_context()->resource_manager();
OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
DummyResourceForTest* dummy;
OP_REQUIRES_OK(ctx, rm->Lookup<DummyResourceForTest>(
rm->default_container(), "dummy", &dummy));
dummy->Increment();
dummy->Unref();
ctx->SetOutput(0, ctx->Input(0));
}
};
class DummyReadResourceCC {
public:
DummyReadResourceCC(const Scope& scope, const Input& value) {
if (!scope.ok()) return;
auto _value = ops::AsNodeOut(scope, value);
if (!scope.ok()) return;
Node* ret;
const auto unique_name = scope.GetUniqueNameForOp("DummyReadResource");
auto builder = NodeBuilder(unique_name, "DummyReadResource").Input(_value);
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
if (!scope.ok()) return;
this->output_ = Output(ret, 0);
}
Node* node() const { return output_.node(); }
Output output_;
};
REGISTER_OP("DummyReadResource")
.Input("input: int32")
.Output("output: int32")
.Doc(R"doc(
A dummy Op.
input: dummy input.
output: dummy output.
)doc");
REGISTER_XLA_OP(Name("DummyReadResource"), DummyReadResourceOp);
class XlaCompilerTest : public ::testing::Test {
protected:
void SetUp() override {
@ -224,5 +285,45 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
}
}
// Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest, ResourceManager) {
// Builds a graph that calls the dummy resource Op.
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto b = DummyReadResourceCC(scope.WithOpName("B"), a);
auto c = ops::_Retval(scope.WithOpName("C"), b.output_, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the argument.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});
DummyResourceForTest* resource = new DummyResourceForTest();
// Compiles the graph.
auto options = DefaultOptions();
std::function<Status(ResourceMgr*)> populate_function =
[resource](ResourceMgr* rm) {
resource->Ref();
return rm->Create(rm->default_container(), "dummy", resource);
};
options.populate_resource_manager = &populate_function;
XlaCompiler compiler(options);
auto flr = BuildFunctionLibraryRuntime(compiler);
EXPECT_EQ(0, resource->Get());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph("dummy", std::move(graph), flr.get(), args,
&result));
EXPECT_EQ(1, resource->Get());
resource->Unref();
}
} // namespace
} // namespace tensorflow

View File

@ -354,6 +354,10 @@ void XlaOpKernelContext::SetOpHasSideEffects() {
XlaContext::Get(context_).AddSideEffects();
}
const XlaCompiler::Options& XlaOpKernelContext::GetCompilerOptions() const {
return XlaContext::Get(context_).compiler()->options();
}
void XlaOpKernelContext::CtxFailure(Status s) { context_->CtxFailure(s); }
void XlaOpKernelContext::CtxFailureWithWarning(Status s) {
context_->CtxFailureWithWarning(s);

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
#define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/macros.h"
@ -182,6 +183,11 @@ class XlaOpKernelContext {
// Returns the underlying OpKernelContext. Use rarely.
OpKernelContext* op_kernel_context() const { return context_; }
// Returns the options passed to the XlaCompiler that is being
// run. Used for, e.g., While to inherit options needed for nested
// computation.
const XlaCompiler::Options& GetCompilerOptions() const;
// TODO(phawkins): find a better home for these helpers.
// Get an XLA lambda to compute Max. This is cached in the

View File

@ -167,6 +167,8 @@ void XlaOpRegistry::RegisterCompilationKernels() {
!backend.second.op_filter(kdef.get())) {
continue;
}
VLOG(2) << "XLA op registration: device: " << backend.first
<< " op: " << op.first;
registry.kernel_registrars_.emplace_back(
new kernel_factory::OpKernelRegistrar(
new KernelDef(*kdef), "XlaJitOp", op.second->factory));

View File

@ -6,6 +6,7 @@ package_group(
name = "friends",
packages = [
"//tensorflow/compiler/...",
"//tensorflow/contrib/xla_tf_graph/...",
],
)

View File

@ -1229,8 +1229,7 @@ StatusOr<bool> ComputationBuilder::IsConstant(
VLOG(2) << "done with request";
if (!s.ok()) {
NoteError(s);
return first_error_;
return s;
}
return response.is_constant();
}
@ -1255,8 +1254,7 @@ StatusOr<std::unique_ptr<GlobalData>> ComputationBuilder::ComputeConstant(
VLOG(2) << "done with request";
if (!s.ok()) {
NoteError(s);
return first_error_;
return s;
}
TF_RET_CHECK(response.output().handle() != 0);

View File

@ -120,6 +120,7 @@ class HloComputation {
}
const string& name() const { return name_; }
void set_name(const string& name) { name_ = name; }
// Return a string representation of the computation.
string ToString() const;
@ -257,7 +258,7 @@ class HloComputation {
// Internal helper to collect unreachable roots.
std::vector<HloInstruction*> CollectUnreachableRoots() const;
const string name_;
string name_;
HloInstruction* root_instruction_;
// Module containing this computation.

View File

@ -357,7 +357,9 @@ Status HloCostAnalysis::HandleRng(HloInstruction* random,
Status HloCostAnalysis::HandleFusion(HloInstruction* fusion) {
// Compute the cost of the fused expression.
HloInstruction* fused_expression_root = fusion->fused_expression_root();
HloCostAnalysis visitor(shape_size_);
// Don't compute sizes inside of fused ops. We don't use the size here and the
// operations inside might not have a layout.
HloCostAnalysis visitor([](const Shape&) { return 0; });
TF_RETURN_IF_ERROR(fused_expression_root->Accept(&visitor));
// Attribute the cost of the fused expression to the fusion node.

View File

@ -375,6 +375,33 @@ TEST_F(FusionCostAnalysis, LoopFusion) {
EXPECT_EQ(fusion_analysis.transcendental_count(), 4);
}
TEST_F(FusionCostAnalysis, NoLayout) {
Shape shape_with_layout = ShapeUtil::MakeShape(F32, {2, 3, 4, 5});
// Instructions within a fused op may have no layout.
Shape shape_without_layout = shape_with_layout;
shape_without_layout.clear_layout();
auto c1 = HloInstruction::CreateConstant(
LiteralUtil::CreateR4FromArray4D(Array4D<float>(2, 3, 4, 5)));
auto c2 =
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1, 2, 3}));
auto broadcast =
HloInstruction::CreateBroadcast(shape_without_layout, c2.get(), {1});
auto add = HloInstruction::CreateBinary(shape_with_layout, HloOpcode::kAdd,
c1.get(), broadcast.get());
auto fusion = HloInstruction::CreateFusion(
shape_with_layout, HloInstruction::FusionKind::kLoop, add.get());
fusion->FuseInstruction(broadcast.get());
HloCostAnalysis fusion_analysis(ShapeSize);
ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
EXPECT_EQ(fusion_analysis.flop_count(), 120);
EXPECT_EQ(fusion_analysis.transcendental_count(), 0);
}
TEST_F(HloCostAnalysisTest, TupleCost) {
HloCostAnalysis analysis(ShapeSize);
{

View File

@ -31,20 +31,38 @@ limitations under the License.
namespace xla {
HloComputation* HloModule::AddEntryComputation(
HloModule::HloModule(const string& name,
const VersionedComputationHandle& entry_computation_handle)
: name_(name),
entry_computation_(nullptr),
has_entry_computation_handle_(true),
entry_computation_handle_(entry_computation_handle),
computation_name_uniquer_(/*separator=*/".") {}
HloModule::HloModule(const string& name)
: name_(name),
entry_computation_(nullptr),
computation_name_uniquer_(/*separator=*/".") {}
HloComputation* HloModule::AddComputationInternal(
std::unique_ptr<HloComputation> computation) {
CHECK_EQ(nullptr, entry_computation_);
entry_computation_ = computation.get();
computation->set_name(
computation_name_uniquer_.GetUniqueName(computation->name()));
computation->set_parent(this);
computations_.push_back(std::move(computation));
return computations_.back().get();
}
HloComputation* HloModule::AddEntryComputation(
std::unique_ptr<HloComputation> computation) {
CHECK_EQ(nullptr, entry_computation_);
entry_computation_ = computation.get();
return AddComputationInternal(std::move(computation));
}
HloComputation* HloModule::AddEmbeddedComputation(
std::unique_ptr<HloComputation> computation) {
computation->set_parent(this);
computations_.push_back(std::move(computation));
return computations_.back().get();
return AddComputationInternal(std::move(computation));
}
void HloModule::ReplaceComputations(

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@ -41,19 +42,14 @@ namespace xla {
// computations are owned by the module.
class HloModule {
public:
explicit HloModule(const string& name,
const VersionedComputationHandle& entry_computation_handle)
: name_(name),
entry_computation_(nullptr),
has_entry_computation_handle_(true),
entry_computation_handle_(entry_computation_handle) {}
HloModule(const string& name,
const VersionedComputationHandle& entry_computation_handle);
// Constructor without a versioned computation handle. This constructor should
// only be used for HloModules used outside of the XLA service (eg
// tests). The versioned handle is used by the service in the compilation
// cache.
explicit HloModule(const string& name)
: name_(name), entry_computation_(nullptr) {}
explicit HloModule(const string& name);
// Adds an entry computation to the module. A module can only have one entry
// computation. Returns a pointer to the newly added computation.
@ -111,6 +107,9 @@ class HloModule {
uint64 RandomNew64() const;
private:
HloComputation* AddComputationInternal(
std::unique_ptr<HloComputation> computation);
const string name_;
HloComputation* entry_computation_;
std::vector<std::unique_ptr<HloComputation>> computations_;
@ -125,6 +124,9 @@ class HloModule {
// Versioned handle of the entry computation of the module.
bool has_entry_computation_handle_ = false;
VersionedComputationHandle entry_computation_handle_;
// Unique name generator for computation names, which are unique per module.
NameUniquer computation_name_uniquer_;
};
} // namespace xla

View File

@ -74,6 +74,11 @@ TEST_F(HloModuleTest, TwoComputationsPostOrder) {
EXPECT_MATCH(
testing::ListToVec<HloComputation*>(module->MakeComputationPostOrder()),
testing::UnorderedMatcher<HloComputation*>(computation1, computation2));
// We specified the same name for both computations, but the HloModule should
// have made the names unique.
EXPECT_EQ(computation1->name(), "Constant");
EXPECT_EQ(computation2->name(), "Constant.1");
}
TEST_F(HloModuleTest, DiamondComputationsPostOrder) {

View File

@ -633,26 +633,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
TF_DCHECK_OK(ShapeUtil::ValidateShape(ehs));
switch (operation) {
case TRIOP_CLAMP:
TF_RETURN_IF_ERROR(
ExpectNotTupleOrOpaque(lhs, "lhs of ternary operation"));
TF_RETURN_IF_ERROR(
ExpectNotTupleOrOpaque(rhs, "rhs of ternary operation"));
TF_RETURN_IF_ERROR(
ExpectNotTupleOrOpaque(ehs, "ehs of ternary operation"));
if (((ShapeUtil::Compatible(lhs, rhs) || ShapeUtil::Rank(lhs) == 0) &&
(ShapeUtil::Compatible(rhs, ehs) || ShapeUtil::Rank(ehs) == 0))) {
return rhs;
}
if (ShapeUtil::Rank(rhs) == 0) {
if (ShapeUtil::Compatible(lhs, ehs)) {
return lhs;
}
return ShapeUtil::Rank(ehs) == 0 ? lhs : ehs;
}
return Unimplemented("not yet implemented: %s, %s <clamp> %s",
lhs.ShortDebugString().c_str(),
ehs.ShortDebugString().c_str(),
rhs.ShortDebugString().c_str());
return InferClampShape(lhs, rhs, ehs);
case TRIOP_SELECT:
return InferSelectShape(lhs, rhs, ehs);
case TRIOP_UPDATE:
@ -1332,6 +1313,41 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand);
}
// TODO(b/36794510): Make broadcast semantics more consistent, by supporting
// "degenerate" cases, as with binary elementwise ops.
/* static */ StatusOr<Shape> ShapeInference::InferClampShape(
const Shape& min, const Shape& operand, const Shape& max) {
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min"));
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand"));
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max"));
if (!ShapeUtil::SameElementType(min, operand) ||
!ShapeUtil::SameElementType(max, operand)) {
return InvalidArgument("clamp op with different operand types: %s, %s, %s",
ShapeUtil::HumanString(min).c_str(),
ShapeUtil::HumanString(operand).c_str(),
ShapeUtil::HumanString(max).c_str());
}
if (((ShapeUtil::Compatible(min, operand) || ShapeUtil::IsScalar(min)) &&
(ShapeUtil::Compatible(max, operand) || ShapeUtil::IsScalar(max)))) {
return operand;
}
if (ShapeUtil::IsScalar(operand)) {
if (ShapeUtil::Compatible(min, max)) {
return min;
} else if (ShapeUtil::IsScalar(min)) {
return max;
} else if (ShapeUtil::IsScalar(max)) {
return min;
}
}
return Unimplemented(
"not yet implemented: %s, %s <clamp> %s", min.ShortDebugString().c_str(),
max.ShortDebugString().c_str(), operand.ShortDebugString().c_str());
}
// TODO(b/36794510): Make broadcast semantics more consistent, by supporting
// "degenerate" cases, as with binary elementwise ops, as well as scalar
// broadcast from all operands, not just the predicate.
/* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
const Shape& pred, const Shape& on_true, const Shape& on_false) {
if (!ShapeUtil::Compatible(on_true, on_false)) {

View File

@ -190,6 +190,10 @@ class ShapeInference {
BinaryOperation operation, const Shape& lhs, const Shape& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
// Helper for inferring the shape of Clamp ops.
static StatusOr<Shape> InferClampShape(const Shape& min, const Shape& operand,
const Shape& max);
// Helper for inferring the shape of Select ops.
static StatusOr<Shape> InferSelectShape(const Shape& pred,
const Shape& on_true,

View File

@ -157,6 +157,99 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) {
testing::ContainsRegex("pred operand must have PRED element type"));
}
TEST_F(ShapeInferenceTest, ClampAllMatrix) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_,
matrix_64_48_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampAllScalar) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
TernaryOperation::TRIOP_CLAMP, f32_, f32_, f32_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampMinScalar) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, matrix_64_48_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampMaxScalar) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, f32_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampOperandScalar) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, matrix_64_48_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampMinMatrix) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, f32_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampMaxMatrix) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
TernaryOperation::TRIOP_CLAMP, f32_, f32_, matrix_64_48_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampOperandMatrix) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, f32_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampBadShapes) {
// Type mismatch
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
TernaryOperation::TRIOP_CLAMP, s32_, f32_, f32_)
.ok());
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
TernaryOperation::TRIOP_CLAMP, f32_, s32_, f32_)
.ok());
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
TernaryOperation::TRIOP_CLAMP, f32_, f32_, s32_)
.ok());
// Dimension mismatch
ASSERT_FALSE(
ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP,
vector_64_, vector_32_, vector_32_)
.ok());
ASSERT_FALSE(
ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP,
vector_32_, vector_64_, vector_32_)
.ok());
ASSERT_FALSE(
ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP,
vector_32_, vector_32_, vector_64_)
.ok());
// Dimension mismatch, where one operand is a scalar
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
TernaryOperation::TRIOP_CLAMP, vector_64_, vector_32_, f32_)
.ok());
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
TernaryOperation::TRIOP_CLAMP, vector_64_, f32_, vector_32_)
.ok());
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
TernaryOperation::TRIOP_CLAMP, f32_, vector_64_, vector_32_)
.ok());
}
TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
StatusOr<Shape> result = ShapeInference::InferVariadicOpShape(
VariadicOperation::VAROP_TUPLE, {&s32_, &f32_});

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@ -245,37 +246,69 @@ XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsF32) {
ComputeAndCompareR0<float>(&builder, 2.5f, {}, error_spec_);
}
XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsS32) {
ComputationBuilder builder(client_, TestName());
builder.Div(builder.ConstantR0<int32>(-5), builder.ConstantR0<int32>(2));
struct DivS32Params {
int32 dividend;
int32 divisor;
int32 quotient;
int32 remainder;
};
ComputeAndCompareR0<int32>(&builder, -2, {});
void PrintTo(const DivS32Params& p, std::ostream* os) {
*os << "{" << p.dividend << ", " << p.divisor << ", " << p.quotient << ", "
<< p.remainder << "}";
}
TEST_F(ScalarComputationsTest, RemainderTwoScalarsNegativeResultS32) {
ComputationBuilder builder(client_, TestName());
builder.Rem(builder.ConstantR0<int32>(-5), builder.ConstantR0<int32>(2));
class DivS32Test : public ClientLibraryTestBase,
public ::testing::WithParamInterface<DivS32Params> {};
ComputeAndCompareR0<int32>(&builder, -1, {});
XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) {
DivS32Params p = GetParam();
ComputationBuilder builder(client_, TestName());
builder.Div(builder.ConstantR0<int32>(p.dividend),
builder.ConstantR0<int32>(p.divisor));
ComputeAndCompareR0<int32>(&builder, p.quotient, {});
}
TEST_F(ScalarComputationsTest, RemainderTwoScalarsIntMinS32) {
XLA_TEST_P(DivS32Test, RemainderTwoScalarsS32) {
DivS32Params p = GetParam();
ComputationBuilder builder(client_, TestName());
builder.Rem(builder.ConstantR0<int32>(INT_MIN),
builder.ConstantR0<int32>(7919));
builder.Rem(builder.ConstantR0<int32>(p.dividend),
builder.ConstantR0<int32>(p.divisor));
ComputeAndCompareR0<int32>(&builder, -1309, {});
ComputeAndCompareR0<int32>(&builder, p.remainder, {});
}
TEST_F(ScalarComputationsTest, RemainderTwoScalarsIntMinVsIntMaxS32) {
ComputationBuilder builder(client_, TestName());
builder.Rem(builder.ConstantR0<int32>(INT_MIN),
builder.ConstantR0<int32>(INT_MAX));
INSTANTIATE_TEST_CASE_P(
DivS32Test_Instantiation, DivS32Test,
::testing::Values(
// Positive divisors.
DivS32Params{5, 2, 2, 1}, //
DivS32Params{-5, 2, -2, -1}, //
DivS32Params{17, 3, 5, 2}, //
DivS32Params{-17, 3, -5, -2}, //
// Negative divisors.
DivS32Params{5, -2, -2, 1}, //
DivS32Params{-5, -2, 2, -1}, //
DivS32Params{17, -3, -5, 2}, //
DivS32Params{-17, -3, 5, -2}, //
// Large positive divisors.
DivS32Params{INT32_MIN, 7919, -271181, -1309}, //
DivS32Params{INT32_MIN, INT32_MAX, -1, -1}, //
DivS32Params{INT32_MIN + 1, INT32_MAX, -1, 0}, //
DivS32Params{INT32_MIN + 2, INT32_MAX, 0, INT32_MIN + 2}, //
DivS32Params{INT32_MIN, 0x40000000, -2, 0}, //
DivS32Params{INT32_MIN + 1, 0x40000000, -1, -0x3fffffff}, //
// Large negative divisors.
DivS32Params{INT32_MIN, INT32_MIN, 1, 0}, //
DivS32Params{INT32_MIN, INT32_MIN + 1, 1, -1}, //
DivS32Params{INT32_MIN + 1, INT32_MIN, 0, INT32_MIN + 1}, //
DivS32Params{INT32_MAX, INT32_MIN, 0, INT32_MAX}, //
DivS32Params{INT32_MAX, INT32_MIN + 1, -1, 0}, //
DivS32Params{INT32_MIN, -0x40000000, 2, 0}, //
DivS32Params{INT32_MIN + 1, -0x40000000, 1, -0x3fffffff}));
ComputeAndCompareR0<int32>(&builder, -1, {});
}
TEST_F(ScalarComputationsTest, RemainderTwoScalarsPositiveResultS32) {
TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) {
ComputationBuilder builder(client_, TestName());
auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x");
builder.Rem(x, builder.ConstantR0<int32>(80000));

View File

@ -7,8 +7,6 @@ exports_files(["LICENSE"])
package(default_visibility = ["//tensorflow:__subpackages__"])
load("//tensorflow:tensorflow.bzl", "if_not_windows")
py_library(
name = "contrib_py",
srcs = glob(["**/*.py"]),
@ -46,6 +44,7 @@ py_library(
"//tensorflow/contrib/losses:losses_py",
"//tensorflow/contrib/memory_stats:memory_stats_py",
"//tensorflow/contrib/metrics:metrics_py",
"//tensorflow/contrib/nccl:nccl_py",
"//tensorflow/contrib/ndlstm",
"//tensorflow/contrib/nn:nn_py",
"//tensorflow/contrib/opt:opt_py",
@ -65,9 +64,7 @@ py_library(
"//tensorflow/contrib/tfprof",
"//tensorflow/contrib/training:training_py",
"//tensorflow/contrib/util:util_py",
] + if_not_windows([
"//tensorflow/contrib/nccl:nccl_py",
]),
],
)
cc_library(

View File

@ -35,6 +35,7 @@ from tensorflow.contrib import image
from tensorflow.contrib import input_pipeline
from tensorflow.contrib import integrate
from tensorflow.contrib import keras
from tensorflow.contrib import kernel_methods
from tensorflow.contrib import labeled_tensor
from tensorflow.contrib import layers
from tensorflow.contrib import learn
@ -45,6 +46,7 @@ from tensorflow.contrib import lookup
from tensorflow.contrib import losses
from tensorflow.contrib import memory_stats
from tensorflow.contrib import metrics
from tensorflow.contrib import nccl
from tensorflow.contrib import nn
from tensorflow.contrib import opt
from tensorflow.contrib import quantization

View File

@ -160,3 +160,90 @@ cc_test(
"//tensorflow/core:test_main",
],
)
cc_library(
name = "models",
srcs = ["models/multiple_additive_trees.cc"],
hdrs = ["models/multiple_additive_trees.h"],
deps = [
":trees",
":utils",
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
"//tensorflow/core:framework_headers_lib",
],
)
cc_test(
name = "multiple_additive_trees_test",
size = "small",
srcs = ["models/multiple_additive_trees_test.cc"],
deps = [
":batch_features_testutil",
":models",
":random_tree_gen",
"//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:lib",
"//tensorflow/core:tensor_testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "trees",
srcs = ["trees/decision_tree.cc"],
hdrs = ["trees/decision_tree.h"],
deps = [
":utils",
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
"//tensorflow/core:framework_headers_lib",
],
)
cc_test(
name = "trees_test",
size = "small",
srcs = ["trees/decision_tree_test.cc"],
deps = [
":trees",
":utils",
"//tensorflow/core:tensor_testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "batch_features_testutil",
testonly = 1,
srcs = ["testutil/batch_features_testutil.cc"],
hdrs = ["testutil/batch_features_testutil.h"],
deps = [
":utils",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:testlib",
],
)
cc_library(
name = "random_tree_gen",
srcs = ["testutil/random_tree_gen.cc"],
hdrs = ["testutil/random_tree_gen.h"],
deps = [
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
"//tensorflow/core:lib",
],
)
cc_binary(
name = "random_tree_gen_main",
srcs = ["testutil/random_tree_gen_main.cc"],
deps = [
":random_tree_gen",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)

View File

@ -0,0 +1,140 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h"
#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
#include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h"
namespace tensorflow {
namespace boosted_trees {
namespace models {
namespace {
void CalculateTreesToKeep(
const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
const std::vector<int32>& trees_to_drop, const int32 num_trees,
const bool only_finalized, std::vector<int32>* trees_to_keep) {
trees_to_keep->reserve(num_trees - trees_to_drop.size());
int32 index = 0;
// This assumes that trees_to_drop is a sorted list of tree ids.
for (int32 tree = 0; tree < num_trees; ++tree) {
if ((!trees_to_drop.empty() && index < trees_to_drop.size() &&
trees_to_drop[index] == tree) ||
(only_finalized && config.tree_metadata_size() > 0 &&
!config.tree_metadata(tree).is_finalized())) {
++index;
continue;
}
trees_to_keep->push_back(tree);
}
}
void UpdatePredictions(
const int32 index_1, const int32 index_2, const float value,
tensorflow::TTypes<float>::Matrix* output_predictions,
tensorflow::TTypes<float>::Matrix* additional_output_predictions) {
(*output_predictions)(index_1, index_2) += value;
if (additional_output_predictions != nullptr) {
(*additional_output_predictions)(index_1, index_2) += value;
}
}
void UpdatePredictionsBasedOnTree(
const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
const int32 tree_idx, const boosted_trees::utils::Example& example,
tensorflow::TTypes<float>::Matrix* output_predictions,
tensorflow::TTypes<float>::Matrix* additional_output_predictions) {
const boosted_trees::trees::DecisionTreeConfig& tree = config.trees(tree_idx);
const float tree_weight = config.tree_weights(tree_idx);
const int leaf_idx = trees::DecisionTree::Traverse(tree, 0, example);
QCHECK(leaf_idx >= 0) << "Invalid tree: " << tree.DebugString();
const auto& leaf_node = tree.nodes(leaf_idx);
QCHECK(leaf_node.has_leaf())
<< "Invalid leaf node: " << leaf_node.DebugString();
if (leaf_node.leaf().has_sparse_vector()) {
const auto& leaf = leaf_node.leaf().sparse_vector();
QCHECK_EQ(leaf.index_size(), leaf.value_size());
for (size_t class_idx = 0; class_idx < leaf.index_size(); ++class_idx) {
const float value = tree_weight * leaf.value(class_idx);
UpdatePredictions(example.example_idx, leaf.index(class_idx), value,
output_predictions, additional_output_predictions);
}
} else {
QCHECK(leaf_node.leaf().has_vector()) << "Unknown leaf type";
const auto& leaf = leaf_node.leaf().vector();
for (size_t i = 0; i < leaf.value_size(); ++i) {
const float value = tree_weight * leaf.value(i);
UpdatePredictions(example.example_idx, i, value, output_predictions,
additional_output_predictions);
}
}
}
} // namespace
void MultipleAdditiveTrees::Predict(
const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
const bool only_finalized_trees, const std::vector<int32>& trees_to_drop,
const boosted_trees::utils::BatchFeatures& features,
tensorflow::thread::ThreadPool* worker_threads,
tensorflow::TTypes<float>::Matrix output_predictions,
tensorflow::TTypes<float>::Matrix no_dropout_predictions) {
// Zero out predictions as the model is additive.
output_predictions.setZero();
no_dropout_predictions.setZero();
// Get batch size.
const int64 batch_size = features.batch_size();
if (batch_size <= 0) {
return;
}
// Prepare the list of trees to keep.
std::vector<int32> trees_to_keep;
CalculateTreesToKeep(config, trees_to_drop, config.trees_size(),
only_finalized_trees, &trees_to_keep);
// Lambda for doing a block of work.
auto update_predictions = [&config, &features, &trees_to_keep, &trees_to_drop,
&output_predictions,
&no_dropout_predictions](int64 start, int64 end) {
auto examples_iterable = features.examples_iterable(start, end);
for (const auto& example : examples_iterable) {
for (const int32 tree_idx : trees_to_keep) {
UpdatePredictionsBasedOnTree(config, tree_idx, example,
&output_predictions,
&no_dropout_predictions);
}
// Now do predictions for dropped trees
for (const int32 tree_idx : trees_to_drop) {
UpdatePredictionsBasedOnTree(config, tree_idx, example,
&no_dropout_predictions, nullptr);
}
}
};
// TODO(salehay): parallelize this for low latency in serving path where
// batch size tends to be small but ensemble size tends to be large.
boosted_trees::utils::ParallelFor(batch_size, worker_threads->NumThreads(),
worker_threads, update_predictions);
}
} // namespace models
} // namespace boosted_trees
} // namespace tensorflow

View File

@ -0,0 +1,50 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_
#include <vector>
#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace boosted_trees {
namespace models {
// Multiple additive trees prediction model.
// This class does not hold state and is thread safe.
class MultipleAdditiveTrees {
public:
// Predict runs tree ensemble on the given batch and updates
// output predictions accordingly. The method also returns predictions that
// we would get if no dropout was applied.
static void Predict(
const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
const bool only_finalized_trees, const std::vector<int32>& trees_to_drop,
const boosted_trees::utils::BatchFeatures& features,
thread::ThreadPool* const thread_pool,
TTypes<float>::Matrix output_predictions,
TTypes<float>::Matrix no_dropout_predictions);
};
} // namespace models
} // namespace boosted_trees
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_

View File

@ -0,0 +1,381 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h"
#include "tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h"
#include "tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h"
#include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
using boosted_trees::trees::DecisionTreeEnsembleConfig;
using test::AsTensor;
namespace boosted_trees {
namespace models {
namespace {
const int32 kNumThreadsMultiThreaded = 6;
const int32 kNumThreadsSingleThreaded = 1;
class MultipleAdditiveTreesTest : public ::testing::Test {
protected:
MultipleAdditiveTreesTest() : batch_features_(2) {
// Create a batch of two examples having one dense feature each.
// The shape of the dense matrix is therefore 2x1 as in one row per example
// and one column per feature per example.
auto dense_matrix = test::AsTensor<float>({7.0f, -2.0f}, {2, 1});
TF_EXPECT_OK(
batch_features_.Initialize({dense_matrix}, {}, {}, {}, {}, {}, {}));
}
boosted_trees::utils::BatchFeatures batch_features_;
};
TEST_F(MultipleAdditiveTreesTest, Empty) {
// Create empty tree ensemble.
DecisionTreeEnsembleConfig tree_ensemble_config;
auto output_tensor = AsTensor<float>({9.0f, 23.0f}, {2, 1});
auto output_matrix = output_tensor.matrix<float>();
auto no_dropout_output_matrix = output_tensor.matrix<float>();
// Predict for both instances.
tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test",
kNumThreadsSingleThreaded);
MultipleAdditiveTrees::Predict(tree_ensemble_config,
false, // include non-finalized trees
{}, batch_features_, &threads, output_matrix,
no_dropout_output_matrix);
EXPECT_EQ(0, output_matrix(0, 0));
EXPECT_EQ(0, output_matrix(1, 0));
// There was no dropout
for (int i = 0; i < 2; ++i) {
EXPECT_EQ(output_matrix(i, 0), no_dropout_output_matrix(i, 0));
}
}
TEST_F(MultipleAdditiveTreesTest, SingleClass) {
// Add one bias and one stump to ensemble for a single class.
DecisionTreeEnsembleConfig tree_ensemble_config;
auto* tree1 = tree_ensemble_config.add_trees();
auto* bias_leaf = tree1->add_nodes()->mutable_leaf()->mutable_sparse_vector();
bias_leaf->add_index(0);
bias_leaf->add_value(-0.4f);
auto* tree2 = tree_ensemble_config.add_trees();
auto* dense_split = tree2->add_nodes()->mutable_dense_float_binary_split();
dense_split->set_feature_column(0);
dense_split->set_threshold(5.0f);
dense_split->set_left_id(1);
dense_split->set_right_id(2);
auto* leaf1 = tree2->add_nodes()->mutable_leaf()->mutable_sparse_vector();
leaf1->add_index(0);
leaf1->add_value(0.9f);
auto* leaf2 = tree2->add_nodes()->mutable_leaf()->mutable_sparse_vector();
leaf2->add_index(0);
leaf2->add_value(0.2f);
tree_ensemble_config.add_tree_weights(1.0);
tree_ensemble_config.add_tree_weights(1.0);
auto output_tensor = AsTensor<float>({0.0f, 0.0f}, {2, 1});
auto output_matrix = output_tensor.matrix<float>();
auto no_dropout_output_tensor = AsTensor<float>({0.0f, 0.0f}, {2, 1});
auto no_dropout_output_matrix = no_dropout_output_tensor.matrix<float>();
tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test",
kNumThreadsSingleThreaded);
// Normal case.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config,
false, // include non-finalized trees
{}, batch_features_, &threads, output_matrix,
no_dropout_output_matrix);
EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2).
EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1).
// No dropout predictions are the same.
for (int i = 0; i < 2; ++i) {
EXPECT_EQ(output_matrix(i, 0), no_dropout_output_matrix(i, 0));
}
}
// Weighted case
{
DecisionTreeEnsembleConfig weighted = tree_ensemble_config;
weighted.set_tree_weights(0, 6.0);
weighted.set_tree_weights(1, 3.2);
MultipleAdditiveTrees::Predict(weighted,
false, // include non-finalized trees
{}, batch_features_, &threads, output_matrix,
no_dropout_output_matrix);
// -0.4 (bias) + 0.2 (leaf 2).
EXPECT_FLOAT_EQ(-0.4f * 6 + 0.2 * 3.2, output_matrix(0, 0));
// -0.4 (bias) + 0.9 (leaf 1).
EXPECT_FLOAT_EQ(-0.4f * 6 + 0.9 * 3.2, output_matrix(1, 0));
// No dropout predictions are the same.
for (int i = 0; i < 2; ++i) {
EXPECT_EQ(output_matrix(i, 0), no_dropout_output_matrix(i, 0));
}
}
// Drop first tree.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config,
false, // include non-finalized trees
{0}, batch_features_, &threads,
output_matrix, no_dropout_output_matrix);
EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 0)); // 0.2 (leaf 2).
EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 1).
// No dropout predictions
EXPECT_FLOAT_EQ(
-0.2f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2).
EXPECT_FLOAT_EQ(
0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1).
}
// Drop second tree.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config,
false, // include non-finalized trees
{1}, batch_features_, &threads,
output_matrix, no_dropout_output_matrix);
EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias).
EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias).
// No dropout predictions
EXPECT_FLOAT_EQ(
-0.2f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2).
EXPECT_FLOAT_EQ(
0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1).
}
// Drop all trees.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config,
false, // include non-finalized trees
{0, 1}, batch_features_, &threads,
output_matrix, no_dropout_output_matrix);
EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0));
EXPECT_FLOAT_EQ(0.0, output_matrix(1, 0));
// No dropout predictions
EXPECT_FLOAT_EQ(
-0.2f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2).
EXPECT_FLOAT_EQ(
0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1).
}
}
TEST_F(MultipleAdditiveTreesTest, MultiClass) {
// Add one bias and one stump to ensemble for two classes.
DecisionTreeEnsembleConfig tree_ensemble_config;
auto* tree1 = tree_ensemble_config.add_trees();
auto* bias_leaf = tree1->add_nodes()->mutable_leaf()->mutable_sparse_vector();
bias_leaf->add_index(0);
bias_leaf->add_value(-0.4f);
bias_leaf->add_index(1);
bias_leaf->add_value(-0.7f);
auto* tree2 = tree_ensemble_config.add_trees();
auto* dense_split = tree2->add_nodes()->mutable_dense_float_binary_split();
dense_split->set_feature_column(0);
dense_split->set_threshold(5.0f);
dense_split->set_left_id(1);
dense_split->set_right_id(2);
auto* leaf1 = tree2->add_nodes()->mutable_leaf()->mutable_sparse_vector();
leaf1->add_index(0);
leaf1->add_value(0.9f);
auto* leaf2 = tree2->add_nodes()->mutable_leaf()->mutable_sparse_vector();
leaf2->add_index(1);
leaf2->add_value(0.2f);
tree_ensemble_config.add_tree_weights(1.0);
tree_ensemble_config.add_tree_weights(1.0);
// Predict for both instances.
tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test",
kNumThreadsSingleThreaded);
auto output_tensor = AsTensor<float>({0.0f, 0.0f, 0.0f, 0.0f}, {2, 2});
auto output_matrix = output_tensor.matrix<float>();
auto no_dropout_output_tensor =
AsTensor<float>({0.0f, 0.0f, 0.0f, 0.0f}, {2, 2});
auto no_dropout_output_matrix = no_dropout_output_tensor.matrix<float>();
// Normal case.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config,
false, // include non-finalized trees
{}, batch_features_, &threads, output_matrix,
no_dropout_output_matrix);
EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias)
EXPECT_FLOAT_EQ(-0.5f, output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2)
EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1)
EXPECT_FLOAT_EQ(-0.7f, output_matrix(1, 1)); // -0.7 (bias)
// No dropout predictions are the same.
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 2; ++j) {
EXPECT_EQ(output_matrix(i, j), no_dropout_output_matrix(i, j));
}
}
}
// Weighted case.
{
DecisionTreeEnsembleConfig weighted = tree_ensemble_config;
weighted.set_tree_weights(0, 6.0);
weighted.set_tree_weights(1, 3.2);
MultipleAdditiveTrees::Predict(weighted,
false, // include non-finalized trees
{}, batch_features_, &threads, output_matrix,
no_dropout_output_matrix);
// bias
EXPECT_FLOAT_EQ(-0.4f * 6, output_matrix(0, 0));
// bias + leaf 2
EXPECT_FLOAT_EQ(-0.7f * 6 + 0.2f * 3.2, output_matrix(0, 1));
// bias + leaf 2
EXPECT_FLOAT_EQ(-0.4f * 6 + 0.9f * 3.2f, output_matrix(1, 0));
// bias
EXPECT_FLOAT_EQ(-0.7f * 6, output_matrix(1, 1));
}
// Dropout first tree.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config,
false, // include non-finalized trees
{0}, batch_features_, &threads,
output_matrix, no_dropout_output_matrix);
EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0));
EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 1)); // 0.2 (leaf 2)
EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 2)
EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 1));
// No dropout predictions
EXPECT_FLOAT_EQ(-0.4f, no_dropout_output_matrix(0, 0)); // -0.4 (bias)
EXPECT_FLOAT_EQ(
-0.5f, no_dropout_output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2)
EXPECT_FLOAT_EQ(
0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 2)
EXPECT_FLOAT_EQ(-0.7f, no_dropout_output_matrix(1, 1)); // -0.7 (bias)
}
// Dropout second tree.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config,
false, // include non-finalized trees
{1}, batch_features_, &threads,
output_matrix, no_dropout_output_matrix);
EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias)
EXPECT_FLOAT_EQ(-0.7f, output_matrix(0, 1)); // -0.7 (bias)
EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias)
EXPECT_FLOAT_EQ(-0.7f, output_matrix(1, 1)); // -0.7 (bias)
// No dropout predictions
EXPECT_FLOAT_EQ(-0.4f, no_dropout_output_matrix(0, 0)); // -0.4 (bias)
EXPECT_FLOAT_EQ(
-0.5f, no_dropout_output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2)
EXPECT_FLOAT_EQ(
0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 2)
EXPECT_FLOAT_EQ(-0.7f, no_dropout_output_matrix(1, 1)); // -0.7 (bias)
}
// Drop both trees.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config,
false, // include non-finalized trees
{0, 1}, batch_features_, &threads,
output_matrix, no_dropout_output_matrix);
EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 0));
EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 1));
EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 0));
EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 1));
// No dropout predictions
EXPECT_FLOAT_EQ(-0.4f, no_dropout_output_matrix(0, 0)); // -0.4 (bias)
EXPECT_FLOAT_EQ(
-0.5f, no_dropout_output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2)
EXPECT_FLOAT_EQ(
0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 2)
EXPECT_FLOAT_EQ(-0.7f, no_dropout_output_matrix(1, 1)); // -0.7 (bias)
}
}
TEST_F(MultipleAdditiveTreesTest, DenseLeaves) {
DecisionTreeEnsembleConfig tree_ensemble_config;
auto* tree1 = tree_ensemble_config.add_trees();
auto* bias_leaf = tree1->add_nodes()->mutable_leaf()->mutable_vector();
bias_leaf->add_value(-0.4f);
bias_leaf->add_value(-0.7f);
bias_leaf->add_value(3.0f);
auto* tree2 = tree_ensemble_config.add_trees();
auto* dense_split = tree2->add_nodes()->mutable_dense_float_binary_split();
dense_split->set_feature_column(0);
dense_split->set_threshold(5.0f);
dense_split->set_left_id(1);
dense_split->set_right_id(2);
auto* leaf1 = tree2->add_nodes()->mutable_leaf()->mutable_vector();
leaf1->add_value(0.9f);
leaf1->add_value(0.8f);
leaf1->add_value(0.7f);
auto* leaf2 = tree2->add_nodes()->mutable_leaf()->mutable_vector();
leaf2->add_value(0.2f);
leaf2->add_value(0.3f);
leaf2->add_value(0.4f);
tree_ensemble_config.add_tree_weights(1.0);
tree_ensemble_config.add_tree_weights(1.0);
// Predict for both instances.
tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test",
kNumThreadsSingleThreaded);
auto output_tensor =
AsTensor<float>({0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, {2, 3});
auto output_matrix = output_tensor.matrix<float>();
auto no_dropout_output_tensor =
AsTensor<float>({0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, {2, 3});
auto no_dropout_output_matrix = no_dropout_output_tensor.matrix<float>();
// Normal case.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config,
false, // include non-finalized trees
{}, batch_features_, &threads, output_matrix,
no_dropout_output_matrix);
EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (tree1) + 0.2 (leaf 2)
EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 1)); // -0.7 (tree1) + 0.3 (leaf 2)
EXPECT_FLOAT_EQ(3.4f, output_matrix(0, 2)); // 3.0 -(tree1) + 0.4 (leaf 2)
EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (tree1) + 0.9 (leaf 1)
EXPECT_FLOAT_EQ(0.1f, output_matrix(1, 1)); // -0.7 (tree1) + 0.8 (leaf 1)
EXPECT_FLOAT_EQ(3.7f, output_matrix(1, 2)); // 3.0 (tree1) + 0.7 (leaf 1)
// No dropout predictions are the same.
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) {
EXPECT_EQ(output_matrix(i, j), no_dropout_output_matrix(i, j));
}
}
}
}
} // namespace
} // namespace models
} // namespace boosted_trees
} // namespace tensorflow

View File

@ -0,0 +1,88 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace tensorflow {
namespace boosted_trees {
namespace testutil {
using tensorflow::Tensor;
void RandomlyInitializeBatchFeatures(
tensorflow::random::SimplePhilox* rng, uint32 num_dense_float_features,
uint32 num_sparse_float_features, double sparsity_lo, double sparsity_hi,
boosted_trees::utils::BatchFeatures* batch_features) {
const int64 batch_size = static_cast<int64>(batch_features->batch_size());
// Populate dense features.
std::vector<tensorflow::Tensor> dense_float_features_list;
for (int i = 0; i < num_dense_float_features; ++i) {
std::vector<float> values;
for (int64 j = 0; j < batch_size; ++j) {
values.push_back(rng->RandFloat());
}
auto dense_tensor = Tensor(tensorflow::DT_FLOAT, {batch_size, 1});
tensorflow::test::FillValues<float>(&dense_tensor, values);
dense_float_features_list.push_back(dense_tensor);
}
// Populate sparse features.
std::vector<tensorflow::Tensor> sparse_float_feature_indices_list;
std::vector<tensorflow::Tensor> sparse_float_feature_values_list;
std::vector<tensorflow::Tensor> sparse_float_feature_shapes_list;
for (int i = 0; i < num_sparse_float_features; ++i) {
std::set<uint64> indices;
const double sparsity =
sparsity_lo + rng->RandDouble() * (sparsity_hi - sparsity_lo);
const double density = 1 - sparsity;
for (int64 k = 0; k < static_cast<int64>(density * batch_size) + 1; ++k) {
indices.insert(rng->Uniform64(batch_size));
}
const int64 sparse_values_size = indices.size();
std::vector<int64> indices_vector;
for (auto idx : indices) {
indices_vector.push_back(idx);
indices_vector.push_back(0);
}
auto indices_tensor = Tensor(tensorflow::DT_INT64, {sparse_values_size, 2});
tensorflow::test::FillValues<int64>(&indices_tensor, indices_vector);
sparse_float_feature_indices_list.push_back(indices_tensor);
std::vector<float> values;
for (int64 j = 0; j < sparse_values_size; ++j) {
values.push_back(rng->RandFloat());
}
auto values_tensor = Tensor(tensorflow::DT_FLOAT, {sparse_values_size});
tensorflow::test::FillValues<float>(&values_tensor, values);
sparse_float_feature_values_list.push_back(values_tensor);
auto shape_tensor = Tensor(tensorflow::DT_INT64, {2});
tensorflow::test::FillValues<int64>(&shape_tensor, {batch_size, 1});
sparse_float_feature_shapes_list.push_back(shape_tensor);
}
// TODO(salehay): Add categorical feature generation support.
TF_EXPECT_OK(batch_features->Initialize(
dense_float_features_list, sparse_float_feature_indices_list,
sparse_float_feature_values_list, sparse_float_feature_shapes_list, {},
{}, {}));
}
} // namespace testutil
} // namespace boosted_trees
} // namespace tensorflow

View File

@ -0,0 +1,45 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_
#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/random/simple_philox.h"
namespace tensorflow {
namespace boosted_trees {
namespace testutil {
// This method calls Initialize on the given 'batch_features', which will be
// populated with randomly generated feature values when the call returns.
// 'tensors' returns a vector of all tensors used in the initialization,
// because they must outlive 'batch_features'.
//
// All float features will be either missing or uniformly randomly chosen
// from [0, 1). For sparse (float) features, a sparsity is uniformly randomly
// chosen from ['sparsity_lo', 'sparsity_hi') per feature, and each instance
// will have a probability of sparsity of missing that feature, in other words,
// sparsity = 1 - density.
void RandomlyInitializeBatchFeatures(
tensorflow::random::SimplePhilox* rng, uint32 num_dense_float_features,
uint32 num_sparse_float_features, double sparsity_lo, double sparsity_hi,
boosted_trees::utils::BatchFeatures* batch_features);
} // namespace testutil
} // namespace boosted_trees
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_

View File

@ -0,0 +1,211 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h"
#include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
namespace boosted_trees {
namespace testutil {
using tensorflow::boosted_trees::trees::DecisionTreeConfig;
using tensorflow::boosted_trees::trees::TreeNode;
using boosted_trees::trees::DenseFloatBinarySplit;
namespace {
// Append the given nodes to tree with transfer of pointer ownership.
// nodes will not be usable upon return.
template <typename T>
void AppendNodes(DecisionTreeConfig* tree, T* nodes) {
std::reverse(nodes->pointer_begin(), nodes->pointer_end());
while (!nodes->empty()) {
tree->mutable_nodes()->AddAllocated(nodes->ReleaseLast());
}
}
DenseFloatBinarySplit* GetSplit(TreeNode* node) {
switch (node->node_case()) {
case TreeNode::kSparseFloatBinarySplitDefaultLeft:
return node->mutable_sparse_float_binary_split_default_left()
->mutable_split();
case TreeNode::kSparseFloatBinarySplitDefaultRight:
return node->mutable_sparse_float_binary_split_default_right()
->mutable_split();
case TreeNode::kDenseFloatBinarySplit:
return node->mutable_dense_float_binary_split();
default:
LOG(FATAL) << "Unknown node type encountered.";
}
return nullptr;
}
} // namespace
RandomTreeGen::RandomTreeGen(tensorflow::random::SimplePhilox* rng,
int dense_feature_size, int sparse_feature_size)
: rng_(rng),
dense_feature_size_(dense_feature_size),
sparse_feature_size_(sparse_feature_size) {}
namespace {
void AddWeightAndMetadata(
boosted_trees::trees::DecisionTreeEnsembleConfig* ret) {
// Assign the weight of the tree to 1 and say that this weight was updated
// only once.
ret->add_tree_weights(1.0);
auto* meta = ret->add_tree_metadata();
meta->set_num_tree_weight_updates(1);
}
} // namespace
boosted_trees::trees::DecisionTreeEnsembleConfig
RandomTreeGen::GenerateEnsemble(int depth, int tree_count) {
boosted_trees::trees::DecisionTreeEnsembleConfig ret;
*(ret.add_trees()) = Generate(depth);
AddWeightAndMetadata(&ret);
for (int i = 1; i < tree_count; ++i) {
*(ret.add_trees()) = Generate(ret.trees(0));
AddWeightAndMetadata(&ret);
}
return ret;
}
DecisionTreeConfig RandomTreeGen::Generate(const DecisionTreeConfig& tree) {
DecisionTreeConfig ret = tree;
for (auto& node : *ret.mutable_nodes()) {
if (node.node_case() == TreeNode::kLeaf) {
node.mutable_leaf()->mutable_sparse_vector()->set_value(
0, rng_->RandFloat());
continue;
}
// Original node is a split. Re-generate it's type but retain the split node
// indices.
DenseFloatBinarySplit* split = GetSplit(&node);
const int left_id = split->left_id();
const int right_id = split->right_id();
GenerateSplit(&node, left_id, right_id);
}
return ret;
}
DecisionTreeConfig RandomTreeGen::Generate(int depth) {
DecisionTreeConfig ret;
// Add root,
TreeNode* node = ret.add_nodes();
GenerateSplit(node, 1, 2);
if (depth == 1) {
// Add left and right leaves.
TreeNode* left = ret.add_nodes();
left->mutable_leaf()->mutable_sparse_vector()->add_index(0);
left->mutable_leaf()->mutable_sparse_vector()->add_value(rng_->RandFloat());
TreeNode* right = ret.add_nodes();
right->mutable_leaf()->mutable_sparse_vector()->add_index(0);
right->mutable_leaf()->mutable_sparse_vector()->add_value(
rng_->RandFloat());
return ret;
} else {
DecisionTreeConfig left_branch = Generate(depth - 1);
DecisionTreeConfig right_branch = Generate(depth - 1);
Combine(&ret, &left_branch, &right_branch);
return ret;
}
}
void RandomTreeGen::Combine(DecisionTreeConfig* root,
DecisionTreeConfig* left_branch,
DecisionTreeConfig* right_branch) {
const int left_branch_size = left_branch->nodes_size();
CHECK_EQ(1, root->nodes_size());
// left_branch starts its index at 1. right_branch starts its index at
// (left_branch_size + 1).
auto* root_node = root->mutable_nodes(0);
DenseFloatBinarySplit* root_split = GetSplit(root_node);
root_split->set_left_id(1);
root_split->set_right_id(left_branch_size + 1);
// Shift left/right branch's indices internally so that everything is
// consistent.
ShiftNodeIndex(left_branch, 1);
ShiftNodeIndex(right_branch, left_branch_size + 1);
// Complexity O(branch node size). No proto copying though.
AppendNodes(root, left_branch->mutable_nodes());
AppendNodes(root, right_branch->mutable_nodes());
}
void RandomTreeGen::ShiftNodeIndex(DecisionTreeConfig* tree, int shift) {
for (TreeNode& node : *(tree->mutable_nodes())) {
DenseFloatBinarySplit* split = nullptr;
switch (node.node_case()) {
case TreeNode::kLeaf:
break;
case TreeNode::kSparseFloatBinarySplitDefaultLeft:
split = node.mutable_sparse_float_binary_split_default_left()
->mutable_split();
break;
case TreeNode::kSparseFloatBinarySplitDefaultRight:
split = node.mutable_sparse_float_binary_split_default_right()
->mutable_split();
break;
case TreeNode::kDenseFloatBinarySplit:
split = node.mutable_dense_float_binary_split();
break;
default:
LOG(FATAL) << "Unknown node type encountered.";
}
if (split) {
split->set_left_id(shift + split->left_id());
split->set_right_id(shift + split->right_id());
}
}
}
void RandomTreeGen::GenerateSplit(TreeNode* node, int left_id, int right_id) {
const double denseSplitProb =
sparse_feature_size_ == 0
? 1.0
: static_cast<double>(dense_feature_size_) /
(dense_feature_size_ + sparse_feature_size_);
// Generate the tree such that it has equal probability of going left and
// right when the feature is missing.
static constexpr float kLeftProb = 0.5;
DenseFloatBinarySplit* split;
int feature_size;
if (rng_->RandFloat() < denseSplitProb) {
feature_size = dense_feature_size_;
split = node->mutable_dense_float_binary_split();
} else {
feature_size = sparse_feature_size_;
if (rng_->RandFloat() < kLeftProb) {
split = node->mutable_sparse_float_binary_split_default_left()
->mutable_split();
} else {
split = node->mutable_sparse_float_binary_split_default_right()
->mutable_split();
}
}
split->set_threshold(rng_->RandFloat());
split->set_feature_column(rng_->Uniform(feature_size));
split->set_left_id(left_id);
split->set_right_id(right_id);
}
} // namespace testutil
} // namespace boosted_trees
} // namespace tensorflow

View File

@ -0,0 +1,75 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_
#include <memory>
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" // NOLINT
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
namespace boosted_trees {
namespace testutil {
// Randomly generate a balanced tree, for performance benchmarking purposes,
// that assume all features are sparse float features, for now.
class RandomTreeGen {
public:
RandomTreeGen(tensorflow::random::SimplePhilox* rng, int dense_feature_size,
int sparse_feature_size);
// Required: depth must be >= 1.
// If one wants to generate multiple trees with the same depth, see also the
// overload below.
boosted_trees::trees::DecisionTreeConfig Generate(int depth);
// Randomly generate a new tree with the same depth (and tree structure)
// as the given tree. This is faster.
boosted_trees::trees::DecisionTreeConfig Generate(
const boosted_trees::trees::DecisionTreeConfig& tree);
// Requried: depth >= 1; tree_count >= 1.
boosted_trees::trees::DecisionTreeEnsembleConfig GenerateEnsemble(
int dept, int tree_count);
private:
tensorflow::random::SimplePhilox* rng_;
const int dense_feature_size_;
const int sparse_feature_size_;
// Put together a deeper tree by combining two trees.
void Combine(boosted_trees::trees::DecisionTreeConfig* root,
boosted_trees::trees::DecisionTreeConfig* left_branch,
boosted_trees::trees::DecisionTreeConfig* right_branch);
// For each node in the provided tree, shift its referenced left/right index
// by shift.
void ShiftNodeIndex(boosted_trees::trees::DecisionTreeConfig* tree,
int shift);
// Generate a sparse split in the node.
void GenerateSplit(boosted_trees::trees::TreeNode* node, int left_id,
int right_id);
TF_DISALLOW_COPY_AND_ASSIGN(RandomTreeGen);
};
} // namespace testutil
} // namespace boosted_trees
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_

View File

@ -0,0 +1,67 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Randomly generate a tree ensemble and write to file.
#include "tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
using tensorflow::Flag;
using tensorflow::Flags;
using tensorflow::int32;
using tensorflow::string;
int main(int argc, char* argv[]) {
int32 dense_feature_size = 100;
int32 sparse_feature_size = 100;
int32 depth = 8;
int32 tree_count = 10;
string filename = "/tmp/trees.pb";
std::vector<Flag> flag_list = {
Flag("dense_feature_size", &dense_feature_size, "dense feature size"),
Flag("sparse_feature_size", &sparse_feature_size, "sparse_feature_size"),
Flag("depth", &depth, "tree depth"),
Flag("tree_count", &tree_count, "tree count"),
Flag("filename", &filename, "Output filename."),
};
string usage = Flags::Usage(argv[0], flag_list);
const bool parse_result = Flags::Parse(&argc, argv, flag_list);
// We need to call this to set up global state for TensorFlow.
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
if (!parse_result) {
LOG(ERROR) << "\n" << usage;
return -1;
}
tensorflow::random::PhiloxRandom philox(1);
tensorflow::random::SimplePhilox rng(&philox);
tensorflow::boosted_trees::testutil::RandomTreeGen tree_gen(
&rng, dense_feature_size, sparse_feature_size);
const auto& trees = tree_gen.GenerateEnsemble(depth, tree_count);
tensorflow::Status status =
tensorflow::WriteBinaryProto(tensorflow::Env::Default(), filename, trees);
if (!status.ok()) {
LOG(WARNING) << "Failed to write: " << filename << " : " << status;
} else {
LOG(INFO) << "Tree ensemble written to: " << filename;
}
return 0;
}

View File

@ -0,0 +1,170 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
namespace boosted_trees {
namespace trees {
constexpr int kInvalidLeaf = -1;
int DecisionTree::Traverse(const DecisionTreeConfig& config,
const int32 sub_root_id,
const utils::Example& example) {
if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) {
return kInvalidLeaf;
}
// Traverse tree starting at the provided sub-root.
int32 node_id = sub_root_id;
while (true) {
const auto& current_node = config.nodes(node_id);
switch (current_node.node_case()) {
case TreeNode::kLeaf: {
return node_id;
}
case TreeNode::kDenseFloatBinarySplit: {
const auto& split = current_node.dense_float_binary_split();
node_id = example.dense_float_features[split.feature_column()] <=
split.threshold()
? split.left_id()
: split.right_id();
break;
}
case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
const auto& split =
current_node.sparse_float_binary_split_default_left().split();
auto sparse_feature =
example.sparse_float_features[split.feature_column()];
node_id = !sparse_feature.has_value() ||
sparse_feature.get_value() <= split.threshold()
? split.left_id()
: split.right_id();
break;
}
case TreeNode::kSparseFloatBinarySplitDefaultRight: {
const auto& split =
current_node.sparse_float_binary_split_default_right().split();
auto sparse_feature =
example.sparse_float_features[split.feature_column()];
node_id = sparse_feature.has_value() &&
sparse_feature.get_value() <= split.threshold()
? split.left_id()
: split.right_id();
break;
}
case TreeNode::kCategoricalIdBinarySplit: {
const auto& split = current_node.categorical_id_binary_split();
node_id = example.sparse_int_features[split.feature_column()].count(
split.feature_id()) > 0
? split.left_id()
: split.right_id();
break;
}
case TreeNode::NODE_NOT_SET: {
QCHECK(false) << "Invalid node in tree: " << current_node.DebugString();
break;
}
}
DCHECK_NE(node_id, 0) << "Malformed tree, cycles found to root:"
<< current_node.DebugString();
}
}
void DecisionTree::LinkChildren(const std::vector<int32>& children,
TreeNode* parent_node) {
// Decide how to link children depending on the parent node's type.
auto children_it = children.begin();
switch (parent_node->node_case()) {
case TreeNode::kLeaf: {
// Essentially no-op.
QCHECK(children.empty()) << "A leaf node cannot have children.";
break;
}
case TreeNode::kDenseFloatBinarySplit: {
QCHECK(children.size() == 2)
<< "A binary split node must have exactly two children.";
auto* split = parent_node->mutable_dense_float_binary_split();
split->set_left_id(*children_it);
split->set_right_id(*++children_it);
break;
}
case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
QCHECK(children.size() == 2)
<< "A binary split node must have exactly two children.";
auto* split =
parent_node->mutable_sparse_float_binary_split_default_left()
->mutable_split();
split->set_left_id(*children_it);
split->set_right_id(*++children_it);
break;
}
case TreeNode::kSparseFloatBinarySplitDefaultRight: {
QCHECK(children.size() == 2)
<< "A binary split node must have exactly two children.";
auto* split =
parent_node->mutable_sparse_float_binary_split_default_right()
->mutable_split();
split->set_left_id(*children_it);
split->set_right_id(*++children_it);
break;
}
case TreeNode::kCategoricalIdBinarySplit: {
QCHECK(children.size() == 2)
<< "A binary split node must have exactly two children.";
auto* split = parent_node->mutable_categorical_id_binary_split();
split->set_left_id(*children_it);
split->set_right_id(*++children_it);
break;
}
case TreeNode::NODE_NOT_SET: {
QCHECK(false) << "A non-set node cannot have children.";
break;
}
}
}
std::vector<int32> DecisionTree::GetChildren(const TreeNode& node) {
// A node's children depend on its type.
switch (node.node_case()) {
case TreeNode::kLeaf: {
return {};
}
case TreeNode::kDenseFloatBinarySplit: {
const auto& split = node.dense_float_binary_split();
return {split.left_id(), split.right_id()};
}
case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
const auto& split = node.sparse_float_binary_split_default_left().split();
return {split.left_id(), split.right_id()};
}
case TreeNode::kSparseFloatBinarySplitDefaultRight: {
const auto& split =
node.sparse_float_binary_split_default_right().split();
return {split.left_id(), split.right_id()};
}
case TreeNode::kCategoricalIdBinarySplit: {
const auto& split = node.categorical_id_binary_split();
return {split.left_id(), split.right_id()};
}
case TreeNode::NODE_NOT_SET: {
return {};
}
}
}
} // namespace trees
} // namespace boosted_trees
} // namespace tensorflow

View File

@ -0,0 +1,49 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_
#include "tensorflow/contrib/boosted_trees/lib/utils/example.h"
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" // NOLINT
namespace tensorflow {
namespace boosted_trees {
namespace trees {
// Decision tree class to encapsulate tree traversal and mutation logic.
// This class does not hold state and is thread safe.
class DecisionTree {
public:
// Traverse given an instance, a sub-root and its set of features
// and return the leaf index or -1 if the tree is empty or
// the sub-root is invalid.
static int Traverse(const DecisionTreeConfig& config, int32 sub_root_id,
const utils::Example& example);
// Links the specified children to the parent, the children must
// already be added to the decision tree config so this method
// just ensures nodes are re-linked.
static void LinkChildren(const std::vector<int32>& children,
TreeNode* parent_node);
// Retrieves node children indices if any.
static std::vector<int32> GetChildren(const TreeNode& node);
};
} // namespace trees
} // namespace boosted_trees
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_

View File

@ -0,0 +1,326 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace boosted_trees {
namespace trees {
namespace {
class DecisionTreeTest : public ::testing::Test {
protected:
DecisionTreeTest() : batch_features_(2) {
// Create a batch of two examples having one dense float, two sparse float
// and one sparse int features.
// The first example is missing the second sparse feature column and the
// second example is missing the first sparse feature column.
// This looks like the following:
// Instance | DenseF1 | SparseF1 | SparseF2 | SparseI1 |
// 0 | 7 | -3 | | 3 |
// 1 | -2 | | 4 | |
auto dense_float_matrix = test::AsTensor<float>({7.0f, -2.0f}, {2, 1});
auto sparse_float_indices1 = test::AsTensor<int64>({0, 0}, {1, 2});
auto sparse_float_values1 = test::AsTensor<float>({-3.0f});
auto sparse_float_shape1 = test::AsTensor<int64>({2, 1});
auto sparse_float_indices2 = test::AsTensor<int64>({1, 0}, {1, 2});
auto sparse_float_values2 = test::AsTensor<float>({4.0f});
auto sparse_float_shape2 = test::AsTensor<int64>({2, 1});
auto sparse_int_indices1 = test::AsTensor<int64>({0, 0}, {1, 2});
auto sparse_int_values1 = test::AsTensor<int64>({3});
auto sparse_int_shape1 = test::AsTensor<int64>({2, 1});
TF_EXPECT_OK(batch_features_.Initialize(
{dense_float_matrix}, {sparse_float_indices1, sparse_float_indices2},
{sparse_float_values1, sparse_float_values2},
{sparse_float_shape1, sparse_float_shape2}, {sparse_int_indices1},
{sparse_int_values1}, {sparse_int_shape1}));
}
template <typename SplitType>
void TestLinkChildrenBinary(TreeNode* node, SplitType* split) {
// Verify children were linked.
DecisionTree::LinkChildren({3, 8}, node);
EXPECT_EQ(3, split->left_id());
EXPECT_EQ(8, split->right_id());
// Invalid cases.
EXPECT_DEATH(DecisionTree::LinkChildren({}, node),
"A binary split node must have exactly two children.");
EXPECT_DEATH(DecisionTree::LinkChildren({3}, node),
"A binary split node must have exactly two children.");
EXPECT_DEATH(DecisionTree::LinkChildren({1, 2, 3}, node),
"A binary split node must have exactly two children.");
}
void TestGetChildren(const TreeNode& node,
const std::vector<uint32>& expected_children) {
// Verify children were linked.
auto children = DecisionTree::GetChildren(node);
EXPECT_EQ(children.size(), expected_children.size());
for (size_t idx = 0; idx < children.size(); ++idx) {
EXPECT_EQ(children[idx], expected_children[idx]);
}
}
utils::BatchFeatures batch_features_;
};
TEST_F(DecisionTreeTest, TraverseEmpty) {
DecisionTreeConfig tree_config;
auto example = (*batch_features_.examples_iterable(0, 1).begin());
EXPECT_EQ(-1, DecisionTree::Traverse(tree_config, 0, example));
}
TEST_F(DecisionTreeTest, TraverseBias) {
DecisionTreeConfig tree_config;
tree_config.add_nodes()->mutable_leaf();
auto example = (*batch_features_.examples_iterable(0, 1).begin());
EXPECT_EQ(0, DecisionTree::Traverse(tree_config, 0, example));
}
TEST_F(DecisionTreeTest, TraverseInvalidSubRoot) {
DecisionTreeConfig tree_config;
tree_config.add_nodes()->mutable_leaf();
auto example = (*batch_features_.examples_iterable(0, 1).begin());
EXPECT_EQ(-1, DecisionTree::Traverse(tree_config, 10, example));
}
TEST_F(DecisionTreeTest, TraverseDenseBinarySplit) {
DecisionTreeConfig tree_config;
auto* split_node =
tree_config.add_nodes()->mutable_dense_float_binary_split();
split_node->set_feature_column(0);
split_node->set_threshold(0.0f);
split_node->set_left_id(1);
split_node->set_right_id(2);
tree_config.add_nodes()->mutable_leaf();
tree_config.add_nodes()->mutable_leaf();
auto example_iterable = batch_features_.examples_iterable(0, 2);
// Expect right child to be picked as !(7 <= 0);
auto example_it = example_iterable.begin();
EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it));
// Expect left child to be picked as (-2 <= 0);
EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it));
}
TEST_F(DecisionTreeTest, TraverseSparseBinarySplit) {
// Test first sparse feature which is missing for the second example.
DecisionTreeConfig tree_config1;
auto* split_node1 = tree_config1.add_nodes()
->mutable_sparse_float_binary_split_default_left()
->mutable_split();
split_node1->set_feature_column(0);
split_node1->set_threshold(-20.0f);
split_node1->set_left_id(1);
split_node1->set_right_id(2);
tree_config1.add_nodes()->mutable_leaf();
tree_config1.add_nodes()->mutable_leaf();
auto example_iterable = batch_features_.examples_iterable(0, 2);
// Expect right child to be picked as !(-3 <= -20).
auto example_it = example_iterable.begin();
EXPECT_EQ(2, DecisionTree::Traverse(tree_config1, 0, *example_it));
// Expect left child to be picked as default direction.
EXPECT_EQ(1, DecisionTree::Traverse(tree_config1, 0, *++example_it));
// Test second sparse feature which is missing for the first example.
DecisionTreeConfig tree_config2;
auto* split_node2 = tree_config2.add_nodes()
->mutable_sparse_float_binary_split_default_right()
->mutable_split();
split_node2->set_feature_column(1);
split_node2->set_threshold(4.0f);
split_node2->set_left_id(1);
split_node2->set_right_id(2);
tree_config2.add_nodes()->mutable_leaf();
tree_config2.add_nodes()->mutable_leaf();
// Expect right child to be picked as default direction.
example_it = example_iterable.begin();
EXPECT_EQ(2, DecisionTree::Traverse(tree_config2, 0, *example_it));
// Expect left child to be picked as (4 <= 4).
EXPECT_EQ(1, DecisionTree::Traverse(tree_config2, 0, *++example_it));
}
TEST_F(DecisionTreeTest, TraverseCategoricalIdBinarySplit) {
DecisionTreeConfig tree_config;
auto* split_node =
tree_config.add_nodes()->mutable_categorical_id_binary_split();
split_node->set_feature_column(0);
split_node->set_feature_id(3);
split_node->set_left_id(1);
split_node->set_right_id(2);
tree_config.add_nodes()->mutable_leaf();
tree_config.add_nodes()->mutable_leaf();
auto example_iterable = batch_features_.examples_iterable(0, 2);
// Expect left child to be picked as 3 == 3;
auto example_it = example_iterable.begin();
EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *example_it));
// Expect right child to be picked as the feature is missing;
EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *++example_it));
}
TEST_F(DecisionTreeTest, TraverseHybridSplits) {
DecisionTreeConfig tree_config;
auto* split_node1 =
tree_config.add_nodes()->mutable_dense_float_binary_split();
split_node1->set_feature_column(0);
split_node1->set_threshold(9.0f);
split_node1->set_left_id(1); // sparse split.
split_node1->set_right_id(2); // leaf
auto* split_node2 = tree_config.add_nodes()
->mutable_sparse_float_binary_split_default_left()
->mutable_split();
tree_config.add_nodes()->mutable_leaf();
split_node2->set_feature_column(0);
split_node2->set_threshold(-20.0f);
split_node2->set_left_id(3);
split_node2->set_right_id(4);
auto* split_node3 =
tree_config.add_nodes()->mutable_categorical_id_binary_split();
split_node3->set_feature_column(0);
split_node3->set_feature_id(2);
split_node3->set_left_id(5);
split_node3->set_right_id(6);
tree_config.add_nodes()->mutable_leaf();
tree_config.add_nodes()->mutable_leaf();
tree_config.add_nodes()->mutable_leaf();
auto example_iterable = batch_features_.examples_iterable(0, 2);
// Expect will go left through the first dense split as (7.0f <= 9.0f),
// then will go right through the sparse split as !(-3 <= -20).
auto example_it = example_iterable.begin();
EXPECT_EQ(4, DecisionTree::Traverse(tree_config, 0, *example_it));
// Expect will go left through the first dense split as (-2.0f <= 9.0f),
// then will go left the default direction as the sparse feature is missing,
// then will go right as 2 != 3 on the categorical split.
EXPECT_EQ(6, DecisionTree::Traverse(tree_config, 0, *++example_it));
}
TEST_F(DecisionTreeTest, LinkChildrenLeaf) {
// Create leaf node.
TreeNode node;
node.mutable_leaf();
// No-op.
DecisionTree::LinkChildren({}, &node);
// Invalid case.
EXPECT_DEATH(DecisionTree::LinkChildren({1}, &node),
"A leaf node cannot have children.");
}
TEST_F(DecisionTreeTest, LinkChildrenDenseFloatBinarySplit) {
TreeNode node;
auto* split = node.mutable_dense_float_binary_split();
split->set_left_id(-1);
split->set_right_id(-1);
TestLinkChildrenBinary(&node, split);
}
TEST_F(DecisionTreeTest, LinkChildrenSparseFloatBinarySplitDefaultLeft) {
TreeNode node;
auto* split =
node.mutable_sparse_float_binary_split_default_left()->mutable_split();
split->set_left_id(-1);
split->set_right_id(-1);
TestLinkChildrenBinary(&node, split);
}
TEST_F(DecisionTreeTest, LinkChildrenSparseFloatBinarySplitDefaultRight) {
TreeNode node;
auto* split =
node.mutable_sparse_float_binary_split_default_right()->mutable_split();
split->set_left_id(-1);
split->set_right_id(-1);
TestLinkChildrenBinary(&node, split);
}
TEST_F(DecisionTreeTest, LinkChildrenCategoricalSingleIdBinarySplit) {
TreeNode node;
auto* split = node.mutable_categorical_id_binary_split();
split->set_left_id(-1);
split->set_right_id(-1);
TestLinkChildrenBinary(&node, split);
}
TEST_F(DecisionTreeTest, LinkChildrenNodeNotSet) {
// Create unset node.
TreeNode node;
// Invalid case.
EXPECT_DEATH(DecisionTree::LinkChildren({1}, &node),
"A non-set node cannot have children.");
}
TEST_F(DecisionTreeTest, GetChildrenLeaf) {
TreeNode node;
node.mutable_leaf();
TestGetChildren(node, {});
}
TEST_F(DecisionTreeTest, GetChildrenDenseFloatBinarySplit) {
TreeNode node;
auto* split = node.mutable_dense_float_binary_split();
split->set_left_id(23);
split->set_right_id(24);
TestGetChildren(node, {23, 24});
}
TEST_F(DecisionTreeTest, GetChildrenSparseFloatBinarySplitDefaultLeft) {
TreeNode node;
auto* split =
node.mutable_sparse_float_binary_split_default_left()->mutable_split();
split->set_left_id(12);
split->set_right_id(13);
TestGetChildren(node, {12, 13});
}
TEST_F(DecisionTreeTest, GetChildrenSparseFloatBinarySplitDefaultRight) {
TreeNode node;
auto* split =
node.mutable_sparse_float_binary_split_default_right()->mutable_split();
split->set_left_id(1);
split->set_right_id(2);
TestGetChildren(node, {1, 2});
}
TEST_F(DecisionTreeTest, GetChildrenCategoricalSingleIdBinarySplit) {
TreeNode node;
auto* split = node.mutable_categorical_id_binary_split();
split->set_left_id(7);
split->set_right_id(8);
TestGetChildren(node, {7, 8});
}
TEST_F(DecisionTreeTest, GetChildrenNodeNotSet) {
TreeNode node;
TestGetChildren(node, {});
}
} // namespace
} // namespace trees
} // namespace boosted_trees
} // namespace tensorflow

View File

@ -24,6 +24,15 @@ tf_proto_library(
visibility = ["//visibility:public"],
)
tf_proto_library(
name = "quantiles_proto",
srcs = [
"quantiles.proto",
],
cc_api_version = 2,
visibility = ["//visibility:public"],
)
tf_proto_library(
name = "tree_config_proto",
srcs = ["tree_config.proto"],

View File

@ -0,0 +1,32 @@
syntax = "proto3";
option cc_enable_arenas = true;
package boosted_trees;
message QuantileConfig {
// Maximum eps error when computing quantile summaries.
double eps = 1;
// Number of quantiles to generate.
int64 num_quantiles = 2;
}
message QuantileEntry {
// Value for the entry.
float value = 1;
// Weight for the entry.
float weight = 2;
// We need the minimum and maximum rank possible for this entry.
// Rank is 0.0 for the absolute minimum and sum of the weights for the maximum
// value in the input.
float min_rank = 3;
float max_rank = 4;
}
message QuantileSummaryState {
repeated QuantileEntry entries = 1;
}
message QuantileStreamState {
repeated QuantileSummaryState summaries = 1;
}

View File

@ -0,0 +1,53 @@
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package(
default_visibility = [
"//tensorflow/contrib/boosted_trees:__subpackages__",
"//tensorflow/contrib/boosted_trees:friends",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)
cc_library(
name = "stamped_resource",
hdrs = ["stamped_resource.h"],
deps = [
"//tensorflow/core:framework_headers_lib",
"//third_party/eigen3",
],
)
cc_library(
name = "quantile_stream_resource",
hdrs = ["quantile_stream_resource.h"],
deps = [
":stamped_resource",
"//tensorflow/contrib/boosted_trees/lib:weighted_quantiles",
"//tensorflow/contrib/boosted_trees/proto:quantiles_proto_cc",
"//tensorflow/core:framework_headers_lib",
"//third_party/eigen3",
],
)
cc_library(
name = "decision_tree_ensemble_resource",
hdrs = ["decision_tree_ensemble_resource.h"],
deps = [
":stamped_resource",
"//tensorflow/contrib/boosted_trees/lib:trees",
"//tensorflow/core:framework_headers_lib",
],
alwayslink = 1,
)

View File

@ -0,0 +1,77 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_
#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
#include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
namespace boosted_trees {
namespace models {
// Keep a tree ensemble in memory for efficient evaluation and mutation.
class DecisionTreeEnsembleResource : public StampedResource {
public:
// Constructor.
explicit DecisionTreeEnsembleResource()
: decision_tree_ensemble_(
protobuf::Arena::CreateMessage<
boosted_trees::trees::DecisionTreeEnsembleConfig>(&arena_)) {}
string DebugString() override {
return strings::StrCat("GTFlowDecisionTreeEnsemble[size=",
decision_tree_ensemble_->trees_size(), "]");
}
const boosted_trees::trees::DecisionTreeEnsembleConfig&
decision_tree_ensemble() const {
return *decision_tree_ensemble_;
}
boosted_trees::trees::DecisionTreeEnsembleConfig*
mutable_decision_tree_ensemble() {
return decision_tree_ensemble_;
}
// Resets the resource and frees the protos in arena.
// Caller needs to hold the mutex lock while calling this.
void Reset() {
// Reset stamp.
set_stamp(-1);
// Clear tree ensemle.
arena_.Reset();
CHECK_EQ(0, arena_.SpaceAllocated());
decision_tree_ensemble_ = protobuf::Arena::CreateMessage<
boosted_trees::trees::DecisionTreeEnsembleConfig>(&arena_);
}
mutex* get_mutex() { return &mu_; }
private:
protobuf::Arena arena_;
mutex mu_;
boosted_trees::trees::DecisionTreeEnsembleConfig* decision_tree_ensemble_;
};
} // namespace models
} // namespace boosted_trees
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_

View File

@ -0,0 +1,104 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_
#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h"
#include "tensorflow/contrib/boosted_trees/proto/quantiles.pb.h" // NOLINT
#include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace boosted_trees {
using QuantileStream =
boosted_trees::quantiles::WeightedQuantilesStream<float, float>;
// Resource for accumulating summaries for multiple columns.
class QuantileStreamResource : public StampedResource {
public:
QuantileStreamResource(const float epsilon, const int32 num_quantiles,
const int64 max_elements, int64 stamp_token)
: stream_(epsilon, max_elements),
are_buckets_ready_(false),
epsilon_(epsilon),
num_quantiles_(num_quantiles),
max_elements_(max_elements) {
set_stamp(stamp_token);
}
string DebugString() override { return "QuantileStreamResource"; }
tensorflow::mutex* mutex() { return &mu_; }
QuantileStream* stream(int64 stamp) {
CHECK(is_stamp_valid(stamp));
return &stream_;
}
const std::vector<float>& boundaries(int64 stamp) {
CHECK(is_stamp_valid(stamp));
return boundaries_;
}
void set_boundaries(int64 stamp, const std::vector<float>& boundaries) {
CHECK(is_stamp_valid(stamp));
are_buckets_ready_ = true;
boundaries_ = boundaries;
}
float epsilon() const { return epsilon_; }
int32 num_quantiles() const { return num_quantiles_; }
void Reset(int64 stamp) {
set_stamp(stamp);
stream_ = QuantileStream(epsilon_, max_elements_);
}
bool are_buckets_ready() const { return are_buckets_ready_; }
void set_buckets_ready(bool are_buckets_ready) {
are_buckets_ready_ = are_buckets_ready;
}
private:
~QuantileStreamResource() override {}
// Mutex for the whole resource.
tensorflow::mutex mu_;
// Quantile stream.
QuantileStream stream_;
// Stores the boundaries from the previous iteration. Empty during the first
// iteration.
std::vector<float> boundaries_;
// Whether boundaries are created. Initially boundaries are empty until
// set_boundaries are called.
bool are_buckets_ready_;
const float epsilon_;
const int32 num_quantiles_;
// An upper-bound for the number of elements.
int64 max_elements_;
TF_DISALLOW_COPY_AND_ASSIGN(QuantileStreamResource);
};
} // namespace boosted_trees
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_

View File

@ -0,0 +1,42 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace boosted_trees {
// A StampedResource is a resource that has a stamp token associated with it.
// Before reading from or applying updates to the resource, the stamp should
// be checked to verify that the update is not stale.
class StampedResource : public ResourceBase {
public:
StampedResource() : stamp_(-1) {}
bool is_stamp_valid(int64 stamp) const { return stamp_ == stamp; }
int64 stamp() const { return stamp_; }
void set_stamp(int64 stamp) { stamp_ = stamp; }
private:
int64 stamp_;
};
} // namespace boosted_trees
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_

View File

@ -37,6 +37,9 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
"${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc"
"${tensorflow_source_dir}/tensorflow/contrib/layers/ops/bucketization_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc"
"${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/blas_gemm.cc"
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/gru_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/lstm_ops.cc"

View File

@ -58,6 +58,7 @@ GENERATE_CONTRIB_OP_LIBRARY(image "${tensorflow_source_dir}/tensorflow/contrib/i
GENERATE_CONTRIB_OP_LIBRARY(layers_bucketization "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/bucketization_op.cc")
GENERATE_CONTRIB_OP_LIBRARY(layers_sparse_feature_cross "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc")
GENERATE_CONTRIB_OP_LIBRARY(memory_stats "${tensorflow_source_dir}/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(nccl "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(rnn_gru "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/gru_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(rnn_lstm "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/lstm_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(tensor_forest "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/tensor_forest_ops.cc")

View File

@ -111,6 +111,7 @@ file(GLOB_RECURSE tf_protos_python_srcs RELATIVE ${tensorflow_source_dir}
"${tensorflow_source_dir}/tensorflow/core/*.proto"
"${tensorflow_source_dir}/tensorflow/python/*.proto"
"${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto"
"${tensorflow_source_dir}/tensorflow/tensorboard/*.proto"
"${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto"
"${tensorflow_source_dir}/tensorflow/contrib/training/*.proto"
)
@ -124,6 +125,7 @@ RELATIVE_PROTOBUF_GENERATE_PYTHON(
file(GLOB_RECURSE tf_python_protos_cc_srcs RELATIVE ${tensorflow_source_dir}
"${tensorflow_source_dir}/tensorflow/python/*.proto"
"${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto"
"${tensorflow_source_dir}/tensorflow/tensorboard/*.proto"
"${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto"
"${tensorflow_source_dir}/tensorflow/contrib/training/*.proto"
)
@ -342,6 +344,9 @@ add_python_module("tensorflow/contrib/keras/python/keras/layers")
add_python_module("tensorflow/contrib/keras/python/keras/preprocessing")
add_python_module("tensorflow/contrib/keras/python/keras/utils")
add_python_module("tensorflow/contrib/keras/python/keras/wrappers")
add_python_module("tensorflow/contrib/kernel_methods")
add_python_module("tensorflow/contrib/kernel_methods/python")
add_python_module("tensorflow/contrib/kernel_methods/python/mappers")
add_python_module("tensorflow/contrib/labeled_tensor")
add_python_module("tensorflow/contrib/labeled_tensor/python")
add_python_module("tensorflow/contrib/labeled_tensor/python/ops")
@ -405,6 +410,11 @@ add_python_module("tensorflow/contrib/ndlstm/python")
add_python_module("tensorflow/contrib/nn")
add_python_module("tensorflow/contrib/nn/python")
add_python_module("tensorflow/contrib/nn/python/ops")
add_python_module("tensorflow/contrib/nccl")
add_python_module("tensorflow/contrib/nccl/kernels")
add_python_module("tensorflow/contrib/nccl/ops")
add_python_module("tensorflow/contrib/nccl/python")
add_python_module("tensorflow/contrib/nccl/python/ops")
add_python_module("tensorflow/contrib/opt")
add_python_module("tensorflow/contrib/opt/python")
add_python_module("tensorflow/contrib/opt/python/training")
@ -599,6 +609,8 @@ GENERATE_PYTHON_OP_LIB("contrib_layers_sparse_feature_cross_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/layers/ops/gen_sparse_feature_cross_op.py)
GENERATE_PYTHON_OP_LIB("contrib_memory_stats_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/memory_stats/ops/gen_memory_stats_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_nccl_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/nccl/ops/gen_nccl_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_rnn_gru_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rnn/ops/gen_gru_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_rnn_lstm_ops"

View File

@ -69,19 +69,25 @@ class BinomialTest(test.TestCase):
self.assertEqual((1, 3), binom.logits.get_shape())
self.assertAllClose(logits, binom.logits.eval())
def testPmfNandCountsAgree(self):
def testPmfAndCdfNandCountsAgree(self):
p = [[0.1, 0.2, 0.7]]
n = [[5.]]
with self.test_session():
binom = binomial.Binomial(total_count=n, probs=p, validate_args=True)
binom.prob([2., 3, 2]).eval()
binom.prob([3., 1, 2]).eval()
binom.cdf([2., 3, 2]).eval()
binom.cdf([3., 1, 2]).eval()
with self.assertRaisesOpError("Condition x >= 0.*"):
binom.prob([-1., 4, 2]).eval()
with self.assertRaisesOpError("Condition x <= y.*"):
binom.prob([7., 3, 0]).eval()
with self.assertRaisesOpError("Condition x >= 0.*"):
binom.cdf([-1., 4, 2]).eval()
with self.assertRaisesOpError("Condition x <= y.*"):
binom.cdf([7., 3, 0]).eval()
def testPmfNonIntegerCounts(self):
def testPmfAndCdfNonIntegerCounts(self):
p = [[0.1, 0.2, 0.7]]
n = [[5.]]
with self.test_session():
@ -89,50 +95,72 @@ class BinomialTest(test.TestCase):
binom = binomial.Binomial(total_count=n, probs=p, validate_args=True)
binom.prob([2., 3, 2]).eval()
binom.prob([3., 1, 2]).eval()
binom.cdf([2., 3, 2]).eval()
binom.cdf([3., 1, 2]).eval()
# Both equality and integer checking fail.
with self.assertRaisesOpError(
"cannot contain fractional components."):
binom.prob([1.0, 2.5, 1.5]).eval()
with self.assertRaisesOpError(
"cannot contain fractional components."):
binom.cdf([1.0, 2.5, 1.5]).eval()
binom = binomial.Binomial(total_count=n, probs=p, validate_args=False)
binom.prob([1., 2., 3.]).eval()
binom.cdf([1., 2., 3.]).eval()
# Non-integer arguments work.
binom.prob([1.0, 2.5, 1.5]).eval()
binom.cdf([1.0, 2.5, 1.5]).eval()
def testPmfBothZeroBatches(self):
def testPmfAndCdfBothZeroBatches(self):
with self.test_session():
# Both zero-batches. No broadcast
p = 0.5
counts = 1.
pmf = binomial.Binomial(total_count=1., probs=p).prob(counts)
binom = binomial.Binomial(total_count=1., probs=p)
pmf = binom.prob(counts)
cdf = binom.cdf(counts)
self.assertAllClose(0.5, pmf.eval())
self.assertAllClose(stats.binom.cdf(counts, n=1, p=p), cdf.eval())
self.assertEqual((), pmf.get_shape())
self.assertEqual((), cdf.get_shape())
def testPmfBothZeroBatchesNontrivialN(self):
def testPmfAndCdfBothZeroBatchesNontrivialN(self):
with self.test_session():
# Both zero-batches. No broadcast
p = 0.1
counts = 3.
binom = binomial.Binomial(total_count=5., probs=p)
pmf = binom.prob(counts)
cdf = binom.cdf(counts)
self.assertAllClose(stats.binom.pmf(counts, n=5., p=p), pmf.eval())
self.assertAllClose(stats.binom.cdf(counts, n=5., p=p), cdf.eval())
self.assertEqual((), pmf.get_shape())
self.assertEqual((), cdf.get_shape())
def testPmfPStretchedInBroadcastWhenSameRank(self):
def testPmfAndCdfPStretchedInBroadcastWhenSameRank(self):
with self.test_session():
p = [[0.1, 0.9]]
counts = [[1., 2.]]
pmf = binomial.Binomial(total_count=3., probs=p).prob(counts)
binom = binomial.Binomial(total_count=3., probs=p)
pmf = binom.prob(counts)
cdf = binom.cdf(counts)
self.assertAllClose(stats.binom.pmf(counts, n=3., p=p), pmf.eval())
self.assertAllClose(stats.binom.cdf(counts, n=3., p=p), cdf.eval())
self.assertEqual((1, 2), pmf.get_shape())
self.assertEqual((1, 2), cdf.get_shape())
def testPmfPStretchedInBroadcastWhenLowerRank(self):
def testPmfAndCdfPStretchedInBroadcastWhenLowerRank(self):
with self.test_session():
p = [0.1, 0.4]
counts = [[1.], [0.]]
pmf = binomial.Binomial(total_count=1., probs=p).prob(counts)
binom = binomial.Binomial(total_count=1., probs=p)
pmf = binom.prob(counts)
cdf = binom.cdf(counts)
self.assertAllClose([[0.1, 0.4], [0.9, 0.6]], pmf.eval())
self.assertAllClose([[1.0, 1.0], [0.9, 0.6]], cdf.eval())
self.assertEqual((2, 2), pmf.get_shape())
self.assertEqual((2, 2), cdf.get_shape())
def testBinomialMean(self):
with self.test_session():

View File

@ -103,6 +103,31 @@ class MultivariateNormalDiagTest(test.TestCase):
self.assertAllClose(cov_mat, np.cov(samps.T),
atol=0.05, rtol=0.05)
def testSampleWithBroadcastScale(self):
# mu corresponds to a 2-batch of 3-variate normals
mu = np.zeros([2, 3])
# diag corresponds to no batches of 3-variate normals
diag = np.ones([3])
with self.test_session():
dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True)
mean = dist.mean()
self.assertAllEqual([2, 3], mean.get_shape())
self.assertAllClose(mu, mean.eval())
n = int(1e3)
samps = dist.sample(n, seed=0).eval()
cov_mat = array_ops.matrix_diag(diag).eval()**2
sample_cov = np.matmul(samps.transpose([1, 2, 0]),
samps.transpose([1, 0, 2])) / n
self.assertAllClose(mu, samps.mean(axis=0),
atol=0.10, rtol=0.05)
self.assertAllClose([cov_mat, cov_mat], sample_cov,
atol=0.10, rtol=0.05)
def testCovariance(self):
with self.test_session():
mvn = ds.MultivariateNormalDiag(

View File

@ -42,6 +42,28 @@ to integer values.
"""
def _bdtr(k, n, p):
"""The binomial cumulative distribution function.
Args:
k: floating point `Tensor`.
n: floating point `Tensor`.
p: floating point `Tensor`.
Returns:
`sum_{j=0}^k p^j (1 - p)^(n - j)`.
"""
# Trick for getting safe backprop/gradients into n, k when
# betainc(a = 0, ..) = nan
# Write:
# where(unsafe, safe_output, betainc(where(unsafe, safe_input, input)))
ones = array_ops.ones_like(n - k)
k_eq_n = math_ops.equal(k, n)
safe_dn = array_ops.where(k_eq_n, ones, n - k)
dk = math_ops.betainc(a=safe_dn, b=k + 1, x=1 - p)
return array_ops.where(k_eq_n, ones, dk)
class Binomial(distribution.Distribution):
"""Binomial distribution.
@ -201,6 +223,18 @@ class Binomial(distribution.Distribution):
def _prob(self, counts):
return math_ops.exp(self._log_prob(counts))
def _cdf(self, counts):
counts = self._maybe_assert_valid_sample(counts)
probs = self.probs
if not (counts.shape.is_fully_defined()
and self.probs.shape.is_fully_defined()
and counts.shape.is_compatible_with(self.probs.shape)):
# If both shapes are well defined and equal, we skip broadcasting.
probs += array_ops.zeros_like(counts)
counts += array_ops.zeros_like(self.probs)
return _bdtr(k=counts, n=self.total_count, p=probs)
def _log_unnormalized_prob(self, counts):
counts = self._maybe_assert_valid_sample(counts)
return (counts * math_ops.log(self.probs) +

View File

@ -25,6 +25,7 @@ from tensorflow.contrib.distributions.python.ops import kullback_leibler
from tensorflow.contrib.distributions.python.ops import normal
from tensorflow.contrib.distributions.python.ops import transformed_distribution
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
@ -53,6 +54,16 @@ or
"""
def _broadcast_shape(shape1, shape2):
"""Convenience function which statically broadcasts shape when possible."""
if (tensor_util.constant_value(shape1) is not None and
tensor_util.constant_value(shape2) is not None):
return array_ops.broadcast_static_shape(
tensor_shape.TensorShape(tensor_util.constant_value(shape1)),
tensor_shape.TensorShape(tensor_util.constant_value(shape2)))
return array_ops.broadcast_dynamic_shape(shape1, shape2)
# TODO(b/35290280): Import in `../../__init__.py` after adding unit-tests.
class MultivariateNormalLinearOperator(
transformed_distribution.TransformedDistribution):
@ -179,12 +190,25 @@ class MultivariateNormalLinearOperator(
if not scale.dtype.is_floating:
raise TypeError("`scale` parameter must have floating-point dtype.")
with ops.name_scope(name, values=[loc] + scale.graph_parents):
# Since expand_dims doesn't preserve constant-ness, we obtain the
# non-dynamic value if possible.
event_shape = scale.domain_dimension_tensor()
event_shape = scale.range_dimension_tensor()
if tensor_util.constant_value(event_shape) is not None:
event_shape = tensor_util.constant_value(event_shape)
event_shape = tensor_util.constant_value(event_shape).reshape([1])
else:
event_shape = event_shape[array_ops.newaxis]
batch_shape = scale.batch_shape_tensor()
if loc is not None:
loc = ops.convert_to_tensor(loc, name="loc")
loc_batch_shape = loc.get_shape().with_rank_at_least(1)[:-1]
if (loc.get_shape().ndims is None or
not loc_batch_shape.is_fully_defined()):
loc_batch_shape = array_ops.shape(loc)[:-1]
else:
loc_batch_shape = ops.convert_to_tensor(loc_batch_shape,
name="loc_batch_shape")
batch_shape = _broadcast_shape(batch_shape, loc_batch_shape)
super(MultivariateNormalLinearOperator, self).__init__(
distribution=normal.Normal(
@ -192,7 +216,7 @@ class MultivariateNormalLinearOperator(
scale=array_ops.ones([], dtype=scale.dtype)),
bijector=bijectors.AffineLinearOperator(
shift=loc, scale=scale, validate_args=validate_args),
batch_shape=scale.batch_shape_tensor(),
batch_shape=batch_shape,
event_shape=event_shape,
validate_args=validate_args,
name=name)

View File

@ -35,6 +35,7 @@ tf_custom_op_py_library(
],
srcs_version = "PY2AND3",
deps = [
":factorization_ops_test_utils_py",
":gen_clustering_ops",
":gen_factorization_ops",
"//tensorflow/contrib/framework:framework_py",
@ -161,12 +162,28 @@ tf_py_test(
],
)
py_library(
name = "factorization_ops_test_utils_py",
srcs = [
"python/ops/factorization_ops_test_utils.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:embedding_ops",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:sparse_ops",
],
)
tf_py_test(
name = "factorization_ops_test",
srcs = ["python/ops/factorization_ops_test.py"],
additional_deps = [
":factorization_py",
":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO",
":factorization_ops_test_utils_py",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",

View File

@ -18,141 +18,40 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.factorization.python.ops import factorization_ops
from tensorflow.python.framework import constant_op
from tensorflow.contrib.factorization.python.ops import factorization_ops_test_utils
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
INPUT_MATRIX = np.array(
[[0.1, 0.0, 0.2, 0.0, 0.4, 0.5, 0.0],
[0.0, 1.1, 0.0, 1.3, 1.4, 0.0, 1.6],
[2.0, 0.0, 0.0, 2.3, 0.0, 2.5, 0.0],
[3.0, 0.0, 3.2, 3.3, 0.0, 3.5, 0.0],
[0.0, 4.1, 0.0, 0.0, 4.4, 0.0, 4.6]]).astype(np.float32)
INPUT_MATRIX = factorization_ops_test_utils.INPUT_MATRIX
np_matrix_to_tf_sparse = factorization_ops_test_utils.np_matrix_to_tf_sparse
def np_matrix_to_tf_sparse(np_matrix,
row_slices=None,
col_slices=None,
transpose=False,
shuffle=False):
"""Simple util to slice non-zero np matrix elements as tf.SparseTensor."""
indices = np.nonzero(np_matrix)
class WalsModelTest(test.TestCase):
# Only allow slices of whole rows or whole columns.
assert not (row_slices is not None and col_slices is not None)
if row_slices is not None:
selected_ind = np.concatenate(
[np.where(indices[0] == r)[0] for r in row_slices], 0)
indices = (indices[0][selected_ind], indices[1][selected_ind])
if col_slices is not None:
selected_ind = np.concatenate(
[np.where(indices[1] == c)[0] for c in col_slices], 0)
indices = (indices[0][selected_ind], indices[1][selected_ind])
if shuffle:
shuffled_ind = [x for x in range(len(indices[0]))]
random.shuffle(shuffled_ind)
indices = (indices[0][shuffled_ind], indices[1][shuffled_ind])
ind = (np.concatenate((np.expand_dims(indices[1], 1),
np.expand_dims(indices[0], 1)), 1).astype(np.int64) if
transpose else np.concatenate((np.expand_dims(indices[0], 1),
np.expand_dims(indices[1], 1)),
1).astype(np.int64))
val = np_matrix[indices].astype(np.float32)
shape = (np.array([max(indices[1]) + 1, max(indices[0]) + 1]).astype(np.int64)
if transpose else np.array(
[max(indices[0]) + 1, max(indices[1]) + 1]).astype(np.int64))
return sparse_tensor.SparseTensor(ind, val, shape)
def sparse_input():
def sparse_input(self):
return np_matrix_to_tf_sparse(INPUT_MATRIX)
def count_rows(sp_input):
def count_rows(self, sp_input):
return math_ops.cast(
array_ops.shape(array_ops.unique(sp_input.indices[:, 0])[0])[0],
dtypes.float32)
def count_cols(sp_input):
def count_cols(self, sp_input):
return math_ops.cast(
array_ops.shape(array_ops.unique(sp_input.indices[:, 1])[0])[0],
dtypes.float32)
def calculate_loss(input_mat, row_factors, col_factors, regularization=None,
w0=1., row_weights=None, col_weights=None):
"""Calculates the loss of a given factorization.
Using a non distributed method, different than the one implemented in the
WALS model. The weight of an observed entry (i, j) (i.e. such that
input_mat[i, j] is non zero) is (w0 + row_weights[i]col_weights[j]).
Args:
input_mat: The input matrix, a SparseTensor of rank 2.
row_factors: The row factors, a dense Tensor of rank 2.
col_factors: The col factors, a dense Tensor of rank 2.
regularization: the regularization coefficient, a scalar.
w0: the weight of unobserved entries. A scalar.
row_weights: A dense tensor of rank 1.
col_weights: A dense tensor of rank 1.
Returns:
The total loss.
"""
wr = (array_ops.expand_dims(row_weights, 1) if row_weights is not None
else constant_op.constant(1.))
wc = (array_ops.expand_dims(col_weights, 0) if col_weights is not None
else constant_op.constant(1.))
reg = (regularization if regularization is not None
else constant_op.constant(0.))
row_indices, col_indices = array_ops.split(input_mat.indices,
axis=1,
num_or_size_splits=2)
gathered_row_factors = array_ops.gather(row_factors, row_indices)
gathered_col_factors = array_ops.gather(col_factors, col_indices)
sp_approx_vals = array_ops.squeeze(math_ops.matmul(
gathered_row_factors, gathered_col_factors, adjoint_b=True))
sp_approx = sparse_tensor.SparseTensor(
indices=input_mat.indices,
values=sp_approx_vals,
dense_shape=input_mat.dense_shape)
sp_approx_sq = math_ops.square(sp_approx)
row_norm = math_ops.reduce_sum(math_ops.square(row_factors))
col_norm = math_ops.reduce_sum(math_ops.square(col_factors))
row_col_norm = math_ops.reduce_sum(math_ops.square(math_ops.matmul(
row_factors, col_factors, transpose_b=True)))
resid = sparse_ops.sparse_add(input_mat, sp_approx * (-1))
resid_sq = math_ops.square(resid)
loss = w0 * (
sparse_ops.sparse_reduce_sum(resid_sq) -
sparse_ops.sparse_reduce_sum(sp_approx_sq)
)
loss += (sparse_ops.sparse_reduce_sum(wr * (resid_sq * wc)) +
w0 * row_col_norm + reg * (row_norm + col_norm))
return loss.eval()
def calculate_loss_from_wals_model(wals_model, sp_inputs):
def calculate_loss_from_wals_model(self, wals_model, sp_inputs):
current_rows = embedding_ops.embedding_lookup(
wals_model.row_factors, math_ops.range(wals_model._input_rows),
partition_strategy="div")
@ -165,13 +64,10 @@ def calculate_loss_from_wals_model(wals_model, sp_inputs):
col_wts = embedding_ops.embedding_lookup(
wals_model._col_weights, math_ops.range(wals_model._input_cols),
partition_strategy="div")
return calculate_loss(
return factorization_ops_test_utils.calculate_loss(
sp_inputs, current_rows, current_cols, wals_model._regularization,
wals_model._unobserved_weight, row_wts, col_wts)
class WalsModelTest(test.TestCase):
def setUp(self):
self.col_init = [
# shard 0
@ -208,7 +104,7 @@ class WalsModelTest(test.TestCase):
use_factors_weights_cache,
compute_loss=False):
with ops.Graph().as_default(), self.test_session() as sess:
self._wals_inputs = sparse_input()
self._wals_inputs = self.sparse_input()
sp_feeder = array_ops.sparse_placeholder(dtypes.float32)
num_rows = 5
num_cols = 7
@ -282,10 +178,10 @@ class WalsModelTest(test.TestCase):
if compute_loss:
# Test loss computation after the row update
loss = sum(
sess.run(factor_loss * count_rows(inp) / num_rows,
sess.run(factor_loss * self.count_rows(inp) / num_rows,
feed_dict={sp_feeder: inp})
for inp in input_scattered_rows)
true_loss = calculate_loss_from_wals_model(
true_loss = self.calculate_loss_from_wals_model(
wals_model, self._wals_inputs)
self.assertNear(
loss, true_loss, err=.001,
@ -355,10 +251,10 @@ class WalsModelTest(test.TestCase):
if compute_loss:
# Test loss computation after the column update.
loss = sum(
sess.run(factor_loss * count_cols(inp) / num_cols,
sess.run(factor_loss * self.count_cols(inp) / num_cols,
feed_dict={sp_feeder: inp})
for inp in input_scattered_cols_non_duplicate)
true_loss = calculate_loss_from_wals_model(
true_loss = self.calculate_loss_from_wals_model(
wals_model, self._wals_inputs)
self.assertNear(
loss, true_loss, err=.001,
@ -368,7 +264,7 @@ class WalsModelTest(test.TestCase):
def _run_test_process_input_transposed(self, use_factors_weights_cache,
compute_loss=False):
with ops.Graph().as_default(), self.test_session() as sess:
self._wals_inputs = sparse_input()
self._wals_inputs = self.sparse_input()
sp_feeder = array_ops.sparse_placeholder(dtypes.float32)
num_rows = 5
num_cols = 7
@ -448,10 +344,10 @@ class WalsModelTest(test.TestCase):
if compute_loss:
# Test loss computation after the row update
loss = sum(
sess.run(factor_loss * count_cols(inp) / num_rows,
sess.run(factor_loss * self.count_cols(inp) / num_rows,
feed_dict={sp_feeder: inp})
for inp in input_scattered_rows_non_duplicate)
true_loss = calculate_loss_from_wals_model(
true_loss = self.calculate_loss_from_wals_model(
wals_model, self._wals_inputs)
self.assertNear(
loss, true_loss, err=.001,
@ -516,10 +412,10 @@ class WalsModelTest(test.TestCase):
if compute_loss:
# Test loss computation after the col update
loss = sum(
sess.run(factor_loss * count_rows(inp) / num_cols,
sess.run(factor_loss * self.count_rows(inp) / num_cols,
feed_dict={sp_feeder: inp})
for inp in input_scattered_cols_non_duplicate)
true_loss = calculate_loss_from_wals_model(
true_loss = self.calculate_loss_from_wals_model(
wals_model, self._wals_inputs)
self.assertNear(
loss, true_loss, err=.001,
@ -534,7 +430,7 @@ class WalsModelTest(test.TestCase):
# Here we test that those two give identical results.
def _run_test_als(self, use_factors_weights_cache):
with ops.Graph().as_default(), self.test_session():
self._wals_inputs = sparse_input()
self._wals_inputs = self.sparse_input()
col_init = np.random.rand(7, 3)
als_model = factorization_ops.WALSModel(
5,
@ -613,7 +509,7 @@ class WalsModelTest(test.TestCase):
def _run_test_als_transposed(self, use_factors_weights_cache):
with ops.Graph().as_default(), self.test_session():
self._wals_inputs = sparse_input()
self._wals_inputs = self.sparse_input()
col_init = np.random.rand(7, 3)
als_model = factorization_ops.WALSModel(
5,

View File

@ -0,0 +1,131 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test utils for factorization_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
INPUT_MATRIX = np.array(
[[0.1, 0.0, 0.2, 0.0, 0.4, 0.5, 0.0],
[0.0, 1.1, 0.0, 1.3, 1.4, 0.0, 1.6],
[2.0, 0.0, 0.0, 2.3, 0.0, 2.5, 0.0],
[3.0, 0.0, 3.2, 3.3, 0.0, 3.5, 0.0],
[0.0, 4.1, 0.0, 0.0, 4.4, 0.0, 4.6]]).astype(np.float32)
def np_matrix_to_tf_sparse(np_matrix,
row_slices=None,
col_slices=None,
transpose=False,
shuffle=False):
"""Simple util to slice non-zero np matrix elements as tf.SparseTensor."""
indices = np.nonzero(np_matrix)
# Only allow slices of whole rows or whole columns.
assert not (row_slices is not None and col_slices is not None)
if row_slices is not None:
selected_ind = np.concatenate(
[np.where(indices[0] == r)[0] for r in row_slices], 0)
indices = (indices[0][selected_ind], indices[1][selected_ind])
if col_slices is not None:
selected_ind = np.concatenate(
[np.where(indices[1] == c)[0] for c in col_slices], 0)
indices = (indices[0][selected_ind], indices[1][selected_ind])
if shuffle:
shuffled_ind = [x for x in range(len(indices[0]))]
random.shuffle(shuffled_ind)
indices = (indices[0][shuffled_ind], indices[1][shuffled_ind])
ind = (np.concatenate((np.expand_dims(indices[1], 1),
np.expand_dims(indices[0], 1)), 1).astype(np.int64) if
transpose else np.concatenate((np.expand_dims(indices[0], 1),
np.expand_dims(indices[1], 1)),
1).astype(np.int64))
val = np_matrix[indices].astype(np.float32)
shape = (np.array([max(indices[1]) + 1, max(indices[0]) + 1]).astype(np.int64)
if transpose else np.array(
[max(indices[0]) + 1, max(indices[1]) + 1]).astype(np.int64))
return sparse_tensor.SparseTensor(ind, val, shape)
def calculate_loss(input_mat, row_factors, col_factors, regularization=None,
w0=1., row_weights=None, col_weights=None):
"""Calculates the loss of a given factorization.
Using a non distributed method, different than the one implemented in the
WALS model. The weight of an observed entry (i, j) (i.e. such that
input_mat[i, j] is non zero) is (w0 + row_weights[i]col_weights[j]).
Args:
input_mat: The input matrix, a SparseTensor of rank 2.
row_factors: The row factors, a dense Tensor of rank 2.
col_factors: The col factors, a dense Tensor of rank 2.
regularization: the regularization coefficient, a scalar.
w0: the weight of unobserved entries. A scalar.
row_weights: A dense tensor of rank 1.
col_weights: A dense tensor of rank 1.
Returns:
The total loss.
"""
wr = (array_ops.expand_dims(row_weights, 1) if row_weights is not None
else constant_op.constant(1.))
wc = (array_ops.expand_dims(col_weights, 0) if col_weights is not None
else constant_op.constant(1.))
reg = (regularization if regularization is not None
else constant_op.constant(0.))
row_indices, col_indices = array_ops.split(input_mat.indices,
axis=1,
num_or_size_splits=2)
gathered_row_factors = array_ops.gather(row_factors, row_indices)
gathered_col_factors = array_ops.gather(col_factors, col_indices)
sp_approx_vals = array_ops.squeeze(math_ops.matmul(
gathered_row_factors, gathered_col_factors, adjoint_b=True))
sp_approx = sparse_tensor.SparseTensor(
indices=input_mat.indices,
values=sp_approx_vals,
dense_shape=input_mat.dense_shape)
sp_approx_sq = math_ops.square(sp_approx)
row_norm = math_ops.reduce_sum(math_ops.square(row_factors))
col_norm = math_ops.reduce_sum(math_ops.square(col_factors))
row_col_norm = math_ops.reduce_sum(math_ops.square(math_ops.matmul(
row_factors, col_factors, transpose_b=True)))
resid = sparse_ops.sparse_add(input_mat, sp_approx * (-1))
resid_sq = math_ops.square(resid)
loss = w0 * (
sparse_ops.sparse_reduce_sum(resid_sq) -
sparse_ops.sparse_reduce_sum(sp_approx_sq)
)
loss += (sparse_ops.sparse_reduce_sum(wr * (resid_sq * wc)) +
w0 * row_col_norm + reg * (row_norm + col_norm))
return loss.eval()

View File

@ -21,9 +21,14 @@ py_library(
":dense_kernel_mapper_py",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/learn",
"//tensorflow/python:framework",
"//tensorflow/python:ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:util",
"//third_party/py/numpy",
"@six_archive//:six",
],
)
@ -31,6 +36,7 @@ py_library(
name = "dense_kernel_mapper_py",
srcs = ["python/mappers/dense_kernel_mapper.py"],
srcs_version = "PY2AND3",
deps = ["@six_archive//:six"],
)
py_test(
@ -40,12 +46,12 @@ py_test(
deps = [
":dense_kernel_mapper_py",
":kernel_methods",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn",
"//tensorflow/python:ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:random_ops",
],
)
@ -55,10 +61,12 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":kernel_methods",
"//tensorflow/python:client_testlib",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/learn",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:sparse_tensor",
"//third_party/py/numpy",
],
)

View File

@ -22,7 +22,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.kernel_methods.python.kernel_estimators import KernelLinearClassifier
from tensorflow.contrib.kernel_methods.python.mappers import dense_kernel_mapper
from tensorflow.contrib.kernel_methods.python.mappers.random_fourier_features import RandomFourierFeatureMapper
from tensorflow.python.util.all_util import remove_undocumented

View File

@ -118,6 +118,7 @@ tf_custom_op_py_library(
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:clip_ops",
"//tensorflow/python:common_shapes",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:embedding_ops",
"//tensorflow/python:framework",
@ -131,9 +132,11 @@ tf_custom_op_py_library(
"//tensorflow/python:platform",
"//tensorflow/python:random_ops",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:standard_ops",
"//tensorflow/python:string_ops",
"//tensorflow/python:summary",
"//tensorflow/python:tensor_util",
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from tensorflow.contrib.framework.python.framework import checkpoint_utils
from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
@ -36,6 +38,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
def _embeddings_from_arguments(column,
@ -136,6 +139,58 @@ def _embeddings_from_arguments(column,
max_norm=args.max_norm)
def _maybe_reshape_input_tensor(tensor, column_name, output_rank):
"""Reshape the input tensor by the following rule.
1. If `output_rank > input_rank + 1`, raise a `ValueError`.
2. If `output_rank == input_rank + 1`, expand the tensor by one dimension.
3. If `output_rank == input_rank`, do nothing.
4. If `output_rank < input_rank`, flatten the inner dimensions of the tensor.
Args:
tensor: A Tensor or SparseTensor to be reshaped.
column_name: A string name of the feature column for the tensor.
output_rank: the desired rank of the tensor.
Returns:
A reshaped Tensor or SparseTensor.
Raises:
ValueError: if `output_rank > input_rank + 1` for the input tensor.
"""
input_rank = tensor.get_shape().ndims
if input_rank is None and isinstance(tensor, sparse_tensor_py.SparseTensor):
# Try to get the rank of a sparse tensor by its dense_shape's shape.
input_rank = tensor.dense_shape.get_shape().as_list()[0]
if input_rank is None:
raise ValueError('Error while processing column {}. Rank of input Tensor '
'can not be None.'.format(column_name))
if output_rank > input_rank + 1:
raise ValueError('Error while processing column {}. Rank of input Tensor '
'({}) should be the same as output_rank ({}). For '
'example, sequence data should typically be 3 '
'dimensional (rank 3) while non-sequence data is '
'typically 2 dimensional (rank 2).'.format(
column_name, input_rank, output_rank))
elif output_rank == input_rank + 1:
# Expand the tensor's shape by 1 dimension.
if isinstance(tensor, sparse_tensor_py.SparseTensor):
output_shape = array_ops.concat([tensor.dense_shape, [1]], 0)
return sparse_ops.sparse_reshape(tensor, output_shape)
else:
reshaped = array_ops.expand_dims(tensor, -1)
# Try to calculate the new shape.
static_shape = tensor.get_shape()
if static_shape is not None and static_shape.dims is not None:
reshaped.set_shape(static_shape.as_list() + [1])
return reshaped
elif output_rank < input_rank:
return layers._inner_flatten(tensor, output_rank) # pylint: disable=protected-access
else:
return tensor
def _input_from_feature_columns(columns_to_tensors,
feature_columns,
weight_collections,
@ -160,6 +215,12 @@ def _input_from_feature_columns(columns_to_tensors,
default_name=column.name,
values=columns_to_tensors.values()):
transformed_tensor = transformer.transform(column)
if output_rank == 3:
transformed_tensor = nest.map_structure(
functools.partial(
_maybe_reshape_input_tensor,
column_name=column.name,
output_rank=output_rank), transformed_tensor)
try:
# pylint: disable=protected-access
arguments = column._deep_embedding_lookup_arguments(
@ -548,7 +609,8 @@ def weighted_sum_from_feature_columns(columns_to_tensors,
default_name=column.name,
values=columns_to_tensors.values()):
tensor = column._to_dense_tensor(transformed_tensor)
tensor = fc._reshape_real_valued_tensor(tensor, 2, column.name)
tensor = _maybe_reshape_input_tensor(
tensor, column.name, output_rank=2)
variable = [
contrib_variables.model_variable(
name='weight',

View File

@ -1350,6 +1350,35 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
self.assertAllEqual(expected_input_shape, model_input.shape)
def testEmbeddingColumnWithAutoReshape(self):
hash_buckets = 10
embedding_dimension = 5
ids_tensor = sparse_tensor.SparseTensor(
values=["c", "b",
"a", "c", "b",
"b"],
indices=[[0, 0], [0, 1],
[1, 0], [1, 1], [1, 2],
[3, 2]],
dense_shape=[4, 3])
expected_input_shape = np.array([4, 3, embedding_dimension])
hashed_ids_column = feature_column.sparse_column_with_hash_bucket(
"ids", hash_buckets)
embedded_column = feature_column.embedding_column(hashed_ids_column,
embedding_dimension)
columns_to_tensors = {"ids": ids_tensor}
model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
columns_to_tensors, [embedded_column])
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
self.assertAllEqual(expected_input_shape, model_input.shape)
def testEmbeddingColumnGradient(self):
hash_buckets = 1000
embedding_dimension = 3

View File

@ -836,6 +836,19 @@ py_test(
],
)
py_test(
name = "model_fn_test",
size = "small",
srcs = ["python/learn/estimators/model_fn_test.py"],
srcs_version = "PY2AND3",
deps = [
":learn",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"//third_party/py/numpy",
],
)
py_test(
name = "multioutput_test",
size = "small",

View File

@ -42,6 +42,7 @@ from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.summary import summary
@ -816,7 +817,8 @@ class _BinaryLogisticHead(_SingleHead):
loss_fn=self._loss_fn,
logits_to_predictions_fn=self._logits_to_predictions,
metrics_fn=self._metrics,
create_output_alternatives_fn=self._create_output_alternatives,
create_output_alternatives_fn=_classification_output_alternatives(
self.head_name, self._problem_type),
labels=labels,
train_op_fn=train_op_fn,
logits=logits,
@ -885,6 +887,8 @@ class _BinaryLogisticHead(_SingleHead):
_indicator_labels_streaming_mean(labels, weights))
metrics[_summary_key(self.head_name, mkey.AUC)] = (
_streaming_auc(logistic, labels, weights))
metrics[_summary_key(self.head_name, mkey.AUC_PR)] = (
_streaming_auc(logistic, labels, weights, curve="PR"))
for threshold in self._thresholds:
metrics[_summary_key(
@ -1009,7 +1013,8 @@ class _MultiClassHead(_SingleHead):
loss_fn=self._wrapped_loss_fn,
logits_to_predictions_fn=self._logits_to_predictions,
metrics_fn=self._metrics,
create_output_alternatives_fn=self._create_output_alternatives,
create_output_alternatives_fn=_classification_output_alternatives(
self.head_name, self._problem_type, self._label_keys),
labels=labels,
train_op_fn=train_op_fn,
logits=logits,
@ -1113,25 +1118,6 @@ class _MultiClassHead(_SingleHead):
return metrics
def _create_output_alternatives(self, predictions):
"""See superclass."""
probabilities = predictions[prediction_key.PredictionKey.PROBABILITIES]
batch_size = array_ops.shape(probabilities)[0]
if self._label_keys:
classes = array_ops.tile(
input=array_ops.expand_dims(input=self._label_keys, axis=0),
multiples=[batch_size, 1])
else:
classes = array_ops.tile(
input=array_ops.expand_dims(
input=math_ops.range(self.logits_dimension), axis=0),
multiples=[batch_size, 1])
predictions_for_serving = {
prediction_key.PredictionKey.CLASSES: classes,
prediction_key.PredictionKey.PROBABILITIES: probabilities,
}
return {self._head_name: (self._problem_type, predictions_for_serving)}
def _to_labels_tensor(labels, label_name):
"""Returns label as a tensor.
@ -1226,6 +1212,7 @@ class _BinarySvmHead(_SingleHead):
loss_fn=self._loss_fn,
logits_to_predictions_fn=self._logits_to_predictions,
metrics_fn=self._metrics,
# TODO(zakaria): Handle labels for export.
create_output_alternatives_fn=self._create_output_alternatives,
labels=labels,
train_op_fn=train_op_fn,
@ -1325,7 +1312,8 @@ class _MultiLabelHead(_SingleHead):
loss_fn=self._loss_fn,
logits_to_predictions_fn=self._logits_to_predictions,
metrics_fn=self._metrics,
create_output_alternatives_fn=self._create_output_alternatives,
create_output_alternatives_fn=_classification_output_alternatives(
self.head_name, self._problem_type),
labels=labels,
train_op_fn=train_op_fn,
logits=logits,
@ -1374,6 +1362,8 @@ class _MultiLabelHead(_SingleHead):
metrics_lib.streaming_accuracy(classes, labels, weights))
metrics[_summary_key(self.head_name, mkey.AUC)] = _streaming_auc(
probabilities, labels, weights)
metrics[_summary_key(self.head_name, mkey.AUC_PR)] = _streaming_auc(
probabilities, labels, weights, curve="PR")
for class_id in self._metric_class_ids:
# TODO(ptucker): Add per-class accuracy, precision, recall.
@ -1391,6 +1381,9 @@ class _MultiLabelHead(_SingleHead):
_predictions_streaming_mean(logits, weights, class_id))
metrics[_summary_key(self.head_name, mkey.CLASS_AUC % class_id)] = (
_streaming_auc(probabilities, labels, weights, class_id))
metrics[_summary_key(self.head_name, mkey.CLASS_AUC_PR % class_id)] = (
_streaming_auc(probabilities, labels, weights, class_id,
curve="PR"))
return metrics
@ -1857,7 +1850,8 @@ def _class_labels_streaming_mean(labels, weights, class_id):
weights=weights)
def _streaming_auc(predictions, labels, weights=None, class_id=None):
def _streaming_auc(predictions, labels, weights=None, class_id=None,
curve="ROC"):
predictions = ops.convert_to_tensor(predictions)
labels = ops.convert_to_tensor(labels)
if class_id is not None:
@ -1866,7 +1860,8 @@ def _streaming_auc(predictions, labels, weights=None, class_id=None):
return metrics_lib.streaming_auc(
predictions,
math_ops.cast(labels, dtypes.bool),
weights=_float_weights_or_none(weights))
weights=_float_weights_or_none(weights),
curve=curve)
def _assert_class_id(class_id, num_classes=None):
@ -1901,6 +1896,71 @@ def _streaming_recall_at_threshold(predictions, labels, weights, threshold):
return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op)
def _classification_output_alternatives(head_name, problem_type,
label_keys=None):
"""Creates a func to generate output alternatives for classification.
Servo expects classes to be a string tensor, and have the same dimensions
as the probabilities tensor. It should contain the labels of the corresponding
entries in probabilities. This function creates a new classes tensor that
satisfies these conditions and can be exported.
Args:
head_name: Name of the head.
problem_type: `ProblemType`
label_keys: Optional label keys
Returns:
A function to generate output alternatives.
"""
def _create_output_alternatives(predictions):
"""Creates output alternative for the Head.
Args:
predictions: a dict of {tensor_name: Tensor}, where 'tensor_name' is a
symbolic name for an output Tensor possibly but not necessarily taken
from `PredictionKey`, and 'Tensor' is the corresponding output Tensor
itself.
Returns:
`dict` of {submodel_name: (problem_type, {tensor_name: Tensor})}, where
'submodel_name' is a submodel identifier that should be consistent across
the pipeline (here likely taken from the head_name),
'problem_type' is a `ProblemType`,
'tensor_name' is a symbolic name for an output Tensor possibly but not
necessarily taken from `PredictionKey`, and
'Tensor' is the corresponding output Tensor itself.
Raises:
ValueError: if predictions does not have PredictionKey.PROBABILITIES key.
"""
probabilities = predictions.get(prediction_key.PredictionKey.PROBABILITIES)
if probabilities is None:
raise ValueError("%s missing in predictions" %
prediction_key.PredictionKey.PROBABILITIES)
with ops.name_scope(None, "_classification_output_alternatives",
(probabilities,)):
batch_size = array_ops.shape(probabilities)[0]
if label_keys:
classes = array_ops.tile(
input=array_ops.expand_dims(input=label_keys, axis=0),
multiples=[batch_size, 1],
name="classes_tensor")
else:
n = array_ops.shape(probabilities)[1]
classes = array_ops.tile(
input=array_ops.expand_dims(input=math_ops.range(n), axis=0),
multiples=[batch_size, 1])
classes = string_ops.as_string(classes, name="classes_tensor")
exported_predictions = {
prediction_key.PredictionKey.PROBABILITIES: probabilities,
prediction_key.PredictionKey.CLASSES: classes}
return {head_name: (problem_type, exported_predictions)}
return _create_output_alternatives
# Aliases
# TODO(zakaria): Remove these aliases, See b/34751732
_regression_head = regression_head

View File

@ -297,11 +297,15 @@ class MultiLabelHeadTest(test.TestCase):
def _expected_eval_metrics(self, expected_loss):
return {
"accuracy": 1. / 3,
"auc": 1. / 4,
"loss": expected_loss,
"auc": 1. / 4,
"auc/class0": 1.,
"auc/class1": 1.,
"auc/class2": 0.,
"auc_precision_recall": 0.166667,
"auc_precision_recall/class0": 0,
"auc_precision_recall/class1": 0.,
"auc_precision_recall/class2": 1.,
"labels/actual_label_mean/class0": self._labels[0][0],
"labels/actual_label_mean/class1": self._labels[0][1],
"labels/actual_label_mean/class2": self._labels[0][2],
@ -417,7 +421,7 @@ class MultiLabelHeadTest(test.TestCase):
{}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
logits_input=((0., 0.),), logits=self._logits)
def testMultiLabelEvalMode(self):
def testMultiLabelEval(self):
n_classes = 3
head = head_lib.multi_label_head(
n_classes=n_classes, metric_class_ids=range(n_classes))
@ -433,7 +437,7 @@ class MultiLabelHeadTest(test.TestCase):
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
def testMultiClassEvalModeWithLargeLogits(self):
def testMultiClassEvalWithLargeLogits(self):
n_classes = 3
head = head_lib.multi_label_head(
n_classes=n_classes, metric_class_ids=range(n_classes))
@ -472,6 +476,36 @@ class MultiLabelHeadTest(test.TestCase):
_assert_metrics(self, expected_loss,
expected_eval_metrics, model_fn_ops)
def testMultiLabelInfer(self):
n_classes = 3
head = head_lib.multi_label_head(n_classes=n_classes, head_name="head_name")
with ops.Graph().as_default(), session.Session():
model_fn_ops = head.create_model_fn_ops(
{}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn,
logits=((1., 0., 0.), (0., 0., 1)))
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
with session.Session():
self.assertListEqual(
[1, 0, 0], model_fn_ops.predictions["classes"].eval().tolist()[0])
self.assertItemsEqual(
["head_name"], six.iterkeys(model_fn_ops.output_alternatives))
self.assertEqual(
constants.ProblemType.CLASSIFICATION,
model_fn_ops.output_alternatives["head_name"][0])
predictions_for_serving = (
model_fn_ops.output_alternatives["head_name"][1])
self.assertIn("classes", six.iterkeys(predictions_for_serving))
self.assertAllEqual(
[[b"0", b"1", b"2"], [b"0", b"1", b"2"]],
predictions_for_serving["classes"].eval())
self.assertIn("probabilities", six.iterkeys(predictions_for_serving))
self.assertAllClose(
[[0.731059, 0.5, 0.5],
[0.5, 0.5, 0.731059,]],
predictions_for_serving["probabilities"].eval())
def testMultiLabelWithLabelName(self):
n_classes = 3
label_name = "my_label"
@ -621,6 +655,7 @@ class BinaryClassificationHeadTest(test.TestCase):
"accuracy/baseline_label_mean": label_mean,
"accuracy/threshold_0.500000_mean": 1. / 2,
"auc": 1. / 2,
"auc_precision_recall": 0.749999,
"labels/actual_label_mean": label_mean,
"labels/prediction_mean": .731059, # softmax
"loss": expected_loss,
@ -691,7 +726,7 @@ class BinaryClassificationHeadTest(test.TestCase):
{}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
logits_input=((0., 0.), (0., 0.)), logits=self._logits)
def testBinaryClassificationEvalMode(self):
def testBinaryClassificationEval(self):
n_classes = 2
head = head_lib.multi_class_head(n_classes=n_classes)
with ops.Graph().as_default(), session.Session():
@ -708,18 +743,32 @@ class BinaryClassificationHeadTest(test.TestCase):
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
def testBinaryClassificationInferMode(self):
def testBinaryClassificationInfer(self):
n_classes = 2
head = head_lib.multi_class_head(n_classes=n_classes)
head = head_lib.multi_class_head(n_classes=n_classes, head_name="head_name")
with ops.Graph().as_default(), session.Session():
# logloss: z:label, x:logit
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
model_fn_ops = head.create_model_fn_ops(
{}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn,
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
with session.Session():
self.assertListEqual(
[1, 1], list(model_fn_ops.predictions["classes"].eval()))
self.assertItemsEqual(
["head_name"], six.iterkeys(model_fn_ops.output_alternatives))
self.assertEqual(
constants.ProblemType.LOGISTIC_REGRESSION,
model_fn_ops.output_alternatives["head_name"][0])
predictions_for_serving = (
model_fn_ops.output_alternatives["head_name"][1])
self.assertIn("classes", six.iterkeys(predictions_for_serving))
predicted_classes = predictions_for_serving["classes"].eval().tolist()
self.assertListEqual(
[b"0", b"1"], predicted_classes[0])
self.assertIn("probabilities", six.iterkeys(predictions_for_serving))
def testBinaryClassificationInferMode_withWightColumn(self):
n_classes = 2
@ -1006,7 +1055,7 @@ class MultiClassHeadTest(test.TestCase):
"multi_class_head/centered_bias/bias_1",
"multi_class_head/centered_bias/bias_2"])
def testMultiClassEvalMode(self):
def testMultiClassEval(self):
n_classes = 3
head = head_lib.multi_class_head(
n_classes=n_classes, metric_class_ids=range(n_classes))
@ -1131,7 +1180,7 @@ class MultiClassHeadTest(test.TestCase):
model_fn_ops.output_alternatives["head_name"][1])
self.assertIn("classes", six.iterkeys(predictions_for_serving))
self.assertAllEqual(
[[0, 1, 2], [0, 1, 2]],
[[b"0", b"1", b"2"], [b"0", b"1", b"2"]],
predictions_for_serving["classes"].eval())
self.assertIn("probabilities", six.iterkeys(predictions_for_serving))
self.assertAllClose(

View File

@ -22,7 +22,9 @@ class MetricKey(object):
"""Metric key strings."""
LOSS = "loss"
AUC = "auc"
AUC_PR = "auc_precision_recall"
CLASS_AUC = "auc/class%d"
CLASS_AUC_PR = "auc_precision_recall/class%d"
PREDICTION_MEAN = "labels/prediction_mean"
CLASS_PREDICTION_MEAN = "labels/prediction_mean/class%d"
CLASS_LOGITS_MEAN = "labels/logits_mean/class%d"

View File

@ -25,10 +25,16 @@ import six
from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib.framework import get_graph_from_inputs
from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
from tensorflow.python.estimator import model_fn as core_model_fn_lib
from tensorflow.python.estimator.export import export_output as core_export_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import session_run_hook
@ -177,3 +183,85 @@ class ModelFnOps(
training_chief_hooks=training_chief_hooks,
training_hooks=training_hooks,
scaffold=scaffold)
def estimator_spec(self, mode, default_serving_output_alternative_key=None):
"""Creates an equivalent `EstimatorSpec`.
Args:
mode: One of `ModeKeys`. Specifies if this training, evaluation or
prediction.
default_serving_output_alternative_key: Required for multiple heads. If
you have multiple entries in `output_alternatives` dict (comparable to
multiple heads), `EstimatorSpec` requires a default head that will be
used if a Servo request does not explicitly mention which head to infer
on. Pass the key of the output alternative here that you want to
designate as default. A separate ExportOutpout for this default head
wil be added to the export_outputs dict with the special key
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, unless there is
already an enry in output_alternatives with this special key.
Returns:
Instance of `EstimatorSpec` that is equivalent to this `ModelFnOps`
Raises:
ValueError: If problem type is unknown.
"""
def _scores(output_tensors):
scores = output_tensors.get(prediction_key.PredictionKey.SCORES)
if scores is None:
scores = output_tensors.get(prediction_key.PredictionKey.PROBABILITIES)
return scores
def _classes(output_tensors): # pylint: disable=missing-docstring
classes = output_tensors.get(prediction_key.PredictionKey.CLASSES)
if classes is None:
logging.warning(
'classes is None, Servo inference will not have class ids.')
return None
elif classes.dtype != dtypes.string:
# Servo classification can only serve string classes
logging.warning(
'classes is not string, Servo inference will not have class ids.')
return None
return classes
def _export_output(problem_type, predictions): # pylint: disable=missing-docstring
if problem_type == constants.ProblemType.LINEAR_REGRESSION:
return core_export_lib.RegressionOutput(_scores(predictions))
if (problem_type == constants.ProblemType.CLASSIFICATION or
problem_type == constants.ProblemType.LOGISTIC_REGRESSION):
return core_export_lib.ClassificationOutput(
scores=_scores(predictions), classes=_classes(predictions))
if problem_type == constants.ProblemType.UNSPECIFIED:
return core_export_lib.PredictOutput(predictions)
raise ValueError('Unknown problem_type=%s' % problem_type)
# Converts output_alternatives
export_outputs_dict = None
if self.output_alternatives:
output_alternatives = self.output_alternatives
# Adds default output_alternative if needed.
if (len(output_alternatives) > 1 and
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY not in
output_alternatives):
output_alternatives = output_alternatives.copy()
output_alternatives[
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (
output_alternatives[default_serving_output_alternative_key])
export_outputs_dict = {key: _export_output(*val) for key, val in
output_alternatives.items()}
return core_model_fn_lib.EstimatorSpec(
mode=mode,
predictions=self.predictions,
loss=self.loss,
train_op=self.train_op,
eval_metric_ops=self.eval_metric_ops,
export_outputs=export_outputs_dict,
training_chief_hooks=self.training_chief_hooks,
training_hooks=self.training_hooks,
scaffold=self.scaffold)

View File

@ -0,0 +1,279 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ModelFnOps tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import model_fn
from tensorflow.python.client import session
from tensorflow.python.estimator.export import export_output as core_export_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import monitored_session
class ModelFnopsTest(test.TestCase):
"""Multi-output tests."""
def create_predictions(self):
probabilities = constant_op.constant([1., 1., 1.])
scores = constant_op.constant([1., 2., 3.])
classes = constant_op.constant([b"0", b"1", b"2"])
return {
"probabilities": probabilities,
"scores": scores,
"classes": classes}
def create_model_fn_ops(self, predictions, output_alternatives,
mode=model_fn.ModeKeys.INFER):
return model_fn.ModelFnOps(
model_fn.ModeKeys.INFER,
predictions=predictions,
loss=constant_op.constant([1]),
train_op=control_flow_ops.no_op(),
eval_metric_ops={"metric_key": (control_flow_ops.no_op(),
control_flow_ops.no_op())},
# zzz
training_chief_hooks=[basic_session_run_hooks.StepCounterHook()],
training_hooks=[basic_session_run_hooks.StepCounterHook()],
output_alternatives=output_alternatives,
scaffold=monitored_session.Scaffold())
def assertEquals_except_export(self, model_fn_ops, estimator_spec):
self.assertEqual(model_fn_ops.predictions, estimator_spec.predictions)
self.assertEqual(model_fn_ops.loss, estimator_spec.loss)
self.assertEqual(model_fn_ops.train_op, estimator_spec.train_op)
self.assertEqual(model_fn_ops.eval_metric_ops,
estimator_spec.eval_metric_ops)
self.assertEqual(model_fn_ops.training_chief_hooks,
estimator_spec.training_chief_hooks)
self.assertEqual(model_fn_ops.training_hooks, estimator_spec.training_hooks)
self.assertEqual(model_fn_ops.scaffold, estimator_spec.scaffold)
def testEstimatorSpec_except_export(self):
predictions = self.create_predictions()
model_fn_ops = self.create_model_fn_ops(predictions, None)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
self.assertEquals_except_export(model_fn_ops, estimator_spec)
def testEstimatorSpec_export_regression_with_scores(self):
predictions = self.create_predictions()
output_alternatives = {"regression_head": (
constants.ProblemType.LINEAR_REGRESSION, predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
self.assertEquals_except_export(model_fn_ops, estimator_spec)
with session.Session():
regression_output = estimator_spec.export_outputs["regression_head"]
self.assertTrue(isinstance(
regression_output, core_export_lib.RegressionOutput))
self.assertAllEqual(predictions["scores"].eval(),
regression_output.value.eval())
def testEstimatorSpec_export_regression_with_probabilities(self):
predictions = self.create_predictions()
output_alternatives_predictions = predictions.copy()
del output_alternatives_predictions["scores"]
output_alternatives = {"regression_head": (
constants.ProblemType.LINEAR_REGRESSION,
output_alternatives_predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
self.assertEquals_except_export(model_fn_ops, estimator_spec)
with session.Session():
regression_output = estimator_spec.export_outputs["regression_head"]
self.assertTrue(isinstance(
regression_output, core_export_lib.RegressionOutput))
self.assertAllEqual(predictions["probabilities"].eval(),
regression_output.value.eval())
def testEstimatorSpec_export_classsification(self):
predictions = self.create_predictions()
output_alternatives = {"classification_head": (
constants.ProblemType.CLASSIFICATION, predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
self.assertEquals_except_export(model_fn_ops, estimator_spec)
with session.Session():
classification_output = estimator_spec.export_outputs[
"classification_head"]
self.assertTrue(isinstance(classification_output,
core_export_lib.ClassificationOutput))
self.assertAllEqual(predictions["scores"].eval(),
classification_output.scores.eval())
self.assertAllEqual(predictions["classes"].eval(),
classification_output.classes.eval())
def testEstimatorSpec_export_classsification_with_missing_scores(self):
predictions = self.create_predictions()
output_alternatives_predictions = predictions.copy()
del output_alternatives_predictions["scores"]
output_alternatives = {"classification_head": (
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
self.assertEquals_except_export(model_fn_ops, estimator_spec)
with session.Session():
classification_output = estimator_spec.export_outputs[
"classification_head"]
self.assertTrue(isinstance(classification_output,
core_export_lib.ClassificationOutput))
self.assertAllEqual(predictions["probabilities"].eval(),
classification_output.scores.eval())
self.assertAllEqual(predictions["classes"].eval(),
classification_output.classes.eval())
def testEstimatorSpec_export_classsification_with_missing_scores_proba(self):
predictions = self.create_predictions()
output_alternatives_predictions = predictions.copy()
del output_alternatives_predictions["scores"]
del output_alternatives_predictions["probabilities"]
output_alternatives = {"classification_head": (
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
self.assertEquals_except_export(model_fn_ops, estimator_spec)
with session.Session():
classification_output = estimator_spec.export_outputs[
"classification_head"]
self.assertTrue(isinstance(classification_output,
core_export_lib.ClassificationOutput))
self.assertIsNone(classification_output.scores)
self.assertAllEqual(predictions["classes"].eval(),
classification_output.classes.eval())
def testEstimatorSpec_export_classsification_with_missing_classes(self):
predictions = self.create_predictions()
output_alternatives_predictions = predictions.copy()
del output_alternatives_predictions["classes"]
output_alternatives = {"classification_head": (
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
self.assertEquals_except_export(model_fn_ops, estimator_spec)
with session.Session():
classification_output = estimator_spec.export_outputs[
"classification_head"]
self.assertTrue(isinstance(classification_output,
core_export_lib.ClassificationOutput))
self.assertAllEqual(predictions["scores"].eval(),
classification_output.scores.eval())
self.assertIsNone(classification_output.classes)
def testEstimatorSpec_export_classsification_with_nonstring_classes(self):
predictions = self.create_predictions()
output_alternatives_predictions = predictions.copy()
output_alternatives_predictions["classes"] = constant_op.constant(
[1, 2, 3])
output_alternatives = {"classification_head": (
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
self.assertEquals_except_export(model_fn_ops, estimator_spec)
with session.Session():
classification_output = estimator_spec.export_outputs[
"classification_head"]
self.assertTrue(isinstance(classification_output,
core_export_lib.ClassificationOutput))
self.assertAllEqual(predictions["scores"].eval(),
classification_output.scores.eval())
self.assertIsNone(classification_output.classes)
def testEstimatorSpec_export_logistic(self):
predictions = self.create_predictions()
output_alternatives = {"logistic_head": (
constants.ProblemType.LOGISTIC_REGRESSION, predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
self.assertEquals_except_export(model_fn_ops, estimator_spec)
with session.Session():
logistic_output = estimator_spec.export_outputs["logistic_head"]
self.assertTrue(isinstance(logistic_output,
core_export_lib.ClassificationOutput))
self.assertAllEqual(predictions["scores"].eval(),
logistic_output.scores.eval())
self.assertAllEqual(predictions["classes"].eval(),
logistic_output.classes.eval())
def testEstimatorSpec_export_unspecified(self):
predictions = self.create_predictions()
output_alternatives = {"unspecified_head": (
constants.ProblemType.UNSPECIFIED, predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
self.assertEquals_except_export(model_fn_ops, estimator_spec)
with session.Session():
unspecified_output = estimator_spec.export_outputs["unspecified_head"]
self.assertTrue(isinstance(unspecified_output,
core_export_lib.PredictOutput))
self.assertEqual(predictions, unspecified_output.outputs)
def testEstimatorSpec_export_multihead(self):
predictions = self.create_predictions()
output_alternatives = {
"regression_head": (
constants.ProblemType.LINEAR_REGRESSION, predictions),
"classification_head": (
constants.ProblemType.CLASSIFICATION, predictions)}
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER,
"regression_head")
self.assertEquals_except_export(model_fn_ops, estimator_spec)
with session.Session():
regression_output = estimator_spec.export_outputs["regression_head"]
self.assertTrue(isinstance(
regression_output, core_export_lib.RegressionOutput))
self.assertAllEqual(predictions["scores"].eval(),
regression_output.value.eval())
default_output = estimator_spec.export_outputs[
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
self.assertTrue(isinstance(default_output,
core_export_lib.RegressionOutput))
self.assertAllEqual(predictions["scores"].eval(),
default_output.value.eval())
if __name__ == "__main__":
test.main()

View File

@ -66,8 +66,8 @@ def _get_single_cell(cell_type, num_units):
ValueError: `cell_type` is an invalid `RNNCell` name.
TypeError: `cell_type` is not a string or a subclass of `RNNCell`.
"""
cell_type = _CELL_TYPES.get(cell_type)
if cell_type is None and not issubclass(cell_type, contrib_rnn.RNNCell):
cell_type = _CELL_TYPES.get(cell_type, cell_type)
if not cell_type or not issubclass(cell_type, contrib_rnn.RNNCell):
raise ValueError('The supported cell types are {}; got {}'.format(
list(_CELL_TYPES.keys()), cell_type))
return cell_type(num_units=num_units)

View File

@ -97,7 +97,8 @@ class Experiment(object):
finite number of batches (generally, 1 epoch over the evaluation data).
eval_metrics: `dict` of string, metric function. If `None`, default set
is used. This should be `None` if the `estimator` is
${tf.estimator.Estimator}.
${tf.estimator.Estimator}. If metrics are provided they will be
*appended* to the default set.
train_steps: Perform this many steps of training. `None`, the default,
means train forever.
eval_steps: `evaluate` runs until input is exhausted (or another exception

View File

@ -45,7 +45,7 @@ class SquareLinearOperatorFullMatrixTest(
# values are random and we want the same value used for both mat and
# feed_dict.
matrix = matrix.eval()
operator = linalg.LinearOperatorFullMatrix(matrix)
operator = linalg.LinearOperatorFullMatrix(matrix_ph)
feed_dict = {matrix_ph: matrix}
else:
operator = linalg.LinearOperatorFullMatrix(matrix)
@ -105,7 +105,7 @@ class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest(
# feed_dict.
matrix = matrix.eval()
operator = linalg.LinearOperatorFullMatrix(
matrix, is_self_adjoint=True, is_positive_definite=True)
matrix_ph, is_self_adjoint=True, is_positive_definite=True)
feed_dict = {matrix_ph: matrix}
else:
operator = linalg.LinearOperatorFullMatrix(
@ -144,7 +144,7 @@ class NonSquareLinearOperatorFullMatrixTest(
# values are random and we want the same value used for both mat and
# feed_dict.
matrix = matrix.eval()
operator = linalg.LinearOperatorFullMatrix(matrix)
operator = linalg.LinearOperatorFullMatrix(matrix_ph)
feed_dict = {matrix_ph: matrix}
else:
operator = linalg.LinearOperatorFullMatrix(matrix)

View File

@ -12,13 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops for nccl AllReduce."""
"""Functions for using NVIDIA nccl collective ops.
@@all_max
@@all_min
@@all_prod
@@all_sum
@@broadcast
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.contrib.nccl.python.ops.nccl_ops import *
# pylint: enable=wildcard-import
from tensorflow.contrib.nccl.python.ops.nccl_ops import all_max
from tensorflow.contrib.nccl.python.ops.nccl_ops import all_min
from tensorflow.contrib.nccl.python.ops.nccl_ops import all_prod
from tensorflow.contrib.nccl.python.ops.nccl_ops import all_sum
from tensorflow.contrib.nccl.python.ops.nccl_ops import broadcast
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(__name__)

View File

@ -66,7 +66,7 @@ class RNNCellTest(test.TestCase):
x = array_ops.zeros([batch_size, input_size])
m = array_ops.zeros([batch_size, state_size])
output, state = rnn_cell.CoupledInputForgetGateLSTMCell(
num_units=num_units, forget_bias=1.0)(x, m)
num_units=num_units, forget_bias=1.0, state_is_tuple=False)(x, m)
sess.run([variables.global_variables_initializer()])
res = sess.run([output, state], {
x.name:

View File

@ -466,12 +466,13 @@ class OutputProjectionWrapper(RNNCell):
if needed or directly feed into a softmax.
"""
def __init__(self, cell, output_size, reuse=None):
def __init__(self, cell, output_size, activation=None, reuse=None):
"""Create a cell with output projection.
Args:
cell: an RNNCell, a projection to output_size is added to it.
output_size: integer, the size of the output after projection.
activation: (optional) an optional activation function.
reuse: (optional) Python boolean describing whether to reuse variables
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
@ -487,6 +488,7 @@ class OutputProjectionWrapper(RNNCell):
self._cell = cell
self._output_size = output_size
self._reuse = reuse
self._activation = activation
@property
def state_size(self):
@ -507,6 +509,8 @@ class OutputProjectionWrapper(RNNCell):
with _checked_scope(self, scope or "output_projection_wrapper",
reuse=self._reuse):
projected = _linear(output, self._output_size, True)
if self._activation:
projected = self._activation(projected)
return projected, res_state
@ -518,12 +522,13 @@ class InputProjectionWrapper(RNNCell):
do the projection on this batch-concatenated sequence, then split it.
"""
def __init__(self, cell, num_proj, input_size=None):
def __init__(self, cell, num_proj, activation=None, input_size=None):
"""Create a cell with input projection.
Args:
cell: an RNNCell, a projection of inputs is added before it.
num_proj: Python integer. The dimension to project to.
activation: (optional) an optional activation function.
input_size: Deprecated and unused.
Raises:
@ -535,6 +540,7 @@ class InputProjectionWrapper(RNNCell):
raise TypeError("The parameter cell is not RNNCell.")
self._cell = cell
self._num_proj = num_proj
self._activation = activation
@property
def state_size(self):
@ -553,6 +559,8 @@ class InputProjectionWrapper(RNNCell):
# Default scope: "InputProjectionWrapper"
with vs.variable_scope(scope or "input_projection_wrapper"):
projected = _linear(inputs, self._num_proj, True)
if self._activation:
projected = self._activation(projected)
return self._cell(projected, state)

View File

@ -109,7 +109,7 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
def __init__(self, num_units, use_peepholes=False,
initializer=None, num_proj=None, proj_clip=None,
num_unit_shards=1, num_proj_shards=1,
forget_bias=1.0, state_is_tuple=False,
forget_bias=1.0, state_is_tuple=True,
activation=math_ops.tanh, reuse=None):
"""Initialize the parameters for an LSTM cell.
@ -457,7 +457,7 @@ class GridLSTMCell(core_rnn_cell.RNNCell):
start_freqindex_list=None,
end_freqindex_list=None,
couple_input_forget_gates=False,
state_is_tuple=False,
state_is_tuple=True,
reuse=None):
"""Initialize the parameters for an LSTM cell.
@ -571,7 +571,7 @@ class GridLSTMCell(core_rnn_cell.RNNCell):
ValueError: if an input_size was specified and the provided inputs have
a different dimension.
"""
batch_size = int(inputs.get_shape()[0])
batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
freq_inputs = self._make_tf_features(inputs)
with _checked_scope(self, scope or "grid_lstm_cell",
initializer=self._initializer, reuse=self._reuse):
@ -994,7 +994,7 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
ValueError: if an input_size was specified and the provided inputs have
a different dimension.
"""
batch_size = int(inputs.get_shape()[0])
batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
fwd_inputs = self._make_tf_features(inputs)
if self._backward_slice_offset:
bwd_inputs = self._make_tf_features(inputs, self._backward_slice_offset)
@ -1043,7 +1043,7 @@ class AttentionCellWrapper(core_rnn_cell.RNNCell):
"""
def __init__(self, cell, attn_length, attn_size=None, attn_vec_size=None,
input_size=None, state_is_tuple=False, reuse=None):
input_size=None, state_is_tuple=True, reuse=None):
"""Create a cell with attention.
Args:

View File

@ -56,6 +56,13 @@ class AttentionWrapperTest(test.TestCase):
return super(AttentionWrapperTest, self).assertAllClose(
*args, **kwargs)
def testAttentionWrapperState(self):
num_fields = len(wrapper.AttentionWrapperState._fields) # pylint: disable=protected-access
state = wrapper.AttentionWrapperState(*([None] * num_fields))
new_state = state.clone(time=1)
self.assertEqual(state.time, None)
self.assertEqual(new_state.time, 1)
def _testWithAttention(self,
create_attention_mechanism,
expected_final_output,

View File

@ -369,7 +369,26 @@ class AttentionWrapperState(
- `attention_history`: (if enabled) a `TensorArray` containing attention
matrices from all time steps. Call `stack()` to convert to a `Tensor`.
"""
pass
def clone(self, **kwargs):
"""Clone this object, overriding components provided by kwargs.
Example:
```python
initial_state = attention_wrapper.zero_state(dtype=..., batch_size=...)
initial_state = initial_state.clone(cell_state=encoder_state)
```
Args:
**kwargs: Any properties of the state object to replace in the returned
`AttentionWrapperState`.
Returns:
A new `AttentionWrapperState` whose properties are the same as
this one, except any overriden properties as provided in `kwargs`.
"""
return super(AttentionWrapperState, self)._replace(**kwargs)
def hardmax(logits, name=None):

View File

@ -431,7 +431,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper):
shape=base_shape))
all_finished = math_ops.reduce_all(finished)
no_samples = math_ops.equal(array_ops.shape(sample_ids)[0], 0)
no_samples = math_ops.logical_not(math_ops.reduce_any(sample_ids))
next_inputs = control_flow_ops.cond(
math_ops.logical_or(all_finished, no_samples),
lambda: base_next_inputs, maybe_sample)

View File

@ -31,6 +31,8 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import monitored_session
from tensorflow.python.training import session_run_hook
@ -95,6 +97,22 @@ class TensorForestLossHook(session_run_hook.SessionRunHook):
run_context.request_stop()
class EveryCheckpointPreSaveListener(
basic_session_run_hooks.CheckpointSaverListener):
"""Runs a given op before each checkpoint save."""
def __init__(self, op):
"""Initializes the object.
Args:
op: An op to run before each checkpoint save.
"""
self._op = op
def before_save(self, session, global_step_value):
session.run(self._op)
def get_model_fn(params,
graph_builder_class,
device_assigner,
@ -103,6 +121,7 @@ def get_model_fn(params,
num_trainers=1,
trainer_id=0,
report_feature_importances=False,
model_dir=None,
local_eval=False):
"""Return a model function given a way to construct a graph builder."""
def _model_fn(features, labels, mode):
@ -138,6 +157,8 @@ def get_model_fn(params,
# question of why we force everything to adhere to a single model_fn).
loss_deps = []
training_graph = None
training_hooks = []
scaffold = None
if labels is not None and mode == model_fn_lib.ModeKeys.TRAIN:
training_graph = control_flow_ops.group(
graph_builder.training_graph(
@ -146,6 +167,15 @@ def get_model_fn(params,
trainer_id=trainer_id),
state_ops.assign_add(contrib_framework.get_global_step(), 1))
loss_deps.append(training_graph)
if hasattr(graph_builder, 'finalize_training'):
finalize_listener = EveryCheckpointPreSaveListener(
graph_builder.finalize_training())
scaffold = monitored_session.Scaffold()
training_hooks.append(
basic_session_run_hooks.CheckpointSaverHook(
model_dir, save_secs=600, save_steps=None,
scaffold=scaffold,
listeners=[finalize_listener]))
training_loss = None
if (mode == model_fn_lib.ModeKeys.EVAL or
@ -158,7 +188,6 @@ def get_model_fn(params,
if weights is not None:
features[weights_name] = weights
training_hooks = []
if early_stopping_rounds:
training_hooks.append(TensorForestLossHook(early_stopping_rounds))
@ -167,7 +196,9 @@ def get_model_fn(params,
predictions=inference,
loss=training_loss,
train_op=training_graph,
training_hooks=training_hooks)
training_hooks=training_hooks,
scaffold=scaffold)
return _model_fn
@ -257,6 +288,7 @@ class TensorForestEstimator(estimator.Estimator):
num_trainers=num_trainers,
trainer_id=trainer_id,
report_feature_importances=report_feature_importances,
model_dir=model_dir,
local_eval=local_eval),
model_dir=model_dir,
config=config,

View File

@ -43,9 +43,9 @@ py_library(
srcs = ["plugins/projector/__init__.py"],
srcs_version = "PY2AND3",
deps = [
":protos_all_py",
"//tensorflow/python:lib",
"//tensorflow/tensorboard/plugins/projector:projector_plugin",
"//tensorflow/tensorboard/plugins/projector:protos_all_py",
],
)
@ -56,10 +56,10 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":projector",
":protos_all_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:platform",
"//tensorflow/python:summary",
"//tensorflow/tensorboard/plugins/projector:protos_all_py",
],
)

View File

@ -28,11 +28,10 @@ from __future__ import print_function
import os
from google.protobuf import text_format
from tensorflow.contrib.tensorboard.plugins.projector.projector_config_pb2 import EmbeddingInfo
from tensorflow.contrib.tensorboard.plugins.projector.projector_config_pb2 import ProjectorConfig
from tensorflow.python.lib.io import file_io
from tensorflow.tensorboard.plugins.projector import projector_plugin
# pylint: disable=wildcard-import
from tensorflow.tensorboard.plugins.projector.projector_config_pb2 import *
from tensorflow.tensorboard.plugins.projector.projector_plugin import *
# pylint: enable=wildcard-import

View File

@ -24,10 +24,10 @@ import shutil
from google.protobuf import text_format
from tensorflow.contrib.tensorboard.plugins import projector
from tensorflow.contrib.tensorboard.plugins.projector import projector_config_pb2
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer as writer_lib
from tensorflow.tensorboard.plugins.projector import projector_config_pb2
class ProjectorApiTest(test.TestCase):

View File

@ -0,0 +1,62 @@
# Description:
# contains parts of TensorFlow that are experimental or unstable and which are not supported.
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
)
cc_library(
name = "xla_tf_graph_util",
srcs = [
"xla_tf_graph_util.cc",
],
hdrs = [
"xla_tf_graph_util.h",
],
deps = [
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/client",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
cc_test(
name = "xla_tf_graph_util_test",
srcs = ["xla_tf_graph_util_test.cc"],
deps = [
":xla_tf_graph_util",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:scope",
"//tensorflow/compiler/jit:xla_cpu_jit",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework_internal",
"//tensorflow/core:ops",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/kernels:cwise_op",
],
)

View File

@ -0,0 +1,8 @@
# Xla Tf Graph
## Description
This module contains utilities to treat xla representation as tf graph to support mobile SOC experiments and leverage tf tools.
Maintainers:
- Satoshi Kataoka (satok@google.com, github.com/satok16)

View File

@ -0,0 +1,71 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
namespace tensorflow {
namespace xla_tf_graph {
namespace {
constexpr const char* const GRAPH_NAME = "xla_tf_graph_util";
void SetupXlaCpuClient(std::unique_ptr<FunctionLibraryDefinition>* flib_def,
std::unique_ptr<FunctionLibraryRuntime>* flr,
std::unique_ptr<XlaCompiler>* compiler) {
xla::Client* client = xla::ClientLibrary::LocalClientOrDie();
XlaOpRegistry::RegisterCompilationKernels();
FunctionDefLibrary flib;
flib_def->reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
// Setup compiler options
XlaCompiler::Options options;
options.device_type = DeviceType(DEVICE_CPU_XLA_JIT);
options.client = client;
compiler->reset(new XlaCompiler(options));
flr->reset(NewFunctionLibraryRuntime(
compiler->get()->device_mgr(), /*env=*/nullptr, compiler->get()->device(),
TF_GRAPH_DEF_VERSION, flib_def->get(), OptimizerOptions(),
/*custom_kernel_creator=*/nullptr));
}
} // namespace
xla::StatusOr<std::unique_ptr<xla::SessionModule>>
ConvertTfGraphToXlaSessionModule(const std::vector<XlaCompiler::Argument>& args,
std::unique_ptr<Graph> graph) {
CHECK(graph);
std::unique_ptr<FunctionLibraryDefinition> flib_def;
std::unique_ptr<FunctionLibraryRuntime> flr;
std::unique_ptr<XlaCompiler> compiler;
SetupXlaCpuClient(&flib_def, &flr, &compiler);
// Compile graph and build computation
XlaCompiler::CompilationResult result;
TF_CHECK_OK(compiler->CompileGraph(GRAPH_NAME, std::move(graph), flr.get(),
args, &result));
return result.computation.Snapshot();
}
} // namespace xla_tf_graph
} // namespace tensorflow

View File

@ -0,0 +1,43 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CONTRIB_XLA_TF_GRAPH_XLA_TF_GRAPH_UTIL_H_
#define TENSORFLOW_CONTRIB_XLA_TF_GRAPH_XLA_TF_GRAPH_UTIL_H_
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
namespace xla_tf_graph {
// A set of utilities to handle xla computation requests.
// These utilities help developers leverage existing tools to work with
// xla computations, also provide a way to support TensorFlow ops by
// implementing xla computations so that they can do experiments on their
// specialized environments.
// Convert a tf graph to a xla session module
xla::StatusOr<std::unique_ptr<xla::SessionModule>>
ConvertTfGraphToXlaSessionModule(const std::vector<XlaCompiler::Argument>& args,
std::unique_ptr<Graph> graph);
} // namespace xla_tf_graph
} // namespace tensorflow
#endif // TENSORFLOW_CONTRIB_XLA_TF_GRAPH_XLA_TF_GRAPH_UTIL_H_

View File

@ -0,0 +1,57 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace xla_tf_graph {
static std::unique_ptr<Graph> BuildAddGraph() {
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
auto c = ops::Add(scope.WithOpName("C"), a, b);
auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_CHECK_OK(scope.ToGraph(graph.get()));
return graph;
}
TEST(XlaTfGraphUtil, ConvertTfGraphToHloModule) {
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});
args[1].kind = XlaCompiler::Argument::kParameter;
args[1].type = DT_INT32;
args[1].shape = TensorShape({2});
std::unique_ptr<Graph> graph = BuildAddGraph();
TF_ASSIGN_OR_ASSERT_OK(
std::unique_ptr<xla::SessionModule> session_module,
ConvertTfGraphToXlaSessionModule(args, std::move(graph)));
ASSERT_EQ(5, session_module->entry().requests_size());
}
} // namespace xla_tf_graph
} // namespace tensorflow

View File

@ -453,8 +453,8 @@ void BFCAllocator::RemoveFreeChunkIterFromBin(
void BFCAllocator::RemoveFreeChunkFromBin(BFCAllocator::ChunkHandle h) {
Chunk* c = ChunkFromHandle(h);
CHECK(!c->in_use() && (c->bin_num != kInvalidBinNum));
int count = BinFromIndex(c->bin_num)->free_chunks.erase(h);
CHECK(count > 0) << "Could not find chunk in bin";
CHECK_GT(BinFromIndex(c->bin_num)->free_chunks.erase(h), 0)
<< "Could not find chunk in bin";
c->bin_num = kInvalidBinNum;
}

View File

@ -78,7 +78,7 @@ class BFCAllocator : public VisitableAllocator {
// A ChunkHandle is an index into the chunks_ vector in BFCAllocator
// kInvalidChunkHandle means an invalid chunk
typedef int ChunkHandle;
typedef size_t ChunkHandle;
static const int kInvalidChunkHandle = -1;
typedef int BinNum;

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/public/session_options.h"
@ -304,10 +305,18 @@ Status DoConstantFoldingWithStatus(const ConstantFoldingOptions& opts,
tensors_to_replace.push_back({n.second, n.first.second});
}
auto graph_runner = std::unique_ptr<GraphRunner>(new GraphRunner(env));
// Evaluate the constant foldable nodes.
std::vector<Tensor> outputs;
Status s = GraphRunner::Run(constant_graph.get(), function_library, env,
{} /* inputs*/, tensors_to_fetch_names, &outputs);
auto delete_tensors = gtl::MakeCleanup([&graph_runner, &outputs] {
// Output tensors need to be cleared before the GraphRunner is deleted.
outputs.clear();
graph_runner.reset(nullptr);
});
Status s =
graph_runner->Run(constant_graph.get(), function_library, {} /* inputs*/,
tensors_to_fetch_names, &outputs);
if (!s.ok()) {
VLOG(1) << "Could not fetch constants: " << s;
*was_mutated = false;

View File

@ -44,7 +44,7 @@ DeviceMgr::~DeviceMgr() {
}
StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
int n = s.size();
size_t n = s.size();
char* space = name_backing_store_.Alloc(n);
memcpy(space, s.data(), n);
return StringPiece(space, n);

View File

@ -427,7 +427,7 @@ Status DirectSession::Run(const RunOptions& run_options,
TF_RETURN_IF_ERROR(SendInputs(inputs, executors_and_keys, run_state.rendez));
// Start parallel Executors.
const int num_executors = executors_and_keys->items.size();
const size_t num_executors = executors_and_keys->items.size();
ExecutorBarrier* barrier = new ExecutorBarrier(
num_executors, run_state.rendez, [&run_state](const Status& ret) {
{
@ -458,7 +458,7 @@ Status DirectSession::Run(const RunOptions& run_options,
options_.config.graph_options().build_cost_model();
const int64 build_cost_model_after =
options_.config.graph_options().build_cost_model_after();
int measure_step_count = executor_step_count - build_cost_model_after;
int64 measure_step_count = executor_step_count - build_cost_model_after;
if (measure_step_count >= 0) {
update_cost_model =
((measure_step_count + 1) % build_cost_model_every == 0);
@ -611,7 +611,7 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
}
// Start parallel Executors.
const int num_executors = executors_and_keys->items.size();
const size_t num_executors = executors_and_keys->items.size();
ExecutorBarrier* barrier = new ExecutorBarrier(
num_executors, run_state->rendez, [run_state](const Status& ret) {
if (!ret.ok()) {

View File

@ -232,7 +232,7 @@ struct NodeItem {
int input_start = 0;
// Number of output edges.
int num_output_edges;
size_t num_output_edges;
PendingCounts::Handle pending_id;
@ -307,7 +307,7 @@ class GraphView {
void Initialize(const Graph* g);
Status SetAllocAttrs(const Graph* g, const Device* device);
NodeItem* node(int id) const {
NodeItem* node(size_t id) const {
DCHECK_GE(id, 0);
DCHECK_LT(id, num_nodes_);
uint32 offset = node_offsets_[id];
@ -454,7 +454,7 @@ GraphView::~GraphView() {
}
size_t GraphView::NodeItemBytes(const Node* n) {
const int num_output_edges = n->out_edges().size();
const size_t num_output_edges = n->out_edges().size();
const int num_inputs = n->num_inputs();
const int num_outputs = n->num_outputs();
@ -500,11 +500,11 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) {
// pointers). Casting to int64 is needed on 32bit CPU to avoid comparing
// values as "int" vs "size_t" in CHECK_LE.
CHECK_LE(static_cast<int64>(ptr - space_), kuint32max);
const uint32 offset = ptr - space_;
const uint32 offset = static_cast<uint32>(ptr - space_);
node_offsets_[id] = offset;
ptr += bytes;
const int num_output_edges = n->out_edges().size();
const size_t num_output_edges = n->out_edges().size();
const int num_inputs = n->num_inputs();
const int num_outputs = n->num_outputs();
@ -580,9 +580,10 @@ void GraphView::Initialize(const Graph* g) {
CHECK_EQ(ptr, space_ + total_bytes);
}
void GetMaxPendingCounts(const Node* n, int* max_pending, int* max_dead_count) {
const int num_in_edges = n->in_edges().size();
int initial_count;
void GetMaxPendingCounts(const Node* n, size_t* max_pending,
size_t* max_dead_count) {
const size_t num_in_edges = n->in_edges().size();
size_t initial_count;
if (IsMerge(n)) {
// merge waits all control inputs so we initialize the pending
// count to be the number of control edges.
@ -626,8 +627,7 @@ Status ExecutorImpl::Initialize() {
FrameInfo* frame_info = EnsureFrameInfo(frame_name);
// See if this node is a root node, and if so, add to root_nodes_.
const int num_in_edges = n->in_edges().size();
if (num_in_edges == 0) {
if (n->in_edges().empty()) {
root_nodes_.push_back(n);
}
@ -659,7 +659,7 @@ Status ExecutorImpl::Initialize() {
// pending counts data structure, and allocate a handle in
// that frame's pending counts data structure that has enough
// space to store these maximal count values.
int max_pending, max_dead;
size_t max_pending, max_dead;
GetMaxPendingCounts(n, &max_pending, &max_dead);
item->pending_id =
frame_info->pending_counts_layout.CreateHandle(max_pending, max_dead);
@ -896,7 +896,7 @@ class ExecutorState {
Entry* input_tensors;
// The number of outstanding ops for each iteration.
int outstanding_ops;
size_t outstanding_ops;
// The number of outstanding frames for each iteration.
int outstanding_frame_count;
@ -1037,13 +1037,13 @@ class ExecutorState {
inline IterationState* GetIteration(int64 iter)
EXCLUSIVE_LOCKS_REQUIRED(mu) {
int index = iter % iterations.size();
size_t index = iter % iterations.size();
return iterations[index];
}
inline void SetIteration(int64 iter, IterationState* state)
EXCLUSIVE_LOCKS_REQUIRED(mu) {
int index = iter % iterations.size();
size_t index = iter % iterations.size();
DCHECK(state == nullptr || iterations[index] == nullptr);
iterations[index] = state;
}
@ -1404,7 +1404,7 @@ void ExecutorImpl::InitializePending(const Graph* graph,
for (const Node* n : graph->nodes()) {
const int id = n->id();
const string& name = cf_info.frame_names[id];
int max_pending, max_dead;
size_t max_pending, max_dead;
GetMaxPendingCounts(n, &max_pending, &max_dead);
const NodeItem* item = gview_.node(id);
PendingCounts* counts = EnsureFrameInfo(name)->pending_counts;
@ -2027,7 +2027,7 @@ bool ExecutorState::NodeDone(const Status& s, const Node* node,
}
bool completed = false;
int ready_size = ready.size();
size_t ready_size = ready.size();
if (ready_size == 0 || !s.ok()) {
completed = (num_outstanding_ops_.fetch_sub(1) == 1);
} else if (ready_size > 1) {
@ -2375,10 +2375,10 @@ void ExecutorState::FrameState::ActivateNodes(const NodeItem* item,
TaggedNodeSeq* ready) {
const GraphView& gview = executor->gview_;
IterationState* iter_state = GetIteration(iter);
const int num_output_edges = item->num_output_edges;
const size_t num_output_edges = item->num_output_edges;
const EdgeInfo* edges = item->output_edge_list();
Entry* input_tensors = iter_state->input_tensors;
for (int out_index = 0; out_index < num_output_edges; out_index++) {
for (size_t out_index = 0; out_index < num_output_edges; out_index++) {
const EdgeInfo& e = edges[out_index];
const int dst_id = e.dst_id;
const NodeItem* dst_item = gview.node(dst_id);

View File

@ -162,7 +162,7 @@ class ExecutorBarrier {
//
// 'done' is called after the last executor completes, and
// ExecutorBarrier is deleted.
ExecutorBarrier(int num, Rendezvous* r, StatusCallback done)
ExecutorBarrier(size_t num, Rendezvous* r, StatusCallback done)
: rendez_(r), done_cb_(done), pending_(num) {}
~ExecutorBarrier() {}

View File

@ -274,8 +274,9 @@ class CallOp : public AsyncOpKernel {
if (!status.ok()) {
ctx->SetStatus(status);
} else {
CHECK_EQ(rets->size(), ctx->num_outputs());
for (size_t i = 0; i < rets->size(); ++i) {
const int ret_size = static_cast<int>(rets->size());
CHECK_EQ(ret_size, ctx->num_outputs());
for (int i = 0; i < ret_size; ++i) {
ctx->set_output(i, (*rets)[i]);
}
}
@ -1000,7 +1001,7 @@ string NewName(const Node* n, bool pretty) {
void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) {
// We visit nodes in forward topological sort order, which is a
// possible execution order of the graph.
std::vector<int> pending(g->num_node_ids());
std::vector<size_t> pending(g->num_node_ids());
std::deque<const Node*> ready;
for (const Node* n : g->nodes()) {
pending[n->id()] = n->in_edges().size();
@ -1154,7 +1155,7 @@ FunctionBody* SymbolicGradientHelper::Compute() {
Graph* g = gbody_->graph;
const int num_y = gbody_->ret_nodes.size();
const int num_y = static_cast<int>(gbody_->ret_nodes.size());
// Populate 'y_node_outputs_' with node function body outputs.
// Populate 'y_grad_nodes' with initial gradient nodes for each return node of
@ -1169,7 +1170,7 @@ FunctionBody* SymbolicGradientHelper::Compute() {
y_node_outputs.push_back({y, 0});
DCHECK_EQ(y->type_string(), kRetOp);
const DataType dtype = y->input_type(0);
const int index = gbody_->arg_nodes.size();
const int index = static_cast<int>(gbody_->arg_nodes.size());
Node* dy = AddArg(g, dtype, index);
gbody_->arg_types.push_back(dtype);
gbody_->arg_nodes.push_back(dy);
@ -1177,7 +1178,7 @@ FunctionBody* SymbolicGradientHelper::Compute() {
}
// Populate 'x_nodes' with function args (excluding 'y_grad_node_outputs').
const int num_x = fbody_->arg_nodes.size();
const size_t num_x = fbody_->arg_nodes.size();
std::vector<NodeOut> x_node_outputs;
x_node_outputs.reserve(num_x);
for (size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
@ -1200,7 +1201,8 @@ FunctionBody* SymbolicGradientHelper::Compute() {
gbody_->ret_nodes.clear();
// Add new return nodes to the function gradient body for each node
// in 'x_grad_nodes'.
for (size_t i = 0; i < fbody_->arg_types.size(); ++i) {
const int arg_types_size = static_cast<int>(fbody_->arg_types.size());
for (int i = 0; i < arg_types_size; ++i) {
Endpoint grad = {x_grad_node_outputs[i].node, x_grad_node_outputs[i].index};
Node* ret = AddRet(g, grad, i);
gbody_->ret_nodes.push_back(ret);

View File

@ -82,7 +82,7 @@ Status AssignStreams(const Graph* graph, const AssignStreamsOpts& opts,
// Determine a suitable stream to use.
int stream_id = highest_stream_id + 1;
for (const Edge* e : n->in_edges()) {
const int fanout = e->src()->out_edges().size();
const size_t fanout = e->src()->out_edges().size();
if (fanout == 1) {
stream_id = (*node_to_stream_id)[e->src()->id()];
break;

View File

@ -191,7 +191,7 @@ Allocator* ProcessState::GetCUDAHostAllocator(int numa_node) {
// example, process_state could maybe save the first stream executor
// it knows is valid.
gpu::StreamExecutor* se = nullptr;
for (size_t i = 0; i < gpu_allocators_.size(); ++i) {
for (int i = 0; i < static_cast<int>(gpu_allocators_.size()); ++i) {
if (gpu_allocators_[i] != nullptr) {
se = GPUMachineManager()->ExecutorForDevice(i).ValueOrDie();
break;

View File

@ -15,7 +15,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/memory_types.h"
@ -95,22 +94,24 @@ class SimpleRendezvous : public Rendezvous {
} // namespace
// static
GraphRunner::GraphRunner(Env* env) : cpu_device_(GetCPUDevice(env)) {}
GraphRunner::~GraphRunner() {}
Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
Env* env, const NamedTensorList& inputs,
const NamedTensorList& inputs,
const std::vector<string>& output_names,
std::vector<Tensor>* outputs) {
if (cpu_device_ == nullptr) {
return errors::NotFound("Cannot find a device for GraphRunner.");
}
// TODO(vrv): Instead of copying the entire graph, consider modifying
// the existing graph, and then removing those removed edges.
// prior to returning.
std::unique_ptr<Graph> graph_to_run(new Graph(graph->op_registry()));
CopyGraph(*graph, graph_to_run.get());
std::unique_ptr<Device> device = GetCPUDevice(env);
if (!device) {
return errors::NotFound("Cannot find a device for GraphRunner.");
}
SimpleRendezvous* rendez = new SimpleRendezvous;
core::ScopedUnref rendez_unref(rendez);
@ -130,7 +131,7 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
// Call RewriteGraphForExecution
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
graph_to_run.get(), input_names, output_names, {} /* target nodes */,
device->attributes()));
cpu_device_->attributes()));
// Create the local executor and the Rendezvous for fetching back the
// constants.
@ -143,10 +144,11 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
Graph* g = graph_to_run.release();
LocalExecutorParams params;
params.device = device.get();
// The ownership of the output tensors are bound to this device's lifetime.
params.device = cpu_device_.get();
params.function_library = function_library;
params.create_kernel = [&device, g](const NodeDef& ndef, OpKernel** kernel) {
return CreateNonCachedKernel(device.get(), nullptr, ndef,
params.create_kernel = [this, g](const NodeDef& ndef, OpKernel** kernel) {
return CreateNonCachedKernel(cpu_device_.get(), nullptr, ndef,
g->versions().producer(), kernel);
};
params.delete_kernel = [](OpKernel* kernel) { delete kernel; };

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <vector>
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/graph.h"
@ -44,16 +45,26 @@ namespace tensorflow {
// to be particularly lightweight, fast, or efficient.
class GraphRunner {
public:
// REQUIRES: `env` is not nullptr.
GraphRunner(Env* env);
~GraphRunner();
// Function semantics for `inputs`, `output_names` and `outputs`
// matches those from Session::Run().
//
// NOTE: The output tensors share lifetime with the GraphRunner, and could
// be destroyed once the GraphRunner is destroyed.
//
// REQUIRES: `graph`, `env`, and `outputs` are not nullptr.
// `function_library` may be nullptr.
typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
static Status Run(Graph* graph, FunctionLibraryRuntime* function_library,
Env* env, const NamedTensorList& inputs,
Status Run(Graph* graph, FunctionLibraryRuntime* function_library,
const NamedTensorList& inputs,
const std::vector<string>& output_names,
std::vector<Tensor>* outputs);
private:
std::unique_ptr<Device> cpu_device_;
};
} // namespace tensorflow

View File

@ -46,9 +46,9 @@ using test::internal::ExpectEqual;
TEST(GraphRunnerTest, SingleConst) {
Scope root = Scope::NewRootScope();
auto c = ops::Const(root, 42.0f);
GraphRunner graph_runner(Env::Default());
std::vector<Tensor> outputs;
Status s = GraphRunner::Run(root.graph(), nullptr, Env::Default(), {},
{c.name()}, &outputs);
Status s = graph_runner.Run(root.graph(), nullptr, {}, {c.name()}, &outputs);
TF_ASSERT_OK(s);
ExpectEqual(42.0f, outputs[0].scalar<float>()());
}
@ -57,9 +57,10 @@ TEST(GraphRunnerTest, MultiFetchConst) {
Scope root = Scope::NewRootScope();
auto c = ops::Const(root, 42.0f);
auto pi = ops::Const(root, 3.14f);
GraphRunner graph_runner(Env::Default());
std::vector<Tensor> outputs;
Status s = GraphRunner::Run(root.graph(), nullptr, Env::Default(), {},
{c.name(), pi.name()}, &outputs);
Status s = graph_runner.Run(root.graph(), nullptr, {}, {c.name(), pi.name()},
&outputs);
TF_ASSERT_OK(s);
ExpectEqual(42.0f, outputs[0].scalar<float>()());
ExpectEqual(3.14f, outputs[1].scalar<float>()());
@ -78,9 +79,10 @@ TEST(GraphRunnerTest, FeedAndFetch) {
std::vector<std::pair<string, Tensor>> inputs = {{"p1:0", p1_data},
{"p2:0", p2_data}};
GraphRunner graph_runner(Env::Default());
std::vector<Tensor> outputs;
Status s = GraphRunner::Run(root.graph(), nullptr, Env::Default(), inputs,
{"add:0"}, &outputs);
Status s =
graph_runner.Run(root.graph(), nullptr, inputs, {"add:0"}, &outputs);
TF_ASSERT_OK(s);
ExpectEqual(3.0f, outputs[0].scalar<float>()());
}

Some files were not shown because too many files have changed in this diff Show More