Merge pull request #8958 from rohan100jain/branch_152141388
Branch 152141388
This commit is contained in:
commit
efe5376f3d
@ -202,6 +202,7 @@ filegroup(
|
||||
"//tensorflow/contrib/boosted_trees:all_files",
|
||||
"//tensorflow/contrib/boosted_trees/lib:all_files",
|
||||
"//tensorflow/contrib/boosted_trees/proto:all_files",
|
||||
"//tensorflow/contrib/boosted_trees/resources:all_files",
|
||||
"//tensorflow/contrib/cloud:all_files",
|
||||
"//tensorflow/contrib/cloud/kernels:all_files",
|
||||
"//tensorflow/contrib/compiler:all_files",
|
||||
@ -256,6 +257,7 @@ filegroup(
|
||||
"//tensorflow/contrib/tfprof/python/tools/tfprof:all_files",
|
||||
"//tensorflow/contrib/training:all_files",
|
||||
"//tensorflow/contrib/util:all_files",
|
||||
"//tensorflow/contrib/xla_tf_graph:all_files",
|
||||
"//tensorflow/core:all_files",
|
||||
"//tensorflow/core/debug:all_files",
|
||||
"//tensorflow/core/distributed_runtime:all_files",
|
||||
|
@ -51,6 +51,7 @@ genrule(
|
||||
"test_graph_tfgather.pb",
|
||||
"test_graph_tfmatmul.pb",
|
||||
"test_graph_tfmatmulandadd.pb",
|
||||
"test_graph_tffunction.pb",
|
||||
],
|
||||
cmd = "$(location :make_test_graphs) --out_dir $(@D)",
|
||||
tags = ["manual"],
|
||||
@ -114,6 +115,15 @@ tf_library(
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tffunction",
|
||||
testonly = 1,
|
||||
config = "test_graph_tffunction.config.pbtxt",
|
||||
cpp_class = "FunctionComp",
|
||||
graph = "test_graph_tffunction.pb",
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "tfcompile_test",
|
||||
srcs = ["tfcompile_test.cc"],
|
||||
@ -122,6 +132,7 @@ cc_test(
|
||||
":test_graph_tfadd",
|
||||
":test_graph_tfadd_with_ckpt",
|
||||
":test_graph_tfadd_with_ckpt_saver",
|
||||
":test_graph_tffunction",
|
||||
":test_graph_tfgather",
|
||||
":test_graph_tfmatmul",
|
||||
":test_graph_tfmatmulandadd",
|
||||
|
@ -25,6 +25,7 @@ from tensorflow.core.protobuf import saver_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -95,6 +96,17 @@ def tfmatmulandadd(_):
|
||||
math_ops.add(x, y, name='x_y_sum')
|
||||
|
||||
|
||||
def tffunction(_):
|
||||
|
||||
@function.Defun(dtypes.int32, dtypes.int32)
|
||||
def test_func(a, b):
|
||||
return a + b
|
||||
|
||||
x = constant_op.constant([1], name='x_const')
|
||||
y = constant_op.constant([2], name='y_const')
|
||||
test_func(x, y, name='func_call') # pylint: disable=unexpected-keyword-arg
|
||||
|
||||
|
||||
def write_graph(build_graph, out_dir):
|
||||
"""Build a graph using build_graph and write it out."""
|
||||
g = ops.Graph()
|
||||
@ -112,6 +124,7 @@ def main(_):
|
||||
write_graph(tfgather, FLAGS.out_dir)
|
||||
write_graph(tfmatmul, FLAGS.out_dir)
|
||||
write_graph(tfmatmulandadd, FLAGS.out_dir)
|
||||
write_graph(tffunction, FLAGS.out_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -121,7 +134,6 @@ if __name__ == '__main__':
|
||||
'--out_dir',
|
||||
type=str,
|
||||
default='',
|
||||
help='Output directory for graphs, checkpoints and savers.'
|
||||
)
|
||||
help='Output directory for graphs, checkpoints and savers.')
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
@ -0,0 +1,16 @@
|
||||
# Text form of tensorflow.tfcompile.Config proto.
|
||||
feed {
|
||||
id { node_name: "x_const" }
|
||||
shape {
|
||||
dim { size: 1 }
|
||||
}
|
||||
}
|
||||
feed {
|
||||
id { node_name: "y_const" }
|
||||
shape {
|
||||
dim { size: 1 }
|
||||
}
|
||||
}
|
||||
fetch {
|
||||
id { node_name: "func_call" }
|
||||
}
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tffunction.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
|
||||
@ -376,6 +377,21 @@ TEST(TFCompileTest, MatMulAndAdd1) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TFCompileTest, Function) {
|
||||
// The function is equivalent to an addition
|
||||
FunctionComp add_fn;
|
||||
EXPECT_EQ(add_fn.arg0_data(), add_fn.args()[0]);
|
||||
EXPECT_EQ(add_fn.arg1_data(), add_fn.args()[1]);
|
||||
|
||||
add_fn.arg0() = 1;
|
||||
add_fn.arg1() = 2;
|
||||
EXPECT_TRUE(add_fn.Run());
|
||||
EXPECT_EQ(add_fn.error_msg(), "");
|
||||
EXPECT_EQ(add_fn.result0(), 3);
|
||||
EXPECT_EQ(add_fn.result0_data()[0], 3);
|
||||
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
@ -50,7 +50,7 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
|
||||
}
|
||||
|
||||
// Make sure we don't recurse infinitely on recursive functions.
|
||||
const int kMaxRecursionDepth = 5;
|
||||
const int kMaxRecursionDepth = 10;
|
||||
|
||||
bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type,
|
||||
int depth, FunctionLibraryRuntime* lib_runtime);
|
||||
|
@ -2339,6 +2339,14 @@ TEST_F(OpTest, ZerosLike) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, OnesLike) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("OnesLike").Input(RandomTensor(type)).Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -257,6 +257,11 @@ class UnaryOpsTest(XLATestCase):
|
||||
np.array([[4, 3], [2, 1]], dtype=dtype),
|
||||
expected=np.array([[0, 0], [0, 0]], dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
array_ops.ones_like,
|
||||
np.array([[4, 3], [2, 1]], dtype=dtype),
|
||||
expected=np.array([[1, 1], [1, 1]], dtype=dtype))
|
||||
|
||||
def testLogicalOps(self):
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.logical_not,
|
||||
|
@ -241,5 +241,19 @@ class ZerosLikeOp : public XlaOpKernel {
|
||||
|
||||
REGISTER_XLA_OP(Name("ZerosLike"), ZerosLikeOp);
|
||||
|
||||
class OnesLikeOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit OnesLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
|
||||
auto one = XlaHelpers::One(ctx->builder(), input_type(0));
|
||||
ctx->SetOutput(0, ctx->builder()->Broadcast(one, input_shape.dim_sizes()));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("OnesLike"), OnesLikeOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -59,9 +59,15 @@ Status CheckSignature(const DataTypeVector& types,
|
||||
|
||||
XlaCompiler::XlaCompiler(XlaCompiler::Options options)
|
||||
: options_(std::move(options)),
|
||||
initialization_status_(Status::OK()),
|
||||
next_step_id_(1),
|
||||
device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
|
||||
device_mgr_({device_}) {}
|
||||
device_mgr_({device_}) {
|
||||
if (options_.populate_resource_manager) {
|
||||
initialization_status_ =
|
||||
(*options_.populate_resource_manager)(device_->resource_manager());
|
||||
}
|
||||
}
|
||||
|
||||
XlaCompiler::~XlaCompiler() = default;
|
||||
|
||||
@ -379,6 +385,9 @@ Status XlaCompiler::CompileGraph(string const& name,
|
||||
CompilationResult* result) {
|
||||
VLOG(1) << "Executing graph symbolically to populate ComputationBuilder.";
|
||||
|
||||
// Report the error here if initialization failed.
|
||||
TF_RETURN_IF_ERROR(initialization_status_);
|
||||
|
||||
xla::ComputationBuilder builder(client(), name);
|
||||
XlaContext* context =
|
||||
new XlaContext(this, &builder, options_.allow_cpu_custom_calls,
|
||||
|
@ -214,6 +214,12 @@ class XlaCompiler {
|
||||
// This is useful to prune stateful operators that should not be executed
|
||||
// from a function body.
|
||||
bool prune_unreachable_nodes = false;
|
||||
|
||||
// If not nullptr, populate_resource_manager is called with the
|
||||
// compilation device's resource manager when the compilation
|
||||
// device is created, and can be used to create metadata objects
|
||||
// that can be accessed by XLA op kernels.
|
||||
std::function<Status(ResourceMgr*)>* populate_resource_manager = nullptr;
|
||||
};
|
||||
|
||||
explicit XlaCompiler(Options options);
|
||||
@ -247,6 +253,7 @@ class XlaCompiler {
|
||||
Status BuildExecutable(const CompilationResult& result,
|
||||
std::unique_ptr<xla::LocalExecutable>* executable);
|
||||
|
||||
const Options& options() const { return options_; }
|
||||
xla::Client* client() const { return options_.client; }
|
||||
XlaCompilationDevice* device() const { return device_; }
|
||||
const DeviceMgr* device_mgr() const { return &device_mgr_; }
|
||||
@ -260,6 +267,9 @@ class XlaCompiler {
|
||||
private:
|
||||
Options options_;
|
||||
|
||||
// Status set to non-OK in the constructor if initialization fails.
|
||||
Status initialization_status_;
|
||||
|
||||
// Returns the next step sequence number.
|
||||
int64 NextStepId();
|
||||
|
||||
|
@ -17,12 +17,14 @@ limitations under the License.
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
@ -33,6 +35,65 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Helper class to test the ability to pass resources through to XLA
|
||||
// compiled kernels.
|
||||
class DummyResourceForTest : public ResourceBase {
|
||||
public:
|
||||
string DebugString() override { return "dummy"; }
|
||||
void Increment() { ++value_; }
|
||||
int Get() { return value_; }
|
||||
|
||||
private:
|
||||
int value_ = 0;
|
||||
};
|
||||
|
||||
class DummyReadResourceOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit DummyReadResourceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
ResourceMgr* rm = ctx->op_kernel_context()->resource_manager();
|
||||
OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
|
||||
DummyResourceForTest* dummy;
|
||||
OP_REQUIRES_OK(ctx, rm->Lookup<DummyResourceForTest>(
|
||||
rm->default_container(), "dummy", &dummy));
|
||||
dummy->Increment();
|
||||
dummy->Unref();
|
||||
|
||||
ctx->SetOutput(0, ctx->Input(0));
|
||||
}
|
||||
};
|
||||
|
||||
class DummyReadResourceCC {
|
||||
public:
|
||||
DummyReadResourceCC(const Scope& scope, const Input& value) {
|
||||
if (!scope.ok()) return;
|
||||
auto _value = ops::AsNodeOut(scope, value);
|
||||
if (!scope.ok()) return;
|
||||
Node* ret;
|
||||
const auto unique_name = scope.GetUniqueNameForOp("DummyReadResource");
|
||||
auto builder = NodeBuilder(unique_name, "DummyReadResource").Input(_value);
|
||||
scope.UpdateBuilder(&builder);
|
||||
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
|
||||
if (!scope.ok()) return;
|
||||
this->output_ = Output(ret, 0);
|
||||
}
|
||||
Node* node() const { return output_.node(); }
|
||||
|
||||
Output output_;
|
||||
};
|
||||
|
||||
REGISTER_OP("DummyReadResource")
|
||||
.Input("input: int32")
|
||||
.Output("output: int32")
|
||||
.Doc(R"doc(
|
||||
A dummy Op.
|
||||
|
||||
input: dummy input.
|
||||
output: dummy output.
|
||||
)doc");
|
||||
|
||||
REGISTER_XLA_OP(Name("DummyReadResource"), DummyReadResourceOp);
|
||||
|
||||
class XlaCompilerTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
@ -224,5 +285,45 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
|
||||
}
|
||||
}
|
||||
|
||||
// Tests compilation and execution of a graph that adds two tensors.
|
||||
TEST_F(XlaCompilerTest, ResourceManager) {
|
||||
// Builds a graph that calls the dummy resource Op.
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
|
||||
auto b = DummyReadResourceCC(scope.WithOpName("B"), a);
|
||||
auto c = ops::_Retval(scope.WithOpName("C"), b.output_, 0);
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_ASSERT_OK(scope.ToGraph(graph.get()));
|
||||
|
||||
// Builds a description of the argument.
|
||||
std::vector<XlaCompiler::Argument> args(1);
|
||||
args[0].kind = XlaCompiler::Argument::kParameter;
|
||||
args[0].type = DT_INT32;
|
||||
args[0].shape = TensorShape({2});
|
||||
|
||||
DummyResourceForTest* resource = new DummyResourceForTest();
|
||||
|
||||
// Compiles the graph.
|
||||
auto options = DefaultOptions();
|
||||
std::function<Status(ResourceMgr*)> populate_function =
|
||||
[resource](ResourceMgr* rm) {
|
||||
resource->Ref();
|
||||
return rm->Create(rm->default_container(), "dummy", resource);
|
||||
};
|
||||
options.populate_resource_manager = &populate_function;
|
||||
XlaCompiler compiler(options);
|
||||
auto flr = BuildFunctionLibraryRuntime(compiler);
|
||||
|
||||
EXPECT_EQ(0, resource->Get());
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph("dummy", std::move(graph), flr.get(), args,
|
||||
&result));
|
||||
|
||||
EXPECT_EQ(1, resource->Get());
|
||||
|
||||
resource->Unref();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -354,6 +354,10 @@ void XlaOpKernelContext::SetOpHasSideEffects() {
|
||||
XlaContext::Get(context_).AddSideEffects();
|
||||
}
|
||||
|
||||
const XlaCompiler::Options& XlaOpKernelContext::GetCompilerOptions() const {
|
||||
return XlaContext::Get(context_).compiler()->options();
|
||||
}
|
||||
|
||||
void XlaOpKernelContext::CtxFailure(Status s) { context_->CtxFailure(s); }
|
||||
void XlaOpKernelContext::CtxFailureWithWarning(Status s) {
|
||||
context_->CtxFailureWithWarning(s);
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
@ -182,6 +183,11 @@ class XlaOpKernelContext {
|
||||
// Returns the underlying OpKernelContext. Use rarely.
|
||||
OpKernelContext* op_kernel_context() const { return context_; }
|
||||
|
||||
// Returns the options passed to the XlaCompiler that is being
|
||||
// run. Used for, e.g., While to inherit options needed for nested
|
||||
// computation.
|
||||
const XlaCompiler::Options& GetCompilerOptions() const;
|
||||
|
||||
// TODO(phawkins): find a better home for these helpers.
|
||||
|
||||
// Get an XLA lambda to compute Max. This is cached in the
|
||||
|
@ -167,6 +167,8 @@ void XlaOpRegistry::RegisterCompilationKernels() {
|
||||
!backend.second.op_filter(kdef.get())) {
|
||||
continue;
|
||||
}
|
||||
VLOG(2) << "XLA op registration: device: " << backend.first
|
||||
<< " op: " << op.first;
|
||||
registry.kernel_registrars_.emplace_back(
|
||||
new kernel_factory::OpKernelRegistrar(
|
||||
new KernelDef(*kdef), "XlaJitOp", op.second->factory));
|
||||
|
@ -6,6 +6,7 @@ package_group(
|
||||
name = "friends",
|
||||
packages = [
|
||||
"//tensorflow/compiler/...",
|
||||
"//tensorflow/contrib/xla_tf_graph/...",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1229,8 +1229,7 @@ StatusOr<bool> ComputationBuilder::IsConstant(
|
||||
VLOG(2) << "done with request";
|
||||
|
||||
if (!s.ok()) {
|
||||
NoteError(s);
|
||||
return first_error_;
|
||||
return s;
|
||||
}
|
||||
return response.is_constant();
|
||||
}
|
||||
@ -1255,8 +1254,7 @@ StatusOr<std::unique_ptr<GlobalData>> ComputationBuilder::ComputeConstant(
|
||||
VLOG(2) << "done with request";
|
||||
|
||||
if (!s.ok()) {
|
||||
NoteError(s);
|
||||
return first_error_;
|
||||
return s;
|
||||
}
|
||||
|
||||
TF_RET_CHECK(response.output().handle() != 0);
|
||||
|
@ -120,6 +120,7 @@ class HloComputation {
|
||||
}
|
||||
|
||||
const string& name() const { return name_; }
|
||||
void set_name(const string& name) { name_ = name; }
|
||||
|
||||
// Return a string representation of the computation.
|
||||
string ToString() const;
|
||||
@ -257,7 +258,7 @@ class HloComputation {
|
||||
// Internal helper to collect unreachable roots.
|
||||
std::vector<HloInstruction*> CollectUnreachableRoots() const;
|
||||
|
||||
const string name_;
|
||||
string name_;
|
||||
HloInstruction* root_instruction_;
|
||||
|
||||
// Module containing this computation.
|
||||
|
@ -357,7 +357,9 @@ Status HloCostAnalysis::HandleRng(HloInstruction* random,
|
||||
Status HloCostAnalysis::HandleFusion(HloInstruction* fusion) {
|
||||
// Compute the cost of the fused expression.
|
||||
HloInstruction* fused_expression_root = fusion->fused_expression_root();
|
||||
HloCostAnalysis visitor(shape_size_);
|
||||
// Don't compute sizes inside of fused ops. We don't use the size here and the
|
||||
// operations inside might not have a layout.
|
||||
HloCostAnalysis visitor([](const Shape&) { return 0; });
|
||||
TF_RETURN_IF_ERROR(fused_expression_root->Accept(&visitor));
|
||||
|
||||
// Attribute the cost of the fused expression to the fusion node.
|
||||
|
@ -375,6 +375,33 @@ TEST_F(FusionCostAnalysis, LoopFusion) {
|
||||
EXPECT_EQ(fusion_analysis.transcendental_count(), 4);
|
||||
}
|
||||
|
||||
TEST_F(FusionCostAnalysis, NoLayout) {
|
||||
Shape shape_with_layout = ShapeUtil::MakeShape(F32, {2, 3, 4, 5});
|
||||
// Instructions within a fused op may have no layout.
|
||||
Shape shape_without_layout = shape_with_layout;
|
||||
shape_without_layout.clear_layout();
|
||||
|
||||
auto c1 = HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR4FromArray4D(Array4D<float>(2, 3, 4, 5)));
|
||||
auto c2 =
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1, 2, 3}));
|
||||
|
||||
auto broadcast =
|
||||
HloInstruction::CreateBroadcast(shape_without_layout, c2.get(), {1});
|
||||
auto add = HloInstruction::CreateBinary(shape_with_layout, HloOpcode::kAdd,
|
||||
c1.get(), broadcast.get());
|
||||
|
||||
auto fusion = HloInstruction::CreateFusion(
|
||||
shape_with_layout, HloInstruction::FusionKind::kLoop, add.get());
|
||||
fusion->FuseInstruction(broadcast.get());
|
||||
|
||||
HloCostAnalysis fusion_analysis(ShapeSize);
|
||||
ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
|
||||
|
||||
EXPECT_EQ(fusion_analysis.flop_count(), 120);
|
||||
EXPECT_EQ(fusion_analysis.transcendental_count(), 0);
|
||||
}
|
||||
|
||||
TEST_F(HloCostAnalysisTest, TupleCost) {
|
||||
HloCostAnalysis analysis(ShapeSize);
|
||||
{
|
||||
|
@ -31,20 +31,38 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
HloComputation* HloModule::AddEntryComputation(
|
||||
HloModule::HloModule(const string& name,
|
||||
const VersionedComputationHandle& entry_computation_handle)
|
||||
: name_(name),
|
||||
entry_computation_(nullptr),
|
||||
has_entry_computation_handle_(true),
|
||||
entry_computation_handle_(entry_computation_handle),
|
||||
computation_name_uniquer_(/*separator=*/".") {}
|
||||
|
||||
HloModule::HloModule(const string& name)
|
||||
: name_(name),
|
||||
entry_computation_(nullptr),
|
||||
computation_name_uniquer_(/*separator=*/".") {}
|
||||
|
||||
HloComputation* HloModule::AddComputationInternal(
|
||||
std::unique_ptr<HloComputation> computation) {
|
||||
CHECK_EQ(nullptr, entry_computation_);
|
||||
entry_computation_ = computation.get();
|
||||
computation->set_name(
|
||||
computation_name_uniquer_.GetUniqueName(computation->name()));
|
||||
computation->set_parent(this);
|
||||
computations_.push_back(std::move(computation));
|
||||
return computations_.back().get();
|
||||
}
|
||||
|
||||
HloComputation* HloModule::AddEntryComputation(
|
||||
std::unique_ptr<HloComputation> computation) {
|
||||
CHECK_EQ(nullptr, entry_computation_);
|
||||
entry_computation_ = computation.get();
|
||||
return AddComputationInternal(std::move(computation));
|
||||
}
|
||||
|
||||
HloComputation* HloModule::AddEmbeddedComputation(
|
||||
std::unique_ptr<HloComputation> computation) {
|
||||
computation->set_parent(this);
|
||||
computations_.push_back(std::move(computation));
|
||||
return computations_.back().get();
|
||||
return AddComputationInternal(std::move(computation));
|
||||
}
|
||||
|
||||
void HloModule::ReplaceComputations(
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/name_uniquer.h"
|
||||
#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
@ -41,19 +42,14 @@ namespace xla {
|
||||
// computations are owned by the module.
|
||||
class HloModule {
|
||||
public:
|
||||
explicit HloModule(const string& name,
|
||||
const VersionedComputationHandle& entry_computation_handle)
|
||||
: name_(name),
|
||||
entry_computation_(nullptr),
|
||||
has_entry_computation_handle_(true),
|
||||
entry_computation_handle_(entry_computation_handle) {}
|
||||
HloModule(const string& name,
|
||||
const VersionedComputationHandle& entry_computation_handle);
|
||||
|
||||
// Constructor without a versioned computation handle. This constructor should
|
||||
// only be used for HloModules used outside of the XLA service (eg
|
||||
// tests). The versioned handle is used by the service in the compilation
|
||||
// cache.
|
||||
explicit HloModule(const string& name)
|
||||
: name_(name), entry_computation_(nullptr) {}
|
||||
explicit HloModule(const string& name);
|
||||
|
||||
// Adds an entry computation to the module. A module can only have one entry
|
||||
// computation. Returns a pointer to the newly added computation.
|
||||
@ -111,6 +107,9 @@ class HloModule {
|
||||
uint64 RandomNew64() const;
|
||||
|
||||
private:
|
||||
HloComputation* AddComputationInternal(
|
||||
std::unique_ptr<HloComputation> computation);
|
||||
|
||||
const string name_;
|
||||
HloComputation* entry_computation_;
|
||||
std::vector<std::unique_ptr<HloComputation>> computations_;
|
||||
@ -125,6 +124,9 @@ class HloModule {
|
||||
// Versioned handle of the entry computation of the module.
|
||||
bool has_entry_computation_handle_ = false;
|
||||
VersionedComputationHandle entry_computation_handle_;
|
||||
|
||||
// Unique name generator for computation names, which are unique per module.
|
||||
NameUniquer computation_name_uniquer_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -74,6 +74,11 @@ TEST_F(HloModuleTest, TwoComputationsPostOrder) {
|
||||
EXPECT_MATCH(
|
||||
testing::ListToVec<HloComputation*>(module->MakeComputationPostOrder()),
|
||||
testing::UnorderedMatcher<HloComputation*>(computation1, computation2));
|
||||
|
||||
// We specified the same name for both computations, but the HloModule should
|
||||
// have made the names unique.
|
||||
EXPECT_EQ(computation1->name(), "Constant");
|
||||
EXPECT_EQ(computation2->name(), "Constant.1");
|
||||
}
|
||||
|
||||
TEST_F(HloModuleTest, DiamondComputationsPostOrder) {
|
||||
|
@ -633,26 +633,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
TF_DCHECK_OK(ShapeUtil::ValidateShape(ehs));
|
||||
switch (operation) {
|
||||
case TRIOP_CLAMP:
|
||||
TF_RETURN_IF_ERROR(
|
||||
ExpectNotTupleOrOpaque(lhs, "lhs of ternary operation"));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ExpectNotTupleOrOpaque(rhs, "rhs of ternary operation"));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ExpectNotTupleOrOpaque(ehs, "ehs of ternary operation"));
|
||||
if (((ShapeUtil::Compatible(lhs, rhs) || ShapeUtil::Rank(lhs) == 0) &&
|
||||
(ShapeUtil::Compatible(rhs, ehs) || ShapeUtil::Rank(ehs) == 0))) {
|
||||
return rhs;
|
||||
}
|
||||
if (ShapeUtil::Rank(rhs) == 0) {
|
||||
if (ShapeUtil::Compatible(lhs, ehs)) {
|
||||
return lhs;
|
||||
}
|
||||
return ShapeUtil::Rank(ehs) == 0 ? lhs : ehs;
|
||||
}
|
||||
return Unimplemented("not yet implemented: %s, %s <clamp> %s",
|
||||
lhs.ShortDebugString().c_str(),
|
||||
ehs.ShortDebugString().c_str(),
|
||||
rhs.ShortDebugString().c_str());
|
||||
return InferClampShape(lhs, rhs, ehs);
|
||||
case TRIOP_SELECT:
|
||||
return InferSelectShape(lhs, rhs, ehs);
|
||||
case TRIOP_UPDATE:
|
||||
@ -1332,6 +1313,41 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand);
|
||||
}
|
||||
|
||||
// TODO(b/36794510): Make broadcast semantics more consistent, by supporting
|
||||
// "degenerate" cases, as with binary elementwise ops.
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferClampShape(
|
||||
const Shape& min, const Shape& operand, const Shape& max) {
|
||||
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min"));
|
||||
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand"));
|
||||
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max"));
|
||||
if (!ShapeUtil::SameElementType(min, operand) ||
|
||||
!ShapeUtil::SameElementType(max, operand)) {
|
||||
return InvalidArgument("clamp op with different operand types: %s, %s, %s",
|
||||
ShapeUtil::HumanString(min).c_str(),
|
||||
ShapeUtil::HumanString(operand).c_str(),
|
||||
ShapeUtil::HumanString(max).c_str());
|
||||
}
|
||||
if (((ShapeUtil::Compatible(min, operand) || ShapeUtil::IsScalar(min)) &&
|
||||
(ShapeUtil::Compatible(max, operand) || ShapeUtil::IsScalar(max)))) {
|
||||
return operand;
|
||||
}
|
||||
if (ShapeUtil::IsScalar(operand)) {
|
||||
if (ShapeUtil::Compatible(min, max)) {
|
||||
return min;
|
||||
} else if (ShapeUtil::IsScalar(min)) {
|
||||
return max;
|
||||
} else if (ShapeUtil::IsScalar(max)) {
|
||||
return min;
|
||||
}
|
||||
}
|
||||
return Unimplemented(
|
||||
"not yet implemented: %s, %s <clamp> %s", min.ShortDebugString().c_str(),
|
||||
max.ShortDebugString().c_str(), operand.ShortDebugString().c_str());
|
||||
}
|
||||
|
||||
// TODO(b/36794510): Make broadcast semantics more consistent, by supporting
|
||||
// "degenerate" cases, as with binary elementwise ops, as well as scalar
|
||||
// broadcast from all operands, not just the predicate.
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
|
||||
const Shape& pred, const Shape& on_true, const Shape& on_false) {
|
||||
if (!ShapeUtil::Compatible(on_true, on_false)) {
|
||||
|
@ -190,6 +190,10 @@ class ShapeInference {
|
||||
BinaryOperation operation, const Shape& lhs, const Shape& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
|
||||
|
||||
// Helper for inferring the shape of Clamp ops.
|
||||
static StatusOr<Shape> InferClampShape(const Shape& min, const Shape& operand,
|
||||
const Shape& max);
|
||||
|
||||
// Helper for inferring the shape of Select ops.
|
||||
static StatusOr<Shape> InferSelectShape(const Shape& pred,
|
||||
const Shape& on_true,
|
||||
|
@ -157,6 +157,99 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) {
|
||||
testing::ContainsRegex("pred operand must have PRED element type"));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ClampAllMatrix) {
|
||||
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
||||
TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_,
|
||||
matrix_64_48_);
|
||||
ASSERT_IS_OK(inferred_status.status());
|
||||
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ClampAllScalar) {
|
||||
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
||||
TernaryOperation::TRIOP_CLAMP, f32_, f32_, f32_);
|
||||
ASSERT_IS_OK(inferred_status.status());
|
||||
ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ClampMinScalar) {
|
||||
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
||||
TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, matrix_64_48_);
|
||||
ASSERT_IS_OK(inferred_status.status());
|
||||
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ClampMaxScalar) {
|
||||
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
||||
TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, f32_);
|
||||
ASSERT_IS_OK(inferred_status.status());
|
||||
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ClampOperandScalar) {
|
||||
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
||||
TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, matrix_64_48_);
|
||||
ASSERT_IS_OK(inferred_status.status());
|
||||
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ClampMinMatrix) {
|
||||
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
||||
TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, f32_);
|
||||
ASSERT_IS_OK(inferred_status.status());
|
||||
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ClampMaxMatrix) {
|
||||
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
||||
TernaryOperation::TRIOP_CLAMP, f32_, f32_, matrix_64_48_);
|
||||
ASSERT_IS_OK(inferred_status.status());
|
||||
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ClampOperandMatrix) {
|
||||
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
||||
TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, f32_);
|
||||
ASSERT_IS_OK(inferred_status.status());
|
||||
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ClampBadShapes) {
|
||||
// Type mismatch
|
||||
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
|
||||
TernaryOperation::TRIOP_CLAMP, s32_, f32_, f32_)
|
||||
.ok());
|
||||
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
|
||||
TernaryOperation::TRIOP_CLAMP, f32_, s32_, f32_)
|
||||
.ok());
|
||||
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
|
||||
TernaryOperation::TRIOP_CLAMP, f32_, f32_, s32_)
|
||||
.ok());
|
||||
// Dimension mismatch
|
||||
ASSERT_FALSE(
|
||||
ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP,
|
||||
vector_64_, vector_32_, vector_32_)
|
||||
.ok());
|
||||
ASSERT_FALSE(
|
||||
ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP,
|
||||
vector_32_, vector_64_, vector_32_)
|
||||
.ok());
|
||||
ASSERT_FALSE(
|
||||
ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP,
|
||||
vector_32_, vector_32_, vector_64_)
|
||||
.ok());
|
||||
// Dimension mismatch, where one operand is a scalar
|
||||
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
|
||||
TernaryOperation::TRIOP_CLAMP, vector_64_, vector_32_, f32_)
|
||||
.ok());
|
||||
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
|
||||
TernaryOperation::TRIOP_CLAMP, vector_64_, f32_, vector_32_)
|
||||
.ok());
|
||||
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
|
||||
TernaryOperation::TRIOP_CLAMP, f32_, vector_64_, vector_32_)
|
||||
.ok());
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
|
||||
StatusOr<Shape> result = ShapeInference::InferVariadicOpShape(
|
||||
VariadicOperation::VAROP_TUPLE, {&s32_, &f32_});
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -245,37 +246,69 @@ XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsF32) {
|
||||
ComputeAndCompareR0<float>(&builder, 2.5f, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsS32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Div(builder.ConstantR0<int32>(-5), builder.ConstantR0<int32>(2));
|
||||
struct DivS32Params {
|
||||
int32 dividend;
|
||||
int32 divisor;
|
||||
int32 quotient;
|
||||
int32 remainder;
|
||||
};
|
||||
|
||||
ComputeAndCompareR0<int32>(&builder, -2, {});
|
||||
void PrintTo(const DivS32Params& p, std::ostream* os) {
|
||||
*os << "{" << p.dividend << ", " << p.divisor << ", " << p.quotient << ", "
|
||||
<< p.remainder << "}";
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, RemainderTwoScalarsNegativeResultS32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Rem(builder.ConstantR0<int32>(-5), builder.ConstantR0<int32>(2));
|
||||
class DivS32Test : public ClientLibraryTestBase,
|
||||
public ::testing::WithParamInterface<DivS32Params> {};
|
||||
|
||||
ComputeAndCompareR0<int32>(&builder, -1, {});
|
||||
XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) {
|
||||
DivS32Params p = GetParam();
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Div(builder.ConstantR0<int32>(p.dividend),
|
||||
builder.ConstantR0<int32>(p.divisor));
|
||||
|
||||
ComputeAndCompareR0<int32>(&builder, p.quotient, {});
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, RemainderTwoScalarsIntMinS32) {
|
||||
XLA_TEST_P(DivS32Test, RemainderTwoScalarsS32) {
|
||||
DivS32Params p = GetParam();
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Rem(builder.ConstantR0<int32>(INT_MIN),
|
||||
builder.ConstantR0<int32>(7919));
|
||||
builder.Rem(builder.ConstantR0<int32>(p.dividend),
|
||||
builder.ConstantR0<int32>(p.divisor));
|
||||
|
||||
ComputeAndCompareR0<int32>(&builder, -1309, {});
|
||||
ComputeAndCompareR0<int32>(&builder, p.remainder, {});
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, RemainderTwoScalarsIntMinVsIntMaxS32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Rem(builder.ConstantR0<int32>(INT_MIN),
|
||||
builder.ConstantR0<int32>(INT_MAX));
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
DivS32Test_Instantiation, DivS32Test,
|
||||
::testing::Values(
|
||||
// Positive divisors.
|
||||
DivS32Params{5, 2, 2, 1}, //
|
||||
DivS32Params{-5, 2, -2, -1}, //
|
||||
DivS32Params{17, 3, 5, 2}, //
|
||||
DivS32Params{-17, 3, -5, -2}, //
|
||||
// Negative divisors.
|
||||
DivS32Params{5, -2, -2, 1}, //
|
||||
DivS32Params{-5, -2, 2, -1}, //
|
||||
DivS32Params{17, -3, -5, 2}, //
|
||||
DivS32Params{-17, -3, 5, -2}, //
|
||||
// Large positive divisors.
|
||||
DivS32Params{INT32_MIN, 7919, -271181, -1309}, //
|
||||
DivS32Params{INT32_MIN, INT32_MAX, -1, -1}, //
|
||||
DivS32Params{INT32_MIN + 1, INT32_MAX, -1, 0}, //
|
||||
DivS32Params{INT32_MIN + 2, INT32_MAX, 0, INT32_MIN + 2}, //
|
||||
DivS32Params{INT32_MIN, 0x40000000, -2, 0}, //
|
||||
DivS32Params{INT32_MIN + 1, 0x40000000, -1, -0x3fffffff}, //
|
||||
// Large negative divisors.
|
||||
DivS32Params{INT32_MIN, INT32_MIN, 1, 0}, //
|
||||
DivS32Params{INT32_MIN, INT32_MIN + 1, 1, -1}, //
|
||||
DivS32Params{INT32_MIN + 1, INT32_MIN, 0, INT32_MIN + 1}, //
|
||||
DivS32Params{INT32_MAX, INT32_MIN, 0, INT32_MAX}, //
|
||||
DivS32Params{INT32_MAX, INT32_MIN + 1, -1, 0}, //
|
||||
DivS32Params{INT32_MIN, -0x40000000, 2, 0}, //
|
||||
DivS32Params{INT32_MIN + 1, -0x40000000, 1, -0x3fffffff}));
|
||||
|
||||
ComputeAndCompareR0<int32>(&builder, -1, {});
|
||||
}
|
||||
|
||||
TEST_F(ScalarComputationsTest, RemainderTwoScalarsPositiveResultS32) {
|
||||
TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x");
|
||||
builder.Rem(x, builder.ConstantR0<int32>(80000));
|
||||
|
@ -7,8 +7,6 @@ exports_files(["LICENSE"])
|
||||
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "if_not_windows")
|
||||
|
||||
py_library(
|
||||
name = "contrib_py",
|
||||
srcs = glob(["**/*.py"]),
|
||||
@ -46,6 +44,7 @@ py_library(
|
||||
"//tensorflow/contrib/losses:losses_py",
|
||||
"//tensorflow/contrib/memory_stats:memory_stats_py",
|
||||
"//tensorflow/contrib/metrics:metrics_py",
|
||||
"//tensorflow/contrib/nccl:nccl_py",
|
||||
"//tensorflow/contrib/ndlstm",
|
||||
"//tensorflow/contrib/nn:nn_py",
|
||||
"//tensorflow/contrib/opt:opt_py",
|
||||
@ -65,9 +64,7 @@ py_library(
|
||||
"//tensorflow/contrib/tfprof",
|
||||
"//tensorflow/contrib/training:training_py",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
] + if_not_windows([
|
||||
"//tensorflow/contrib/nccl:nccl_py",
|
||||
]),
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.contrib import image
|
||||
from tensorflow.contrib import input_pipeline
|
||||
from tensorflow.contrib import integrate
|
||||
from tensorflow.contrib import keras
|
||||
from tensorflow.contrib import kernel_methods
|
||||
from tensorflow.contrib import labeled_tensor
|
||||
from tensorflow.contrib import layers
|
||||
from tensorflow.contrib import learn
|
||||
@ -45,6 +46,7 @@ from tensorflow.contrib import lookup
|
||||
from tensorflow.contrib import losses
|
||||
from tensorflow.contrib import memory_stats
|
||||
from tensorflow.contrib import metrics
|
||||
from tensorflow.contrib import nccl
|
||||
from tensorflow.contrib import nn
|
||||
from tensorflow.contrib import opt
|
||||
from tensorflow.contrib import quantization
|
||||
|
@ -160,3 +160,90 @@ cc_test(
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "models",
|
||||
srcs = ["models/multiple_additive_trees.cc"],
|
||||
hdrs = ["models/multiple_additive_trees.h"],
|
||||
deps = [
|
||||
":trees",
|
||||
":utils",
|
||||
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "multiple_additive_trees_test",
|
||||
size = "small",
|
||||
srcs = ["models/multiple_additive_trees_test.cc"],
|
||||
deps = [
|
||||
":batch_features_testutil",
|
||||
":models",
|
||||
":random_tree_gen",
|
||||
"//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:tensor_testutil",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "trees",
|
||||
srcs = ["trees/decision_tree.cc"],
|
||||
hdrs = ["trees/decision_tree.h"],
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "trees_test",
|
||||
size = "small",
|
||||
srcs = ["trees/decision_tree_test.cc"],
|
||||
deps = [
|
||||
":trees",
|
||||
":utils",
|
||||
"//tensorflow/core:tensor_testutil",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "batch_features_testutil",
|
||||
testonly = 1,
|
||||
srcs = ["testutil/batch_features_testutil.cc"],
|
||||
hdrs = ["testutil/batch_features_testutil.h"],
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "random_tree_gen",
|
||||
srcs = ["testutil/random_tree_gen.cc"],
|
||||
hdrs = ["testutil/random_tree_gen.h"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "random_tree_gen_main",
|
||||
srcs = ["testutil/random_tree_gen_main.cc"],
|
||||
deps = [
|
||||
":random_tree_gen",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,140 @@
|
||||
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#include "tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h"
|
||||
#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
|
||||
#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
|
||||
#include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace boosted_trees {
|
||||
namespace models {
|
||||
|
||||
namespace {
|
||||
void CalculateTreesToKeep(
|
||||
const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
|
||||
const std::vector<int32>& trees_to_drop, const int32 num_trees,
|
||||
const bool only_finalized, std::vector<int32>* trees_to_keep) {
|
||||
trees_to_keep->reserve(num_trees - trees_to_drop.size());
|
||||
|
||||
int32 index = 0;
|
||||
// This assumes that trees_to_drop is a sorted list of tree ids.
|
||||
for (int32 tree = 0; tree < num_trees; ++tree) {
|
||||
if ((!trees_to_drop.empty() && index < trees_to_drop.size() &&
|
||||
trees_to_drop[index] == tree) ||
|
||||
(only_finalized && config.tree_metadata_size() > 0 &&
|
||||
!config.tree_metadata(tree).is_finalized())) {
|
||||
++index;
|
||||
continue;
|
||||
}
|
||||
trees_to_keep->push_back(tree);
|
||||
}
|
||||
}
|
||||
|
||||
void UpdatePredictions(
|
||||
const int32 index_1, const int32 index_2, const float value,
|
||||
tensorflow::TTypes<float>::Matrix* output_predictions,
|
||||
tensorflow::TTypes<float>::Matrix* additional_output_predictions) {
|
||||
(*output_predictions)(index_1, index_2) += value;
|
||||
|
||||
if (additional_output_predictions != nullptr) {
|
||||
(*additional_output_predictions)(index_1, index_2) += value;
|
||||
}
|
||||
}
|
||||
|
||||
void UpdatePredictionsBasedOnTree(
|
||||
const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
|
||||
const int32 tree_idx, const boosted_trees::utils::Example& example,
|
||||
tensorflow::TTypes<float>::Matrix* output_predictions,
|
||||
tensorflow::TTypes<float>::Matrix* additional_output_predictions) {
|
||||
const boosted_trees::trees::DecisionTreeConfig& tree = config.trees(tree_idx);
|
||||
const float tree_weight = config.tree_weights(tree_idx);
|
||||
const int leaf_idx = trees::DecisionTree::Traverse(tree, 0, example);
|
||||
QCHECK(leaf_idx >= 0) << "Invalid tree: " << tree.DebugString();
|
||||
const auto& leaf_node = tree.nodes(leaf_idx);
|
||||
QCHECK(leaf_node.has_leaf())
|
||||
<< "Invalid leaf node: " << leaf_node.DebugString();
|
||||
if (leaf_node.leaf().has_sparse_vector()) {
|
||||
const auto& leaf = leaf_node.leaf().sparse_vector();
|
||||
QCHECK_EQ(leaf.index_size(), leaf.value_size());
|
||||
for (size_t class_idx = 0; class_idx < leaf.index_size(); ++class_idx) {
|
||||
const float value = tree_weight * leaf.value(class_idx);
|
||||
|
||||
UpdatePredictions(example.example_idx, leaf.index(class_idx), value,
|
||||
output_predictions, additional_output_predictions);
|
||||
}
|
||||
} else {
|
||||
QCHECK(leaf_node.leaf().has_vector()) << "Unknown leaf type";
|
||||
const auto& leaf = leaf_node.leaf().vector();
|
||||
for (size_t i = 0; i < leaf.value_size(); ++i) {
|
||||
const float value = tree_weight * leaf.value(i);
|
||||
UpdatePredictions(example.example_idx, i, value, output_predictions,
|
||||
additional_output_predictions);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void MultipleAdditiveTrees::Predict(
|
||||
const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
|
||||
const bool only_finalized_trees, const std::vector<int32>& trees_to_drop,
|
||||
const boosted_trees::utils::BatchFeatures& features,
|
||||
tensorflow::thread::ThreadPool* worker_threads,
|
||||
tensorflow::TTypes<float>::Matrix output_predictions,
|
||||
tensorflow::TTypes<float>::Matrix no_dropout_predictions) {
|
||||
// Zero out predictions as the model is additive.
|
||||
output_predictions.setZero();
|
||||
no_dropout_predictions.setZero();
|
||||
|
||||
// Get batch size.
|
||||
const int64 batch_size = features.batch_size();
|
||||
if (batch_size <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Prepare the list of trees to keep.
|
||||
std::vector<int32> trees_to_keep;
|
||||
CalculateTreesToKeep(config, trees_to_drop, config.trees_size(),
|
||||
only_finalized_trees, &trees_to_keep);
|
||||
|
||||
// Lambda for doing a block of work.
|
||||
auto update_predictions = [&config, &features, &trees_to_keep, &trees_to_drop,
|
||||
&output_predictions,
|
||||
&no_dropout_predictions](int64 start, int64 end) {
|
||||
auto examples_iterable = features.examples_iterable(start, end);
|
||||
for (const auto& example : examples_iterable) {
|
||||
for (const int32 tree_idx : trees_to_keep) {
|
||||
UpdatePredictionsBasedOnTree(config, tree_idx, example,
|
||||
&output_predictions,
|
||||
&no_dropout_predictions);
|
||||
}
|
||||
|
||||
// Now do predictions for dropped trees
|
||||
for (const int32 tree_idx : trees_to_drop) {
|
||||
UpdatePredictionsBasedOnTree(config, tree_idx, example,
|
||||
&no_dropout_predictions, nullptr);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(salehay): parallelize this for low latency in serving path where
|
||||
// batch size tends to be small but ensemble size tends to be large.
|
||||
boosted_trees::utils::ParallelFor(batch_size, worker_threads->NumThreads(),
|
||||
worker_threads, update_predictions);
|
||||
}
|
||||
|
||||
} // namespace models
|
||||
} // namespace boosted_trees
|
||||
} // namespace tensorflow
|
@ -0,0 +1,50 @@
|
||||
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
|
||||
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" // NOLINT
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace boosted_trees {
|
||||
namespace models {
|
||||
|
||||
// Multiple additive trees prediction model.
|
||||
// This class does not hold state and is thread safe.
|
||||
class MultipleAdditiveTrees {
|
||||
public:
|
||||
// Predict runs tree ensemble on the given batch and updates
|
||||
// output predictions accordingly. The method also returns predictions that
|
||||
// we would get if no dropout was applied.
|
||||
static void Predict(
|
||||
const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
|
||||
const bool only_finalized_trees, const std::vector<int32>& trees_to_drop,
|
||||
const boosted_trees::utils::BatchFeatures& features,
|
||||
thread::ThreadPool* const thread_pool,
|
||||
TTypes<float>::Matrix output_predictions,
|
||||
TTypes<float>::Matrix no_dropout_predictions);
|
||||
};
|
||||
|
||||
} // namespace models
|
||||
} // namespace boosted_trees
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_
|
@ -0,0 +1,381 @@
|
||||
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#include "tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h"
|
||||
|
||||
#include "tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h"
|
||||
#include "tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h"
|
||||
#include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
using boosted_trees::trees::DecisionTreeEnsembleConfig;
|
||||
using test::AsTensor;
|
||||
|
||||
namespace boosted_trees {
|
||||
namespace models {
|
||||
namespace {
|
||||
|
||||
const int32 kNumThreadsMultiThreaded = 6;
|
||||
const int32 kNumThreadsSingleThreaded = 1;
|
||||
|
||||
class MultipleAdditiveTreesTest : public ::testing::Test {
|
||||
protected:
|
||||
MultipleAdditiveTreesTest() : batch_features_(2) {
|
||||
// Create a batch of two examples having one dense feature each.
|
||||
// The shape of the dense matrix is therefore 2x1 as in one row per example
|
||||
// and one column per feature per example.
|
||||
auto dense_matrix = test::AsTensor<float>({7.0f, -2.0f}, {2, 1});
|
||||
TF_EXPECT_OK(
|
||||
batch_features_.Initialize({dense_matrix}, {}, {}, {}, {}, {}, {}));
|
||||
}
|
||||
|
||||
boosted_trees::utils::BatchFeatures batch_features_;
|
||||
};
|
||||
|
||||
TEST_F(MultipleAdditiveTreesTest, Empty) {
|
||||
// Create empty tree ensemble.
|
||||
DecisionTreeEnsembleConfig tree_ensemble_config;
|
||||
auto output_tensor = AsTensor<float>({9.0f, 23.0f}, {2, 1});
|
||||
auto output_matrix = output_tensor.matrix<float>();
|
||||
auto no_dropout_output_matrix = output_tensor.matrix<float>();
|
||||
|
||||
// Predict for both instances.
|
||||
tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test",
|
||||
kNumThreadsSingleThreaded);
|
||||
MultipleAdditiveTrees::Predict(tree_ensemble_config,
|
||||
false, // include non-finalized trees
|
||||
{}, batch_features_, &threads, output_matrix,
|
||||
no_dropout_output_matrix);
|
||||
EXPECT_EQ(0, output_matrix(0, 0));
|
||||
EXPECT_EQ(0, output_matrix(1, 0));
|
||||
|
||||
// There was no dropout
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
EXPECT_EQ(output_matrix(i, 0), no_dropout_output_matrix(i, 0));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(MultipleAdditiveTreesTest, SingleClass) {
|
||||
// Add one bias and one stump to ensemble for a single class.
|
||||
DecisionTreeEnsembleConfig tree_ensemble_config;
|
||||
auto* tree1 = tree_ensemble_config.add_trees();
|
||||
auto* bias_leaf = tree1->add_nodes()->mutable_leaf()->mutable_sparse_vector();
|
||||
bias_leaf->add_index(0);
|
||||
bias_leaf->add_value(-0.4f);
|
||||
auto* tree2 = tree_ensemble_config.add_trees();
|
||||
auto* dense_split = tree2->add_nodes()->mutable_dense_float_binary_split();
|
||||
dense_split->set_feature_column(0);
|
||||
dense_split->set_threshold(5.0f);
|
||||
dense_split->set_left_id(1);
|
||||
dense_split->set_right_id(2);
|
||||
auto* leaf1 = tree2->add_nodes()->mutable_leaf()->mutable_sparse_vector();
|
||||
leaf1->add_index(0);
|
||||
leaf1->add_value(0.9f);
|
||||
auto* leaf2 = tree2->add_nodes()->mutable_leaf()->mutable_sparse_vector();
|
||||
leaf2->add_index(0);
|
||||
leaf2->add_value(0.2f);
|
||||
|
||||
tree_ensemble_config.add_tree_weights(1.0);
|
||||
tree_ensemble_config.add_tree_weights(1.0);
|
||||
|
||||
auto output_tensor = AsTensor<float>({0.0f, 0.0f}, {2, 1});
|
||||
auto output_matrix = output_tensor.matrix<float>();
|
||||
|
||||
auto no_dropout_output_tensor = AsTensor<float>({0.0f, 0.0f}, {2, 1});
|
||||
auto no_dropout_output_matrix = no_dropout_output_tensor.matrix<float>();
|
||||
|
||||
tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test",
|
||||
kNumThreadsSingleThreaded);
|
||||
|
||||
// Normal case.
|
||||
{
|
||||
MultipleAdditiveTrees::Predict(tree_ensemble_config,
|
||||
false, // include non-finalized trees
|
||||
{}, batch_features_, &threads, output_matrix,
|
||||
no_dropout_output_matrix);
|
||||
EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2).
|
||||
EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1).
|
||||
|
||||
// No dropout predictions are the same.
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
EXPECT_EQ(output_matrix(i, 0), no_dropout_output_matrix(i, 0));
|
||||
}
|
||||
}
|
||||
// Weighted case
|
||||
{
|
||||
DecisionTreeEnsembleConfig weighted = tree_ensemble_config;
|
||||
weighted.set_tree_weights(0, 6.0);
|
||||
weighted.set_tree_weights(1, 3.2);
|
||||
MultipleAdditiveTrees::Predict(weighted,
|
||||
false, // include non-finalized trees
|
||||
{}, batch_features_, &threads, output_matrix,
|
||||
no_dropout_output_matrix);
|
||||
// -0.4 (bias) + 0.2 (leaf 2).
|
||||
EXPECT_FLOAT_EQ(-0.4f * 6 + 0.2 * 3.2, output_matrix(0, 0));
|
||||
// -0.4 (bias) + 0.9 (leaf 1).
|
||||
EXPECT_FLOAT_EQ(-0.4f * 6 + 0.9 * 3.2, output_matrix(1, 0));
|
||||
|
||||
// No dropout predictions are the same.
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
EXPECT_EQ(output_matrix(i, 0), no_dropout_output_matrix(i, 0));
|
||||
}
|
||||
}
|
||||
// Drop first tree.
|
||||
{
|
||||
MultipleAdditiveTrees::Predict(tree_ensemble_config,
|
||||
false, // include non-finalized trees
|
||||
{0}, batch_features_, &threads,
|
||||
output_matrix, no_dropout_output_matrix);
|
||||
EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 0)); // 0.2 (leaf 2).
|
||||
EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 1).
|
||||
|
||||
// No dropout predictions
|
||||
EXPECT_FLOAT_EQ(
|
||||
-0.2f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2).
|
||||
EXPECT_FLOAT_EQ(
|
||||
0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1).
|
||||
}
|
||||
// Drop second tree.
|
||||
{
|
||||
MultipleAdditiveTrees::Predict(tree_ensemble_config,
|
||||
false, // include non-finalized trees
|
||||
{1}, batch_features_, &threads,
|
||||
output_matrix, no_dropout_output_matrix);
|
||||
EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias).
|
||||
EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias).
|
||||
|
||||
// No dropout predictions
|
||||
EXPECT_FLOAT_EQ(
|
||||
-0.2f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2).
|
||||
EXPECT_FLOAT_EQ(
|
||||
0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1).
|
||||
}
|
||||
// Drop all trees.
|
||||
{
|
||||
MultipleAdditiveTrees::Predict(tree_ensemble_config,
|
||||
false, // include non-finalized trees
|
||||
{0, 1}, batch_features_, &threads,
|
||||
output_matrix, no_dropout_output_matrix);
|
||||
EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0));
|
||||
EXPECT_FLOAT_EQ(0.0, output_matrix(1, 0));
|
||||
|
||||
// No dropout predictions
|
||||
EXPECT_FLOAT_EQ(
|
||||
-0.2f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2).
|
||||
EXPECT_FLOAT_EQ(
|
||||
0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1).
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(MultipleAdditiveTreesTest, MultiClass) {
|
||||
// Add one bias and one stump to ensemble for two classes.
|
||||
DecisionTreeEnsembleConfig tree_ensemble_config;
|
||||
auto* tree1 = tree_ensemble_config.add_trees();
|
||||
auto* bias_leaf = tree1->add_nodes()->mutable_leaf()->mutable_sparse_vector();
|
||||
bias_leaf->add_index(0);
|
||||
bias_leaf->add_value(-0.4f);
|
||||
bias_leaf->add_index(1);
|
||||
bias_leaf->add_value(-0.7f);
|
||||
auto* tree2 = tree_ensemble_config.add_trees();
|
||||
auto* dense_split = tree2->add_nodes()->mutable_dense_float_binary_split();
|
||||
dense_split->set_feature_column(0);
|
||||
dense_split->set_threshold(5.0f);
|
||||
dense_split->set_left_id(1);
|
||||
dense_split->set_right_id(2);
|
||||
auto* leaf1 = tree2->add_nodes()->mutable_leaf()->mutable_sparse_vector();
|
||||
leaf1->add_index(0);
|
||||
leaf1->add_value(0.9f);
|
||||
auto* leaf2 = tree2->add_nodes()->mutable_leaf()->mutable_sparse_vector();
|
||||
leaf2->add_index(1);
|
||||
leaf2->add_value(0.2f);
|
||||
|
||||
tree_ensemble_config.add_tree_weights(1.0);
|
||||
tree_ensemble_config.add_tree_weights(1.0);
|
||||
|
||||
// Predict for both instances.
|
||||
tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test",
|
||||
kNumThreadsSingleThreaded);
|
||||
auto output_tensor = AsTensor<float>({0.0f, 0.0f, 0.0f, 0.0f}, {2, 2});
|
||||
auto output_matrix = output_tensor.matrix<float>();
|
||||
|
||||
auto no_dropout_output_tensor =
|
||||
AsTensor<float>({0.0f, 0.0f, 0.0f, 0.0f}, {2, 2});
|
||||
auto no_dropout_output_matrix = no_dropout_output_tensor.matrix<float>();
|
||||
|
||||
// Normal case.
|
||||
{
|
||||
MultipleAdditiveTrees::Predict(tree_ensemble_config,
|
||||
false, // include non-finalized trees
|
||||
{}, batch_features_, &threads, output_matrix,
|
||||
no_dropout_output_matrix);
|
||||
EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias)
|
||||
EXPECT_FLOAT_EQ(-0.5f, output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2)
|
||||
EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1)
|
||||
EXPECT_FLOAT_EQ(-0.7f, output_matrix(1, 1)); // -0.7 (bias)
|
||||
|
||||
// No dropout predictions are the same.
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
EXPECT_EQ(output_matrix(i, j), no_dropout_output_matrix(i, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
// Weighted case.
|
||||
{
|
||||
DecisionTreeEnsembleConfig weighted = tree_ensemble_config;
|
||||
weighted.set_tree_weights(0, 6.0);
|
||||
weighted.set_tree_weights(1, 3.2);
|
||||
MultipleAdditiveTrees::Predict(weighted,
|
||||
false, // include non-finalized trees
|
||||
{}, batch_features_, &threads, output_matrix,
|
||||
no_dropout_output_matrix);
|
||||
// bias
|
||||
EXPECT_FLOAT_EQ(-0.4f * 6, output_matrix(0, 0));
|
||||
// bias + leaf 2
|
||||
EXPECT_FLOAT_EQ(-0.7f * 6 + 0.2f * 3.2, output_matrix(0, 1));
|
||||
// bias + leaf 2
|
||||
EXPECT_FLOAT_EQ(-0.4f * 6 + 0.9f * 3.2f, output_matrix(1, 0));
|
||||
// bias
|
||||
EXPECT_FLOAT_EQ(-0.7f * 6, output_matrix(1, 1));
|
||||
}
|
||||
// Dropout first tree.
|
||||
{
|
||||
MultipleAdditiveTrees::Predict(tree_ensemble_config,
|
||||
false, // include non-finalized trees
|
||||
{0}, batch_features_, &threads,
|
||||
output_matrix, no_dropout_output_matrix);
|
||||
EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0));
|
||||
EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 1)); // 0.2 (leaf 2)
|
||||
EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 2)
|
||||
EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 1));
|
||||
|
||||
// No dropout predictions
|
||||
EXPECT_FLOAT_EQ(-0.4f, no_dropout_output_matrix(0, 0)); // -0.4 (bias)
|
||||
EXPECT_FLOAT_EQ(
|
||||
-0.5f, no_dropout_output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2)
|
||||
EXPECT_FLOAT_EQ(
|
||||
0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 2)
|
||||
EXPECT_FLOAT_EQ(-0.7f, no_dropout_output_matrix(1, 1)); // -0.7 (bias)
|
||||
}
|
||||
// Dropout second tree.
|
||||
{
|
||||
MultipleAdditiveTrees::Predict(tree_ensemble_config,
|
||||
false, // include non-finalized trees
|
||||
{1}, batch_features_, &threads,
|
||||
output_matrix, no_dropout_output_matrix);
|
||||
EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias)
|
||||
EXPECT_FLOAT_EQ(-0.7f, output_matrix(0, 1)); // -0.7 (bias)
|
||||
EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias)
|
||||
EXPECT_FLOAT_EQ(-0.7f, output_matrix(1, 1)); // -0.7 (bias)
|
||||
|
||||
// No dropout predictions
|
||||
EXPECT_FLOAT_EQ(-0.4f, no_dropout_output_matrix(0, 0)); // -0.4 (bias)
|
||||
EXPECT_FLOAT_EQ(
|
||||
-0.5f, no_dropout_output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2)
|
||||
EXPECT_FLOAT_EQ(
|
||||
0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 2)
|
||||
EXPECT_FLOAT_EQ(-0.7f, no_dropout_output_matrix(1, 1)); // -0.7 (bias)
|
||||
}
|
||||
// Drop both trees.
|
||||
{
|
||||
MultipleAdditiveTrees::Predict(tree_ensemble_config,
|
||||
false, // include non-finalized trees
|
||||
{0, 1}, batch_features_, &threads,
|
||||
output_matrix, no_dropout_output_matrix);
|
||||
EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 0));
|
||||
EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 1));
|
||||
EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 0));
|
||||
EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 1));
|
||||
|
||||
// No dropout predictions
|
||||
EXPECT_FLOAT_EQ(-0.4f, no_dropout_output_matrix(0, 0)); // -0.4 (bias)
|
||||
EXPECT_FLOAT_EQ(
|
||||
-0.5f, no_dropout_output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2)
|
||||
EXPECT_FLOAT_EQ(
|
||||
0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 2)
|
||||
EXPECT_FLOAT_EQ(-0.7f, no_dropout_output_matrix(1, 1)); // -0.7 (bias)
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(MultipleAdditiveTreesTest, DenseLeaves) {
|
||||
DecisionTreeEnsembleConfig tree_ensemble_config;
|
||||
auto* tree1 = tree_ensemble_config.add_trees();
|
||||
auto* bias_leaf = tree1->add_nodes()->mutable_leaf()->mutable_vector();
|
||||
bias_leaf->add_value(-0.4f);
|
||||
bias_leaf->add_value(-0.7f);
|
||||
bias_leaf->add_value(3.0f);
|
||||
auto* tree2 = tree_ensemble_config.add_trees();
|
||||
auto* dense_split = tree2->add_nodes()->mutable_dense_float_binary_split();
|
||||
dense_split->set_feature_column(0);
|
||||
dense_split->set_threshold(5.0f);
|
||||
dense_split->set_left_id(1);
|
||||
dense_split->set_right_id(2);
|
||||
auto* leaf1 = tree2->add_nodes()->mutable_leaf()->mutable_vector();
|
||||
leaf1->add_value(0.9f);
|
||||
leaf1->add_value(0.8f);
|
||||
leaf1->add_value(0.7f);
|
||||
auto* leaf2 = tree2->add_nodes()->mutable_leaf()->mutable_vector();
|
||||
leaf2->add_value(0.2f);
|
||||
leaf2->add_value(0.3f);
|
||||
leaf2->add_value(0.4f);
|
||||
|
||||
tree_ensemble_config.add_tree_weights(1.0);
|
||||
tree_ensemble_config.add_tree_weights(1.0);
|
||||
|
||||
// Predict for both instances.
|
||||
tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test",
|
||||
kNumThreadsSingleThreaded);
|
||||
auto output_tensor =
|
||||
AsTensor<float>({0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, {2, 3});
|
||||
auto output_matrix = output_tensor.matrix<float>();
|
||||
|
||||
auto no_dropout_output_tensor =
|
||||
AsTensor<float>({0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, {2, 3});
|
||||
auto no_dropout_output_matrix = no_dropout_output_tensor.matrix<float>();
|
||||
|
||||
// Normal case.
|
||||
{
|
||||
MultipleAdditiveTrees::Predict(tree_ensemble_config,
|
||||
false, // include non-finalized trees
|
||||
{}, batch_features_, &threads, output_matrix,
|
||||
no_dropout_output_matrix);
|
||||
EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (tree1) + 0.2 (leaf 2)
|
||||
EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 1)); // -0.7 (tree1) + 0.3 (leaf 2)
|
||||
EXPECT_FLOAT_EQ(3.4f, output_matrix(0, 2)); // 3.0 -(tree1) + 0.4 (leaf 2)
|
||||
EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (tree1) + 0.9 (leaf 1)
|
||||
EXPECT_FLOAT_EQ(0.1f, output_matrix(1, 1)); // -0.7 (tree1) + 0.8 (leaf 1)
|
||||
EXPECT_FLOAT_EQ(3.7f, output_matrix(1, 2)); // 3.0 (tree1) + 0.7 (leaf 1)
|
||||
|
||||
// No dropout predictions are the same.
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
EXPECT_EQ(output_matrix(i, j), no_dropout_output_matrix(i, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace models
|
||||
} // namespace boosted_trees
|
||||
} // namespace tensorflow
|
@ -0,0 +1,88 @@
|
||||
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#include "tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h"
|
||||
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace boosted_trees {
|
||||
namespace testutil {
|
||||
|
||||
using tensorflow::Tensor;
|
||||
|
||||
void RandomlyInitializeBatchFeatures(
|
||||
tensorflow::random::SimplePhilox* rng, uint32 num_dense_float_features,
|
||||
uint32 num_sparse_float_features, double sparsity_lo, double sparsity_hi,
|
||||
boosted_trees::utils::BatchFeatures* batch_features) {
|
||||
const int64 batch_size = static_cast<int64>(batch_features->batch_size());
|
||||
|
||||
// Populate dense features.
|
||||
std::vector<tensorflow::Tensor> dense_float_features_list;
|
||||
for (int i = 0; i < num_dense_float_features; ++i) {
|
||||
std::vector<float> values;
|
||||
for (int64 j = 0; j < batch_size; ++j) {
|
||||
values.push_back(rng->RandFloat());
|
||||
}
|
||||
auto dense_tensor = Tensor(tensorflow::DT_FLOAT, {batch_size, 1});
|
||||
tensorflow::test::FillValues<float>(&dense_tensor, values);
|
||||
dense_float_features_list.push_back(dense_tensor);
|
||||
}
|
||||
|
||||
// Populate sparse features.
|
||||
std::vector<tensorflow::Tensor> sparse_float_feature_indices_list;
|
||||
std::vector<tensorflow::Tensor> sparse_float_feature_values_list;
|
||||
std::vector<tensorflow::Tensor> sparse_float_feature_shapes_list;
|
||||
for (int i = 0; i < num_sparse_float_features; ++i) {
|
||||
std::set<uint64> indices;
|
||||
const double sparsity =
|
||||
sparsity_lo + rng->RandDouble() * (sparsity_hi - sparsity_lo);
|
||||
const double density = 1 - sparsity;
|
||||
for (int64 k = 0; k < static_cast<int64>(density * batch_size) + 1; ++k) {
|
||||
indices.insert(rng->Uniform64(batch_size));
|
||||
}
|
||||
const int64 sparse_values_size = indices.size();
|
||||
std::vector<int64> indices_vector;
|
||||
for (auto idx : indices) {
|
||||
indices_vector.push_back(idx);
|
||||
indices_vector.push_back(0);
|
||||
}
|
||||
auto indices_tensor = Tensor(tensorflow::DT_INT64, {sparse_values_size, 2});
|
||||
tensorflow::test::FillValues<int64>(&indices_tensor, indices_vector);
|
||||
sparse_float_feature_indices_list.push_back(indices_tensor);
|
||||
|
||||
std::vector<float> values;
|
||||
for (int64 j = 0; j < sparse_values_size; ++j) {
|
||||
values.push_back(rng->RandFloat());
|
||||
}
|
||||
auto values_tensor = Tensor(tensorflow::DT_FLOAT, {sparse_values_size});
|
||||
tensorflow::test::FillValues<float>(&values_tensor, values);
|
||||
sparse_float_feature_values_list.push_back(values_tensor);
|
||||
|
||||
auto shape_tensor = Tensor(tensorflow::DT_INT64, {2});
|
||||
tensorflow::test::FillValues<int64>(&shape_tensor, {batch_size, 1});
|
||||
sparse_float_feature_shapes_list.push_back(shape_tensor);
|
||||
}
|
||||
|
||||
// TODO(salehay): Add categorical feature generation support.
|
||||
TF_EXPECT_OK(batch_features->Initialize(
|
||||
dense_float_features_list, sparse_float_feature_indices_list,
|
||||
sparse_float_feature_values_list, sparse_float_feature_shapes_list, {},
|
||||
{}, {}));
|
||||
}
|
||||
|
||||
} // namespace testutil
|
||||
} // namespace boosted_trees
|
||||
} // namespace tensorflow
|
@ -0,0 +1,45 @@
|
||||
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_
|
||||
|
||||
#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace boosted_trees {
|
||||
namespace testutil {
|
||||
|
||||
// This method calls Initialize on the given 'batch_features', which will be
|
||||
// populated with randomly generated feature values when the call returns.
|
||||
// 'tensors' returns a vector of all tensors used in the initialization,
|
||||
// because they must outlive 'batch_features'.
|
||||
//
|
||||
// All float features will be either missing or uniformly randomly chosen
|
||||
// from [0, 1). For sparse (float) features, a sparsity is uniformly randomly
|
||||
// chosen from ['sparsity_lo', 'sparsity_hi') per feature, and each instance
|
||||
// will have a probability of sparsity of missing that feature, in other words,
|
||||
// sparsity = 1 - density.
|
||||
void RandomlyInitializeBatchFeatures(
|
||||
tensorflow::random::SimplePhilox* rng, uint32 num_dense_float_features,
|
||||
uint32 num_sparse_float_features, double sparsity_lo, double sparsity_hi,
|
||||
boosted_trees::utils::BatchFeatures* batch_features);
|
||||
|
||||
} // namespace testutil
|
||||
} // namespace boosted_trees
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_
|
211
tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.cc
Normal file
211
tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.cc
Normal file
@ -0,0 +1,211 @@
|
||||
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#include "tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h"
|
||||
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace boosted_trees {
|
||||
namespace testutil {
|
||||
|
||||
using tensorflow::boosted_trees::trees::DecisionTreeConfig;
|
||||
using tensorflow::boosted_trees::trees::TreeNode;
|
||||
using boosted_trees::trees::DenseFloatBinarySplit;
|
||||
|
||||
namespace {
|
||||
|
||||
// Append the given nodes to tree with transfer of pointer ownership.
|
||||
// nodes will not be usable upon return.
|
||||
template <typename T>
|
||||
void AppendNodes(DecisionTreeConfig* tree, T* nodes) {
|
||||
std::reverse(nodes->pointer_begin(), nodes->pointer_end());
|
||||
while (!nodes->empty()) {
|
||||
tree->mutable_nodes()->AddAllocated(nodes->ReleaseLast());
|
||||
}
|
||||
}
|
||||
|
||||
DenseFloatBinarySplit* GetSplit(TreeNode* node) {
|
||||
switch (node->node_case()) {
|
||||
case TreeNode::kSparseFloatBinarySplitDefaultLeft:
|
||||
return node->mutable_sparse_float_binary_split_default_left()
|
||||
->mutable_split();
|
||||
case TreeNode::kSparseFloatBinarySplitDefaultRight:
|
||||
return node->mutable_sparse_float_binary_split_default_right()
|
||||
->mutable_split();
|
||||
case TreeNode::kDenseFloatBinarySplit:
|
||||
return node->mutable_dense_float_binary_split();
|
||||
default:
|
||||
LOG(FATAL) << "Unknown node type encountered.";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
RandomTreeGen::RandomTreeGen(tensorflow::random::SimplePhilox* rng,
|
||||
int dense_feature_size, int sparse_feature_size)
|
||||
: rng_(rng),
|
||||
dense_feature_size_(dense_feature_size),
|
||||
sparse_feature_size_(sparse_feature_size) {}
|
||||
|
||||
namespace {
|
||||
void AddWeightAndMetadata(
|
||||
boosted_trees::trees::DecisionTreeEnsembleConfig* ret) {
|
||||
// Assign the weight of the tree to 1 and say that this weight was updated
|
||||
// only once.
|
||||
ret->add_tree_weights(1.0);
|
||||
auto* meta = ret->add_tree_metadata();
|
||||
meta->set_num_tree_weight_updates(1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
boosted_trees::trees::DecisionTreeEnsembleConfig
|
||||
RandomTreeGen::GenerateEnsemble(int depth, int tree_count) {
|
||||
boosted_trees::trees::DecisionTreeEnsembleConfig ret;
|
||||
*(ret.add_trees()) = Generate(depth);
|
||||
AddWeightAndMetadata(&ret);
|
||||
for (int i = 1; i < tree_count; ++i) {
|
||||
*(ret.add_trees()) = Generate(ret.trees(0));
|
||||
AddWeightAndMetadata(&ret);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
DecisionTreeConfig RandomTreeGen::Generate(const DecisionTreeConfig& tree) {
|
||||
DecisionTreeConfig ret = tree;
|
||||
for (auto& node : *ret.mutable_nodes()) {
|
||||
if (node.node_case() == TreeNode::kLeaf) {
|
||||
node.mutable_leaf()->mutable_sparse_vector()->set_value(
|
||||
0, rng_->RandFloat());
|
||||
continue;
|
||||
}
|
||||
// Original node is a split. Re-generate it's type but retain the split node
|
||||
// indices.
|
||||
DenseFloatBinarySplit* split = GetSplit(&node);
|
||||
const int left_id = split->left_id();
|
||||
const int right_id = split->right_id();
|
||||
GenerateSplit(&node, left_id, right_id);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
DecisionTreeConfig RandomTreeGen::Generate(int depth) {
|
||||
DecisionTreeConfig ret;
|
||||
// Add root,
|
||||
TreeNode* node = ret.add_nodes();
|
||||
GenerateSplit(node, 1, 2);
|
||||
if (depth == 1) {
|
||||
// Add left and right leaves.
|
||||
TreeNode* left = ret.add_nodes();
|
||||
left->mutable_leaf()->mutable_sparse_vector()->add_index(0);
|
||||
left->mutable_leaf()->mutable_sparse_vector()->add_value(rng_->RandFloat());
|
||||
TreeNode* right = ret.add_nodes();
|
||||
right->mutable_leaf()->mutable_sparse_vector()->add_index(0);
|
||||
right->mutable_leaf()->mutable_sparse_vector()->add_value(
|
||||
rng_->RandFloat());
|
||||
return ret;
|
||||
} else {
|
||||
DecisionTreeConfig left_branch = Generate(depth - 1);
|
||||
DecisionTreeConfig right_branch = Generate(depth - 1);
|
||||
Combine(&ret, &left_branch, &right_branch);
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
void RandomTreeGen::Combine(DecisionTreeConfig* root,
|
||||
DecisionTreeConfig* left_branch,
|
||||
DecisionTreeConfig* right_branch) {
|
||||
const int left_branch_size = left_branch->nodes_size();
|
||||
CHECK_EQ(1, root->nodes_size());
|
||||
// left_branch starts its index at 1. right_branch starts its index at
|
||||
// (left_branch_size + 1).
|
||||
auto* root_node = root->mutable_nodes(0);
|
||||
DenseFloatBinarySplit* root_split = GetSplit(root_node);
|
||||
root_split->set_left_id(1);
|
||||
root_split->set_right_id(left_branch_size + 1);
|
||||
// Shift left/right branch's indices internally so that everything is
|
||||
// consistent.
|
||||
ShiftNodeIndex(left_branch, 1);
|
||||
ShiftNodeIndex(right_branch, left_branch_size + 1);
|
||||
|
||||
// Complexity O(branch node size). No proto copying though.
|
||||
AppendNodes(root, left_branch->mutable_nodes());
|
||||
AppendNodes(root, right_branch->mutable_nodes());
|
||||
}
|
||||
|
||||
void RandomTreeGen::ShiftNodeIndex(DecisionTreeConfig* tree, int shift) {
|
||||
for (TreeNode& node : *(tree->mutable_nodes())) {
|
||||
DenseFloatBinarySplit* split = nullptr;
|
||||
switch (node.node_case()) {
|
||||
case TreeNode::kLeaf:
|
||||
break;
|
||||
case TreeNode::kSparseFloatBinarySplitDefaultLeft:
|
||||
split = node.mutable_sparse_float_binary_split_default_left()
|
||||
->mutable_split();
|
||||
break;
|
||||
case TreeNode::kSparseFloatBinarySplitDefaultRight:
|
||||
split = node.mutable_sparse_float_binary_split_default_right()
|
||||
->mutable_split();
|
||||
break;
|
||||
case TreeNode::kDenseFloatBinarySplit:
|
||||
split = node.mutable_dense_float_binary_split();
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown node type encountered.";
|
||||
}
|
||||
if (split) {
|
||||
split->set_left_id(shift + split->left_id());
|
||||
split->set_right_id(shift + split->right_id());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RandomTreeGen::GenerateSplit(TreeNode* node, int left_id, int right_id) {
|
||||
const double denseSplitProb =
|
||||
sparse_feature_size_ == 0
|
||||
? 1.0
|
||||
: static_cast<double>(dense_feature_size_) /
|
||||
(dense_feature_size_ + sparse_feature_size_);
|
||||
// Generate the tree such that it has equal probability of going left and
|
||||
// right when the feature is missing.
|
||||
static constexpr float kLeftProb = 0.5;
|
||||
|
||||
DenseFloatBinarySplit* split;
|
||||
int feature_size;
|
||||
if (rng_->RandFloat() < denseSplitProb) {
|
||||
feature_size = dense_feature_size_;
|
||||
split = node->mutable_dense_float_binary_split();
|
||||
} else {
|
||||
feature_size = sparse_feature_size_;
|
||||
if (rng_->RandFloat() < kLeftProb) {
|
||||
split = node->mutable_sparse_float_binary_split_default_left()
|
||||
->mutable_split();
|
||||
} else {
|
||||
split = node->mutable_sparse_float_binary_split_default_right()
|
||||
->mutable_split();
|
||||
}
|
||||
}
|
||||
split->set_threshold(rng_->RandFloat());
|
||||
split->set_feature_column(rng_->Uniform(feature_size));
|
||||
split->set_left_id(left_id);
|
||||
split->set_right_id(right_id);
|
||||
}
|
||||
|
||||
} // namespace testutil
|
||||
} // namespace boosted_trees
|
||||
} // namespace tensorflow
|
@ -0,0 +1,75 @@
|
||||
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" // NOLINT
|
||||
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace boosted_trees {
|
||||
namespace testutil {
|
||||
|
||||
// Randomly generate a balanced tree, for performance benchmarking purposes,
|
||||
// that assume all features are sparse float features, for now.
|
||||
class RandomTreeGen {
|
||||
public:
|
||||
RandomTreeGen(tensorflow::random::SimplePhilox* rng, int dense_feature_size,
|
||||
int sparse_feature_size);
|
||||
|
||||
// Required: depth must be >= 1.
|
||||
// If one wants to generate multiple trees with the same depth, see also the
|
||||
// overload below.
|
||||
boosted_trees::trees::DecisionTreeConfig Generate(int depth);
|
||||
|
||||
// Randomly generate a new tree with the same depth (and tree structure)
|
||||
// as the given tree. This is faster.
|
||||
boosted_trees::trees::DecisionTreeConfig Generate(
|
||||
const boosted_trees::trees::DecisionTreeConfig& tree);
|
||||
|
||||
// Requried: depth >= 1; tree_count >= 1.
|
||||
boosted_trees::trees::DecisionTreeEnsembleConfig GenerateEnsemble(
|
||||
int dept, int tree_count);
|
||||
|
||||
private:
|
||||
tensorflow::random::SimplePhilox* rng_;
|
||||
const int dense_feature_size_;
|
||||
const int sparse_feature_size_;
|
||||
|
||||
// Put together a deeper tree by combining two trees.
|
||||
void Combine(boosted_trees::trees::DecisionTreeConfig* root,
|
||||
boosted_trees::trees::DecisionTreeConfig* left_branch,
|
||||
boosted_trees::trees::DecisionTreeConfig* right_branch);
|
||||
|
||||
// For each node in the provided tree, shift its referenced left/right index
|
||||
// by shift.
|
||||
void ShiftNodeIndex(boosted_trees::trees::DecisionTreeConfig* tree,
|
||||
int shift);
|
||||
|
||||
// Generate a sparse split in the node.
|
||||
void GenerateSplit(boosted_trees::trees::TreeNode* node, int left_id,
|
||||
int right_id);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RandomTreeGen);
|
||||
};
|
||||
|
||||
} // namespace testutil
|
||||
} // namespace boosted_trees
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_
|
@ -0,0 +1,67 @@
|
||||
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
// Randomly generate a tree ensemble and write to file.
|
||||
|
||||
#include "tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
using tensorflow::Flag;
|
||||
using tensorflow::Flags;
|
||||
using tensorflow::int32;
|
||||
using tensorflow::string;
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
int32 dense_feature_size = 100;
|
||||
int32 sparse_feature_size = 100;
|
||||
int32 depth = 8;
|
||||
int32 tree_count = 10;
|
||||
string filename = "/tmp/trees.pb";
|
||||
std::vector<Flag> flag_list = {
|
||||
Flag("dense_feature_size", &dense_feature_size, "dense feature size"),
|
||||
Flag("sparse_feature_size", &sparse_feature_size, "sparse_feature_size"),
|
||||
Flag("depth", &depth, "tree depth"),
|
||||
Flag("tree_count", &tree_count, "tree count"),
|
||||
Flag("filename", &filename, "Output filename."),
|
||||
};
|
||||
string usage = Flags::Usage(argv[0], flag_list);
|
||||
const bool parse_result = Flags::Parse(&argc, argv, flag_list);
|
||||
// We need to call this to set up global state for TensorFlow.
|
||||
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
|
||||
if (!parse_result) {
|
||||
LOG(ERROR) << "\n" << usage;
|
||||
return -1;
|
||||
}
|
||||
|
||||
tensorflow::random::PhiloxRandom philox(1);
|
||||
tensorflow::random::SimplePhilox rng(&philox);
|
||||
tensorflow::boosted_trees::testutil::RandomTreeGen tree_gen(
|
||||
&rng, dense_feature_size, sparse_feature_size);
|
||||
const auto& trees = tree_gen.GenerateEnsemble(depth, tree_count);
|
||||
tensorflow::Status status =
|
||||
tensorflow::WriteBinaryProto(tensorflow::Env::Default(), filename, trees);
|
||||
if (!status.ok()) {
|
||||
LOG(WARNING) << "Failed to write: " << filename << " : " << status;
|
||||
} else {
|
||||
LOG(INFO) << "Tree ensemble written to: " << filename;
|
||||
}
|
||||
return 0;
|
||||
}
|
170
tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc
Normal file
170
tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc
Normal file
@ -0,0 +1,170 @@
|
||||
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace boosted_trees {
|
||||
namespace trees {
|
||||
|
||||
constexpr int kInvalidLeaf = -1;
|
||||
int DecisionTree::Traverse(const DecisionTreeConfig& config,
|
||||
const int32 sub_root_id,
|
||||
const utils::Example& example) {
|
||||
if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) {
|
||||
return kInvalidLeaf;
|
||||
}
|
||||
|
||||
// Traverse tree starting at the provided sub-root.
|
||||
int32 node_id = sub_root_id;
|
||||
while (true) {
|
||||
const auto& current_node = config.nodes(node_id);
|
||||
switch (current_node.node_case()) {
|
||||
case TreeNode::kLeaf: {
|
||||
return node_id;
|
||||
}
|
||||
case TreeNode::kDenseFloatBinarySplit: {
|
||||
const auto& split = current_node.dense_float_binary_split();
|
||||
node_id = example.dense_float_features[split.feature_column()] <=
|
||||
split.threshold()
|
||||
? split.left_id()
|
||||
: split.right_id();
|
||||
break;
|
||||
}
|
||||
case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
|
||||
const auto& split =
|
||||
current_node.sparse_float_binary_split_default_left().split();
|
||||
auto sparse_feature =
|
||||
example.sparse_float_features[split.feature_column()];
|
||||
node_id = !sparse_feature.has_value() ||
|
||||
sparse_feature.get_value() <= split.threshold()
|
||||
? split.left_id()
|
||||
: split.right_id();
|
||||
break;
|
||||
}
|
||||
case TreeNode::kSparseFloatBinarySplitDefaultRight: {
|
||||
const auto& split =
|
||||
current_node.sparse_float_binary_split_default_right().split();
|
||||
auto sparse_feature =
|
||||
example.sparse_float_features[split.feature_column()];
|
||||
node_id = sparse_feature.has_value() &&
|
||||
sparse_feature.get_value() <= split.threshold()
|
||||
? split.left_id()
|
||||
: split.right_id();
|
||||
break;
|
||||
}
|
||||
case TreeNode::kCategoricalIdBinarySplit: {
|
||||
const auto& split = current_node.categorical_id_binary_split();
|
||||
node_id = example.sparse_int_features[split.feature_column()].count(
|
||||
split.feature_id()) > 0
|
||||
? split.left_id()
|
||||
: split.right_id();
|
||||
break;
|
||||
}
|
||||
case TreeNode::NODE_NOT_SET: {
|
||||
QCHECK(false) << "Invalid node in tree: " << current_node.DebugString();
|
||||
break;
|
||||
}
|
||||
}
|
||||
DCHECK_NE(node_id, 0) << "Malformed tree, cycles found to root:"
|
||||
<< current_node.DebugString();
|
||||
}
|
||||
}
|
||||
|
||||
void DecisionTree::LinkChildren(const std::vector<int32>& children,
|
||||
TreeNode* parent_node) {
|
||||
// Decide how to link children depending on the parent node's type.
|
||||
auto children_it = children.begin();
|
||||
switch (parent_node->node_case()) {
|
||||
case TreeNode::kLeaf: {
|
||||
// Essentially no-op.
|
||||
QCHECK(children.empty()) << "A leaf node cannot have children.";
|
||||
break;
|
||||
}
|
||||
case TreeNode::kDenseFloatBinarySplit: {
|
||||
QCHECK(children.size() == 2)
|
||||
<< "A binary split node must have exactly two children.";
|
||||
auto* split = parent_node->mutable_dense_float_binary_split();
|
||||
split->set_left_id(*children_it);
|
||||
split->set_right_id(*++children_it);
|
||||
break;
|
||||
}
|
||||
case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
|
||||
QCHECK(children.size() == 2)
|
||||
<< "A binary split node must have exactly two children.";
|
||||
auto* split =
|
||||
parent_node->mutable_sparse_float_binary_split_default_left()
|
||||
->mutable_split();
|
||||
split->set_left_id(*children_it);
|
||||
split->set_right_id(*++children_it);
|
||||
break;
|
||||
}
|
||||
case TreeNode::kSparseFloatBinarySplitDefaultRight: {
|
||||
QCHECK(children.size() == 2)
|
||||
<< "A binary split node must have exactly two children.";
|
||||
auto* split =
|
||||
parent_node->mutable_sparse_float_binary_split_default_right()
|
||||
->mutable_split();
|
||||
split->set_left_id(*children_it);
|
||||
split->set_right_id(*++children_it);
|
||||
break;
|
||||
}
|
||||
case TreeNode::kCategoricalIdBinarySplit: {
|
||||
QCHECK(children.size() == 2)
|
||||
<< "A binary split node must have exactly two children.";
|
||||
auto* split = parent_node->mutable_categorical_id_binary_split();
|
||||
split->set_left_id(*children_it);
|
||||
split->set_right_id(*++children_it);
|
||||
break;
|
||||
}
|
||||
case TreeNode::NODE_NOT_SET: {
|
||||
QCHECK(false) << "A non-set node cannot have children.";
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int32> DecisionTree::GetChildren(const TreeNode& node) {
|
||||
// A node's children depend on its type.
|
||||
switch (node.node_case()) {
|
||||
case TreeNode::kLeaf: {
|
||||
return {};
|
||||
}
|
||||
case TreeNode::kDenseFloatBinarySplit: {
|
||||
const auto& split = node.dense_float_binary_split();
|
||||
return {split.left_id(), split.right_id()};
|
||||
}
|
||||
case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
|
||||
const auto& split = node.sparse_float_binary_split_default_left().split();
|
||||
return {split.left_id(), split.right_id()};
|
||||
}
|
||||
case TreeNode::kSparseFloatBinarySplitDefaultRight: {
|
||||
const auto& split =
|
||||
node.sparse_float_binary_split_default_right().split();
|
||||
return {split.left_id(), split.right_id()};
|
||||
}
|
||||
case TreeNode::kCategoricalIdBinarySplit: {
|
||||
const auto& split = node.categorical_id_binary_split();
|
||||
return {split.left_id(), split.right_id()};
|
||||
}
|
||||
case TreeNode::NODE_NOT_SET: {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace trees
|
||||
} // namespace boosted_trees
|
||||
} // namespace tensorflow
|
49
tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h
Normal file
49
tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h
Normal file
@ -0,0 +1,49 @@
|
||||
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_
|
||||
|
||||
#include "tensorflow/contrib/boosted_trees/lib/utils/example.h"
|
||||
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" // NOLINT
|
||||
|
||||
namespace tensorflow {
|
||||
namespace boosted_trees {
|
||||
namespace trees {
|
||||
|
||||
// Decision tree class to encapsulate tree traversal and mutation logic.
|
||||
// This class does not hold state and is thread safe.
|
||||
class DecisionTree {
|
||||
public:
|
||||
// Traverse given an instance, a sub-root and its set of features
|
||||
// and return the leaf index or -1 if the tree is empty or
|
||||
// the sub-root is invalid.
|
||||
static int Traverse(const DecisionTreeConfig& config, int32 sub_root_id,
|
||||
const utils::Example& example);
|
||||
|
||||
// Links the specified children to the parent, the children must
|
||||
// already be added to the decision tree config so this method
|
||||
// just ensures nodes are re-linked.
|
||||
static void LinkChildren(const std::vector<int32>& children,
|
||||
TreeNode* parent_node);
|
||||
|
||||
// Retrieves node children indices if any.
|
||||
static std::vector<int32> GetChildren(const TreeNode& node);
|
||||
};
|
||||
|
||||
} // namespace trees
|
||||
} // namespace boosted_trees
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_
|
326
tensorflow/contrib/boosted_trees/lib/trees/decision_tree_test.cc
Normal file
326
tensorflow/contrib/boosted_trees/lib/trees/decision_tree_test.cc
Normal file
@ -0,0 +1,326 @@
|
||||
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
|
||||
#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace boosted_trees {
|
||||
namespace trees {
|
||||
namespace {
|
||||
|
||||
class DecisionTreeTest : public ::testing::Test {
|
||||
protected:
|
||||
DecisionTreeTest() : batch_features_(2) {
|
||||
// Create a batch of two examples having one dense float, two sparse float
|
||||
// and one sparse int features.
|
||||
// The first example is missing the second sparse feature column and the
|
||||
// second example is missing the first sparse feature column.
|
||||
// This looks like the following:
|
||||
// Instance | DenseF1 | SparseF1 | SparseF2 | SparseI1 |
|
||||
// 0 | 7 | -3 | | 3 |
|
||||
// 1 | -2 | | 4 | |
|
||||
auto dense_float_matrix = test::AsTensor<float>({7.0f, -2.0f}, {2, 1});
|
||||
auto sparse_float_indices1 = test::AsTensor<int64>({0, 0}, {1, 2});
|
||||
auto sparse_float_values1 = test::AsTensor<float>({-3.0f});
|
||||
auto sparse_float_shape1 = test::AsTensor<int64>({2, 1});
|
||||
auto sparse_float_indices2 = test::AsTensor<int64>({1, 0}, {1, 2});
|
||||
auto sparse_float_values2 = test::AsTensor<float>({4.0f});
|
||||
auto sparse_float_shape2 = test::AsTensor<int64>({2, 1});
|
||||
auto sparse_int_indices1 = test::AsTensor<int64>({0, 0}, {1, 2});
|
||||
auto sparse_int_values1 = test::AsTensor<int64>({3});
|
||||
auto sparse_int_shape1 = test::AsTensor<int64>({2, 1});
|
||||
TF_EXPECT_OK(batch_features_.Initialize(
|
||||
{dense_float_matrix}, {sparse_float_indices1, sparse_float_indices2},
|
||||
{sparse_float_values1, sparse_float_values2},
|
||||
{sparse_float_shape1, sparse_float_shape2}, {sparse_int_indices1},
|
||||
{sparse_int_values1}, {sparse_int_shape1}));
|
||||
}
|
||||
|
||||
template <typename SplitType>
|
||||
void TestLinkChildrenBinary(TreeNode* node, SplitType* split) {
|
||||
// Verify children were linked.
|
||||
DecisionTree::LinkChildren({3, 8}, node);
|
||||
EXPECT_EQ(3, split->left_id());
|
||||
EXPECT_EQ(8, split->right_id());
|
||||
|
||||
// Invalid cases.
|
||||
EXPECT_DEATH(DecisionTree::LinkChildren({}, node),
|
||||
"A binary split node must have exactly two children.");
|
||||
EXPECT_DEATH(DecisionTree::LinkChildren({3}, node),
|
||||
"A binary split node must have exactly two children.");
|
||||
EXPECT_DEATH(DecisionTree::LinkChildren({1, 2, 3}, node),
|
||||
"A binary split node must have exactly two children.");
|
||||
}
|
||||
|
||||
void TestGetChildren(const TreeNode& node,
|
||||
const std::vector<uint32>& expected_children) {
|
||||
// Verify children were linked.
|
||||
auto children = DecisionTree::GetChildren(node);
|
||||
EXPECT_EQ(children.size(), expected_children.size());
|
||||
for (size_t idx = 0; idx < children.size(); ++idx) {
|
||||
EXPECT_EQ(children[idx], expected_children[idx]);
|
||||
}
|
||||
}
|
||||
|
||||
utils::BatchFeatures batch_features_;
|
||||
};
|
||||
|
||||
TEST_F(DecisionTreeTest, TraverseEmpty) {
|
||||
DecisionTreeConfig tree_config;
|
||||
auto example = (*batch_features_.examples_iterable(0, 1).begin());
|
||||
EXPECT_EQ(-1, DecisionTree::Traverse(tree_config, 0, example));
|
||||
}
|
||||
|
||||
TEST_F(DecisionTreeTest, TraverseBias) {
|
||||
DecisionTreeConfig tree_config;
|
||||
tree_config.add_nodes()->mutable_leaf();
|
||||
auto example = (*batch_features_.examples_iterable(0, 1).begin());
|
||||
EXPECT_EQ(0, DecisionTree::Traverse(tree_config, 0, example));
|
||||
}
|
||||
|
||||
TEST_F(DecisionTreeTest, TraverseInvalidSubRoot) {
|
||||
DecisionTreeConfig tree_config;
|
||||
tree_config.add_nodes()->mutable_leaf();
|
||||
auto example = (*batch_features_.examples_iterable(0, 1).begin());
|
||||
EXPECT_EQ(-1, DecisionTree::Traverse(tree_config, 10, example));
|
||||
}
|
||||
|
||||
TEST_F(DecisionTreeTest, TraverseDenseBinarySplit) {
|
||||
DecisionTreeConfig tree_config;
|
||||
auto* split_node =
|
||||
tree_config.add_nodes()->mutable_dense_float_binary_split();
|
||||
split_node->set_feature_column(0);
|
||||
split_node->set_threshold(0.0f);
|
||||
split_node->set_left_id(1);
|
||||
split_node->set_right_id(2);
|
||||
tree_config.add_nodes()->mutable_leaf();
|
||||
tree_config.add_nodes()->mutable_leaf();
|
||||
auto example_iterable = batch_features_.examples_iterable(0, 2);
|
||||
|
||||
// Expect right child to be picked as !(7 <= 0);
|
||||
auto example_it = example_iterable.begin();
|
||||
EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it));
|
||||
|
||||
// Expect left child to be picked as (-2 <= 0);
|
||||
EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it));
|
||||
}
|
||||
|
||||
TEST_F(DecisionTreeTest, TraverseSparseBinarySplit) {
|
||||
// Test first sparse feature which is missing for the second example.
|
||||
DecisionTreeConfig tree_config1;
|
||||
auto* split_node1 = tree_config1.add_nodes()
|
||||
->mutable_sparse_float_binary_split_default_left()
|
||||
->mutable_split();
|
||||
split_node1->set_feature_column(0);
|
||||
split_node1->set_threshold(-20.0f);
|
||||
split_node1->set_left_id(1);
|
||||
split_node1->set_right_id(2);
|
||||
tree_config1.add_nodes()->mutable_leaf();
|
||||
tree_config1.add_nodes()->mutable_leaf();
|
||||
auto example_iterable = batch_features_.examples_iterable(0, 2);
|
||||
|
||||
// Expect right child to be picked as !(-3 <= -20).
|
||||
auto example_it = example_iterable.begin();
|
||||
EXPECT_EQ(2, DecisionTree::Traverse(tree_config1, 0, *example_it));
|
||||
|
||||
// Expect left child to be picked as default direction.
|
||||
EXPECT_EQ(1, DecisionTree::Traverse(tree_config1, 0, *++example_it));
|
||||
|
||||
// Test second sparse feature which is missing for the first example.
|
||||
DecisionTreeConfig tree_config2;
|
||||
auto* split_node2 = tree_config2.add_nodes()
|
||||
->mutable_sparse_float_binary_split_default_right()
|
||||
->mutable_split();
|
||||
split_node2->set_feature_column(1);
|
||||
split_node2->set_threshold(4.0f);
|
||||
split_node2->set_left_id(1);
|
||||
split_node2->set_right_id(2);
|
||||
tree_config2.add_nodes()->mutable_leaf();
|
||||
tree_config2.add_nodes()->mutable_leaf();
|
||||
|
||||
// Expect right child to be picked as default direction.
|
||||
example_it = example_iterable.begin();
|
||||
EXPECT_EQ(2, DecisionTree::Traverse(tree_config2, 0, *example_it));
|
||||
|
||||
// Expect left child to be picked as (4 <= 4).
|
||||
EXPECT_EQ(1, DecisionTree::Traverse(tree_config2, 0, *++example_it));
|
||||
}
|
||||
|
||||
TEST_F(DecisionTreeTest, TraverseCategoricalIdBinarySplit) {
|
||||
DecisionTreeConfig tree_config;
|
||||
auto* split_node =
|
||||
tree_config.add_nodes()->mutable_categorical_id_binary_split();
|
||||
split_node->set_feature_column(0);
|
||||
split_node->set_feature_id(3);
|
||||
split_node->set_left_id(1);
|
||||
split_node->set_right_id(2);
|
||||
tree_config.add_nodes()->mutable_leaf();
|
||||
tree_config.add_nodes()->mutable_leaf();
|
||||
auto example_iterable = batch_features_.examples_iterable(0, 2);
|
||||
|
||||
// Expect left child to be picked as 3 == 3;
|
||||
auto example_it = example_iterable.begin();
|
||||
EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *example_it));
|
||||
|
||||
// Expect right child to be picked as the feature is missing;
|
||||
EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *++example_it));
|
||||
}
|
||||
|
||||
TEST_F(DecisionTreeTest, TraverseHybridSplits) {
|
||||
DecisionTreeConfig tree_config;
|
||||
auto* split_node1 =
|
||||
tree_config.add_nodes()->mutable_dense_float_binary_split();
|
||||
split_node1->set_feature_column(0);
|
||||
split_node1->set_threshold(9.0f);
|
||||
split_node1->set_left_id(1); // sparse split.
|
||||
split_node1->set_right_id(2); // leaf
|
||||
auto* split_node2 = tree_config.add_nodes()
|
||||
->mutable_sparse_float_binary_split_default_left()
|
||||
->mutable_split();
|
||||
tree_config.add_nodes()->mutable_leaf();
|
||||
split_node2->set_feature_column(0);
|
||||
split_node2->set_threshold(-20.0f);
|
||||
split_node2->set_left_id(3);
|
||||
split_node2->set_right_id(4);
|
||||
auto* split_node3 =
|
||||
tree_config.add_nodes()->mutable_categorical_id_binary_split();
|
||||
split_node3->set_feature_column(0);
|
||||
split_node3->set_feature_id(2);
|
||||
split_node3->set_left_id(5);
|
||||
split_node3->set_right_id(6);
|
||||
tree_config.add_nodes()->mutable_leaf();
|
||||
tree_config.add_nodes()->mutable_leaf();
|
||||
tree_config.add_nodes()->mutable_leaf();
|
||||
auto example_iterable = batch_features_.examples_iterable(0, 2);
|
||||
|
||||
// Expect will go left through the first dense split as (7.0f <= 9.0f),
|
||||
// then will go right through the sparse split as !(-3 <= -20).
|
||||
auto example_it = example_iterable.begin();
|
||||
EXPECT_EQ(4, DecisionTree::Traverse(tree_config, 0, *example_it));
|
||||
|
||||
// Expect will go left through the first dense split as (-2.0f <= 9.0f),
|
||||
// then will go left the default direction as the sparse feature is missing,
|
||||
// then will go right as 2 != 3 on the categorical split.
|
||||
EXPECT_EQ(6, DecisionTree::Traverse(tree_config, 0, *++example_it));
|
||||
}
|
||||
|
||||
TEST_F(DecisionTreeTest, LinkChildrenLeaf) {
|
||||
// Create leaf node.
|
||||
TreeNode node;
|
||||
node.mutable_leaf();
|
||||
|
||||
// No-op.
|
||||
DecisionTree::LinkChildren({}, &node);
|
||||
|
||||
// Invalid case.
|
||||
EXPECT_DEATH(DecisionTree::LinkChildren({1}, &node),
|
||||
"A leaf node cannot have children.");
|
||||
}
|
||||
|
||||
TEST_F(DecisionTreeTest, LinkChildrenDenseFloatBinarySplit) {
|
||||
TreeNode node;
|
||||
auto* split = node.mutable_dense_float_binary_split();
|
||||
split->set_left_id(-1);
|
||||
split->set_right_id(-1);
|
||||
TestLinkChildrenBinary(&node, split);
|
||||
}
|
||||
|
||||
TEST_F(DecisionTreeTest, LinkChildrenSparseFloatBinarySplitDefaultLeft) {
|
||||
TreeNode node;
|
||||
auto* split =
|
||||
node.mutable_sparse_float_binary_split_default_left()->mutable_split();
|
||||
split->set_left_id(-1);
|
||||
split->set_right_id(-1);
|
||||
TestLinkChildrenBinary(&node, split);
|
||||
}
|
||||
|
||||
TEST_F(DecisionTreeTest, LinkChildrenSparseFloatBinarySplitDefaultRight) {
|
||||
TreeNode node;
|
||||
auto* split =
|
||||
node.mutable_sparse_float_binary_split_default_right()->mutable_split();
|
||||
split->set_left_id(-1);
|
||||
split->set_right_id(-1);
|
||||
TestLinkChildrenBinary(&node, split);
|
||||
}
|
||||
|
||||
TEST_F(DecisionTreeTest, LinkChildrenCategoricalSingleIdBinarySplit) {
|
||||
TreeNode node;
|
||||
auto* split = node.mutable_categorical_id_binary_split();
|
||||
split->set_left_id(-1);
|
||||
split->set_right_id(-1);
|
||||
TestLinkChildrenBinary(&node, split);
|
||||
}
|
||||
|
||||
TEST_F(DecisionTreeTest, LinkChildrenNodeNotSet) {
|
||||
// Create unset node.
|
||||
TreeNode node;
|
||||
|
||||
// Invalid case.
|
||||
EXPECT_DEATH(DecisionTree::LinkChildren({1}, &node),
|
||||
"A non-set node cannot have children.");
|
||||
}
|
||||
|
||||
TEST_F(DecisionTreeTest, GetChildrenLeaf) {
|
||||
TreeNode node;
|
||||
node.mutable_leaf();
|
||||
TestGetChildren(node, {});
|
||||
}
|
||||
|
||||
TEST_F(DecisionTreeTest, GetChildrenDenseFloatBinarySplit) {
|
||||
TreeNode node;
|
||||
auto* split = node.mutable_dense_float_binary_split();
|
||||
split->set_left_id(23);
|
||||
split->set_right_id(24);
|
||||
TestGetChildren(node, {23, 24});
|
||||
}
|
||||
|
||||
TEST_F(DecisionTreeTest, GetChildrenSparseFloatBinarySplitDefaultLeft) {
|
||||
TreeNode node;
|
||||
auto* split =
|
||||
node.mutable_sparse_float_binary_split_default_left()->mutable_split();
|
||||
split->set_left_id(12);
|
||||
split->set_right_id(13);
|
||||
TestGetChildren(node, {12, 13});
|
||||
}
|
||||
|
||||
TEST_F(DecisionTreeTest, GetChildrenSparseFloatBinarySplitDefaultRight) {
|
||||
TreeNode node;
|
||||
auto* split =
|
||||
node.mutable_sparse_float_binary_split_default_right()->mutable_split();
|
||||
split->set_left_id(1);
|
||||
split->set_right_id(2);
|
||||
TestGetChildren(node, {1, 2});
|
||||
}
|
||||
|
||||
TEST_F(DecisionTreeTest, GetChildrenCategoricalSingleIdBinarySplit) {
|
||||
TreeNode node;
|
||||
auto* split = node.mutable_categorical_id_binary_split();
|
||||
split->set_left_id(7);
|
||||
split->set_right_id(8);
|
||||
TestGetChildren(node, {7, 8});
|
||||
}
|
||||
|
||||
TEST_F(DecisionTreeTest, GetChildrenNodeNotSet) {
|
||||
TreeNode node;
|
||||
TestGetChildren(node, {});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace trees
|
||||
} // namespace boosted_trees
|
||||
} // namespace tensorflow
|
@ -24,6 +24,15 @@ tf_proto_library(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
tf_proto_library(
|
||||
name = "quantiles_proto",
|
||||
srcs = [
|
||||
"quantiles.proto",
|
||||
],
|
||||
cc_api_version = 2,
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
tf_proto_library(
|
||||
name = "tree_config_proto",
|
||||
srcs = ["tree_config.proto"],
|
||||
|
32
tensorflow/contrib/boosted_trees/proto/quantiles.proto
Normal file
32
tensorflow/contrib/boosted_trees/proto/quantiles.proto
Normal file
@ -0,0 +1,32 @@
|
||||
syntax = "proto3";
|
||||
|
||||
option cc_enable_arenas = true;
|
||||
|
||||
package boosted_trees;
|
||||
|
||||
message QuantileConfig {
|
||||
// Maximum eps error when computing quantile summaries.
|
||||
double eps = 1;
|
||||
// Number of quantiles to generate.
|
||||
int64 num_quantiles = 2;
|
||||
}
|
||||
|
||||
message QuantileEntry {
|
||||
// Value for the entry.
|
||||
float value = 1;
|
||||
// Weight for the entry.
|
||||
float weight = 2;
|
||||
// We need the minimum and maximum rank possible for this entry.
|
||||
// Rank is 0.0 for the absolute minimum and sum of the weights for the maximum
|
||||
// value in the input.
|
||||
float min_rank = 3;
|
||||
float max_rank = 4;
|
||||
}
|
||||
|
||||
message QuantileSummaryState {
|
||||
repeated QuantileEntry entries = 1;
|
||||
}
|
||||
|
||||
message QuantileStreamState {
|
||||
repeated QuantileSummaryState summaries = 1;
|
||||
}
|
53
tensorflow/contrib/boosted_trees/resources/BUILD
Normal file
53
tensorflow/contrib/boosted_trees/resources/BUILD
Normal file
@ -0,0 +1,53 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/contrib/boosted_trees:__subpackages__",
|
||||
"//tensorflow/contrib/boosted_trees:friends",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stamped_resource",
|
||||
hdrs = ["stamped_resource.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "quantile_stream_resource",
|
||||
hdrs = ["quantile_stream_resource.h"],
|
||||
deps = [
|
||||
":stamped_resource",
|
||||
"//tensorflow/contrib/boosted_trees/lib:weighted_quantiles",
|
||||
"//tensorflow/contrib/boosted_trees/proto:quantiles_proto_cc",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "decision_tree_ensemble_resource",
|
||||
hdrs = ["decision_tree_ensemble_resource.h"],
|
||||
deps = [
|
||||
":stamped_resource",
|
||||
"//tensorflow/contrib/boosted_trees/lib:trees",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
@ -0,0 +1,77 @@
|
||||
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_
|
||||
|
||||
#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
|
||||
#include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace boosted_trees {
|
||||
namespace models {
|
||||
|
||||
// Keep a tree ensemble in memory for efficient evaluation and mutation.
|
||||
class DecisionTreeEnsembleResource : public StampedResource {
|
||||
public:
|
||||
// Constructor.
|
||||
explicit DecisionTreeEnsembleResource()
|
||||
: decision_tree_ensemble_(
|
||||
protobuf::Arena::CreateMessage<
|
||||
boosted_trees::trees::DecisionTreeEnsembleConfig>(&arena_)) {}
|
||||
|
||||
string DebugString() override {
|
||||
return strings::StrCat("GTFlowDecisionTreeEnsemble[size=",
|
||||
decision_tree_ensemble_->trees_size(), "]");
|
||||
}
|
||||
|
||||
const boosted_trees::trees::DecisionTreeEnsembleConfig&
|
||||
decision_tree_ensemble() const {
|
||||
return *decision_tree_ensemble_;
|
||||
}
|
||||
|
||||
boosted_trees::trees::DecisionTreeEnsembleConfig*
|
||||
mutable_decision_tree_ensemble() {
|
||||
return decision_tree_ensemble_;
|
||||
}
|
||||
|
||||
// Resets the resource and frees the protos in arena.
|
||||
// Caller needs to hold the mutex lock while calling this.
|
||||
void Reset() {
|
||||
// Reset stamp.
|
||||
set_stamp(-1);
|
||||
|
||||
// Clear tree ensemle.
|
||||
arena_.Reset();
|
||||
CHECK_EQ(0, arena_.SpaceAllocated());
|
||||
decision_tree_ensemble_ = protobuf::Arena::CreateMessage<
|
||||
boosted_trees::trees::DecisionTreeEnsembleConfig>(&arena_);
|
||||
}
|
||||
|
||||
mutex* get_mutex() { return &mu_; }
|
||||
|
||||
private:
|
||||
protobuf::Arena arena_;
|
||||
mutex mu_;
|
||||
boosted_trees::trees::DecisionTreeEnsembleConfig* decision_tree_ensemble_;
|
||||
};
|
||||
|
||||
} // namespace models
|
||||
} // namespace boosted_trees
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_
|
@ -0,0 +1,104 @@
|
||||
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_
|
||||
|
||||
#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h"
|
||||
#include "tensorflow/contrib/boosted_trees/proto/quantiles.pb.h" // NOLINT
|
||||
#include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace boosted_trees {
|
||||
|
||||
using QuantileStream =
|
||||
boosted_trees::quantiles::WeightedQuantilesStream<float, float>;
|
||||
|
||||
// Resource for accumulating summaries for multiple columns.
|
||||
class QuantileStreamResource : public StampedResource {
|
||||
public:
|
||||
QuantileStreamResource(const float epsilon, const int32 num_quantiles,
|
||||
const int64 max_elements, int64 stamp_token)
|
||||
: stream_(epsilon, max_elements),
|
||||
are_buckets_ready_(false),
|
||||
epsilon_(epsilon),
|
||||
num_quantiles_(num_quantiles),
|
||||
max_elements_(max_elements) {
|
||||
set_stamp(stamp_token);
|
||||
}
|
||||
|
||||
string DebugString() override { return "QuantileStreamResource"; }
|
||||
|
||||
tensorflow::mutex* mutex() { return &mu_; }
|
||||
|
||||
QuantileStream* stream(int64 stamp) {
|
||||
CHECK(is_stamp_valid(stamp));
|
||||
return &stream_;
|
||||
}
|
||||
|
||||
const std::vector<float>& boundaries(int64 stamp) {
|
||||
CHECK(is_stamp_valid(stamp));
|
||||
return boundaries_;
|
||||
}
|
||||
|
||||
void set_boundaries(int64 stamp, const std::vector<float>& boundaries) {
|
||||
CHECK(is_stamp_valid(stamp));
|
||||
are_buckets_ready_ = true;
|
||||
boundaries_ = boundaries;
|
||||
}
|
||||
|
||||
float epsilon() const { return epsilon_; }
|
||||
int32 num_quantiles() const { return num_quantiles_; }
|
||||
|
||||
void Reset(int64 stamp) {
|
||||
set_stamp(stamp);
|
||||
stream_ = QuantileStream(epsilon_, max_elements_);
|
||||
}
|
||||
|
||||
bool are_buckets_ready() const { return are_buckets_ready_; }
|
||||
void set_buckets_ready(bool are_buckets_ready) {
|
||||
are_buckets_ready_ = are_buckets_ready;
|
||||
}
|
||||
|
||||
private:
|
||||
~QuantileStreamResource() override {}
|
||||
|
||||
// Mutex for the whole resource.
|
||||
tensorflow::mutex mu_;
|
||||
|
||||
// Quantile stream.
|
||||
QuantileStream stream_;
|
||||
|
||||
// Stores the boundaries from the previous iteration. Empty during the first
|
||||
// iteration.
|
||||
std::vector<float> boundaries_;
|
||||
|
||||
// Whether boundaries are created. Initially boundaries are empty until
|
||||
// set_boundaries are called.
|
||||
bool are_buckets_ready_;
|
||||
|
||||
const float epsilon_;
|
||||
const int32 num_quantiles_;
|
||||
// An upper-bound for the number of elements.
|
||||
int64 max_elements_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(QuantileStreamResource);
|
||||
};
|
||||
|
||||
} // namespace boosted_trees
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_
|
@ -0,0 +1,42 @@
|
||||
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_
|
||||
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace boosted_trees {
|
||||
|
||||
// A StampedResource is a resource that has a stamp token associated with it.
|
||||
// Before reading from or applying updates to the resource, the stamp should
|
||||
// be checked to verify that the update is not stale.
|
||||
class StampedResource : public ResourceBase {
|
||||
public:
|
||||
StampedResource() : stamp_(-1) {}
|
||||
|
||||
bool is_stamp_valid(int64 stamp) const { return stamp_ == stamp; }
|
||||
|
||||
int64 stamp() const { return stamp_; }
|
||||
void set_stamp(int64 stamp) { stamp_ = stamp; }
|
||||
|
||||
private:
|
||||
int64 stamp_;
|
||||
};
|
||||
|
||||
} // namespace boosted_trees
|
||||
} // namespace tensorflow
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_
|
@ -37,6 +37,9 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/layers/ops/bucketization_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/blas_gemm.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/gru_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/lstm_ops.cc"
|
||||
|
@ -58,6 +58,7 @@ GENERATE_CONTRIB_OP_LIBRARY(image "${tensorflow_source_dir}/tensorflow/contrib/i
|
||||
GENERATE_CONTRIB_OP_LIBRARY(layers_bucketization "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/bucketization_op.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(layers_sparse_feature_cross "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(memory_stats "${tensorflow_source_dir}/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(nccl "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(rnn_gru "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/gru_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(rnn_lstm "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/lstm_ops.cc")
|
||||
GENERATE_CONTRIB_OP_LIBRARY(tensor_forest "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/tensor_forest_ops.cc")
|
||||
|
@ -111,6 +111,7 @@ file(GLOB_RECURSE tf_protos_python_srcs RELATIVE ${tensorflow_source_dir}
|
||||
"${tensorflow_source_dir}/tensorflow/core/*.proto"
|
||||
"${tensorflow_source_dir}/tensorflow/python/*.proto"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto"
|
||||
"${tensorflow_source_dir}/tensorflow/tensorboard/*.proto"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/training/*.proto"
|
||||
)
|
||||
@ -124,6 +125,7 @@ RELATIVE_PROTOBUF_GENERATE_PYTHON(
|
||||
file(GLOB_RECURSE tf_python_protos_cc_srcs RELATIVE ${tensorflow_source_dir}
|
||||
"${tensorflow_source_dir}/tensorflow/python/*.proto"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto"
|
||||
"${tensorflow_source_dir}/tensorflow/tensorboard/*.proto"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/training/*.proto"
|
||||
)
|
||||
@ -342,6 +344,9 @@ add_python_module("tensorflow/contrib/keras/python/keras/layers")
|
||||
add_python_module("tensorflow/contrib/keras/python/keras/preprocessing")
|
||||
add_python_module("tensorflow/contrib/keras/python/keras/utils")
|
||||
add_python_module("tensorflow/contrib/keras/python/keras/wrappers")
|
||||
add_python_module("tensorflow/contrib/kernel_methods")
|
||||
add_python_module("tensorflow/contrib/kernel_methods/python")
|
||||
add_python_module("tensorflow/contrib/kernel_methods/python/mappers")
|
||||
add_python_module("tensorflow/contrib/labeled_tensor")
|
||||
add_python_module("tensorflow/contrib/labeled_tensor/python")
|
||||
add_python_module("tensorflow/contrib/labeled_tensor/python/ops")
|
||||
@ -405,6 +410,11 @@ add_python_module("tensorflow/contrib/ndlstm/python")
|
||||
add_python_module("tensorflow/contrib/nn")
|
||||
add_python_module("tensorflow/contrib/nn/python")
|
||||
add_python_module("tensorflow/contrib/nn/python/ops")
|
||||
add_python_module("tensorflow/contrib/nccl")
|
||||
add_python_module("tensorflow/contrib/nccl/kernels")
|
||||
add_python_module("tensorflow/contrib/nccl/ops")
|
||||
add_python_module("tensorflow/contrib/nccl/python")
|
||||
add_python_module("tensorflow/contrib/nccl/python/ops")
|
||||
add_python_module("tensorflow/contrib/opt")
|
||||
add_python_module("tensorflow/contrib/opt/python")
|
||||
add_python_module("tensorflow/contrib/opt/python/training")
|
||||
@ -599,6 +609,8 @@ GENERATE_PYTHON_OP_LIB("contrib_layers_sparse_feature_cross_ops"
|
||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/layers/ops/gen_sparse_feature_cross_op.py)
|
||||
GENERATE_PYTHON_OP_LIB("contrib_memory_stats_ops"
|
||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/memory_stats/ops/gen_memory_stats_ops.py)
|
||||
GENERATE_PYTHON_OP_LIB("contrib_nccl_ops"
|
||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/nccl/ops/gen_nccl_ops.py)
|
||||
GENERATE_PYTHON_OP_LIB("contrib_rnn_gru_ops"
|
||||
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rnn/ops/gen_gru_ops.py)
|
||||
GENERATE_PYTHON_OP_LIB("contrib_rnn_lstm_ops"
|
||||
|
@ -69,19 +69,25 @@ class BinomialTest(test.TestCase):
|
||||
self.assertEqual((1, 3), binom.logits.get_shape())
|
||||
self.assertAllClose(logits, binom.logits.eval())
|
||||
|
||||
def testPmfNandCountsAgree(self):
|
||||
def testPmfAndCdfNandCountsAgree(self):
|
||||
p = [[0.1, 0.2, 0.7]]
|
||||
n = [[5.]]
|
||||
with self.test_session():
|
||||
binom = binomial.Binomial(total_count=n, probs=p, validate_args=True)
|
||||
binom.prob([2., 3, 2]).eval()
|
||||
binom.prob([3., 1, 2]).eval()
|
||||
binom.cdf([2., 3, 2]).eval()
|
||||
binom.cdf([3., 1, 2]).eval()
|
||||
with self.assertRaisesOpError("Condition x >= 0.*"):
|
||||
binom.prob([-1., 4, 2]).eval()
|
||||
with self.assertRaisesOpError("Condition x <= y.*"):
|
||||
binom.prob([7., 3, 0]).eval()
|
||||
with self.assertRaisesOpError("Condition x >= 0.*"):
|
||||
binom.cdf([-1., 4, 2]).eval()
|
||||
with self.assertRaisesOpError("Condition x <= y.*"):
|
||||
binom.cdf([7., 3, 0]).eval()
|
||||
|
||||
def testPmfNonIntegerCounts(self):
|
||||
def testPmfAndCdfNonIntegerCounts(self):
|
||||
p = [[0.1, 0.2, 0.7]]
|
||||
n = [[5.]]
|
||||
with self.test_session():
|
||||
@ -89,50 +95,72 @@ class BinomialTest(test.TestCase):
|
||||
binom = binomial.Binomial(total_count=n, probs=p, validate_args=True)
|
||||
binom.prob([2., 3, 2]).eval()
|
||||
binom.prob([3., 1, 2]).eval()
|
||||
binom.cdf([2., 3, 2]).eval()
|
||||
binom.cdf([3., 1, 2]).eval()
|
||||
# Both equality and integer checking fail.
|
||||
with self.assertRaisesOpError(
|
||||
"cannot contain fractional components."):
|
||||
binom.prob([1.0, 2.5, 1.5]).eval()
|
||||
with self.assertRaisesOpError(
|
||||
"cannot contain fractional components."):
|
||||
binom.cdf([1.0, 2.5, 1.5]).eval()
|
||||
|
||||
binom = binomial.Binomial(total_count=n, probs=p, validate_args=False)
|
||||
binom.prob([1., 2., 3.]).eval()
|
||||
binom.cdf([1., 2., 3.]).eval()
|
||||
# Non-integer arguments work.
|
||||
binom.prob([1.0, 2.5, 1.5]).eval()
|
||||
binom.cdf([1.0, 2.5, 1.5]).eval()
|
||||
|
||||
def testPmfBothZeroBatches(self):
|
||||
def testPmfAndCdfBothZeroBatches(self):
|
||||
with self.test_session():
|
||||
# Both zero-batches. No broadcast
|
||||
p = 0.5
|
||||
counts = 1.
|
||||
pmf = binomial.Binomial(total_count=1., probs=p).prob(counts)
|
||||
binom = binomial.Binomial(total_count=1., probs=p)
|
||||
pmf = binom.prob(counts)
|
||||
cdf = binom.cdf(counts)
|
||||
self.assertAllClose(0.5, pmf.eval())
|
||||
self.assertAllClose(stats.binom.cdf(counts, n=1, p=p), cdf.eval())
|
||||
self.assertEqual((), pmf.get_shape())
|
||||
self.assertEqual((), cdf.get_shape())
|
||||
|
||||
def testPmfBothZeroBatchesNontrivialN(self):
|
||||
def testPmfAndCdfBothZeroBatchesNontrivialN(self):
|
||||
with self.test_session():
|
||||
# Both zero-batches. No broadcast
|
||||
p = 0.1
|
||||
counts = 3.
|
||||
binom = binomial.Binomial(total_count=5., probs=p)
|
||||
pmf = binom.prob(counts)
|
||||
cdf = binom.cdf(counts)
|
||||
self.assertAllClose(stats.binom.pmf(counts, n=5., p=p), pmf.eval())
|
||||
self.assertAllClose(stats.binom.cdf(counts, n=5., p=p), cdf.eval())
|
||||
self.assertEqual((), pmf.get_shape())
|
||||
self.assertEqual((), cdf.get_shape())
|
||||
|
||||
def testPmfPStretchedInBroadcastWhenSameRank(self):
|
||||
def testPmfAndCdfPStretchedInBroadcastWhenSameRank(self):
|
||||
with self.test_session():
|
||||
p = [[0.1, 0.9]]
|
||||
counts = [[1., 2.]]
|
||||
pmf = binomial.Binomial(total_count=3., probs=p).prob(counts)
|
||||
binom = binomial.Binomial(total_count=3., probs=p)
|
||||
pmf = binom.prob(counts)
|
||||
cdf = binom.cdf(counts)
|
||||
self.assertAllClose(stats.binom.pmf(counts, n=3., p=p), pmf.eval())
|
||||
self.assertAllClose(stats.binom.cdf(counts, n=3., p=p), cdf.eval())
|
||||
self.assertEqual((1, 2), pmf.get_shape())
|
||||
self.assertEqual((1, 2), cdf.get_shape())
|
||||
|
||||
def testPmfPStretchedInBroadcastWhenLowerRank(self):
|
||||
def testPmfAndCdfPStretchedInBroadcastWhenLowerRank(self):
|
||||
with self.test_session():
|
||||
p = [0.1, 0.4]
|
||||
counts = [[1.], [0.]]
|
||||
pmf = binomial.Binomial(total_count=1., probs=p).prob(counts)
|
||||
binom = binomial.Binomial(total_count=1., probs=p)
|
||||
pmf = binom.prob(counts)
|
||||
cdf = binom.cdf(counts)
|
||||
self.assertAllClose([[0.1, 0.4], [0.9, 0.6]], pmf.eval())
|
||||
self.assertAllClose([[1.0, 1.0], [0.9, 0.6]], cdf.eval())
|
||||
self.assertEqual((2, 2), pmf.get_shape())
|
||||
self.assertEqual((2, 2), cdf.get_shape())
|
||||
|
||||
def testBinomialMean(self):
|
||||
with self.test_session():
|
||||
|
@ -103,6 +103,31 @@ class MultivariateNormalDiagTest(test.TestCase):
|
||||
self.assertAllClose(cov_mat, np.cov(samps.T),
|
||||
atol=0.05, rtol=0.05)
|
||||
|
||||
def testSampleWithBroadcastScale(self):
|
||||
# mu corresponds to a 2-batch of 3-variate normals
|
||||
mu = np.zeros([2, 3])
|
||||
|
||||
# diag corresponds to no batches of 3-variate normals
|
||||
diag = np.ones([3])
|
||||
|
||||
with self.test_session():
|
||||
dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True)
|
||||
|
||||
mean = dist.mean()
|
||||
self.assertAllEqual([2, 3], mean.get_shape())
|
||||
self.assertAllClose(mu, mean.eval())
|
||||
|
||||
n = int(1e3)
|
||||
samps = dist.sample(n, seed=0).eval()
|
||||
cov_mat = array_ops.matrix_diag(diag).eval()**2
|
||||
sample_cov = np.matmul(samps.transpose([1, 2, 0]),
|
||||
samps.transpose([1, 0, 2])) / n
|
||||
|
||||
self.assertAllClose(mu, samps.mean(axis=0),
|
||||
atol=0.10, rtol=0.05)
|
||||
self.assertAllClose([cov_mat, cov_mat], sample_cov,
|
||||
atol=0.10, rtol=0.05)
|
||||
|
||||
def testCovariance(self):
|
||||
with self.test_session():
|
||||
mvn = ds.MultivariateNormalDiag(
|
||||
|
@ -42,6 +42,28 @@ to integer values.
|
||||
"""
|
||||
|
||||
|
||||
def _bdtr(k, n, p):
|
||||
"""The binomial cumulative distribution function.
|
||||
|
||||
Args:
|
||||
k: floating point `Tensor`.
|
||||
n: floating point `Tensor`.
|
||||
p: floating point `Tensor`.
|
||||
|
||||
Returns:
|
||||
`sum_{j=0}^k p^j (1 - p)^(n - j)`.
|
||||
"""
|
||||
# Trick for getting safe backprop/gradients into n, k when
|
||||
# betainc(a = 0, ..) = nan
|
||||
# Write:
|
||||
# where(unsafe, safe_output, betainc(where(unsafe, safe_input, input)))
|
||||
ones = array_ops.ones_like(n - k)
|
||||
k_eq_n = math_ops.equal(k, n)
|
||||
safe_dn = array_ops.where(k_eq_n, ones, n - k)
|
||||
dk = math_ops.betainc(a=safe_dn, b=k + 1, x=1 - p)
|
||||
return array_ops.where(k_eq_n, ones, dk)
|
||||
|
||||
|
||||
class Binomial(distribution.Distribution):
|
||||
"""Binomial distribution.
|
||||
|
||||
@ -201,6 +223,18 @@ class Binomial(distribution.Distribution):
|
||||
def _prob(self, counts):
|
||||
return math_ops.exp(self._log_prob(counts))
|
||||
|
||||
def _cdf(self, counts):
|
||||
counts = self._maybe_assert_valid_sample(counts)
|
||||
probs = self.probs
|
||||
if not (counts.shape.is_fully_defined()
|
||||
and self.probs.shape.is_fully_defined()
|
||||
and counts.shape.is_compatible_with(self.probs.shape)):
|
||||
# If both shapes are well defined and equal, we skip broadcasting.
|
||||
probs += array_ops.zeros_like(counts)
|
||||
counts += array_ops.zeros_like(self.probs)
|
||||
|
||||
return _bdtr(k=counts, n=self.total_count, p=probs)
|
||||
|
||||
def _log_unnormalized_prob(self, counts):
|
||||
counts = self._maybe_assert_valid_sample(counts)
|
||||
return (counts * math_ops.log(self.probs) +
|
||||
|
@ -25,6 +25,7 @@ from tensorflow.contrib.distributions.python.ops import kullback_leibler
|
||||
from tensorflow.contrib.distributions.python.ops import normal
|
||||
from tensorflow.contrib.distributions.python.ops import transformed_distribution
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
@ -53,6 +54,16 @@ or
|
||||
"""
|
||||
|
||||
|
||||
def _broadcast_shape(shape1, shape2):
|
||||
"""Convenience function which statically broadcasts shape when possible."""
|
||||
if (tensor_util.constant_value(shape1) is not None and
|
||||
tensor_util.constant_value(shape2) is not None):
|
||||
return array_ops.broadcast_static_shape(
|
||||
tensor_shape.TensorShape(tensor_util.constant_value(shape1)),
|
||||
tensor_shape.TensorShape(tensor_util.constant_value(shape2)))
|
||||
return array_ops.broadcast_dynamic_shape(shape1, shape2)
|
||||
|
||||
|
||||
# TODO(b/35290280): Import in `../../__init__.py` after adding unit-tests.
|
||||
class MultivariateNormalLinearOperator(
|
||||
transformed_distribution.TransformedDistribution):
|
||||
@ -179,12 +190,25 @@ class MultivariateNormalLinearOperator(
|
||||
if not scale.dtype.is_floating:
|
||||
raise TypeError("`scale` parameter must have floating-point dtype.")
|
||||
|
||||
# Since expand_dims doesn't preserve constant-ness, we obtain the
|
||||
# non-dynamic value if possible.
|
||||
event_shape = scale.domain_dimension_tensor()
|
||||
if tensor_util.constant_value(event_shape) is not None:
|
||||
event_shape = tensor_util.constant_value(event_shape)
|
||||
event_shape = event_shape[array_ops.newaxis]
|
||||
with ops.name_scope(name, values=[loc] + scale.graph_parents):
|
||||
# Since expand_dims doesn't preserve constant-ness, we obtain the
|
||||
# non-dynamic value if possible.
|
||||
event_shape = scale.range_dimension_tensor()
|
||||
if tensor_util.constant_value(event_shape) is not None:
|
||||
event_shape = tensor_util.constant_value(event_shape).reshape([1])
|
||||
else:
|
||||
event_shape = event_shape[array_ops.newaxis]
|
||||
batch_shape = scale.batch_shape_tensor()
|
||||
if loc is not None:
|
||||
loc = ops.convert_to_tensor(loc, name="loc")
|
||||
loc_batch_shape = loc.get_shape().with_rank_at_least(1)[:-1]
|
||||
if (loc.get_shape().ndims is None or
|
||||
not loc_batch_shape.is_fully_defined()):
|
||||
loc_batch_shape = array_ops.shape(loc)[:-1]
|
||||
else:
|
||||
loc_batch_shape = ops.convert_to_tensor(loc_batch_shape,
|
||||
name="loc_batch_shape")
|
||||
batch_shape = _broadcast_shape(batch_shape, loc_batch_shape)
|
||||
|
||||
super(MultivariateNormalLinearOperator, self).__init__(
|
||||
distribution=normal.Normal(
|
||||
@ -192,7 +216,7 @@ class MultivariateNormalLinearOperator(
|
||||
scale=array_ops.ones([], dtype=scale.dtype)),
|
||||
bijector=bijectors.AffineLinearOperator(
|
||||
shift=loc, scale=scale, validate_args=validate_args),
|
||||
batch_shape=scale.batch_shape_tensor(),
|
||||
batch_shape=batch_shape,
|
||||
event_shape=event_shape,
|
||||
validate_args=validate_args,
|
||||
name=name)
|
||||
|
@ -35,6 +35,7 @@ tf_custom_op_py_library(
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":factorization_ops_test_utils_py",
|
||||
":gen_clustering_ops",
|
||||
":gen_factorization_ops",
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
@ -161,12 +162,28 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "factorization_ops_test_utils_py",
|
||||
srcs = [
|
||||
"python/ops/factorization_ops_test_utils.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:embedding_ops",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:sparse_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "factorization_ops_test",
|
||||
srcs = ["python/ops/factorization_ops_test.py"],
|
||||
additional_deps = [
|
||||
":factorization_py",
|
||||
":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO",
|
||||
":factorization_ops_test_utils_py",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
@ -18,160 +18,56 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.contrib.factorization.python.ops import factorization_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.contrib.factorization.python.ops import factorization_ops_test_utils
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
INPUT_MATRIX = np.array(
|
||||
[[0.1, 0.0, 0.2, 0.0, 0.4, 0.5, 0.0],
|
||||
[0.0, 1.1, 0.0, 1.3, 1.4, 0.0, 1.6],
|
||||
[2.0, 0.0, 0.0, 2.3, 0.0, 2.5, 0.0],
|
||||
[3.0, 0.0, 3.2, 3.3, 0.0, 3.5, 0.0],
|
||||
[0.0, 4.1, 0.0, 0.0, 4.4, 0.0, 4.6]]).astype(np.float32)
|
||||
|
||||
|
||||
def np_matrix_to_tf_sparse(np_matrix,
|
||||
row_slices=None,
|
||||
col_slices=None,
|
||||
transpose=False,
|
||||
shuffle=False):
|
||||
"""Simple util to slice non-zero np matrix elements as tf.SparseTensor."""
|
||||
indices = np.nonzero(np_matrix)
|
||||
|
||||
# Only allow slices of whole rows or whole columns.
|
||||
assert not (row_slices is not None and col_slices is not None)
|
||||
|
||||
if row_slices is not None:
|
||||
selected_ind = np.concatenate(
|
||||
[np.where(indices[0] == r)[0] for r in row_slices], 0)
|
||||
indices = (indices[0][selected_ind], indices[1][selected_ind])
|
||||
|
||||
if col_slices is not None:
|
||||
selected_ind = np.concatenate(
|
||||
[np.where(indices[1] == c)[0] for c in col_slices], 0)
|
||||
indices = (indices[0][selected_ind], indices[1][selected_ind])
|
||||
|
||||
if shuffle:
|
||||
shuffled_ind = [x for x in range(len(indices[0]))]
|
||||
random.shuffle(shuffled_ind)
|
||||
indices = (indices[0][shuffled_ind], indices[1][shuffled_ind])
|
||||
|
||||
ind = (np.concatenate((np.expand_dims(indices[1], 1),
|
||||
np.expand_dims(indices[0], 1)), 1).astype(np.int64) if
|
||||
transpose else np.concatenate((np.expand_dims(indices[0], 1),
|
||||
np.expand_dims(indices[1], 1)),
|
||||
1).astype(np.int64))
|
||||
val = np_matrix[indices].astype(np.float32)
|
||||
shape = (np.array([max(indices[1]) + 1, max(indices[0]) + 1]).astype(np.int64)
|
||||
if transpose else np.array(
|
||||
[max(indices[0]) + 1, max(indices[1]) + 1]).astype(np.int64))
|
||||
return sparse_tensor.SparseTensor(ind, val, shape)
|
||||
|
||||
|
||||
def sparse_input():
|
||||
return np_matrix_to_tf_sparse(INPUT_MATRIX)
|
||||
|
||||
|
||||
def count_rows(sp_input):
|
||||
return math_ops.cast(
|
||||
array_ops.shape(array_ops.unique(sp_input.indices[:, 0])[0])[0],
|
||||
dtypes.float32)
|
||||
|
||||
|
||||
def count_cols(sp_input):
|
||||
return math_ops.cast(
|
||||
array_ops.shape(array_ops.unique(sp_input.indices[:, 1])[0])[0],
|
||||
dtypes.float32)
|
||||
|
||||
|
||||
def calculate_loss(input_mat, row_factors, col_factors, regularization=None,
|
||||
w0=1., row_weights=None, col_weights=None):
|
||||
"""Calculates the loss of a given factorization.
|
||||
|
||||
Using a non distributed method, different than the one implemented in the
|
||||
WALS model. The weight of an observed entry (i, j) (i.e. such that
|
||||
input_mat[i, j] is non zero) is (w0 + row_weights[i]col_weights[j]).
|
||||
|
||||
Args:
|
||||
input_mat: The input matrix, a SparseTensor of rank 2.
|
||||
row_factors: The row factors, a dense Tensor of rank 2.
|
||||
col_factors: The col factors, a dense Tensor of rank 2.
|
||||
regularization: the regularization coefficient, a scalar.
|
||||
w0: the weight of unobserved entries. A scalar.
|
||||
row_weights: A dense tensor of rank 1.
|
||||
col_weights: A dense tensor of rank 1.
|
||||
|
||||
Returns:
|
||||
The total loss.
|
||||
"""
|
||||
wr = (array_ops.expand_dims(row_weights, 1) if row_weights is not None
|
||||
else constant_op.constant(1.))
|
||||
wc = (array_ops.expand_dims(col_weights, 0) if col_weights is not None
|
||||
else constant_op.constant(1.))
|
||||
reg = (regularization if regularization is not None
|
||||
else constant_op.constant(0.))
|
||||
|
||||
row_indices, col_indices = array_ops.split(input_mat.indices,
|
||||
axis=1,
|
||||
num_or_size_splits=2)
|
||||
gathered_row_factors = array_ops.gather(row_factors, row_indices)
|
||||
gathered_col_factors = array_ops.gather(col_factors, col_indices)
|
||||
sp_approx_vals = array_ops.squeeze(math_ops.matmul(
|
||||
gathered_row_factors, gathered_col_factors, adjoint_b=True))
|
||||
sp_approx = sparse_tensor.SparseTensor(
|
||||
indices=input_mat.indices,
|
||||
values=sp_approx_vals,
|
||||
dense_shape=input_mat.dense_shape)
|
||||
|
||||
sp_approx_sq = math_ops.square(sp_approx)
|
||||
row_norm = math_ops.reduce_sum(math_ops.square(row_factors))
|
||||
col_norm = math_ops.reduce_sum(math_ops.square(col_factors))
|
||||
row_col_norm = math_ops.reduce_sum(math_ops.square(math_ops.matmul(
|
||||
row_factors, col_factors, transpose_b=True)))
|
||||
|
||||
resid = sparse_ops.sparse_add(input_mat, sp_approx * (-1))
|
||||
resid_sq = math_ops.square(resid)
|
||||
loss = w0 * (
|
||||
sparse_ops.sparse_reduce_sum(resid_sq) -
|
||||
sparse_ops.sparse_reduce_sum(sp_approx_sq)
|
||||
)
|
||||
loss += (sparse_ops.sparse_reduce_sum(wr * (resid_sq * wc)) +
|
||||
w0 * row_col_norm + reg * (row_norm + col_norm))
|
||||
return loss.eval()
|
||||
|
||||
|
||||
def calculate_loss_from_wals_model(wals_model, sp_inputs):
|
||||
current_rows = embedding_ops.embedding_lookup(
|
||||
wals_model.row_factors, math_ops.range(wals_model._input_rows),
|
||||
partition_strategy="div")
|
||||
current_cols = embedding_ops.embedding_lookup(
|
||||
wals_model.col_factors, math_ops.range(wals_model._input_cols),
|
||||
partition_strategy="div")
|
||||
row_wts = embedding_ops.embedding_lookup(
|
||||
wals_model._row_weights, math_ops.range(wals_model._input_rows),
|
||||
partition_strategy="div")
|
||||
col_wts = embedding_ops.embedding_lookup(
|
||||
wals_model._col_weights, math_ops.range(wals_model._input_cols),
|
||||
partition_strategy="div")
|
||||
return calculate_loss(
|
||||
sp_inputs, current_rows, current_cols, wals_model._regularization,
|
||||
wals_model._unobserved_weight, row_wts, col_wts)
|
||||
INPUT_MATRIX = factorization_ops_test_utils.INPUT_MATRIX
|
||||
np_matrix_to_tf_sparse = factorization_ops_test_utils.np_matrix_to_tf_sparse
|
||||
|
||||
|
||||
class WalsModelTest(test.TestCase):
|
||||
|
||||
def sparse_input(self):
|
||||
return np_matrix_to_tf_sparse(INPUT_MATRIX)
|
||||
|
||||
def count_rows(self, sp_input):
|
||||
return math_ops.cast(
|
||||
array_ops.shape(array_ops.unique(sp_input.indices[:, 0])[0])[0],
|
||||
dtypes.float32)
|
||||
|
||||
def count_cols(self, sp_input):
|
||||
return math_ops.cast(
|
||||
array_ops.shape(array_ops.unique(sp_input.indices[:, 1])[0])[0],
|
||||
dtypes.float32)
|
||||
|
||||
def calculate_loss_from_wals_model(self, wals_model, sp_inputs):
|
||||
current_rows = embedding_ops.embedding_lookup(
|
||||
wals_model.row_factors, math_ops.range(wals_model._input_rows),
|
||||
partition_strategy="div")
|
||||
current_cols = embedding_ops.embedding_lookup(
|
||||
wals_model.col_factors, math_ops.range(wals_model._input_cols),
|
||||
partition_strategy="div")
|
||||
row_wts = embedding_ops.embedding_lookup(
|
||||
wals_model._row_weights, math_ops.range(wals_model._input_rows),
|
||||
partition_strategy="div")
|
||||
col_wts = embedding_ops.embedding_lookup(
|
||||
wals_model._col_weights, math_ops.range(wals_model._input_cols),
|
||||
partition_strategy="div")
|
||||
return factorization_ops_test_utils.calculate_loss(
|
||||
sp_inputs, current_rows, current_cols, wals_model._regularization,
|
||||
wals_model._unobserved_weight, row_wts, col_wts)
|
||||
|
||||
def setUp(self):
|
||||
self.col_init = [
|
||||
# shard 0
|
||||
@ -208,7 +104,7 @@ class WalsModelTest(test.TestCase):
|
||||
use_factors_weights_cache,
|
||||
compute_loss=False):
|
||||
with ops.Graph().as_default(), self.test_session() as sess:
|
||||
self._wals_inputs = sparse_input()
|
||||
self._wals_inputs = self.sparse_input()
|
||||
sp_feeder = array_ops.sparse_placeholder(dtypes.float32)
|
||||
num_rows = 5
|
||||
num_cols = 7
|
||||
@ -282,10 +178,10 @@ class WalsModelTest(test.TestCase):
|
||||
if compute_loss:
|
||||
# Test loss computation after the row update
|
||||
loss = sum(
|
||||
sess.run(factor_loss * count_rows(inp) / num_rows,
|
||||
sess.run(factor_loss * self.count_rows(inp) / num_rows,
|
||||
feed_dict={sp_feeder: inp})
|
||||
for inp in input_scattered_rows)
|
||||
true_loss = calculate_loss_from_wals_model(
|
||||
true_loss = self.calculate_loss_from_wals_model(
|
||||
wals_model, self._wals_inputs)
|
||||
self.assertNear(
|
||||
loss, true_loss, err=.001,
|
||||
@ -355,10 +251,10 @@ class WalsModelTest(test.TestCase):
|
||||
if compute_loss:
|
||||
# Test loss computation after the column update.
|
||||
loss = sum(
|
||||
sess.run(factor_loss * count_cols(inp) / num_cols,
|
||||
sess.run(factor_loss * self.count_cols(inp) / num_cols,
|
||||
feed_dict={sp_feeder: inp})
|
||||
for inp in input_scattered_cols_non_duplicate)
|
||||
true_loss = calculate_loss_from_wals_model(
|
||||
true_loss = self.calculate_loss_from_wals_model(
|
||||
wals_model, self._wals_inputs)
|
||||
self.assertNear(
|
||||
loss, true_loss, err=.001,
|
||||
@ -368,7 +264,7 @@ class WalsModelTest(test.TestCase):
|
||||
def _run_test_process_input_transposed(self, use_factors_weights_cache,
|
||||
compute_loss=False):
|
||||
with ops.Graph().as_default(), self.test_session() as sess:
|
||||
self._wals_inputs = sparse_input()
|
||||
self._wals_inputs = self.sparse_input()
|
||||
sp_feeder = array_ops.sparse_placeholder(dtypes.float32)
|
||||
num_rows = 5
|
||||
num_cols = 7
|
||||
@ -448,10 +344,10 @@ class WalsModelTest(test.TestCase):
|
||||
if compute_loss:
|
||||
# Test loss computation after the row update
|
||||
loss = sum(
|
||||
sess.run(factor_loss * count_cols(inp) / num_rows,
|
||||
sess.run(factor_loss * self.count_cols(inp) / num_rows,
|
||||
feed_dict={sp_feeder: inp})
|
||||
for inp in input_scattered_rows_non_duplicate)
|
||||
true_loss = calculate_loss_from_wals_model(
|
||||
true_loss = self.calculate_loss_from_wals_model(
|
||||
wals_model, self._wals_inputs)
|
||||
self.assertNear(
|
||||
loss, true_loss, err=.001,
|
||||
@ -516,10 +412,10 @@ class WalsModelTest(test.TestCase):
|
||||
if compute_loss:
|
||||
# Test loss computation after the col update
|
||||
loss = sum(
|
||||
sess.run(factor_loss * count_rows(inp) / num_cols,
|
||||
sess.run(factor_loss * self.count_rows(inp) / num_cols,
|
||||
feed_dict={sp_feeder: inp})
|
||||
for inp in input_scattered_cols_non_duplicate)
|
||||
true_loss = calculate_loss_from_wals_model(
|
||||
true_loss = self.calculate_loss_from_wals_model(
|
||||
wals_model, self._wals_inputs)
|
||||
self.assertNear(
|
||||
loss, true_loss, err=.001,
|
||||
@ -534,7 +430,7 @@ class WalsModelTest(test.TestCase):
|
||||
# Here we test that those two give identical results.
|
||||
def _run_test_als(self, use_factors_weights_cache):
|
||||
with ops.Graph().as_default(), self.test_session():
|
||||
self._wals_inputs = sparse_input()
|
||||
self._wals_inputs = self.sparse_input()
|
||||
col_init = np.random.rand(7, 3)
|
||||
als_model = factorization_ops.WALSModel(
|
||||
5,
|
||||
@ -613,7 +509,7 @@ class WalsModelTest(test.TestCase):
|
||||
|
||||
def _run_test_als_transposed(self, use_factors_weights_cache):
|
||||
with ops.Graph().as_default(), self.test_session():
|
||||
self._wals_inputs = sparse_input()
|
||||
self._wals_inputs = self.sparse_input()
|
||||
col_init = np.random.rand(7, 3)
|
||||
als_model = factorization_ops.WALSModel(
|
||||
5,
|
||||
|
@ -0,0 +1,131 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test utils for factorization_ops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
|
||||
|
||||
INPUT_MATRIX = np.array(
|
||||
[[0.1, 0.0, 0.2, 0.0, 0.4, 0.5, 0.0],
|
||||
[0.0, 1.1, 0.0, 1.3, 1.4, 0.0, 1.6],
|
||||
[2.0, 0.0, 0.0, 2.3, 0.0, 2.5, 0.0],
|
||||
[3.0, 0.0, 3.2, 3.3, 0.0, 3.5, 0.0],
|
||||
[0.0, 4.1, 0.0, 0.0, 4.4, 0.0, 4.6]]).astype(np.float32)
|
||||
|
||||
|
||||
def np_matrix_to_tf_sparse(np_matrix,
|
||||
row_slices=None,
|
||||
col_slices=None,
|
||||
transpose=False,
|
||||
shuffle=False):
|
||||
"""Simple util to slice non-zero np matrix elements as tf.SparseTensor."""
|
||||
indices = np.nonzero(np_matrix)
|
||||
|
||||
# Only allow slices of whole rows or whole columns.
|
||||
assert not (row_slices is not None and col_slices is not None)
|
||||
|
||||
if row_slices is not None:
|
||||
selected_ind = np.concatenate(
|
||||
[np.where(indices[0] == r)[0] for r in row_slices], 0)
|
||||
indices = (indices[0][selected_ind], indices[1][selected_ind])
|
||||
|
||||
if col_slices is not None:
|
||||
selected_ind = np.concatenate(
|
||||
[np.where(indices[1] == c)[0] for c in col_slices], 0)
|
||||
indices = (indices[0][selected_ind], indices[1][selected_ind])
|
||||
|
||||
if shuffle:
|
||||
shuffled_ind = [x for x in range(len(indices[0]))]
|
||||
random.shuffle(shuffled_ind)
|
||||
indices = (indices[0][shuffled_ind], indices[1][shuffled_ind])
|
||||
|
||||
ind = (np.concatenate((np.expand_dims(indices[1], 1),
|
||||
np.expand_dims(indices[0], 1)), 1).astype(np.int64) if
|
||||
transpose else np.concatenate((np.expand_dims(indices[0], 1),
|
||||
np.expand_dims(indices[1], 1)),
|
||||
1).astype(np.int64))
|
||||
val = np_matrix[indices].astype(np.float32)
|
||||
shape = (np.array([max(indices[1]) + 1, max(indices[0]) + 1]).astype(np.int64)
|
||||
if transpose else np.array(
|
||||
[max(indices[0]) + 1, max(indices[1]) + 1]).astype(np.int64))
|
||||
return sparse_tensor.SparseTensor(ind, val, shape)
|
||||
|
||||
|
||||
def calculate_loss(input_mat, row_factors, col_factors, regularization=None,
|
||||
w0=1., row_weights=None, col_weights=None):
|
||||
"""Calculates the loss of a given factorization.
|
||||
|
||||
Using a non distributed method, different than the one implemented in the
|
||||
WALS model. The weight of an observed entry (i, j) (i.e. such that
|
||||
input_mat[i, j] is non zero) is (w0 + row_weights[i]col_weights[j]).
|
||||
|
||||
Args:
|
||||
input_mat: The input matrix, a SparseTensor of rank 2.
|
||||
row_factors: The row factors, a dense Tensor of rank 2.
|
||||
col_factors: The col factors, a dense Tensor of rank 2.
|
||||
regularization: the regularization coefficient, a scalar.
|
||||
w0: the weight of unobserved entries. A scalar.
|
||||
row_weights: A dense tensor of rank 1.
|
||||
col_weights: A dense tensor of rank 1.
|
||||
|
||||
Returns:
|
||||
The total loss.
|
||||
"""
|
||||
wr = (array_ops.expand_dims(row_weights, 1) if row_weights is not None
|
||||
else constant_op.constant(1.))
|
||||
wc = (array_ops.expand_dims(col_weights, 0) if col_weights is not None
|
||||
else constant_op.constant(1.))
|
||||
reg = (regularization if regularization is not None
|
||||
else constant_op.constant(0.))
|
||||
|
||||
row_indices, col_indices = array_ops.split(input_mat.indices,
|
||||
axis=1,
|
||||
num_or_size_splits=2)
|
||||
gathered_row_factors = array_ops.gather(row_factors, row_indices)
|
||||
gathered_col_factors = array_ops.gather(col_factors, col_indices)
|
||||
sp_approx_vals = array_ops.squeeze(math_ops.matmul(
|
||||
gathered_row_factors, gathered_col_factors, adjoint_b=True))
|
||||
sp_approx = sparse_tensor.SparseTensor(
|
||||
indices=input_mat.indices,
|
||||
values=sp_approx_vals,
|
||||
dense_shape=input_mat.dense_shape)
|
||||
|
||||
sp_approx_sq = math_ops.square(sp_approx)
|
||||
row_norm = math_ops.reduce_sum(math_ops.square(row_factors))
|
||||
col_norm = math_ops.reduce_sum(math_ops.square(col_factors))
|
||||
row_col_norm = math_ops.reduce_sum(math_ops.square(math_ops.matmul(
|
||||
row_factors, col_factors, transpose_b=True)))
|
||||
|
||||
resid = sparse_ops.sparse_add(input_mat, sp_approx * (-1))
|
||||
resid_sq = math_ops.square(resid)
|
||||
loss = w0 * (
|
||||
sparse_ops.sparse_reduce_sum(resid_sq) -
|
||||
sparse_ops.sparse_reduce_sum(sp_approx_sq)
|
||||
)
|
||||
loss += (sparse_ops.sparse_reduce_sum(wr * (resid_sq * wc)) +
|
||||
w0 * row_col_norm + reg * (row_norm + col_norm))
|
||||
return loss.eval()
|
@ -21,9 +21,14 @@ py_library(
|
||||
":dense_kernel_mapper_py",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/contrib/learn",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:util",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
@ -31,6 +36,7 @@ py_library(
|
||||
name = "dense_kernel_mapper_py",
|
||||
srcs = ["python/mappers/dense_kernel_mapper.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["@six_archive//:six"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
@ -40,12 +46,12 @@ py_test(
|
||||
deps = [
|
||||
":dense_kernel_mapper_py",
|
||||
":kernel_methods",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn",
|
||||
"//tensorflow/python:ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:random_ops",
|
||||
],
|
||||
)
|
||||
|
||||
@ -55,10 +61,12 @@ py_test(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":kernel_methods",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/contrib/learn",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
@ -22,7 +22,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.kernel_methods.python.kernel_estimators import KernelLinearClassifier
|
||||
from tensorflow.contrib.kernel_methods.python.mappers import dense_kernel_mapper
|
||||
from tensorflow.contrib.kernel_methods.python.mappers.random_fourier_features import RandomFourierFeatureMapper
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
@ -118,6 +118,7 @@ tf_custom_op_py_library(
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:check_ops",
|
||||
"//tensorflow/python:clip_ops",
|
||||
"//tensorflow/python:common_shapes",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:embedding_ops",
|
||||
"//tensorflow/python:framework",
|
||||
@ -131,9 +132,11 @@ tf_custom_op_py_library(
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:sparse_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:standard_ops",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:summary",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from tensorflow.contrib.framework.python.framework import checkpoint_utils
|
||||
from tensorflow.contrib.framework.python.framework import experimental
|
||||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
||||
@ -36,6 +38,7 @@ from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
def _embeddings_from_arguments(column,
|
||||
@ -136,6 +139,58 @@ def _embeddings_from_arguments(column,
|
||||
max_norm=args.max_norm)
|
||||
|
||||
|
||||
def _maybe_reshape_input_tensor(tensor, column_name, output_rank):
|
||||
"""Reshape the input tensor by the following rule.
|
||||
|
||||
1. If `output_rank > input_rank + 1`, raise a `ValueError`.
|
||||
2. If `output_rank == input_rank + 1`, expand the tensor by one dimension.
|
||||
3. If `output_rank == input_rank`, do nothing.
|
||||
4. If `output_rank < input_rank`, flatten the inner dimensions of the tensor.
|
||||
|
||||
Args:
|
||||
tensor: A Tensor or SparseTensor to be reshaped.
|
||||
column_name: A string name of the feature column for the tensor.
|
||||
output_rank: the desired rank of the tensor.
|
||||
Returns:
|
||||
A reshaped Tensor or SparseTensor.
|
||||
Raises:
|
||||
ValueError: if `output_rank > input_rank + 1` for the input tensor.
|
||||
"""
|
||||
input_rank = tensor.get_shape().ndims
|
||||
|
||||
if input_rank is None and isinstance(tensor, sparse_tensor_py.SparseTensor):
|
||||
# Try to get the rank of a sparse tensor by its dense_shape's shape.
|
||||
input_rank = tensor.dense_shape.get_shape().as_list()[0]
|
||||
|
||||
if input_rank is None:
|
||||
raise ValueError('Error while processing column {}. Rank of input Tensor '
|
||||
'can not be None.'.format(column_name))
|
||||
|
||||
if output_rank > input_rank + 1:
|
||||
raise ValueError('Error while processing column {}. Rank of input Tensor '
|
||||
'({}) should be the same as output_rank ({}). For '
|
||||
'example, sequence data should typically be 3 '
|
||||
'dimensional (rank 3) while non-sequence data is '
|
||||
'typically 2 dimensional (rank 2).'.format(
|
||||
column_name, input_rank, output_rank))
|
||||
elif output_rank == input_rank + 1:
|
||||
# Expand the tensor's shape by 1 dimension.
|
||||
if isinstance(tensor, sparse_tensor_py.SparseTensor):
|
||||
output_shape = array_ops.concat([tensor.dense_shape, [1]], 0)
|
||||
return sparse_ops.sparse_reshape(tensor, output_shape)
|
||||
else:
|
||||
reshaped = array_ops.expand_dims(tensor, -1)
|
||||
# Try to calculate the new shape.
|
||||
static_shape = tensor.get_shape()
|
||||
if static_shape is not None and static_shape.dims is not None:
|
||||
reshaped.set_shape(static_shape.as_list() + [1])
|
||||
return reshaped
|
||||
elif output_rank < input_rank:
|
||||
return layers._inner_flatten(tensor, output_rank) # pylint: disable=protected-access
|
||||
else:
|
||||
return tensor
|
||||
|
||||
|
||||
def _input_from_feature_columns(columns_to_tensors,
|
||||
feature_columns,
|
||||
weight_collections,
|
||||
@ -160,6 +215,12 @@ def _input_from_feature_columns(columns_to_tensors,
|
||||
default_name=column.name,
|
||||
values=columns_to_tensors.values()):
|
||||
transformed_tensor = transformer.transform(column)
|
||||
if output_rank == 3:
|
||||
transformed_tensor = nest.map_structure(
|
||||
functools.partial(
|
||||
_maybe_reshape_input_tensor,
|
||||
column_name=column.name,
|
||||
output_rank=output_rank), transformed_tensor)
|
||||
try:
|
||||
# pylint: disable=protected-access
|
||||
arguments = column._deep_embedding_lookup_arguments(
|
||||
@ -548,7 +609,8 @@ def weighted_sum_from_feature_columns(columns_to_tensors,
|
||||
default_name=column.name,
|
||||
values=columns_to_tensors.values()):
|
||||
tensor = column._to_dense_tensor(transformed_tensor)
|
||||
tensor = fc._reshape_real_valued_tensor(tensor, 2, column.name)
|
||||
tensor = _maybe_reshape_input_tensor(
|
||||
tensor, column.name, output_rank=2)
|
||||
variable = [
|
||||
contrib_variables.model_variable(
|
||||
name='weight',
|
||||
|
@ -1350,6 +1350,35 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
|
||||
|
||||
self.assertAllEqual(expected_input_shape, model_input.shape)
|
||||
|
||||
def testEmbeddingColumnWithAutoReshape(self):
|
||||
hash_buckets = 10
|
||||
embedding_dimension = 5
|
||||
ids_tensor = sparse_tensor.SparseTensor(
|
||||
values=["c", "b",
|
||||
"a", "c", "b",
|
||||
"b"],
|
||||
indices=[[0, 0], [0, 1],
|
||||
[1, 0], [1, 1], [1, 2],
|
||||
[3, 2]],
|
||||
dense_shape=[4, 3])
|
||||
|
||||
expected_input_shape = np.array([4, 3, embedding_dimension])
|
||||
|
||||
hashed_ids_column = feature_column.sparse_column_with_hash_bucket(
|
||||
"ids", hash_buckets)
|
||||
embedded_column = feature_column.embedding_column(hashed_ids_column,
|
||||
embedding_dimension)
|
||||
columns_to_tensors = {"ids": ids_tensor}
|
||||
model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
|
||||
columns_to_tensors, [embedded_column])
|
||||
|
||||
with self.test_session() as sess:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
data_flow_ops.tables_initializer().run()
|
||||
model_input = sess.run(model_input_tensor)
|
||||
|
||||
self.assertAllEqual(expected_input_shape, model_input.shape)
|
||||
|
||||
def testEmbeddingColumnGradient(self):
|
||||
hash_buckets = 1000
|
||||
embedding_dimension = 3
|
||||
|
@ -836,6 +836,19 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "model_fn_test",
|
||||
size = "small",
|
||||
srcs = ["python/learn/estimators/model_fn_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":learn",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "multioutput_test",
|
||||
size = "small",
|
||||
|
@ -42,6 +42,7 @@ from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.summary import summary
|
||||
@ -816,7 +817,8 @@ class _BinaryLogisticHead(_SingleHead):
|
||||
loss_fn=self._loss_fn,
|
||||
logits_to_predictions_fn=self._logits_to_predictions,
|
||||
metrics_fn=self._metrics,
|
||||
create_output_alternatives_fn=self._create_output_alternatives,
|
||||
create_output_alternatives_fn=_classification_output_alternatives(
|
||||
self.head_name, self._problem_type),
|
||||
labels=labels,
|
||||
train_op_fn=train_op_fn,
|
||||
logits=logits,
|
||||
@ -885,6 +887,8 @@ class _BinaryLogisticHead(_SingleHead):
|
||||
_indicator_labels_streaming_mean(labels, weights))
|
||||
metrics[_summary_key(self.head_name, mkey.AUC)] = (
|
||||
_streaming_auc(logistic, labels, weights))
|
||||
metrics[_summary_key(self.head_name, mkey.AUC_PR)] = (
|
||||
_streaming_auc(logistic, labels, weights, curve="PR"))
|
||||
|
||||
for threshold in self._thresholds:
|
||||
metrics[_summary_key(
|
||||
@ -1009,7 +1013,8 @@ class _MultiClassHead(_SingleHead):
|
||||
loss_fn=self._wrapped_loss_fn,
|
||||
logits_to_predictions_fn=self._logits_to_predictions,
|
||||
metrics_fn=self._metrics,
|
||||
create_output_alternatives_fn=self._create_output_alternatives,
|
||||
create_output_alternatives_fn=_classification_output_alternatives(
|
||||
self.head_name, self._problem_type, self._label_keys),
|
||||
labels=labels,
|
||||
train_op_fn=train_op_fn,
|
||||
logits=logits,
|
||||
@ -1113,25 +1118,6 @@ class _MultiClassHead(_SingleHead):
|
||||
|
||||
return metrics
|
||||
|
||||
def _create_output_alternatives(self, predictions):
|
||||
"""See superclass."""
|
||||
probabilities = predictions[prediction_key.PredictionKey.PROBABILITIES]
|
||||
batch_size = array_ops.shape(probabilities)[0]
|
||||
if self._label_keys:
|
||||
classes = array_ops.tile(
|
||||
input=array_ops.expand_dims(input=self._label_keys, axis=0),
|
||||
multiples=[batch_size, 1])
|
||||
else:
|
||||
classes = array_ops.tile(
|
||||
input=array_ops.expand_dims(
|
||||
input=math_ops.range(self.logits_dimension), axis=0),
|
||||
multiples=[batch_size, 1])
|
||||
predictions_for_serving = {
|
||||
prediction_key.PredictionKey.CLASSES: classes,
|
||||
prediction_key.PredictionKey.PROBABILITIES: probabilities,
|
||||
}
|
||||
return {self._head_name: (self._problem_type, predictions_for_serving)}
|
||||
|
||||
|
||||
def _to_labels_tensor(labels, label_name):
|
||||
"""Returns label as a tensor.
|
||||
@ -1226,6 +1212,7 @@ class _BinarySvmHead(_SingleHead):
|
||||
loss_fn=self._loss_fn,
|
||||
logits_to_predictions_fn=self._logits_to_predictions,
|
||||
metrics_fn=self._metrics,
|
||||
# TODO(zakaria): Handle labels for export.
|
||||
create_output_alternatives_fn=self._create_output_alternatives,
|
||||
labels=labels,
|
||||
train_op_fn=train_op_fn,
|
||||
@ -1325,7 +1312,8 @@ class _MultiLabelHead(_SingleHead):
|
||||
loss_fn=self._loss_fn,
|
||||
logits_to_predictions_fn=self._logits_to_predictions,
|
||||
metrics_fn=self._metrics,
|
||||
create_output_alternatives_fn=self._create_output_alternatives,
|
||||
create_output_alternatives_fn=_classification_output_alternatives(
|
||||
self.head_name, self._problem_type),
|
||||
labels=labels,
|
||||
train_op_fn=train_op_fn,
|
||||
logits=logits,
|
||||
@ -1374,6 +1362,8 @@ class _MultiLabelHead(_SingleHead):
|
||||
metrics_lib.streaming_accuracy(classes, labels, weights))
|
||||
metrics[_summary_key(self.head_name, mkey.AUC)] = _streaming_auc(
|
||||
probabilities, labels, weights)
|
||||
metrics[_summary_key(self.head_name, mkey.AUC_PR)] = _streaming_auc(
|
||||
probabilities, labels, weights, curve="PR")
|
||||
|
||||
for class_id in self._metric_class_ids:
|
||||
# TODO(ptucker): Add per-class accuracy, precision, recall.
|
||||
@ -1391,6 +1381,9 @@ class _MultiLabelHead(_SingleHead):
|
||||
_predictions_streaming_mean(logits, weights, class_id))
|
||||
metrics[_summary_key(self.head_name, mkey.CLASS_AUC % class_id)] = (
|
||||
_streaming_auc(probabilities, labels, weights, class_id))
|
||||
metrics[_summary_key(self.head_name, mkey.CLASS_AUC_PR % class_id)] = (
|
||||
_streaming_auc(probabilities, labels, weights, class_id,
|
||||
curve="PR"))
|
||||
|
||||
return metrics
|
||||
|
||||
@ -1857,7 +1850,8 @@ def _class_labels_streaming_mean(labels, weights, class_id):
|
||||
weights=weights)
|
||||
|
||||
|
||||
def _streaming_auc(predictions, labels, weights=None, class_id=None):
|
||||
def _streaming_auc(predictions, labels, weights=None, class_id=None,
|
||||
curve="ROC"):
|
||||
predictions = ops.convert_to_tensor(predictions)
|
||||
labels = ops.convert_to_tensor(labels)
|
||||
if class_id is not None:
|
||||
@ -1866,7 +1860,8 @@ def _streaming_auc(predictions, labels, weights=None, class_id=None):
|
||||
return metrics_lib.streaming_auc(
|
||||
predictions,
|
||||
math_ops.cast(labels, dtypes.bool),
|
||||
weights=_float_weights_or_none(weights))
|
||||
weights=_float_weights_or_none(weights),
|
||||
curve=curve)
|
||||
|
||||
|
||||
def _assert_class_id(class_id, num_classes=None):
|
||||
@ -1901,6 +1896,71 @@ def _streaming_recall_at_threshold(predictions, labels, weights, threshold):
|
||||
return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op)
|
||||
|
||||
|
||||
def _classification_output_alternatives(head_name, problem_type,
|
||||
label_keys=None):
|
||||
"""Creates a func to generate output alternatives for classification.
|
||||
|
||||
Servo expects classes to be a string tensor, and have the same dimensions
|
||||
as the probabilities tensor. It should contain the labels of the corresponding
|
||||
entries in probabilities. This function creates a new classes tensor that
|
||||
satisfies these conditions and can be exported.
|
||||
|
||||
Args:
|
||||
head_name: Name of the head.
|
||||
problem_type: `ProblemType`
|
||||
label_keys: Optional label keys
|
||||
|
||||
Returns:
|
||||
A function to generate output alternatives.
|
||||
"""
|
||||
def _create_output_alternatives(predictions):
|
||||
"""Creates output alternative for the Head.
|
||||
|
||||
Args:
|
||||
predictions: a dict of {tensor_name: Tensor}, where 'tensor_name' is a
|
||||
symbolic name for an output Tensor possibly but not necessarily taken
|
||||
from `PredictionKey`, and 'Tensor' is the corresponding output Tensor
|
||||
itself.
|
||||
|
||||
Returns:
|
||||
`dict` of {submodel_name: (problem_type, {tensor_name: Tensor})}, where
|
||||
'submodel_name' is a submodel identifier that should be consistent across
|
||||
the pipeline (here likely taken from the head_name),
|
||||
'problem_type' is a `ProblemType`,
|
||||
'tensor_name' is a symbolic name for an output Tensor possibly but not
|
||||
necessarily taken from `PredictionKey`, and
|
||||
'Tensor' is the corresponding output Tensor itself.
|
||||
|
||||
Raises:
|
||||
ValueError: if predictions does not have PredictionKey.PROBABILITIES key.
|
||||
"""
|
||||
probabilities = predictions.get(prediction_key.PredictionKey.PROBABILITIES)
|
||||
if probabilities is None:
|
||||
raise ValueError("%s missing in predictions" %
|
||||
prediction_key.PredictionKey.PROBABILITIES)
|
||||
|
||||
with ops.name_scope(None, "_classification_output_alternatives",
|
||||
(probabilities,)):
|
||||
batch_size = array_ops.shape(probabilities)[0]
|
||||
if label_keys:
|
||||
classes = array_ops.tile(
|
||||
input=array_ops.expand_dims(input=label_keys, axis=0),
|
||||
multiples=[batch_size, 1],
|
||||
name="classes_tensor")
|
||||
else:
|
||||
n = array_ops.shape(probabilities)[1]
|
||||
classes = array_ops.tile(
|
||||
input=array_ops.expand_dims(input=math_ops.range(n), axis=0),
|
||||
multiples=[batch_size, 1])
|
||||
classes = string_ops.as_string(classes, name="classes_tensor")
|
||||
|
||||
exported_predictions = {
|
||||
prediction_key.PredictionKey.PROBABILITIES: probabilities,
|
||||
prediction_key.PredictionKey.CLASSES: classes}
|
||||
return {head_name: (problem_type, exported_predictions)}
|
||||
|
||||
return _create_output_alternatives
|
||||
|
||||
# Aliases
|
||||
# TODO(zakaria): Remove these aliases, See b/34751732
|
||||
_regression_head = regression_head
|
||||
|
@ -297,11 +297,15 @@ class MultiLabelHeadTest(test.TestCase):
|
||||
def _expected_eval_metrics(self, expected_loss):
|
||||
return {
|
||||
"accuracy": 1. / 3,
|
||||
"auc": 1. / 4,
|
||||
"loss": expected_loss,
|
||||
"auc": 1. / 4,
|
||||
"auc/class0": 1.,
|
||||
"auc/class1": 1.,
|
||||
"auc/class2": 0.,
|
||||
"auc_precision_recall": 0.166667,
|
||||
"auc_precision_recall/class0": 0,
|
||||
"auc_precision_recall/class1": 0.,
|
||||
"auc_precision_recall/class2": 1.,
|
||||
"labels/actual_label_mean/class0": self._labels[0][0],
|
||||
"labels/actual_label_mean/class1": self._labels[0][1],
|
||||
"labels/actual_label_mean/class2": self._labels[0][2],
|
||||
@ -417,7 +421,7 @@ class MultiLabelHeadTest(test.TestCase):
|
||||
{}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
|
||||
logits_input=((0., 0.),), logits=self._logits)
|
||||
|
||||
def testMultiLabelEvalMode(self):
|
||||
def testMultiLabelEval(self):
|
||||
n_classes = 3
|
||||
head = head_lib.multi_label_head(
|
||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||
@ -433,7 +437,7 @@ class MultiLabelHeadTest(test.TestCase):
|
||||
_assert_metrics(self, expected_loss,
|
||||
self._expected_eval_metrics(expected_loss), model_fn_ops)
|
||||
|
||||
def testMultiClassEvalModeWithLargeLogits(self):
|
||||
def testMultiClassEvalWithLargeLogits(self):
|
||||
n_classes = 3
|
||||
head = head_lib.multi_label_head(
|
||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||
@ -472,6 +476,36 @@ class MultiLabelHeadTest(test.TestCase):
|
||||
_assert_metrics(self, expected_loss,
|
||||
expected_eval_metrics, model_fn_ops)
|
||||
|
||||
def testMultiLabelInfer(self):
|
||||
n_classes = 3
|
||||
head = head_lib.multi_label_head(n_classes=n_classes, head_name="head_name")
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
{}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn,
|
||||
logits=((1., 0., 0.), (0., 0., 1)))
|
||||
self.assertIsNone(model_fn_ops.train_op)
|
||||
_assert_no_variables(self)
|
||||
with session.Session():
|
||||
self.assertListEqual(
|
||||
[1, 0, 0], model_fn_ops.predictions["classes"].eval().tolist()[0])
|
||||
self.assertItemsEqual(
|
||||
["head_name"], six.iterkeys(model_fn_ops.output_alternatives))
|
||||
self.assertEqual(
|
||||
constants.ProblemType.CLASSIFICATION,
|
||||
model_fn_ops.output_alternatives["head_name"][0])
|
||||
|
||||
predictions_for_serving = (
|
||||
model_fn_ops.output_alternatives["head_name"][1])
|
||||
self.assertIn("classes", six.iterkeys(predictions_for_serving))
|
||||
self.assertAllEqual(
|
||||
[[b"0", b"1", b"2"], [b"0", b"1", b"2"]],
|
||||
predictions_for_serving["classes"].eval())
|
||||
self.assertIn("probabilities", six.iterkeys(predictions_for_serving))
|
||||
self.assertAllClose(
|
||||
[[0.731059, 0.5, 0.5],
|
||||
[0.5, 0.5, 0.731059,]],
|
||||
predictions_for_serving["probabilities"].eval())
|
||||
|
||||
def testMultiLabelWithLabelName(self):
|
||||
n_classes = 3
|
||||
label_name = "my_label"
|
||||
@ -621,6 +655,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
||||
"accuracy/baseline_label_mean": label_mean,
|
||||
"accuracy/threshold_0.500000_mean": 1. / 2,
|
||||
"auc": 1. / 2,
|
||||
"auc_precision_recall": 0.749999,
|
||||
"labels/actual_label_mean": label_mean,
|
||||
"labels/prediction_mean": .731059, # softmax
|
||||
"loss": expected_loss,
|
||||
@ -691,7 +726,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
||||
{}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
|
||||
logits_input=((0., 0.), (0., 0.)), logits=self._logits)
|
||||
|
||||
def testBinaryClassificationEvalMode(self):
|
||||
def testBinaryClassificationEval(self):
|
||||
n_classes = 2
|
||||
head = head_lib.multi_class_head(n_classes=n_classes)
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
@ -708,18 +743,32 @@ class BinaryClassificationHeadTest(test.TestCase):
|
||||
_assert_metrics(self, expected_loss,
|
||||
self._expected_eval_metrics(expected_loss), model_fn_ops)
|
||||
|
||||
def testBinaryClassificationInferMode(self):
|
||||
def testBinaryClassificationInfer(self):
|
||||
n_classes = 2
|
||||
head = head_lib.multi_class_head(n_classes=n_classes)
|
||||
head = head_lib.multi_class_head(n_classes=n_classes, head_name="head_name")
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
# logloss: z:label, x:logit
|
||||
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
{}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn,
|
||||
logits=self._logits)
|
||||
self._assert_output_alternatives(model_fn_ops)
|
||||
self.assertIsNone(model_fn_ops.train_op)
|
||||
_assert_no_variables(self)
|
||||
with session.Session():
|
||||
self.assertListEqual(
|
||||
[1, 1], list(model_fn_ops.predictions["classes"].eval()))
|
||||
self.assertItemsEqual(
|
||||
["head_name"], six.iterkeys(model_fn_ops.output_alternatives))
|
||||
self.assertEqual(
|
||||
constants.ProblemType.LOGISTIC_REGRESSION,
|
||||
model_fn_ops.output_alternatives["head_name"][0])
|
||||
predictions_for_serving = (
|
||||
model_fn_ops.output_alternatives["head_name"][1])
|
||||
self.assertIn("classes", six.iterkeys(predictions_for_serving))
|
||||
predicted_classes = predictions_for_serving["classes"].eval().tolist()
|
||||
self.assertListEqual(
|
||||
[b"0", b"1"], predicted_classes[0])
|
||||
self.assertIn("probabilities", six.iterkeys(predictions_for_serving))
|
||||
|
||||
def testBinaryClassificationInferMode_withWightColumn(self):
|
||||
n_classes = 2
|
||||
@ -1006,7 +1055,7 @@ class MultiClassHeadTest(test.TestCase):
|
||||
"multi_class_head/centered_bias/bias_1",
|
||||
"multi_class_head/centered_bias/bias_2"])
|
||||
|
||||
def testMultiClassEvalMode(self):
|
||||
def testMultiClassEval(self):
|
||||
n_classes = 3
|
||||
head = head_lib.multi_class_head(
|
||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||
@ -1131,7 +1180,7 @@ class MultiClassHeadTest(test.TestCase):
|
||||
model_fn_ops.output_alternatives["head_name"][1])
|
||||
self.assertIn("classes", six.iterkeys(predictions_for_serving))
|
||||
self.assertAllEqual(
|
||||
[[0, 1, 2], [0, 1, 2]],
|
||||
[[b"0", b"1", b"2"], [b"0", b"1", b"2"]],
|
||||
predictions_for_serving["classes"].eval())
|
||||
self.assertIn("probabilities", six.iterkeys(predictions_for_serving))
|
||||
self.assertAllClose(
|
||||
|
@ -22,7 +22,9 @@ class MetricKey(object):
|
||||
"""Metric key strings."""
|
||||
LOSS = "loss"
|
||||
AUC = "auc"
|
||||
AUC_PR = "auc_precision_recall"
|
||||
CLASS_AUC = "auc/class%d"
|
||||
CLASS_AUC_PR = "auc_precision_recall/class%d"
|
||||
PREDICTION_MEAN = "labels/prediction_mean"
|
||||
CLASS_PREDICTION_MEAN = "labels/prediction_mean/class%d"
|
||||
CLASS_LOGITS_MEAN = "labels/logits_mean/class%d"
|
||||
|
@ -25,10 +25,16 @@ import six
|
||||
|
||||
from tensorflow.contrib import framework as contrib_framework
|
||||
from tensorflow.contrib.framework import get_graph_from_inputs
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.estimators import constants
|
||||
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||
from tensorflow.python.estimator import model_fn as core_model_fn_lib
|
||||
from tensorflow.python.estimator.export import export_output as core_export_lib
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.training import session_run_hook
|
||||
|
||||
|
||||
@ -177,3 +183,85 @@ class ModelFnOps(
|
||||
training_chief_hooks=training_chief_hooks,
|
||||
training_hooks=training_hooks,
|
||||
scaffold=scaffold)
|
||||
|
||||
def estimator_spec(self, mode, default_serving_output_alternative_key=None):
|
||||
"""Creates an equivalent `EstimatorSpec`.
|
||||
|
||||
Args:
|
||||
mode: One of `ModeKeys`. Specifies if this training, evaluation or
|
||||
prediction.
|
||||
default_serving_output_alternative_key: Required for multiple heads. If
|
||||
you have multiple entries in `output_alternatives` dict (comparable to
|
||||
multiple heads), `EstimatorSpec` requires a default head that will be
|
||||
used if a Servo request does not explicitly mention which head to infer
|
||||
on. Pass the key of the output alternative here that you want to
|
||||
designate as default. A separate ExportOutpout for this default head
|
||||
wil be added to the export_outputs dict with the special key
|
||||
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, unless there is
|
||||
already an enry in output_alternatives with this special key.
|
||||
|
||||
Returns:
|
||||
Instance of `EstimatorSpec` that is equivalent to this `ModelFnOps`
|
||||
|
||||
Raises:
|
||||
ValueError: If problem type is unknown.
|
||||
"""
|
||||
def _scores(output_tensors):
|
||||
scores = output_tensors.get(prediction_key.PredictionKey.SCORES)
|
||||
if scores is None:
|
||||
scores = output_tensors.get(prediction_key.PredictionKey.PROBABILITIES)
|
||||
return scores
|
||||
|
||||
def _classes(output_tensors): # pylint: disable=missing-docstring
|
||||
classes = output_tensors.get(prediction_key.PredictionKey.CLASSES)
|
||||
if classes is None:
|
||||
logging.warning(
|
||||
'classes is None, Servo inference will not have class ids.')
|
||||
return None
|
||||
elif classes.dtype != dtypes.string:
|
||||
# Servo classification can only serve string classes
|
||||
logging.warning(
|
||||
'classes is not string, Servo inference will not have class ids.')
|
||||
return None
|
||||
|
||||
return classes
|
||||
|
||||
def _export_output(problem_type, predictions): # pylint: disable=missing-docstring
|
||||
if problem_type == constants.ProblemType.LINEAR_REGRESSION:
|
||||
return core_export_lib.RegressionOutput(_scores(predictions))
|
||||
|
||||
if (problem_type == constants.ProblemType.CLASSIFICATION or
|
||||
problem_type == constants.ProblemType.LOGISTIC_REGRESSION):
|
||||
return core_export_lib.ClassificationOutput(
|
||||
scores=_scores(predictions), classes=_classes(predictions))
|
||||
|
||||
if problem_type == constants.ProblemType.UNSPECIFIED:
|
||||
return core_export_lib.PredictOutput(predictions)
|
||||
|
||||
raise ValueError('Unknown problem_type=%s' % problem_type)
|
||||
|
||||
# Converts output_alternatives
|
||||
export_outputs_dict = None
|
||||
if self.output_alternatives:
|
||||
output_alternatives = self.output_alternatives
|
||||
# Adds default output_alternative if needed.
|
||||
if (len(output_alternatives) > 1 and
|
||||
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY not in
|
||||
output_alternatives):
|
||||
output_alternatives = output_alternatives.copy()
|
||||
output_alternatives[
|
||||
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (
|
||||
output_alternatives[default_serving_output_alternative_key])
|
||||
export_outputs_dict = {key: _export_output(*val) for key, val in
|
||||
output_alternatives.items()}
|
||||
|
||||
return core_model_fn_lib.EstimatorSpec(
|
||||
mode=mode,
|
||||
predictions=self.predictions,
|
||||
loss=self.loss,
|
||||
train_op=self.train_op,
|
||||
eval_metric_ops=self.eval_metric_ops,
|
||||
export_outputs=export_outputs_dict,
|
||||
training_chief_hooks=self.training_chief_hooks,
|
||||
training_hooks=self.training_hooks,
|
||||
scaffold=self.scaffold)
|
||||
|
@ -0,0 +1,279 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""ModelFnOps tests."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.estimators import constants
|
||||
from tensorflow.contrib.learn.python.learn.estimators import model_fn
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.estimator.export import export_output as core_export_lib
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import monitored_session
|
||||
|
||||
|
||||
class ModelFnopsTest(test.TestCase):
|
||||
"""Multi-output tests."""
|
||||
|
||||
def create_predictions(self):
|
||||
probabilities = constant_op.constant([1., 1., 1.])
|
||||
scores = constant_op.constant([1., 2., 3.])
|
||||
classes = constant_op.constant([b"0", b"1", b"2"])
|
||||
return {
|
||||
"probabilities": probabilities,
|
||||
"scores": scores,
|
||||
"classes": classes}
|
||||
|
||||
def create_model_fn_ops(self, predictions, output_alternatives,
|
||||
mode=model_fn.ModeKeys.INFER):
|
||||
|
||||
return model_fn.ModelFnOps(
|
||||
model_fn.ModeKeys.INFER,
|
||||
predictions=predictions,
|
||||
loss=constant_op.constant([1]),
|
||||
train_op=control_flow_ops.no_op(),
|
||||
eval_metric_ops={"metric_key": (control_flow_ops.no_op(),
|
||||
control_flow_ops.no_op())},
|
||||
# zzz
|
||||
training_chief_hooks=[basic_session_run_hooks.StepCounterHook()],
|
||||
training_hooks=[basic_session_run_hooks.StepCounterHook()],
|
||||
output_alternatives=output_alternatives,
|
||||
scaffold=monitored_session.Scaffold())
|
||||
|
||||
def assertEquals_except_export(self, model_fn_ops, estimator_spec):
|
||||
self.assertEqual(model_fn_ops.predictions, estimator_spec.predictions)
|
||||
self.assertEqual(model_fn_ops.loss, estimator_spec.loss)
|
||||
self.assertEqual(model_fn_ops.train_op, estimator_spec.train_op)
|
||||
self.assertEqual(model_fn_ops.eval_metric_ops,
|
||||
estimator_spec.eval_metric_ops)
|
||||
self.assertEqual(model_fn_ops.training_chief_hooks,
|
||||
estimator_spec.training_chief_hooks)
|
||||
self.assertEqual(model_fn_ops.training_hooks, estimator_spec.training_hooks)
|
||||
self.assertEqual(model_fn_ops.scaffold, estimator_spec.scaffold)
|
||||
|
||||
def testEstimatorSpec_except_export(self):
|
||||
predictions = self.create_predictions()
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, None)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
|
||||
def testEstimatorSpec_export_regression_with_scores(self):
|
||||
predictions = self.create_predictions()
|
||||
output_alternatives = {"regression_head": (
|
||||
constants.ProblemType.LINEAR_REGRESSION, predictions)}
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
regression_output = estimator_spec.export_outputs["regression_head"]
|
||||
self.assertTrue(isinstance(
|
||||
regression_output, core_export_lib.RegressionOutput))
|
||||
self.assertAllEqual(predictions["scores"].eval(),
|
||||
regression_output.value.eval())
|
||||
|
||||
def testEstimatorSpec_export_regression_with_probabilities(self):
|
||||
predictions = self.create_predictions()
|
||||
output_alternatives_predictions = predictions.copy()
|
||||
del output_alternatives_predictions["scores"]
|
||||
output_alternatives = {"regression_head": (
|
||||
constants.ProblemType.LINEAR_REGRESSION,
|
||||
output_alternatives_predictions)}
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
regression_output = estimator_spec.export_outputs["regression_head"]
|
||||
self.assertTrue(isinstance(
|
||||
regression_output, core_export_lib.RegressionOutput))
|
||||
self.assertAllEqual(predictions["probabilities"].eval(),
|
||||
regression_output.value.eval())
|
||||
|
||||
def testEstimatorSpec_export_classsification(self):
|
||||
predictions = self.create_predictions()
|
||||
output_alternatives = {"classification_head": (
|
||||
constants.ProblemType.CLASSIFICATION, predictions)}
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
classification_output = estimator_spec.export_outputs[
|
||||
"classification_head"]
|
||||
self.assertTrue(isinstance(classification_output,
|
||||
core_export_lib.ClassificationOutput))
|
||||
self.assertAllEqual(predictions["scores"].eval(),
|
||||
classification_output.scores.eval())
|
||||
self.assertAllEqual(predictions["classes"].eval(),
|
||||
classification_output.classes.eval())
|
||||
|
||||
def testEstimatorSpec_export_classsification_with_missing_scores(self):
|
||||
predictions = self.create_predictions()
|
||||
output_alternatives_predictions = predictions.copy()
|
||||
del output_alternatives_predictions["scores"]
|
||||
output_alternatives = {"classification_head": (
|
||||
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
classification_output = estimator_spec.export_outputs[
|
||||
"classification_head"]
|
||||
self.assertTrue(isinstance(classification_output,
|
||||
core_export_lib.ClassificationOutput))
|
||||
self.assertAllEqual(predictions["probabilities"].eval(),
|
||||
classification_output.scores.eval())
|
||||
self.assertAllEqual(predictions["classes"].eval(),
|
||||
classification_output.classes.eval())
|
||||
|
||||
def testEstimatorSpec_export_classsification_with_missing_scores_proba(self):
|
||||
predictions = self.create_predictions()
|
||||
output_alternatives_predictions = predictions.copy()
|
||||
del output_alternatives_predictions["scores"]
|
||||
del output_alternatives_predictions["probabilities"]
|
||||
output_alternatives = {"classification_head": (
|
||||
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
classification_output = estimator_spec.export_outputs[
|
||||
"classification_head"]
|
||||
self.assertTrue(isinstance(classification_output,
|
||||
core_export_lib.ClassificationOutput))
|
||||
self.assertIsNone(classification_output.scores)
|
||||
self.assertAllEqual(predictions["classes"].eval(),
|
||||
classification_output.classes.eval())
|
||||
|
||||
def testEstimatorSpec_export_classsification_with_missing_classes(self):
|
||||
predictions = self.create_predictions()
|
||||
output_alternatives_predictions = predictions.copy()
|
||||
del output_alternatives_predictions["classes"]
|
||||
output_alternatives = {"classification_head": (
|
||||
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
classification_output = estimator_spec.export_outputs[
|
||||
"classification_head"]
|
||||
self.assertTrue(isinstance(classification_output,
|
||||
core_export_lib.ClassificationOutput))
|
||||
self.assertAllEqual(predictions["scores"].eval(),
|
||||
classification_output.scores.eval())
|
||||
self.assertIsNone(classification_output.classes)
|
||||
|
||||
def testEstimatorSpec_export_classsification_with_nonstring_classes(self):
|
||||
predictions = self.create_predictions()
|
||||
output_alternatives_predictions = predictions.copy()
|
||||
output_alternatives_predictions["classes"] = constant_op.constant(
|
||||
[1, 2, 3])
|
||||
output_alternatives = {"classification_head": (
|
||||
constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)}
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
classification_output = estimator_spec.export_outputs[
|
||||
"classification_head"]
|
||||
self.assertTrue(isinstance(classification_output,
|
||||
core_export_lib.ClassificationOutput))
|
||||
self.assertAllEqual(predictions["scores"].eval(),
|
||||
classification_output.scores.eval())
|
||||
self.assertIsNone(classification_output.classes)
|
||||
|
||||
def testEstimatorSpec_export_logistic(self):
|
||||
predictions = self.create_predictions()
|
||||
output_alternatives = {"logistic_head": (
|
||||
constants.ProblemType.LOGISTIC_REGRESSION, predictions)}
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
logistic_output = estimator_spec.export_outputs["logistic_head"]
|
||||
self.assertTrue(isinstance(logistic_output,
|
||||
core_export_lib.ClassificationOutput))
|
||||
self.assertAllEqual(predictions["scores"].eval(),
|
||||
logistic_output.scores.eval())
|
||||
self.assertAllEqual(predictions["classes"].eval(),
|
||||
logistic_output.classes.eval())
|
||||
|
||||
def testEstimatorSpec_export_unspecified(self):
|
||||
predictions = self.create_predictions()
|
||||
output_alternatives = {"unspecified_head": (
|
||||
constants.ProblemType.UNSPECIFIED, predictions)}
|
||||
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
unspecified_output = estimator_spec.export_outputs["unspecified_head"]
|
||||
self.assertTrue(isinstance(unspecified_output,
|
||||
core_export_lib.PredictOutput))
|
||||
self.assertEqual(predictions, unspecified_output.outputs)
|
||||
|
||||
def testEstimatorSpec_export_multihead(self):
|
||||
predictions = self.create_predictions()
|
||||
output_alternatives = {
|
||||
"regression_head": (
|
||||
constants.ProblemType.LINEAR_REGRESSION, predictions),
|
||||
"classification_head": (
|
||||
constants.ProblemType.CLASSIFICATION, predictions)}
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER,
|
||||
"regression_head")
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
regression_output = estimator_spec.export_outputs["regression_head"]
|
||||
self.assertTrue(isinstance(
|
||||
regression_output, core_export_lib.RegressionOutput))
|
||||
self.assertAllEqual(predictions["scores"].eval(),
|
||||
regression_output.value.eval())
|
||||
|
||||
default_output = estimator_spec.export_outputs[
|
||||
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
||||
self.assertTrue(isinstance(default_output,
|
||||
core_export_lib.RegressionOutput))
|
||||
self.assertAllEqual(predictions["scores"].eval(),
|
||||
default_output.value.eval())
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -66,8 +66,8 @@ def _get_single_cell(cell_type, num_units):
|
||||
ValueError: `cell_type` is an invalid `RNNCell` name.
|
||||
TypeError: `cell_type` is not a string or a subclass of `RNNCell`.
|
||||
"""
|
||||
cell_type = _CELL_TYPES.get(cell_type)
|
||||
if cell_type is None and not issubclass(cell_type, contrib_rnn.RNNCell):
|
||||
cell_type = _CELL_TYPES.get(cell_type, cell_type)
|
||||
if not cell_type or not issubclass(cell_type, contrib_rnn.RNNCell):
|
||||
raise ValueError('The supported cell types are {}; got {}'.format(
|
||||
list(_CELL_TYPES.keys()), cell_type))
|
||||
return cell_type(num_units=num_units)
|
||||
|
@ -97,7 +97,8 @@ class Experiment(object):
|
||||
finite number of batches (generally, 1 epoch over the evaluation data).
|
||||
eval_metrics: `dict` of string, metric function. If `None`, default set
|
||||
is used. This should be `None` if the `estimator` is
|
||||
${tf.estimator.Estimator}.
|
||||
${tf.estimator.Estimator}. If metrics are provided they will be
|
||||
*appended* to the default set.
|
||||
train_steps: Perform this many steps of training. `None`, the default,
|
||||
means train forever.
|
||||
eval_steps: `evaluate` runs until input is exhausted (or another exception
|
||||
|
@ -45,7 +45,7 @@ class SquareLinearOperatorFullMatrixTest(
|
||||
# values are random and we want the same value used for both mat and
|
||||
# feed_dict.
|
||||
matrix = matrix.eval()
|
||||
operator = linalg.LinearOperatorFullMatrix(matrix)
|
||||
operator = linalg.LinearOperatorFullMatrix(matrix_ph)
|
||||
feed_dict = {matrix_ph: matrix}
|
||||
else:
|
||||
operator = linalg.LinearOperatorFullMatrix(matrix)
|
||||
@ -105,7 +105,7 @@ class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest(
|
||||
# feed_dict.
|
||||
matrix = matrix.eval()
|
||||
operator = linalg.LinearOperatorFullMatrix(
|
||||
matrix, is_self_adjoint=True, is_positive_definite=True)
|
||||
matrix_ph, is_self_adjoint=True, is_positive_definite=True)
|
||||
feed_dict = {matrix_ph: matrix}
|
||||
else:
|
||||
operator = linalg.LinearOperatorFullMatrix(
|
||||
@ -144,7 +144,7 @@ class NonSquareLinearOperatorFullMatrixTest(
|
||||
# values are random and we want the same value used for both mat and
|
||||
# feed_dict.
|
||||
matrix = matrix.eval()
|
||||
operator = linalg.LinearOperatorFullMatrix(matrix)
|
||||
operator = linalg.LinearOperatorFullMatrix(matrix_ph)
|
||||
feed_dict = {matrix_ph: matrix}
|
||||
else:
|
||||
operator = linalg.LinearOperatorFullMatrix(matrix)
|
||||
|
@ -12,13 +12,25 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Ops for nccl AllReduce."""
|
||||
"""Functions for using NVIDIA nccl collective ops.
|
||||
|
||||
@@all_max
|
||||
@@all_min
|
||||
@@all_prod
|
||||
@@all_sum
|
||||
@@broadcast
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# go/tf-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.nccl.python.ops.nccl_ops import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.contrib.nccl.python.ops.nccl_ops import all_max
|
||||
from tensorflow.contrib.nccl.python.ops.nccl_ops import all_min
|
||||
from tensorflow.contrib.nccl.python.ops.nccl_ops import all_prod
|
||||
from tensorflow.contrib.nccl.python.ops.nccl_ops import all_sum
|
||||
from tensorflow.contrib.nccl.python.ops.nccl_ops import broadcast
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
remove_undocumented(__name__)
|
||||
|
@ -66,7 +66,7 @@ class RNNCellTest(test.TestCase):
|
||||
x = array_ops.zeros([batch_size, input_size])
|
||||
m = array_ops.zeros([batch_size, state_size])
|
||||
output, state = rnn_cell.CoupledInputForgetGateLSTMCell(
|
||||
num_units=num_units, forget_bias=1.0)(x, m)
|
||||
num_units=num_units, forget_bias=1.0, state_is_tuple=False)(x, m)
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res = sess.run([output, state], {
|
||||
x.name:
|
||||
|
@ -466,12 +466,13 @@ class OutputProjectionWrapper(RNNCell):
|
||||
if needed or directly feed into a softmax.
|
||||
"""
|
||||
|
||||
def __init__(self, cell, output_size, reuse=None):
|
||||
def __init__(self, cell, output_size, activation=None, reuse=None):
|
||||
"""Create a cell with output projection.
|
||||
|
||||
Args:
|
||||
cell: an RNNCell, a projection to output_size is added to it.
|
||||
output_size: integer, the size of the output after projection.
|
||||
activation: (optional) an optional activation function.
|
||||
reuse: (optional) Python boolean describing whether to reuse variables
|
||||
in an existing scope. If not `True`, and the existing scope already has
|
||||
the given variables, an error is raised.
|
||||
@ -487,6 +488,7 @@ class OutputProjectionWrapper(RNNCell):
|
||||
self._cell = cell
|
||||
self._output_size = output_size
|
||||
self._reuse = reuse
|
||||
self._activation = activation
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
@ -507,6 +509,8 @@ class OutputProjectionWrapper(RNNCell):
|
||||
with _checked_scope(self, scope or "output_projection_wrapper",
|
||||
reuse=self._reuse):
|
||||
projected = _linear(output, self._output_size, True)
|
||||
if self._activation:
|
||||
projected = self._activation(projected)
|
||||
return projected, res_state
|
||||
|
||||
|
||||
@ -518,12 +522,13 @@ class InputProjectionWrapper(RNNCell):
|
||||
do the projection on this batch-concatenated sequence, then split it.
|
||||
"""
|
||||
|
||||
def __init__(self, cell, num_proj, input_size=None):
|
||||
def __init__(self, cell, num_proj, activation=None, input_size=None):
|
||||
"""Create a cell with input projection.
|
||||
|
||||
Args:
|
||||
cell: an RNNCell, a projection of inputs is added before it.
|
||||
num_proj: Python integer. The dimension to project to.
|
||||
activation: (optional) an optional activation function.
|
||||
input_size: Deprecated and unused.
|
||||
|
||||
Raises:
|
||||
@ -535,6 +540,7 @@ class InputProjectionWrapper(RNNCell):
|
||||
raise TypeError("The parameter cell is not RNNCell.")
|
||||
self._cell = cell
|
||||
self._num_proj = num_proj
|
||||
self._activation = activation
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
@ -553,6 +559,8 @@ class InputProjectionWrapper(RNNCell):
|
||||
# Default scope: "InputProjectionWrapper"
|
||||
with vs.variable_scope(scope or "input_projection_wrapper"):
|
||||
projected = _linear(inputs, self._num_proj, True)
|
||||
if self._activation:
|
||||
projected = self._activation(projected)
|
||||
return self._cell(projected, state)
|
||||
|
||||
|
||||
|
@ -109,7 +109,7 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
|
||||
def __init__(self, num_units, use_peepholes=False,
|
||||
initializer=None, num_proj=None, proj_clip=None,
|
||||
num_unit_shards=1, num_proj_shards=1,
|
||||
forget_bias=1.0, state_is_tuple=False,
|
||||
forget_bias=1.0, state_is_tuple=True,
|
||||
activation=math_ops.tanh, reuse=None):
|
||||
"""Initialize the parameters for an LSTM cell.
|
||||
|
||||
@ -457,7 +457,7 @@ class GridLSTMCell(core_rnn_cell.RNNCell):
|
||||
start_freqindex_list=None,
|
||||
end_freqindex_list=None,
|
||||
couple_input_forget_gates=False,
|
||||
state_is_tuple=False,
|
||||
state_is_tuple=True,
|
||||
reuse=None):
|
||||
"""Initialize the parameters for an LSTM cell.
|
||||
|
||||
@ -571,7 +571,7 @@ class GridLSTMCell(core_rnn_cell.RNNCell):
|
||||
ValueError: if an input_size was specified and the provided inputs have
|
||||
a different dimension.
|
||||
"""
|
||||
batch_size = int(inputs.get_shape()[0])
|
||||
batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
|
||||
freq_inputs = self._make_tf_features(inputs)
|
||||
with _checked_scope(self, scope or "grid_lstm_cell",
|
||||
initializer=self._initializer, reuse=self._reuse):
|
||||
@ -994,7 +994,7 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
|
||||
ValueError: if an input_size was specified and the provided inputs have
|
||||
a different dimension.
|
||||
"""
|
||||
batch_size = int(inputs.get_shape()[0])
|
||||
batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
|
||||
fwd_inputs = self._make_tf_features(inputs)
|
||||
if self._backward_slice_offset:
|
||||
bwd_inputs = self._make_tf_features(inputs, self._backward_slice_offset)
|
||||
@ -1043,7 +1043,7 @@ class AttentionCellWrapper(core_rnn_cell.RNNCell):
|
||||
"""
|
||||
|
||||
def __init__(self, cell, attn_length, attn_size=None, attn_vec_size=None,
|
||||
input_size=None, state_is_tuple=False, reuse=None):
|
||||
input_size=None, state_is_tuple=True, reuse=None):
|
||||
"""Create a cell with attention.
|
||||
|
||||
Args:
|
||||
|
@ -56,6 +56,13 @@ class AttentionWrapperTest(test.TestCase):
|
||||
return super(AttentionWrapperTest, self).assertAllClose(
|
||||
*args, **kwargs)
|
||||
|
||||
def testAttentionWrapperState(self):
|
||||
num_fields = len(wrapper.AttentionWrapperState._fields) # pylint: disable=protected-access
|
||||
state = wrapper.AttentionWrapperState(*([None] * num_fields))
|
||||
new_state = state.clone(time=1)
|
||||
self.assertEqual(state.time, None)
|
||||
self.assertEqual(new_state.time, 1)
|
||||
|
||||
def _testWithAttention(self,
|
||||
create_attention_mechanism,
|
||||
expected_final_output,
|
||||
|
@ -369,7 +369,26 @@ class AttentionWrapperState(
|
||||
- `attention_history`: (if enabled) a `TensorArray` containing attention
|
||||
matrices from all time steps. Call `stack()` to convert to a `Tensor`.
|
||||
"""
|
||||
pass
|
||||
|
||||
def clone(self, **kwargs):
|
||||
"""Clone this object, overriding components provided by kwargs.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
initial_state = attention_wrapper.zero_state(dtype=..., batch_size=...)
|
||||
initial_state = initial_state.clone(cell_state=encoder_state)
|
||||
```
|
||||
|
||||
Args:
|
||||
**kwargs: Any properties of the state object to replace in the returned
|
||||
`AttentionWrapperState`.
|
||||
|
||||
Returns:
|
||||
A new `AttentionWrapperState` whose properties are the same as
|
||||
this one, except any overriden properties as provided in `kwargs`.
|
||||
"""
|
||||
return super(AttentionWrapperState, self)._replace(**kwargs)
|
||||
|
||||
|
||||
def hardmax(logits, name=None):
|
||||
|
@ -431,7 +431,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper):
|
||||
shape=base_shape))
|
||||
|
||||
all_finished = math_ops.reduce_all(finished)
|
||||
no_samples = math_ops.equal(array_ops.shape(sample_ids)[0], 0)
|
||||
no_samples = math_ops.logical_not(math_ops.reduce_any(sample_ids))
|
||||
next_inputs = control_flow_ops.cond(
|
||||
math_ops.logical_or(all_finished, no_samples),
|
||||
lambda: base_next_inputs, maybe_sample)
|
||||
|
@ -31,6 +31,8 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import session_run_hook
|
||||
|
||||
|
||||
@ -95,6 +97,22 @@ class TensorForestLossHook(session_run_hook.SessionRunHook):
|
||||
run_context.request_stop()
|
||||
|
||||
|
||||
class EveryCheckpointPreSaveListener(
|
||||
basic_session_run_hooks.CheckpointSaverListener):
|
||||
"""Runs a given op before each checkpoint save."""
|
||||
|
||||
def __init__(self, op):
|
||||
"""Initializes the object.
|
||||
|
||||
Args:
|
||||
op: An op to run before each checkpoint save.
|
||||
"""
|
||||
self._op = op
|
||||
|
||||
def before_save(self, session, global_step_value):
|
||||
session.run(self._op)
|
||||
|
||||
|
||||
def get_model_fn(params,
|
||||
graph_builder_class,
|
||||
device_assigner,
|
||||
@ -103,6 +121,7 @@ def get_model_fn(params,
|
||||
num_trainers=1,
|
||||
trainer_id=0,
|
||||
report_feature_importances=False,
|
||||
model_dir=None,
|
||||
local_eval=False):
|
||||
"""Return a model function given a way to construct a graph builder."""
|
||||
def _model_fn(features, labels, mode):
|
||||
@ -138,6 +157,8 @@ def get_model_fn(params,
|
||||
# question of why we force everything to adhere to a single model_fn).
|
||||
loss_deps = []
|
||||
training_graph = None
|
||||
training_hooks = []
|
||||
scaffold = None
|
||||
if labels is not None and mode == model_fn_lib.ModeKeys.TRAIN:
|
||||
training_graph = control_flow_ops.group(
|
||||
graph_builder.training_graph(
|
||||
@ -146,6 +167,15 @@ def get_model_fn(params,
|
||||
trainer_id=trainer_id),
|
||||
state_ops.assign_add(contrib_framework.get_global_step(), 1))
|
||||
loss_deps.append(training_graph)
|
||||
if hasattr(graph_builder, 'finalize_training'):
|
||||
finalize_listener = EveryCheckpointPreSaveListener(
|
||||
graph_builder.finalize_training())
|
||||
scaffold = monitored_session.Scaffold()
|
||||
training_hooks.append(
|
||||
basic_session_run_hooks.CheckpointSaverHook(
|
||||
model_dir, save_secs=600, save_steps=None,
|
||||
scaffold=scaffold,
|
||||
listeners=[finalize_listener]))
|
||||
|
||||
training_loss = None
|
||||
if (mode == model_fn_lib.ModeKeys.EVAL or
|
||||
@ -158,7 +188,6 @@ def get_model_fn(params,
|
||||
if weights is not None:
|
||||
features[weights_name] = weights
|
||||
|
||||
training_hooks = []
|
||||
if early_stopping_rounds:
|
||||
training_hooks.append(TensorForestLossHook(early_stopping_rounds))
|
||||
|
||||
@ -167,7 +196,9 @@ def get_model_fn(params,
|
||||
predictions=inference,
|
||||
loss=training_loss,
|
||||
train_op=training_graph,
|
||||
training_hooks=training_hooks)
|
||||
training_hooks=training_hooks,
|
||||
scaffold=scaffold)
|
||||
|
||||
return _model_fn
|
||||
|
||||
|
||||
@ -257,6 +288,7 @@ class TensorForestEstimator(estimator.Estimator):
|
||||
num_trainers=num_trainers,
|
||||
trainer_id=trainer_id,
|
||||
report_feature_importances=report_feature_importances,
|
||||
model_dir=model_dir,
|
||||
local_eval=local_eval),
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
|
@ -43,9 +43,9 @@ py_library(
|
||||
srcs = ["plugins/projector/__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":protos_all_py",
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/tensorboard/plugins/projector:projector_plugin",
|
||||
"//tensorflow/tensorboard/plugins/projector:protos_all_py",
|
||||
],
|
||||
)
|
||||
|
||||
@ -56,10 +56,10 @@ py_test(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":projector",
|
||||
":protos_all_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:summary",
|
||||
"//tensorflow/tensorboard/plugins/projector:protos_all_py",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -28,11 +28,10 @@ from __future__ import print_function
|
||||
import os
|
||||
|
||||
from google.protobuf import text_format
|
||||
from tensorflow.contrib.tensorboard.plugins.projector.projector_config_pb2 import EmbeddingInfo
|
||||
from tensorflow.contrib.tensorboard.plugins.projector.projector_config_pb2 import ProjectorConfig
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.tensorboard.plugins.projector import projector_plugin
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.tensorboard.plugins.projector.projector_config_pb2 import *
|
||||
from tensorflow.tensorboard.plugins.projector.projector_plugin import *
|
||||
# pylint: enable=wildcard-import
|
||||
|
||||
|
@ -24,10 +24,10 @@ import shutil
|
||||
from google.protobuf import text_format
|
||||
|
||||
from tensorflow.contrib.tensorboard.plugins import projector
|
||||
from tensorflow.contrib.tensorboard.plugins.projector import projector_config_pb2
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.summary.writer import writer as writer_lib
|
||||
from tensorflow.tensorboard.plugins.projector import projector_config_pb2
|
||||
|
||||
|
||||
class ProjectorApiTest(test.TestCase):
|
||||
|
62
tensorflow/contrib/xla_tf_graph/BUILD
Normal file
62
tensorflow/contrib/xla_tf_graph/BUILD
Normal file
@ -0,0 +1,62 @@
|
||||
# Description:
|
||||
# contains parts of TensorFlow that are experimental or unstable and which are not supported.
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_tf_graph_util",
|
||||
srcs = [
|
||||
"xla_tf_graph_util.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"xla_tf_graph_util.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla/client",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "xla_tf_graph_util_test",
|
||||
srcs = ["xla_tf_graph_util_test.cc"],
|
||||
deps = [
|
||||
":xla_tf_graph_util",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:function_ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/compiler/jit:xla_cpu_jit",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
],
|
||||
)
|
8
tensorflow/contrib/xla_tf_graph/README.md
Normal file
8
tensorflow/contrib/xla_tf_graph/README.md
Normal file
@ -0,0 +1,8 @@
|
||||
# Xla Tf Graph
|
||||
|
||||
## Description
|
||||
|
||||
This module contains utilities to treat xla representation as tf graph to support mobile SOC experiments and leverage tf tools.
|
||||
|
||||
Maintainers:
|
||||
- Satoshi Kataoka (satok@google.com, github.com/satok16)
|
71
tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc
Normal file
71
tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc
Normal file
@ -0,0 +1,71 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace xla_tf_graph {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr const char* const GRAPH_NAME = "xla_tf_graph_util";
|
||||
|
||||
void SetupXlaCpuClient(std::unique_ptr<FunctionLibraryDefinition>* flib_def,
|
||||
std::unique_ptr<FunctionLibraryRuntime>* flr,
|
||||
std::unique_ptr<XlaCompiler>* compiler) {
|
||||
xla::Client* client = xla::ClientLibrary::LocalClientOrDie();
|
||||
XlaOpRegistry::RegisterCompilationKernels();
|
||||
|
||||
FunctionDefLibrary flib;
|
||||
flib_def->reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
|
||||
|
||||
// Setup compiler options
|
||||
XlaCompiler::Options options;
|
||||
options.device_type = DeviceType(DEVICE_CPU_XLA_JIT);
|
||||
options.client = client;
|
||||
compiler->reset(new XlaCompiler(options));
|
||||
|
||||
flr->reset(NewFunctionLibraryRuntime(
|
||||
compiler->get()->device_mgr(), /*env=*/nullptr, compiler->get()->device(),
|
||||
TF_GRAPH_DEF_VERSION, flib_def->get(), OptimizerOptions(),
|
||||
/*custom_kernel_creator=*/nullptr));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
xla::StatusOr<std::unique_ptr<xla::SessionModule>>
|
||||
ConvertTfGraphToXlaSessionModule(const std::vector<XlaCompiler::Argument>& args,
|
||||
std::unique_ptr<Graph> graph) {
|
||||
CHECK(graph);
|
||||
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def;
|
||||
std::unique_ptr<FunctionLibraryRuntime> flr;
|
||||
std::unique_ptr<XlaCompiler> compiler;
|
||||
|
||||
SetupXlaCpuClient(&flib_def, &flr, &compiler);
|
||||
|
||||
// Compile graph and build computation
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_CHECK_OK(compiler->CompileGraph(GRAPH_NAME, std::move(graph), flr.get(),
|
||||
args, &result));
|
||||
|
||||
return result.computation.Snapshot();
|
||||
}
|
||||
|
||||
} // namespace xla_tf_graph
|
||||
} // namespace tensorflow
|
43
tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h
Normal file
43
tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h
Normal file
@ -0,0 +1,43 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_XLA_TF_GRAPH_XLA_TF_GRAPH_UTIL_H_
|
||||
#define TENSORFLOW_CONTRIB_XLA_TF_GRAPH_XLA_TF_GRAPH_UTIL_H_
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/xla/client/client.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace xla_tf_graph {
|
||||
|
||||
// A set of utilities to handle xla computation requests.
|
||||
// These utilities help developers leverage existing tools to work with
|
||||
// xla computations, also provide a way to support TensorFlow ops by
|
||||
// implementing xla computations so that they can do experiments on their
|
||||
// specialized environments.
|
||||
|
||||
// Convert a tf graph to a xla session module
|
||||
xla::StatusOr<std::unique_ptr<xla::SessionModule>>
|
||||
ConvertTfGraphToXlaSessionModule(const std::vector<XlaCompiler::Argument>& args,
|
||||
std::unique_ptr<Graph> graph);
|
||||
|
||||
} // namespace xla_tf_graph
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_XLA_TF_GRAPH_XLA_TF_GRAPH_UTIL_H_
|
57
tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc
Normal file
57
tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc
Normal file
@ -0,0 +1,57 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h"
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace xla_tf_graph {
|
||||
|
||||
static std::unique_ptr<Graph> BuildAddGraph() {
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
|
||||
auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
|
||||
auto c = ops::Add(scope.WithOpName("C"), a, b);
|
||||
auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_CHECK_OK(scope.ToGraph(graph.get()));
|
||||
return graph;
|
||||
}
|
||||
|
||||
TEST(XlaTfGraphUtil, ConvertTfGraphToHloModule) {
|
||||
// Builds a description of the arguments.
|
||||
std::vector<XlaCompiler::Argument> args(2);
|
||||
args[0].kind = XlaCompiler::Argument::kParameter;
|
||||
args[0].type = DT_INT32;
|
||||
args[0].shape = TensorShape({2});
|
||||
args[1].kind = XlaCompiler::Argument::kParameter;
|
||||
args[1].type = DT_INT32;
|
||||
args[1].shape = TensorShape({2});
|
||||
|
||||
std::unique_ptr<Graph> graph = BuildAddGraph();
|
||||
|
||||
TF_ASSIGN_OR_ASSERT_OK(
|
||||
std::unique_ptr<xla::SessionModule> session_module,
|
||||
ConvertTfGraphToXlaSessionModule(args, std::move(graph)));
|
||||
|
||||
ASSERT_EQ(5, session_module->entry().requests_size());
|
||||
}
|
||||
|
||||
} // namespace xla_tf_graph
|
||||
} // namespace tensorflow
|
@ -453,8 +453,8 @@ void BFCAllocator::RemoveFreeChunkIterFromBin(
|
||||
void BFCAllocator::RemoveFreeChunkFromBin(BFCAllocator::ChunkHandle h) {
|
||||
Chunk* c = ChunkFromHandle(h);
|
||||
CHECK(!c->in_use() && (c->bin_num != kInvalidBinNum));
|
||||
int count = BinFromIndex(c->bin_num)->free_chunks.erase(h);
|
||||
CHECK(count > 0) << "Could not find chunk in bin";
|
||||
CHECK_GT(BinFromIndex(c->bin_num)->free_chunks.erase(h), 0)
|
||||
<< "Could not find chunk in bin";
|
||||
c->bin_num = kInvalidBinNum;
|
||||
}
|
||||
|
||||
|
@ -78,7 +78,7 @@ class BFCAllocator : public VisitableAllocator {
|
||||
|
||||
// A ChunkHandle is an index into the chunks_ vector in BFCAllocator
|
||||
// kInvalidChunkHandle means an invalid chunk
|
||||
typedef int ChunkHandle;
|
||||
typedef size_t ChunkHandle;
|
||||
static const int kInvalidChunkHandle = -1;
|
||||
|
||||
typedef int BinNum;
|
||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/graph/subgraph.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
@ -304,10 +305,18 @@ Status DoConstantFoldingWithStatus(const ConstantFoldingOptions& opts,
|
||||
tensors_to_replace.push_back({n.second, n.first.second});
|
||||
}
|
||||
|
||||
auto graph_runner = std::unique_ptr<GraphRunner>(new GraphRunner(env));
|
||||
// Evaluate the constant foldable nodes.
|
||||
std::vector<Tensor> outputs;
|
||||
Status s = GraphRunner::Run(constant_graph.get(), function_library, env,
|
||||
{} /* inputs*/, tensors_to_fetch_names, &outputs);
|
||||
auto delete_tensors = gtl::MakeCleanup([&graph_runner, &outputs] {
|
||||
// Output tensors need to be cleared before the GraphRunner is deleted.
|
||||
outputs.clear();
|
||||
graph_runner.reset(nullptr);
|
||||
});
|
||||
|
||||
Status s =
|
||||
graph_runner->Run(constant_graph.get(), function_library, {} /* inputs*/,
|
||||
tensors_to_fetch_names, &outputs);
|
||||
if (!s.ok()) {
|
||||
VLOG(1) << "Could not fetch constants: " << s;
|
||||
*was_mutated = false;
|
||||
|
@ -44,7 +44,7 @@ DeviceMgr::~DeviceMgr() {
|
||||
}
|
||||
|
||||
StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
|
||||
int n = s.size();
|
||||
size_t n = s.size();
|
||||
char* space = name_backing_store_.Alloc(n);
|
||||
memcpy(space, s.data(), n);
|
||||
return StringPiece(space, n);
|
||||
|
@ -427,7 +427,7 @@ Status DirectSession::Run(const RunOptions& run_options,
|
||||
TF_RETURN_IF_ERROR(SendInputs(inputs, executors_and_keys, run_state.rendez));
|
||||
|
||||
// Start parallel Executors.
|
||||
const int num_executors = executors_and_keys->items.size();
|
||||
const size_t num_executors = executors_and_keys->items.size();
|
||||
ExecutorBarrier* barrier = new ExecutorBarrier(
|
||||
num_executors, run_state.rendez, [&run_state](const Status& ret) {
|
||||
{
|
||||
@ -458,7 +458,7 @@ Status DirectSession::Run(const RunOptions& run_options,
|
||||
options_.config.graph_options().build_cost_model();
|
||||
const int64 build_cost_model_after =
|
||||
options_.config.graph_options().build_cost_model_after();
|
||||
int measure_step_count = executor_step_count - build_cost_model_after;
|
||||
int64 measure_step_count = executor_step_count - build_cost_model_after;
|
||||
if (measure_step_count >= 0) {
|
||||
update_cost_model =
|
||||
((measure_step_count + 1) % build_cost_model_every == 0);
|
||||
@ -611,7 +611,7 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
|
||||
}
|
||||
|
||||
// Start parallel Executors.
|
||||
const int num_executors = executors_and_keys->items.size();
|
||||
const size_t num_executors = executors_and_keys->items.size();
|
||||
ExecutorBarrier* barrier = new ExecutorBarrier(
|
||||
num_executors, run_state->rendez, [run_state](const Status& ret) {
|
||||
if (!ret.ok()) {
|
||||
|
@ -232,7 +232,7 @@ struct NodeItem {
|
||||
int input_start = 0;
|
||||
|
||||
// Number of output edges.
|
||||
int num_output_edges;
|
||||
size_t num_output_edges;
|
||||
|
||||
PendingCounts::Handle pending_id;
|
||||
|
||||
@ -307,7 +307,7 @@ class GraphView {
|
||||
void Initialize(const Graph* g);
|
||||
Status SetAllocAttrs(const Graph* g, const Device* device);
|
||||
|
||||
NodeItem* node(int id) const {
|
||||
NodeItem* node(size_t id) const {
|
||||
DCHECK_GE(id, 0);
|
||||
DCHECK_LT(id, num_nodes_);
|
||||
uint32 offset = node_offsets_[id];
|
||||
@ -454,7 +454,7 @@ GraphView::~GraphView() {
|
||||
}
|
||||
|
||||
size_t GraphView::NodeItemBytes(const Node* n) {
|
||||
const int num_output_edges = n->out_edges().size();
|
||||
const size_t num_output_edges = n->out_edges().size();
|
||||
const int num_inputs = n->num_inputs();
|
||||
const int num_outputs = n->num_outputs();
|
||||
|
||||
@ -500,11 +500,11 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) {
|
||||
// pointers). Casting to int64 is needed on 32bit CPU to avoid comparing
|
||||
// values as "int" vs "size_t" in CHECK_LE.
|
||||
CHECK_LE(static_cast<int64>(ptr - space_), kuint32max);
|
||||
const uint32 offset = ptr - space_;
|
||||
const uint32 offset = static_cast<uint32>(ptr - space_);
|
||||
node_offsets_[id] = offset;
|
||||
ptr += bytes;
|
||||
|
||||
const int num_output_edges = n->out_edges().size();
|
||||
const size_t num_output_edges = n->out_edges().size();
|
||||
const int num_inputs = n->num_inputs();
|
||||
const int num_outputs = n->num_outputs();
|
||||
|
||||
@ -580,9 +580,10 @@ void GraphView::Initialize(const Graph* g) {
|
||||
CHECK_EQ(ptr, space_ + total_bytes);
|
||||
}
|
||||
|
||||
void GetMaxPendingCounts(const Node* n, int* max_pending, int* max_dead_count) {
|
||||
const int num_in_edges = n->in_edges().size();
|
||||
int initial_count;
|
||||
void GetMaxPendingCounts(const Node* n, size_t* max_pending,
|
||||
size_t* max_dead_count) {
|
||||
const size_t num_in_edges = n->in_edges().size();
|
||||
size_t initial_count;
|
||||
if (IsMerge(n)) {
|
||||
// merge waits all control inputs so we initialize the pending
|
||||
// count to be the number of control edges.
|
||||
@ -626,8 +627,7 @@ Status ExecutorImpl::Initialize() {
|
||||
FrameInfo* frame_info = EnsureFrameInfo(frame_name);
|
||||
|
||||
// See if this node is a root node, and if so, add to root_nodes_.
|
||||
const int num_in_edges = n->in_edges().size();
|
||||
if (num_in_edges == 0) {
|
||||
if (n->in_edges().empty()) {
|
||||
root_nodes_.push_back(n);
|
||||
}
|
||||
|
||||
@ -659,7 +659,7 @@ Status ExecutorImpl::Initialize() {
|
||||
// pending counts data structure, and allocate a handle in
|
||||
// that frame's pending counts data structure that has enough
|
||||
// space to store these maximal count values.
|
||||
int max_pending, max_dead;
|
||||
size_t max_pending, max_dead;
|
||||
GetMaxPendingCounts(n, &max_pending, &max_dead);
|
||||
item->pending_id =
|
||||
frame_info->pending_counts_layout.CreateHandle(max_pending, max_dead);
|
||||
@ -896,7 +896,7 @@ class ExecutorState {
|
||||
Entry* input_tensors;
|
||||
|
||||
// The number of outstanding ops for each iteration.
|
||||
int outstanding_ops;
|
||||
size_t outstanding_ops;
|
||||
|
||||
// The number of outstanding frames for each iteration.
|
||||
int outstanding_frame_count;
|
||||
@ -1037,13 +1037,13 @@ class ExecutorState {
|
||||
|
||||
inline IterationState* GetIteration(int64 iter)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu) {
|
||||
int index = iter % iterations.size();
|
||||
size_t index = iter % iterations.size();
|
||||
return iterations[index];
|
||||
}
|
||||
|
||||
inline void SetIteration(int64 iter, IterationState* state)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu) {
|
||||
int index = iter % iterations.size();
|
||||
size_t index = iter % iterations.size();
|
||||
DCHECK(state == nullptr || iterations[index] == nullptr);
|
||||
iterations[index] = state;
|
||||
}
|
||||
@ -1404,7 +1404,7 @@ void ExecutorImpl::InitializePending(const Graph* graph,
|
||||
for (const Node* n : graph->nodes()) {
|
||||
const int id = n->id();
|
||||
const string& name = cf_info.frame_names[id];
|
||||
int max_pending, max_dead;
|
||||
size_t max_pending, max_dead;
|
||||
GetMaxPendingCounts(n, &max_pending, &max_dead);
|
||||
const NodeItem* item = gview_.node(id);
|
||||
PendingCounts* counts = EnsureFrameInfo(name)->pending_counts;
|
||||
@ -2027,7 +2027,7 @@ bool ExecutorState::NodeDone(const Status& s, const Node* node,
|
||||
}
|
||||
|
||||
bool completed = false;
|
||||
int ready_size = ready.size();
|
||||
size_t ready_size = ready.size();
|
||||
if (ready_size == 0 || !s.ok()) {
|
||||
completed = (num_outstanding_ops_.fetch_sub(1) == 1);
|
||||
} else if (ready_size > 1) {
|
||||
@ -2375,10 +2375,10 @@ void ExecutorState::FrameState::ActivateNodes(const NodeItem* item,
|
||||
TaggedNodeSeq* ready) {
|
||||
const GraphView& gview = executor->gview_;
|
||||
IterationState* iter_state = GetIteration(iter);
|
||||
const int num_output_edges = item->num_output_edges;
|
||||
const size_t num_output_edges = item->num_output_edges;
|
||||
const EdgeInfo* edges = item->output_edge_list();
|
||||
Entry* input_tensors = iter_state->input_tensors;
|
||||
for (int out_index = 0; out_index < num_output_edges; out_index++) {
|
||||
for (size_t out_index = 0; out_index < num_output_edges; out_index++) {
|
||||
const EdgeInfo& e = edges[out_index];
|
||||
const int dst_id = e.dst_id;
|
||||
const NodeItem* dst_item = gview.node(dst_id);
|
||||
|
@ -162,7 +162,7 @@ class ExecutorBarrier {
|
||||
//
|
||||
// 'done' is called after the last executor completes, and
|
||||
// ExecutorBarrier is deleted.
|
||||
ExecutorBarrier(int num, Rendezvous* r, StatusCallback done)
|
||||
ExecutorBarrier(size_t num, Rendezvous* r, StatusCallback done)
|
||||
: rendez_(r), done_cb_(done), pending_(num) {}
|
||||
|
||||
~ExecutorBarrier() {}
|
||||
|
@ -274,8 +274,9 @@ class CallOp : public AsyncOpKernel {
|
||||
if (!status.ok()) {
|
||||
ctx->SetStatus(status);
|
||||
} else {
|
||||
CHECK_EQ(rets->size(), ctx->num_outputs());
|
||||
for (size_t i = 0; i < rets->size(); ++i) {
|
||||
const int ret_size = static_cast<int>(rets->size());
|
||||
CHECK_EQ(ret_size, ctx->num_outputs());
|
||||
for (int i = 0; i < ret_size; ++i) {
|
||||
ctx->set_output(i, (*rets)[i]);
|
||||
}
|
||||
}
|
||||
@ -1000,7 +1001,7 @@ string NewName(const Node* n, bool pretty) {
|
||||
void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) {
|
||||
// We visit nodes in forward topological sort order, which is a
|
||||
// possible execution order of the graph.
|
||||
std::vector<int> pending(g->num_node_ids());
|
||||
std::vector<size_t> pending(g->num_node_ids());
|
||||
std::deque<const Node*> ready;
|
||||
for (const Node* n : g->nodes()) {
|
||||
pending[n->id()] = n->in_edges().size();
|
||||
@ -1154,7 +1155,7 @@ FunctionBody* SymbolicGradientHelper::Compute() {
|
||||
|
||||
Graph* g = gbody_->graph;
|
||||
|
||||
const int num_y = gbody_->ret_nodes.size();
|
||||
const int num_y = static_cast<int>(gbody_->ret_nodes.size());
|
||||
|
||||
// Populate 'y_node_outputs_' with node function body outputs.
|
||||
// Populate 'y_grad_nodes' with initial gradient nodes for each return node of
|
||||
@ -1169,7 +1170,7 @@ FunctionBody* SymbolicGradientHelper::Compute() {
|
||||
y_node_outputs.push_back({y, 0});
|
||||
DCHECK_EQ(y->type_string(), kRetOp);
|
||||
const DataType dtype = y->input_type(0);
|
||||
const int index = gbody_->arg_nodes.size();
|
||||
const int index = static_cast<int>(gbody_->arg_nodes.size());
|
||||
Node* dy = AddArg(g, dtype, index);
|
||||
gbody_->arg_types.push_back(dtype);
|
||||
gbody_->arg_nodes.push_back(dy);
|
||||
@ -1177,7 +1178,7 @@ FunctionBody* SymbolicGradientHelper::Compute() {
|
||||
}
|
||||
|
||||
// Populate 'x_nodes' with function args (excluding 'y_grad_node_outputs').
|
||||
const int num_x = fbody_->arg_nodes.size();
|
||||
const size_t num_x = fbody_->arg_nodes.size();
|
||||
std::vector<NodeOut> x_node_outputs;
|
||||
x_node_outputs.reserve(num_x);
|
||||
for (size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
|
||||
@ -1200,7 +1201,8 @@ FunctionBody* SymbolicGradientHelper::Compute() {
|
||||
gbody_->ret_nodes.clear();
|
||||
// Add new return nodes to the function gradient body for each node
|
||||
// in 'x_grad_nodes'.
|
||||
for (size_t i = 0; i < fbody_->arg_types.size(); ++i) {
|
||||
const int arg_types_size = static_cast<int>(fbody_->arg_types.size());
|
||||
for (int i = 0; i < arg_types_size; ++i) {
|
||||
Endpoint grad = {x_grad_node_outputs[i].node, x_grad_node_outputs[i].index};
|
||||
Node* ret = AddRet(g, grad, i);
|
||||
gbody_->ret_nodes.push_back(ret);
|
||||
|
@ -82,7 +82,7 @@ Status AssignStreams(const Graph* graph, const AssignStreamsOpts& opts,
|
||||
// Determine a suitable stream to use.
|
||||
int stream_id = highest_stream_id + 1;
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
const int fanout = e->src()->out_edges().size();
|
||||
const size_t fanout = e->src()->out_edges().size();
|
||||
if (fanout == 1) {
|
||||
stream_id = (*node_to_stream_id)[e->src()->id()];
|
||||
break;
|
||||
|
@ -191,7 +191,7 @@ Allocator* ProcessState::GetCUDAHostAllocator(int numa_node) {
|
||||
// example, process_state could maybe save the first stream executor
|
||||
// it knows is valid.
|
||||
gpu::StreamExecutor* se = nullptr;
|
||||
for (size_t i = 0; i < gpu_allocators_.size(); ++i) {
|
||||
for (int i = 0; i < static_cast<int>(gpu_allocators_.size()); ++i) {
|
||||
if (gpu_allocators_[i] != nullptr) {
|
||||
se = GPUMachineManager()->ExecutorForDevice(i).ValueOrDie();
|
||||
break;
|
||||
|
@ -15,7 +15,6 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/common_runtime/graph_runner.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/executor.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/memory_types.h"
|
||||
@ -95,22 +94,24 @@ class SimpleRendezvous : public Rendezvous {
|
||||
|
||||
} // namespace
|
||||
|
||||
// static
|
||||
GraphRunner::GraphRunner(Env* env) : cpu_device_(GetCPUDevice(env)) {}
|
||||
|
||||
GraphRunner::~GraphRunner() {}
|
||||
|
||||
Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
|
||||
Env* env, const NamedTensorList& inputs,
|
||||
const NamedTensorList& inputs,
|
||||
const std::vector<string>& output_names,
|
||||
std::vector<Tensor>* outputs) {
|
||||
if (cpu_device_ == nullptr) {
|
||||
return errors::NotFound("Cannot find a device for GraphRunner.");
|
||||
}
|
||||
|
||||
// TODO(vrv): Instead of copying the entire graph, consider modifying
|
||||
// the existing graph, and then removing those removed edges.
|
||||
// prior to returning.
|
||||
std::unique_ptr<Graph> graph_to_run(new Graph(graph->op_registry()));
|
||||
CopyGraph(*graph, graph_to_run.get());
|
||||
|
||||
std::unique_ptr<Device> device = GetCPUDevice(env);
|
||||
if (!device) {
|
||||
return errors::NotFound("Cannot find a device for GraphRunner.");
|
||||
}
|
||||
|
||||
SimpleRendezvous* rendez = new SimpleRendezvous;
|
||||
core::ScopedUnref rendez_unref(rendez);
|
||||
|
||||
@ -130,7 +131,7 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
|
||||
// Call RewriteGraphForExecution
|
||||
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
|
||||
graph_to_run.get(), input_names, output_names, {} /* target nodes */,
|
||||
device->attributes()));
|
||||
cpu_device_->attributes()));
|
||||
|
||||
// Create the local executor and the Rendezvous for fetching back the
|
||||
// constants.
|
||||
@ -143,10 +144,11 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
|
||||
Graph* g = graph_to_run.release();
|
||||
|
||||
LocalExecutorParams params;
|
||||
params.device = device.get();
|
||||
// The ownership of the output tensors are bound to this device's lifetime.
|
||||
params.device = cpu_device_.get();
|
||||
params.function_library = function_library;
|
||||
params.create_kernel = [&device, g](const NodeDef& ndef, OpKernel** kernel) {
|
||||
return CreateNonCachedKernel(device.get(), nullptr, ndef,
|
||||
params.create_kernel = [this, g](const NodeDef& ndef, OpKernel** kernel) {
|
||||
return CreateNonCachedKernel(cpu_device_.get(), nullptr, ndef,
|
||||
g->versions().producer(), kernel);
|
||||
};
|
||||
params.delete_kernel = [](OpKernel* kernel) { delete kernel; };
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
@ -44,16 +45,26 @@ namespace tensorflow {
|
||||
// to be particularly lightweight, fast, or efficient.
|
||||
class GraphRunner {
|
||||
public:
|
||||
// REQUIRES: `env` is not nullptr.
|
||||
GraphRunner(Env* env);
|
||||
~GraphRunner();
|
||||
|
||||
// Function semantics for `inputs`, `output_names` and `outputs`
|
||||
// matches those from Session::Run().
|
||||
//
|
||||
// NOTE: The output tensors share lifetime with the GraphRunner, and could
|
||||
// be destroyed once the GraphRunner is destroyed.
|
||||
//
|
||||
// REQUIRES: `graph`, `env`, and `outputs` are not nullptr.
|
||||
// `function_library` may be nullptr.
|
||||
typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
|
||||
static Status Run(Graph* graph, FunctionLibraryRuntime* function_library,
|
||||
Env* env, const NamedTensorList& inputs,
|
||||
const std::vector<string>& output_names,
|
||||
std::vector<Tensor>* outputs);
|
||||
Status Run(Graph* graph, FunctionLibraryRuntime* function_library,
|
||||
const NamedTensorList& inputs,
|
||||
const std::vector<string>& output_names,
|
||||
std::vector<Tensor>* outputs);
|
||||
|
||||
private:
|
||||
std::unique_ptr<Device> cpu_device_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -46,9 +46,9 @@ using test::internal::ExpectEqual;
|
||||
TEST(GraphRunnerTest, SingleConst) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto c = ops::Const(root, 42.0f);
|
||||
GraphRunner graph_runner(Env::Default());
|
||||
std::vector<Tensor> outputs;
|
||||
Status s = GraphRunner::Run(root.graph(), nullptr, Env::Default(), {},
|
||||
{c.name()}, &outputs);
|
||||
Status s = graph_runner.Run(root.graph(), nullptr, {}, {c.name()}, &outputs);
|
||||
TF_ASSERT_OK(s);
|
||||
ExpectEqual(42.0f, outputs[0].scalar<float>()());
|
||||
}
|
||||
@ -57,9 +57,10 @@ TEST(GraphRunnerTest, MultiFetchConst) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto c = ops::Const(root, 42.0f);
|
||||
auto pi = ops::Const(root, 3.14f);
|
||||
GraphRunner graph_runner(Env::Default());
|
||||
std::vector<Tensor> outputs;
|
||||
Status s = GraphRunner::Run(root.graph(), nullptr, Env::Default(), {},
|
||||
{c.name(), pi.name()}, &outputs);
|
||||
Status s = graph_runner.Run(root.graph(), nullptr, {}, {c.name(), pi.name()},
|
||||
&outputs);
|
||||
TF_ASSERT_OK(s);
|
||||
ExpectEqual(42.0f, outputs[0].scalar<float>()());
|
||||
ExpectEqual(3.14f, outputs[1].scalar<float>()());
|
||||
@ -78,9 +79,10 @@ TEST(GraphRunnerTest, FeedAndFetch) {
|
||||
std::vector<std::pair<string, Tensor>> inputs = {{"p1:0", p1_data},
|
||||
{"p2:0", p2_data}};
|
||||
|
||||
GraphRunner graph_runner(Env::Default());
|
||||
std::vector<Tensor> outputs;
|
||||
Status s = GraphRunner::Run(root.graph(), nullptr, Env::Default(), inputs,
|
||||
{"add:0"}, &outputs);
|
||||
Status s =
|
||||
graph_runner.Run(root.graph(), nullptr, inputs, {"add:0"}, &outputs);
|
||||
TF_ASSERT_OK(s);
|
||||
ExpectEqual(3.0f, outputs[0].scalar<float>()());
|
||||
}
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user